Caffe2 - Python API
A deep learning, cross platform ML framework
lowrank_multivariate_normal.py
1 import math
2 
3 import torch
4 from torch.distributions import constraints
5 from torch.distributions.distribution import Distribution
6 from torch.distributions.multivariate_normal import _batch_mahalanobis, _batch_mv
7 from torch.distributions.utils import _standard_normal, lazy_property
8 
9 
10 def _batch_capacitance_tril(W, D):
11  r"""
12  Computes Cholesky of :math:`I + W.T @ inv(D) @ W` for a batch of matrices :math:`W`
13  and a batch of vectors :math:`D`.
14  """
15  m = W.size(-1)
16  Wt_Dinv = W.transpose(-1, -2) / D.unsqueeze(-2)
17  K = torch.matmul(Wt_Dinv, W).contiguous()
18  K.view(-1, m * m)[:, ::m + 1] += 1 # add identity matrix to K
19  return torch.cholesky(K)
20 
21 
22 def _batch_lowrank_logdet(W, D, capacitance_tril):
23  r"""
24  Uses "matrix determinant lemma"::
25  log|W @ W.T + D| = log|C| + log|D|,
26  where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute
27  the log determinant.
28  """
29  return 2 * capacitance_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + D.log().sum(-1)
30 
31 
32 def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril):
33  r"""
34  Uses "Woodbury matrix identity"::
35  inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D),
36  where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared
37  Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`.
38  """
39  Wt_Dinv = W.transpose(-1, -2) / D.unsqueeze(-2)
40  Wt_Dinv_x = _batch_mv(Wt_Dinv, x)
41  mahalanobis_term1 = (x.pow(2) / D).sum(-1)
42  mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x)
43  return mahalanobis_term1 - mahalanobis_term2
44 
45 
47  r"""
48  Creates a multivariate normal distribution with covariance matrix having a low-rank form
49  parameterized by :attr:`cov_factor` and :attr:`cov_diag`::
50  covariance_matrix = cov_factor @ cov_factor.T + cov_diag
51 
52  Example:
53 
54  >>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([1, 0]), torch.tensor([1, 1]))
55  >>> m.sample() # normally distributed with mean=`[0,0]`, cov_factor=`[1,0]`, cov_diag=`[1,1]`
56  tensor([-0.2102, -0.5429])
57 
58  Args:
59  loc (Tensor): mean of the distribution with shape `batch_shape + event_shape`
60  cov_factor (Tensor): factor part of low-rank form of covariance matrix with shape
61  `batch_shape + event_shape + (rank,)`
62  cov_diag (Tensor): diagonal part of low-rank form of covariance matrix with shape
63  `batch_shape + event_shape`
64 
65  Note:
66  The computation for determinant and inverse of covariance matrix is avoided when
67  `cov_factor.shape[1] << cov_factor.shape[0]` thanks to `Woodbury matrix identity
68  <https://en.wikipedia.org/wiki/Woodbury_matrix_identity>`_ and
69  `matrix determinant lemma <https://en.wikipedia.org/wiki/Matrix_determinant_lemma>`_.
70  Thanks to these formulas, we just need to compute the determinant and inverse of
71  the small size "capacitance" matrix::
72  capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
73  """
74  arg_constraints = {"loc": constraints.real,
75  "cov_factor": constraints.real,
76  "cov_diag": constraints.positive}
77  support = constraints.real
78  has_rsample = True
79 
80  def __init__(self, loc, cov_factor, cov_diag, validate_args=None):
81  if loc.dim() < 1:
82  raise ValueError("loc must be at least one-dimensional.")
83  event_shape = loc.shape[-1:]
84  if cov_factor.dim() < 2:
85  raise ValueError("cov_factor must be at least two-dimensional, "
86  "with optional leading batch dimensions")
87  if cov_factor.shape[-2:-1] != event_shape:
88  raise ValueError("cov_factor must be a batch of matrices with shape {} x m"
89  .format(event_shape[0]))
90  if cov_diag.shape[-1:] != event_shape:
91  raise ValueError("cov_diag must be a batch of vectors with shape {}".format(event_shape))
92 
93  loc_ = loc.unsqueeze(-1)
94  cov_diag_ = cov_diag.unsqueeze(-1)
95  try:
96  loc_, self.cov_factor, cov_diag_ = torch.broadcast_tensors(loc_, cov_factor, cov_diag_)
97  except RuntimeError:
98  raise ValueError("Incompatible batch shapes: loc {}, cov_factor {}, cov_diag {}"
99  .format(loc.shape, cov_factor.shape, cov_diag.shape))
100  self.loc = loc_[..., 0]
101  self.cov_diag = cov_diag_[..., 0]
102  batch_shape = self.loc.shape[:-1]
103 
104  self._unbroadcasted_cov_factor = cov_factor
105  self._unbroadcasted_cov_diag = cov_diag
106  self._capacitance_tril = _batch_capacitance_tril(cov_factor, cov_diag)
107  super(LowRankMultivariateNormal, self).__init__(batch_shape, event_shape,
108  validate_args=validate_args)
109 
110  def expand(self, batch_shape, _instance=None):
111  new = self._get_checked_instance(LowRankMultivariateNormal, _instance)
112  batch_shape = torch.Size(batch_shape)
113  loc_shape = batch_shape + self.event_shape
114  new.loc = self.loc.expand(loc_shape)
115  new.cov_diag = self.cov_diag.expand(loc_shape)
116  new.cov_factor = self.cov_factor.expand(loc_shape + self.cov_factor.shape[-1:])
117  new._unbroadcasted_cov_factor = self._unbroadcasted_cov_factor
118  new._unbroadcasted_cov_diag = self._unbroadcasted_cov_diag
119  new._capacitance_tril = self._capacitance_tril
120  super(LowRankMultivariateNormal, new).__init__(batch_shape,
121  self.event_shape,
122  validate_args=False)
123  new._validate_args = self._validate_args
124  return new
125 
126  @property
127  def mean(self):
128  return self.loc
129 
130  @lazy_property
131  def variance(self):
132  return (self._unbroadcasted_cov_factor.pow(2).sum(-1)
133  + self._unbroadcasted_cov_diag).expand(self._batch_shape + self._event_shape)
134 
135  @lazy_property
136  def scale_tril(self):
137  # The following identity is used to increase the numerically computation stability
138  # for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3):
139  # W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2
140  # The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1,
141  # hence it is well-conditioned and safe to take Cholesky decomposition.
142  n = self._event_shape[0]
143  cov_diag_sqrt_unsqueeze = self._unbroadcasted_cov_diag.sqrt().unsqueeze(-1)
144  Dinvsqrt_W = self._unbroadcasted_cov_factor / cov_diag_sqrt_unsqueeze
145  K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.transpose(-1, -2)).contiguous()
146  K.view(-1, n * n)[:, ::n + 1] += 1 # add identity matrix to K
147  scale_tril = cov_diag_sqrt_unsqueeze * torch.cholesky(K)
148  return scale_tril.expand(self._batch_shape + self._event_shape + self._event_shape)
149 
150  @lazy_property
151  def covariance_matrix(self):
152  covariance_matrix = (torch.matmul(self._unbroadcasted_cov_factor,
153  self._unbroadcasted_cov_factor.transpose(-1, -2))
154  + torch.diag_embed(self._unbroadcasted_cov_diag))
155  return covariance_matrix.expand(self._batch_shape + self._event_shape +
156  self._event_shape)
157 
158  @lazy_property
159  def precision_matrix(self):
160  # We use "Woodbury matrix identity" to take advantage of low rank form::
161  # inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D)
162  # where :math:`C` is the capacitance matrix.
163  Wt_Dinv = (self._unbroadcasted_cov_factor.transpose(-1, -2)
164  / self._unbroadcasted_cov_diag.unsqueeze(-2))
165  A = torch.trtrs(Wt_Dinv, self._capacitance_tril, upper=False)[0]
166  precision_matrix = (torch.diag_embed(self._unbroadcasted_cov_diag.reciprocal())
167  - torch.matmul(A.transpose(-1, -2), A))
168  return precision_matrix.expand(self._batch_shape + self._event_shape +
169  self._event_shape)
170 
171  def rsample(self, sample_shape=torch.Size()):
172  shape = self._extended_shape(sample_shape)
173  W_shape = shape[:-1] + self.cov_factor.shape[-1:]
174  eps_W = _standard_normal(W_shape, dtype=self.loc.dtype, device=self.loc.device)
175  eps_D = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
176  return (self.loc + _batch_mv(self._unbroadcasted_cov_factor, eps_W)
177  + self._unbroadcasted_cov_diag.sqrt() * eps_D)
178 
179  def log_prob(self, value):
180  if self._validate_args:
181  self._validate_sample(value)
182  diff = value - self.loc
183  M = _batch_lowrank_mahalanobis(self._unbroadcasted_cov_factor,
185  diff,
186  self._capacitance_tril)
187  log_det = _batch_lowrank_logdet(self._unbroadcasted_cov_factor,
189  self._capacitance_tril)
190  return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + log_det + M)
191 
192  def entropy(self):
193  log_det = _batch_lowrank_logdet(self._unbroadcasted_cov_factor,
195  self._capacitance_tril)
196  H = 0.5 * (self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + log_det)
197  if len(self._batch_shape) == 0:
198  return H
199  else:
200  return H.expand(self._batch_shape)
def _get_checked_instance(self, cls, _instance=None)
def _extended_shape(self, sample_shape=torch.Size())