Caffe2 - Python API
A deep learning, cross platform ML framework
multivariate_normal.py
1 import math
2 
3 import torch
4 from torch.distributions import constraints
5 from torch.distributions.distribution import Distribution
6 from torch.distributions.utils import _standard_normal, lazy_property
7 
8 
9 def _batch_mv(bmat, bvec):
10  r"""
11  Performs a batched matrix-vector product, with compatible but different batch shapes.
12 
13  This function takes as input `bmat`, containing :math:`n \times n` matrices, and
14  `bvec`, containing length :math:`n` vectors.
15 
16  Both `bmat` and `bvec` may have any number of leading dimensions, which correspond
17  to a batch shape. They are not necessarily assumed to have the same batch shape,
18  just ones which can be broadcasted.
19  """
20  return torch.matmul(bmat, bvec.unsqueeze(-1)).squeeze(-1)
21 
22 
23 def _batch_mahalanobis(bL, bx):
24  r"""
25  Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}`
26  for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`.
27 
28  Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch
29  shape, but `bL` one should be able to broadcasted to `bx` one.
30  """
31  n = bx.size(-1)
32  bx_batch_shape = bx.shape[:-1]
33 
34  # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
35  # we are going to make bx have shape (..., 1, j, i, 1, n) to apply _batch_trtrs_lower
36  bx_batch_dims = len(bx_batch_shape)
37  bL_batch_dims = bL.dim() - 2
38  outer_batch_dims = bx_batch_dims - bL_batch_dims
39  old_batch_dims = outer_batch_dims + bL_batch_dims
40  new_batch_dims = outer_batch_dims + 2 * bL_batch_dims
41  # Reshape bx with the shape (..., 1, i, j, 1, n)
42  bx_new_shape = bx.shape[:outer_batch_dims]
43  for (sL, sx) in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]):
44  bx_new_shape += (sx // sL, sL)
45  bx_new_shape += (n,)
46  bx = bx.reshape(bx_new_shape)
47  # Permute bx to make it have shape (..., 1, j, i, 1, n)
48  permute_dims = (list(range(outer_batch_dims)) +
49  list(range(outer_batch_dims, new_batch_dims, 2)) +
50  list(range(outer_batch_dims + 1, new_batch_dims, 2)) +
51  [new_batch_dims])
52  bx = bx.permute(permute_dims)
53 
54  flat_L = bL.reshape(-1, n, n) # shape = b x n x n
55  flat_x = bx.reshape(-1, flat_L.size(0), n) # shape = c x b x n
56  flat_x_swap = flat_x.permute(1, 2, 0) # shape = b x n x c
57  M_swap = torch.trtrs(flat_x_swap, flat_L, upper=False)[0].pow(2).sum(-2) # shape = b x c
58  M = M_swap.t() # shape = c x b
59 
60  # Now we revert the above reshape and permute operators.
61  permuted_M = M.reshape(bx.shape[:-1]) # shape = (..., 1, j, i, 1)
62  permute_inv_dims = list(range(outer_batch_dims))
63  for i in range(bL_batch_dims):
64  permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i]
65  reshaped_M = permuted_M.permute(permute_inv_dims) # shape = (..., 1, i, j, 1)
66  return reshaped_M.reshape(bx_batch_shape)
67 
68 
70  r"""
71  Creates a multivariate normal (also called Gaussian) distribution
72  parameterized by a mean vector and a covariance matrix.
73 
74  The multivariate normal distribution can be parameterized either
75  in terms of a positive definite covariance matrix :math:`\mathbf{\Sigma}`
76  or a positive definite precision matrix :math:`\mathbf{\Sigma}^{-1}`
77  or a lower-triangular matrix :math:`\mathbf{L}` with positive-valued
78  diagonal entries, such that
79  :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`. This triangular matrix
80  can be obtained via e.g. Cholesky decomposition of the covariance.
81 
82  Example:
83 
84  >>> m = MultivariateNormal(torch.zeros(2), torch.eye(2))
85  >>> m.sample() # normally distributed with mean=`[0,0]` and covariance_matrix=`I`
86  tensor([-0.2102, -0.5429])
87 
88  Args:
89  loc (Tensor): mean of the distribution
90  covariance_matrix (Tensor): positive-definite covariance matrix
91  precision_matrix (Tensor): positive-definite precision matrix
92  scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
93 
94  Note:
95  Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or
96  :attr:`scale_tril` can be specified.
97 
98  Using :attr:`scale_tril` will be more efficient: all computations internally
99  are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or
100  :attr:`precision_matrix` is passed instead, it is only used to compute
101  the corresponding lower triangular matrices using a Cholesky decomposition.
102  """
103  arg_constraints = {'loc': constraints.real_vector,
104  'covariance_matrix': constraints.positive_definite,
105  'precision_matrix': constraints.positive_definite,
106  'scale_tril': constraints.lower_cholesky}
107  support = constraints.real
108  has_rsample = True
109 
110  def __init__(self, loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None):
111  if loc.dim() < 1:
112  raise ValueError("loc must be at least one-dimensional.")
113  if (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) != 1:
114  raise ValueError("Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified.")
115 
116  loc_ = loc.unsqueeze(-1) # temporarily add dim on right
117  if scale_tril is not None:
118  if scale_tril.dim() < 2:
119  raise ValueError("scale_tril matrix must be at least two-dimensional, "
120  "with optional leading batch dimensions")
121  self.scale_tril, loc_ = torch.broadcast_tensors(scale_tril, loc_)
122  elif covariance_matrix is not None:
123  if covariance_matrix.dim() < 2:
124  raise ValueError("covariance_matrix must be at least two-dimensional, "
125  "with optional leading batch dimensions")
126  self.covariance_matrix, loc_ = torch.broadcast_tensors(covariance_matrix, loc_)
127  else:
128  if precision_matrix.dim() < 2:
129  raise ValueError("precision_matrix must be at least two-dimensional, "
130  "with optional leading batch dimensions")
131  self.precision_matrix, loc_ = torch.broadcast_tensors(precision_matrix, loc_)
132  self.loc = loc_[..., 0] # drop rightmost dim
133 
134  batch_shape, event_shape = self.loc.shape[:-1], self.loc.shape[-1:]
135  super(MultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=validate_args)
136 
137  if scale_tril is not None:
138  self._unbroadcasted_scale_tril = scale_tril
139  else:
140  if precision_matrix is not None:
141  self.covariance_matrix = torch.inverse(precision_matrix).expand_as(loc_)
142  self._unbroadcasted_scale_tril = torch.cholesky(self.covariance_matrix)
143 
144  def expand(self, batch_shape, _instance=None):
145  new = self._get_checked_instance(MultivariateNormal, _instance)
146  batch_shape = torch.Size(batch_shape)
147  loc_shape = batch_shape + self.event_shape
148  cov_shape = batch_shape + self.event_shape + self.event_shape
149  new.loc = self.loc.expand(loc_shape)
150  new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril
151  if 'covariance_matrix' in self.__dict__:
152  new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
153  if 'scale_tril' in self.__dict__:
154  new.scale_tril = self.scale_tril.expand(cov_shape)
155  if 'precision_matrix' in self.__dict__:
156  new.precision_matrix = self.precision_matrix.expand(cov_shape)
157  super(MultivariateNormal, new).__init__(batch_shape,
158  self.event_shape,
159  validate_args=False)
160  new._validate_args = self._validate_args
161  return new
162 
163  @lazy_property
164  def scale_tril(self):
165  return self._unbroadcasted_scale_tril.expand(
166  self._batch_shape + self._event_shape + self._event_shape)
167 
168  @lazy_property
169  def covariance_matrix(self):
170  return (torch.matmul(self._unbroadcasted_scale_tril,
171  self._unbroadcasted_scale_tril.transpose(-1, -2))
172  .expand(self._batch_shape + self._event_shape + self._event_shape))
173 
174  @lazy_property
175  def precision_matrix(self):
176  # TODO: use `torch.potri` on `scale_tril` once a backwards pass is implemented.
177  scale_tril_inv = torch.inverse(self._unbroadcasted_scale_tril)
178  return torch.matmul(scale_tril_inv.transpose(-1, -2), scale_tril_inv).expand(
179  self._batch_shape + self._event_shape + self._event_shape)
180 
181  @property
182  def mean(self):
183  return self.loc
184 
185  @property
186  def variance(self):
187  return self._unbroadcasted_scale_tril.pow(2).sum(-1).expand(
188  self._batch_shape + self._event_shape)
189 
190  def rsample(self, sample_shape=torch.Size()):
191  shape = self._extended_shape(sample_shape)
192  eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
193  return self.loc + _batch_mv(self._unbroadcasted_scale_tril, eps)
194 
195  def log_prob(self, value):
196  if self._validate_args:
197  self._validate_sample(value)
198  diff = value - self.loc
199  M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
200  half_log_det = self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
201  return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + M) - half_log_det
202 
203  def entropy(self):
204  half_log_det = self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
205  H = 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det
206  if len(self._batch_shape) == 0:
207  return H
208  else:
209  return H.expand(self._batch_shape)
def _get_checked_instance(self, cls, _instance=None)
def _extended_shape(self, sample_shape=torch.Size())