Caffe2 - Python API
A deep learning, cross platform ML framework
logistic_normal.py
1 import torch
2 from torch.distributions import constraints
3 from torch.distributions.normal import Normal
4 from torch.distributions.transformed_distribution import TransformedDistribution
5 from torch.distributions.transforms import ComposeTransform, ExpTransform, StickBreakingTransform
6 
7 
9  r"""
10  Creates a logistic-normal distribution parameterized by :attr:`loc` and :attr:`scale`
11  that define the base `Normal` distribution transformed with the
12  `StickBreakingTransform` such that::
13 
14  X ~ LogisticNormal(loc, scale)
15  Y = log(X / (1 - X.cumsum(-1)))[..., :-1] ~ Normal(loc, scale)
16 
17  Args:
18  loc (float or Tensor): mean of the base distribution
19  scale (float or Tensor): standard deviation of the base distribution
20 
21  Example::
22 
23  >>> # logistic-normal distributed with mean=(0, 0, 0) and stddev=(1, 1, 1)
24  >>> # of the base Normal distribution
25  >>> m = distributions.LogisticNormal(torch.tensor([0.0] * 3), torch.tensor([1.0] * 3))
26  >>> m.sample()
27  tensor([ 0.7653, 0.0341, 0.0579, 0.1427])
28 
29  """
30  arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
31  support = constraints.simplex
32  has_rsample = True
33 
34  def __init__(self, loc, scale, validate_args=None):
35  base_dist = Normal(loc, scale)
36  super(LogisticNormal, self).__init__(base_dist,
38  validate_args=validate_args)
39  # Adjust event shape since StickBreakingTransform adds 1 dimension
40  self._event_shape = torch.Size([s + 1 for s in self._event_shape])
41 
42  def expand(self, batch_shape, _instance=None):
43  new = self._get_checked_instance(LogisticNormal, _instance)
44  return super(LogisticNormal, self).expand(batch_shape, _instance=new)
45 
46  @property
47  def loc(self):
48  return self.base_dist.loc
49 
50  @property
51  def scale(self):
52  return self.base_dist.scale
def _get_checked_instance(self, cls, _instance=None)