1 from numbers
import Number
12 Beta distribution parameterized by :attr:`concentration1` and :attr:`concentration0`. 16 >>> m = Beta(torch.tensor([0.5]), torch.tensor([0.5])) 17 >>> m.sample() # Beta distributed with concentration concentration1 and concentration0 21 concentration1 (float or Tensor): 1st concentration parameter of the distribution 22 (often referred to as alpha) 23 concentration0 (float or Tensor): 2nd concentration parameter of the distribution 24 (often referred to as beta) 26 arg_constraints = {
'concentration1': constraints.positive,
'concentration0': constraints.positive}
27 support = constraints.unit_interval
30 def __init__(self, concentration1, concentration0, validate_args=None):
31 if isinstance(concentration1, Number)
and isinstance(concentration0, Number):
32 concentration1_concentration0 =
torch.tensor([float(concentration1), float(concentration0)])
34 concentration1, concentration0 = broadcast_all(concentration1, concentration0)
35 concentration1_concentration0 = torch.stack([concentration1, concentration0], -1)
37 super(Beta, self).__init__(self._dirichlet._batch_shape, validate_args=validate_args)
39 def expand(self, batch_shape, _instance=None):
41 batch_shape = torch.Size(batch_shape)
42 new._dirichlet = self._dirichlet.expand(batch_shape)
43 super(Beta, new).__init__(batch_shape, validate_args=
False)
55 (total.pow(2) * (total + 1)))
57 def rsample(self, sample_shape=()):
58 value = self._dirichlet.rsample(sample_shape).select(-1, 0)
59 if isinstance(value, Number):
60 value = self._dirichlet.concentration.new_tensor(value)
63 def log_prob(self, value):
66 heads_tails = torch.stack([value, 1.0 - value], -1)
67 return self._dirichlet.log_prob(heads_tails)
70 return self._dirichlet.entropy()
73 def concentration1(self):
74 result = self._dirichlet.concentration[..., 0]
75 if isinstance(result, Number):
81 def concentration0(self):
82 result = self._dirichlet.concentration[..., 1]
83 if isinstance(result, Number):
89 def _natural_params(self):
92 def _log_normalizer(self, x, y):
93 return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y)
def _get_checked_instance(self, cls, _instance=None)
def _validate_sample(self, value)