3 from functools
import total_ordering
8 from .bernoulli
import Bernoulli
10 from .binomial
import Binomial
11 from .categorical
import Categorical
12 from .dirichlet
import Dirichlet
13 from .distribution
import Distribution
14 from .exponential
import Exponential
15 from .exp_family
import ExponentialFamily
16 from .gamma
import Gamma
17 from .geometric
import Geometric
18 from .gumbel
import Gumbel
19 from .half_normal
import HalfNormal
20 from .independent
import Independent
21 from .laplace
import Laplace
22 from .logistic_normal
import LogisticNormal
23 from .lowrank_multivariate_normal
import (LowRankMultivariateNormal, _batch_lowrank_logdet,
24 _batch_lowrank_mahalanobis)
25 from .multivariate_normal
import (MultivariateNormal, _batch_mahalanobis)
26 from .normal
import Normal
27 from .one_hot_categorical
import OneHotCategorical
28 from .pareto
import Pareto
29 from .poisson
import Poisson
30 from .transformed_distribution
import TransformedDistribution
31 from .uniform
import Uniform
32 from .utils
import _sum_rightmost
38 def register_kl(type_p, type_q):
40 Decorator to register a pairwise function with :meth:`kl_divergence`. 43 @register_kl(Normal, Normal) 44 def kl_normal_normal(p, q): 45 # insert implementation here 47 Lookup returns the most specific (type,type) match ordered by subclass. If 48 the match is ambiguous, a `RuntimeWarning` is raised. For example to 49 resolve the ambiguous situation:: 51 @register_kl(BaseP, DerivedQ) 52 def kl_version1(p, q): ... 53 @register_kl(DerivedP, BaseQ) 54 def kl_version2(p, q): ... 56 you should register a third most-specific implementation, e.g.:: 58 register_kl(DerivedP, DerivedQ)(kl_version1) # Break the tie. 61 type_p (type): A subclass of :class:`~torch.distributions.Distribution`. 62 type_q (type): A subclass of :class:`~torch.distributions.Distribution`. 64 if not isinstance(type_p, type)
and issubclass(type_p, Distribution):
65 raise TypeError(
'Expected type_p to be a Distribution subclass but got {}'.format(type_p))
66 if not isinstance(type_q, type)
and issubclass(type_q, Distribution):
67 raise TypeError(
'Expected type_q to be a Distribution subclass but got {}'.format(type_q))
70 _KL_REGISTRY[type_p, type_q] = fun
81 def __init__(self, *types):
84 def __eq__(self, other):
85 return self.
types == other.types
87 def __le__(self, other):
88 for x, y
in zip(self.
types, other.types):
89 if not issubclass(x, y):
96 def _dispatch_kl(type_p, type_q):
98 Find the most specific approximate match, assuming single inheritance. 100 matches = [(super_p, super_q)
for super_p, super_q
in _KL_REGISTRY
101 if issubclass(type_p, super_p)
and issubclass(type_q, super_q)]
103 return NotImplemented
105 left_p, left_q = min(
_Match(*m)
for m
in matches).types
106 right_q, right_p = min(
_Match(*reversed(m))
for m
in matches).types
107 left_fun = _KL_REGISTRY[left_p, left_q]
108 right_fun = _KL_REGISTRY[right_p, right_q]
109 if left_fun
is not right_fun:
110 warnings.warn(
'Ambiguous kl_divergence({}, {}). Please register_kl({}, {})'.format(
111 type_p.__name__, type_q.__name__, left_p.__name__, right_q.__name__),
116 def _infinite_like(tensor):
118 Helper function for obtaining infinite KL Divergence throughout 120 return tensor.new_tensor(inf).expand_as(tensor)
123 def _x_log_x(tensor):
125 Utility function for calculating x log x 127 return tensor * tensor.log()
130 def _batch_trace_XXT(bmat):
132 Utility function for calculating the trace of XX^{T} with X having arbitrary trailing batch dimensions 136 flat_trace = bmat.reshape(-1, m * n).pow(2).sum(-1)
137 return flat_trace.reshape(bmat.shape[:-2])
140 def kl_divergence(p, q):
142 Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions. 146 KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx 149 p (Distribution): A :class:`~torch.distributions.Distribution` object. 150 q (Distribution): A :class:`~torch.distributions.Distribution` object. 153 Tensor: A batch of KL divergences of shape `batch_shape`. 156 NotImplementedError: If the distribution types have not been registered via 160 fun = _KL_MEMOIZE[type(p), type(q)]
162 fun = _dispatch_kl(type(p), type(q))
163 _KL_MEMOIZE[type(p), type(q)] = fun
164 if fun
is NotImplemented:
165 raise NotImplementedError
173 _euler_gamma = 0.57721566490153286060
178 @register_kl(Bernoulli, Bernoulli)
179 def _kl_bernoulli_bernoulli(p, q):
180 t1 = p.probs * (p.probs / q.probs).log()
181 t1[q.probs == 0] = inf
183 t2 = (1 - p.probs) * ((1 - p.probs) / (1 - q.probs)).log()
184 t2[q.probs == 1] = inf
189 @register_kl(Beta, Beta)
190 def _kl_beta_beta(p, q):
191 sum_params_p = p.concentration1 + p.concentration0
192 sum_params_q = q.concentration1 + q.concentration0
193 t1 = q.concentration1.lgamma() + q.concentration0.lgamma() + (sum_params_p).lgamma()
194 t2 = p.concentration1.lgamma() + p.concentration0.lgamma() + (sum_params_q).lgamma()
195 t3 = (p.concentration1 - q.concentration1) * torch.digamma(p.concentration1)
196 t4 = (p.concentration0 - q.concentration0) * torch.digamma(p.concentration0)
197 t5 = (sum_params_q - sum_params_p) * torch.digamma(sum_params_p)
198 return t1 - t2 + t3 + t4 + t5
201 @register_kl(Binomial, Binomial)
202 def _kl_binomial_binomial(p, q):
205 if (p.total_count < q.total_count).any():
206 raise NotImplementedError(
'KL between Binomials where q.total_count > p.total_count is not implemented')
207 kl = p.total_count * (p.probs * (p.logits - q.logits) + (-p.probs).log1p() - (-q.probs).log1p())
208 inf_idxs = p.total_count > q.total_count
209 kl[inf_idxs] = _infinite_like(kl[inf_idxs])
213 @register_kl(Categorical, Categorical)
214 def _kl_categorical_categorical(p, q):
215 t = p.probs * (p.logits - q.logits)
216 t[(q.probs == 0).expand_as(t)] = inf
217 t[(p.probs == 0).expand_as(t)] = 0
221 @register_kl(Dirichlet, Dirichlet)
222 def _kl_dirichlet_dirichlet(p, q):
224 sum_p_concentration = p.concentration.sum(-1)
225 sum_q_concentration = q.concentration.sum(-1)
226 t1 = sum_p_concentration.lgamma() - sum_q_concentration.lgamma()
227 t2 = (p.concentration.lgamma() - q.concentration.lgamma()).sum(-1)
228 t3 = p.concentration - q.concentration
229 t4 = p.concentration.digamma() - sum_p_concentration.digamma().unsqueeze(-1)
230 return t1 - t2 + (t3 * t4).sum(-1)
233 @register_kl(Exponential, Exponential)
234 def _kl_exponential_exponential(p, q):
235 rate_ratio = q.rate / p.rate
236 t1 = -rate_ratio.log()
237 return t1 + rate_ratio - 1
240 @register_kl(ExponentialFamily, ExponentialFamily)
241 def _kl_expfamily_expfamily(p, q):
242 if not type(p) == type(q):
243 raise NotImplementedError(
"The cross KL-divergence between different exponential families cannot \ 244 be computed using Bregman divergences")
245 p_nparams = [np.detach().requires_grad_()
for np
in p._natural_params]
246 q_nparams = q._natural_params
247 lg_normal = p._log_normalizer(*p_nparams)
249 result = q._log_normalizer(*q_nparams) - lg_normal.clone()
250 for pnp, qnp, g
in zip(p_nparams, q_nparams, gradients):
251 term = (qnp - pnp) * g
252 result -= _sum_rightmost(term, len(q.event_shape))
256 @register_kl(Gamma, Gamma)
257 def _kl_gamma_gamma(p, q):
258 t1 = q.concentration * (p.rate / q.rate).log()
259 t2 = torch.lgamma(q.concentration) - torch.lgamma(p.concentration)
260 t3 = (p.concentration - q.concentration) * torch.digamma(p.concentration)
261 t4 = (q.rate - p.rate) * (p.concentration / p.rate)
262 return t1 + t2 + t3 + t4
265 @register_kl(Gumbel, Gumbel)
266 def _kl_gumbel_gumbel(p, q):
267 ct1 = p.scale / q.scale
268 ct2 = q.loc / q.scale
269 ct3 = p.loc / q.scale
270 t1 = -ct1.log() - ct2 + ct3
271 t2 = ct1 * _euler_gamma
272 t3 = torch.exp(ct2 + (1 + ct1).lgamma() - ct3)
273 return t1 + t2 + t3 - (1 + _euler_gamma)
276 @register_kl(Geometric, Geometric)
277 def _kl_geometric_geometric(p, q):
278 return -p.entropy() - torch.log1p(-q.probs) / p.probs - q.logits
281 @register_kl(HalfNormal, HalfNormal)
282 def _kl_halfnormal_halfnormal(p, q):
283 return _kl_normal_normal(p.base_dist, q.base_dist)
286 @register_kl(Laplace, Laplace)
287 def _kl_laplace_laplace(p, q):
289 scale_ratio = p.scale / q.scale
290 loc_abs_diff = (p.loc - q.loc).abs()
291 t1 = -scale_ratio.log()
292 t2 = loc_abs_diff / q.scale
293 t3 = scale_ratio * torch.exp(-loc_abs_diff / p.scale)
294 return t1 + t2 + t3 - 1
297 @register_kl(LowRankMultivariateNormal, LowRankMultivariateNormal)
298 def _kl_lowrankmultivariatenormal_lowrankmultivariatenormal(p, q):
299 if p.event_shape != q.event_shape:
300 raise ValueError(
"KL-divergence between two Low Rank Multivariate Normals with\ 301 different event shapes cannot be computed")
303 term1 = (_batch_lowrank_logdet(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag,
304 q._capacitance_tril) -
305 _batch_lowrank_logdet(p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag,
306 p._capacitance_tril))
307 term3 = _batch_lowrank_mahalanobis(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag,
313 qWt_qDinv = (q._unbroadcasted_cov_factor.transpose(-1, -2) /
314 q._unbroadcasted_cov_diag.unsqueeze(-2))
315 A = torch.trtrs(qWt_qDinv, q._capacitance_tril, upper=
False)[0]
316 term21 = (p._unbroadcasted_cov_diag / q._unbroadcasted_cov_diag).sum(-1)
317 term22 = _batch_trace_XXT(p._unbroadcasted_cov_factor *
318 q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1))
319 term23 = _batch_trace_XXT(A * p._unbroadcasted_cov_diag.sqrt().unsqueeze(-2))
320 term24 = _batch_trace_XXT(A.matmul(p._unbroadcasted_cov_factor))
321 term2 = term21 + term22 - term23 - term24
322 return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
325 @register_kl(MultivariateNormal, LowRankMultivariateNormal)
326 def _kl_multivariatenormal_lowrankmultivariatenormal(p, q):
327 if p.event_shape != q.event_shape:
328 raise ValueError(
"KL-divergence between two (Low Rank) Multivariate Normals with\ 329 different event shapes cannot be computed")
331 term1 = (_batch_lowrank_logdet(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag,
332 q._capacitance_tril) -
333 2 * p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1))
334 term3 = _batch_lowrank_mahalanobis(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag,
340 qWt_qDinv = (q._unbroadcasted_cov_factor.transpose(-1, -2) /
341 q._unbroadcasted_cov_diag.unsqueeze(-2))
342 A = torch.trtrs(qWt_qDinv, q._capacitance_tril, upper=
False)[0]
343 term21 = _batch_trace_XXT(p._unbroadcasted_scale_tril *
344 q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1))
345 term22 = _batch_trace_XXT(A.matmul(p._unbroadcasted_scale_tril))
346 term2 = term21 - term22
347 return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
350 @register_kl(LowRankMultivariateNormal, MultivariateNormal)
351 def _kl_lowrankmultivariatenormal_multivariatenormal(p, q):
352 if p.event_shape != q.event_shape:
353 raise ValueError(
"KL-divergence between two (Low Rank) Multivariate Normals with\ 354 different event shapes cannot be computed")
356 term1 = (2 * q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) -
357 _batch_lowrank_logdet(p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag,
358 p._capacitance_tril))
359 term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc))
362 combined_batch_shape = torch._C._infer_size(q._unbroadcasted_scale_tril.shape[:-2],
363 p._unbroadcasted_cov_factor.shape[:-2])
365 q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
366 p_cov_factor = p._unbroadcasted_cov_factor.expand(combined_batch_shape +
367 (n, p.cov_factor.size(-1)))
368 p_cov_diag = (torch.diag_embed(p._unbroadcasted_cov_diag.sqrt())
369 .expand(combined_batch_shape + (n, n)))
370 term21 = _batch_trace_XXT(torch.trtrs(p_cov_factor, q_scale_tril, upper=
False)[0])
371 term22 = _batch_trace_XXT(torch.trtrs(p_cov_diag, q_scale_tril, upper=
False)[0])
372 term2 = term21 + term22
373 return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
376 @register_kl(MultivariateNormal, MultivariateNormal)
377 def _kl_multivariatenormal_multivariatenormal(p, q):
379 if p.event_shape != q.event_shape:
380 raise ValueError(
"KL-divergence between two Multivariate Normals with\ 381 different event shapes cannot be computed")
383 half_term1 = (q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) -
384 p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1))
385 combined_batch_shape = torch._C._infer_size(q._unbroadcasted_scale_tril.shape[:-2],
386 p._unbroadcasted_scale_tril.shape[:-2])
388 q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
389 p_scale_tril = p._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
390 term2 = _batch_trace_XXT(torch.trtrs(p_scale_tril, q_scale_tril, upper=
False)[0])
391 term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc))
392 return half_term1 + 0.5 * (term2 + term3 - n)
395 @register_kl(Normal, Normal)
396 def _kl_normal_normal(p, q):
397 var_ratio = (p.scale / q.scale).pow(2)
398 t1 = ((p.loc - q.loc) / q.scale).pow(2)
399 return 0.5 * (var_ratio + t1 - 1 - var_ratio.log())
402 @register_kl(OneHotCategorical, OneHotCategorical)
403 def _kl_onehotcategorical_onehotcategorical(p, q):
404 return _kl_categorical_categorical(p._categorical, q._categorical)
407 @register_kl(Pareto, Pareto)
408 def _kl_pareto_pareto(p, q):
410 scale_ratio = p.scale / q.scale
411 alpha_ratio = q.alpha / p.alpha
412 t1 = q.alpha * scale_ratio.log()
413 t2 = -alpha_ratio.log()
414 result = t1 + t2 + alpha_ratio - 1
415 result[p.support.lower_bound < q.support.lower_bound] = inf
419 @register_kl(Poisson, Poisson)
420 def _kl_poisson_poisson(p, q):
421 return p.rate * (p.rate.log() - q.rate.log()) - (p.rate - q.rate)
424 @register_kl(TransformedDistribution, TransformedDistribution)
425 def _kl_transformed_transformed(p, q):
426 if p.transforms != q.transforms:
427 raise NotImplementedError
428 if p.event_shape != q.event_shape:
429 raise NotImplementedError
431 extra_event_dim = len(p.event_shape)
432 base_kl_divergence = kl_divergence(p.base_dist, q.base_dist)
433 return _sum_rightmost(base_kl_divergence, extra_event_dim)
436 @register_kl(Uniform, Uniform)
437 def _kl_uniform_uniform(p, q):
438 result = ((q.high - q.low) / (p.high - p.low)).log()
439 result[(q.low > p.low) | (q.high < p.high)] = inf
444 @register_kl(Bernoulli, Poisson)
445 def _kl_bernoulli_poisson(p, q):
446 return -p.entropy() - (p.probs * q.rate.log() - q.rate)
449 @register_kl(Beta, Pareto)
450 def _kl_beta_infinity(p, q):
451 return _infinite_like(p.concentration1)
454 @register_kl(Beta, Exponential)
455 def _kl_beta_exponential(p, q):
456 return -p.entropy() - q.rate.log() + q.rate * (p.concentration1 / (p.concentration1 + p.concentration0))
459 @register_kl(Beta, Gamma)
460 def _kl_beta_gamma(p, q):
462 t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
463 t3 = (q.concentration - 1) * (p.concentration1.digamma() - (p.concentration1 + p.concentration0).digamma())
464 t4 = q.rate * p.concentration1 / (p.concentration1 + p.concentration0)
465 return t1 + t2 - t3 + t4
470 @register_kl(Beta, Normal)
471 def _kl_beta_normal(p, q):
472 E_beta = p.concentration1 / (p.concentration1 + p.concentration0)
473 var_normal = q.scale.pow(2)
475 t2 = 0.5 * (var_normal * 2 * math.pi).log()
476 t3 = (E_beta * (1 - E_beta) / (p.concentration1 + p.concentration0 + 1) + E_beta.pow(2)) * 0.5
478 t5 = q.loc.pow(2) * 0.5
479 return t1 + t2 + (t3 - t4 + t5) / var_normal
482 @register_kl(Beta, Uniform)
483 def _kl_beta_uniform(p, q):
484 result = -p.entropy() + (q.high - q.low).log()
485 result[(q.low > p.support.lower_bound) | (q.high < p.support.upper_bound)] = inf
489 @register_kl(Exponential, Beta)
490 @register_kl(Exponential, Pareto)
491 @register_kl(Exponential, Uniform)
492 def _kl_exponential_infinity(p, q):
493 return _infinite_like(p.rate)
496 @register_kl(Exponential, Gamma)
497 def _kl_exponential_gamma(p, q):
498 ratio = q.rate / p.rate
499 t1 = -q.concentration * torch.log(ratio)
500 return t1 + ratio + q.concentration.lgamma() + q.concentration * _euler_gamma - (1 + _euler_gamma)
503 @register_kl(Exponential, Gumbel)
504 def _kl_exponential_gumbel(p, q):
505 scale_rate_prod = p.rate * q.scale
506 loc_scale_ratio = q.loc / q.scale
507 t1 = scale_rate_prod.log() - 1
508 t2 = torch.exp(loc_scale_ratio) * scale_rate_prod / (scale_rate_prod + 1)
509 t3 = scale_rate_prod.reciprocal()
510 return t1 - loc_scale_ratio + t2 + t3
515 @register_kl(Exponential, Normal)
516 def _kl_exponential_normal(p, q):
517 var_normal = q.scale.pow(2)
518 rate_sqr = p.rate.pow(2)
519 t1 = 0.5 * torch.log(rate_sqr * var_normal * 2 * math.pi)
520 t2 = rate_sqr.reciprocal()
522 t4 = q.loc.pow(2) * 0.5
523 return t1 - 1 + (t2 - t3 + t4) / var_normal
526 @register_kl(Gamma, Beta)
527 @register_kl(Gamma, Pareto)
528 @register_kl(Gamma, Uniform)
529 def _kl_gamma_infinity(p, q):
530 return _infinite_like(p.concentration)
533 @register_kl(Gamma, Exponential)
534 def _kl_gamma_exponential(p, q):
535 return -p.entropy() - q.rate.log() + q.rate * p.concentration / p.rate
538 @register_kl(Gamma, Gumbel)
539 def _kl_gamma_gumbel(p, q):
540 beta_scale_prod = p.rate * q.scale
541 loc_scale_ratio = q.loc / q.scale
542 t1 = (p.concentration - 1) * p.concentration.digamma() - p.concentration.lgamma() - p.concentration
543 t2 = beta_scale_prod.log() + p.concentration / beta_scale_prod
544 t3 = torch.exp(loc_scale_ratio) * (1 + beta_scale_prod.reciprocal()).pow(-p.concentration) - loc_scale_ratio
550 @register_kl(Gamma, Normal)
551 def _kl_gamma_normal(p, q):
552 var_normal = q.scale.pow(2)
553 beta_sqr = p.rate.pow(2)
554 t1 = 0.5 * torch.log(beta_sqr * var_normal * 2 * math.pi) - p.concentration - p.concentration.lgamma()
555 t2 = 0.5 * (p.concentration.pow(2) + p.concentration) / beta_sqr
556 t3 = q.loc * p.concentration / p.rate
557 t4 = 0.5 * q.loc.pow(2)
558 return t1 + (p.concentration - 1) * p.concentration.digamma() + (t2 - t3 + t4) / var_normal
561 @register_kl(Gumbel, Beta)
562 @register_kl(Gumbel, Exponential)
563 @register_kl(Gumbel, Gamma)
564 @register_kl(Gumbel, Pareto)
565 @register_kl(Gumbel, Uniform)
566 def _kl_gumbel_infinity(p, q):
567 return _infinite_like(p.loc)
572 @register_kl(Gumbel, Normal)
573 def _kl_gumbel_normal(p, q):
574 param_ratio = p.scale / q.scale
575 t1 = (param_ratio / math.sqrt(2 * math.pi)).log()
576 t2 = (math.pi * param_ratio * 0.5).pow(2) / 3
577 t3 = ((p.loc + p.scale * _euler_gamma - q.loc) / q.scale).pow(2) * 0.5
578 return -t1 + t2 + t3 - (_euler_gamma + 1)
581 @register_kl(Laplace, Beta)
582 @register_kl(Laplace, Exponential)
583 @register_kl(Laplace, Gamma)
584 @register_kl(Laplace, Pareto)
585 @register_kl(Laplace, Uniform)
586 def _kl_laplace_infinity(p, q):
587 return _infinite_like(p.loc)
590 @register_kl(Laplace, Normal)
591 def _kl_laplace_normal(p, q):
592 var_normal = q.scale.pow(2)
593 scale_sqr_var_ratio = p.scale.pow(2) / var_normal
594 t1 = 0.5 * torch.log(2 * scale_sqr_var_ratio / math.pi)
595 t2 = 0.5 * p.loc.pow(2)
597 t4 = 0.5 * q.loc.pow(2)
598 return -t1 + scale_sqr_var_ratio + (t2 - t3 + t4) / var_normal - 1
601 @register_kl(Normal, Beta)
602 @register_kl(Normal, Exponential)
603 @register_kl(Normal, Gamma)
604 @register_kl(Normal, Pareto)
605 @register_kl(Normal, Uniform)
606 def _kl_normal_infinity(p, q):
607 return _infinite_like(p.loc)
610 @register_kl(Normal, Gumbel)
611 def _kl_normal_gumbel(p, q):
612 mean_scale_ratio = p.loc / q.scale
613 var_scale_sqr_ratio = (p.scale / q.scale).pow(2)
614 loc_scale_ratio = q.loc / q.scale
615 t1 = var_scale_sqr_ratio.log() * 0.5
616 t2 = mean_scale_ratio - loc_scale_ratio
617 t3 = torch.exp(-mean_scale_ratio + 0.5 * var_scale_sqr_ratio + loc_scale_ratio)
618 return -t1 + t2 + t3 - (0.5 * (1 + math.log(2 * math.pi)))
623 @register_kl(Pareto, Beta)
624 @register_kl(Pareto, Uniform)
625 def _kl_pareto_infinity(p, q):
626 return _infinite_like(p.scale)
629 @register_kl(Pareto, Exponential)
630 def _kl_pareto_exponential(p, q):
631 scale_rate_prod = p.scale * q.rate
632 t1 = (p.alpha / scale_rate_prod).log()
633 t2 = p.alpha.reciprocal()
634 t3 = p.alpha * scale_rate_prod / (p.alpha - 1)
635 result = t1 - t2 + t3 - 1
636 result[p.alpha <= 1] = inf
640 @register_kl(Pareto, Gamma)
641 def _kl_pareto_gamma(p, q):
642 common_term = p.scale.log() + p.alpha.reciprocal()
643 t1 = p.alpha.log() - common_term
644 t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
645 t3 = (1 - q.concentration) * common_term
646 t4 = q.rate * p.alpha * p.scale / (p.alpha - 1)
647 result = t1 + t2 + t3 + t4 - 1
648 result[p.alpha <= 1] = inf
654 @register_kl(Pareto, Normal)
655 def _kl_pareto_normal(p, q):
656 var_normal = 2 * q.scale.pow(2)
657 common_term = p.scale / (p.alpha - 1)
658 t1 = (math.sqrt(2 * math.pi) * q.scale * p.alpha / p.scale).log()
659 t2 = p.alpha.reciprocal()
660 t3 = p.alpha * common_term.pow(2) / (p.alpha - 2)
661 t4 = (p.alpha * common_term - q.loc).pow(2)
662 result = t1 - t2 + (t3 + t4) / var_normal - 1
663 result[p.alpha <= 2] = inf
667 @register_kl(Poisson, Bernoulli)
668 @register_kl(Poisson, Binomial)
669 def _kl_poisson_infinity(p, q):
670 return _infinite_like(p.rate)
673 @register_kl(Uniform, Beta)
674 def _kl_uniform_beta(p, q):
675 common_term = p.high - p.low
676 t1 = torch.log(common_term)
677 t2 = (q.concentration1 - 1) * (_x_log_x(p.high) - _x_log_x(p.low) - common_term) / common_term
678 t3 = (q.concentration0 - 1) * (_x_log_x((1 - p.high)) - _x_log_x((1 - p.low)) + common_term) / common_term
679 t4 = q.concentration1.lgamma() + q.concentration0.lgamma() - (q.concentration1 + q.concentration0).lgamma()
680 result = t3 + t4 - t1 - t2
681 result[(p.high > q.support.upper_bound) | (p.low < q.support.lower_bound)] = inf
685 @register_kl(Uniform, Exponential)
686 def _kl_uniform_exponetial(p, q):
687 result = q.rate * (p.high + p.low) / 2 - ((p.high - p.low) * q.rate).log()
688 result[p.low < q.support.lower_bound] = inf
692 @register_kl(Uniform, Gamma)
693 def _kl_uniform_gamma(p, q):
694 common_term = p.high - p.low
695 t1 = common_term.log()
696 t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
697 t3 = (1 - q.concentration) * (_x_log_x(p.high) - _x_log_x(p.low) - common_term) / common_term
698 t4 = q.rate * (p.high + p.low) / 2
699 result = -t1 + t2 + t3 + t4
700 result[p.low < q.support.lower_bound] = inf
704 @register_kl(Uniform, Gumbel)
705 def _kl_uniform_gumbel(p, q):
706 common_term = q.scale / (p.high - p.low)
707 high_loc_diff = (p.high - q.loc) / q.scale
708 low_loc_diff = (p.low - q.loc) / q.scale
709 t1 = common_term.log() + 0.5 * (high_loc_diff + low_loc_diff)
710 t2 = common_term * (torch.exp(-high_loc_diff) - torch.exp(-low_loc_diff))
716 @register_kl(Uniform, Normal)
717 def _kl_uniform_normal(p, q):
718 common_term = p.high - p.low
719 t1 = (math.sqrt(math.pi * 2) * q.scale / common_term).log()
720 t2 = (common_term).pow(2) / 12
721 t3 = ((p.high + p.low - 2 * q.loc) / 2).pow(2)
722 return t1 + 0.5 * (t2 + t3) / q.scale.pow(2)
725 @register_kl(Uniform, Pareto)
726 def _kl_uniform_pareto(p, q):
727 support_uniform = p.high - p.low
728 t1 = (q.alpha * q.scale.pow(q.alpha) * (support_uniform)).log()
729 t2 = (_x_log_x(p.high) - _x_log_x(p.low) - support_uniform) / support_uniform
730 result = t2 * (q.alpha + 1) - t1
731 result[p.low < q.support.lower_bound] = inf
735 @register_kl(Independent, Independent)
736 def _kl_independent_independent(p, q):
737 if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims:
738 raise NotImplementedError
739 result = kl_divergence(p.base_dist, q.base_dist)
740 return _sum_rightmost(result, p.reinterpreted_batch_ndims)
def grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False, only_inputs=True, allow_unused=False)