Caffe2 - Python API
A deep learning, cross platform ML framework
gumbel.py
1 from numbers import Number
2 import math
3 import torch
4 from torch.distributions import constraints
5 from torch.distributions.uniform import Uniform
6 from torch.distributions.transformed_distribution import TransformedDistribution
7 from torch.distributions.transforms import AffineTransform, ExpTransform
8 from torch.distributions.utils import broadcast_all
9 
10 euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant
11 
12 
14  r"""
15  Samples from a Gumbel Distribution.
16 
17  Examples::
18 
19  >>> m = Gumbel(torch.tensor([1.0]), torch.tensor([2.0]))
20  >>> m.sample() # sample from Gumbel distribution with loc=1, scale=2
21  tensor([ 1.0124])
22 
23  Args:
24  loc (float or Tensor): Location parameter of the distribution
25  scale (float or Tensor): Scale parameter of the distribution
26  """
27  arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
28  support = constraints.real
29 
30  def __init__(self, loc, scale, validate_args=None):
31  self.loc, self.scale = broadcast_all(loc, scale)
32  finfo = torch.finfo(self.loc.dtype)
33  if isinstance(loc, Number) and isinstance(scale, Number):
34  base_dist = Uniform(finfo.tiny, 1 - finfo.eps)
35  else:
36  base_dist = Uniform(torch.full_like(self.loc, finfo.tiny),
37  torch.full_like(self.loc, 1 - finfo.eps))
38  transforms = [ExpTransform().inv, AffineTransform(loc=0, scale=-torch.ones_like(self.scale)),
39  ExpTransform().inv, AffineTransform(loc=loc, scale=-self.scale)]
40  super(Gumbel, self).__init__(base_dist, transforms, validate_args=validate_args)
41 
42  def expand(self, batch_shape, _instance=None):
43  new = self._get_checked_instance(Gumbel, _instance)
44  new.loc = self.loc.expand(batch_shape)
45  new.scale = self.scale.expand(batch_shape)
46  return super(Gumbel, self).expand(batch_shape, _instance=new)
47 
48  # Explicitly defining the log probability function for Gumbel due to precision issues
49  def log_prob(self, value):
50  if self._validate_args:
51  self._validate_sample(value)
52  y = (self.loc - value) / self.scale
53  return (y - y.exp()) - self.scale.log()
54 
55  @property
56  def mean(self):
57  return self.loc + self.scale * euler_constant
58 
59  @property
60  def stddev(self):
61  return (math.pi / math.sqrt(6)) * self.scale
62 
63  @property
64  def variance(self):
65  return self.stddev.pow(2)
66 
67  def entropy(self):
68  return self.scale.log() + (1 + euler_constant)
def _get_checked_instance(self, cls, _instance=None)