multivariate_normal.py
1 import math
2
3 import torch
4 from torch.distributions import constraints
5 from torch.distributions.distribution import Distribution
6 from torch.distributions.utils import _standard_normal, lazy_property
7
8
9 def _batch_mv(bmat, bvec):
10  r"""
11  Performs a batched matrix-vector product, with compatible but different batch shapes.
12
13  This function takes as input bmat, containing :math:n \times n matrices, and
14  bvec, containing length :math:n vectors.
15
16  Both bmat and bvec may have any number of leading dimensions, which correspond
17  to a batch shape. They are not necessarily assumed to have the same batch shape,
18  just ones which can be broadcasted.
19  """
21
22
23 def _batch_mahalanobis(bL, bx):
24  r"""
25  Computes the squared Mahalanobis distance :math:\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}
26  for a factored :math:\mathbf{M} = \mathbf{L}\mathbf{L}^\top.
27
28  Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch
29  shape, but bL one should be able to broadcasted to bx one.
30  """
31  n = bx.size(-1)
32  bx_batch_shape = bx.shape[:-1]
33
34  # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
35  # we are going to make bx have shape (..., 1, j, i, 1, n) to apply _batch_trtrs_lower
36  bx_batch_dims = len(bx_batch_shape)
37  bL_batch_dims = bL.dim() - 2
38  outer_batch_dims = bx_batch_dims - bL_batch_dims
39  old_batch_dims = outer_batch_dims + bL_batch_dims
40  new_batch_dims = outer_batch_dims + 2 * bL_batch_dims
41  # Reshape bx with the shape (..., 1, i, j, 1, n)
42  bx_new_shape = bx.shape[:outer_batch_dims]
43  for (sL, sx) in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]):
44  bx_new_shape += (sx // sL, sL)
45  bx_new_shape += (n,)
46  bx = bx.reshape(bx_new_shape)
47  # Permute bx to make it have shape (..., 1, j, i, 1, n)
48  permute_dims = (list(range(outer_batch_dims)) +
49  list(range(outer_batch_dims, new_batch_dims, 2)) +
50  list(range(outer_batch_dims + 1, new_batch_dims, 2)) +
51  [new_batch_dims])
52  bx = bx.permute(permute_dims)
53
54  flat_L = bL.reshape(-1, n, n) # shape = b x n x n
55  flat_x = bx.reshape(-1, flat_L.size(0), n) # shape = c x b x n
56  flat_x_swap = flat_x.permute(1, 2, 0) # shape = b x n x c
57  M_swap = torch.trtrs(flat_x_swap, flat_L, upper=False).pow(2).sum(-2) # shape = b x c
58  M = M_swap.t() # shape = c x b
59
60  # Now we revert the above reshape and permute operators.
61  permuted_M = M.reshape(bx.shape[:-1]) # shape = (..., 1, j, i, 1)
62  permute_inv_dims = list(range(outer_batch_dims))
63  for i in range(bL_batch_dims):
64  permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i]
65  reshaped_M = permuted_M.permute(permute_inv_dims) # shape = (..., 1, i, j, 1)
66  return reshaped_M.reshape(bx_batch_shape)
67
68
70  r"""
71  Creates a multivariate normal (also called Gaussian) distribution
72  parameterized by a mean vector and a covariance matrix.
73
74  The multivariate normal distribution can be parameterized either
75  in terms of a positive definite covariance matrix :math:\mathbf{\Sigma}
76  or a positive definite precision matrix :math:\mathbf{\Sigma}^{-1}
77  or a lower-triangular matrix :math:\mathbf{L} with positive-valued
78  diagonal entries, such that
79  :math:\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top. This triangular matrix
80  can be obtained via e.g. Cholesky decomposition of the covariance.
81
82  Example:
83
84  >>> m = MultivariateNormal(torch.zeros(2), torch.eye(2))
85  >>> m.sample() # normally distributed with mean=[0,0] and covariance_matrix=I
86  tensor([-0.2102, -0.5429])
87
88  Args:
89  loc (Tensor): mean of the distribution
90  covariance_matrix (Tensor): positive-definite covariance matrix
91  precision_matrix (Tensor): positive-definite precision matrix
92  scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
93
94  Note:
95  Only one of :attr:covariance_matrix or :attr:precision_matrix or
96  :attr:scale_tril can be specified.
97
98  Using :attr:scale_tril will be more efficient: all computations internally
99  are based on :attr:scale_tril. If :attr:covariance_matrix or
100  :attr:precision_matrix is passed instead, it is only used to compute
101  the corresponding lower triangular matrices using a Cholesky decomposition.
102  """
103  arg_constraints = {'loc': constraints.real_vector,
104  'covariance_matrix': constraints.positive_definite,
105  'precision_matrix': constraints.positive_definite,
106  'scale_tril': constraints.lower_cholesky}
107  support = constraints.real
108  has_rsample = True
109
110  def __init__(self, loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None):
111  if loc.dim() < 1:
112  raise ValueError("loc must be at least one-dimensional.")
113  if (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) != 1:
114  raise ValueError("Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified.")
115
116  loc_ = loc.unsqueeze(-1) # temporarily add dim on right
117  if scale_tril is not None:
118  if scale_tril.dim() < 2:
119  raise ValueError("scale_tril matrix must be at least two-dimensional, "
120  "with optional leading batch dimensions")
121  self.scale_tril, loc_ = torch.broadcast_tensors(scale_tril, loc_)
122  elif covariance_matrix is not None:
123  if covariance_matrix.dim() < 2:
124  raise ValueError("covariance_matrix must be at least two-dimensional, "
125  "with optional leading batch dimensions")
126  self.covariance_matrix, loc_ = torch.broadcast_tensors(covariance_matrix, loc_)
127  else:
128  if precision_matrix.dim() < 2:
129  raise ValueError("precision_matrix must be at least two-dimensional, "
130  "with optional leading batch dimensions")
131  self.precision_matrix, loc_ = torch.broadcast_tensors(precision_matrix, loc_)
132  self.loc = loc_[..., 0] # drop rightmost dim
133
134  batch_shape, event_shape = self.loc.shape[:-1], self.loc.shape[-1:]
135  super(MultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=validate_args)
136
137  if scale_tril is not None:
139  else:
140  if precision_matrix is not None:
141  self.covariance_matrix = torch.inverse(precision_matrix).expand_as(loc_)
143
144  def expand(self, batch_shape, _instance=None):
145  new = self._get_checked_instance(MultivariateNormal, _instance)
146  batch_shape = torch.Size(batch_shape)
147  loc_shape = batch_shape + self.event_shape
148  cov_shape = batch_shape + self.event_shape + self.event_shape
149  new.loc = self.loc.expand(loc_shape)
151  if 'covariance_matrix' in self.__dict__:
152  new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
153  if 'scale_tril' in self.__dict__:
154  new.scale_tril = self.scale_tril.expand(cov_shape)
155  if 'precision_matrix' in self.__dict__:
156  new.precision_matrix = self.precision_matrix.expand(cov_shape)
157  super(MultivariateNormal, new).__init__(batch_shape,
158  self.event_shape,
159  validate_args=False)
160  new._validate_args = self._validate_args
161  return new
162
163  @lazy_property
164  def scale_tril(self):
166  self._batch_shape + self._event_shape + self._event_shape)
167
168  @lazy_property
169  def covariance_matrix(self):
172  .expand(self._batch_shape + self._event_shape + self._event_shape))
173
174  @lazy_property
175  def precision_matrix(self):
176  # TODO: use torch.potri on scale_tril once a backwards pass is implemented.
179  self._batch_shape + self._event_shape + self._event_shape)
180
181  @property
182  def mean(self):
183  return self.loc
184
185  @property
186  def variance(self):
188  self._batch_shape + self._event_shape)
189
190  def rsample(self, sample_shape=torch.Size()):
191  shape = self._extended_shape(sample_shape)
192  eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
193  return self.loc + _batch_mv(self._unbroadcasted_scale_tril, eps)
194
195  def log_prob(self, value):
196  if self._validate_args:
197  self._validate_sample(value)
198  diff = value - self.loc