1 from numbers
import Number
9 def _standard_gamma(concentration):
10 return torch._standard_gamma(concentration)
15 Creates a Gamma distribution parameterized by shape :attr:`concentration` and :attr:`rate`. 19 >>> m = Gamma(torch.tensor([1.0]), torch.tensor([1.0])) 20 >>> m.sample() # Gamma distributed with concentration=1 and rate=1 24 concentration (float or Tensor): shape parameter of the distribution 25 (often referred to as alpha) 26 rate (float or Tensor): rate = 1 / scale of the distribution 27 (often referred to as beta) 29 arg_constraints = {
'concentration': constraints.positive,
'rate': constraints.positive}
30 support = constraints.positive
32 _mean_carrier_measure = 0
36 return self.concentration / self.
rate 40 return self.concentration / self.rate.pow(2)
42 def __init__(self, concentration, rate, validate_args=None):
43 self.concentration, self.
rate = broadcast_all(concentration, rate)
44 if isinstance(concentration, Number)
and isinstance(rate, Number):
45 batch_shape = torch.Size()
47 batch_shape = self.concentration.size()
48 super(Gamma, self).__init__(batch_shape, validate_args=validate_args)
50 def expand(self, batch_shape, _instance=None):
52 batch_shape = torch.Size(batch_shape)
53 new.concentration = self.concentration.expand(batch_shape)
54 new.rate = self.rate.expand(batch_shape)
55 super(Gamma, new).__init__(batch_shape, validate_args=
False)
59 def rsample(self, sample_shape=torch.Size()):
61 value = _standard_gamma(self.concentration.expand(shape)) / self.rate.expand(shape)
62 value.detach().clamp_(min=torch.finfo(value.dtype).tiny)
65 def log_prob(self, value):
68 return (self.concentration * torch.log(self.
rate) +
69 (self.concentration - 1) * torch.log(value) -
70 self.
rate * value - torch.lgamma(self.concentration))
73 return (self.concentration - torch.log(self.
rate) + torch.lgamma(self.concentration) +
74 (1.0 - self.concentration) * torch.digamma(self.concentration))
77 def _natural_params(self):
78 return (self.concentration - 1, -self.
rate)
80 def _log_normalizer(self, x, y):
81 return torch.lgamma(x + 1) + (x + 1) * torch.log(-y.reciprocal())
def _get_checked_instance(self, cls, _instance=None)
def _extended_shape(self, sample_shape=torch.Size())
def _validate_sample(self, value)