Caffe2 - Python API
A deep learning, cross platform ML framework
one_hot_categorical.py
1 import torch
2 from torch.distributions import constraints
3 from torch.distributions.categorical import Categorical
4 from torch.distributions.distribution import Distribution
5 
6 
8  r"""
9  Creates a one-hot categorical distribution parameterized by :attr:`probs` or
10  :attr:`logits`.
11 
12  Samples are one-hot coded vectors of size ``probs.size(-1)``.
13 
14  .. note:: :attr:`probs` must be non-negative, finite and have a non-zero sum,
15  and it will be normalized to sum to 1.
16 
17  See also: :func:`torch.distributions.Categorical` for specifications of
18  :attr:`probs` and :attr:`logits`.
19 
20  Example::
21 
22  >>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
23  >>> m.sample() # equal probability of 0, 1, 2, 3
24  tensor([ 0., 0., 0., 1.])
25 
26  Args:
27  probs (Tensor): event probabilities
28  logits (Tensor): event log probabilities
29  """
30  arg_constraints = {'probs': constraints.simplex,
31  'logits': constraints.real}
32  support = constraints.simplex
33  has_enumerate_support = True
34 
35  def __init__(self, probs=None, logits=None, validate_args=None):
36  self._categorical = Categorical(probs, logits)
37  batch_shape = self._categorical.batch_shape
38  event_shape = self._categorical.param_shape[-1:]
39  super(OneHotCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args)
40 
41  def expand(self, batch_shape, _instance=None):
42  new = self._get_checked_instance(OneHotCategorical, _instance)
43  batch_shape = torch.Size(batch_shape)
44  new._categorical = self._categorical.expand(batch_shape)
45  super(OneHotCategorical, new).__init__(batch_shape, self.event_shape, validate_args=False)
46  new._validate_args = self._validate_args
47  return new
48 
49  def _new(self, *args, **kwargs):
50  return self._categorical._new(*args, **kwargs)
51 
52  @property
53  def _param(self):
54  return self._categorical._param
55 
56  @property
57  def probs(self):
58  return self._categorical.probs
59 
60  @property
61  def logits(self):
62  return self._categorical.logits
63 
64  @property
65  def mean(self):
66  return self._categorical.probs
67 
68  @property
69  def variance(self):
70  return self._categorical.probs * (1 - self._categorical.probs)
71 
72  @property
73  def param_shape(self):
74  return self._categorical.param_shape
75 
76  def sample(self, sample_shape=torch.Size()):
77  sample_shape = torch.Size(sample_shape)
78  probs = self._categorical.probs
79  num_events = self._categorical._num_events
80  indices = self._categorical.sample(sample_shape)
81  return torch.nn.functional.one_hot(indices, num_events).to(probs)
82 
83  def log_prob(self, value):
84  if self._validate_args:
85  self._validate_sample(value)
86  indices = value.max(-1)[1]
87  return self._categorical.log_prob(indices)
88 
89  def entropy(self):
90  return self._categorical.entropy()
91 
92  def enumerate_support(self, expand=True):
93  n = self.event_shape[0]
94  values = torch.eye(n, dtype=self._param.dtype, device=self._param.device)
95  values = values.view((n,) + (1,) * len(self.batch_shape) + (n,))
96  if expand:
97  values = values.expand((n,) + self.batch_shape + (n,))
98  return values
def _get_checked_instance(self, cls, _instance=None)