Caffe2 - Python API
A deep learning, cross platform ML framework
constraints.py
1 r"""
2 The following constraints are implemented:
3 
4 - ``constraints.boolean``
5 - ``constraints.dependent``
6 - ``constraints.greater_than(lower_bound)``
7 - ``constraints.integer_interval(lower_bound, upper_bound)``
8 - ``constraints.interval(lower_bound, upper_bound)``
9 - ``constraints.lower_cholesky``
10 - ``constraints.lower_triangular``
11 - ``constraints.nonnegative_integer``
12 - ``constraints.positive``
13 - ``constraints.positive_definite``
14 - ``constraints.positive_integer``
15 - ``constraints.real``
16 - ``constraints.real_vector``
17 - ``constraints.simplex``
18 - ``constraints.unit_interval``
19 """
20 
21 import torch
22 
23 __all__ = [
24  'Constraint',
25  'boolean',
26  'dependent',
27  'dependent_property',
28  'greater_than',
29  'greater_than_eq',
30  'integer_interval',
31  'interval',
32  'half_open_interval',
33  'is_dependent',
34  'less_than',
35  'lower_cholesky',
36  'lower_triangular',
37  'nonnegative_integer',
38  'positive',
39  'positive_definite',
40  'positive_integer',
41  'real',
42  'real_vector',
43  'simplex',
44  'unit_interval',
45 ]
46 
47 
48 class Constraint(object):
49  """
50  Abstract base class for constraints.
51 
52  A constraint object represents a region over which a variable is valid,
53  e.g. within which a variable can be optimized.
54  """
55  def check(self, value):
56  """
57  Returns a byte tensor of `sample_shape + batch_shape` indicating
58  whether each event in value satisfies this constraint.
59  """
60  raise NotImplementedError
61 
62  def __repr__(self):
63  return self.__class__.__name__[1:] + '()'
64 
65 
67  """
68  Placeholder for variables whose support depends on other variables.
69  These variables obey no simple coordinate-wise constraints.
70  """
71  def check(self, x):
72  raise ValueError('Cannot determine validity of dependent constraint')
73 
74 
75 def is_dependent(constraint):
76  return isinstance(constraint, _Dependent)
77 
78 
79 class _DependentProperty(property, _Dependent):
80  """
81  Decorator that extends @property to act like a `Dependent` constraint when
82  called on a class and act like a property when called on an object.
83 
84  Example::
85 
86  class Uniform(Distribution):
87  def __init__(self, low, high):
88  self.low = low
89  self.high = high
90  @constraints.dependent_property
91  def support(self):
92  return constraints.interval(self.low, self.high)
93  """
94  pass
95 
96 
97 class _Boolean(Constraint):
98  """
99  Constrain to the two values `{0, 1}`.
100  """
101  def check(self, value):
102  return (value == 0) | (value == 1)
103 
104 
106  """
107  Constrain to an integer interval `[lower_bound, upper_bound]`.
108  """
109  def __init__(self, lower_bound, upper_bound):
110  self.lower_bound = lower_bound
111  self.upper_bound = upper_bound
112 
113  def check(self, value):
114  return (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
115 
116  def __repr__(self):
117  fmt_string = self.__class__.__name__[1:]
118  fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound)
119  return fmt_string
120 
121 
123  """
124  Constrain to an integer interval `(-inf, upper_bound]`.
125  """
126  def __init__(self, upper_bound):
127  self.upper_bound = upper_bound
128 
129  def check(self, value):
130  return (value % 1 == 0) & (value <= self.upper_bound)
131 
132  def __repr__(self):
133  fmt_string = self.__class__.__name__[1:]
134  fmt_string += '(upper_bound={})'.format(self.upper_bound)
135  return fmt_string
136 
137 
139  """
140  Constrain to an integer interval `[lower_bound, inf)`.
141  """
142  def __init__(self, lower_bound):
143  self.lower_bound = lower_bound
144 
145  def check(self, value):
146  return (value % 1 == 0) & (value >= self.lower_bound)
147 
148  def __repr__(self):
149  fmt_string = self.__class__.__name__[1:]
150  fmt_string += '(lower_bound={})'.format(self.lower_bound)
151  return fmt_string
152 
153 
155  """
156  Trivially constrain to the extended real line `[-inf, inf]`.
157  """
158  def check(self, value):
159  return value == value # False for NANs.
160 
161 
163  """
164  Constrain to a real half line `(lower_bound, inf]`.
165  """
166  def __init__(self, lower_bound):
167  self.lower_bound = lower_bound
168 
169  def check(self, value):
170  return self.lower_bound < value
171 
172  def __repr__(self):
173  fmt_string = self.__class__.__name__[1:]
174  fmt_string += '(lower_bound={})'.format(self.lower_bound)
175  return fmt_string
176 
177 
179  """
180  Constrain to a real half line `[lower_bound, inf)`.
181  """
182  def __init__(self, lower_bound):
183  self.lower_bound = lower_bound
184 
185  def check(self, value):
186  return self.lower_bound <= value
187 
188  def __repr__(self):
189  fmt_string = self.__class__.__name__[1:]
190  fmt_string += '(lower_bound={})'.format(self.lower_bound)
191  return fmt_string
192 
193 
195  """
196  Constrain to a real half line `[-inf, upper_bound)`.
197  """
198  def __init__(self, upper_bound):
199  self.upper_bound = upper_bound
200 
201  def check(self, value):
202  return value < self.upper_bound
203 
204  def __repr__(self):
205  fmt_string = self.__class__.__name__[1:]
206  fmt_string += '(upper_bound={})'.format(self.upper_bound)
207  return fmt_string
208 
209 
211  """
212  Constrain to a real interval `[lower_bound, upper_bound]`.
213  """
214  def __init__(self, lower_bound, upper_bound):
215  self.lower_bound = lower_bound
216  self.upper_bound = upper_bound
217 
218  def check(self, value):
219  return (self.lower_bound <= value) & (value <= self.upper_bound)
220 
221  def __repr__(self):
222  fmt_string = self.__class__.__name__[1:]
223  fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound)
224  return fmt_string
225 
226 
228  """
229  Constrain to a real interval `[lower_bound, upper_bound)`.
230  """
231  def __init__(self, lower_bound, upper_bound):
232  self.lower_bound = lower_bound
233  self.upper_bound = upper_bound
234 
235  def check(self, value):
236  return (self.lower_bound <= value) & (value < self.upper_bound)
237 
238  def __repr__(self):
239  fmt_string = self.__class__.__name__[1:]
240  fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound)
241  return fmt_string
242 
243 
245  """
246  Constrain to the unit simplex in the innermost (rightmost) dimension.
247  Specifically: `x >= 0` and `x.sum(-1) == 1`.
248  """
249  def check(self, value):
250  return (value >= 0).all() & ((value.sum(-1, True) - 1).abs() < 1e-6).all()
251 
252 
254  """
255  Constrain to lower-triangular square matrices.
256  """
257  def check(self, value):
258  value_tril = value.tril()
259  return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
260 
261 
263  """
264  Constrain to lower-triangular square matrices with positive diagonals.
265  """
266  def check(self, value):
267  value_tril = value.tril()
268  lower_triangular = (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
269 
270  positive_diagonal = (value.diagonal(dim1=-2, dim2=-1) > 0).min(-1)[0]
271  return lower_triangular & positive_diagonal
272 
273 
275  """
276  Constrain to positive-definite matrices.
277  """
278  def check(self, value):
279  matrix_shape = value.shape[-2:]
280  batch_shape = value.unsqueeze(0).shape[:-2]
281  # TODO: replace with batched linear algebra routine when one becomes available
282  # note that `symeig()` returns eigenvalues in ascending order
283  flattened_value = value.reshape((-1,) + matrix_shape)
284  return torch.stack([v.symeig(eigenvectors=False)[0][:1] > 0.0
285  for v in flattened_value]).view(batch_shape)
286 
287 
289  """
290  Constrain to real-valued vectors. This is the same as `constraints.real`,
291  but additionally reduces across the `event_shape` dimension.
292  """
293  def check(self, value):
294  return (value == value).all() # False for NANs.
295 
296 
297 # Public interface.
298 dependent = _Dependent()
299 dependent_property = _DependentProperty
300 boolean = _Boolean()
301 nonnegative_integer = _IntegerGreaterThan(0)
302 positive_integer = _IntegerGreaterThan(1)
303 integer_interval = _IntegerInterval
304 real = _Real()
305 real_vector = _RealVector()
306 positive = _GreaterThan(0.)
307 greater_than = _GreaterThan
308 greater_than_eq = _GreaterThanEq
309 less_than = _LessThan
310 unit_interval = _Interval(0., 1.)
311 interval = _Interval
312 half_open_interval = _HalfOpenInterval
313 simplex = _Simplex()
314 lower_triangular = _LowerTriangular()
315 lower_cholesky = _LowerCholesky()
316 positive_definite = _PositiveDefinite()