2 from numbers
import Number
12 Creates a LogitRelaxedBernoulli distribution parameterized by :attr:`probs` 13 or :attr:`logits` (but not both), which is the logit of a RelaxedBernoulli 16 Samples are logits of values in (0, 1). See [1] for more details. 19 temperature (Tensor): relaxation temperature 20 probs (Number, Tensor): the probability of sampling `1` 21 logits (Number, Tensor): the log-odds of sampling `1` 23 [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random 24 Variables (Maddison et al, 2017) 26 [2] Categorical Reparametrization with Gumbel-Softmax 29 arg_constraints = {
'probs': constraints.unit_interval,
30 'logits': constraints.real}
31 support = constraints.real
33 def __init__(self, temperature, probs=None, logits=None, validate_args=None):
35 if (probs
is None) == (logits
is None):
36 raise ValueError(
"Either `probs` or `logits` must be specified, but not both.")
38 is_scalar = isinstance(probs, Number)
39 self.
probs, = broadcast_all(probs)
41 is_scalar = isinstance(logits, Number)
42 self.
logits, = broadcast_all(logits)
45 batch_shape = torch.Size()
47 batch_shape = self._param.size()
48 super(LogitRelaxedBernoulli, self).__init__(batch_shape, validate_args=validate_args)
50 def expand(self, batch_shape, _instance=None):
52 batch_shape = torch.Size(batch_shape)
54 if 'probs' in self.__dict__:
55 new.probs = self.probs.expand(batch_shape)
56 new._param = new.probs
58 new.logits = self.logits.expand(batch_shape)
59 new._param = new.logits
60 super(LogitRelaxedBernoulli, new).__init__(batch_shape, validate_args=
False)
64 def _new(self, *args, **kwargs):
65 return self._param.new(*args, **kwargs)
69 return probs_to_logits(self.
probs, is_binary=
True)
73 return logits_to_probs(self.
logits, is_binary=
True)
76 def param_shape(self):
77 return self._param.size()
79 def rsample(self, sample_shape=torch.Size()):
81 probs = clamp_probs(self.probs.expand(shape))
82 uniforms = clamp_probs(torch.rand(shape, dtype=probs.dtype, device=probs.device))
83 return (uniforms.log() - (-uniforms).log1p() + probs.log() - (-probs).log1p()) / self.
temperature 85 def log_prob(self, value):
88 logits, value = broadcast_all(self.
logits, value)
90 return self.temperature.log() + diff - 2 * diff.exp().log1p()
95 Creates a RelaxedBernoulli distribution, parametrized by 96 :attr:`temperature`, and either :attr:`probs` or :attr:`logits` 97 (but not both). This is a relaxed version of the `Bernoulli` distribution, 98 so the values are in (0, 1), and has reparametrizable samples. 102 >>> m = RelaxedBernoulli(torch.tensor([2.2]), 103 torch.tensor([0.1, 0.2, 0.3, 0.99])) 105 tensor([ 0.2951, 0.3442, 0.8918, 0.9021]) 108 temperature (Tensor): relaxation temperature 109 probs (Number, Tensor): the probability of sampling `1` 110 logits (Number, Tensor): the log-odds of sampling `1` 112 arg_constraints = {
'probs': constraints.unit_interval,
113 'logits': constraints.real}
114 support = constraints.unit_interval
117 def __init__(self, temperature, probs=None, logits=None, validate_args=None):
119 super(RelaxedBernoulli, self).__init__(base_dist,
121 validate_args=validate_args)
123 def expand(self, batch_shape, _instance=None):
125 return super(RelaxedBernoulli, self).expand(batch_shape, _instance=new)
128 def temperature(self):
129 return self.base_dist.temperature
133 return self.base_dist.logits
137 return self.base_dist.probs
def _get_checked_instance(self, cls, _instance=None)
def _extended_shape(self, sample_shape=torch.Size())
def _validate_sample(self, value)