 Caffe2 - Python API A deep learning, cross platform ML framework
exp_family.py
1 import torch
2 from torch.distributions.distribution import Distribution
3
4
6  r"""
7  ExponentialFamily is the abstract base class for probability distributions belonging to an
8  exponential family, whose probability mass/density function has the form is defined below
9
10  .. math::
11
12  p_{F}(x; \theta) = \exp(\langle t(x), \theta\rangle - F(\theta) + k(x))
13
14  where :math:\theta denotes the natural parameters, :math:t(x) denotes the sufficient statistic,
15  :math:F(\theta) is the log normalizer function for a given family and :math:k(x) is the carrier
16  measure.
17
18  Note:
19  This class is an intermediary between the Distribution class and distributions which belong
20  to an exponential family mainly to check the correctness of the .entropy() and analytic KL
21  divergence methods. We use this class to compute the entropy and KL divergence using the AD
22  framework and Bregman divergences (courtesy of: Frank Nielsen and Richard Nock, Entropies and
23  Cross-entropies of Exponential Families).
24  """
25
26  @property
27  def _natural_params(self):
28  """
29  Abstract method for natural parameters. Returns a tuple of Tensors based
30  on the distribution
31  """
32  raise NotImplementedError
33
34  def _log_normalizer(self, *natural_params):
35  """
36  Abstract method for log normalizer function. Returns a log normalizer based on
37  the distribution and input
38  """
39  raise NotImplementedError
40
41  @property
42  def _mean_carrier_measure(self):
43  """
44  Abstract method for expected carrier measure, which is required for computing
45  entropy.
46  """
47  raise NotImplementedError
48
49  def entropy(self):
50  """
51  Method to compute the entropy using Bregman divergence of the log normalizer.
52  """
53  result = -self._mean_carrier_measure
54  nparams = [p.detach().requires_grad_() for p in self._natural_params]
55  lg_normal = self._log_normalizer(*nparams)