10 Creates a Negative Binomial distribution, i.e. distribution 11 of the number of independent identical Bernoulli trials 12 needed before :attr:`total_count` failures are achieved. The probability 13 of success of each Bernoulli trial is :attr:`probs`. 16 total_count (float or Tensor): non-negative number of negative Bernoulli 17 trials to stop, although the distribution is still valid for real 19 probs (Tensor): Event probabilities of success in the half open interval [0, 1) 20 logits (Tensor): Event log-odds for probabilities of success 22 arg_constraints = {
'total_count': constraints.greater_than_eq(0),
23 'probs': constraints.half_open_interval(0., 1.),
24 'logits': constraints.real}
25 support = constraints.nonnegative_integer
27 def __init__(self, total_count, probs=None, logits=None, validate_args=None):
28 if (probs
is None) == (logits
is None):
29 raise ValueError(
"Either `probs` or `logits` must be specified, but not both.")
38 batch_shape = self._param.size()
39 super(NegativeBinomial, self).__init__(batch_shape, validate_args=validate_args)
41 def expand(self, batch_shape, _instance=None):
43 batch_shape = torch.Size(batch_shape)
44 new.total_count = self.total_count.expand(batch_shape)
45 if 'probs' in self.__dict__:
46 new.probs = self.probs.expand(batch_shape)
47 new._param = new.probs
49 new.logits = self.logits.expand(batch_shape)
50 new._param = new.logits
51 super(NegativeBinomial, new).__init__(batch_shape, validate_args=
False)
55 def _new(self, *args, **kwargs):
56 return self._param.new(*args, **kwargs)
68 return probs_to_logits(self.
probs, is_binary=
True)
72 return logits_to_probs(self.
logits, is_binary=
True)
75 def param_shape(self):
76 return self._param.size()
80 return torch.distributions.Gamma(concentration=self.
total_count,
81 rate=torch.exp(-self.
logits))
83 def sample(self, sample_shape=torch.Size()):
85 rate = self._gamma.sample(sample_shape=sample_shape)
86 return torch.poisson(rate)
88 def log_prob(self, value):
93 value * F.logsigmoid(self.
logits))
95 log_normalization = (-torch.lgamma(self.
total_count + value) + torch.lgamma(1. + value) +
98 return log_unnormalized_prob - log_normalization
def _get_checked_instance(self, cls, _instance=None)
def _validate_sample(self, value)