2 The ``distributions`` package contains parameterizable probability distributions 3 and sampling functions. This allows the construction of stochastic computation 4 graphs and stochastic gradient estimators for optimization. This package 5 generally follows the design of the `TensorFlow Distributions`_ package. 7 .. _`TensorFlow Distributions`: 8 https://arxiv.org/abs/1711.10604 10 It is not possible to directly backpropagate through random samples. However, 11 there are two main methods for creating surrogate functions that can be 12 backpropagated through. These are the score function estimator/likelihood ratio 13 estimator/REINFORCE and the pathwise derivative estimator. REINFORCE is commonly 14 seen as the basis for policy gradient methods in reinforcement learning, and the 15 pathwise derivative estimator is commonly seen in the reparameterization trick 16 in variational autoencoders. Whilst the score function only requires the value 17 of samples :math:`f(x)`, the pathwise derivative requires the derivative 18 :math:`f'(x)`. The next sections discuss these two in a reinforcement learning 19 example. For more details see 20 `Gradient Estimation Using Stochastic Computation Graphs`_ . 22 .. _`Gradient Estimation Using Stochastic Computation Graphs`: 23 https://arxiv.org/abs/1506.05254 28 When the probability density function is differentiable with respect to its 29 parameters, we only need :meth:`~torch.distributions.Distribution.sample` and 30 :meth:`~torch.distributions.Distribution.log_prob` to implement REINFORCE: 34 \Delta\theta = \alpha r \frac{\partial\log p(a|\pi^\theta(s))}{\partial\theta} 36 where :math:`\theta` are the parameters, :math:`\alpha` is the learning rate, 37 :math:`r` is the reward and :math:`p(a|\pi^\theta(s))` is the probability of 38 taking action :math:`a` in state :math:`s` given policy :math:`\pi^\theta`. 40 In practice we would sample an action from the output of a network, apply this 41 action in an environment, and then use ``log_prob`` to construct an equivalent 42 loss function. Note that we use a negative because optimizers use gradient 43 descent, whilst the rule above assumes gradient ascent. With a categorical 44 policy, the code for implementing REINFORCE would be as follows:: 46 probs = policy_network(state) 47 # Note that this is equivalent to what used to be called multinomial 48 m = Categorical(probs) 50 next_state, reward = env.step(action) 51 loss = -m.log_prob(action) * reward 57 The other way to implement these stochastic/policy gradients would be to use the 58 reparameterization trick from the 59 :meth:`~torch.distributions.Distribution.rsample` method, where the 60 parameterized random variable can be constructed via a parameterized 61 deterministic function of a parameter-free random variable. The reparameterized 62 sample therefore becomes differentiable. The code for implementing the pathwise 63 derivative would be as follows:: 65 params = policy_network(state) 67 # Any distribution with .has_rsample == True could work based on the application 69 next_state, reward = env.step(action) # Assuming that reward is differentiable 74 from .bernoulli
import Bernoulli
75 from .beta
import Beta
76 from .binomial
import Binomial
77 from .categorical
import Categorical
78 from .cauchy
import Cauchy
79 from .chi2
import Chi2
80 from .constraint_registry
import biject_to, transform_to
81 from .dirichlet
import Dirichlet
82 from .distribution
import Distribution
83 from .exp_family
import ExponentialFamily
84 from .exponential
import Exponential
85 from .fishersnedecor
import FisherSnedecor
86 from .gamma
import Gamma
87 from .geometric
import Geometric
88 from .gumbel
import Gumbel
89 from .half_cauchy
import HalfCauchy
90 from .half_normal
import HalfNormal
91 from .independent
import Independent
92 from .kl
import kl_divergence, register_kl
93 from .laplace
import Laplace
94 from .log_normal
import LogNormal
95 from .logistic_normal
import LogisticNormal
96 from .lowrank_multivariate_normal
import LowRankMultivariateNormal
97 from .multinomial
import Multinomial
98 from .multivariate_normal
import MultivariateNormal
99 from .negative_binomial
import NegativeBinomial
100 from .normal
import Normal
101 from .one_hot_categorical
import OneHotCategorical
102 from .pareto
import Pareto
103 from .poisson
import Poisson
104 from .relaxed_bernoulli
import RelaxedBernoulli
105 from .relaxed_categorical
import RelaxedOneHotCategorical
106 from .studentT
import StudentT
107 from .transformed_distribution
import TransformedDistribution
108 from .transforms
import *
109 from .uniform
import Uniform
110 from .weibull
import Weibull
131 'LowRankMultivariateNormal',
133 'MultivariateNormal',
139 'RelaxedOneHotCategorical',
144 'TransformedDistribution',
150 __all__.extend(transforms.__all__)