5 from numbers
import Number
12 Creates a Multinomial distribution parameterized by :attr:`total_count` and 13 either :attr:`probs` or :attr:`logits` (but not both). The innermost dimension of 14 :attr:`probs` indexes over categories. All other dimensions index over batches. 16 Note that :attr:`total_count` need not be specified if only :meth:`log_prob` is 17 called (see example below) 19 .. note:: :attr:`probs` must be non-negative, finite and have a non-zero sum, 20 and it will be normalized to sum to 1. 22 - :meth:`sample` requires a single shared `total_count` for all 23 parameters and samples. 24 - :meth:`log_prob` allows different `total_count` for each parameter and 29 >>> m = Multinomial(100, torch.tensor([ 1., 1., 1., 1.])) 30 >>> x = m.sample() # equal probability of 0, 1, 2, 3 31 tensor([ 21., 24., 30., 25.]) 33 >>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x) 37 total_count (int): number of trials 38 probs (Tensor): event probabilities 39 logits (Tensor): event log probabilities 41 arg_constraints = {
'probs': constraints.simplex,
42 'logits': constraints.real}
52 def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
53 if not isinstance(total_count, Number):
54 raise NotImplementedError(
'inhomogeneous total_count is not supported')
56 self.
_categorical = Categorical(probs=probs, logits=logits)
57 batch_shape = self._categorical.batch_shape
58 event_shape = self._categorical.param_shape[-1:]
59 super(Multinomial, self).__init__(batch_shape, event_shape, validate_args=validate_args)
61 def expand(self, batch_shape, _instance=None):
63 batch_shape = torch.Size(batch_shape)
65 new._categorical = self._categorical.expand(batch_shape)
66 super(Multinomial, new).__init__(batch_shape, self.
event_shape, validate_args=
False)
70 def _new(self, *args, **kwargs):
71 return self._categorical._new(*args, **kwargs)
73 @constraints.dependent_property
75 return constraints.integer_interval(0, self.
total_count)
79 return self._categorical.logits
83 return self._categorical.probs
86 def param_shape(self):
87 return self._categorical.param_shape
89 def sample(self, sample_shape=torch.Size()):
90 sample_shape = torch.Size(sample_shape)
91 samples = self._categorical.sample(torch.Size((self.
total_count,)) + sample_shape)
94 shifted_idx = list(range(samples.dim()))
95 shifted_idx.append(shifted_idx.pop(0))
96 samples = samples.permute(*shifted_idx)
98 counts.scatter_add_(-1, samples, torch.ones_like(samples))
99 return counts.type_as(self.
probs)
101 def log_prob(self, value):
104 logits, value = broadcast_all(self.logits.clone(), value)
105 log_factorial_n = torch.lgamma(value.sum(-1) + 1)
106 log_factorial_xs = torch.lgamma(value + 1).sum(-1)
107 logits[(value == 0) & (logits == -inf)] = 0
108 log_powers = (logits * value).sum(-1)
109 return log_factorial_n - log_factorial_xs + log_powers
def _get_checked_instance(self, cls, _instance=None)
def _extended_shape(self, sample_shape=torch.Size())
def _validate_sample(self, value)