Caffe2 - Python API
A deep learning, cross platform ML framework
Public Member Functions | List of all members
torch.distributions.constraint_registry.ConstraintRegistry Class Reference
Inheritance diagram for torch.distributions.constraint_registry.ConstraintRegistry:

Public Member Functions

def __init__ (self)
def register (self, constraint, factory=None)
def __call__ (self, constraint)

Detailed Description

Registry to link constraints to transforms.

Definition at line 79 of file

Member Function Documentation

def torch.distributions.constraint_registry.ConstraintRegistry.__call__ (   self,
Looks up a transform to constrained space, given a constraint object.

    constraint = Normal.arg_constraints['scale']
    scale = transform_to(constraint)(torch.zeros(1))  # constrained
    u = transform_to(constraint).inv(scale)           # unconstrained

    constraint (:class:`~torch.distributions.constraints.Constraint`):
A constraint object.

    A :class:`~torch.distributions.transforms.Transform` object.

    `NotImplementedError` if no transform has been registered.

Definition at line 119 of file

def torch.distributions.constraint_registry.ConstraintRegistry.register (   self,
  factory = None 
Registers a :class:`~torch.distributions.constraints.Constraint`
subclass in this registry. Usage::

    def construct_transform(constraint):
assert isinstance(constraint, MyConstraint)
return MyTransform(constraint.arg_constraints)

    constraint (subclass of :class:`~torch.distributions.constraints.Constraint`):
A subclass of :class:`~torch.distributions.constraints.Constraint`, or
a singleton object of the desired class.
    factory (callable): A callable that inputs a constraint object and returns
a  :class:`~torch.distributions.transforms.Transform` object.

Definition at line 87 of file

The documentation for this class was generated from the following file: