Caffe2 - Python API
A deep learning, cross platform ML framework
constraint_registry.py
1 r"""
2 PyTorch provides two global :class:`ConstraintRegistry` objects that link
3 :class:`~torch.distributions.constraints.Constraint` objects to
4 :class:`~torch.distributions.transforms.Transform` objects. These objects both
5 input constraints and return transforms, but they have different guarantees on
6 bijectivity.
7 
8 1. ``biject_to(constraint)`` looks up a bijective
9  :class:`~torch.distributions.transforms.Transform` from ``constraints.real``
10  to the given ``constraint``. The returned transform is guaranteed to have
11  ``.bijective = True`` and should implement ``.log_abs_det_jacobian()``.
12 2. ``transform_to(constraint)`` looks up a not-necessarily bijective
13  :class:`~torch.distributions.transforms.Transform` from ``constraints.real``
14  to the given ``constraint``. The returned transform is not guaranteed to
15  implement ``.log_abs_det_jacobian()``.
16 
17 The ``transform_to()`` registry is useful for performing unconstrained
18 optimization on constrained parameters of probability distributions, which are
19 indicated by each distribution's ``.arg_constraints`` dict. These transforms often
20 overparameterize a space in order to avoid rotation; they are thus more
21 suitable for coordinate-wise optimization algorithms like Adam::
22 
23  loc = torch.zeros(100, requires_grad=True)
24  unconstrained = torch.zeros(100, requires_grad=True)
25  scale = transform_to(Normal.arg_constraints['scale'])(unconstrained)
26  loss = -Normal(loc, scale).log_prob(data).sum()
27 
28 The ``biject_to()`` registry is useful for Hamiltonian Monte Carlo, where
29 samples from a probability distribution with constrained ``.support`` are
30 propagated in an unconstrained space, and algorithms are typically rotation
31 invariant.::
32 
33  dist = Exponential(rate)
34  unconstrained = torch.zeros(100, requires_grad=True)
35  sample = biject_to(dist.support)(unconstrained)
36  potential_energy = -dist.log_prob(sample).sum()
37 
38 .. note::
39 
40  An example where ``transform_to`` and ``biject_to`` differ is
41  ``constraints.simplex``: ``transform_to(constraints.simplex)`` returns a
42  :class:`~torch.distributions.transforms.SoftmaxTransform` that simply
43  exponentiates and normalizes its inputs; this is a cheap and mostly
44  coordinate-wise operation appropriate for algorithms like SVI. In
45  contrast, ``biject_to(constraints.simplex)`` returns a
46  :class:`~torch.distributions.transforms.StickBreakingTransform` that
47  bijects its input down to a one-fewer-dimensional space; this a more
48  expensive less numerically stable transform but is needed for algorithms
49  like HMC.
50 
51 The ``biject_to`` and ``transform_to`` objects can be extended by user-defined
52 constraints and transforms using their ``.register()`` method either as a
53 function on singleton constraints::
54 
55  transform_to.register(my_constraint, my_transform)
56 
57 or as a decorator on parameterized constraints::
58 
59  @transform_to.register(MyConstraintClass)
60  def my_factory(constraint):
61  assert isinstance(constraint, MyConstraintClass)
62  return MyTransform(constraint.param1, constraint.param2)
63 
64 You can create your own registry by creating a new :class:`ConstraintRegistry`
65 object.
66 """
67 
68 import numbers
69 
70 from torch.distributions import constraints, transforms
71 
72 __all__ = [
73  'ConstraintRegistry',
74  'biject_to',
75  'transform_to',
76 ]
77 
78 
79 class ConstraintRegistry(object):
80  """
81  Registry to link constraints to transforms.
82  """
83  def __init__(self):
84  self._registry = {}
85  super(ConstraintRegistry, self).__init__()
86 
87  def register(self, constraint, factory=None):
88  """
89  Registers a :class:`~torch.distributions.constraints.Constraint`
90  subclass in this registry. Usage::
91 
92  @my_registry.register(MyConstraintClass)
93  def construct_transform(constraint):
94  assert isinstance(constraint, MyConstraint)
95  return MyTransform(constraint.arg_constraints)
96 
97  Args:
98  constraint (subclass of :class:`~torch.distributions.constraints.Constraint`):
99  A subclass of :class:`~torch.distributions.constraints.Constraint`, or
100  a singleton object of the desired class.
101  factory (callable): A callable that inputs a constraint object and returns
102  a :class:`~torch.distributions.transforms.Transform` object.
103  """
104  # Support use as decorator.
105  if factory is None:
106  return lambda factory: self.register(constraint, factory)
107 
108  # Support calling on singleton instances.
109  if isinstance(constraint, constraints.Constraint):
110  constraint = type(constraint)
111 
112  if not isinstance(constraint, type) or not issubclass(constraint, constraints.Constraint):
113  raise TypeError('Expected constraint to be either a Constraint subclass or instance, '
114  'but got {}'.format(constraint))
115 
116  self._registry[constraint] = factory
117  return factory
118 
119  def __call__(self, constraint):
120  """
121  Looks up a transform to constrained space, given a constraint object.
122  Usage::
123 
124  constraint = Normal.arg_constraints['scale']
125  scale = transform_to(constraint)(torch.zeros(1)) # constrained
126  u = transform_to(constraint).inv(scale) # unconstrained
127 
128  Args:
129  constraint (:class:`~torch.distributions.constraints.Constraint`):
130  A constraint object.
131 
132  Returns:
133  A :class:`~torch.distributions.transforms.Transform` object.
134 
135  Raises:
136  `NotImplementedError` if no transform has been registered.
137  """
138  # Look up by Constraint subclass.
139  try:
140  factory = self._registry[type(constraint)]
141  except KeyError:
142  raise NotImplementedError(
143  'Cannot transform {} constraints'.format(type(constraint).__name__))
144  return factory(constraint)
145 
146 
147 biject_to = ConstraintRegistry()
148 transform_to = ConstraintRegistry()
149 
150 
151 ################################################################################
152 # Registration Table
153 ################################################################################
154 
155 @biject_to.register(constraints.real)
156 @biject_to.register(constraints.real_vector)
157 @transform_to.register(constraints.real)
158 @transform_to.register(constraints.real_vector)
159 def _transform_to_real(constraint):
160  return transforms.identity_transform
161 
162 
163 @biject_to.register(constraints.positive)
164 @transform_to.register(constraints.positive)
165 def _transform_to_positive(constraint):
166  return transforms.ExpTransform()
167 
168 
169 @biject_to.register(constraints.greater_than)
170 @biject_to.register(constraints.greater_than_eq)
171 @transform_to.register(constraints.greater_than)
172 @transform_to.register(constraints.greater_than_eq)
173 def _transform_to_greater_than(constraint):
175  transforms.AffineTransform(constraint.lower_bound, 1)])
176 
177 
178 @biject_to.register(constraints.less_than)
179 @transform_to.register(constraints.less_than)
180 def _transform_to_less_than(constraint):
182  transforms.AffineTransform(constraint.upper_bound, -1)])
183 
184 
185 @biject_to.register(constraints.interval)
186 @biject_to.register(constraints.half_open_interval)
187 @transform_to.register(constraints.interval)
188 @transform_to.register(constraints.half_open_interval)
189 def _transform_to_interval(constraint):
190  # Handle the special case of the unit interval.
191  lower_is_0 = isinstance(constraint.lower_bound, numbers.Number) and constraint.lower_bound == 0
192  upper_is_1 = isinstance(constraint.upper_bound, numbers.Number) and constraint.upper_bound == 1
193  if lower_is_0 and upper_is_1:
195 
196  loc = constraint.lower_bound
197  scale = constraint.upper_bound - constraint.lower_bound
199  transforms.AffineTransform(loc, scale)])
200 
201 
202 @biject_to.register(constraints.simplex)
203 def _biject_to_simplex(constraint):
205 
206 
207 @transform_to.register(constraints.simplex)
208 def _transform_to_simplex(constraint):
210 
211 
212 # TODO define a bijection for LowerCholeskyTransform
213 @transform_to.register(constraints.lower_cholesky)
214 def _transform_to_lower_cholesky(constraint):