2 The following constraints are implemented: 4 - ``constraints.boolean`` 5 - ``constraints.dependent`` 6 - ``constraints.greater_than(lower_bound)`` 7 - ``constraints.integer_interval(lower_bound, upper_bound)`` 8 - ``constraints.interval(lower_bound, upper_bound)`` 9 - ``constraints.lower_cholesky`` 10 - ``constraints.lower_triangular`` 11 - ``constraints.nonnegative_integer`` 12 - ``constraints.positive`` 13 - ``constraints.positive_definite`` 14 - ``constraints.positive_integer`` 15 - ``constraints.real`` 16 - ``constraints.real_vector`` 17 - ``constraints.simplex`` 18 - ``constraints.unit_interval`` 37 'nonnegative_integer',
50 Abstract base class for constraints. 52 A constraint object represents a region over which a variable is valid, 53 e.g. within which a variable can be optimized. 57 Returns a byte tensor of `sample_shape + batch_shape` indicating 58 whether each event in value satisfies this constraint. 60 raise NotImplementedError
63 return self.__class__.__name__[1:] +
'()' 68 Placeholder for variables whose support depends on other variables. 69 These variables obey no simple coordinate-wise constraints. 72 raise ValueError(
'Cannot determine validity of dependent constraint')
75 def is_dependent(constraint):
76 return isinstance(constraint, _Dependent)
81 Decorator that extends @property to act like a `Dependent` constraint when 82 called on a class and act like a property when called on an object. 86 class Uniform(Distribution): 87 def __init__(self, low, high): 90 @constraints.dependent_property 92 return constraints.interval(self.low, self.high) 97 class _Boolean(Constraint):
99 Constrain to the two values `{0, 1}`. 101 def check(self, value):
102 return (value == 0) | (value == 1)
107 Constrain to an integer interval `[lower_bound, upper_bound]`. 109 def __init__(self, lower_bound, upper_bound):
113 def check(self, value):
117 fmt_string = self.__class__.__name__[1:]
124 Constrain to an integer interval `(-inf, upper_bound]`. 126 def __init__(self, upper_bound):
129 def check(self, value):
130 return (value % 1 == 0) & (value <= self.
upper_bound)
133 fmt_string = self.__class__.__name__[1:]
134 fmt_string +=
'(upper_bound={})'.format(self.
upper_bound)
140 Constrain to an integer interval `[lower_bound, inf)`. 142 def __init__(self, lower_bound):
145 def check(self, value):
146 return (value % 1 == 0) & (value >= self.
lower_bound)
149 fmt_string = self.__class__.__name__[1:]
150 fmt_string +=
'(lower_bound={})'.format(self.
lower_bound)
156 Trivially constrain to the extended real line `[-inf, inf]`. 158 def check(self, value):
159 return value == value
164 Constrain to a real half line `(lower_bound, inf]`. 166 def __init__(self, lower_bound):
169 def check(self, value):
173 fmt_string = self.__class__.__name__[1:]
174 fmt_string +=
'(lower_bound={})'.format(self.
lower_bound)
180 Constrain to a real half line `[lower_bound, inf)`. 182 def __init__(self, lower_bound):
185 def check(self, value):
189 fmt_string = self.__class__.__name__[1:]
190 fmt_string +=
'(lower_bound={})'.format(self.
lower_bound)
196 Constrain to a real half line `[-inf, upper_bound)`. 198 def __init__(self, upper_bound):
201 def check(self, value):
205 fmt_string = self.__class__.__name__[1:]
206 fmt_string +=
'(upper_bound={})'.format(self.
upper_bound)
212 Constrain to a real interval `[lower_bound, upper_bound]`. 214 def __init__(self, lower_bound, upper_bound):
218 def check(self, value):
222 fmt_string = self.__class__.__name__[1:]
229 Constrain to a real interval `[lower_bound, upper_bound)`. 231 def __init__(self, lower_bound, upper_bound):
235 def check(self, value):
239 fmt_string = self.__class__.__name__[1:]
246 Constrain to the unit simplex in the innermost (rightmost) dimension. 247 Specifically: `x >= 0` and `x.sum(-1) == 1`. 249 def check(self, value):
250 return (value >= 0).all() & ((value.sum(-1,
True) - 1).abs() < 1e-6).all()
255 Constrain to lower-triangular square matrices. 257 def check(self, value):
258 value_tril = value.tril()
259 return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
264 Constrain to lower-triangular square matrices with positive diagonals. 266 def check(self, value):
267 value_tril = value.tril()
268 lower_triangular = (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
270 positive_diagonal = (value.diagonal(dim1=-2, dim2=-1) > 0).min(-1)[0]
271 return lower_triangular & positive_diagonal
276 Constrain to positive-definite matrices. 278 def check(self, value):
279 matrix_shape = value.shape[-2:]
280 batch_shape = value.unsqueeze(0).shape[:-2]
283 flattened_value = value.reshape((-1,) + matrix_shape)
284 return torch.stack([v.symeig(eigenvectors=
False)[0][:1] > 0.0
285 for v
in flattened_value]).view(batch_shape)
290 Constrain to real-valued vectors. This is the same as `constraints.real`, 291 but additionally reduces across the `event_shape` dimension. 293 def check(self, value):
294 return (value == value).all()
299 dependent_property = _DependentProperty
303 integer_interval = _IntegerInterval
307 greater_than = _GreaterThan
308 greater_than_eq = _GreaterThanEq
309 less_than = _LessThan
312 half_open_interval = _HalfOpenInterval