16 'LowerCholeskyTransform',
20 'StickBreakingTransform',
28 Abstract class for invertable transformations with computable log 29 det jacobians. They are primarily used in 30 :class:`torch.distributions.TransformedDistribution`. 32 Caching is useful for tranforms whose inverses are either expensive or 33 numerically unstable. Note that care must be taken with memoized values 34 since the autograd graph may be reversed. For example while the following 35 works with or without caching:: 38 t.log_abs_det_jacobian(x, y).backward() # x will receive gradients. 40 However the following will error when caching due to dependency reversal:: 44 grad(z.sum(), [y]) # error because z is x 46 Derived classes should implement one or both of :meth:`_call` or 47 :meth:`_inverse`. Derived classes that set `bijective=True` should also 48 implement :meth:`log_abs_det_jacobian`. 51 cache_size (int): Size of cache. If zero, no caching is done. If one, 52 the latest single value is cached. Only 0 and 1 are supported. 55 domain (:class:`~torch.distributions.constraints.Constraint`): 56 The constraint representing valid inputs to this transform. 57 codomain (:class:`~torch.distributions.constraints.Constraint`): 58 The constraint representing valid outputs to this transform 59 which are inputs to the inverse transform. 60 bijective (bool): Whether this transform is bijective. A transform 61 ``t`` is bijective iff ``t.inv(t(x)) == x`` and 62 ``t(t.inv(y)) == y`` for every ``x`` in the domain and ``y`` in 63 the codomain. Transforms that are not bijective should at least 64 maintain the weaker pseudoinverse properties 65 ``t(t.inv(t(x)) == t(x)`` and ``t.inv(t(t.inv(y))) == t.inv(y)``. 66 sign (int or Tensor): For bijective univariate transforms, this 67 should be +1 or -1 depending on whether transform is monotone 68 increasing or decreasing. 69 event_dim (int): Number of dimensions that are correlated together in 70 the transform ``event_shape``. This should be 0 for pointwise 71 transforms, 1 for transforms that act jointly on vectors, 2 for 72 transforms that act jointly on matrices, etc. 77 def __init__(self, cache_size=0):
85 raise ValueError(
'cache_size must be 0 or 1')
86 super(Transform, self).__init__()
91 Returns the inverse :class:`Transform` of this transform. 92 This should satisfy ``t.inv.inv is t``. 95 if self.
_inv is not None:
99 self.
_inv = weakref.ref(inv)
105 Returns the sign of the determinant of the Jacobian, if applicable. 106 In general this only makes sense for bijective transforms. 108 raise NotImplementedError
110 def __eq__(self, other):
113 def __ne__(self, other):
115 return not self.
__eq__(other)
119 Computes the transform `x => y`. 130 def _inv_call(self, y):
132 Inverts the transform `y => x`. 145 Abstract method to compute forward transformation. 147 raise NotImplementedError
149 def _inverse(self, y):
151 Abstract method to compute inverse transformation. 153 raise NotImplementedError
157 Computes the log det jacobian `log |dy/dx|` given input and output. 159 raise NotImplementedError
162 return self.__class__.__name__ +
'()' 167 Inverts a single :class:`Transform`. 168 This class is private; please instead use the ``Transform.inv`` property. 170 def __init__(self, transform):
171 super(_InverseTransform, self).__init__()
172 self.
_inv = transform
174 @constraints.dependent_property
176 return self._inv.codomain
178 @constraints.dependent_property
180 return self._inv.domain
184 return self._inv.bijective
188 return self._inv.sign
192 return self._inv.event_dim
198 def __eq__(self, other):
199 if not isinstance(other, _InverseTransform):
201 return self.
_inv == other._inv
203 def __call__(self, x):
204 return self._inv._inv_call(x)
206 def log_abs_det_jacobian(self, x, y):
207 return -self._inv.log_abs_det_jacobian(y, x)
212 Composes multiple transforms in a chain. 213 The transforms being composed are responsible for caching. 216 parts (list of :class:`Transform`): A list of transforms to compose. 218 def __init__(self, parts):
219 super(ComposeTransform, self).__init__()
222 def __eq__(self, other):
223 if not isinstance(other, ComposeTransform):
225 return self.
parts == other.parts
227 @constraints.dependent_property
230 return constraints.real
231 return self.
parts[0].domain
233 @constraints.dependent_property
236 return constraints.real
237 return self.
parts[-1].codomain
241 return all(p.bijective
for p
in self.
parts)
252 return max(p.event_dim
for p
in self.
parts)
if self.
parts else 0
257 if self.
_inv is not None:
261 self.
_inv = weakref.ref(inv)
262 inv._inv = weakref.ref(self)
265 def __call__(self, x):
266 for part
in self.
parts:
270 def log_abs_det_jacobian(self, x, y):
272 return torch.zeros_like(x)
274 for part
in self.
parts:
276 result = result + _sum_rightmost(part.log_abs_det_jacobian(x, y),
282 fmt_string = self.__class__.__name__ +
'(\n ' 283 fmt_string +=
',\n '.join([p.__repr__()
for p
in self.
parts])
293 Transform via the mapping :math:`y = \exp(x)`. 295 domain = constraints.real
296 codomain = constraints.positive
300 def __eq__(self, other):
301 return isinstance(other, ExpTransform)
306 def _inverse(self, y):
309 def log_abs_det_jacobian(self, x, y):
315 Transform via the mapping :math:`y = x^{\text{exponent}}`. 317 domain = constraints.positive
318 codomain = constraints.positive
322 def __init__(self, exponent, cache_size=0):
323 super(PowerTransform, self).__init__(cache_size=cache_size)
324 self.exponent, = broadcast_all(exponent)
326 def __eq__(self, other):
327 if not isinstance(other, PowerTransform):
329 return self.exponent.eq(other.exponent).all().item()
332 return x.pow(self.exponent)
334 def _inverse(self, y):
335 return y.pow(1 / self.exponent)
337 def log_abs_det_jacobian(self, x, y):
338 return (self.exponent * y / x).abs().log()
343 Transform via the mapping :math:`y = \frac{1}{1 + \exp(-x)}` and :math:`x = \text{logit}(y)`. 345 domain = constraints.real
346 codomain = constraints.unit_interval
350 def __eq__(self, other):
351 return isinstance(other, SigmoidTransform)
354 return torch.sigmoid(x)
356 def _inverse(self, y):
357 return y.log() - (-y).log1p()
359 def log_abs_det_jacobian(self, x, y):
360 return -(y.reciprocal() + (1 - y).reciprocal()).log()
365 Transform via the mapping :math:`y = |x|`. 367 domain = constraints.real
368 codomain = constraints.positive
370 def __eq__(self, other):
371 return isinstance(other, AbsTransform)
376 def _inverse(self, y):
382 Transform via the pointwise affine mapping :math:`y = \text{loc} + \text{scale} \times x`. 385 loc (Tensor or float): Location parameter. 386 scale (Tensor or float): Scale parameter. 387 event_dim (int): Optional size of `event_shape`. This should be zero 388 for univariate random variables, 1 for distributions over vectors, 389 2 for distributions over matrices, etc. 391 domain = constraints.real
392 codomain = constraints.real
395 def __init__(self, loc, scale, event_dim=0, cache_size=0):
396 super(AffineTransform, self).__init__(cache_size=cache_size)
401 def __eq__(self, other):
402 if not isinstance(other, AffineTransform):
405 if isinstance(self.
loc, numbers.Number)
and isinstance(other.loc, numbers.Number):
406 if self.
loc != other.loc:
409 if not (self.
loc == other.loc).all().item():
412 if isinstance(self.
scale, numbers.Number)
and isinstance(other.scale, numbers.Number):
413 if self.
scale != other.scale:
416 if not (self.
scale == other.scale).all().item():
423 if isinstance(self.
scale, numbers.Number):
424 return 1
if self.
scale > 0
else -1
if self.
scale < 0
else 0
425 return self.scale.sign()
430 def _inverse(self, y):
433 def log_abs_det_jacobian(self, x, y):
436 if isinstance(scale, numbers.Number):
437 result = x.new_empty(shape).fill_(math.log(abs(scale)))
439 result = torch.abs(scale).log()
441 result_size = result.size()[:-self.
event_dim] + (-1,)
442 result = result.view(result_size).sum(-1)
444 return result.expand(shape)
449 Transform from unconstrained space to the simplex via :math:`y = \exp(x)` then 452 This is not bijective and cannot be used for HMC. However this acts mostly 453 coordinate-wise (except for the final normalization), and thus is 454 appropriate for coordinate-wise optimization algorithms. 456 domain = constraints.real
457 codomain = constraints.simplex
460 def __eq__(self, other):
461 return isinstance(other, SoftmaxTransform)
465 probs = (logprobs - logprobs.max(-1,
True)[0]).exp()
466 return probs / probs.sum(-1,
True)
468 def _inverse(self, y):
475 Transform from unconstrained space to the simplex of one additional 476 dimension via a stick-breaking process. 478 This transform arises as an iterated sigmoid transform in a stick-breaking 479 construction of the `Dirichlet` distribution: the first logit is 480 transformed via sigmoid to the first probability and the probability of 481 everything else, and then the process recurses. 483 This is bijective and appropriate for use in HMC; however it mixes 484 coordinates together and is less appropriate for optimization. 486 domain = constraints.real
487 codomain = constraints.simplex
491 def __eq__(self, other):
492 return isinstance(other, StickBreakingTransform)
495 offset = (x.shape[-1] + 1) - x.new([1]).expand(x.shape).cumsum(-1)
496 z = torch.sigmoid(x - offset.log())
497 z_cumprod = (1 - z).cumprod(-1)
498 y = pad(z, (0, 1), value=1) * pad(z_cumprod, (1, 0), value=1)
501 def _inverse(self, y):
502 shape = y.shape[:-1] + (y.shape[-1] - 1,)
503 offset = (shape[-1] + 1) - y.new([1]).expand(shape).cumsum(-1)
504 sf = (1 - y.cumsum(-1))[..., :-1]
505 x = y[..., :-1].log() - sf.log() + offset.log()
508 def log_abs_det_jacobian(self, x, y):
509 offset = (x.shape[-1] + 1) - x.new([1]).expand(x.shape).cumsum(-1)
510 z = torch.sigmoid(x - offset.log())
511 detJ = ((1 - z).log() + y[..., :-1].log()).sum(-1)
517 Transform from unconstrained matrices to lower-triangular matrices with 518 nonnegative diagonal entries. 520 This is useful for parameterizing positive definite matrices in terms of 521 their Cholesky factorization. 523 domain = constraints.real
524 codomain = constraints.lower_cholesky
527 def __eq__(self, other):
528 return isinstance(other, LowerCholeskyTransform)
530 def _call_on_event(self, x):
531 return x.tril(-1) + x.diag().exp().diag()
533 def _inverse_on_event(self, y):
534 return y.tril(-1) + y.diag().log().diag()
537 flat_x = x.contiguous().view((-1,) + x.shape[-2:])
538 return torch.stack([self.
_call_on_event(flat_x[i])
for i
in range(flat_x.size(0))]).view(x.shape)
540 def _inverse(self, y):
541 flat_y = y.contiguous().view((-1,) + y.shape[-2:])
542 return torch.stack([self.
_inverse_on_event(flat_y[i])
for i
in range(flat_y.size(0))]).view(y.shape)