Caffe2 - Python API
A deep learning, cross platform ML framework
kl.py
1 import math
2 import warnings
3 from functools import total_ordering
4 
5 import torch
6 from torch._six import inf
7 
8 from .bernoulli import Bernoulli
9 from .beta import Beta
10 from .binomial import Binomial
11 from .categorical import Categorical
12 from .dirichlet import Dirichlet
13 from .distribution import Distribution
14 from .exponential import Exponential
15 from .exp_family import ExponentialFamily
16 from .gamma import Gamma
17 from .geometric import Geometric
18 from .gumbel import Gumbel
19 from .half_normal import HalfNormal
20 from .independent import Independent
21 from .laplace import Laplace
22 from .logistic_normal import LogisticNormal
23 from .lowrank_multivariate_normal import (LowRankMultivariateNormal, _batch_lowrank_logdet,
24  _batch_lowrank_mahalanobis)
25 from .multivariate_normal import (MultivariateNormal, _batch_mahalanobis)
26 from .normal import Normal
27 from .one_hot_categorical import OneHotCategorical
28 from .pareto import Pareto
29 from .poisson import Poisson
30 from .transformed_distribution import TransformedDistribution
31 from .uniform import Uniform
32 from .utils import _sum_rightmost
33 
34 _KL_REGISTRY = {} # Source of truth mapping a few general (type, type) pairs to functions.
35 _KL_MEMOIZE = {} # Memoized version mapping many specific (type, type) pairs to functions.
36 
37 
38 def register_kl(type_p, type_q):
39  """
40  Decorator to register a pairwise function with :meth:`kl_divergence`.
41  Usage::
42 
43  @register_kl(Normal, Normal)
44  def kl_normal_normal(p, q):
45  # insert implementation here
46 
47  Lookup returns the most specific (type,type) match ordered by subclass. If
48  the match is ambiguous, a `RuntimeWarning` is raised. For example to
49  resolve the ambiguous situation::
50 
51  @register_kl(BaseP, DerivedQ)
52  def kl_version1(p, q): ...
53  @register_kl(DerivedP, BaseQ)
54  def kl_version2(p, q): ...
55 
56  you should register a third most-specific implementation, e.g.::
57 
58  register_kl(DerivedP, DerivedQ)(kl_version1) # Break the tie.
59 
60  Args:
61  type_p (type): A subclass of :class:`~torch.distributions.Distribution`.
62  type_q (type): A subclass of :class:`~torch.distributions.Distribution`.
63  """
64  if not isinstance(type_p, type) and issubclass(type_p, Distribution):
65  raise TypeError('Expected type_p to be a Distribution subclass but got {}'.format(type_p))
66  if not isinstance(type_q, type) and issubclass(type_q, Distribution):
67  raise TypeError('Expected type_q to be a Distribution subclass but got {}'.format(type_q))
68 
69  def decorator(fun):
70  _KL_REGISTRY[type_p, type_q] = fun
71  _KL_MEMOIZE.clear() # reset since lookup order may have changed
72  return fun
73 
74  return decorator
75 
76 
77 @total_ordering
78 class _Match(object):
79  __slots__ = ['types']
80 
81  def __init__(self, *types):
82  self.types = types
83 
84  def __eq__(self, other):
85  return self.types == other.types
86 
87  def __le__(self, other):
88  for x, y in zip(self.types, other.types):
89  if not issubclass(x, y):
90  return False
91  if x is not y:
92  break
93  return True
94 
95 
96 def _dispatch_kl(type_p, type_q):
97  """
98  Find the most specific approximate match, assuming single inheritance.
99  """
100  matches = [(super_p, super_q) for super_p, super_q in _KL_REGISTRY
101  if issubclass(type_p, super_p) and issubclass(type_q, super_q)]
102  if not matches:
103  return NotImplemented
104  # Check that the left- and right- lexicographic orders agree.
105  left_p, left_q = min(_Match(*m) for m in matches).types
106  right_q, right_p = min(_Match(*reversed(m)) for m in matches).types
107  left_fun = _KL_REGISTRY[left_p, left_q]
108  right_fun = _KL_REGISTRY[right_p, right_q]
109  if left_fun is not right_fun:
110  warnings.warn('Ambiguous kl_divergence({}, {}). Please register_kl({}, {})'.format(
111  type_p.__name__, type_q.__name__, left_p.__name__, right_q.__name__),
112  RuntimeWarning)
113  return left_fun
114 
115 
116 def _infinite_like(tensor):
117  """
118  Helper function for obtaining infinite KL Divergence throughout
119  """
120  return tensor.new_tensor(inf).expand_as(tensor)
121 
122 
123 def _x_log_x(tensor):
124  """
125  Utility function for calculating x log x
126  """
127  return tensor * tensor.log()
128 
129 
130 def _batch_trace_XXT(bmat):
131  """
132  Utility function for calculating the trace of XX^{T} with X having arbitrary trailing batch dimensions
133  """
134  n = bmat.size(-1)
135  m = bmat.size(-2)
136  flat_trace = bmat.reshape(-1, m * n).pow(2).sum(-1)
137  return flat_trace.reshape(bmat.shape[:-2])
138 
139 
140 def kl_divergence(p, q):
141  r"""
142  Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions.
143 
144  .. math::
145 
146  KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx
147 
148  Args:
149  p (Distribution): A :class:`~torch.distributions.Distribution` object.
150  q (Distribution): A :class:`~torch.distributions.Distribution` object.
151 
152  Returns:
153  Tensor: A batch of KL divergences of shape `batch_shape`.
154 
155  Raises:
156  NotImplementedError: If the distribution types have not been registered via
157  :meth:`register_kl`.
158  """
159  try:
160  fun = _KL_MEMOIZE[type(p), type(q)]
161  except KeyError:
162  fun = _dispatch_kl(type(p), type(q))
163  _KL_MEMOIZE[type(p), type(q)] = fun
164  if fun is NotImplemented:
165  raise NotImplementedError
166  return fun(p, q)
167 
168 
169 ################################################################################
170 # KL Divergence Implementations
171 ################################################################################
172 
173 _euler_gamma = 0.57721566490153286060
174 
175 # Same distributions
176 
177 
178 @register_kl(Bernoulli, Bernoulli)
179 def _kl_bernoulli_bernoulli(p, q):
180  t1 = p.probs * (p.probs / q.probs).log()
181  t1[q.probs == 0] = inf
182  t1[p.probs == 0] = 0
183  t2 = (1 - p.probs) * ((1 - p.probs) / (1 - q.probs)).log()
184  t2[q.probs == 1] = inf
185  t2[p.probs == 1] = 0
186  return t1 + t2
187 
188 
189 @register_kl(Beta, Beta)
190 def _kl_beta_beta(p, q):
191  sum_params_p = p.concentration1 + p.concentration0
192  sum_params_q = q.concentration1 + q.concentration0
193  t1 = q.concentration1.lgamma() + q.concentration0.lgamma() + (sum_params_p).lgamma()
194  t2 = p.concentration1.lgamma() + p.concentration0.lgamma() + (sum_params_q).lgamma()
195  t3 = (p.concentration1 - q.concentration1) * torch.digamma(p.concentration1)
196  t4 = (p.concentration0 - q.concentration0) * torch.digamma(p.concentration0)
197  t5 = (sum_params_q - sum_params_p) * torch.digamma(sum_params_p)
198  return t1 - t2 + t3 + t4 + t5
199 
200 
201 @register_kl(Binomial, Binomial)
202 def _kl_binomial_binomial(p, q):
203  # from https://math.stackexchange.com/questions/2214993/
204  # kullback-leibler-divergence-for-binomial-distributions-p-and-q
205  if (p.total_count < q.total_count).any():
206  raise NotImplementedError('KL between Binomials where q.total_count > p.total_count is not implemented')
207  kl = p.total_count * (p.probs * (p.logits - q.logits) + (-p.probs).log1p() - (-q.probs).log1p())
208  inf_idxs = p.total_count > q.total_count
209  kl[inf_idxs] = _infinite_like(kl[inf_idxs])
210  return kl
211 
212 
213 @register_kl(Categorical, Categorical)
214 def _kl_categorical_categorical(p, q):
215  t = p.probs * (p.logits - q.logits)
216  t[(q.probs == 0).expand_as(t)] = inf
217  t[(p.probs == 0).expand_as(t)] = 0
218  return t.sum(-1)
219 
220 
221 @register_kl(Dirichlet, Dirichlet)
222 def _kl_dirichlet_dirichlet(p, q):
223  # From http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/
224  sum_p_concentration = p.concentration.sum(-1)
225  sum_q_concentration = q.concentration.sum(-1)
226  t1 = sum_p_concentration.lgamma() - sum_q_concentration.lgamma()
227  t2 = (p.concentration.lgamma() - q.concentration.lgamma()).sum(-1)
228  t3 = p.concentration - q.concentration
229  t4 = p.concentration.digamma() - sum_p_concentration.digamma().unsqueeze(-1)
230  return t1 - t2 + (t3 * t4).sum(-1)
231 
232 
233 @register_kl(Exponential, Exponential)
234 def _kl_exponential_exponential(p, q):
235  rate_ratio = q.rate / p.rate
236  t1 = -rate_ratio.log()
237  return t1 + rate_ratio - 1
238 
239 
240 @register_kl(ExponentialFamily, ExponentialFamily)
241 def _kl_expfamily_expfamily(p, q):
242  if not type(p) == type(q):
243  raise NotImplementedError("The cross KL-divergence between different exponential families cannot \
244  be computed using Bregman divergences")
245  p_nparams = [np.detach().requires_grad_() for np in p._natural_params]
246  q_nparams = q._natural_params
247  lg_normal = p._log_normalizer(*p_nparams)
248  gradients = torch.autograd.grad(lg_normal.sum(), p_nparams, create_graph=True)
249  result = q._log_normalizer(*q_nparams) - lg_normal.clone()
250  for pnp, qnp, g in zip(p_nparams, q_nparams, gradients):
251  term = (qnp - pnp) * g
252  result -= _sum_rightmost(term, len(q.event_shape))
253  return result
254 
255 
256 @register_kl(Gamma, Gamma)
257 def _kl_gamma_gamma(p, q):
258  t1 = q.concentration * (p.rate / q.rate).log()
259  t2 = torch.lgamma(q.concentration) - torch.lgamma(p.concentration)
260  t3 = (p.concentration - q.concentration) * torch.digamma(p.concentration)
261  t4 = (q.rate - p.rate) * (p.concentration / p.rate)
262  return t1 + t2 + t3 + t4
263 
264 
265 @register_kl(Gumbel, Gumbel)
266 def _kl_gumbel_gumbel(p, q):
267  ct1 = p.scale / q.scale
268  ct2 = q.loc / q.scale
269  ct3 = p.loc / q.scale
270  t1 = -ct1.log() - ct2 + ct3
271  t2 = ct1 * _euler_gamma
272  t3 = torch.exp(ct2 + (1 + ct1).lgamma() - ct3)
273  return t1 + t2 + t3 - (1 + _euler_gamma)
274 
275 
276 @register_kl(Geometric, Geometric)
277 def _kl_geometric_geometric(p, q):
278  return -p.entropy() - torch.log1p(-q.probs) / p.probs - q.logits
279 
280 
281 @register_kl(HalfNormal, HalfNormal)
282 def _kl_halfnormal_halfnormal(p, q):
283  return _kl_normal_normal(p.base_dist, q.base_dist)
284 
285 
286 @register_kl(Laplace, Laplace)
287 def _kl_laplace_laplace(p, q):
288  # From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf
289  scale_ratio = p.scale / q.scale
290  loc_abs_diff = (p.loc - q.loc).abs()
291  t1 = -scale_ratio.log()
292  t2 = loc_abs_diff / q.scale
293  t3 = scale_ratio * torch.exp(-loc_abs_diff / p.scale)
294  return t1 + t2 + t3 - 1
295 
296 
297 @register_kl(LowRankMultivariateNormal, LowRankMultivariateNormal)
298 def _kl_lowrankmultivariatenormal_lowrankmultivariatenormal(p, q):
299  if p.event_shape != q.event_shape:
300  raise ValueError("KL-divergence between two Low Rank Multivariate Normals with\
301  different event shapes cannot be computed")
302 
303  term1 = (_batch_lowrank_logdet(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag,
304  q._capacitance_tril) -
305  _batch_lowrank_logdet(p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag,
306  p._capacitance_tril))
307  term3 = _batch_lowrank_mahalanobis(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag,
308  q.loc - p.loc,
309  q._capacitance_tril)
310  # Expands term2 according to
311  # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ (pW @ pW.T + pD)
312  # = [inv(qD) - A.T @ A] @ (pD + pW @ pW.T)
313  qWt_qDinv = (q._unbroadcasted_cov_factor.transpose(-1, -2) /
314  q._unbroadcasted_cov_diag.unsqueeze(-2))
315  A = torch.trtrs(qWt_qDinv, q._capacitance_tril, upper=False)[0]
316  term21 = (p._unbroadcasted_cov_diag / q._unbroadcasted_cov_diag).sum(-1)
317  term22 = _batch_trace_XXT(p._unbroadcasted_cov_factor *
318  q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1))
319  term23 = _batch_trace_XXT(A * p._unbroadcasted_cov_diag.sqrt().unsqueeze(-2))
320  term24 = _batch_trace_XXT(A.matmul(p._unbroadcasted_cov_factor))
321  term2 = term21 + term22 - term23 - term24
322  return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
323 
324 
325 @register_kl(MultivariateNormal, LowRankMultivariateNormal)
326 def _kl_multivariatenormal_lowrankmultivariatenormal(p, q):
327  if p.event_shape != q.event_shape:
328  raise ValueError("KL-divergence between two (Low Rank) Multivariate Normals with\
329  different event shapes cannot be computed")
330 
331  term1 = (_batch_lowrank_logdet(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag,
332  q._capacitance_tril) -
333  2 * p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1))
334  term3 = _batch_lowrank_mahalanobis(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag,
335  q.loc - p.loc,
336  q._capacitance_tril)
337  # Expands term2 according to
338  # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ p_tril @ p_tril.T
339  # = [inv(qD) - A.T @ A] @ p_tril @ p_tril.T
340  qWt_qDinv = (q._unbroadcasted_cov_factor.transpose(-1, -2) /
341  q._unbroadcasted_cov_diag.unsqueeze(-2))
342  A = torch.trtrs(qWt_qDinv, q._capacitance_tril, upper=False)[0]
343  term21 = _batch_trace_XXT(p._unbroadcasted_scale_tril *
344  q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1))
345  term22 = _batch_trace_XXT(A.matmul(p._unbroadcasted_scale_tril))
346  term2 = term21 - term22
347  return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
348 
349 
350 @register_kl(LowRankMultivariateNormal, MultivariateNormal)
351 def _kl_lowrankmultivariatenormal_multivariatenormal(p, q):
352  if p.event_shape != q.event_shape:
353  raise ValueError("KL-divergence between two (Low Rank) Multivariate Normals with\
354  different event shapes cannot be computed")
355 
356  term1 = (2 * q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) -
357  _batch_lowrank_logdet(p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag,
358  p._capacitance_tril))
359  term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc))
360  # Expands term2 according to
361  # inv(qcov) @ pcov = inv(q_tril @ q_tril.T) @ (pW @ pW.T + pD)
362  combined_batch_shape = torch._C._infer_size(q._unbroadcasted_scale_tril.shape[:-2],
363  p._unbroadcasted_cov_factor.shape[:-2])
364  n = p.event_shape[0]
365  q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
366  p_cov_factor = p._unbroadcasted_cov_factor.expand(combined_batch_shape +
367  (n, p.cov_factor.size(-1)))
368  p_cov_diag = (torch.diag_embed(p._unbroadcasted_cov_diag.sqrt())
369  .expand(combined_batch_shape + (n, n)))
370  term21 = _batch_trace_XXT(torch.trtrs(p_cov_factor, q_scale_tril, upper=False)[0])
371  term22 = _batch_trace_XXT(torch.trtrs(p_cov_diag, q_scale_tril, upper=False)[0])
372  term2 = term21 + term22
373  return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
374 
375 
376 @register_kl(MultivariateNormal, MultivariateNormal)
377 def _kl_multivariatenormal_multivariatenormal(p, q):
378  # From https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Kullback%E2%80%93Leibler_divergence
379  if p.event_shape != q.event_shape:
380  raise ValueError("KL-divergence between two Multivariate Normals with\
381  different event shapes cannot be computed")
382 
383  half_term1 = (q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) -
384  p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1))
385  combined_batch_shape = torch._C._infer_size(q._unbroadcasted_scale_tril.shape[:-2],
386  p._unbroadcasted_scale_tril.shape[:-2])
387  n = p.event_shape[0]
388  q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
389  p_scale_tril = p._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
390  term2 = _batch_trace_XXT(torch.trtrs(p_scale_tril, q_scale_tril, upper=False)[0])
391  term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc))
392  return half_term1 + 0.5 * (term2 + term3 - n)
393 
394 
395 @register_kl(Normal, Normal)
396 def _kl_normal_normal(p, q):
397  var_ratio = (p.scale / q.scale).pow(2)
398  t1 = ((p.loc - q.loc) / q.scale).pow(2)
399  return 0.5 * (var_ratio + t1 - 1 - var_ratio.log())
400 
401 
402 @register_kl(OneHotCategorical, OneHotCategorical)
403 def _kl_onehotcategorical_onehotcategorical(p, q):
404  return _kl_categorical_categorical(p._categorical, q._categorical)
405 
406 
407 @register_kl(Pareto, Pareto)
408 def _kl_pareto_pareto(p, q):
409  # From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf
410  scale_ratio = p.scale / q.scale
411  alpha_ratio = q.alpha / p.alpha
412  t1 = q.alpha * scale_ratio.log()
413  t2 = -alpha_ratio.log()
414  result = t1 + t2 + alpha_ratio - 1
415  result[p.support.lower_bound < q.support.lower_bound] = inf
416  return result
417 
418 
419 @register_kl(Poisson, Poisson)
420 def _kl_poisson_poisson(p, q):
421  return p.rate * (p.rate.log() - q.rate.log()) - (p.rate - q.rate)
422 
423 
424 @register_kl(TransformedDistribution, TransformedDistribution)
425 def _kl_transformed_transformed(p, q):
426  if p.transforms != q.transforms:
427  raise NotImplementedError
428  if p.event_shape != q.event_shape:
429  raise NotImplementedError
430  # extra_event_dim = len(p.event_shape) - len(p.base_dist.event_shape)
431  extra_event_dim = len(p.event_shape)
432  base_kl_divergence = kl_divergence(p.base_dist, q.base_dist)
433  return _sum_rightmost(base_kl_divergence, extra_event_dim)
434 
435 
436 @register_kl(Uniform, Uniform)
437 def _kl_uniform_uniform(p, q):
438  result = ((q.high - q.low) / (p.high - p.low)).log()
439  result[(q.low > p.low) | (q.high < p.high)] = inf
440  return result
441 
442 
443 # Different distributions
444 @register_kl(Bernoulli, Poisson)
445 def _kl_bernoulli_poisson(p, q):
446  return -p.entropy() - (p.probs * q.rate.log() - q.rate)
447 
448 
449 @register_kl(Beta, Pareto)
450 def _kl_beta_infinity(p, q):
451  return _infinite_like(p.concentration1)
452 
453 
454 @register_kl(Beta, Exponential)
455 def _kl_beta_exponential(p, q):
456  return -p.entropy() - q.rate.log() + q.rate * (p.concentration1 / (p.concentration1 + p.concentration0))
457 
458 
459 @register_kl(Beta, Gamma)
460 def _kl_beta_gamma(p, q):
461  t1 = -p.entropy()
462  t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
463  t3 = (q.concentration - 1) * (p.concentration1.digamma() - (p.concentration1 + p.concentration0).digamma())
464  t4 = q.rate * p.concentration1 / (p.concentration1 + p.concentration0)
465  return t1 + t2 - t3 + t4
466 
467 # TODO: Add Beta-Laplace KL Divergence
468 
469 
470 @register_kl(Beta, Normal)
471 def _kl_beta_normal(p, q):
472  E_beta = p.concentration1 / (p.concentration1 + p.concentration0)
473  var_normal = q.scale.pow(2)
474  t1 = -p.entropy()
475  t2 = 0.5 * (var_normal * 2 * math.pi).log()
476  t3 = (E_beta * (1 - E_beta) / (p.concentration1 + p.concentration0 + 1) + E_beta.pow(2)) * 0.5
477  t4 = q.loc * E_beta
478  t5 = q.loc.pow(2) * 0.5
479  return t1 + t2 + (t3 - t4 + t5) / var_normal
480 
481 
482 @register_kl(Beta, Uniform)
483 def _kl_beta_uniform(p, q):
484  result = -p.entropy() + (q.high - q.low).log()
485  result[(q.low > p.support.lower_bound) | (q.high < p.support.upper_bound)] = inf
486  return result
487 
488 
489 @register_kl(Exponential, Beta)
490 @register_kl(Exponential, Pareto)
491 @register_kl(Exponential, Uniform)
492 def _kl_exponential_infinity(p, q):
493  return _infinite_like(p.rate)
494 
495 
496 @register_kl(Exponential, Gamma)
497 def _kl_exponential_gamma(p, q):
498  ratio = q.rate / p.rate
499  t1 = -q.concentration * torch.log(ratio)
500  return t1 + ratio + q.concentration.lgamma() + q.concentration * _euler_gamma - (1 + _euler_gamma)
501 
502 
503 @register_kl(Exponential, Gumbel)
504 def _kl_exponential_gumbel(p, q):
505  scale_rate_prod = p.rate * q.scale
506  loc_scale_ratio = q.loc / q.scale
507  t1 = scale_rate_prod.log() - 1
508  t2 = torch.exp(loc_scale_ratio) * scale_rate_prod / (scale_rate_prod + 1)
509  t3 = scale_rate_prod.reciprocal()
510  return t1 - loc_scale_ratio + t2 + t3
511 
512 # TODO: Add Exponential-Laplace KL Divergence
513 
514 
515 @register_kl(Exponential, Normal)
516 def _kl_exponential_normal(p, q):
517  var_normal = q.scale.pow(2)
518  rate_sqr = p.rate.pow(2)
519  t1 = 0.5 * torch.log(rate_sqr * var_normal * 2 * math.pi)
520  t2 = rate_sqr.reciprocal()
521  t3 = q.loc / p.rate
522  t4 = q.loc.pow(2) * 0.5
523  return t1 - 1 + (t2 - t3 + t4) / var_normal
524 
525 
526 @register_kl(Gamma, Beta)
527 @register_kl(Gamma, Pareto)
528 @register_kl(Gamma, Uniform)
529 def _kl_gamma_infinity(p, q):
530  return _infinite_like(p.concentration)
531 
532 
533 @register_kl(Gamma, Exponential)
534 def _kl_gamma_exponential(p, q):
535  return -p.entropy() - q.rate.log() + q.rate * p.concentration / p.rate
536 
537 
538 @register_kl(Gamma, Gumbel)
539 def _kl_gamma_gumbel(p, q):
540  beta_scale_prod = p.rate * q.scale
541  loc_scale_ratio = q.loc / q.scale
542  t1 = (p.concentration - 1) * p.concentration.digamma() - p.concentration.lgamma() - p.concentration
543  t2 = beta_scale_prod.log() + p.concentration / beta_scale_prod
544  t3 = torch.exp(loc_scale_ratio) * (1 + beta_scale_prod.reciprocal()).pow(-p.concentration) - loc_scale_ratio
545  return t1 + t2 + t3
546 
547 # TODO: Add Gamma-Laplace KL Divergence
548 
549 
550 @register_kl(Gamma, Normal)
551 def _kl_gamma_normal(p, q):
552  var_normal = q.scale.pow(2)
553  beta_sqr = p.rate.pow(2)
554  t1 = 0.5 * torch.log(beta_sqr * var_normal * 2 * math.pi) - p.concentration - p.concentration.lgamma()
555  t2 = 0.5 * (p.concentration.pow(2) + p.concentration) / beta_sqr
556  t3 = q.loc * p.concentration / p.rate
557  t4 = 0.5 * q.loc.pow(2)
558  return t1 + (p.concentration - 1) * p.concentration.digamma() + (t2 - t3 + t4) / var_normal
559 
560 
561 @register_kl(Gumbel, Beta)
562 @register_kl(Gumbel, Exponential)
563 @register_kl(Gumbel, Gamma)
564 @register_kl(Gumbel, Pareto)
565 @register_kl(Gumbel, Uniform)
566 def _kl_gumbel_infinity(p, q):
567  return _infinite_like(p.loc)
568 
569 # TODO: Add Gumbel-Laplace KL Divergence
570 
571 
572 @register_kl(Gumbel, Normal)
573 def _kl_gumbel_normal(p, q):
574  param_ratio = p.scale / q.scale
575  t1 = (param_ratio / math.sqrt(2 * math.pi)).log()
576  t2 = (math.pi * param_ratio * 0.5).pow(2) / 3
577  t3 = ((p.loc + p.scale * _euler_gamma - q.loc) / q.scale).pow(2) * 0.5
578  return -t1 + t2 + t3 - (_euler_gamma + 1)
579 
580 
581 @register_kl(Laplace, Beta)
582 @register_kl(Laplace, Exponential)
583 @register_kl(Laplace, Gamma)
584 @register_kl(Laplace, Pareto)
585 @register_kl(Laplace, Uniform)
586 def _kl_laplace_infinity(p, q):
587  return _infinite_like(p.loc)
588 
589 
590 @register_kl(Laplace, Normal)
591 def _kl_laplace_normal(p, q):
592  var_normal = q.scale.pow(2)
593  scale_sqr_var_ratio = p.scale.pow(2) / var_normal
594  t1 = 0.5 * torch.log(2 * scale_sqr_var_ratio / math.pi)
595  t2 = 0.5 * p.loc.pow(2)
596  t3 = p.loc * q.loc
597  t4 = 0.5 * q.loc.pow(2)
598  return -t1 + scale_sqr_var_ratio + (t2 - t3 + t4) / var_normal - 1
599 
600 
601 @register_kl(Normal, Beta)
602 @register_kl(Normal, Exponential)
603 @register_kl(Normal, Gamma)
604 @register_kl(Normal, Pareto)
605 @register_kl(Normal, Uniform)
606 def _kl_normal_infinity(p, q):
607  return _infinite_like(p.loc)
608 
609 
610 @register_kl(Normal, Gumbel)
611 def _kl_normal_gumbel(p, q):
612  mean_scale_ratio = p.loc / q.scale
613  var_scale_sqr_ratio = (p.scale / q.scale).pow(2)
614  loc_scale_ratio = q.loc / q.scale
615  t1 = var_scale_sqr_ratio.log() * 0.5
616  t2 = mean_scale_ratio - loc_scale_ratio
617  t3 = torch.exp(-mean_scale_ratio + 0.5 * var_scale_sqr_ratio + loc_scale_ratio)
618  return -t1 + t2 + t3 - (0.5 * (1 + math.log(2 * math.pi)))
619 
620 # TODO: Add Normal-Laplace KL Divergence
621 
622 
623 @register_kl(Pareto, Beta)
624 @register_kl(Pareto, Uniform)
625 def _kl_pareto_infinity(p, q):
626  return _infinite_like(p.scale)
627 
628 
629 @register_kl(Pareto, Exponential)
630 def _kl_pareto_exponential(p, q):
631  scale_rate_prod = p.scale * q.rate
632  t1 = (p.alpha / scale_rate_prod).log()
633  t2 = p.alpha.reciprocal()
634  t3 = p.alpha * scale_rate_prod / (p.alpha - 1)
635  result = t1 - t2 + t3 - 1
636  result[p.alpha <= 1] = inf
637  return result
638 
639 
640 @register_kl(Pareto, Gamma)
641 def _kl_pareto_gamma(p, q):
642  common_term = p.scale.log() + p.alpha.reciprocal()
643  t1 = p.alpha.log() - common_term
644  t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
645  t3 = (1 - q.concentration) * common_term
646  t4 = q.rate * p.alpha * p.scale / (p.alpha - 1)
647  result = t1 + t2 + t3 + t4 - 1
648  result[p.alpha <= 1] = inf
649  return result
650 
651 # TODO: Add Pareto-Laplace KL Divergence
652 
653 
654 @register_kl(Pareto, Normal)
655 def _kl_pareto_normal(p, q):
656  var_normal = 2 * q.scale.pow(2)
657  common_term = p.scale / (p.alpha - 1)
658  t1 = (math.sqrt(2 * math.pi) * q.scale * p.alpha / p.scale).log()
659  t2 = p.alpha.reciprocal()
660  t3 = p.alpha * common_term.pow(2) / (p.alpha - 2)
661  t4 = (p.alpha * common_term - q.loc).pow(2)
662  result = t1 - t2 + (t3 + t4) / var_normal - 1
663  result[p.alpha <= 2] = inf
664  return result
665 
666 
667 @register_kl(Poisson, Bernoulli)
668 @register_kl(Poisson, Binomial)
669 def _kl_poisson_infinity(p, q):
670  return _infinite_like(p.rate)
671 
672 
673 @register_kl(Uniform, Beta)
674 def _kl_uniform_beta(p, q):
675  common_term = p.high - p.low
676  t1 = torch.log(common_term)
677  t2 = (q.concentration1 - 1) * (_x_log_x(p.high) - _x_log_x(p.low) - common_term) / common_term
678  t3 = (q.concentration0 - 1) * (_x_log_x((1 - p.high)) - _x_log_x((1 - p.low)) + common_term) / common_term
679  t4 = q.concentration1.lgamma() + q.concentration0.lgamma() - (q.concentration1 + q.concentration0).lgamma()
680  result = t3 + t4 - t1 - t2
681  result[(p.high > q.support.upper_bound) | (p.low < q.support.lower_bound)] = inf
682  return result
683 
684 
685 @register_kl(Uniform, Exponential)
686 def _kl_uniform_exponetial(p, q):
687  result = q.rate * (p.high + p.low) / 2 - ((p.high - p.low) * q.rate).log()
688  result[p.low < q.support.lower_bound] = inf
689  return result
690 
691 
692 @register_kl(Uniform, Gamma)
693 def _kl_uniform_gamma(p, q):
694  common_term = p.high - p.low
695  t1 = common_term.log()
696  t2 = q.concentration.lgamma() - q.concentration * q.rate.log()
697  t3 = (1 - q.concentration) * (_x_log_x(p.high) - _x_log_x(p.low) - common_term) / common_term
698  t4 = q.rate * (p.high + p.low) / 2
699  result = -t1 + t2 + t3 + t4
700  result[p.low < q.support.lower_bound] = inf
701  return result
702 
703 
704 @register_kl(Uniform, Gumbel)
705 def _kl_uniform_gumbel(p, q):
706  common_term = q.scale / (p.high - p.low)
707  high_loc_diff = (p.high - q.loc) / q.scale
708  low_loc_diff = (p.low - q.loc) / q.scale
709  t1 = common_term.log() + 0.5 * (high_loc_diff + low_loc_diff)
710  t2 = common_term * (torch.exp(-high_loc_diff) - torch.exp(-low_loc_diff))
711  return t1 - t2
712 
713 # TODO: Uniform-Laplace KL Divergence
714 
715 
716 @register_kl(Uniform, Normal)
717 def _kl_uniform_normal(p, q):
718  common_term = p.high - p.low
719  t1 = (math.sqrt(math.pi * 2) * q.scale / common_term).log()
720  t2 = (common_term).pow(2) / 12
721  t3 = ((p.high + p.low - 2 * q.loc) / 2).pow(2)
722  return t1 + 0.5 * (t2 + t3) / q.scale.pow(2)
723 
724 
725 @register_kl(Uniform, Pareto)
726 def _kl_uniform_pareto(p, q):
727  support_uniform = p.high - p.low
728  t1 = (q.alpha * q.scale.pow(q.alpha) * (support_uniform)).log()
729  t2 = (_x_log_x(p.high) - _x_log_x(p.low) - support_uniform) / support_uniform
730  result = t2 * (q.alpha + 1) - t1
731  result[p.low < q.support.lower_bound] = inf
732  return result
733 
734 
735 @register_kl(Independent, Independent)
736 def _kl_independent_independent(p, q):
737  if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims:
738  raise NotImplementedError
739  result = kl_divergence(p.base_dist, q.base_dist)
740  return _sum_rightmost(result, p.reinterpreted_batch_ndims)
def grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False, only_inputs=True, allow_unused=False)
Definition: __init__.py:97