12 Creates a ExpRelaxedCategorical parameterized by 13 :attr:`temperature`, and either :attr:`probs` or :attr:`logits` (but not both). 14 Returns the log of a point in the simplex. Based on the interface to 15 :class:`OneHotCategorical`. 17 Implementation based on [1]. 19 See also: :func:`torch.distributions.OneHotCategorical` 22 temperature (Tensor): relaxation temperature 23 probs (Tensor): event probabilities 24 logits (Tensor): the log probability of each event. 26 [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables 27 (Maddison et al, 2017) 29 [2] Categorical Reparametrization with Gumbel-Softmax 32 arg_constraints = {
'probs': constraints.simplex,
33 'logits': constraints.real}
34 support = constraints.real
37 def __init__(self, temperature, probs=None, logits=None, validate_args=None):
40 batch_shape = self._categorical.batch_shape
41 event_shape = self._categorical.param_shape[-1:]
42 super(ExpRelaxedCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args)
44 def expand(self, batch_shape, _instance=None):
46 batch_shape = torch.Size(batch_shape)
48 new._categorical = self._categorical.expand(batch_shape)
49 super(ExpRelaxedCategorical, new).__init__(batch_shape, self.
event_shape, validate_args=
False)
53 def _new(self, *args, **kwargs):
54 return self._categorical._new(*args, **kwargs)
57 def param_shape(self):
58 return self._categorical.param_shape
62 return self._categorical.logits
66 return self._categorical.probs
68 def rsample(self, sample_shape=torch.Size()):
70 uniforms = clamp_probs(torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device))
71 gumbels = -((-(uniforms.log())).log())
73 return scores - scores.logsumexp(dim=-1, keepdim=
True)
75 def log_prob(self, value):
76 K = self._categorical._num_events
79 logits, value = broadcast_all(self.
logits, value)
80 log_scale = (self.temperature.new_tensor(float(K)).lgamma() -
81 self.temperature.log().mul(-(K - 1)))
83 score = (score - score.logsumexp(dim=-1, keepdim=
True)).sum(-1)
84 return score + log_scale
89 Creates a RelaxedOneHotCategorical distribution parametrized by 90 :attr:`temperature`, and either :attr:`probs` or :attr:`logits`. 91 This is a relaxed version of the :class:`OneHotCategorical` distribution, so 92 its samples are on simplex, and are reparametrizable. 96 >>> m = RelaxedOneHotCategorical(torch.tensor([2.2]), 97 torch.tensor([0.1, 0.2, 0.3, 0.4])) 99 tensor([ 0.1294, 0.2324, 0.3859, 0.2523]) 102 temperature (Tensor): relaxation temperature 103 probs (Tensor): event probabilities 104 logits (Tensor): the log probability of each event. 106 arg_constraints = {
'probs': constraints.simplex,
107 'logits': constraints.real}
108 support = constraints.simplex
111 def __init__(self, temperature, probs=None, logits=None, validate_args=None):
113 super(RelaxedOneHotCategorical, self).__init__(base_dist,
115 validate_args=validate_args)
117 def expand(self, batch_shape, _instance=None):
119 return super(RelaxedOneHotCategorical, self).expand(batch_shape, _instance=new)
122 def temperature(self):
123 return self.base_dist.temperature
127 return self.base_dist.logits
131 return self.base_dist.probs
def _get_checked_instance(self, cls, _instance=None)
def _extended_shape(self, sample_shape=torch.Size())
def _validate_sample(self, value)