9 Distribution is the abstract base class for probability distributions. 13 has_enumerate_support =
False 14 _validate_args =
False 19 def set_default_validate_args(value):
20 if value
not in [
True,
False]:
22 Distribution._validate_args = value
24 def __init__(self, batch_shape=torch.Size(), event_shape=torch.Size(), validate_args=
None):
27 if validate_args
is not None:
30 for param, constraint
in self.arg_constraints.items():
31 if constraints.is_dependent(constraint):
33 if param
not in self.__dict__
and isinstance(getattr(type(self), param), lazy_property):
35 if not constraint.check(getattr(self, param)).all():
36 raise ValueError(
"The parameter {} has invalid values".format(param))
37 super(Distribution, self).__init__()
39 def expand(self, batch_shape, _instance=None):
41 Returns a new distribution instance (or populates an existing instance 42 provided by a derived class) with batch dimensions expanded to 43 `batch_shape`. This method calls :class:`~torch.Tensor.expand` on 44 the distribution's parameters. As such, this does not allocate new 45 memory for the expanded distribution instance. Additionally, 46 this does not repeat any args checking or parameter broadcasting in 47 `__init__.py`, when an instance is first created. 50 batch_shape (torch.Size): the desired expanded size. 51 _instance: new instance provided by subclasses that 52 need to override `.expand`. 55 New distribution instance with batch dimensions expanded to 58 raise NotImplementedError
63 Returns the shape over which parameters are batched. 70 Returns the shape of a single sample (without batching). 75 def arg_constraints(self):
77 Returns a dictionary from argument names to 78 :class:`~torch.distributions.constraints.Constraint` objects that 79 should be satisfied by each argument of this distribution. Args that 80 are not tensors need not appear in this dict. 82 raise NotImplementedError
87 Returns a :class:`~torch.distributions.constraints.Constraint` object 88 representing this distribution's support. 90 raise NotImplementedError
95 Returns the mean of the distribution. 97 raise NotImplementedError
102 Returns the variance of the distribution. 104 raise NotImplementedError
109 Returns the standard deviation of the distribution. 111 return self.variance.sqrt()
113 def sample(self, sample_shape=torch.Size()):
115 Generates a sample_shape shaped sample or sample_shape shaped batch of 116 samples if the distribution parameters are batched. 118 with torch.no_grad():
119 return self.
rsample(sample_shape)
123 Generates a sample_shape shaped reparameterized sample or sample_shape 124 shaped batch of reparameterized samples if the distribution parameters 127 raise NotImplementedError
131 Generates n samples or n batches of samples if the distribution 132 parameters are batched. 134 warnings.warn(
'sample_n will be deprecated. Use .sample((n,)) instead', UserWarning)
135 return self.
sample(torch.Size((n,)))
139 Returns the log of the probability density/mass function evaluated at 145 raise NotImplementedError
149 Returns the cumulative density/mass function evaluated at 155 raise NotImplementedError
159 Returns the inverse cumulative density/mass function evaluated at 165 raise NotImplementedError
169 Returns tensor containing all values supported by a discrete 170 distribution. The result will enumerate over dimension 0, so the shape 171 of the result will be `(cardinality,) + batch_shape + event_shape` 172 (where `event_shape = ()` for univariate distributions). 174 Note that this enumerates over all batched tensors in lock-step 175 `[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens 176 along dim 0, but with the remaining batch dimensions being 177 singleton dimensions, `[[0], [1], ..`. 179 To iterate over the full Cartesian product use 180 `itertools.product(m.enumerate_support())`. 183 expand (bool): whether to expand the support over the 184 batch dims to match the distribution's `batch_shape`. 187 Tensor iterating over dimension 0. 189 raise NotImplementedError
193 Returns entropy of distribution, batched over batch_shape. 196 Tensor of shape batch_shape. 198 raise NotImplementedError
202 Returns perplexity of distribution, batched over batch_shape. 205 Tensor of shape batch_shape. 207 return torch.exp(self.
entropy())
209 def _extended_shape(self, sample_shape=torch.Size()):
211 Returns the size of the sample returned by the distribution, given 212 a `sample_shape`. Note, that the batch and event shapes of a distribution 213 instance are fixed at the time of construction. If this is empty, the 214 returned shape is upcast to (1,). 217 sample_shape (torch.Size): the size of the sample to be drawn. 219 if not isinstance(sample_shape, torch.Size):
220 sample_shape = torch.Size(sample_shape)
223 def _validate_sample(self, value):
225 Argument validation for distribution methods such as `log_prob`, 226 `cdf` and `icdf`. The rightmost dimensions of a value to be 227 scored via these methods must agree with the distribution's batch 231 value (Tensor): the tensor whose log probability is to be 232 computed by the `log_prob` method. 234 ValueError: when the rightmost dimensions of `value` do not match the 235 distribution's batch and event shapes. 237 if not isinstance(value, torch.Tensor):
238 raise ValueError(
'The value argument to log_prob must be a Tensor')
240 event_dim_start = len(value.size()) - len(self.
_event_shape)
242 raise ValueError(
'The right-most size of value must match event_shape: {} vs {}.'.
245 actual_shape = value.size()
247 for i, j
in zip(reversed(actual_shape), reversed(expected_shape)):
248 if i != 1
and j != 1
and i != j:
249 raise ValueError(
'Value is not broadcastable with batch_shape+event_shape: {} vs {}.'.
250 format(actual_shape, expected_shape))
252 if not self.support.check(value).all():
253 raise ValueError(
'The value argument must be within the support')
255 def _get_checked_instance(self, cls, _instance=None):
256 if _instance
is None and type(self).__init__ != cls.
__init__:
257 raise NotImplementedError(
"Subclass {} of {} that defines a custom __init__ method " 258 "must also define a custom .expand() method.".
259 format(self.__class__.__name__, cls.__name__))
260 return self.__new__(type(self))
if _instance
is None else _instance
263 param_names = [k
for k, _
in self.arg_constraints.items()
if k
in self.__dict__]
264 args_string =
', '.join([
'{}: {}'.format(p, self.__dict__[p]
265 if self.__dict__[p].numel() == 1
266 else self.__dict__[p].size())
for p
in param_names])
267 return self.__class__.__name__ +
'(' + args_string +
')'
def rsample(self, sample_shape=torch.Size())
def expand(self, batch_shape, _instance=None)
def enumerate_support(self, expand=True)
def __init__(self, batch_shape=torch.Size(), event_shape=torch.Size(), validate_args=None)
def sample(self, sample_shape=torch.Size())
def log_prob(self, value)