10 def _batch_capacitance_tril(W, D):
12 Computes Cholesky of :math:`I + W.T @ inv(D) @ W` for a batch of matrices :math:`W` 13 and a batch of vectors :math:`D`. 16 Wt_Dinv = W.transpose(-1, -2) / D.unsqueeze(-2)
17 K = torch.matmul(Wt_Dinv, W).contiguous()
18 K.view(-1, m * m)[:, ::m + 1] += 1
19 return torch.cholesky(K)
22 def _batch_lowrank_logdet(W, D, capacitance_tril):
24 Uses "matrix determinant lemma":: 25 log|W @ W.T + D| = log|C| + log|D|, 26 where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute 29 return 2 * capacitance_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + D.log().sum(-1)
32 def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril):
34 Uses "Woodbury matrix identity":: 35 inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D), 36 where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared 37 Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`. 39 Wt_Dinv = W.transpose(-1, -2) / D.unsqueeze(-2)
40 Wt_Dinv_x = _batch_mv(Wt_Dinv, x)
41 mahalanobis_term1 = (x.pow(2) / D).sum(-1)
42 mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x)
43 return mahalanobis_term1 - mahalanobis_term2
48 Creates a multivariate normal distribution with covariance matrix having a low-rank form 49 parameterized by :attr:`cov_factor` and :attr:`cov_diag`:: 50 covariance_matrix = cov_factor @ cov_factor.T + cov_diag 54 >>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([1, 0]), torch.tensor([1, 1])) 55 >>> m.sample() # normally distributed with mean=`[0,0]`, cov_factor=`[1,0]`, cov_diag=`[1,1]` 56 tensor([-0.2102, -0.5429]) 59 loc (Tensor): mean of the distribution with shape `batch_shape + event_shape` 60 cov_factor (Tensor): factor part of low-rank form of covariance matrix with shape 61 `batch_shape + event_shape + (rank,)` 62 cov_diag (Tensor): diagonal part of low-rank form of covariance matrix with shape 63 `batch_shape + event_shape` 66 The computation for determinant and inverse of covariance matrix is avoided when 67 `cov_factor.shape[1] << cov_factor.shape[0]` thanks to `Woodbury matrix identity 68 <https://en.wikipedia.org/wiki/Woodbury_matrix_identity>`_ and 69 `matrix determinant lemma <https://en.wikipedia.org/wiki/Matrix_determinant_lemma>`_. 70 Thanks to these formulas, we just need to compute the determinant and inverse of 71 the small size "capacitance" matrix:: 72 capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor 74 arg_constraints = {
"loc": constraints.real,
75 "cov_factor": constraints.real,
76 "cov_diag": constraints.positive}
77 support = constraints.real
80 def __init__(self, loc, cov_factor, cov_diag, validate_args=None):
82 raise ValueError(
"loc must be at least one-dimensional.")
83 event_shape = loc.shape[-1:]
84 if cov_factor.dim() < 2:
85 raise ValueError(
"cov_factor must be at least two-dimensional, " 86 "with optional leading batch dimensions")
87 if cov_factor.shape[-2:-1] != event_shape:
88 raise ValueError(
"cov_factor must be a batch of matrices with shape {} x m" 89 .format(event_shape[0]))
90 if cov_diag.shape[-1:] != event_shape:
91 raise ValueError(
"cov_diag must be a batch of vectors with shape {}".format(event_shape))
93 loc_ = loc.unsqueeze(-1)
94 cov_diag_ = cov_diag.unsqueeze(-1)
96 loc_, self.cov_factor, cov_diag_ = torch.broadcast_tensors(loc_, cov_factor, cov_diag_)
98 raise ValueError(
"Incompatible batch shapes: loc {}, cov_factor {}, cov_diag {}" 99 .format(loc.shape, cov_factor.shape, cov_diag.shape))
100 self.
loc = loc_[..., 0]
102 batch_shape = self.loc.shape[:-1]
107 super(LowRankMultivariateNormal, self).__init__(batch_shape, event_shape,
108 validate_args=validate_args)
110 def expand(self, batch_shape, _instance=None):
112 batch_shape = torch.Size(batch_shape)
114 new.loc = self.loc.expand(loc_shape)
115 new.cov_diag = self.cov_diag.expand(loc_shape)
116 new.cov_factor = self.cov_factor.expand(loc_shape + self.cov_factor.shape[-1:])
120 super(LowRankMultivariateNormal, new).__init__(batch_shape,
132 return (self._unbroadcasted_cov_factor.pow(2).sum(-1)
136 def scale_tril(self):
143 cov_diag_sqrt_unsqueeze = self._unbroadcasted_cov_diag.sqrt().unsqueeze(-1)
145 K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.transpose(-1, -2)).contiguous()
146 K.view(-1, n * n)[:, ::n + 1] += 1
147 scale_tril = cov_diag_sqrt_unsqueeze * torch.cholesky(K)
151 def covariance_matrix(self):
153 self._unbroadcasted_cov_factor.transpose(-1, -2))
159 def precision_matrix(self):
163 Wt_Dinv = (self._unbroadcasted_cov_factor.transpose(-1, -2)
164 / self._unbroadcasted_cov_diag.unsqueeze(-2))
166 precision_matrix = (torch.diag_embed(self._unbroadcasted_cov_diag.reciprocal())
167 - torch.matmul(A.transpose(-1, -2), A))
171 def rsample(self, sample_shape=torch.Size()):
173 W_shape = shape[:-1] + self.cov_factor.shape[-1:]
174 eps_W = _standard_normal(W_shape, dtype=self.loc.dtype, device=self.loc.device)
175 eps_D = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
177 + self._unbroadcasted_cov_diag.sqrt() * eps_D)
179 def log_prob(self, value):
182 diff = value - self.
loc 190 return -0.5 * (self.
_event_shape[0] * math.log(2 * math.pi) + log_det + M)
196 H = 0.5 * (self.
_event_shape[0] * (1.0 + math.log(2 * math.pi)) + log_det)
def _get_checked_instance(self, cls, _instance=None)
_unbroadcasted_cov_factor
def _extended_shape(self, sample_shape=torch.Size())
def _validate_sample(self, value)