9 Creates a one-hot categorical distribution parameterized by :attr:`probs` or 12 Samples are one-hot coded vectors of size ``probs.size(-1)``. 14 .. note:: :attr:`probs` must be non-negative, finite and have a non-zero sum, 15 and it will be normalized to sum to 1. 17 See also: :func:`torch.distributions.Categorical` for specifications of 18 :attr:`probs` and :attr:`logits`. 22 >>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) 23 >>> m.sample() # equal probability of 0, 1, 2, 3 24 tensor([ 0., 0., 0., 1.]) 27 probs (Tensor): event probabilities 28 logits (Tensor): event log probabilities 30 arg_constraints = {
'probs': constraints.simplex,
31 'logits': constraints.real}
32 support = constraints.simplex
33 has_enumerate_support =
True 35 def __init__(self, probs=None, logits=None, validate_args=None):
37 batch_shape = self._categorical.batch_shape
38 event_shape = self._categorical.param_shape[-1:]
39 super(OneHotCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args)
41 def expand(self, batch_shape, _instance=None):
43 batch_shape = torch.Size(batch_shape)
44 new._categorical = self._categorical.expand(batch_shape)
45 super(OneHotCategorical, new).__init__(batch_shape, self.
event_shape, validate_args=
False)
49 def _new(self, *args, **kwargs):
50 return self._categorical._new(*args, **kwargs)
54 return self._categorical._param
58 return self._categorical.probs
62 return self._categorical.logits
66 return self._categorical.probs
70 return self._categorical.probs * (1 - self._categorical.probs)
73 def param_shape(self):
74 return self._categorical.param_shape
76 def sample(self, sample_shape=torch.Size()):
77 sample_shape = torch.Size(sample_shape)
78 probs = self._categorical.probs
79 num_events = self._categorical._num_events
80 indices = self._categorical.sample(sample_shape)
83 def log_prob(self, value):
86 indices = value.max(-1)[1]
87 return self._categorical.log_prob(indices)
90 return self._categorical.entropy()
92 def enumerate_support(self, expand=True):
94 values = torch.eye(n, dtype=self._param.dtype, device=self._param.device)
95 values = values.view((n,) + (1,) * len(self.
batch_shape) + (n,))
97 values = values.expand((n,) + self.
batch_shape + (n,))
def _get_checked_instance(self, cls, _instance=None)
def _validate_sample(self, value)