Caffe2 - Python API
A deep learning, cross platform ML framework
beta.py
1 from numbers import Number
2 
3 import torch
4 from torch.distributions import constraints
5 from torch.distributions.dirichlet import Dirichlet
6 from torch.distributions.exp_family import ExponentialFamily
7 from torch.distributions.utils import broadcast_all
8 
9 
11  r"""
12  Beta distribution parameterized by :attr:`concentration1` and :attr:`concentration0`.
13 
14  Example::
15 
16  >>> m = Beta(torch.tensor([0.5]), torch.tensor([0.5]))
17  >>> m.sample() # Beta distributed with concentration concentration1 and concentration0
18  tensor([ 0.1046])
19 
20  Args:
21  concentration1 (float or Tensor): 1st concentration parameter of the distribution
22  (often referred to as alpha)
23  concentration0 (float or Tensor): 2nd concentration parameter of the distribution
24  (often referred to as beta)
25  """
26  arg_constraints = {'concentration1': constraints.positive, 'concentration0': constraints.positive}
27  support = constraints.unit_interval
28  has_rsample = True
29 
30  def __init__(self, concentration1, concentration0, validate_args=None):
31  if isinstance(concentration1, Number) and isinstance(concentration0, Number):
32  concentration1_concentration0 = torch.tensor([float(concentration1), float(concentration0)])
33  else:
34  concentration1, concentration0 = broadcast_all(concentration1, concentration0)
35  concentration1_concentration0 = torch.stack([concentration1, concentration0], -1)
36  self._dirichlet = Dirichlet(concentration1_concentration0)
37  super(Beta, self).__init__(self._dirichlet._batch_shape, validate_args=validate_args)
38 
39  def expand(self, batch_shape, _instance=None):
40  new = self._get_checked_instance(Beta, _instance)
41  batch_shape = torch.Size(batch_shape)
42  new._dirichlet = self._dirichlet.expand(batch_shape)
43  super(Beta, new).__init__(batch_shape, validate_args=False)
44  new._validate_args = self._validate_args
45  return new
46 
47  @property
48  def mean(self):
49  return self.concentration1 / (self.concentration1 + self.concentration0)
50 
51  @property
52  def variance(self):
53  total = self.concentration1 + self.concentration0
54  return (self.concentration1 * self.concentration0 /
55  (total.pow(2) * (total + 1)))
56 
57  def rsample(self, sample_shape=()):
58  value = self._dirichlet.rsample(sample_shape).select(-1, 0)
59  if isinstance(value, Number):
60  value = self._dirichlet.concentration.new_tensor(value)
61  return value
62 
63  def log_prob(self, value):
64  if self._validate_args:
65  self._validate_sample(value)
66  heads_tails = torch.stack([value, 1.0 - value], -1)
67  return self._dirichlet.log_prob(heads_tails)
68 
69  def entropy(self):
70  return self._dirichlet.entropy()
71 
72  @property
73  def concentration1(self):
74  result = self._dirichlet.concentration[..., 0]
75  if isinstance(result, Number):
76  return torch.tensor([result])
77  else:
78  return result
79 
80  @property
81  def concentration0(self):
82  result = self._dirichlet.concentration[..., 1]
83  if isinstance(result, Number):
84  return torch.tensor([result])
85  else:
86  return result
87 
88  @property
89  def _natural_params(self):
90  return (self.concentration1, self.concentration0)
91 
92  def _log_normalizer(self, x, y):
93  return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y)
def _get_checked_instance(self, cls, _instance=None)
def concentration1(self)
Definition: beta.py:73
def concentration0(self)
Definition: beta.py:81