Caffe2 - Python API
A deep learning, cross platform ML framework
poisson.py
1 from numbers import Number
2 
3 import torch
4 from torch.distributions import constraints
5 from torch.distributions.exp_family import ExponentialFamily
6 from torch.distributions.utils import broadcast_all
7 
8 
10  r"""
11  Creates a Poisson distribution parameterized by :attr:`rate`, the rate parameter.
12 
13  Samples are nonnegative integers, with a pmf given by
14 
15  .. math::
16  \mathrm{rate}^k \frac{e^{-\mathrm{rate}}}{k!}
17 
18  Example::
19 
20  >>> m = Poisson(torch.tensor([4]))
21  >>> m.sample()
22  tensor([ 3.])
23 
24  Args:
25  rate (Number, Tensor): the rate parameter
26  """
27  arg_constraints = {'rate': constraints.positive}
28  support = constraints.nonnegative_integer
29 
30  @property
31  def mean(self):
32  return self.rate
33 
34  @property
35  def variance(self):
36  return self.rate
37 
38  def __init__(self, rate, validate_args=None):
39  self.rate, = broadcast_all(rate)
40  if isinstance(rate, Number):
41  batch_shape = torch.Size()
42  else:
43  batch_shape = self.rate.size()
44  super(Poisson, self).__init__(batch_shape, validate_args=validate_args)
45 
46  def expand(self, batch_shape, _instance=None):
47  new = self._get_checked_instance(Poisson, _instance)
48  batch_shape = torch.Size(batch_shape)
49  new.rate = self.rate.expand(batch_shape)
50  super(Poisson, new).__init__(batch_shape, validate_args=False)
51  new._validate_args = self._validate_args
52  return new
53 
54  def sample(self, sample_shape=torch.Size()):
55  shape = self._extended_shape(sample_shape)
56  with torch.no_grad():
57  return torch.poisson(self.rate.expand(shape))
58 
59  def log_prob(self, value):
60  if self._validate_args:
61  self._validate_sample(value)
62  rate, value = broadcast_all(self.rate, value)
63  return (rate.log() * value) - rate - (value + 1).lgamma()
64 
65  @property
66  def _natural_params(self):
67  return (torch.log(self.rate), )
68 
69  def _log_normalizer(self, x):
70  return torch.exp(x)
def _get_checked_instance(self, cls, _instance=None)
def _extended_shape(self, sample_shape=torch.Size())