10 Extension of the Distribution class, which applies a sequence of Transforms 11 to a base distribution. Let f be the composition of transforms applied:: 14 Y = f(X) ~ TransformedDistribution(BaseDistribution, f) 15 log p(Y) = log p(X) + log |det (dX/dY)| 17 Note that the ``.event_shape`` of a :class:`TransformedDistribution` is the 18 maximum shape of its base distribution and its transforms, since transforms 19 can introduce correlations among events. 21 An example for the usage of :class:`TransformedDistribution` would be:: 23 # Building a Logistic Distribution 25 # f = a + b * logit(X) 26 # Y ~ f(X) ~ Logistic(a, b) 27 base_distribution = Uniform(0, 1) 28 transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)] 29 logistic = TransformedDistribution(base_distribution, transforms) 31 For more examples, please look at the implementations of 32 :class:`~torch.distributions.gumbel.Gumbel`, 33 :class:`~torch.distributions.half_cauchy.HalfCauchy`, 34 :class:`~torch.distributions.half_normal.HalfNormal`, 35 :class:`~torch.distributions.log_normal.LogNormal`, 36 :class:`~torch.distributions.pareto.Pareto`, 37 :class:`~torch.distributions.weibull.Weibull`, 38 :class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and 39 :class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical` 43 def __init__(self, base_distribution, transforms, validate_args=None):
45 if isinstance(transforms, Transform):
47 elif isinstance(transforms, list):
48 if not all(isinstance(t, Transform)
for t
in transforms):
49 raise ValueError(
"transforms must be a Transform or a list of Transforms")
52 raise ValueError(
"transforms must be a Transform or list, but was {}".format(transforms))
53 shape = self.base_dist.batch_shape + self.base_dist.event_shape
54 event_dim = max([len(self.base_dist.event_shape)] + [t.event_dim
for t
in self.
transforms])
55 batch_shape = shape[:len(shape) - event_dim]
56 event_shape = shape[len(shape) - event_dim:]
57 super(TransformedDistribution, self).__init__(batch_shape, event_shape, validate_args=validate_args)
59 def expand(self, batch_shape, _instance=None):
61 batch_shape = torch.Size(batch_shape)
62 base_dist_batch_shape = batch_shape + self.base_dist.batch_shape[len(self.
batch_shape):]
63 new.base_dist = self.base_dist.expand(base_dist_batch_shape)
65 super(TransformedDistribution, new).__init__(batch_shape, self.
event_shape, validate_args=
False)
69 @constraints.dependent_property
74 def has_rsample(self):
75 return self.base_dist.has_rsample
77 def sample(self, sample_shape=torch.Size()):
79 Generates a sample_shape shaped sample or sample_shape shaped batch of 80 samples if the distribution parameters are batched. Samples first from 81 base distribution and applies `transform()` for every transform in the 85 x = self.base_dist.sample(sample_shape)
90 def rsample(self, sample_shape=torch.Size()):
92 Generates a sample_shape shaped reparameterized sample or sample_shape 93 shaped batch of reparameterized samples if the distribution parameters 94 are batched. Samples first from base distribution and applies 95 `transform()` for every transform in the list. 97 x = self.base_dist.rsample(sample_shape)
104 Scores the sample by inverting the transform(s) and computing the score 105 using the score of the base distribution and the log abs det jacobian. 112 log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y),
113 event_dim - transform.event_dim)
116 log_prob = log_prob + _sum_rightmost(self.base_dist.log_prob(y),
117 event_dim - len(self.base_dist.event_shape))
120 def _monotonize_cdf(self, value):
122 This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is 127 sign = sign * transform.sign
128 if isinstance(sign, int)
and sign == 1:
130 return sign * (value - 0.5) + 0.5
134 Computes the cumulative distribution function by inverting the 135 transform(s) and computing the score of the base distribution. 138 value = transform.inv(value)
140 self.base_dist._validate_sample(value)
141 value = self.base_dist.cdf(value)
147 Computes the inverse cumulative distribution function using 148 transform(s) and computing the score of the base distribution. 152 self.base_dist._validate_sample(value)
153 value = self.base_dist.icdf(value)
155 value = transform(value)
def _get_checked_instance(self, cls, _instance=None)