 Caffe2 - Python API A deep learning, cross platform ML framework
__init__.py
1 r"""
2 The distributions package contains parameterizable probability distributions
3 and sampling functions. This allows the construction of stochastic computation
4 graphs and stochastic gradient estimators for optimization. This package
5 generally follows the design of the TensorFlow Distributions_ package.
6
7 .. _TensorFlow Distributions:
8  https://arxiv.org/abs/1711.10604
9
10 It is not possible to directly backpropagate through random samples. However,
11 there are two main methods for creating surrogate functions that can be
12 backpropagated through. These are the score function estimator/likelihood ratio
13 estimator/REINFORCE and the pathwise derivative estimator. REINFORCE is commonly
14 seen as the basis for policy gradient methods in reinforcement learning, and the
15 pathwise derivative estimator is commonly seen in the reparameterization trick
16 in variational autoencoders. Whilst the score function only requires the value
17 of samples :math:f(x), the pathwise derivative requires the derivative
18 :math:f'(x). The next sections discuss these two in a reinforcement learning
19 example. For more details see
20 Gradient Estimation Using Stochastic Computation Graphs_ .
21
22 .. _Gradient Estimation Using Stochastic Computation Graphs:
23  https://arxiv.org/abs/1506.05254
24
25 Score function
26 ^^^^^^^^^^^^^^
27
28 When the probability density function is differentiable with respect to its
29 parameters, we only need :meth:~torch.distributions.Distribution.sample and
30 :meth:~torch.distributions.Distribution.log_prob to implement REINFORCE:
31
32 .. math::
33
34  \Delta\theta = \alpha r \frac{\partial\log p(a|\pi^\theta(s))}{\partial\theta}
35
36 where :math:\theta are the parameters, :math:\alpha is the learning rate,
37 :math:r is the reward and :math:p(a|\pi^\theta(s)) is the probability of
38 taking action :math:a in state :math:s given policy :math:\pi^\theta.
39
40 In practice we would sample an action from the output of a network, apply this
41 action in an environment, and then use log_prob to construct an equivalent
42 loss function. Note that we use a negative because optimizers use gradient
43 descent, whilst the rule above assumes gradient ascent. With a categorical
44 policy, the code for implementing REINFORCE would be as follows::
45
46  probs = policy_network(state)
47  # Note that this is equivalent to what used to be called multinomial
48  m = Categorical(probs)
49  action = m.sample()
50  next_state, reward = env.step(action)
51  loss = -m.log_prob(action) * reward
52  loss.backward()
53
54 Pathwise derivative
55 ^^^^^^^^^^^^^^^^^^^
56
57 The other way to implement these stochastic/policy gradients would be to use the
58 reparameterization trick from the
59 :meth:~torch.distributions.Distribution.rsample method, where the
60 parameterized random variable can be constructed via a parameterized
61 deterministic function of a parameter-free random variable. The reparameterized
62 sample therefore becomes differentiable. The code for implementing the pathwise
63 derivative would be as follows::
64
65  params = policy_network(state)
66  m = Normal(*params)
67  # Any distribution with .has_rsample == True could work based on the application
68  action = m.rsample()
69  next_state, reward = env.step(action) # Assuming that reward is differentiable
70  loss = -reward
71  loss.backward()
72 """
73
74 from .bernoulli import Bernoulli
75 from .beta import Beta
76 from .binomial import Binomial
77 from .categorical import Categorical
78 from .cauchy import Cauchy
79 from .chi2 import Chi2
80 from .constraint_registry import biject_to, transform_to
81 from .dirichlet import Dirichlet
82 from .distribution import Distribution
83 from .exp_family import ExponentialFamily
84 from .exponential import Exponential
85 from .fishersnedecor import FisherSnedecor
86 from .gamma import Gamma
87 from .geometric import Geometric
88 from .gumbel import Gumbel
89 from .half_cauchy import HalfCauchy
90 from .half_normal import HalfNormal
91 from .independent import Independent
92 from .kl import kl_divergence, register_kl
93 from .laplace import Laplace
94 from .log_normal import LogNormal
95 from .logistic_normal import LogisticNormal
96 from .lowrank_multivariate_normal import LowRankMultivariateNormal
97 from .multinomial import Multinomial
98 from .multivariate_normal import MultivariateNormal
99 from .negative_binomial import NegativeBinomial
100 from .normal import Normal
101 from .one_hot_categorical import OneHotCategorical
102 from .pareto import Pareto
103 from .poisson import Poisson
104 from .relaxed_bernoulli import RelaxedBernoulli
105 from .relaxed_categorical import RelaxedOneHotCategorical
106 from .studentT import StudentT
107 from .transformed_distribution import TransformedDistribution
108 from .transforms import *
109 from .uniform import Uniform
110 from .weibull import Weibull
111
112 __all__ = [
113  'Bernoulli',
114  'Beta',
115  'Binomial',
116  'Categorical',
117  'Cauchy',
118  'Chi2',
119  'Dirichlet',
120  'Distribution',
121  'Exponential',
122  'ExponentialFamily',
123  'FisherSnedecor',
124  'Gamma',
125  'Geometric',
126  'Gumbel',
127  'Independent',
128  'Laplace',
129  'LogNormal',
130  'LogisticNormal',
131  'LowRankMultivariateNormal',
132  'Multinomial',
133  'MultivariateNormal',
134  'NegativeBinomial',
135  'Normal',
136  'OneHotCategorical',
137  'Pareto',
138  'RelaxedBernoulli',
139  'RelaxedOneHotCategorical',
140  'StudentT',
141  'Poisson',
142  'Uniform',
143  'Weibull',
144  'TransformedDistribution',
145  'biject_to',
146  'kl_divergence',
147  'register_kl',
148  'transform_to',
149 ]
150 __all__.extend(transforms.__all__)