Caffe2 - Python API
A deep learning, cross platform ML framework
transforms.py
1 import math
2 import numbers
3 import weakref
4 
5 import torch
6 from torch.distributions import constraints
7 from torch.distributions.utils import (_sum_rightmost, broadcast_all,
8  lazy_property)
9 from torch.nn.functional import pad
10 
11 __all__ = [
12  'AbsTransform',
13  'AffineTransform',
14  'ComposeTransform',
15  'ExpTransform',
16  'LowerCholeskyTransform',
17  'PowerTransform',
18  'SigmoidTransform',
19  'SoftmaxTransform',
20  'StickBreakingTransform',
21  'Transform',
22  'identity_transform',
23 ]
24 
25 
26 class Transform(object):
27  """
28  Abstract class for invertable transformations with computable log
29  det jacobians. They are primarily used in
30  :class:`torch.distributions.TransformedDistribution`.
31 
32  Caching is useful for tranforms whose inverses are either expensive or
33  numerically unstable. Note that care must be taken with memoized values
34  since the autograd graph may be reversed. For example while the following
35  works with or without caching::
36 
37  y = t(x)
38  t.log_abs_det_jacobian(x, y).backward() # x will receive gradients.
39 
40  However the following will error when caching due to dependency reversal::
41 
42  y = t(x)
43  z = t.inv(y)
44  grad(z.sum(), [y]) # error because z is x
45 
46  Derived classes should implement one or both of :meth:`_call` or
47  :meth:`_inverse`. Derived classes that set `bijective=True` should also
48  implement :meth:`log_abs_det_jacobian`.
49 
50  Args:
51  cache_size (int): Size of cache. If zero, no caching is done. If one,
52  the latest single value is cached. Only 0 and 1 are supported.
53 
54  Attributes:
55  domain (:class:`~torch.distributions.constraints.Constraint`):
56  The constraint representing valid inputs to this transform.
57  codomain (:class:`~torch.distributions.constraints.Constraint`):
58  The constraint representing valid outputs to this transform
59  which are inputs to the inverse transform.
60  bijective (bool): Whether this transform is bijective. A transform
61  ``t`` is bijective iff ``t.inv(t(x)) == x`` and
62  ``t(t.inv(y)) == y`` for every ``x`` in the domain and ``y`` in
63  the codomain. Transforms that are not bijective should at least
64  maintain the weaker pseudoinverse properties
65  ``t(t.inv(t(x)) == t(x)`` and ``t.inv(t(t.inv(y))) == t.inv(y)``.
66  sign (int or Tensor): For bijective univariate transforms, this
67  should be +1 or -1 depending on whether transform is monotone
68  increasing or decreasing.
69  event_dim (int): Number of dimensions that are correlated together in
70  the transform ``event_shape``. This should be 0 for pointwise
71  transforms, 1 for transforms that act jointly on vectors, 2 for
72  transforms that act jointly on matrices, etc.
73  """
74  bijective = False
75  event_dim = 0
76 
77  def __init__(self, cache_size=0):
78  self._cache_size = cache_size
79  self._inv = None
80  if cache_size == 0:
81  pass # default behavior
82  elif cache_size == 1:
83  self._cached_x_y = None, None
84  else:
85  raise ValueError('cache_size must be 0 or 1')
86  super(Transform, self).__init__()
87 
88  @property
89  def inv(self):
90  """
91  Returns the inverse :class:`Transform` of this transform.
92  This should satisfy ``t.inv.inv is t``.
93  """
94  inv = None
95  if self._inv is not None:
96  inv = self._inv()
97  if inv is None:
98  inv = _InverseTransform(self)
99  self._inv = weakref.ref(inv)
100  return inv
101 
102  @property
103  def sign(self):
104  """
105  Returns the sign of the determinant of the Jacobian, if applicable.
106  In general this only makes sense for bijective transforms.
107  """
108  raise NotImplementedError
109 
110  def __eq__(self, other):
111  return self is other
112 
113  def __ne__(self, other):
114  # Necessary for Python2
115  return not self.__eq__(other)
116 
117  def __call__(self, x):
118  """
119  Computes the transform `x => y`.
120  """
121  if self._cache_size == 0:
122  return self._call(x)
123  x_old, y_old = self._cached_x_y
124  if x is x_old:
125  return y_old
126  y = self._call(x)
127  self._cached_x_y = x, y
128  return y
129 
130  def _inv_call(self, y):
131  """
132  Inverts the transform `y => x`.
133  """
134  if self._cache_size == 0:
135  return self._inverse(y)
136  x_old, y_old = self._cached_x_y
137  if y is y_old:
138  return x_old
139  x = self._inverse(y)
140  self._cached_x_y = x, y
141  return x
142 
143  def _call(self, x):
144  """
145  Abstract method to compute forward transformation.
146  """
147  raise NotImplementedError
148 
149  def _inverse(self, y):
150  """
151  Abstract method to compute inverse transformation.
152  """
153  raise NotImplementedError
154 
155  def log_abs_det_jacobian(self, x, y):
156  """
157  Computes the log det jacobian `log |dy/dx|` given input and output.
158  """
159  raise NotImplementedError
160 
161  def __repr__(self):
162  return self.__class__.__name__ + '()'
163 
164 
166  """
167  Inverts a single :class:`Transform`.
168  This class is private; please instead use the ``Transform.inv`` property.
169  """
170  def __init__(self, transform):
171  super(_InverseTransform, self).__init__()
172  self._inv = transform
173 
174  @constraints.dependent_property
175  def domain(self):
176  return self._inv.codomain
177 
178  @constraints.dependent_property
179  def codomain(self):
180  return self._inv.domain
181 
182  @property
183  def bijective(self):
184  return self._inv.bijective
185 
186  @property
187  def sign(self):
188  return self._inv.sign
189 
190  @property
191  def event_dim(self):
192  return self._inv.event_dim
193 
194  @property
195  def inv(self):
196  return self._inv
197 
198  def __eq__(self, other):
199  if not isinstance(other, _InverseTransform):
200  return False
201  return self._inv == other._inv
202 
203  def __call__(self, x):
204  return self._inv._inv_call(x)
205 
206  def log_abs_det_jacobian(self, x, y):
207  return -self._inv.log_abs_det_jacobian(y, x)
208 
209 
211  """
212  Composes multiple transforms in a chain.
213  The transforms being composed are responsible for caching.
214 
215  Args:
216  parts (list of :class:`Transform`): A list of transforms to compose.
217  """
218  def __init__(self, parts):
219  super(ComposeTransform, self).__init__()
220  self.parts = parts
221 
222  def __eq__(self, other):
223  if not isinstance(other, ComposeTransform):
224  return False
225  return self.parts == other.parts
226 
227  @constraints.dependent_property
228  def domain(self):
229  if not self.parts:
230  return constraints.real
231  return self.parts[0].domain
232 
233  @constraints.dependent_property
234  def codomain(self):
235  if not self.parts:
236  return constraints.real
237  return self.parts[-1].codomain
238 
239  @lazy_property
240  def bijective(self):
241  return all(p.bijective for p in self.parts)
242 
243  @lazy_property
244  def sign(self):
245  sign = 1
246  for p in self.parts:
247  sign = sign * p.sign
248  return sign
249 
250  @lazy_property
251  def event_dim(self):
252  return max(p.event_dim for p in self.parts) if self.parts else 0
253 
254  @property
255  def inv(self):
256  inv = None
257  if self._inv is not None:
258  inv = self._inv()
259  if inv is None:
260  inv = ComposeTransform([p.inv for p in reversed(self.parts)])
261  self._inv = weakref.ref(inv)
262  inv._inv = weakref.ref(self)
263  return inv
264 
265  def __call__(self, x):
266  for part in self.parts:
267  x = part(x)
268  return x
269 
270  def log_abs_det_jacobian(self, x, y):
271  if not self.parts:
272  return torch.zeros_like(x)
273  result = 0
274  for part in self.parts:
275  y = part(x)
276  result = result + _sum_rightmost(part.log_abs_det_jacobian(x, y),
277  self.event_dim - part.event_dim)
278  x = y
279  return result
280 
281  def __repr__(self):
282  fmt_string = self.__class__.__name__ + '(\n '
283  fmt_string += ',\n '.join([p.__repr__() for p in self.parts])
284  fmt_string += '\n)'
285  return fmt_string
286 
287 
288 identity_transform = ComposeTransform([])
289 
290 
292  r"""
293  Transform via the mapping :math:`y = \exp(x)`.
294  """
295  domain = constraints.real
296  codomain = constraints.positive
297  bijective = True
298  sign = +1
299 
300  def __eq__(self, other):
301  return isinstance(other, ExpTransform)
302 
303  def _call(self, x):
304  return x.exp()
305 
306  def _inverse(self, y):
307  return y.log()
308 
309  def log_abs_det_jacobian(self, x, y):
310  return x
311 
312 
314  r"""
315  Transform via the mapping :math:`y = x^{\text{exponent}}`.
316  """
317  domain = constraints.positive
318  codomain = constraints.positive
319  bijective = True
320  sign = +1
321 
322  def __init__(self, exponent, cache_size=0):
323  super(PowerTransform, self).__init__(cache_size=cache_size)
324  self.exponent, = broadcast_all(exponent)
325 
326  def __eq__(self, other):
327  if not isinstance(other, PowerTransform):
328  return False
329  return self.exponent.eq(other.exponent).all().item()
330 
331  def _call(self, x):
332  return x.pow(self.exponent)
333 
334  def _inverse(self, y):
335  return y.pow(1 / self.exponent)
336 
337  def log_abs_det_jacobian(self, x, y):
338  return (self.exponent * y / x).abs().log()
339 
340 
342  r"""
343  Transform via the mapping :math:`y = \frac{1}{1 + \exp(-x)}` and :math:`x = \text{logit}(y)`.
344  """
345  domain = constraints.real
346  codomain = constraints.unit_interval
347  bijective = True
348  sign = +1
349 
350  def __eq__(self, other):
351  return isinstance(other, SigmoidTransform)
352 
353  def _call(self, x):
354  return torch.sigmoid(x)
355 
356  def _inverse(self, y):
357  return y.log() - (-y).log1p()
358 
359  def log_abs_det_jacobian(self, x, y):
360  return -(y.reciprocal() + (1 - y).reciprocal()).log()
361 
362 
364  r"""
365  Transform via the mapping :math:`y = |x|`.
366  """
367  domain = constraints.real
368  codomain = constraints.positive
369 
370  def __eq__(self, other):
371  return isinstance(other, AbsTransform)
372 
373  def _call(self, x):
374  return x.abs()
375 
376  def _inverse(self, y):
377  return y
378 
379 
381  r"""
382  Transform via the pointwise affine mapping :math:`y = \text{loc} + \text{scale} \times x`.
383 
384  Args:
385  loc (Tensor or float): Location parameter.
386  scale (Tensor or float): Scale parameter.
387  event_dim (int): Optional size of `event_shape`. This should be zero
388  for univariate random variables, 1 for distributions over vectors,
389  2 for distributions over matrices, etc.
390  """
391  domain = constraints.real
392  codomain = constraints.real
393  bijective = True
394 
395  def __init__(self, loc, scale, event_dim=0, cache_size=0):
396  super(AffineTransform, self).__init__(cache_size=cache_size)
397  self.loc = loc
398  self.scale = scale
399  self.event_dim = event_dim
400 
401  def __eq__(self, other):
402  if not isinstance(other, AffineTransform):
403  return False
404 
405  if isinstance(self.loc, numbers.Number) and isinstance(other.loc, numbers.Number):
406  if self.loc != other.loc:
407  return False
408  else:
409  if not (self.loc == other.loc).all().item():
410  return False
411 
412  if isinstance(self.scale, numbers.Number) and isinstance(other.scale, numbers.Number):
413  if self.scale != other.scale:
414  return False
415  else:
416  if not (self.scale == other.scale).all().item():
417  return False
418 
419  return True
420 
421  @property
422  def sign(self):
423  if isinstance(self.scale, numbers.Number):
424  return 1 if self.scale > 0 else -1 if self.scale < 0 else 0
425  return self.scale.sign()
426 
427  def _call(self, x):
428  return self.loc + self.scale * x
429 
430  def _inverse(self, y):
431  return (y - self.loc) / self.scale
432 
433  def log_abs_det_jacobian(self, x, y):
434  shape = x.shape
435  scale = self.scale
436  if isinstance(scale, numbers.Number):
437  result = x.new_empty(shape).fill_(math.log(abs(scale)))
438  else:
439  result = torch.abs(scale).log()
440  if self.event_dim:
441  result_size = result.size()[:-self.event_dim] + (-1,)
442  result = result.view(result_size).sum(-1)
443  shape = shape[:-self.event_dim]
444  return result.expand(shape)
445 
446 
448  r"""
449  Transform from unconstrained space to the simplex via :math:`y = \exp(x)` then
450  normalizing.
451 
452  This is not bijective and cannot be used for HMC. However this acts mostly
453  coordinate-wise (except for the final normalization), and thus is
454  appropriate for coordinate-wise optimization algorithms.
455  """
456  domain = constraints.real
457  codomain = constraints.simplex
458  event_dim = 1
459 
460  def __eq__(self, other):
461  return isinstance(other, SoftmaxTransform)
462 
463  def _call(self, x):
464  logprobs = x
465  probs = (logprobs - logprobs.max(-1, True)[0]).exp()
466  return probs / probs.sum(-1, True)
467 
468  def _inverse(self, y):
469  probs = y
470  return probs.log()
471 
472 
474  """
475  Transform from unconstrained space to the simplex of one additional
476  dimension via a stick-breaking process.
477 
478  This transform arises as an iterated sigmoid transform in a stick-breaking
479  construction of the `Dirichlet` distribution: the first logit is
480  transformed via sigmoid to the first probability and the probability of
481  everything else, and then the process recurses.
482 
483  This is bijective and appropriate for use in HMC; however it mixes
484  coordinates together and is less appropriate for optimization.
485  """
486  domain = constraints.real
487  codomain = constraints.simplex
488  bijective = True
489  event_dim = 1
490 
491  def __eq__(self, other):
492  return isinstance(other, StickBreakingTransform)
493 
494  def _call(self, x):
495  offset = (x.shape[-1] + 1) - x.new([1]).expand(x.shape).cumsum(-1)
496  z = torch.sigmoid(x - offset.log())
497  z_cumprod = (1 - z).cumprod(-1)
498  y = pad(z, (0, 1), value=1) * pad(z_cumprod, (1, 0), value=1)
499  return y
500 
501  def _inverse(self, y):
502  shape = y.shape[:-1] + (y.shape[-1] - 1,)
503  offset = (shape[-1] + 1) - y.new([1]).expand(shape).cumsum(-1)
504  sf = (1 - y.cumsum(-1))[..., :-1]
505  x = y[..., :-1].log() - sf.log() + offset.log()
506  return x
507 
508  def log_abs_det_jacobian(self, x, y):
509  offset = (x.shape[-1] + 1) - x.new([1]).expand(x.shape).cumsum(-1)
510  z = torch.sigmoid(x - offset.log())
511  detJ = ((1 - z).log() + y[..., :-1].log()).sum(-1)
512  return detJ
513 
514 
516  """
517  Transform from unconstrained matrices to lower-triangular matrices with
518  nonnegative diagonal entries.
519 
520  This is useful for parameterizing positive definite matrices in terms of
521  their Cholesky factorization.
522  """
523  domain = constraints.real
524  codomain = constraints.lower_cholesky
525  event_dim = 2
526 
527  def __eq__(self, other):
528  return isinstance(other, LowerCholeskyTransform)
529 
530  def _call_on_event(self, x):
531  return x.tril(-1) + x.diag().exp().diag()
532 
533  def _inverse_on_event(self, y):
534  return y.tril(-1) + y.diag().log().diag()
535 
536  def _call(self, x):
537  flat_x = x.contiguous().view((-1,) + x.shape[-2:])
538  return torch.stack([self._call_on_event(flat_x[i]) for i in range(flat_x.size(0))]).view(x.shape)
539 
540  def _inverse(self, y):
541  flat_y = y.contiguous().view((-1,) + y.shape[-2:])
542  return torch.stack([self._inverse_on_event(flat_y[i]) for i in range(flat_y.size(0))]).view(y.shape)