Caffe2 - Python API
A deep learning, cross platform ML framework
relaxed_categorical.py
1 import torch
2 from torch.distributions import constraints
3 from torch.distributions.categorical import Categorical
4 from torch.distributions.utils import clamp_probs, broadcast_all
5 from torch.distributions.distribution import Distribution
6 from torch.distributions.transformed_distribution import TransformedDistribution
7 from torch.distributions.transforms import ExpTransform
8 
9 
11  r"""
12  Creates a ExpRelaxedCategorical parameterized by
13  :attr:`temperature`, and either :attr:`probs` or :attr:`logits` (but not both).
14  Returns the log of a point in the simplex. Based on the interface to
15  :class:`OneHotCategorical`.
16 
17  Implementation based on [1].
18 
19  See also: :func:`torch.distributions.OneHotCategorical`
20 
21  Args:
22  temperature (Tensor): relaxation temperature
23  probs (Tensor): event probabilities
24  logits (Tensor): the log probability of each event.
25 
26  [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables
27  (Maddison et al, 2017)
28 
29  [2] Categorical Reparametrization with Gumbel-Softmax
30  (Jang et al, 2017)
31  """
32  arg_constraints = {'probs': constraints.simplex,
33  'logits': constraints.real}
34  support = constraints.real
35  has_rsample = True
36 
37  def __init__(self, temperature, probs=None, logits=None, validate_args=None):
38  self._categorical = Categorical(probs, logits)
39  self.temperature = temperature
40  batch_shape = self._categorical.batch_shape
41  event_shape = self._categorical.param_shape[-1:]
42  super(ExpRelaxedCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args)
43 
44  def expand(self, batch_shape, _instance=None):
45  new = self._get_checked_instance(ExpRelaxedCategorical, _instance)
46  batch_shape = torch.Size(batch_shape)
47  new.temperature = self.temperature
48  new._categorical = self._categorical.expand(batch_shape)
49  super(ExpRelaxedCategorical, new).__init__(batch_shape, self.event_shape, validate_args=False)
50  new._validate_args = self._validate_args
51  return new
52 
53  def _new(self, *args, **kwargs):
54  return self._categorical._new(*args, **kwargs)
55 
56  @property
57  def param_shape(self):
58  return self._categorical.param_shape
59 
60  @property
61  def logits(self):
62  return self._categorical.logits
63 
64  @property
65  def probs(self):
66  return self._categorical.probs
67 
68  def rsample(self, sample_shape=torch.Size()):
69  shape = self._extended_shape(sample_shape)
70  uniforms = clamp_probs(torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device))
71  gumbels = -((-(uniforms.log())).log())
72  scores = (self.logits + gumbels) / self.temperature
73  return scores - scores.logsumexp(dim=-1, keepdim=True)
74 
75  def log_prob(self, value):
76  K = self._categorical._num_events
77  if self._validate_args:
78  self._validate_sample(value)
79  logits, value = broadcast_all(self.logits, value)
80  log_scale = (self.temperature.new_tensor(float(K)).lgamma() -
81  self.temperature.log().mul(-(K - 1)))
82  score = logits - value.mul(self.temperature)
83  score = (score - score.logsumexp(dim=-1, keepdim=True)).sum(-1)
84  return score + log_scale
85 
86 
88  r"""
89  Creates a RelaxedOneHotCategorical distribution parametrized by
90  :attr:`temperature`, and either :attr:`probs` or :attr:`logits`.
91  This is a relaxed version of the :class:`OneHotCategorical` distribution, so
92  its samples are on simplex, and are reparametrizable.
93 
94  Example::
95 
96  >>> m = RelaxedOneHotCategorical(torch.tensor([2.2]),
97  torch.tensor([0.1, 0.2, 0.3, 0.4]))
98  >>> m.sample()
99  tensor([ 0.1294, 0.2324, 0.3859, 0.2523])
100 
101  Args:
102  temperature (Tensor): relaxation temperature
103  probs (Tensor): event probabilities
104  logits (Tensor): the log probability of each event.
105  """
106  arg_constraints = {'probs': constraints.simplex,
107  'logits': constraints.real}
108  support = constraints.simplex
109  has_rsample = True
110 
111  def __init__(self, temperature, probs=None, logits=None, validate_args=None):
112  base_dist = ExpRelaxedCategorical(temperature, probs, logits)
113  super(RelaxedOneHotCategorical, self).__init__(base_dist,
114  ExpTransform(),
115  validate_args=validate_args)
116 
117  def expand(self, batch_shape, _instance=None):
118  new = self._get_checked_instance(RelaxedOneHotCategorical, _instance)
119  return super(RelaxedOneHotCategorical, self).expand(batch_shape, _instance=new)
120 
121  @property
122  def temperature(self):
123  return self.base_dist.temperature
124 
125  @property
126  def logits(self):
127  return self.base_dist.logits
128 
129  @property
130  def probs(self):
131  return self.base_dist.probs
def _get_checked_instance(self, cls, _instance=None)
def _extended_shape(self, sample_shape=torch.Size())