2 PyTorch provides two global :class:`ConstraintRegistry` objects that link 3 :class:`~torch.distributions.constraints.Constraint` objects to 4 :class:`~torch.distributions.transforms.Transform` objects. These objects both 5 input constraints and return transforms, but they have different guarantees on 8 1. ``biject_to(constraint)`` looks up a bijective 9 :class:`~torch.distributions.transforms.Transform` from ``constraints.real`` 10 to the given ``constraint``. The returned transform is guaranteed to have 11 ``.bijective = True`` and should implement ``.log_abs_det_jacobian()``. 12 2. ``transform_to(constraint)`` looks up a not-necessarily bijective 13 :class:`~torch.distributions.transforms.Transform` from ``constraints.real`` 14 to the given ``constraint``. The returned transform is not guaranteed to 15 implement ``.log_abs_det_jacobian()``. 17 The ``transform_to()`` registry is useful for performing unconstrained 18 optimization on constrained parameters of probability distributions, which are 19 indicated by each distribution's ``.arg_constraints`` dict. These transforms often 20 overparameterize a space in order to avoid rotation; they are thus more 21 suitable for coordinate-wise optimization algorithms like Adam:: 23 loc = torch.zeros(100, requires_grad=True) 24 unconstrained = torch.zeros(100, requires_grad=True) 25 scale = transform_to(Normal.arg_constraints['scale'])(unconstrained) 26 loss = -Normal(loc, scale).log_prob(data).sum() 28 The ``biject_to()`` registry is useful for Hamiltonian Monte Carlo, where 29 samples from a probability distribution with constrained ``.support`` are 30 propagated in an unconstrained space, and algorithms are typically rotation 33 dist = Exponential(rate) 34 unconstrained = torch.zeros(100, requires_grad=True) 35 sample = biject_to(dist.support)(unconstrained) 36 potential_energy = -dist.log_prob(sample).sum() 40 An example where ``transform_to`` and ``biject_to`` differ is 41 ``constraints.simplex``: ``transform_to(constraints.simplex)`` returns a 42 :class:`~torch.distributions.transforms.SoftmaxTransform` that simply 43 exponentiates and normalizes its inputs; this is a cheap and mostly 44 coordinate-wise operation appropriate for algorithms like SVI. In 45 contrast, ``biject_to(constraints.simplex)`` returns a 46 :class:`~torch.distributions.transforms.StickBreakingTransform` that 47 bijects its input down to a one-fewer-dimensional space; this a more 48 expensive less numerically stable transform but is needed for algorithms 51 The ``biject_to`` and ``transform_to`` objects can be extended by user-defined 52 constraints and transforms using their ``.register()`` method either as a 53 function on singleton constraints:: 55 transform_to.register(my_constraint, my_transform) 57 or as a decorator on parameterized constraints:: 59 @transform_to.register(MyConstraintClass) 60 def my_factory(constraint): 61 assert isinstance(constraint, MyConstraintClass) 62 return MyTransform(constraint.param1, constraint.param2) 64 You can create your own registry by creating a new :class:`ConstraintRegistry` 81 Registry to link constraints to transforms. 85 super(ConstraintRegistry, self).__init__()
89 Registers a :class:`~torch.distributions.constraints.Constraint` 90 subclass in this registry. Usage:: 92 @my_registry.register(MyConstraintClass) 93 def construct_transform(constraint): 94 assert isinstance(constraint, MyConstraint) 95 return MyTransform(constraint.arg_constraints) 98 constraint (subclass of :class:`~torch.distributions.constraints.Constraint`): 99 A subclass of :class:`~torch.distributions.constraints.Constraint`, or 100 a singleton object of the desired class. 101 factory (callable): A callable that inputs a constraint object and returns 102 a :class:`~torch.distributions.transforms.Transform` object. 106 return lambda factory: self.
register(constraint, factory)
110 constraint = type(constraint)
113 raise TypeError(
'Expected constraint to be either a Constraint subclass or instance, ' 114 'but got {}'.format(constraint))
121 Looks up a transform to constrained space, given a constraint object. 124 constraint = Normal.arg_constraints['scale'] 125 scale = transform_to(constraint)(torch.zeros(1)) # constrained 126 u = transform_to(constraint).inv(scale) # unconstrained 129 constraint (:class:`~torch.distributions.constraints.Constraint`): 133 A :class:`~torch.distributions.transforms.Transform` object. 136 `NotImplementedError` if no transform has been registered. 140 factory = self.
_registry[type(constraint)]
142 raise NotImplementedError(
143 'Cannot transform {} constraints'.format(type(constraint).__name__))
144 return factory(constraint)
155 @biject_to.register(constraints.real)
156 @biject_to.register(constraints.real_vector)
157 @transform_to.register(constraints.real)
158 @transform_to.register(constraints.real_vector)
159 def _transform_to_real(constraint):
160 return transforms.identity_transform
163 @biject_to.register(constraints.positive)
164 @transform_to.register(constraints.positive)
165 def _transform_to_positive(constraint):
169 @biject_to.register(constraints.greater_than)
170 @biject_to.register(constraints.greater_than_eq)
171 @transform_to.register(constraints.greater_than)
172 @transform_to.register(constraints.greater_than_eq)
173 def _transform_to_greater_than(constraint):
178 @biject_to.register(constraints.less_than)
179 @transform_to.register(constraints.less_than)
180 def _transform_to_less_than(constraint):
185 @biject_to.register(constraints.interval)
186 @biject_to.register(constraints.half_open_interval)
187 @transform_to.register(constraints.interval)
188 @transform_to.register(constraints.half_open_interval)
189 def _transform_to_interval(constraint):
191 lower_is_0 = isinstance(constraint.lower_bound, numbers.Number)
and constraint.lower_bound == 0
192 upper_is_1 = isinstance(constraint.upper_bound, numbers.Number)
and constraint.upper_bound == 1
193 if lower_is_0
and upper_is_1:
196 loc = constraint.lower_bound
197 scale = constraint.upper_bound - constraint.lower_bound
202 @biject_to.register(constraints.simplex)
203 def _biject_to_simplex(constraint):
207 @transform_to.register(constraints.simplex)
208 def _transform_to_simplex(constraint):
213 @transform_to.register(constraints.lower_cholesky)
214 def _transform_to_lower_cholesky(constraint):
def register(self, constraint, factory=None)
def __call__(self, constraint)