9 def _batch_mv(bmat, bvec):
11 Performs a batched matrix-vector product, with compatible but different batch shapes. 13 This function takes as input `bmat`, containing :math:`n \times n` matrices, and 14 `bvec`, containing length :math:`n` vectors. 16 Both `bmat` and `bvec` may have any number of leading dimensions, which correspond 17 to a batch shape. They are not necessarily assumed to have the same batch shape, 18 just ones which can be broadcasted. 20 return torch.matmul(bmat, bvec.unsqueeze(-1)).squeeze(-1)
23 def _batch_mahalanobis(bL, bx):
25 Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}` 26 for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`. 28 Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch 29 shape, but `bL` one should be able to broadcasted to `bx` one. 32 bx_batch_shape = bx.shape[:-1]
36 bx_batch_dims = len(bx_batch_shape)
37 bL_batch_dims = bL.dim() - 2
38 outer_batch_dims = bx_batch_dims - bL_batch_dims
39 old_batch_dims = outer_batch_dims + bL_batch_dims
40 new_batch_dims = outer_batch_dims + 2 * bL_batch_dims
42 bx_new_shape = bx.shape[:outer_batch_dims]
43 for (sL, sx)
in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]):
44 bx_new_shape += (sx // sL, sL)
46 bx = bx.reshape(bx_new_shape)
48 permute_dims = (list(range(outer_batch_dims)) +
49 list(range(outer_batch_dims, new_batch_dims, 2)) +
50 list(range(outer_batch_dims + 1, new_batch_dims, 2)) +
52 bx = bx.permute(permute_dims)
54 flat_L = bL.reshape(-1, n, n)
55 flat_x = bx.reshape(-1, flat_L.size(0), n)
56 flat_x_swap = flat_x.permute(1, 2, 0)
57 M_swap = torch.trtrs(flat_x_swap, flat_L, upper=
False)[0].pow(2).sum(-2)
61 permuted_M = M.reshape(bx.shape[:-1])
62 permute_inv_dims = list(range(outer_batch_dims))
63 for i
in range(bL_batch_dims):
64 permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i]
65 reshaped_M = permuted_M.permute(permute_inv_dims)
66 return reshaped_M.reshape(bx_batch_shape)
71 Creates a multivariate normal (also called Gaussian) distribution 72 parameterized by a mean vector and a covariance matrix. 74 The multivariate normal distribution can be parameterized either 75 in terms of a positive definite covariance matrix :math:`\mathbf{\Sigma}` 76 or a positive definite precision matrix :math:`\mathbf{\Sigma}^{-1}` 77 or a lower-triangular matrix :math:`\mathbf{L}` with positive-valued 78 diagonal entries, such that 79 :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`. This triangular matrix 80 can be obtained via e.g. Cholesky decomposition of the covariance. 84 >>> m = MultivariateNormal(torch.zeros(2), torch.eye(2)) 85 >>> m.sample() # normally distributed with mean=`[0,0]` and covariance_matrix=`I` 86 tensor([-0.2102, -0.5429]) 89 loc (Tensor): mean of the distribution 90 covariance_matrix (Tensor): positive-definite covariance matrix 91 precision_matrix (Tensor): positive-definite precision matrix 92 scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal 95 Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or 96 :attr:`scale_tril` can be specified. 98 Using :attr:`scale_tril` will be more efficient: all computations internally 99 are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or 100 :attr:`precision_matrix` is passed instead, it is only used to compute 101 the corresponding lower triangular matrices using a Cholesky decomposition. 103 arg_constraints = {
'loc': constraints.real_vector,
104 'covariance_matrix': constraints.positive_definite,
105 'precision_matrix': constraints.positive_definite,
106 'scale_tril': constraints.lower_cholesky}
107 support = constraints.real
110 def __init__(self, loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None):
112 raise ValueError(
"loc must be at least one-dimensional.")
113 if (covariance_matrix
is not None) + (scale_tril
is not None) + (precision_matrix
is not None) != 1:
114 raise ValueError(
"Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified.")
116 loc_ = loc.unsqueeze(-1)
117 if scale_tril
is not None:
118 if scale_tril.dim() < 2:
119 raise ValueError(
"scale_tril matrix must be at least two-dimensional, " 120 "with optional leading batch dimensions")
121 self.
scale_tril, loc_ = torch.broadcast_tensors(scale_tril, loc_)
122 elif covariance_matrix
is not None:
123 if covariance_matrix.dim() < 2:
124 raise ValueError(
"covariance_matrix must be at least two-dimensional, " 125 "with optional leading batch dimensions")
128 if precision_matrix.dim() < 2:
129 raise ValueError(
"precision_matrix must be at least two-dimensional, " 130 "with optional leading batch dimensions")
131 self.
precision_matrix, loc_ = torch.broadcast_tensors(precision_matrix, loc_)
132 self.
loc = loc_[..., 0]
134 batch_shape, event_shape = self.loc.shape[:-1], self.loc.shape[-1:]
135 super(MultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=validate_args)
137 if scale_tril
is not None:
140 if precision_matrix
is not None:
144 def expand(self, batch_shape, _instance=None):
146 batch_shape = torch.Size(batch_shape)
149 new.loc = self.loc.expand(loc_shape)
151 if 'covariance_matrix' in self.__dict__:
152 new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
153 if 'scale_tril' in self.__dict__:
154 new.scale_tril = self.scale_tril.expand(cov_shape)
155 if 'precision_matrix' in self.__dict__:
156 new.precision_matrix = self.precision_matrix.expand(cov_shape)
157 super(MultivariateNormal, new).__init__(batch_shape,
164 def scale_tril(self):
165 return self._unbroadcasted_scale_tril.expand(
169 def covariance_matrix(self):
171 self._unbroadcasted_scale_tril.transpose(-1, -2))
175 def precision_matrix(self):
178 return torch.matmul(scale_tril_inv.transpose(-1, -2), scale_tril_inv).expand(
187 return self._unbroadcasted_scale_tril.pow(2).sum(-1).expand(
190 def rsample(self, sample_shape=torch.Size()):
192 eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
195 def log_prob(self, value):
198 diff = value - self.
loc 200 half_log_det = self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
201 return -0.5 * (self.
_event_shape[0] * math.log(2 * math.pi) + M) - half_log_det
204 half_log_det = self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
205 H = 0.5 * self.
_event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det
def _get_checked_instance(self, cls, _instance=None)
def _extended_shape(self, sample_shape=torch.Size())
def precision_matrix(self)
def _validate_sample(self, value)
_unbroadcasted_scale_tril