Caffe2 - Python API
A deep learning, cross platform ML framework
distribution.py
1 import torch
2 import warnings
3 from torch.distributions import constraints
4 from torch.distributions.utils import lazy_property
5 
6 
7 class Distribution(object):
8  r"""
9  Distribution is the abstract base class for probability distributions.
10  """
11 
12  has_rsample = False
13  has_enumerate_support = False
14  _validate_args = False
15  support = None
16  arg_constraints = {}
17 
18  @staticmethod
19  def set_default_validate_args(value):
20  if value not in [True, False]:
21  raise ValueError
22  Distribution._validate_args = value
23 
24  def __init__(self, batch_shape=torch.Size(), event_shape=torch.Size(), validate_args=None):
25  self._batch_shape = batch_shape
26  self._event_shape = event_shape
27  if validate_args is not None:
28  self._validate_args = validate_args
29  if self._validate_args:
30  for param, constraint in self.arg_constraints.items():
31  if constraints.is_dependent(constraint):
32  continue # skip constraints that cannot be checked
33  if param not in self.__dict__ and isinstance(getattr(type(self), param), lazy_property):
34  continue # skip checking lazily-constructed args
35  if not constraint.check(getattr(self, param)).all():
36  raise ValueError("The parameter {} has invalid values".format(param))
37  super(Distribution, self).__init__()
38 
39  def expand(self, batch_shape, _instance=None):
40  """
41  Returns a new distribution instance (or populates an existing instance
42  provided by a derived class) with batch dimensions expanded to
43  `batch_shape`. This method calls :class:`~torch.Tensor.expand` on
44  the distribution's parameters. As such, this does not allocate new
45  memory for the expanded distribution instance. Additionally,
46  this does not repeat any args checking or parameter broadcasting in
47  `__init__.py`, when an instance is first created.
48 
49  Args:
50  batch_shape (torch.Size): the desired expanded size.
51  _instance: new instance provided by subclasses that
52  need to override `.expand`.
53 
54  Returns:
55  New distribution instance with batch dimensions expanded to
56  `batch_size`.
57  """
58  raise NotImplementedError
59 
60  @property
61  def batch_shape(self):
62  """
63  Returns the shape over which parameters are batched.
64  """
65  return self._batch_shape
66 
67  @property
68  def event_shape(self):
69  """
70  Returns the shape of a single sample (without batching).
71  """
72  return self._event_shape
73 
74  @property
75  def arg_constraints(self):
76  """
77  Returns a dictionary from argument names to
78  :class:`~torch.distributions.constraints.Constraint` objects that
79  should be satisfied by each argument of this distribution. Args that
80  are not tensors need not appear in this dict.
81  """
82  raise NotImplementedError
83 
84  @property
85  def support(self):
86  """
87  Returns a :class:`~torch.distributions.constraints.Constraint` object
88  representing this distribution's support.
89  """
90  raise NotImplementedError
91 
92  @property
93  def mean(self):
94  """
95  Returns the mean of the distribution.
96  """
97  raise NotImplementedError
98 
99  @property
100  def variance(self):
101  """
102  Returns the variance of the distribution.
103  """
104  raise NotImplementedError
105 
106  @property
107  def stddev(self):
108  """
109  Returns the standard deviation of the distribution.
110  """
111  return self.variance.sqrt()
112 
113  def sample(self, sample_shape=torch.Size()):
114  """
115  Generates a sample_shape shaped sample or sample_shape shaped batch of
116  samples if the distribution parameters are batched.
117  """
118  with torch.no_grad():
119  return self.rsample(sample_shape)
120 
121  def rsample(self, sample_shape=torch.Size()):
122  """
123  Generates a sample_shape shaped reparameterized sample or sample_shape
124  shaped batch of reparameterized samples if the distribution parameters
125  are batched.
126  """
127  raise NotImplementedError
128 
129  def sample_n(self, n):
130  """
131  Generates n samples or n batches of samples if the distribution
132  parameters are batched.
133  """
134  warnings.warn('sample_n will be deprecated. Use .sample((n,)) instead', UserWarning)
135  return self.sample(torch.Size((n,)))
136 
137  def log_prob(self, value):
138  """
139  Returns the log of the probability density/mass function evaluated at
140  `value`.
141 
142  Args:
143  value (Tensor):
144  """
145  raise NotImplementedError
146 
147  def cdf(self, value):
148  """
149  Returns the cumulative density/mass function evaluated at
150  `value`.
151 
152  Args:
153  value (Tensor):
154  """
155  raise NotImplementedError
156 
157  def icdf(self, value):
158  """
159  Returns the inverse cumulative density/mass function evaluated at
160  `value`.
161 
162  Args:
163  value (Tensor):
164  """
165  raise NotImplementedError
166 
167  def enumerate_support(self, expand=True):
168  """
169  Returns tensor containing all values supported by a discrete
170  distribution. The result will enumerate over dimension 0, so the shape
171  of the result will be `(cardinality,) + batch_shape + event_shape`
172  (where `event_shape = ()` for univariate distributions).
173 
174  Note that this enumerates over all batched tensors in lock-step
175  `[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens
176  along dim 0, but with the remaining batch dimensions being
177  singleton dimensions, `[[0], [1], ..`.
178 
179  To iterate over the full Cartesian product use
180  `itertools.product(m.enumerate_support())`.
181 
182  Args:
183  expand (bool): whether to expand the support over the
184  batch dims to match the distribution's `batch_shape`.
185 
186  Returns:
187  Tensor iterating over dimension 0.
188  """
189  raise NotImplementedError
190 
191  def entropy(self):
192  """
193  Returns entropy of distribution, batched over batch_shape.
194 
195  Returns:
196  Tensor of shape batch_shape.
197  """
198  raise NotImplementedError
199 
200  def perplexity(self):
201  """
202  Returns perplexity of distribution, batched over batch_shape.
203 
204  Returns:
205  Tensor of shape batch_shape.
206  """
207  return torch.exp(self.entropy())
208 
209  def _extended_shape(self, sample_shape=torch.Size()):
210  """
211  Returns the size of the sample returned by the distribution, given
212  a `sample_shape`. Note, that the batch and event shapes of a distribution
213  instance are fixed at the time of construction. If this is empty, the
214  returned shape is upcast to (1,).
215 
216  Args:
217  sample_shape (torch.Size): the size of the sample to be drawn.
218  """
219  if not isinstance(sample_shape, torch.Size):
220  sample_shape = torch.Size(sample_shape)
221  return sample_shape + self._batch_shape + self._event_shape
222 
223  def _validate_sample(self, value):
224  """
225  Argument validation for distribution methods such as `log_prob`,
226  `cdf` and `icdf`. The rightmost dimensions of a value to be
227  scored via these methods must agree with the distribution's batch
228  and event shapes.
229 
230  Args:
231  value (Tensor): the tensor whose log probability is to be
232  computed by the `log_prob` method.
233  Raises
234  ValueError: when the rightmost dimensions of `value` do not match the
235  distribution's batch and event shapes.
236  """
237  if not isinstance(value, torch.Tensor):
238  raise ValueError('The value argument to log_prob must be a Tensor')
239 
240  event_dim_start = len(value.size()) - len(self._event_shape)
241  if value.size()[event_dim_start:] != self._event_shape:
242  raise ValueError('The right-most size of value must match event_shape: {} vs {}.'.
243  format(value.size(), self._event_shape))
244 
245  actual_shape = value.size()
246  expected_shape = self._batch_shape + self._event_shape
247  for i, j in zip(reversed(actual_shape), reversed(expected_shape)):
248  if i != 1 and j != 1 and i != j:
249  raise ValueError('Value is not broadcastable with batch_shape+event_shape: {} vs {}.'.
250  format(actual_shape, expected_shape))
251 
252  if not self.support.check(value).all():
253  raise ValueError('The value argument must be within the support')
254 
255  def _get_checked_instance(self, cls, _instance=None):
256  if _instance is None and type(self).__init__ != cls.__init__:
257  raise NotImplementedError("Subclass {} of {} that defines a custom __init__ method "
258  "must also define a custom .expand() method.".
259  format(self.__class__.__name__, cls.__name__))
260  return self.__new__(type(self)) if _instance is None else _instance
261 
262  def __repr__(self):
263  param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__]
264  args_string = ', '.join(['{}: {}'.format(p, self.__dict__[p]
265  if self.__dict__[p].numel() == 1
266  else self.__dict__[p].size()) for p in param_names])
267  return self.__class__.__name__ + '(' + args_string + ')'
def rsample(self, sample_shape=torch.Size())
def expand(self, batch_shape, _instance=None)
Definition: distribution.py:39
def __init__(self, batch_shape=torch.Size(), event_shape=torch.Size(), validate_args=None)
Definition: distribution.py:24
def sample(self, sample_shape=torch.Size())