1 from numbers
import Number
10 euler_constant = 0.57721566490153286060
15 Samples from a Gumbel Distribution. 19 >>> m = Gumbel(torch.tensor([1.0]), torch.tensor([2.0])) 20 >>> m.sample() # sample from Gumbel distribution with loc=1, scale=2 24 loc (float or Tensor): Location parameter of the distribution 25 scale (float or Tensor): Scale parameter of the distribution 27 arg_constraints = {
'loc': constraints.real,
'scale': constraints.positive}
28 support = constraints.real
30 def __init__(self, loc, scale, validate_args=None):
31 self.loc, self.
scale = broadcast_all(loc, scale)
32 finfo = torch.finfo(self.loc.dtype)
33 if isinstance(loc, Number)
and isinstance(scale, Number):
34 base_dist =
Uniform(finfo.tiny, 1 - finfo.eps)
36 base_dist =
Uniform(torch.full_like(self.loc, finfo.tiny),
37 torch.full_like(self.loc, 1 - finfo.eps))
40 super(Gumbel, self).__init__(base_dist, transforms, validate_args=validate_args)
42 def expand(self, batch_shape, _instance=None):
44 new.loc = self.loc.expand(batch_shape)
45 new.scale = self.scale.expand(batch_shape)
46 return super(Gumbel, self).expand(batch_shape, _instance=new)
49 def log_prob(self, value):
52 y = (self.loc - value) / self.
scale 53 return (y - y.exp()) - self.scale.log()
57 return self.loc + self.
scale * euler_constant
61 return (math.pi / math.sqrt(6)) * self.
scale 65 return self.stddev.pow(2)
68 return self.scale.log() + (1 + euler_constant)
def _get_checked_instance(self, cls, _instance=None)
def _validate_sample(self, value)