Public Member Functions | |
def | __init__ (self) |
def | register (self, constraint, factory=None) |
def | __call__ (self, constraint) |
Registry to link constraints to transforms.
Definition at line 79 of file constraint_registry.py.
def torch.distributions.constraint_registry.ConstraintRegistry.__call__ | ( | self, | |
constraint | |||
) |
Looks up a transform to constrained space, given a constraint object. Usage:: constraint = Normal.arg_constraints['scale'] scale = transform_to(constraint)(torch.zeros(1)) # constrained u = transform_to(constraint).inv(scale) # unconstrained Args: constraint (:class:`~torch.distributions.constraints.Constraint`): A constraint object. Returns: A :class:`~torch.distributions.transforms.Transform` object. Raises: `NotImplementedError` if no transform has been registered.
Definition at line 119 of file constraint_registry.py.
def torch.distributions.constraint_registry.ConstraintRegistry.register | ( | self, | |
constraint, | |||
factory = None |
|||
) |
Registers a :class:`~torch.distributions.constraints.Constraint` subclass in this registry. Usage:: @my_registry.register(MyConstraintClass) def construct_transform(constraint): assert isinstance(constraint, MyConstraint) return MyTransform(constraint.arg_constraints) Args: constraint (subclass of :class:`~torch.distributions.constraints.Constraint`): A subclass of :class:`~torch.distributions.constraints.Constraint`, or a singleton object of the desired class. factory (callable): A callable that inputs a constraint object and returns a :class:`~torch.distributions.transforms.Transform` object.
Definition at line 87 of file constraint_registry.py.