Public Member Functions | |
| def | __init__ (self) |
| def | register (self, constraint, factory=None) |
| def | __call__ (self, constraint) |
Registry to link constraints to transforms.
Definition at line 79 of file constraint_registry.py.
| def torch.distributions.constraint_registry.ConstraintRegistry.__call__ | ( | self, | |
| constraint | |||
| ) |
Looks up a transform to constrained space, given a constraint object.
Usage::
constraint = Normal.arg_constraints['scale']
scale = transform_to(constraint)(torch.zeros(1)) # constrained
u = transform_to(constraint).inv(scale) # unconstrained
Args:
constraint (:class:`~torch.distributions.constraints.Constraint`):
A constraint object.
Returns:
A :class:`~torch.distributions.transforms.Transform` object.
Raises:
`NotImplementedError` if no transform has been registered.
Definition at line 119 of file constraint_registry.py.
| def torch.distributions.constraint_registry.ConstraintRegistry.register | ( | self, | |
| constraint, | |||
factory = None |
|||
| ) |
Registers a :class:`~torch.distributions.constraints.Constraint`
subclass in this registry. Usage::
@my_registry.register(MyConstraintClass)
def construct_transform(constraint):
assert isinstance(constraint, MyConstraint)
return MyTransform(constraint.arg_constraints)
Args:
constraint (subclass of :class:`~torch.distributions.constraints.Constraint`):
A subclass of :class:`~torch.distributions.constraints.Constraint`, or
a singleton object of the desired class.
factory (callable): A callable that inputs a constraint object and returns
a :class:`~torch.distributions.transforms.Transform` object.
Definition at line 87 of file constraint_registry.py.
1.8.11