Caffe2 - Python API
A deep learning, cross platform ML framework
transformed_distribution.py
1 import torch
2 from torch.distributions import constraints
3 from torch.distributions.distribution import Distribution
4 from torch.distributions.transforms import Transform
5 from torch.distributions.utils import _sum_rightmost
6 
7 
9  r"""
10  Extension of the Distribution class, which applies a sequence of Transforms
11  to a base distribution. Let f be the composition of transforms applied::
12 
13  X ~ BaseDistribution
14  Y = f(X) ~ TransformedDistribution(BaseDistribution, f)
15  log p(Y) = log p(X) + log |det (dX/dY)|
16 
17  Note that the ``.event_shape`` of a :class:`TransformedDistribution` is the
18  maximum shape of its base distribution and its transforms, since transforms
19  can introduce correlations among events.
20 
21  An example for the usage of :class:`TransformedDistribution` would be::
22 
23  # Building a Logistic Distribution
24  # X ~ Uniform(0, 1)
25  # f = a + b * logit(X)
26  # Y ~ f(X) ~ Logistic(a, b)
27  base_distribution = Uniform(0, 1)
28  transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)]
29  logistic = TransformedDistribution(base_distribution, transforms)
30 
31  For more examples, please look at the implementations of
32  :class:`~torch.distributions.gumbel.Gumbel`,
33  :class:`~torch.distributions.half_cauchy.HalfCauchy`,
34  :class:`~torch.distributions.half_normal.HalfNormal`,
35  :class:`~torch.distributions.log_normal.LogNormal`,
36  :class:`~torch.distributions.pareto.Pareto`,
37  :class:`~torch.distributions.weibull.Weibull`,
38  :class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and
39  :class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical`
40  """
41  arg_constraints = {}
42 
43  def __init__(self, base_distribution, transforms, validate_args=None):
44  self.base_dist = base_distribution
45  if isinstance(transforms, Transform):
46  self.transforms = [transforms, ]
47  elif isinstance(transforms, list):
48  if not all(isinstance(t, Transform) for t in transforms):
49  raise ValueError("transforms must be a Transform or a list of Transforms")
50  self.transforms = transforms
51  else:
52  raise ValueError("transforms must be a Transform or list, but was {}".format(transforms))
53  shape = self.base_dist.batch_shape + self.base_dist.event_shape
54  event_dim = max([len(self.base_dist.event_shape)] + [t.event_dim for t in self.transforms])
55  batch_shape = shape[:len(shape) - event_dim]
56  event_shape = shape[len(shape) - event_dim:]
57  super(TransformedDistribution, self).__init__(batch_shape, event_shape, validate_args=validate_args)
58 
59  def expand(self, batch_shape, _instance=None):
60  new = self._get_checked_instance(TransformedDistribution, _instance)
61  batch_shape = torch.Size(batch_shape)
62  base_dist_batch_shape = batch_shape + self.base_dist.batch_shape[len(self.batch_shape):]
63  new.base_dist = self.base_dist.expand(base_dist_batch_shape)
64  new.transforms = self.transforms
65  super(TransformedDistribution, new).__init__(batch_shape, self.event_shape, validate_args=False)
66  new._validate_args = self._validate_args
67  return new
68 
69  @constraints.dependent_property
70  def support(self):
71  return self.transforms[-1].codomain if self.transforms else self.base_dist.support
72 
73  @property
74  def has_rsample(self):
75  return self.base_dist.has_rsample
76 
77  def sample(self, sample_shape=torch.Size()):
78  """
79  Generates a sample_shape shaped sample or sample_shape shaped batch of
80  samples if the distribution parameters are batched. Samples first from
81  base distribution and applies `transform()` for every transform in the
82  list.
83  """
84  with torch.no_grad():
85  x = self.base_dist.sample(sample_shape)
86  for transform in self.transforms:
87  x = transform(x)
88  return x
89 
90  def rsample(self, sample_shape=torch.Size()):
91  """
92  Generates a sample_shape shaped reparameterized sample or sample_shape
93  shaped batch of reparameterized samples if the distribution parameters
94  are batched. Samples first from base distribution and applies
95  `transform()` for every transform in the list.
96  """
97  x = self.base_dist.rsample(sample_shape)
98  for transform in self.transforms:
99  x = transform(x)
100  return x
101 
102  def log_prob(self, value):
103  """
104  Scores the sample by inverting the transform(s) and computing the score
105  using the score of the base distribution and the log abs det jacobian.
106  """
107  event_dim = len(self.event_shape)
108  log_prob = 0.0
109  y = value
110  for transform in reversed(self.transforms):
111  x = transform.inv(y)
112  log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y),
113  event_dim - transform.event_dim)
114  y = x
115 
116  log_prob = log_prob + _sum_rightmost(self.base_dist.log_prob(y),
117  event_dim - len(self.base_dist.event_shape))
118  return log_prob
119 
120  def _monotonize_cdf(self, value):
121  """
122  This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is
123  monotone increasing.
124  """
125  sign = 1
126  for transform in self.transforms:
127  sign = sign * transform.sign
128  if isinstance(sign, int) and sign == 1:
129  return value
130  return sign * (value - 0.5) + 0.5
131 
132  def cdf(self, value):
133  """
134  Computes the cumulative distribution function by inverting the
135  transform(s) and computing the score of the base distribution.
136  """
137  for transform in self.transforms[::-1]:
138  value = transform.inv(value)
139  if self._validate_args:
140  self.base_dist._validate_sample(value)
141  value = self.base_dist.cdf(value)
142  value = self._monotonize_cdf(value)
143  return value
144 
145  def icdf(self, value):
146  """
147  Computes the inverse cumulative distribution function using
148  transform(s) and computing the score of the base distribution.
149  """
150  value = self._monotonize_cdf(value)
151  if self._validate_args:
152  self.base_dist._validate_sample(value)
153  value = self.base_dist.icdf(value)
154  for transform in self.transforms:
155  value = transform(value)
156  return value
def _get_checked_instance(self, cls, _instance=None)