2 Note [Randomized statistical tests] 3 ----------------------------------- 5 This note describes how to maintain tests in this file as random sources 6 change. This file contains two types of randomized tests: 8 1. The easier type of randomized test are tests that should always pass but are 9 initialized with random data. If these fail something is wrong, but it's 10 fine to use a fixed seed by inheriting from common.TestCase. 12 2. The trickier tests are statistical tests. These tests explicitly call 13 set_rng_seed(n) and are marked "see Note [Randomized statistical tests]". 14 These statistical tests have a known positive failure rate 15 (we set failure_rate=1e-3 by default). We need to balance strength of these 16 tests with annoyance of false alarms. One way that works is to specifically 17 set seeds in each of the randomized tests. When a random generator 18 occasionally changes (as in #4312 vectorizing the Box-Muller sampler), some 19 of these statistical tests may (rarely) fail. If one fails in this case, 20 it's fine to increment the seed of the failing test (but you shouldn't need 21 to increment it more than once; otherwise something is probably actually 28 from collections
import namedtuple
29 from itertools
import product
30 from random
import shuffle
34 from common_utils
import TestCase, run_tests, set_rng_seed, TEST_WITH_UBSAN, load_tests, skipIfRocm
35 from common_cuda
import TEST_CUDA
38 Cauchy, Chi2, Dirichlet, Distribution,
39 Exponential, ExponentialFamily,
40 FisherSnedecor, Gamma, Geometric, Gumbel,
41 HalfCauchy, HalfNormal,
42 Independent, Laplace, LogisticNormal,
43 LogNormal, LowRankMultivariateNormal,
44 Multinomial, MultivariateNormal,
45 NegativeBinomial, Normal, OneHotCategorical, Pareto,
46 Poisson, RelaxedBernoulli, RelaxedOneHotCategorical,
47 StudentT, TransformedDistribution, Uniform,
48 Weibull, constraints, kl_divergence)
54 ComposeTransform, ExpTransform,
55 LowerCholeskyTransform,
56 PowerTransform, SigmoidTransform,
58 StickBreakingTransform,
65 load_tests = load_tests
78 Creates a pair of distributions `Dist` initialzed to test each element of 79 param with each other. 82 params2 = [p.transpose(0, 1)
for p
in params1]
83 return Dist(*params1), Dist(*params2)
88 Checks if all entries of a tensor is nan. 90 return (tensor != tensor).all()
94 Example = namedtuple(
'Example', [
'Dist',
'params'])
97 {
'probs':
torch.tensor([0.7, 0.2, 0.4], requires_grad=
True)},
103 {
'probs':
torch.tensor([0.7, 0.2, 0.4], requires_grad=
True)},
109 'concentration1': torch.randn(2, 3).exp().requires_grad_(),
110 'concentration0': torch.randn(2, 3).exp().requires_grad_(),
113 'concentration1': torch.randn(4).exp().requires_grad_(),
114 'concentration0': torch.randn(4).exp().requires_grad_(),
117 Example(Categorical, [
118 {
'probs':
torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=
True)},
119 {
'probs':
torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=
True)},
120 {
'logits':
torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=
True)},
123 {
'probs':
torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=
True),
'total_count': 10},
124 {
'probs':
torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=
True),
'total_count': 10},
127 {
'probs':
torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=
True),
129 {
'probs':
torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=
True),
132 Example(NegativeBinomial, [
133 {
'probs':
torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=
True),
'total_count': 10},
134 {
'probs':
torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=
True),
'total_count': 10},
137 {
'probs':
torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=
True),
139 {
'probs':
torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=
True),
142 Example(Multinomial, [
143 {
'probs':
torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=
True),
'total_count': 10},
144 {
'probs':
torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=
True),
'total_count': 10},
147 {
'loc': 0.0,
'scale': 1.0},
153 {
'df': torch.randn(2, 3).exp().requires_grad_()},
154 {
'df': torch.randn(1).exp().requires_grad_()},
157 {
'df': torch.randn(2, 3).exp().requires_grad_()},
158 {
'df': torch.randn(1).exp().requires_grad_()},
161 {
'concentration': torch.randn(2, 3).exp().requires_grad_()},
162 {
'concentration': torch.randn(4).exp().requires_grad_()},
164 Example(Exponential, [
165 {
'rate': torch.randn(5, 5).abs().requires_grad_()},
166 {
'rate': torch.randn(1).abs().requires_grad_()},
168 Example(FisherSnedecor, [
170 'df1': torch.randn(5, 5).abs().requires_grad_(),
171 'df2': torch.randn(5, 5).abs().requires_grad_(),
174 'df1': torch.randn(1).abs().requires_grad_(),
175 'df2': torch.randn(1).abs().requires_grad_(),
184 'concentration': torch.randn(2, 3).exp().requires_grad_(),
185 'rate': torch.randn(2, 3).exp().requires_grad_(),
188 'concentration': torch.randn(1).exp().requires_grad_(),
189 'rate': torch.randn(1).exp().requires_grad_(),
194 'loc': torch.randn(5, 5, requires_grad=
True),
195 'scale': torch.randn(5, 5).abs().requires_grad_(),
198 'loc': torch.randn(1, requires_grad=
True),
199 'scale': torch.randn(1).abs().requires_grad_(),
202 Example(HalfCauchy, [
206 Example(HalfNormal, [
207 {
'scale': torch.randn(5, 5).abs().requires_grad_()},
208 {
'scale': torch.randn(1).abs().requires_grad_()},
209 {
'scale':
torch.tensor([1e-5, 1e-5], requires_grad=
True)}
211 Example(Independent, [
213 'base_distribution': Normal(torch.randn(2, 3, requires_grad=
True),
214 torch.randn(2, 3).abs().requires_grad_()),
215 'reinterpreted_batch_ndims': 0,
218 'base_distribution': Normal(torch.randn(2, 3, requires_grad=
True),
219 torch.randn(2, 3).abs().requires_grad_()),
220 'reinterpreted_batch_ndims': 1,
223 'base_distribution': Normal(torch.randn(2, 3, requires_grad=
True),
224 torch.randn(2, 3).abs().requires_grad_()),
225 'reinterpreted_batch_ndims': 2,
228 'base_distribution': Normal(torch.randn(2, 3, 5, requires_grad=
True),
229 torch.randn(2, 3, 5).abs().requires_grad_()),
230 'reinterpreted_batch_ndims': 2,
233 'base_distribution': Normal(torch.randn(2, 3, 5, requires_grad=
True),
234 torch.randn(2, 3, 5).abs().requires_grad_()),
235 'reinterpreted_batch_ndims': 3,
240 'loc': torch.randn(5, 5, requires_grad=
True),
241 'scale': torch.randn(5, 5).abs().requires_grad_(),
244 'loc': torch.randn(1, requires_grad=
True),
245 'scale': torch.randn(1).abs().requires_grad_(),
249 'scale':
torch.tensor([1e-5, 1e-5], requires_grad=
True),
254 'loc': torch.randn(5, 5, requires_grad=
True),
255 'scale': torch.randn(5, 5).abs().requires_grad_(),
258 'loc': torch.randn(1, requires_grad=
True),
259 'scale': torch.randn(1).abs().requires_grad_(),
263 'scale':
torch.tensor([1e-5, 1e-5], requires_grad=
True),
266 Example(LogisticNormal, [
268 'loc': torch.randn(5, 5).requires_grad_(),
269 'scale': torch.randn(5, 5).abs().requires_grad_(),
272 'loc': torch.randn(1).requires_grad_(),
273 'scale': torch.randn(1).abs().requires_grad_(),
277 'scale':
torch.tensor([1e-5, 1e-5], requires_grad=
True),
280 Example(LowRankMultivariateNormal, [
282 'loc': torch.randn(5, 2, requires_grad=
True),
283 'cov_factor': torch.randn(5, 2, 1, requires_grad=
True),
284 'cov_diag':
torch.tensor([2.0, 0.25], requires_grad=
True),
287 'loc': torch.randn(4, 3, requires_grad=
True),
288 'cov_factor': torch.randn(3, 2, requires_grad=
True),
289 'cov_diag':
torch.tensor([5.0, 1.5, 3.], requires_grad=
True),
292 Example(MultivariateNormal, [
294 'loc': torch.randn(5, 2, requires_grad=
True),
295 'covariance_matrix':
torch.tensor([[2.0, 0.3], [0.3, 0.25]], requires_grad=
True),
298 'loc': torch.randn(2, 3, requires_grad=
True),
301 [0.0, 0.0, 0.3]], requires_grad=
True),
304 'loc': torch.randn(5, 3, 2, requires_grad=
True),
306 [[2.0, 0.0], [0.3, 0.25]],
307 [[5.0, 0.0], [-0.5, 1.5]]], requires_grad=
True),
311 'covariance_matrix':
torch.tensor([[5.0, -0.5], [-0.5, 1.5]]),
316 'loc': torch.randn(5, 5, requires_grad=
True),
317 'scale': torch.randn(5, 5).abs().requires_grad_(),
320 'loc': torch.randn(1, requires_grad=
True),
321 'scale': torch.randn(1).abs().requires_grad_(),
325 'scale':
torch.tensor([1e-5, 1e-5], requires_grad=
True),
328 Example(OneHotCategorical, [
329 {
'probs':
torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=
True)},
330 {
'probs':
torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=
True)},
331 {
'logits':
torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=
True)},
339 'scale': torch.randn(5, 5).abs().requires_grad_(),
340 'alpha': torch.randn(5, 5).abs().requires_grad_()
349 'rate': torch.randn(5, 5).abs().requires_grad_(),
352 'rate': torch.randn(3).abs().requires_grad_(),
358 Example(RelaxedBernoulli, [
361 'probs':
torch.tensor([0.7, 0.2, 0.4], requires_grad=
True),
372 Example(RelaxedOneHotCategorical, [
375 'probs':
torch.tensor([[0.1, 0.2, 0.7], [0.5, 0.3, 0.2]], requires_grad=
True)
386 Example(TransformedDistribution, [
388 'base_distribution': Normal(torch.randn(2, 3, requires_grad=
True),
389 torch.randn(2, 3).abs().requires_grad_()),
393 'base_distribution': Normal(torch.randn(2, 3, requires_grad=
True),
394 torch.randn(2, 3).abs().requires_grad_()),
395 'transforms': ExpTransform(),
398 'base_distribution': Normal(torch.randn(2, 3, 5, requires_grad=
True),
399 torch.randn(2, 3, 5).abs().requires_grad_()),
400 'transforms': [AffineTransform(torch.randn(3, 5), torch.randn(3, 5)),
404 'base_distribution': Normal(torch.randn(2, 3, 5, requires_grad=
True),
405 torch.randn(2, 3, 5).abs().requires_grad_()),
406 'transforms': AffineTransform(1, 2),
411 'low': torch.zeros(5, 5, requires_grad=
True),
412 'high': torch.ones(5, 5, requires_grad=
True),
415 'low': torch.zeros(1, requires_grad=
True),
416 'high': torch.ones(1, requires_grad=
True),
425 'scale': torch.randn(5, 5).abs().requires_grad_(),
426 'concentration': torch.randn(1).abs().requires_grad_()
433 {
'probs':
torch.tensor([1.1, 0.2, 0.4], requires_grad=
True)},
439 'concentration1':
torch.tensor([0.0], requires_grad=
True),
440 'concentration0':
torch.tensor([0.0], requires_grad=
True),
443 'concentration1':
torch.tensor([-1.0], requires_grad=
True),
444 'concentration0':
torch.tensor([-2.0], requires_grad=
True),
448 {
'probs':
torch.tensor([1.1, 0.2, 0.4], requires_grad=
True)},
450 {
'probs': 1.00000001},
452 Example(Categorical, [
453 {
'probs':
torch.tensor([[-0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=
True)},
454 {
'probs':
torch.tensor([[-1.0, 10.0], [0.0, -1.0]], requires_grad=
True)},
457 {
'probs':
torch.tensor([[-0.0000001, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=
True),
459 {
'probs':
torch.tensor([[1.0, 0.0], [0.0, 2.0]], requires_grad=
True),
462 Example(NegativeBinomial, [
463 {
'probs':
torch.tensor([[-0.0000001, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=
True),
465 {
'probs':
torch.tensor([[1.0, 0.0], [0.0, 2.0]], requires_grad=
True),
469 {
'loc': 0.0,
'scale': -1.0},
483 {
'concentration':
torch.tensor([0.], requires_grad=
True)},
484 {
'concentration':
torch.tensor([-2.], requires_grad=
True)}
486 Example(Exponential, [
490 Example(FisherSnedecor, [
502 'concentration':
torch.tensor([0., 0.], requires_grad=
True),
506 'concentration':
torch.tensor([1., 1.], requires_grad=
True),
520 Example(HalfCauchy, [
525 Example(HalfNormal, [
549 Example(MultivariateNormal, [
552 'covariance_matrix':
torch.tensor([[1.0, 0.0], [0.0, -2.0]], requires_grad=
True),
566 'scale':
torch.tensor([1e-5, -1e-5], requires_grad=
True),
569 Example(OneHotCategorical, [
570 {
'probs':
torch.tensor([[0.1, 0.2, 0.3], [0.1, -10.0, 0.2]], requires_grad=
True)},
571 {
'probs':
torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=
True)},
595 Example(RelaxedBernoulli, [
598 'probs':
torch.tensor([1.7, 0.2, 0.4], requires_grad=
True),
605 Example(RelaxedOneHotCategorical, [
608 'probs':
torch.tensor([[-0.1, 0.2, 0.7], [0.5, 0.3, 0.2]], requires_grad=
True)
615 Example(TransformedDistribution, [
617 'base_distribution': Normal(0, 1),
618 'transforms':
lambda x: x,
621 'base_distribution': Normal(0, 1),
622 'transforms': [
lambda x: x],
642 'concentration':
torch.tensor([0.0], requires_grad=
True)
646 'concentration':
torch.tensor([-1.0], requires_grad=
True)
653 _do_cuda_memory_leak_check =
True 655 def _gradcheck_log_prob(self, dist_ctor, ctor_params):
657 distribution = dist_ctor(*ctor_params)
658 s = distribution.sample()
659 if s.is_floating_point():
660 s = s.detach().requires_grad_()
662 expected_shape = distribution.batch_shape + distribution.event_shape
665 def apply_fn(s, *params):
666 return dist_ctor(*params).log_prob(s)
668 gradcheck(apply_fn, (s,) + tuple(ctor_params), raise_exception=
True)
670 def _check_log_prob(self, dist, asset_fn):
673 log_probs = dist.log_prob(s)
674 log_probs_data_flat = log_probs.view(-1)
675 s_data_flat = s.view(len(log_probs_data_flat), -1)
676 for i, (val, log_prob)
in enumerate(zip(s_data_flat, log_probs_data_flat)):
677 asset_fn(i, val.squeeze(), log_prob)
679 def _check_sampler_sampler(self, torch_dist, ref_dist, message, multivariate=False,
680 num_samples=10000, failure_rate=1e-3):
682 torch_samples = torch_dist.sample((num_samples,)).squeeze()
683 torch_samples = torch_samples.cpu().numpy()
684 ref_samples = ref_dist.rvs(num_samples).astype(np.float64)
687 axis = np.random.normal(size=torch_samples.shape[-1])
688 axis /= np.linalg.norm(axis)
689 torch_samples = np.dot(torch_samples, axis)
690 ref_samples = np.dot(ref_samples, axis)
691 samples = [(x, +1)
for x
in torch_samples] + [(x, -1)
for x
in ref_samples]
693 samples.sort(key=
lambda x: x[0])
694 samples = np.array(samples)[:, 1]
698 samples_per_bin = len(samples) // num_bins
699 bins = samples.reshape((num_bins, samples_per_bin)).mean(axis=1)
700 stddev = samples_per_bin ** -0.5
701 threshold = stddev * scipy.special.erfinv(1 - 2 * failure_rate / num_bins)
702 message =
'{}.sample() is biased:\n{}'.format(message, bins)
704 self.assertLess(-threshold, bias, message)
705 self.assertLess(bias, threshold, message)
707 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
708 def _check_sampler_discrete(self, torch_dist, ref_dist, message,
709 num_samples=10000, failure_rate=1e-3):
710 """Runs a Chi2-test for the support, but ignores tail instead of combining""" 711 torch_samples = torch_dist.sample((num_samples,)).squeeze()
712 torch_samples = torch_samples.cpu().numpy()
713 unique, counts = np.unique(torch_samples, return_counts=
True)
714 pmf = ref_dist.pmf(unique)
715 msk = (counts > 5) & ((pmf * num_samples) > 5)
716 self.assertGreater(pmf[msk].sum(), 0.9,
"Distribution is too sparse for test; try increasing num_samples")
717 chisq, p = scipy.stats.chisquare(counts[msk], pmf[msk] * num_samples)
718 self.assertGreater(p, failure_rate, message)
720 def _check_enumerate_support(self, dist, examples):
721 for params, expected
in examples:
722 params = {k:
torch.tensor(v)
for k, v
in params.items()}
725 actual = d.enumerate_support(expand=
False)
727 actual = d.enumerate_support(expand=
True)
728 expected_with_expand = expected.expand((-1,) + d.batch_shape + d.event_shape)
732 for Dist, params
in EXAMPLES:
735 self.assertTrue(repr(dist).startswith(dist.__class__.__name__))
737 def test_sample_detached(self):
738 for Dist, params
in EXAMPLES:
739 for i, param
in enumerate(params):
740 variable_params = [p
for p
in param.values()
if getattr(p,
'requires_grad',
False)]
741 if not variable_params:
744 sample = dist.sample()
745 self.assertFalse(sample.requires_grad,
746 msg=
'{} example {}/{}, .sample() is not detached'.format(
747 Dist.__name__, i + 1, len(params)))
749 def test_rsample_requires_grad(self):
750 for Dist, params
in EXAMPLES:
751 for i, param
in enumerate(params):
752 if not any(getattr(p,
'requires_grad',
False)
for p
in param.values()):
755 if not dist.has_rsample:
757 sample = dist.rsample()
758 self.assertTrue(sample.requires_grad,
759 msg=
'{} example {}/{}, .rsample() does not require grad'.format(
760 Dist.__name__, i + 1, len(params)))
762 def test_enumerate_support_type(self):
763 for Dist, params
in EXAMPLES:
764 for i, param
in enumerate(params):
767 self.assertTrue(type(dist.sample())
is type(dist.enumerate_support()),
768 msg=(
'{} example {}/{}, return type mismatch between ' +
769 'sample and enumerate_support.').format(Dist.__name__, i + 1, len(params)))
770 except NotImplementedError:
773 def test_lazy_property_grad(self):
774 x = torch.randn(1, requires_grad=
True)
787 with torch.no_grad():
790 mean = torch.randn(2)
791 cov = torch.eye(2, requires_grad=
True)
792 distn = MultivariateNormal(mean, cov)
793 with torch.no_grad():
795 distn.scale_tril.sum().backward()
796 self.assertIsNotNone(cov.grad)
798 def test_has_examples(self):
799 distributions_with_examples = {e.Dist
for e
in EXAMPLES}
800 for Dist
in globals().values():
801 if isinstance(Dist, type)
and issubclass(Dist, Distribution) \
802 and Dist
is not Distribution
and Dist
is not ExponentialFamily:
803 self.assertIn(Dist, distributions_with_examples,
804 "Please add {} to the EXAMPLES list in test_distributions.py".format(Dist.__name__))
806 def test_distribution_expand(self):
807 shapes = [torch.Size(), torch.Size((2,)), torch.Size((2, 1))]
808 for Dist, params
in EXAMPLES:
812 expanded_shape = shape + d.batch_shape
813 original_shape = d.batch_shape + d.event_shape
814 expected_shape = shape + original_shape
815 expanded = d.expand(batch_shape=list(expanded_shape))
816 sample = expanded.sample()
817 actual_shape = expanded.sample().shape
820 self.
assertEqual(expanded.log_prob(sample), d.log_prob(sample))
822 self.
assertEqual(expanded.batch_shape, expanded_shape)
825 d.mean.expand(expanded_shape + d.event_shape),
828 d.variance.expand(expanded_shape + d.event_shape),
830 except NotImplementedError:
833 def test_distribution_subclass_expand(self):
834 expand_by = torch.Size((2,))
835 for Dist, params
in EXAMPLES:
837 class SubClass(Dist):
841 d = SubClass(**param)
842 expanded_shape = expand_by + d.batch_shape
843 original_shape = d.batch_shape + d.event_shape
844 expected_shape = expand_by + original_shape
845 expanded = d.expand(batch_shape=expanded_shape)
846 sample = expanded.sample()
847 actual_shape = expanded.sample().shape
848 self.assertEqual(expanded.__class__, d.__class__)
849 self.assertEqual(d.sample().shape, original_shape)
850 self.assertEqual(expanded.log_prob(sample), d.log_prob(sample))
851 self.assertEqual(actual_shape, expected_shape)
853 def test_bernoulli(self):
857 self.assertEqual(Bernoulli(p).sample((8,)).size(), (8, 3))
858 self.assertFalse(Bernoulli(p).sample().requires_grad)
859 self.assertEqual(Bernoulli(r).sample((8,)).size(), (8,))
860 self.assertEqual(Bernoulli(r).sample().size(), ())
861 self.assertEqual(Bernoulli(r).sample((3, 2)).size(), (3, 2,))
862 self.assertEqual(Bernoulli(s).sample().size(), ())
863 self._gradcheck_log_prob(Bernoulli, (p,))
865 def ref_log_prob(idx, val, log_prob):
867 self.assertEqual(log_prob, math.log(prob
if val
else 1 - prob))
869 self._check_log_prob(Bernoulli(p), ref_log_prob)
870 self._check_log_prob(Bernoulli(logits=p.log() - (-p).log1p()), ref_log_prob)
871 self.assertRaises(NotImplementedError, Bernoulli(r).rsample)
874 self.assertEqual(Bernoulli(p).entropy(),
torch.tensor([0.6108, 0.5004, 0.6730]), prec=1e-4)
876 self.assertEqual(Bernoulli(s).entropy(),
torch.tensor(0.6108), prec=1e-4)
878 def test_bernoulli_enumerate_support(self):
880 ({
"probs": [0.1]}, [[0], [1]]),
881 ({
"probs": [0.1, 0.9]}, [[0], [1]]),
882 ({
"probs": [[0.1, 0.2], [0.3, 0.4]]}, [[[0]], [[1]]]),
884 self._check_enumerate_support(Bernoulli, examples)
886 def test_bernoulli_3d(self):
887 p = torch.full((2, 3, 5), 0.5).requires_grad_()
888 self.assertEqual(Bernoulli(p).sample().size(), (2, 3, 5))
889 self.assertEqual(Bernoulli(p).sample(sample_shape=(2, 5)).size(),
891 self.assertEqual(Bernoulli(p).sample((2,)).size(), (2, 2, 3, 5))
893 def test_geometric(self):
897 self.assertEqual(Geometric(p).sample((8,)).size(), (8, 3))
898 self.assertEqual(Geometric(1).sample(), 0)
899 self.assertEqual(Geometric(1).log_prob(
torch.tensor(1.)), -inf, allow_inf=
True)
900 self.assertEqual(Geometric(1).log_prob(
torch.tensor(0.)), 0)
901 self.assertFalse(Geometric(p).sample().requires_grad)
902 self.assertEqual(Geometric(r).sample((8,)).size(), (8,))
903 self.assertEqual(Geometric(r).sample().size(), ())
904 self.assertEqual(Geometric(r).sample((3, 2)).size(), (3, 2))
905 self.assertEqual(Geometric(s).sample().size(), ())
906 self._gradcheck_log_prob(Geometric, (p,))
907 self.assertRaises(ValueError,
lambda: Geometric(0))
908 self.assertRaises(NotImplementedError, Geometric(r).rsample)
910 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
911 def test_geometric_log_prob_and_entropy(self):
915 def ref_log_prob(idx, val, log_prob):
916 prob = p[idx].detach()
917 self.assertEqual(log_prob, scipy.stats.geom(prob, loc=-1).logpmf(val))
919 self._check_log_prob(Geometric(p), ref_log_prob)
920 self._check_log_prob(Geometric(logits=p.log() - (-p).log1p()), ref_log_prob)
923 self.assertEqual(Geometric(p).entropy(), scipy.stats.geom(p.detach().numpy(), loc=-1).entropy(), prec=1e-3)
924 self.assertEqual(float(Geometric(s).entropy()), scipy.stats.geom(s, loc=-1).entropy().item(), prec=1e-3)
926 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
927 def test_geometric_sample(self):
929 for prob
in [0.01, 0.18, 0.8]:
930 self._check_sampler_discrete(Geometric(prob),
931 scipy.stats.geom(p=prob, loc=-1),
932 'Geometric(prob={})'.format(prob))
934 def test_binomial(self):
935 p = torch.arange(0.05, 1, 0.1).requires_grad_()
936 for total_count
in [1, 2, 10]:
937 self._gradcheck_log_prob(
lambda p: Binomial(total_count, p), [p])
938 self._gradcheck_log_prob(
lambda p: Binomial(total_count,
None, p.log()), [p])
939 self.assertRaises(NotImplementedError, Binomial(10, p).rsample)
940 self.assertRaises(NotImplementedError, Binomial(10, p).entropy)
942 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
943 def test_binomial_log_prob(self):
944 probs = torch.arange(0.05, 1, 0.1)
945 for total_count
in [1, 2, 10]:
947 def ref_log_prob(idx, x, log_prob):
948 p = probs.view(-1)[idx].item()
949 expected = scipy.stats.binom(total_count, p).logpmf(x)
950 self.assertAlmostEqual(log_prob, expected, places=3)
952 self._check_log_prob(Binomial(total_count, probs), ref_log_prob)
953 logits = probs_to_logits(probs, is_binary=
True)
954 self._check_log_prob(Binomial(total_count, logits=logits), ref_log_prob)
956 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
957 def test_binomial_log_prob_vectorized_count(self):
961 log_prob = Binomial(total_count, probs).log_prob(sample)
962 expected = scipy.stats.binom(total_count.cpu().numpy(), probs.cpu().numpy()).logpmf(sample)
963 self.assertAlmostEqual(log_prob, expected, places=4)
965 def test_binomial_enumerate_support(self):
967 ({
"probs": [0.1],
"total_count": 2}, [[0], [1], [2]]),
968 ({
"probs": [0.1, 0.9],
"total_count": 2}, [[0], [1], [2]]),
969 ({
"probs": [[0.1, 0.2], [0.3, 0.4]],
"total_count": 3}, [[[0]], [[1]], [[2]], [[3]]]),
971 self._check_enumerate_support(Binomial, examples)
973 def test_binomial_extreme_vals(self):
975 bin0 = Binomial(total_count, 0)
976 self.assertEqual(bin0.sample(), 0)
977 self.assertAlmostEqual(bin0.log_prob(
torch.tensor([0.]))[0], 0, places=3)
978 self.assertEqual(float(bin0.log_prob(
torch.tensor([1.])).exp()), 0, allow_inf=
True)
979 bin1 = Binomial(total_count, 1)
980 self.assertEqual(bin1.sample(), total_count)
981 self.assertAlmostEqual(bin1.log_prob(
torch.tensor([float(total_count)]))[0], 0, places=3)
982 self.assertEqual(float(bin1.log_prob(
torch.tensor([float(total_count - 1)])).exp()), 0, allow_inf=
True)
983 zero_counts = torch.zeros(torch.Size((2, 2)))
984 bin2 = Binomial(zero_counts, 1)
985 self.assertEqual(bin2.sample(), zero_counts)
986 self.assertEqual(bin2.log_prob(zero_counts), zero_counts)
988 def test_binomial_vectorized_count(self):
992 self.assertEqual(bin0.sample(), total_count)
994 samples = bin1.sample(torch.Size((100000,)))
995 self.assertTrue((samples <= total_count.type_as(samples)).all())
996 self.assertEqual(samples.mean(dim=0), bin1.mean, prec=0.02)
997 self.assertEqual(samples.var(dim=0), bin1.variance, prec=0.02)
999 def test_negative_binomial(self):
1000 p = torch.arange(0.05, 1, 0.1).requires_grad_()
1001 for total_count
in [1, 2, 10]:
1002 self._gradcheck_log_prob(
lambda p: NegativeBinomial(total_count, p), [p])
1003 self._gradcheck_log_prob(
lambda p: NegativeBinomial(total_count,
None, p.log()), [p])
1004 self.assertRaises(NotImplementedError, NegativeBinomial(10, p).rsample)
1005 self.assertRaises(NotImplementedError, NegativeBinomial(10, p).entropy)
1007 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
1008 def test_negative_binomial_log_prob(self):
1009 probs = torch.arange(0.05, 1, 0.1)
1010 for total_count
in [1, 2, 10]:
1012 def ref_log_prob(idx, x, log_prob):
1013 p = probs.view(-1)[idx].item()
1014 expected = scipy.stats.nbinom(total_count, 1 - p).logpmf(x)
1015 self.assertAlmostEqual(log_prob, expected, places=3)
1017 self._check_log_prob(NegativeBinomial(total_count, probs), ref_log_prob)
1018 logits = probs_to_logits(probs, is_binary=
True)
1019 self._check_log_prob(NegativeBinomial(total_count, logits=logits), ref_log_prob)
1021 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
1022 def test_negative_binomial_log_prob_vectorized_count(self):
1026 log_prob = NegativeBinomial(total_count, probs).log_prob(sample)
1027 expected = scipy.stats.nbinom(total_count.cpu().numpy(), 1 - probs.cpu().numpy()).logpmf(sample)
1028 self.assertAlmostEqual(log_prob, expected, places=4)
1030 def test_multinomial_1d(self):
1033 self.assertEqual(Multinomial(total_count, p).sample().size(), (3,))
1034 self.assertEqual(Multinomial(total_count, p).sample((2, 2)).size(), (2, 2, 3))
1035 self.assertEqual(Multinomial(total_count, p).sample((1,)).size(), (1, 3))
1036 self._gradcheck_log_prob(
lambda p: Multinomial(total_count, p), [p])
1037 self._gradcheck_log_prob(
lambda p: Multinomial(total_count,
None, p.log()), [p])
1038 self.assertRaises(NotImplementedError, Multinomial(10, p).rsample)
1040 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
1041 def test_multinomial_1d_log_prob(self):
1044 dist = Multinomial(total_count, probs=p)
1046 log_prob = dist.log_prob(x)
1047 expected =
torch.tensor(scipy.stats.multinomial.logpmf(x.numpy(), n=total_count, p=dist.probs.detach().numpy()))
1048 self.assertEqual(log_prob, expected)
1050 dist = Multinomial(total_count, logits=p.log())
1052 log_prob = dist.log_prob(x)
1053 expected =
torch.tensor(scipy.stats.multinomial.logpmf(x.numpy(), n=total_count, p=dist.probs.detach().numpy()))
1054 self.assertEqual(log_prob, expected)
1056 def test_multinomial_2d(self):
1058 probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
1059 probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
1062 self.assertEqual(Multinomial(total_count, p).sample().size(), (2, 3))
1063 self.assertEqual(Multinomial(total_count, p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3))
1064 self.assertEqual(Multinomial(total_count, p).sample((6,)).size(), (6, 2, 3))
1066 self._gradcheck_log_prob(
lambda p: Multinomial(total_count, p), [p])
1067 self._gradcheck_log_prob(
lambda p: Multinomial(total_count,
None, p.log()), [p])
1070 self.assertEqual(Multinomial(total_count, s).sample(),
1074 self.assertRaises(NotImplementedError, Multinomial(10, p).entropy)
1076 def test_categorical_1d(self):
1078 self.assertTrue(
is_all_nan(Categorical(p).mean))
1079 self.assertTrue(
is_all_nan(Categorical(p).variance))
1080 self.assertEqual(Categorical(p).sample().size(), ())
1081 self.assertFalse(Categorical(p).sample().requires_grad)
1082 self.assertEqual(Categorical(p).sample((2, 2)).size(), (2, 2))
1083 self.assertEqual(Categorical(p).sample((1,)).size(), (1,))
1084 self._gradcheck_log_prob(Categorical, (p,))
1085 self.assertRaises(NotImplementedError, Categorical(p).rsample)
1087 def test_categorical_2d(self):
1088 probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
1089 probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
1092 self.assertEqual(Categorical(p).mean.size(), (2,))
1093 self.assertEqual(Categorical(p).variance.size(), (2,))
1094 self.assertTrue(
is_all_nan(Categorical(p).mean))
1095 self.assertTrue(
is_all_nan(Categorical(p).variance))
1096 self.assertEqual(Categorical(p).sample().size(), (2,))
1097 self.assertEqual(Categorical(p).sample(sample_shape=(3, 4)).size(), (3, 4, 2))
1098 self.assertEqual(Categorical(p).sample((6,)).size(), (6, 2))
1099 self._gradcheck_log_prob(Categorical, (p,))
1103 self.assertEqual(Categorical(s).sample(sample_shape=(2,)),
1106 def ref_log_prob(idx, val, log_prob):
1107 sample_prob = p[idx][val] / p[idx].sum()
1108 self.assertEqual(log_prob, math.log(sample_prob))
1110 self._check_log_prob(Categorical(p), ref_log_prob)
1111 self._check_log_prob(Categorical(logits=p.log()), ref_log_prob)
1114 self.assertEqual(Categorical(p).entropy(),
torch.tensor([1.0114, 1.0297]), prec=1e-4)
1115 self.assertEqual(Categorical(s).entropy(),
torch.tensor([0.0, 0.0]))
1117 def test_categorical_enumerate_support(self):
1119 ({
"probs": [0.1, 0.2, 0.7]}, [0, 1, 2]),
1120 ({
"probs": [[0.1, 0.9], [0.3, 0.7]]}, [[0], [1]]),
1122 self._check_enumerate_support(Categorical, examples)
1124 def test_one_hot_categorical_1d(self):
1126 self.assertEqual(OneHotCategorical(p).sample().size(), (3,))
1127 self.assertFalse(OneHotCategorical(p).sample().requires_grad)
1128 self.assertEqual(OneHotCategorical(p).sample((2, 2)).size(), (2, 2, 3))
1129 self.assertEqual(OneHotCategorical(p).sample((1,)).size(), (1, 3))
1130 self._gradcheck_log_prob(OneHotCategorical, (p,))
1131 self.assertRaises(NotImplementedError, OneHotCategorical(p).rsample)
1133 def test_one_hot_categorical_2d(self):
1134 probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
1135 probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
1138 self.assertEqual(OneHotCategorical(p).sample().size(), (2, 3))
1139 self.assertEqual(OneHotCategorical(p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3))
1140 self.assertEqual(OneHotCategorical(p).sample((6,)).size(), (6, 2, 3))
1141 self._gradcheck_log_prob(OneHotCategorical, (p,))
1143 dist = OneHotCategorical(p)
1145 self.assertEqual(dist.log_prob(x), Categorical(p).log_prob(x.max(-1)[1]))
1147 def test_one_hot_categorical_enumerate_support(self):
1149 ({
"probs": [0.1, 0.2, 0.7]}, [[1, 0, 0], [0, 1, 0], [0, 0, 1]]),
1150 ({
"probs": [[0.1, 0.9], [0.3, 0.7]]}, [[[1, 0]], [[0, 1]]]),
1152 self._check_enumerate_support(OneHotCategorical, examples)
1154 def test_poisson_shape(self):
1155 rate = torch.randn(2, 3).abs().requires_grad_()
1156 rate_1d = torch.randn(1).abs().requires_grad_()
1157 self.assertEqual(Poisson(rate).sample().size(), (2, 3))
1158 self.assertEqual(Poisson(rate).sample((7,)).size(), (7, 2, 3))
1159 self.assertEqual(Poisson(rate_1d).sample().size(), (1,))
1160 self.assertEqual(Poisson(rate_1d).sample((1,)).size(), (1, 1))
1161 self.assertEqual(Poisson(2.0).sample((2,)).size(), (2,))
1163 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
1164 def test_poisson_log_prob(self):
1165 rate = torch.randn(2, 3).abs().requires_grad_()
1166 rate_1d = torch.randn(1).abs().requires_grad_()
1168 def ref_log_prob(idx, x, log_prob):
1169 l = rate.view(-1)[idx].detach()
1170 expected = scipy.stats.poisson.logpmf(x, l)
1171 self.assertAlmostEqual(log_prob, expected, places=3)
1174 self._check_log_prob(Poisson(rate), ref_log_prob)
1175 self._gradcheck_log_prob(Poisson, (rate,))
1176 self._gradcheck_log_prob(Poisson, (rate_1d,))
1178 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
1179 def test_poisson_sample(self):
1181 for rate
in [0.1, 1.0, 5.0]:
1182 self._check_sampler_discrete(Poisson(rate),
1183 scipy.stats.poisson(rate),
1184 'Poisson(lambda={})'.format(rate),
1187 @unittest.skipIf(
not TEST_CUDA,
"CUDA not found")
1188 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
1189 def test_poisson_gpu_sample(self):
1191 for rate
in [0.12, 0.9, 4.0]:
1192 self._check_sampler_discrete(Poisson(
torch.tensor([rate]).cuda()),
1193 scipy.stats.poisson(rate),
1194 'Poisson(lambda={}, cuda)'.format(rate),
1197 def test_relaxed_bernoulli(self):
1202 self.assertEqual(RelaxedBernoulli(temp, p).sample((8,)).size(), (8, 3))
1203 self.assertFalse(RelaxedBernoulli(temp, p).sample().requires_grad)
1204 self.assertEqual(RelaxedBernoulli(temp, r).sample((8,)).size(), (8,))
1205 self.assertEqual(RelaxedBernoulli(temp, r).sample().size(), ())
1206 self.assertEqual(RelaxedBernoulli(temp, r).sample((3, 2)).size(), (3, 2,))
1207 self.assertEqual(RelaxedBernoulli(temp, s).sample().size(), ())
1208 self._gradcheck_log_prob(RelaxedBernoulli, (temp, p))
1209 self._gradcheck_log_prob(RelaxedBernoulli, (temp, r))
1212 s = RelaxedBernoulli(temp, p).rsample()
1213 s.backward(torch.ones_like(s))
1215 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
1216 def test_rounded_relaxed_bernoulli(self):
1219 class Rounded(object):
1220 def __init__(self, dist):
1223 def sample(self, *args, **kwargs):
1224 return torch.round(self.dist.sample(*args, **kwargs))
1226 for probs, temp
in product([0.1, 0.2, 0.8], [0.1, 1.0, 10.0]):
1227 self._check_sampler_discrete(Rounded(RelaxedBernoulli(temp, probs)),
1228 scipy.stats.bernoulli(probs),
1229 'Rounded(RelaxedBernoulli(temp={}, probs={}))'.format(temp, probs),
1232 for probs
in [0.001, 0.2, 0.999]:
1234 dist = RelaxedBernoulli(1e10, probs)
1236 self.assertEqual(equal_probs, s)
1238 def test_relaxed_one_hot_categorical_1d(self):
1241 self.assertEqual(RelaxedOneHotCategorical(probs=p, temperature=temp).sample().size(), (3,))
1242 self.assertFalse(RelaxedOneHotCategorical(probs=p, temperature=temp).sample().requires_grad)
1243 self.assertEqual(RelaxedOneHotCategorical(probs=p, temperature=temp).sample((2, 2)).size(), (2, 2, 3))
1244 self.assertEqual(RelaxedOneHotCategorical(probs=p, temperature=temp).sample((1,)).size(), (1, 3))
1245 self._gradcheck_log_prob(RelaxedOneHotCategorical, (temp, p))
1247 def test_relaxed_one_hot_categorical_2d(self):
1248 probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
1249 probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
1256 self.assertEqual(RelaxedOneHotCategorical(temp, p).sample().size(), (2, 3))
1257 self.assertEqual(RelaxedOneHotCategorical(temp, p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3))
1258 self.assertEqual(RelaxedOneHotCategorical(temp, p).sample((6,)).size(), (6, 2, 3))
1259 self._gradcheck_log_prob(RelaxedOneHotCategorical, (temp, p))
1260 self._gradcheck_log_prob(RelaxedOneHotCategorical, (temp_2, p))
1262 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
1263 def test_argmax_relaxed_categorical(self):
1266 class ArgMax(object):
1267 def __init__(self, dist):
1270 def sample(self, *args, **kwargs):
1271 s = self.dist.sample(*args, **kwargs)
1272 _, idx = torch.max(s, -1)
1275 class ScipyCategorical(object):
1276 def __init__(self, dist):
1279 def pmf(self, samples):
1280 new_samples = np.zeros(samples.shape + self.dist.p.shape)
1281 new_samples[np.arange(samples.shape[0]), samples] = 1
1282 return self.dist.pmf(new_samples)
1285 self._check_sampler_discrete(ArgMax(RelaxedOneHotCategorical(temp, probs)),
1286 ScipyCategorical(scipy.stats.multinomial(1, probs)),
1287 'Rounded(RelaxedOneHotCategorical(temp={}, probs={}))'.format(temp, probs),
1291 equal_probs = torch.ones(probs.size()) / probs.size()[0]
1292 dist = RelaxedOneHotCategorical(1e10, probs)
1294 self.assertEqual(equal_probs, s)
1296 def test_uniform(self):
1297 low = torch.zeros(5, 5, requires_grad=
True)
1298 high = (torch.ones(5, 5) * 3).requires_grad_()
1299 low_1d = torch.zeros(1, requires_grad=
True)
1300 high_1d = (torch.ones(1) * 3).requires_grad_()
1301 self.assertEqual(Uniform(low, high).sample().size(), (5, 5))
1302 self.assertEqual(Uniform(low, high).sample((7,)).size(), (7, 5, 5))
1303 self.assertEqual(Uniform(low_1d, high_1d).sample().size(), (1,))
1304 self.assertEqual(Uniform(low_1d, high_1d).sample((1,)).size(), (1, 1))
1305 self.assertEqual(Uniform(0.0, 1.0).sample((1,)).size(), (1,))
1308 uniform = Uniform(low_1d, high_1d)
1311 self.assertEqual(uniform.log_prob(above_high).item(), -inf, allow_inf=
True)
1312 self.assertEqual(uniform.log_prob(below_low).item(), -inf, allow_inf=
True)
1315 self.assertEqual(uniform.cdf(below_low).item(), 0)
1316 self.assertEqual(uniform.cdf(above_high).item(), 1)
1319 self._gradcheck_log_prob(Uniform, (low, high))
1320 self._gradcheck_log_prob(Uniform, (low, 1.0))
1321 self._gradcheck_log_prob(Uniform, (0.0, high))
1323 state = torch.get_rng_state()
1324 rand = low.new(low.size()).uniform_()
1325 torch.set_rng_state(state)
1326 u = Uniform(low, high).rsample()
1327 u.backward(torch.ones_like(u))
1328 self.assertEqual(low.grad, 1 - rand)
1329 self.assertEqual(high.grad, rand)
1333 def test_cauchy(self):
1334 loc = torch.zeros(5, 5, requires_grad=
True)
1335 scale = torch.ones(5, 5, requires_grad=
True)
1336 loc_1d = torch.zeros(1, requires_grad=
True)
1337 scale_1d = torch.ones(1, requires_grad=
True)
1338 self.assertTrue(
is_all_nan(Cauchy(loc_1d, scale_1d).mean))
1339 self.assertEqual(Cauchy(loc_1d, scale_1d).variance, inf, allow_inf=
True)
1340 self.assertEqual(Cauchy(loc, scale).sample().size(), (5, 5))
1341 self.assertEqual(Cauchy(loc, scale).sample((7,)).size(), (7, 5, 5))
1342 self.assertEqual(Cauchy(loc_1d, scale_1d).sample().size(), (1,))
1343 self.assertEqual(Cauchy(loc_1d, scale_1d).sample((1,)).size(), (1, 1))
1344 self.assertEqual(Cauchy(0.0, 1.0).sample((1,)).size(), (1,))
1347 self._gradcheck_log_prob(Cauchy, (loc, scale))
1348 self._gradcheck_log_prob(Cauchy, (loc, 1.0))
1349 self._gradcheck_log_prob(Cauchy, (0.0, scale))
1351 state = torch.get_rng_state()
1352 eps = loc.new(loc.size()).cauchy_()
1353 torch.set_rng_state(state)
1354 c = Cauchy(loc, scale).rsample()
1355 c.backward(torch.ones_like(c))
1356 self.assertEqual(loc.grad, torch.ones_like(scale))
1357 self.assertEqual(scale.grad, eps)
1361 def test_halfcauchy(self):
1362 scale = torch.ones(5, 5, requires_grad=
True)
1363 scale_1d = torch.ones(1, requires_grad=
True)
1364 self.assertTrue(
is_all_nan(HalfCauchy(scale_1d).mean))
1365 self.assertEqual(HalfCauchy(scale_1d).variance, inf, allow_inf=
True)
1366 self.assertEqual(HalfCauchy(scale).sample().size(), (5, 5))
1367 self.assertEqual(HalfCauchy(scale).sample((7,)).size(), (7, 5, 5))
1368 self.assertEqual(HalfCauchy(scale_1d).sample().size(), (1,))
1369 self.assertEqual(HalfCauchy(scale_1d).sample((1,)).size(), (1, 1))
1370 self.assertEqual(HalfCauchy(1.0).sample((1,)).size(), (1,))
1373 self._gradcheck_log_prob(HalfCauchy, (scale,))
1374 self._gradcheck_log_prob(HalfCauchy, (1.0,))
1376 state = torch.get_rng_state()
1377 eps = scale.new(scale.size()).cauchy_().abs_()
1378 torch.set_rng_state(state)
1379 c = HalfCauchy(scale).rsample()
1380 c.backward(torch.ones_like(c))
1381 self.assertEqual(scale.grad, eps)
1384 def test_halfnormal(self):
1385 std = torch.randn(5, 5).abs().requires_grad_()
1386 std_1d = torch.randn(1, requires_grad=
True)
1388 self.assertEqual(HalfNormal(std).sample().size(), (5, 5))
1389 self.assertEqual(HalfNormal(std).sample((7,)).size(), (7, 5, 5))
1390 self.assertEqual(HalfNormal(std_1d).sample((1,)).size(), (1, 1))
1391 self.assertEqual(HalfNormal(std_1d).sample().size(), (1,))
1392 self.assertEqual(HalfNormal(.6).sample((1,)).size(), (1,))
1393 self.assertEqual(HalfNormal(50.0).sample((1,)).size(), (1,))
1397 self.assertEqual(HalfNormal(std_delta).sample(sample_shape=(1, 2)),
1401 self._gradcheck_log_prob(HalfNormal, (std,))
1402 self._gradcheck_log_prob(HalfNormal, (1.0,))
1405 dist = HalfNormal(torch.ones(2, 1, 4))
1406 log_prob = dist.log_prob(torch.ones(3, 1))
1407 self.assertEqual(log_prob.shape, (2, 3, 4))
1409 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
1410 def test_halfnormal_logprob(self):
1411 std = torch.randn(5, 1).abs().requires_grad_()
1413 def ref_log_prob(idx, x, log_prob):
1414 s = std.view(-1)[idx].detach()
1415 expected = scipy.stats.halfnorm(scale=s).logpdf(x)
1416 self.assertAlmostEqual(log_prob, expected, places=3)
1418 self._check_log_prob(HalfNormal(std), ref_log_prob)
1420 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
1421 def test_halfnormal_sample(self):
1423 for std
in [0.1, 1.0, 10.0]:
1424 self._check_sampler_sampler(HalfNormal(std),
1425 scipy.stats.halfnorm(scale=std),
1426 'HalfNormal(scale={})'.format(std))
1428 def test_lognormal(self):
1429 mean = torch.randn(5, 5, requires_grad=
True)
1430 std = torch.randn(5, 5).abs().requires_grad_()
1431 mean_1d = torch.randn(1, requires_grad=
True)
1432 std_1d = torch.randn(1).abs().requires_grad_()
1435 self.assertEqual(LogNormal(mean, std).sample().size(), (5, 5))
1436 self.assertEqual(LogNormal(mean, std).sample((7,)).size(), (7, 5, 5))
1437 self.assertEqual(LogNormal(mean_1d, std_1d).sample((1,)).size(), (1, 1))
1438 self.assertEqual(LogNormal(mean_1d, std_1d).sample().size(), (1,))
1439 self.assertEqual(LogNormal(0.2, .6).sample((1,)).size(), (1,))
1440 self.assertEqual(LogNormal(-0.7, 50.0).sample((1,)).size(), (1,))
1444 self.assertEqual(LogNormal(mean_delta, std_delta).sample(sample_shape=(1, 2)),
1445 torch.tensor([[[math.exp(1), 1.0], [math.exp(1), 1.0]]]),
1448 self._gradcheck_log_prob(LogNormal, (mean, std))
1449 self._gradcheck_log_prob(LogNormal, (mean, 1.0))
1450 self._gradcheck_log_prob(LogNormal, (0.0, std))
1453 dist = LogNormal(torch.zeros(4), torch.ones(2, 1, 1))
1454 log_prob = dist.log_prob(torch.ones(3, 1))
1455 self.assertEqual(log_prob.shape, (2, 3, 4))
1457 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
1458 def test_lognormal_logprob(self):
1459 mean = torch.randn(5, 1, requires_grad=
True)
1460 std = torch.randn(5, 1).abs().requires_grad_()
1462 def ref_log_prob(idx, x, log_prob):
1463 m = mean.view(-1)[idx].detach()
1464 s = std.view(-1)[idx].detach()
1465 expected = scipy.stats.lognorm(s=s, scale=math.exp(m)).logpdf(x)
1466 self.assertAlmostEqual(log_prob, expected, places=3)
1468 self._check_log_prob(LogNormal(mean, std), ref_log_prob)
1470 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
1471 def test_lognormal_sample(self):
1473 for mean, std
in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
1474 self._check_sampler_sampler(LogNormal(mean, std),
1475 scipy.stats.lognorm(scale=math.exp(mean), s=std),
1476 'LogNormal(loc={}, scale={})'.format(mean, std))
1478 def test_logisticnormal(self):
1479 mean = torch.randn(5, 5).requires_grad_()
1480 std = torch.randn(5, 5).abs().requires_grad_()
1481 mean_1d = torch.randn(1).requires_grad_()
1482 std_1d = torch.randn(1).requires_grad_()
1485 self.assertEqual(LogisticNormal(mean, std).sample().size(), (5, 6))
1486 self.assertEqual(LogisticNormal(mean, std).sample((7,)).size(), (7, 5, 6))
1487 self.assertEqual(LogisticNormal(mean_1d, std_1d).sample((1,)).size(), (1, 2))
1488 self.assertEqual(LogisticNormal(mean_1d, std_1d).sample().size(), (2,))
1489 self.assertEqual(LogisticNormal(0.2, .6).sample((1,)).size(), (2,))
1490 self.assertEqual(LogisticNormal(-0.7, 50.0).sample((1,)).size(), (2,))
1494 self.assertEqual(LogisticNormal(mean_delta, std_delta).sample(),
1496 1. / (1. + 1. + math.exp(1)),
1497 1. / (1. + 1. + math.exp(1))]),
1500 self._gradcheck_log_prob(LogisticNormal, (mean, std))
1501 self._gradcheck_log_prob(LogisticNormal, (mean, 1.0))
1502 self._gradcheck_log_prob(LogisticNormal, (0.0, std))
1504 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
1505 def test_logisticnormal_logprob(self):
1506 mean = torch.randn(5, 7).requires_grad_()
1507 std = torch.randn(5, 7).abs().requires_grad_()
1512 dist = LogisticNormal(mean, std)
1513 assert dist.log_prob(dist.sample()).detach().cpu().numpy().shape == (5,)
1515 def _get_logistic_normal_ref_sampler(self, base_dist):
1517 def _sampler(num_samples):
1518 x = base_dist.rvs(num_samples)
1519 offset = np.log((x.shape[-1] + 1) - np.ones_like(x).cumsum(-1))
1520 z = 1. / (1. + np.exp(offset - x))
1521 z_cumprod = np.cumprod(1 - z, axis=-1)
1522 y1 = np.pad(z, ((0, 0), (0, 1)), mode=
'constant', constant_values=1.)
1523 y2 = np.pad(z_cumprod, ((0, 0), (1, 0)), mode=
'constant', constant_values=1.)
1528 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
1529 def test_logisticnormal_sample(self):
1531 means = map(np.asarray, [(-1.0, -1.0), (0.0, 0.0), (1.0, 1.0)])
1532 covs = map(np.diag, [(0.1, 0.1), (1.0, 1.0), (10.0, 10.0)])
1533 for mean, cov
in product(means, covs):
1534 base_dist = scipy.stats.multivariate_normal(mean=mean, cov=cov)
1535 ref_dist = scipy.stats.multivariate_normal(mean=mean, cov=cov)
1536 ref_dist.rvs = self._get_logistic_normal_ref_sampler(base_dist)
1539 self._check_sampler_sampler(
1540 LogisticNormal(mean_th, std_th), ref_dist,
1541 'LogisticNormal(loc={}, scale={})'.format(mean_th, std_th),
1544 def test_normal(self):
1545 loc = torch.randn(5, 5, requires_grad=
True)
1546 scale = torch.randn(5, 5).abs().requires_grad_()
1547 loc_1d = torch.randn(1, requires_grad=
True)
1548 scale_1d = torch.randn(1).abs().requires_grad_()
1551 self.assertEqual(Normal(loc, scale).sample().size(), (5, 5))
1552 self.assertEqual(Normal(loc, scale).sample((7,)).size(), (7, 5, 5))
1553 self.assertEqual(Normal(loc_1d, scale_1d).sample((1,)).size(), (1, 1))
1554 self.assertEqual(Normal(loc_1d, scale_1d).sample().size(), (1,))
1555 self.assertEqual(Normal(0.2, .6).sample((1,)).size(), (1,))
1556 self.assertEqual(Normal(-0.7, 50.0).sample((1,)).size(), (1,))
1560 self.assertEqual(Normal(loc_delta, scale_delta).sample(sample_shape=(1, 2)),
1564 self._gradcheck_log_prob(Normal, (loc, scale))
1565 self._gradcheck_log_prob(Normal, (loc, 1.0))
1566 self._gradcheck_log_prob(Normal, (0.0, scale))
1568 state = torch.get_rng_state()
1569 eps = torch.normal(torch.zeros_like(loc), torch.ones_like(scale))
1570 torch.set_rng_state(state)
1571 z = Normal(loc, scale).rsample()
1572 z.backward(torch.ones_like(z))
1573 self.assertEqual(loc.grad, torch.ones_like(loc))
1574 self.assertEqual(scale.grad, eps)
1577 self.assertEqual(z.size(), (5, 5))
1579 def ref_log_prob(idx, x, log_prob):
1580 m = loc.view(-1)[idx]
1581 s = scale.view(-1)[idx]
1582 expected = (math.exp(-(x - m) ** 2 / (2 * s ** 2)) /
1583 math.sqrt(2 * math.pi * s ** 2))
1584 self.assertAlmostEqual(log_prob, math.log(expected), places=3)
1586 self._check_log_prob(Normal(loc, scale), ref_log_prob)
1588 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
1589 def test_normal_sample(self):
1591 for loc, scale
in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
1592 self._check_sampler_sampler(Normal(loc, scale),
1593 scipy.stats.norm(loc=loc, scale=scale),
1594 'Normal(mean={}, std={})'.format(loc, scale))
1596 def test_lowrank_multivariate_normal_shape(self):
1597 mean = torch.randn(5, 3, requires_grad=
True)
1598 mean_no_batch = torch.randn(3, requires_grad=
True)
1599 mean_multi_batch = torch.randn(6, 5, 3, requires_grad=
True)
1602 cov_factor = torch.randn(3, 1, requires_grad=
True)
1603 cov_diag = torch.randn(3).abs().requires_grad_()
1606 cov_factor_batched = torch.randn(6, 5, 3, 2, requires_grad=
True)
1607 cov_diag_batched = torch.randn(6, 5, 3).abs().requires_grad_()
1610 self.assertEqual(LowRankMultivariateNormal(mean, cov_factor, cov_diag)
1611 .sample().size(), (5, 3))
1612 self.assertEqual(LowRankMultivariateNormal(mean_no_batch, cov_factor, cov_diag)
1613 .sample().size(), (3,))
1614 self.assertEqual(LowRankMultivariateNormal(mean_multi_batch, cov_factor, cov_diag)
1615 .sample().size(), (6, 5, 3))
1616 self.assertEqual(LowRankMultivariateNormal(mean, cov_factor, cov_diag)
1617 .sample((2,)).size(), (2, 5, 3))
1618 self.assertEqual(LowRankMultivariateNormal(mean_no_batch, cov_factor, cov_diag)
1619 .sample((2,)).size(), (2, 3))
1620 self.assertEqual(LowRankMultivariateNormal(mean_multi_batch, cov_factor, cov_diag)
1621 .sample((2,)).size(), (2, 6, 5, 3))
1622 self.assertEqual(LowRankMultivariateNormal(mean, cov_factor, cov_diag)
1623 .sample((2, 7)).size(), (2, 7, 5, 3))
1624 self.assertEqual(LowRankMultivariateNormal(mean_no_batch, cov_factor, cov_diag)
1625 .sample((2, 7)).size(), (2, 7, 3))
1626 self.assertEqual(LowRankMultivariateNormal(mean_multi_batch, cov_factor, cov_diag)
1627 .sample((2, 7)).size(), (2, 7, 6, 5, 3))
1628 self.assertEqual(LowRankMultivariateNormal(mean, cov_factor_batched, cov_diag_batched)
1629 .sample((2, 7)).size(), (2, 7, 6, 5, 3))
1630 self.assertEqual(LowRankMultivariateNormal(mean_no_batch, cov_factor_batched, cov_diag_batched)
1631 .sample((2, 7)).size(), (2, 7, 6, 5, 3))
1632 self.assertEqual(LowRankMultivariateNormal(mean_multi_batch, cov_factor_batched, cov_diag_batched)
1633 .sample((2, 7)).size(), (2, 7, 6, 5, 3))
1636 self._gradcheck_log_prob(LowRankMultivariateNormal,
1637 (mean, cov_factor, cov_diag))
1638 self._gradcheck_log_prob(LowRankMultivariateNormal,
1639 (mean_multi_batch, cov_factor, cov_diag))
1640 self._gradcheck_log_prob(LowRankMultivariateNormal,
1641 (mean_multi_batch, cov_factor_batched, cov_diag_batched))
1643 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
1644 def test_lowrank_multivariate_normal_log_prob(self):
1645 mean = torch.randn(3, requires_grad=
True)
1646 cov_factor = torch.randn(3, 1, requires_grad=
True)
1647 cov_diag = torch.randn(3).abs().requires_grad_()
1648 cov = cov_factor.matmul(cov_factor.t()) + cov_diag.diag()
1652 dist1 = LowRankMultivariateNormal(mean, cov_factor, cov_diag)
1653 ref_dist = scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy())
1655 x = dist1.sample((10,))
1656 expected = ref_dist.logpdf(x.numpy())
1658 self.assertAlmostEqual(0.0, np.mean((dist1.log_prob(x).detach().numpy() - expected)**2), places=3)
1661 mean = torch.randn(5, 3, requires_grad=
True)
1662 cov_factor = torch.randn(5, 3, 2, requires_grad=
True)
1663 cov_diag = torch.randn(5, 3).abs().requires_grad_()
1665 dist_batched = LowRankMultivariateNormal(mean, cov_factor, cov_diag)
1666 dist_unbatched = [LowRankMultivariateNormal(mean[i], cov_factor[i], cov_diag[i])
1667 for i
in range(mean.size(0))]
1669 x = dist_batched.sample((10,))
1670 batched_prob = dist_batched.log_prob(x)
1671 unbatched_prob = torch.stack([dist_unbatched[i].log_prob(x[:, i])
for i
in range(5)]).t()
1673 self.assertEqual(batched_prob.shape, unbatched_prob.shape)
1674 self.assertAlmostEqual(0.0, (batched_prob - unbatched_prob).abs().max(), places=3)
1676 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
1677 def test_lowrank_multivariate_normal_sample(self):
1679 mean = torch.randn(5, requires_grad=
True)
1680 cov_factor = torch.randn(5, 1, requires_grad=
True)
1681 cov_diag = torch.randn(5).abs().requires_grad_()
1682 cov = cov_factor.matmul(cov_factor.t()) + cov_diag.diag()
1684 self._check_sampler_sampler(LowRankMultivariateNormal(mean, cov_factor, cov_diag),
1685 scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy()),
1686 'LowRankMultivariateNormal(loc={}, cov_factor={}, cov_diag={})' 1687 .format(mean, cov_factor, cov_diag), multivariate=
True)
1689 def test_lowrank_multivariate_normal_properties(self):
1690 loc = torch.randn(5)
1691 cov_factor = torch.randn(5, 2)
1692 cov_diag = torch.randn(5).abs()
1693 cov = cov_factor.matmul(cov_factor.t()) + cov_diag.diag()
1694 m1 = LowRankMultivariateNormal(loc, cov_factor, cov_diag)
1695 m2 = MultivariateNormal(loc=loc, covariance_matrix=cov)
1696 self.assertEqual(m1.mean, m2.mean)
1697 self.assertEqual(m1.variance, m2.variance)
1698 self.assertEqual(m1.covariance_matrix, m2.covariance_matrix)
1699 self.assertEqual(m1.scale_tril, m2.scale_tril)
1700 self.assertEqual(m1.precision_matrix, m2.precision_matrix)
1701 self.assertEqual(m1.entropy(), m2.entropy())
1703 def test_lowrank_multivariate_normal_moments(self):
1705 mean = torch.randn(5)
1706 cov_factor = torch.randn(5, 2)
1707 cov_diag = torch.randn(5).abs()
1708 d = LowRankMultivariateNormal(mean, cov_factor, cov_diag)
1709 samples = d.rsample((100000,))
1710 empirical_mean = samples.mean(0)
1711 self.assertEqual(d.mean, empirical_mean, prec=0.01)
1712 empirical_var = samples.var(0)
1713 self.assertEqual(d.variance, empirical_var, prec=0.02)
1715 def test_multivariate_normal_shape(self):
1716 mean = torch.randn(5, 3, requires_grad=
True)
1717 mean_no_batch = torch.randn(3, requires_grad=
True)
1718 mean_multi_batch = torch.randn(6, 5, 3, requires_grad=
True)
1721 tmp = torch.randn(3, 10)
1722 cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
1723 prec = cov.inverse().requires_grad_()
1724 scale_tril = torch.cholesky(cov, upper=
False).requires_grad_()
1727 tmp = torch.randn(6, 5, 3, 10)
1728 cov_batched = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_()
1729 prec_batched = [C.inverse()
for C
in cov_batched.view((-1, 3, 3))]
1730 prec_batched = torch.stack(prec_batched).view(cov_batched.shape)
1731 scale_tril_batched = [torch.cholesky(C, upper=
False)
for C
in cov_batched.view((-1, 3, 3))]
1732 scale_tril_batched = torch.stack(scale_tril_batched).view(cov_batched.shape)
1735 self.assertEqual(MultivariateNormal(mean, cov).sample().size(), (5, 3))
1736 self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample().size(), (3,))
1737 self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample().size(), (6, 5, 3))
1738 self.assertEqual(MultivariateNormal(mean, cov).sample((2,)).size(), (2, 5, 3))
1739 self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample((2,)).size(), (2, 3))
1740 self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample((2,)).size(), (2, 6, 5, 3))
1741 self.assertEqual(MultivariateNormal(mean, cov).sample((2, 7)).size(), (2, 7, 5, 3))
1742 self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample((2, 7)).size(), (2, 7, 3))
1743 self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample((2, 7)).size(), (2, 7, 6, 5, 3))
1744 self.assertEqual(MultivariateNormal(mean, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3))
1745 self.assertEqual(MultivariateNormal(mean_no_batch, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3))
1746 self.assertEqual(MultivariateNormal(mean_multi_batch, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3))
1747 self.assertEqual(MultivariateNormal(mean, precision_matrix=prec).sample((2, 7)).size(), (2, 7, 5, 3))
1748 self.assertEqual(MultivariateNormal(mean, precision_matrix=prec_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3))
1749 self.assertEqual(MultivariateNormal(mean, scale_tril=scale_tril).sample((2, 7)).size(), (2, 7, 5, 3))
1750 self.assertEqual(MultivariateNormal(mean, scale_tril=scale_tril_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3))
1753 self._gradcheck_log_prob(MultivariateNormal, (mean, cov))
1754 self._gradcheck_log_prob(MultivariateNormal, (mean_multi_batch, cov))
1755 self._gradcheck_log_prob(MultivariateNormal, (mean_multi_batch, cov_batched))
1756 self._gradcheck_log_prob(MultivariateNormal, (mean,
None, prec))
1757 self._gradcheck_log_prob(MultivariateNormal, (mean_no_batch,
None, prec_batched))
1758 self._gradcheck_log_prob(MultivariateNormal, (mean,
None,
None, scale_tril))
1759 self._gradcheck_log_prob(MultivariateNormal, (mean_no_batch,
None,
None, scale_tril_batched))
1761 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
1762 def test_multivariate_normal_log_prob(self):
1763 mean = torch.randn(3, requires_grad=
True)
1764 tmp = torch.randn(3, 10)
1765 cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
1766 prec = cov.inverse().requires_grad_()
1767 scale_tril = torch.cholesky(cov, upper=
False).requires_grad_()
1771 dist1 = MultivariateNormal(mean, cov)
1772 dist2 = MultivariateNormal(mean, precision_matrix=prec)
1773 dist3 = MultivariateNormal(mean, scale_tril=scale_tril)
1774 ref_dist = scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy())
1776 x = dist1.sample((10,))
1777 expected = ref_dist.logpdf(x.numpy())
1779 self.assertAlmostEqual(0.0, np.mean((dist1.log_prob(x).detach().numpy() - expected)**2), places=3)
1780 self.assertAlmostEqual(0.0, np.mean((dist2.log_prob(x).detach().numpy() - expected)**2), places=3)
1781 self.assertAlmostEqual(0.0, np.mean((dist3.log_prob(x).detach().numpy() - expected)**2), places=3)
1784 mean = torch.randn(5, 3, requires_grad=
True)
1785 tmp = torch.randn(5, 3, 10)
1786 cov = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_()
1788 dist_batched = MultivariateNormal(mean, cov)
1789 dist_unbatched = [MultivariateNormal(mean[i], cov[i])
for i
in range(mean.size(0))]
1791 x = dist_batched.sample((10,))
1792 batched_prob = dist_batched.log_prob(x)
1793 unbatched_prob = torch.stack([dist_unbatched[i].log_prob(x[:, i])
for i
in range(5)]).t()
1795 self.assertEqual(batched_prob.shape, unbatched_prob.shape)
1796 self.assertAlmostEqual(0.0, (batched_prob - unbatched_prob).abs().max(), places=3)
1798 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
1799 def test_multivariate_normal_sample(self):
1801 mean = torch.randn(3, requires_grad=
True)
1802 tmp = torch.randn(3, 10)
1803 cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
1804 prec = cov.inverse().requires_grad_()
1805 scale_tril = torch.cholesky(cov, upper=
False).requires_grad_()
1807 self._check_sampler_sampler(MultivariateNormal(mean, cov),
1808 scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy()),
1809 'MultivariateNormal(loc={}, cov={})'.format(mean, cov),
1811 self._check_sampler_sampler(MultivariateNormal(mean, precision_matrix=prec),
1812 scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy()),
1813 'MultivariateNormal(loc={}, prec={})'.format(mean, prec),
1815 self._check_sampler_sampler(MultivariateNormal(mean, scale_tril=scale_tril),
1816 scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy()),
1817 'MultivariateNormal(loc={}, scale_tril={})'.format(mean, scale_tril),
1820 def test_multivariate_normal_properties(self):
1821 loc = torch.randn(5)
1822 scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(5, 5))
1823 m = MultivariateNormal(loc=loc, scale_tril=scale_tril)
1824 self.assertEqual(m.covariance_matrix, m.scale_tril.mm(m.scale_tril.t()))
1825 self.assertEqual(m.covariance_matrix.mm(m.precision_matrix), torch.eye(m.event_shape[0]))
1826 self.assertEqual(m.scale_tril, torch.cholesky(m.covariance_matrix, upper=
False))
1828 def test_multivariate_normal_moments(self):
1830 mean = torch.randn(5)
1831 scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(5, 5))
1832 d = MultivariateNormal(mean, scale_tril=scale_tril)
1833 samples = d.rsample((100000,))
1834 empirical_mean = samples.mean(0)
1835 self.assertEqual(d.mean, empirical_mean, prec=0.01)
1836 empirical_var = samples.var(0)
1837 self.assertEqual(d.variance, empirical_var, prec=0.05)
1839 def test_exponential(self):
1840 rate = torch.randn(5, 5).abs().requires_grad_()
1841 rate_1d = torch.randn(1).abs().requires_grad_()
1842 self.assertEqual(Exponential(rate).sample().size(), (5, 5))
1843 self.assertEqual(Exponential(rate).sample((7,)).size(), (7, 5, 5))
1844 self.assertEqual(Exponential(rate_1d).sample((1,)).size(), (1, 1))
1845 self.assertEqual(Exponential(rate_1d).sample().size(), (1,))
1846 self.assertEqual(Exponential(0.2).sample((1,)).size(), (1,))
1847 self.assertEqual(Exponential(50.0).sample((1,)).size(), (1,))
1849 self._gradcheck_log_prob(Exponential, (rate,))
1850 state = torch.get_rng_state()
1851 eps = rate.new(rate.size()).exponential_()
1852 torch.set_rng_state(state)
1853 z = Exponential(rate).rsample()
1854 z.backward(torch.ones_like(z))
1855 self.assertEqual(rate.grad, -eps / rate**2)
1857 self.assertEqual(z.size(), (5, 5))
1859 def ref_log_prob(idx, x, log_prob):
1860 m = rate.view(-1)[idx]
1861 expected = math.log(m) - m * x
1862 self.assertAlmostEqual(log_prob, expected, places=3)
1864 self._check_log_prob(Exponential(rate), ref_log_prob)
1866 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
1867 def test_exponential_sample(self):
1869 for rate
in [1e-5, 1.0, 10.]:
1870 self._check_sampler_sampler(Exponential(rate),
1871 scipy.stats.expon(scale=1. / rate),
1872 'Exponential(rate={})'.format(rate))
1874 def test_laplace(self):
1875 loc = torch.randn(5, 5, requires_grad=
True)
1876 scale = torch.randn(5, 5).abs().requires_grad_()
1877 loc_1d = torch.randn(1, requires_grad=
True)
1878 scale_1d = torch.randn(1, requires_grad=
True)
1881 self.assertEqual(Laplace(loc, scale).sample().size(), (5, 5))
1882 self.assertEqual(Laplace(loc, scale).sample((7,)).size(), (7, 5, 5))
1883 self.assertEqual(Laplace(loc_1d, scale_1d).sample((1,)).size(), (1, 1))
1884 self.assertEqual(Laplace(loc_1d, scale_1d).sample().size(), (1,))
1885 self.assertEqual(Laplace(0.2, .6).sample((1,)).size(), (1,))
1886 self.assertEqual(Laplace(-0.7, 50.0).sample((1,)).size(), (1,))
1890 self.assertEqual(Laplace(loc_delta, scale_delta).sample(sample_shape=(1, 2)),
1894 self._gradcheck_log_prob(Laplace, (loc, scale))
1895 self._gradcheck_log_prob(Laplace, (loc, 1.0))
1896 self._gradcheck_log_prob(Laplace, (0.0, scale))
1898 state = torch.get_rng_state()
1899 eps = torch.ones_like(loc).uniform_(-.5, .5)
1900 torch.set_rng_state(state)
1901 z = Laplace(loc, scale).rsample()
1902 z.backward(torch.ones_like(z))
1903 self.assertEqual(loc.grad, torch.ones_like(loc))
1904 self.assertEqual(scale.grad, -eps.sign() * torch.log1p(-2 * eps.abs()))
1907 self.assertEqual(z.size(), (5, 5))
1909 def ref_log_prob(idx, x, log_prob):
1910 m = loc.view(-1)[idx]
1911 s = scale.view(-1)[idx]
1912 expected = (-math.log(2 * s) - abs(x - m) / s)
1913 self.assertAlmostEqual(log_prob, expected, places=3)
1915 self._check_log_prob(Laplace(loc, scale), ref_log_prob)
1917 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
1918 def test_laplace_sample(self):
1920 for loc, scale
in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
1921 self._check_sampler_sampler(Laplace(loc, scale),
1922 scipy.stats.laplace(loc=loc, scale=scale),
1923 'Laplace(loc={}, scale={})'.format(loc, scale))
1925 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
1926 def test_gamma_shape(self):
1927 alpha = torch.randn(2, 3).exp().requires_grad_()
1928 beta = torch.randn(2, 3).exp().requires_grad_()
1929 alpha_1d = torch.randn(1).exp().requires_grad_()
1930 beta_1d = torch.randn(1).exp().requires_grad_()
1931 self.assertEqual(Gamma(alpha, beta).sample().size(), (2, 3))
1932 self.assertEqual(Gamma(alpha, beta).sample((5,)).size(), (5, 2, 3))
1933 self.assertEqual(Gamma(alpha_1d, beta_1d).sample((1,)).size(), (1, 1))
1934 self.assertEqual(Gamma(alpha_1d, beta_1d).sample().size(), (1,))
1935 self.assertEqual(Gamma(0.5, 0.5).sample().size(), ())
1936 self.assertEqual(Gamma(0.5, 0.5).sample((1,)).size(), (1,))
1938 def ref_log_prob(idx, x, log_prob):
1939 a = alpha.view(-1)[idx].detach()
1940 b = beta.view(-1)[idx].detach()
1941 expected = scipy.stats.gamma.logpdf(x, a, scale=1 / b)
1942 self.assertAlmostEqual(log_prob, expected, places=3)
1944 self._check_log_prob(Gamma(alpha, beta), ref_log_prob)
1946 @unittest.skipIf(
not TEST_CUDA,
"CUDA not found")
1947 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
1948 def test_gamma_gpu_shape(self):
1949 alpha = torch.randn(2, 3).cuda().exp().requires_grad_()
1950 beta = torch.randn(2, 3).cuda().exp().requires_grad_()
1951 alpha_1d = torch.randn(1).cuda().exp().requires_grad_()
1952 beta_1d = torch.randn(1).cuda().exp().requires_grad_()
1953 self.assertEqual(Gamma(alpha, beta).sample().size(), (2, 3))
1954 self.assertEqual(Gamma(alpha, beta).sample((5,)).size(), (5, 2, 3))
1955 self.assertEqual(Gamma(alpha_1d, beta_1d).sample((1,)).size(), (1, 1))
1956 self.assertEqual(Gamma(alpha_1d, beta_1d).sample().size(), (1,))
1957 self.assertEqual(Gamma(0.5, 0.5).sample().size(), ())
1958 self.assertEqual(Gamma(0.5, 0.5).sample((1,)).size(), (1,))
1960 def ref_log_prob(idx, x, log_prob):
1961 a = alpha.view(-1)[idx].detach().cpu()
1962 b = beta.view(-1)[idx].detach().cpu()
1963 expected = scipy.stats.gamma.logpdf(x.cpu(), a, scale=1 / b)
1964 self.assertAlmostEqual(log_prob, expected, places=3)
1966 self._check_log_prob(Gamma(alpha, beta), ref_log_prob)
1968 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
1969 def test_gamma_sample(self):
1971 for alpha, beta
in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
1972 self._check_sampler_sampler(Gamma(alpha, beta),
1973 scipy.stats.gamma(alpha, scale=1.0 / beta),
1974 'Gamma(concentration={}, rate={})'.format(alpha, beta))
1976 @unittest.skipIf(
not TEST_CUDA,
"CUDA not found")
1977 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
1979 def test_gamma_gpu_sample(self):
1981 for alpha, beta
in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
1983 self._check_sampler_sampler(Gamma(a, b),
1984 scipy.stats.gamma(alpha, scale=1.0 / beta),
1985 'Gamma(alpha={}, beta={})'.format(alpha, beta),
1988 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
1989 def test_pareto(self):
1990 scale = torch.randn(2, 3).abs().requires_grad_()
1991 alpha = torch.randn(2, 3).abs().requires_grad_()
1992 scale_1d = torch.randn(1).abs().requires_grad_()
1993 alpha_1d = torch.randn(1).abs().requires_grad_()
1994 self.assertEqual(Pareto(scale_1d, 0.5).mean, inf, allow_inf=
True)
1995 self.assertEqual(Pareto(scale_1d, 0.5).variance, inf, allow_inf=
True)
1996 self.assertEqual(Pareto(scale, alpha).sample().size(), (2, 3))
1997 self.assertEqual(Pareto(scale, alpha).sample((5,)).size(), (5, 2, 3))
1998 self.assertEqual(Pareto(scale_1d, alpha_1d).sample((1,)).size(), (1, 1))
1999 self.assertEqual(Pareto(scale_1d, alpha_1d).sample().size(), (1,))
2000 self.assertEqual(Pareto(1.0, 1.0).sample().size(), ())
2001 self.assertEqual(Pareto(1.0, 1.0).sample((1,)).size(), (1,))
2003 def ref_log_prob(idx, x, log_prob):
2004 s = scale.view(-1)[idx].detach()
2005 a = alpha.view(-1)[idx].detach()
2006 expected = scipy.stats.pareto.logpdf(x, a, scale=s)
2007 self.assertAlmostEqual(log_prob, expected, places=3)
2009 self._check_log_prob(Pareto(scale, alpha), ref_log_prob)
2011 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
2012 def test_pareto_sample(self):
2014 for scale, alpha
in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
2015 self._check_sampler_sampler(Pareto(scale, alpha),
2016 scipy.stats.pareto(alpha, scale=scale),
2017 'Pareto(scale={}, alpha={})'.format(scale, alpha))
2019 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
2020 def test_gumbel(self):
2021 loc = torch.randn(2, 3, requires_grad=
True)
2022 scale = torch.randn(2, 3).abs().requires_grad_()
2023 loc_1d = torch.randn(1, requires_grad=
True)
2024 scale_1d = torch.randn(1).abs().requires_grad_()
2025 self.assertEqual(Gumbel(loc, scale).sample().size(), (2, 3))
2026 self.assertEqual(Gumbel(loc, scale).sample((5,)).size(), (5, 2, 3))
2027 self.assertEqual(Gumbel(loc_1d, scale_1d).sample().size(), (1,))
2028 self.assertEqual(Gumbel(loc_1d, scale_1d).sample((1,)).size(), (1, 1))
2029 self.assertEqual(Gumbel(1.0, 1.0).sample().size(), ())
2030 self.assertEqual(Gumbel(1.0, 1.0).sample((1,)).size(), (1,))
2032 def ref_log_prob(idx, x, log_prob):
2033 l = loc.view(-1)[idx].detach()
2034 s = scale.view(-1)[idx].detach()
2035 expected = scipy.stats.gumbel_r.logpdf(x, loc=l, scale=s)
2036 self.assertAlmostEqual(log_prob, expected, places=3)
2038 self._check_log_prob(Gumbel(loc, scale), ref_log_prob)
2040 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
2041 def test_gumbel_sample(self):
2043 for loc, scale
in product([-5.0, -1.0, -0.1, 0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
2044 self._check_sampler_sampler(Gumbel(loc, scale),
2045 scipy.stats.gumbel_r(loc=loc, scale=scale),
2046 'Gumbel(loc={}, scale={})'.format(loc, scale))
2048 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
2049 def test_fishersnedecor(self):
2050 df1 = torch.randn(2, 3).abs().requires_grad_()
2051 df2 = torch.randn(2, 3).abs().requires_grad_()
2052 df1_1d = torch.randn(1).abs()
2053 df2_1d = torch.randn(1).abs()
2054 self.assertTrue(
is_all_nan(FisherSnedecor(1, 2).mean))
2055 self.assertTrue(
is_all_nan(FisherSnedecor(1, 4).variance))
2056 self.assertEqual(FisherSnedecor(df1, df2).sample().size(), (2, 3))
2057 self.assertEqual(FisherSnedecor(df1, df2).sample((5,)).size(), (5, 2, 3))
2058 self.assertEqual(FisherSnedecor(df1_1d, df2_1d).sample().size(), (1,))
2059 self.assertEqual(FisherSnedecor(df1_1d, df2_1d).sample((1,)).size(), (1, 1))
2060 self.assertEqual(FisherSnedecor(1.0, 1.0).sample().size(), ())
2061 self.assertEqual(FisherSnedecor(1.0, 1.0).sample((1,)).size(), (1,))
2063 def ref_log_prob(idx, x, log_prob):
2064 f1 = df1.view(-1)[idx].detach()
2065 f2 = df2.view(-1)[idx].detach()
2066 expected = scipy.stats.f.logpdf(x, f1, f2)
2067 self.assertAlmostEqual(log_prob, expected, places=3)
2069 self._check_log_prob(FisherSnedecor(df1, df2), ref_log_prob)
2071 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
2072 def test_fishersnedecor_sample(self):
2074 for df1, df2
in product([0.1, 0.5, 1.0, 5.0, 10.0], [0.1, 0.5, 1.0, 5.0, 10.0]):
2075 self._check_sampler_sampler(FisherSnedecor(df1, df2),
2076 scipy.stats.f(df1, df2),
2077 'FisherSnedecor(loc={}, scale={})'.format(df1, df2))
2079 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
2080 def test_chi2_shape(self):
2081 df = torch.randn(2, 3).exp().requires_grad_()
2082 df_1d = torch.randn(1).exp().requires_grad_()
2083 self.assertEqual(Chi2(df).sample().size(), (2, 3))
2084 self.assertEqual(Chi2(df).sample((5,)).size(), (5, 2, 3))
2085 self.assertEqual(Chi2(df_1d).sample((1,)).size(), (1, 1))
2086 self.assertEqual(Chi2(df_1d).sample().size(), (1,))
2087 self.assertEqual(Chi2(
torch.tensor(0.5, requires_grad=
True)).sample().size(), ())
2088 self.assertEqual(Chi2(0.5).sample().size(), ())
2089 self.assertEqual(Chi2(0.5).sample((1,)).size(), (1,))
2091 def ref_log_prob(idx, x, log_prob):
2092 d = df.view(-1)[idx].detach()
2093 expected = scipy.stats.chi2.logpdf(x, d)
2094 self.assertAlmostEqual(log_prob, expected, places=3)
2096 self._check_log_prob(Chi2(df), ref_log_prob)
2098 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
2099 def test_chi2_sample(self):
2101 for df
in [0.1, 1.0, 5.0]:
2102 self._check_sampler_sampler(Chi2(df),
2103 scipy.stats.chi2(df),
2104 'Chi2(df={})'.format(df))
2106 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
2107 def test_studentT(self):
2108 df = torch.randn(2, 3).exp().requires_grad_()
2109 df_1d = torch.randn(1).exp().requires_grad_()
2110 self.assertTrue(
is_all_nan(StudentT(1).mean))
2111 self.assertTrue(
is_all_nan(StudentT(1).variance))
2112 self.assertEqual(StudentT(2).variance, inf, allow_inf=
True)
2113 self.assertEqual(StudentT(df).sample().size(), (2, 3))
2114 self.assertEqual(StudentT(df).sample((5,)).size(), (5, 2, 3))
2115 self.assertEqual(StudentT(df_1d).sample((1,)).size(), (1, 1))
2116 self.assertEqual(StudentT(df_1d).sample().size(), (1,))
2117 self.assertEqual(StudentT(
torch.tensor(0.5, requires_grad=
True)).sample().size(), ())
2118 self.assertEqual(StudentT(0.5).sample().size(), ())
2119 self.assertEqual(StudentT(0.5).sample((1,)).size(), (1,))
2121 def ref_log_prob(idx, x, log_prob):
2122 d = df.view(-1)[idx].detach()
2123 expected = scipy.stats.t.logpdf(x, d)
2124 self.assertAlmostEqual(log_prob, expected, places=3)
2126 self._check_log_prob(StudentT(df), ref_log_prob)
2128 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
2129 def test_studentT_sample(self):
2131 for df, loc, scale
in product([0.1, 1.0, 5.0, 10.0], [-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
2132 self._check_sampler_sampler(StudentT(df=df, loc=loc, scale=scale),
2133 scipy.stats.t(df=df, loc=loc, scale=scale),
2134 'StudentT(df={}, loc={}, scale={})'.format(df, loc, scale))
2136 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
2137 def test_studentT_log_prob(self):
2140 for df, loc, scale
in product([0.1, 1.0, 5.0, 10.0], [-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
2141 dist = StudentT(df=df, loc=loc, scale=scale)
2142 x = dist.sample((num_samples,))
2143 actual_log_prob = dist.log_prob(x)
2144 for i
in range(num_samples):
2145 expected_log_prob = scipy.stats.t.logpdf(x[i], df=df, loc=loc, scale=scale)
2146 self.assertAlmostEqual(float(actual_log_prob[i]), float(expected_log_prob), places=3)
2148 def test_dirichlet_shape(self):
2149 alpha = torch.randn(2, 3).exp().requires_grad_()
2150 alpha_1d = torch.randn(4).exp().requires_grad_()
2151 self.assertEqual(Dirichlet(alpha).sample().size(), (2, 3))
2152 self.assertEqual(Dirichlet(alpha).sample((5,)).size(), (5, 2, 3))
2153 self.assertEqual(Dirichlet(alpha_1d).sample().size(), (4,))
2154 self.assertEqual(Dirichlet(alpha_1d).sample((1,)).size(), (1, 4))
2156 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
2157 def test_dirichlet_log_prob(self):
2159 alpha = torch.exp(torch.randn(5))
2160 dist = Dirichlet(alpha)
2161 x = dist.sample((num_samples,))
2162 actual_log_prob = dist.log_prob(x)
2163 for i
in range(num_samples):
2164 expected_log_prob = scipy.stats.dirichlet.logpdf(x[i].numpy(), alpha.numpy())
2165 self.assertAlmostEqual(actual_log_prob[i], expected_log_prob, places=3)
2167 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
2168 def test_dirichlet_sample(self):
2170 alpha = torch.exp(torch.randn(3))
2171 self._check_sampler_sampler(Dirichlet(alpha),
2172 scipy.stats.dirichlet(alpha.numpy()),
2173 'Dirichlet(alpha={})'.format(list(alpha)),
2176 def test_beta_shape(self):
2177 con1 = torch.randn(2, 3).exp().requires_grad_()
2178 con0 = torch.randn(2, 3).exp().requires_grad_()
2179 con1_1d = torch.randn(4).exp().requires_grad_()
2180 con0_1d = torch.randn(4).exp().requires_grad_()
2181 self.assertEqual(Beta(con1, con0).sample().size(), (2, 3))
2182 self.assertEqual(Beta(con1, con0).sample((5,)).size(), (5, 2, 3))
2183 self.assertEqual(Beta(con1_1d, con0_1d).sample().size(), (4,))
2184 self.assertEqual(Beta(con1_1d, con0_1d).sample((1,)).size(), (1, 4))
2185 self.assertEqual(Beta(0.1, 0.3).sample().size(), ())
2186 self.assertEqual(Beta(0.1, 0.3).sample((5,)).size(), (5,))
2188 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
2189 def test_beta_log_prob(self):
2190 for _
in range(100):
2191 con1 = np.exp(np.random.normal())
2192 con0 = np.exp(np.random.normal())
2193 dist = Beta(con1, con0)
2195 actual_log_prob = dist.log_prob(x).sum()
2196 expected_log_prob = scipy.stats.beta.logpdf(x, con1, con0)
2197 self.assertAlmostEqual(float(actual_log_prob), float(expected_log_prob), places=3, allow_inf=
True)
2199 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
2200 def test_beta_sample(self):
2202 for con1, con0
in product([0.1, 1.0, 10.0], [0.1, 1.0, 10.0]):
2203 self._check_sampler_sampler(Beta(con1, con0),
2204 scipy.stats.beta(con1, con0),
2205 'Beta(alpha={}, beta={})'.format(con1, con0))
2207 for Tensor
in [torch.FloatTensor, torch.DoubleTensor]:
2208 x = Beta(Tensor([1e-6]), Tensor([1e-6])).sample()[0]
2209 self.assertTrue(np.isfinite(x)
and x > 0,
'Invalid Beta.sample(): {}'.format(x))
2211 def test_beta_underflow(self):
2218 for dtype
in [torch.float, torch.double]:
2220 beta_samples = Beta(conc, conc).sample([num_samples])
2221 self.assertEqual((beta_samples == 0).sum(), 0)
2222 self.assertEqual((beta_samples == 1).sum(), 0)
2224 frac_zeros = float((beta_samples < 0.1).sum()) / num_samples
2225 frac_ones = float((beta_samples > 0.9).sum()) / num_samples
2226 self.assertEqual(frac_zeros, 0.5, 0.05)
2227 self.assertEqual(frac_ones, 0.5, 0.05)
2229 @unittest.skipIf(
not TEST_CUDA,
"CUDA not found")
2230 def test_beta_underflow_gpu(self):
2234 beta_samples = Beta(conc, conc).sample([num_samples])
2235 self.assertEqual((beta_samples == 0).sum(), 0)
2236 self.assertEqual((beta_samples == 1).sum(), 0)
2238 frac_zeros = float((beta_samples < 0.1).sum()) / num_samples
2239 frac_ones = float((beta_samples > 0.9).sum()) / num_samples
2241 self.assertEqual(frac_zeros, 0.5, 0.12)
2242 self.assertEqual(frac_ones, 0.5, 0.12)
2244 def test_independent_shape(self):
2245 for Dist, params
in EXAMPLES:
2246 for param
in params:
2247 base_dist = Dist(**param)
2248 x = base_dist.sample()
2249 base_log_prob_shape = base_dist.log_prob(x).shape
2250 for reinterpreted_batch_ndims
in range(len(base_dist.batch_shape) + 1):
2251 indep_dist = Independent(base_dist, reinterpreted_batch_ndims)
2252 indep_log_prob_shape = base_log_prob_shape[:len(base_log_prob_shape) - reinterpreted_batch_ndims]
2253 self.assertEqual(indep_dist.log_prob(x).shape, indep_log_prob_shape)
2254 self.assertEqual(indep_dist.sample().shape, base_dist.sample().shape)
2255 self.assertEqual(indep_dist.has_rsample, base_dist.has_rsample)
2256 if indep_dist.has_rsample:
2257 self.assertEqual(indep_dist.sample().shape, base_dist.sample().shape)
2259 self.assertEqual(indep_dist.enumerate_support().shape, base_dist.enumerate_support().shape)
2260 self.assertEqual(indep_dist.mean.shape, base_dist.mean.shape)
2261 except NotImplementedError:
2264 self.assertEqual(indep_dist.variance.shape, base_dist.variance.shape)
2265 except NotImplementedError:
2268 self.assertEqual(indep_dist.entropy().shape, indep_log_prob_shape)
2269 except NotImplementedError:
2272 def test_independent_expand(self):
2273 for Dist, params
in EXAMPLES:
2274 for param
in params:
2275 base_dist = Dist(**param)
2276 for reinterpreted_batch_ndims
in range(len(base_dist.batch_shape) + 1):
2277 for s
in [torch.Size(), torch.Size((2,)), torch.Size((2, 3))]:
2278 indep_dist = Independent(base_dist, reinterpreted_batch_ndims)
2279 expanded_shape = s + indep_dist.batch_shape
2280 expanded = indep_dist.expand(expanded_shape)
2281 expanded_sample = expanded.sample()
2282 expected_shape = expanded_shape + indep_dist.event_shape
2283 self.assertEqual(expanded_sample.shape, expected_shape)
2284 self.assertEqual(expanded.log_prob(expanded_sample),
2285 indep_dist.log_prob(expanded_sample))
2286 self.assertEqual(expanded.event_shape, indep_dist.event_shape)
2287 self.assertEqual(expanded.batch_shape, expanded_shape)
2289 def test_cdf_icdf_inverse(self):
2291 for Dist, params
in EXAMPLES:
2292 for i, param
in enumerate(params):
2293 dist = Dist(**param)
2294 samples = dist.sample(sample_shape=(20,))
2296 cdf = dist.cdf(samples)
2297 actual = dist.icdf(cdf)
2298 except NotImplementedError:
2300 rel_error = torch.abs(actual - samples) / (1e-10 + torch.abs(samples))
2301 self.assertLess(rel_error.max(), 1e-4, msg=
'\n'.join([
2302 '{} example {}/{}, icdf(cdf(x)) != x'.format(Dist.__name__, i + 1, len(params)),
2303 'x = {}'.format(samples),
2304 'cdf(x) = {}'.format(cdf),
2305 'icdf(cdf(x)) = {}'.format(actual),
2308 def test_cdf_log_prob(self):
2310 for Dist, params
in EXAMPLES:
2311 for i, param
in enumerate(params):
2312 dist = Dist(**param)
2313 samples = dist.sample()
2314 if samples.dtype.is_floating_point:
2315 samples.requires_grad_()
2317 cdfs = dist.cdf(samples)
2318 pdfs = dist.log_prob(samples).exp()
2319 except NotImplementedError:
2321 cdfs_derivative = grad(cdfs.sum(), [samples])[0]
2322 self.assertEqual(cdfs_derivative, pdfs, message=
'\n'.join([
2323 '{} example {}/{}, d(cdf)/dx != pdf(x)'.format(Dist.__name__, i + 1, len(params)),
2324 'x = {}'.format(samples),
2325 'cdf = {}'.format(cdfs),
2326 'pdf = {}'.format(pdfs),
2327 'grad(cdf) = {}'.format(cdfs_derivative),
2330 def test_valid_parameter_broadcasting(self):
2419 (StudentT(df=1., loc=torch.zeros(5, 1), scale=torch.ones(3)),
2423 for dist, expected_size
in valid_examples:
2424 actual_size = dist.sample().size()
2425 self.assertEqual(actual_size, expected_size,
2426 '{} actual size: {} != expected size: {}'.format(dist, actual_size, expected_size))
2428 sample_shape = torch.Size((2,))
2429 expected_size = sample_shape + expected_size
2430 actual_size = dist.sample(sample_shape).size()
2431 self.assertEqual(actual_size, expected_size,
2432 '{} actual size: {} != expected size: {}'.format(dist, actual_size, expected_size))
2434 def test_invalid_parameter_broadcasting(self):
2437 invalid_examples = [
2480 for dist, kwargs
in invalid_examples:
2481 self.assertRaises(RuntimeError, dist, **kwargs)
2488 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
2489 def test_gamma(self):
2491 for alpha
in [1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]:
2492 alphas =
torch.tensor([alpha] * num_samples, dtype=torch.float, requires_grad=
True)
2493 betas = alphas.new_ones(num_samples)
2494 x = Gamma(alphas, betas).rsample()
2497 x = x.detach().numpy()
2498 actual_grad = alphas.grad[ind].numpy()
2500 cdf = scipy.stats.gamma.cdf
2501 pdf = scipy.stats.gamma.pdf
2502 eps = 0.01 * alpha / (1.0 + alpha ** 0.5)
2503 cdf_alpha = (cdf(x, alpha + eps) - cdf(x, alpha - eps)) / (2 * eps)
2504 cdf_x = pdf(x, alpha)
2505 expected_grad = -cdf_alpha / cdf_x
2506 rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
2507 self.assertLess(np.max(rel_error), 0.0005,
'\n'.join([
2508 'Bad gradient dx/alpha for x ~ Gamma({}, 1)'.format(alpha),
2510 'expected {}'.format(expected_grad),
2511 'actual {}'.format(actual_grad),
2512 'rel error {}'.format(rel_error),
2513 'max error {}'.format(rel_error.max()),
2514 'at alpha={}, x={}'.format(alpha, x[rel_error.argmax()]),
2517 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
2518 def test_chi2(self):
2520 for df
in [1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]:
2521 dfs =
torch.tensor([df] * num_samples, dtype=torch.float, requires_grad=
True)
2522 x = Chi2(dfs).rsample()
2525 x = x.detach().numpy()
2526 actual_grad = dfs.grad[ind].numpy()
2528 cdf = scipy.stats.chi2.cdf
2529 pdf = scipy.stats.chi2.pdf
2530 eps = 0.01 * df / (1.0 + df ** 0.5)
2531 cdf_df = (cdf(x, df + eps) - cdf(x, df - eps)) / (2 * eps)
2533 expected_grad = -cdf_df / cdf_x
2534 rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
2535 self.assertLess(np.max(rel_error), 0.001,
'\n'.join([
2536 'Bad gradient dx/ddf for x ~ Chi2({})'.format(df),
2538 'expected {}'.format(expected_grad),
2539 'actual {}'.format(actual_grad),
2540 'rel error {}'.format(rel_error),
2541 'max error {}'.format(rel_error.max()),
2544 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
2545 def test_dirichlet_on_diagonal(self):
2547 grid = [1e-1, 1e0, 1e1]
2548 for a0, a1, a2
in product(grid, grid, grid):
2549 alphas =
torch.tensor([[a0, a1, a2]] * num_samples, dtype=torch.float, requires_grad=
True)
2550 x = Dirichlet(alphas).rsample()[:, 0]
2553 x = x.detach().numpy()
2554 actual_grad = alphas.grad[ind].numpy()[:, 0]
2557 cdf = scipy.stats.beta.cdf
2558 pdf = scipy.stats.beta.pdf
2559 alpha, beta = a0, a1 + a2
2560 eps = 0.01 * alpha / (1.0 + np.sqrt(alpha))
2561 cdf_alpha = (cdf(x, alpha + eps, beta) - cdf(x, alpha - eps, beta)) / (2 * eps)
2562 cdf_x = pdf(x, alpha, beta)
2563 expected_grad = -cdf_alpha / cdf_x
2564 rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
2565 self.assertLess(np.max(rel_error), 0.001,
'\n'.join([
2566 'Bad gradient dx[0]/dalpha[0] for Dirichlet([{}, {}, {}])'.format(a0, a1, a2),
2568 'expected {}'.format(expected_grad),
2569 'actual {}'.format(actual_grad),
2570 'rel error {}'.format(rel_error),
2571 'max error {}'.format(rel_error.max()),
2572 'at x={}'.format(x[rel_error.argmax()]),
2575 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
2576 def test_beta_wrt_alpha(self):
2578 grid = [1e-2, 1e-1, 1e0, 1e1, 1e2]
2579 for con1, con0
in product(grid, grid):
2580 con1s =
torch.tensor([con1] * num_samples, dtype=torch.float, requires_grad=
True)
2581 con0s = con1s.new_tensor([con0] * num_samples)
2582 x = Beta(con1s, con0s).rsample()
2585 x = x.detach().numpy()
2586 actual_grad = con1s.grad[ind].numpy()
2588 cdf = scipy.stats.beta.cdf
2589 pdf = scipy.stats.beta.pdf
2590 eps = 0.01 * con1 / (1.0 + np.sqrt(con1))
2591 cdf_alpha = (cdf(x, con1 + eps, con0) - cdf(x, con1 - eps, con0)) / (2 * eps)
2592 cdf_x = pdf(x, con1, con0)
2593 expected_grad = -cdf_alpha / cdf_x
2594 rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
2595 self.assertLess(np.max(rel_error), 0.005,
'\n'.join([
2596 'Bad gradient dx/dcon1 for x ~ Beta({}, {})'.format(con1, con0),
2598 'expected {}'.format(expected_grad),
2599 'actual {}'.format(actual_grad),
2600 'rel error {}'.format(rel_error),
2601 'max error {}'.format(rel_error.max()),
2602 'at x = {}'.format(x[rel_error.argmax()]),
2605 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
2606 def test_beta_wrt_beta(self):
2608 grid = [1e-2, 1e-1, 1e0, 1e1, 1e2]
2609 for con1, con0
in product(grid, grid):
2610 con0s =
torch.tensor([con0] * num_samples, dtype=torch.float, requires_grad=
True)
2611 con1s = con0s.new_tensor([con1] * num_samples)
2612 x = Beta(con1s, con0s).rsample()
2615 x = x.detach().numpy()
2616 actual_grad = con0s.grad[ind].numpy()
2618 cdf = scipy.stats.beta.cdf
2619 pdf = scipy.stats.beta.pdf
2620 eps = 0.01 * con0 / (1.0 + np.sqrt(con0))
2621 cdf_beta = (cdf(x, con1, con0 + eps) - cdf(x, con1, con0 - eps)) / (2 * eps)
2622 cdf_x = pdf(x, con1, con0)
2623 expected_grad = -cdf_beta / cdf_x
2624 rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
2625 self.assertLess(np.max(rel_error), 0.005,
'\n'.join([
2626 'Bad gradient dx/dcon0 for x ~ Beta({}, {})'.format(con1, con0),
2628 'expected {}'.format(expected_grad),
2629 'actual {}'.format(actual_grad),
2630 'rel error {}'.format(rel_error),
2631 'max error {}'.format(rel_error.max()),
2632 'at x = {!r}'.format(x[rel_error.argmax()]),
2635 def test_dirichlet_multivariate(self):
2636 alpha_crit = 0.25 * (5.0 ** 0.5 - 1.0)
2637 num_samples = 100000
2638 for shift
in [-0.1, -0.05, -0.01, 0.0, 0.01, 0.05, 0.10]:
2639 alpha = alpha_crit + shift
2640 alpha =
torch.tensor([alpha], dtype=torch.float, requires_grad=
True)
2641 alpha_vec = torch.cat([alpha, alpha, alpha.new([1])])
2642 z = Dirichlet(alpha_vec.expand(num_samples, 3)).rsample()
2643 mean_z3 = 1.0 / (2.0 * alpha + 1.0)
2644 loss = torch.pow(z[:, 2] - mean_z3, 2.0).mean()
2645 actual_grad = grad(loss, [alpha])[0]
2647 num = 1.0 - 2.0 * alpha - 4.0 * alpha**2
2648 den = (1.0 + alpha)**2 * (1.0 + 2.0 * alpha)**3
2649 expected_grad = num / den
2650 self.
assertEqual(actual_grad, expected_grad, 0.002,
'\n'.join([
2651 "alpha = alpha_c + %.2g" % shift,
2652 "expected_grad: %.5g" % expected_grad,
2653 "actual_grad: %.5g" % actual_grad,
2654 "error = %.2g" % torch.abs(expected_grad - actual_grad).max(),
2657 def test_dirichlet_tangent_field(self):
2659 alpha_grid = [0.5, 1.0, 2.0]
2662 def compute_v(x, alpha):
2663 return torch.stack([
2664 _Dirichlet_backward(x, alpha, torch.eye(3, 3)[i].expand_as(x))[:, 0]
2668 for a1, a2, a3
in product(alpha_grid, alpha_grid, alpha_grid):
2669 alpha =
torch.tensor([a1, a2, a3], requires_grad=
True).expand(num_samples, 3)
2670 x = Dirichlet(alpha).rsample()
2671 dlogp_da = grad([Dirichlet(alpha).log_prob(x.detach()).sum()],
2672 [alpha], retain_graph=
True)[0][:, 0]
2673 dlogp_dx = grad([Dirichlet(alpha.detach()).log_prob(x).sum()],
2674 [x], retain_graph=
True)[0]
2675 v = torch.stack([grad([x[:, i].sum()], [alpha], retain_graph=
True)[0][:, 0]
2676 for i
in range(3)], dim=-1)
2678 self.
assertEqual(compute_v(x, alpha), v, message=
'Bug in compute_v() helper')
2681 dx /= dx.norm(2, -1,
True)
2682 eps = 1e-2 * x.min(-1,
True)[0]
2683 dv0 = (compute_v(x + eps * dx[0], alpha) - compute_v(x - eps * dx[0], alpha)) / (2 * eps)
2684 dv1 = (compute_v(x + eps * dx[1], alpha) - compute_v(x - eps * dx[1], alpha)) / (2 * eps)
2685 div_v = (dv0 * dx[0] + dv1 * dx[1]).sum(-1)
2688 error = dlogp_da + (dlogp_dx * v).sum(-1) + div_v
2689 self.assertLess(torch.abs(error).max(), 0.005,
'\n'.join([
2690 'Dirichlet([{}, {}, {}]) gradient violates continuity equation:'.format(a1, a2, a3),
2691 'error = {}'.format(error),
2697 super(TestCase, self).setUp()
2701 Distribution.set_default_validate_args(
True)
2704 super(TestCase, self).tearDown()
2705 Distribution.set_default_validate_args(
False)
2707 def test_entropy_shape(self):
2708 for Dist, params
in EXAMPLES:
2709 for i, param
in enumerate(params):
2710 dist = Dist(validate_args=
False, **param)
2712 actual_shape = dist.entropy().size()
2713 expected_shape = dist.batch_shape
if dist.batch_shape
else torch.Size()
2714 message =
'{} example {}/{}, shape mismatch. expected {}, actual {}'.format(
2715 Dist.__name__, i + 1, len(params), expected_shape, actual_shape)
2716 self.
assertEqual(actual_shape, expected_shape, message=message)
2717 except NotImplementedError:
2720 def test_bernoulli_shape_scalar_params(self):
2721 bernoulli = Bernoulli(0.3)
2722 self.
assertEqual(bernoulli._batch_shape, torch.Size())
2723 self.
assertEqual(bernoulli._event_shape, torch.Size())
2724 self.
assertEqual(bernoulli.sample().size(), torch.Size())
2725 self.
assertEqual(bernoulli.sample((3, 2)).size(), torch.Size((3, 2)))
2726 self.assertRaises(ValueError, bernoulli.log_prob, self.
scalar_sample)
2730 def test_bernoulli_shape_tensor_params(self):
2731 bernoulli = Bernoulli(
torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
2732 self.
assertEqual(bernoulli._batch_shape, torch.Size((3, 2)))
2733 self.
assertEqual(bernoulli._event_shape, torch.Size(()))
2734 self.
assertEqual(bernoulli.sample().size(), torch.Size((3, 2)))
2735 self.
assertEqual(bernoulli.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
2737 self.assertRaises(ValueError, bernoulli.log_prob, self.
tensor_sample_2)
2738 self.
assertEqual(bernoulli.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)))
2740 def test_geometric_shape_scalar_params(self):
2741 geometric = Geometric(0.3)
2742 self.
assertEqual(geometric._batch_shape, torch.Size())
2743 self.
assertEqual(geometric._event_shape, torch.Size())
2744 self.
assertEqual(geometric.sample().size(), torch.Size())
2745 self.
assertEqual(geometric.sample((3, 2)).size(), torch.Size((3, 2)))
2746 self.assertRaises(ValueError, geometric.log_prob, self.
scalar_sample)
2750 def test_geometric_shape_tensor_params(self):
2751 geometric = Geometric(
torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
2752 self.
assertEqual(geometric._batch_shape, torch.Size((3, 2)))
2753 self.
assertEqual(geometric._event_shape, torch.Size(()))
2754 self.
assertEqual(geometric.sample().size(), torch.Size((3, 2)))
2755 self.
assertEqual(geometric.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
2757 self.assertRaises(ValueError, geometric.log_prob, self.
tensor_sample_2)
2758 self.
assertEqual(geometric.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)))
2760 def test_beta_shape_scalar_params(self):
2761 dist = Beta(0.1, 0.1)
2764 self.
assertEqual(dist.sample().size(), torch.Size())
2765 self.
assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2)))
2766 self.assertRaises(ValueError, dist.log_prob, self.
scalar_sample)
2770 def test_beta_shape_tensor_params(self):
2771 dist = Beta(
torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]),
2773 self.
assertEqual(dist._batch_shape, torch.Size((3, 2)))
2774 self.
assertEqual(dist._event_shape, torch.Size(()))
2775 self.
assertEqual(dist.sample().size(), torch.Size((3, 2)))
2776 self.
assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
2779 self.
assertEqual(dist.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)))
2781 def test_binomial_shape(self):
2783 self.
assertEqual(dist._batch_shape, torch.Size((2,)))
2784 self.
assertEqual(dist._event_shape, torch.Size(()))
2785 self.
assertEqual(dist.sample().size(), torch.Size((2,)))
2786 self.
assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 2)))
2790 def test_binomial_shape_vectorized_n(self):
2792 self.
assertEqual(dist._batch_shape, torch.Size((2, 3)))
2793 self.
assertEqual(dist._event_shape, torch.Size(()))
2794 self.
assertEqual(dist.sample().size(), torch.Size((2, 3)))
2795 self.
assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 2, 3)))
2799 def test_multinomial_shape(self):
2800 dist = Multinomial(10,
torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
2801 self.
assertEqual(dist._batch_shape, torch.Size((3,)))
2802 self.
assertEqual(dist._event_shape, torch.Size((2,)))
2803 self.
assertEqual(dist.sample().size(), torch.Size((3, 2)))
2804 self.
assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
2807 self.
assertEqual(dist.log_prob(torch.ones(3, 1, 2)).size(), torch.Size((3, 3)))
2809 def test_categorical_shape(self):
2812 self.
assertEqual(dist._batch_shape, torch.Size(()))
2813 self.
assertEqual(dist._event_shape, torch.Size(()))
2814 self.
assertEqual(dist.sample().size(), torch.Size())
2815 self.
assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2,)))
2818 self.
assertEqual(dist.log_prob(torch.ones(3, 1)).size(), torch.Size((3, 1)))
2820 dist = Categorical(
torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
2821 self.
assertEqual(dist._batch_shape, torch.Size((3,)))
2822 self.
assertEqual(dist._event_shape, torch.Size(()))
2823 self.
assertEqual(dist.sample().size(), torch.Size((3,)))
2824 self.
assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3,)))
2827 self.
assertEqual(dist.log_prob(torch.ones(3, 1)).size(), torch.Size((3, 3)))
2829 def test_one_hot_categorical_shape(self):
2831 dist = OneHotCategorical(
torch.tensor([0.6, 0.3, 0.1]))
2832 self.
assertEqual(dist._batch_shape, torch.Size(()))
2833 self.
assertEqual(dist._event_shape, torch.Size((3,)))
2834 self.
assertEqual(dist.sample().size(), torch.Size((3,)))
2835 self.
assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3)))
2837 simplex_sample = self.
tensor_sample_2 / self.tensor_sample_2.sum(-1, keepdim=
True)
2838 self.
assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3, 2,)))
2839 self.
assertEqual(dist.log_prob(dist.enumerate_support()).size(), torch.Size((3,)))
2840 simplex_sample = torch.ones(3, 3) / 3
2841 self.
assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3,)))
2843 dist = OneHotCategorical(
torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
2844 self.
assertEqual(dist._batch_shape, torch.Size((3,)))
2845 self.
assertEqual(dist._event_shape, torch.Size((2,)))
2846 self.
assertEqual(dist.sample().size(), torch.Size((3, 2)))
2847 self.
assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
2848 simplex_sample = self.
tensor_sample_1 / self.tensor_sample_1.sum(-1, keepdim=
True)
2849 self.
assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3,)))
2851 self.
assertEqual(dist.log_prob(dist.enumerate_support()).size(), torch.Size((2, 3)))
2852 simplex_sample = torch.ones(3, 1, 2) / 2
2853 self.
assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3, 3)))
2855 def test_cauchy_shape_scalar_params(self):
2856 cauchy = Cauchy(0, 1)
2857 self.
assertEqual(cauchy._batch_shape, torch.Size())
2858 self.
assertEqual(cauchy._event_shape, torch.Size())
2859 self.
assertEqual(cauchy.sample().size(), torch.Size())
2860 self.
assertEqual(cauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2)))
2861 self.assertRaises(ValueError, cauchy.log_prob, self.
scalar_sample)
2865 def test_cauchy_shape_tensor_params(self):
2867 self.
assertEqual(cauchy._batch_shape, torch.Size((2,)))
2868 self.
assertEqual(cauchy._event_shape, torch.Size(()))
2869 self.
assertEqual(cauchy.sample().size(), torch.Size((2,)))
2870 self.
assertEqual(cauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2)))
2873 self.
assertEqual(cauchy.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
2875 def test_halfcauchy_shape_scalar_params(self):
2876 halfcauchy = HalfCauchy(1)
2877 self.
assertEqual(halfcauchy._batch_shape, torch.Size())
2878 self.
assertEqual(halfcauchy._event_shape, torch.Size())
2879 self.
assertEqual(halfcauchy.sample().size(), torch.Size())
2880 self.
assertEqual(halfcauchy.sample(torch.Size((3, 2))).size(),
2882 self.assertRaises(ValueError, halfcauchy.log_prob, self.
scalar_sample)
2886 torch.Size((3, 2, 3)))
2888 def test_halfcauchy_shape_tensor_params(self):
2890 self.
assertEqual(halfcauchy._batch_shape, torch.Size((2,)))
2891 self.
assertEqual(halfcauchy._event_shape, torch.Size(()))
2892 self.
assertEqual(halfcauchy.sample().size(), torch.Size((2,)))
2893 self.
assertEqual(halfcauchy.sample(torch.Size((3, 2))).size(),
2894 torch.Size((3, 2, 2)))
2897 self.assertRaises(ValueError, halfcauchy.log_prob, self.
tensor_sample_2)
2898 self.
assertEqual(halfcauchy.log_prob(torch.ones(2, 1)).size(),
2901 def test_dirichlet_shape(self):
2902 dist = Dirichlet(
torch.tensor([[0.6, 0.3], [1.6, 1.3], [2.6, 2.3]]))
2903 self.
assertEqual(dist._batch_shape, torch.Size((3,)))
2904 self.
assertEqual(dist._event_shape, torch.Size((2,)))
2905 self.
assertEqual(dist.sample().size(), torch.Size((3, 2)))
2906 self.
assertEqual(dist.sample((5, 4)).size(), torch.Size((5, 4, 3, 2)))
2907 simplex_sample = self.
tensor_sample_1 / self.tensor_sample_1.sum(-1, keepdim=
True)
2908 self.
assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3,)))
2910 simplex_sample = torch.ones(3, 1, 2)
2911 simplex_sample = simplex_sample / simplex_sample.sum(-1).unsqueeze(-1)
2912 self.
assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3, 3)))
2914 def test_gamma_shape_scalar_params(self):
2916 self.
assertEqual(gamma._batch_shape, torch.Size())
2917 self.
assertEqual(gamma._event_shape, torch.Size())
2918 self.
assertEqual(gamma.sample().size(), torch.Size())
2919 self.
assertEqual(gamma.sample((3, 2)).size(), torch.Size((3, 2)))
2920 self.assertRaises(ValueError, gamma.log_prob, self.
scalar_sample)
2924 def test_gamma_shape_tensor_params(self):
2926 self.
assertEqual(gamma._batch_shape, torch.Size((2,)))
2927 self.
assertEqual(gamma._event_shape, torch.Size(()))
2928 self.
assertEqual(gamma.sample().size(), torch.Size((2,)))
2929 self.
assertEqual(gamma.sample((3, 2)).size(), torch.Size((3, 2, 2)))
2932 self.
assertEqual(gamma.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
2934 def test_chi2_shape_scalar_params(self):
2938 self.
assertEqual(chi2.sample().size(), torch.Size())
2939 self.
assertEqual(chi2.sample((3, 2)).size(), torch.Size((3, 2)))
2940 self.assertRaises(ValueError, chi2.log_prob, self.
scalar_sample)
2944 def test_chi2_shape_tensor_params(self):
2946 self.
assertEqual(chi2._batch_shape, torch.Size((2,)))
2947 self.
assertEqual(chi2._event_shape, torch.Size(()))
2948 self.
assertEqual(chi2.sample().size(), torch.Size((2,)))
2949 self.
assertEqual(chi2.sample((3, 2)).size(), torch.Size((3, 2, 2)))
2952 self.
assertEqual(chi2.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
2954 def test_studentT_shape_scalar_params(self):
2958 self.
assertEqual(st.sample().size(), torch.Size())
2959 self.
assertEqual(st.sample((3, 2)).size(), torch.Size((3, 2)))
2960 self.assertRaises(ValueError, st.log_prob, self.
scalar_sample)
2964 def test_studentT_shape_tensor_params(self):
2966 self.
assertEqual(st._batch_shape, torch.Size((2,)))
2968 self.
assertEqual(st.sample().size(), torch.Size((2,)))
2969 self.
assertEqual(st.sample((3, 2)).size(), torch.Size((3, 2, 2)))
2972 self.
assertEqual(st.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
2974 def test_pareto_shape_scalar_params(self):
2975 pareto = Pareto(1, 1)
2976 self.
assertEqual(pareto._batch_shape, torch.Size())
2977 self.
assertEqual(pareto._event_shape, torch.Size())
2978 self.
assertEqual(pareto.sample().size(), torch.Size())
2979 self.
assertEqual(pareto.sample((3, 2)).size(), torch.Size((3, 2)))
2983 def test_gumbel_shape_scalar_params(self):
2984 gumbel = Gumbel(1, 1)
2985 self.
assertEqual(gumbel._batch_shape, torch.Size())
2986 self.
assertEqual(gumbel._event_shape, torch.Size())
2987 self.
assertEqual(gumbel.sample().size(), torch.Size())
2988 self.
assertEqual(gumbel.sample((3, 2)).size(), torch.Size((3, 2)))
2992 def test_weibull_scale_scalar_params(self):
2993 weibull = Weibull(1, 1)
2994 self.
assertEqual(weibull._batch_shape, torch.Size())
2995 self.
assertEqual(weibull._event_shape, torch.Size())
2996 self.
assertEqual(weibull.sample().size(), torch.Size())
2997 self.
assertEqual(weibull.sample((3, 2)).size(), torch.Size((3, 2)))
3001 def test_normal_shape_scalar_params(self):
3002 normal = Normal(0, 1)
3003 self.
assertEqual(normal._batch_shape, torch.Size())
3004 self.
assertEqual(normal._event_shape, torch.Size())
3005 self.
assertEqual(normal.sample().size(), torch.Size())
3006 self.
assertEqual(normal.sample((3, 2)).size(), torch.Size((3, 2)))
3007 self.assertRaises(ValueError, normal.log_prob, self.
scalar_sample)
3011 def test_normal_shape_tensor_params(self):
3013 self.
assertEqual(normal._batch_shape, torch.Size((2,)))
3014 self.
assertEqual(normal._event_shape, torch.Size(()))
3015 self.
assertEqual(normal.sample().size(), torch.Size((2,)))
3016 self.
assertEqual(normal.sample((3, 2)).size(), torch.Size((3, 2, 2)))
3019 self.
assertEqual(normal.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
3021 def test_uniform_shape_scalar_params(self):
3022 uniform = Uniform(0, 1)
3023 self.
assertEqual(uniform._batch_shape, torch.Size())
3024 self.
assertEqual(uniform._event_shape, torch.Size())
3025 self.
assertEqual(uniform.sample().size(), torch.Size())
3026 self.
assertEqual(uniform.sample(torch.Size((3, 2))).size(), torch.Size((3, 2)))
3027 self.assertRaises(ValueError, uniform.log_prob, self.
scalar_sample)
3031 def test_uniform_shape_tensor_params(self):
3033 self.
assertEqual(uniform._batch_shape, torch.Size((2,)))
3034 self.
assertEqual(uniform._event_shape, torch.Size(()))
3035 self.
assertEqual(uniform.sample().size(), torch.Size((2,)))
3036 self.
assertEqual(uniform.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2)))
3039 self.
assertEqual(uniform.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
3041 def test_exponential_shape_scalar_param(self):
3042 expon = Exponential(1.)
3043 self.
assertEqual(expon._batch_shape, torch.Size())
3044 self.
assertEqual(expon._event_shape, torch.Size())
3045 self.
assertEqual(expon.sample().size(), torch.Size())
3046 self.
assertEqual(expon.sample((3, 2)).size(), torch.Size((3, 2)))
3047 self.assertRaises(ValueError, expon.log_prob, self.
scalar_sample)
3051 def test_exponential_shape_tensor_param(self):
3053 self.
assertEqual(expon._batch_shape, torch.Size((2,)))
3054 self.
assertEqual(expon._event_shape, torch.Size(()))
3055 self.
assertEqual(expon.sample().size(), torch.Size((2,)))
3056 self.
assertEqual(expon.sample((3, 2)).size(), torch.Size((3, 2, 2)))
3059 self.
assertEqual(expon.log_prob(torch.ones(2, 2)).size(), torch.Size((2, 2)))
3061 def test_laplace_shape_scalar_params(self):
3062 laplace = Laplace(0, 1)
3063 self.
assertEqual(laplace._batch_shape, torch.Size())
3064 self.
assertEqual(laplace._event_shape, torch.Size())
3065 self.
assertEqual(laplace.sample().size(), torch.Size())
3066 self.
assertEqual(laplace.sample((3, 2)).size(), torch.Size((3, 2)))
3067 self.assertRaises(ValueError, laplace.log_prob, self.
scalar_sample)
3071 def test_laplace_shape_tensor_params(self):
3073 self.
assertEqual(laplace._batch_shape, torch.Size((2,)))
3074 self.
assertEqual(laplace._event_shape, torch.Size(()))
3075 self.
assertEqual(laplace.sample().size(), torch.Size((2,)))
3076 self.
assertEqual(laplace.sample((3, 2)).size(), torch.Size((3, 2, 2)))
3079 self.
assertEqual(laplace.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
3086 class Binomial30(Binomial):
3087 def __init__(self, probs):
3088 super(Binomial30, self).__init__(30, probs)
3093 bernoulli =
pairwise(Bernoulli, [0.1, 0.2, 0.6, 0.9])
3094 binomial30 =
pairwise(Binomial30, [0.1, 0.2, 0.6, 0.9])
3097 beta =
pairwise(Beta, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5])
3098 categorical =
pairwise(Categorical, [[0.4, 0.3, 0.3],
3102 chi2 =
pairwise(Chi2, [1.0, 2.0, 2.5, 5.0])
3103 dirichlet =
pairwise(Dirichlet, [[0.1, 0.2, 0.7],
3107 exponential =
pairwise(Exponential, [1.0, 2.5, 5.0, 10.0])
3108 gamma =
pairwise(Gamma, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5])
3109 gumbel =
pairwise(Gumbel, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5])
3110 halfnormal =
pairwise(HalfNormal, [1.0, 2.0, 1.0, 2.0])
3111 laplace =
pairwise(Laplace, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5])
3112 lognormal =
pairwise(LogNormal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
3113 normal =
pairwise(Normal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
3114 independent = (Independent(normal[0], 1), Independent(normal[1], 1))
3115 onehotcategorical =
pairwise(OneHotCategorical, [[0.4, 0.3, 0.3],
3119 pareto =
pairwise(Pareto, [2.5, 4.0, 2.5, 4.0], [2.25, 3.75, 2.25, 3.75])
3120 poisson =
pairwise(Poisson, [0.3, 1.0, 5.0, 10.0])
3121 uniform_within_unit =
pairwise(Uniform, [0.15, 0.95, 0.2, 0.8], [0.1, 0.9, 0.25, 0.75])
3122 uniform_positive =
pairwise(Uniform, [1, 1.5, 2, 4], [1.2, 2.0, 3, 7])
3123 uniform_real =
pairwise(Uniform, [-2., -1, 0, 2], [-1., 1, 1, 4])
3124 uniform_pareto =
pairwise(Uniform, [6.5, 8.5, 6.5, 8.5], [7.5, 7.5, 9.5, 9.5])
3136 (bernoulli, bernoulli),
3137 (bernoulli, poisson),
3140 (beta, exponential),
3143 (binomial30, binomial30),
3144 (binomial_vectorized_count, binomial_vectorized_count),
3145 (categorical, categorical),
3147 (chi2, exponential),
3150 (dirichlet, dirichlet),
3151 (exponential, chi2),
3152 (exponential, exponential),
3153 (exponential, gamma),
3154 (exponential, gumbel),
3155 (exponential, normal),
3157 (gamma, exponential),
3163 (halfnormal, halfnormal),
3164 (independent, independent),
3166 (lognormal, lognormal),
3170 (onehotcategorical, onehotcategorical),
3173 (pareto, exponential),
3176 (uniform_within_unit, beta),
3177 (uniform_positive, chi2),
3178 (uniform_positive, exponential),
3179 (uniform_positive, gamma),
3180 (uniform_real, gumbel),
3181 (uniform_real, normal),
3182 (uniform_pareto, pareto),
3186 (Bernoulli(0), Bernoulli(1)),
3187 (Bernoulli(1), Bernoulli(0)),
3190 (Beta(1, 2), Uniform(0.25, 1)),
3191 (Beta(1, 2), Uniform(0, 0.75)),
3192 (Beta(1, 2), Uniform(0.25, 0.75)),
3193 (Beta(1, 2), Pareto(1, 2)),
3194 (Binomial(31, 0.7), Binomial(30, 0.3)),
3197 (Chi2(1), Beta(2, 3)),
3198 (Chi2(1), Pareto(2, 3)),
3199 (Chi2(1), Uniform(-2, 3)),
3200 (Exponential(1), Beta(2, 3)),
3201 (Exponential(1), Pareto(2, 3)),
3202 (Exponential(1), Uniform(-2, 3)),
3203 (Gamma(1, 2), Beta(3, 4)),
3204 (Gamma(1, 2), Pareto(3, 4)),
3205 (Gamma(1, 2), Uniform(-3, 4)),
3206 (Gumbel(-1, 2), Beta(3, 4)),
3207 (Gumbel(-1, 2), Chi2(3)),
3208 (Gumbel(-1, 2), Exponential(3)),
3209 (Gumbel(-1, 2), Gamma(3, 4)),
3210 (Gumbel(-1, 2), Pareto(3, 4)),
3211 (Gumbel(-1, 2), Uniform(-3, 4)),
3212 (Laplace(-1, 2), Beta(3, 4)),
3213 (Laplace(-1, 2), Chi2(3)),
3214 (Laplace(-1, 2), Exponential(3)),
3215 (Laplace(-1, 2), Gamma(3, 4)),
3216 (Laplace(-1, 2), Pareto(3, 4)),
3217 (Laplace(-1, 2), Uniform(-3, 4)),
3218 (Normal(-1, 2), Beta(3, 4)),
3219 (Normal(-1, 2), Chi2(3)),
3220 (Normal(-1, 2), Exponential(3)),
3221 (Normal(-1, 2), Gamma(3, 4)),
3222 (Normal(-1, 2), Pareto(3, 4)),
3223 (Normal(-1, 2), Uniform(-3, 4)),
3224 (Pareto(2, 1), Chi2(3)),
3225 (Pareto(2, 1), Exponential(3)),
3226 (Pareto(2, 1), Gamma(3, 4)),
3227 (Pareto(1, 2), Normal(-3, 4)),
3228 (Pareto(1, 2), Pareto(3, 4)),
3229 (Poisson(2), Bernoulli(0.5)),
3230 (Poisson(2.3), Binomial(10, 0.2)),
3231 (Uniform(-1, 1), Beta(2, 2)),
3232 (Uniform(0, 2), Beta(3, 4)),
3233 (Uniform(-1, 2), Beta(3, 4)),
3234 (Uniform(-1, 2), Chi2(3)),
3235 (Uniform(-1, 2), Exponential(3)),
3236 (Uniform(-1, 2), Gamma(3, 4)),
3237 (Uniform(-1, 2), Pareto(3, 4)),
3240 def test_kl_monte_carlo(self):
3243 actual = kl_divergence(p, q)
3248 numerator += (p.log_prob(x) - q.log_prob(x)).sum(0)
3249 denominator += x.size(0)
3250 expected = numerator / denominator
3251 error = torch.abs(expected - actual) / (1 + expected)
3252 if error[error == error].max() < self.
precision:
3254 self.assertLess(error[error == error].max(), self.
precision,
'\n'.join([
3255 'Incorrect KL({}, {}).'.format(type(p).__name__, type(q).__name__),
3256 'Expected ({} Monte Carlo samples): {}'.format(denominator, expected),
3257 'Actual (analytic): {}'.format(actual),
3262 def test_kl_multivariate_normal(self):
3265 for i
in range(0, n):
3266 loc = [torch.randn(4)
for _
in range(0, 2)]
3267 scale_tril = [transform_to(constraints.lower_cholesky)(torch.randn(4, 4))
for _
in range(0, 2)]
3268 p = MultivariateNormal(loc=loc[0], scale_tril=scale_tril[0])
3269 q = MultivariateNormal(loc=loc[1], scale_tril=scale_tril[1])
3270 actual = kl_divergence(p, q)
3275 numerator += (p.log_prob(x) - q.log_prob(x)).sum(0)
3276 denominator += x.size(0)
3277 expected = numerator / denominator
3278 error = torch.abs(expected - actual) / (1 + expected)
3279 if error[error == error].max() < self.
precision:
3281 self.assertLess(error[error == error].max(), self.
precision,
'\n'.join([
3282 'Incorrect KL(MultivariateNormal, MultivariateNormal) instance {}/{}'.format(i + 1, n),
3283 'Expected ({} Monte Carlo sample): {}'.format(denominator, expected),
3284 'Actual (analytic): {}'.format(actual),
3287 def test_kl_multivariate_normal_batched(self):
3289 loc = [torch.randn(b, 3)
for _
in range(0, 2)]
3290 scale_tril = [transform_to(constraints.lower_cholesky)(torch.randn(b, 3, 3))
for _
in range(0, 2)]
3291 expected_kl = torch.stack([
3292 kl_divergence(MultivariateNormal(loc[0][i], scale_tril=scale_tril[0][i]),
3293 MultivariateNormal(loc[1][i], scale_tril=scale_tril[1][i]))
for i
in range(0, b)])
3294 actual_kl = kl_divergence(MultivariateNormal(loc[0], scale_tril=scale_tril[0]),
3295 MultivariateNormal(loc[1], scale_tril=scale_tril[1]))
3298 def test_kl_multivariate_normal_batched_broadcasted(self):
3300 loc = [torch.randn(b, 3)
for _
in range(0, 2)]
3301 scale_tril = [transform_to(constraints.lower_cholesky)(torch.randn(b, 3, 3)),
3302 transform_to(constraints.lower_cholesky)(torch.randn(3, 3))]
3303 expected_kl = torch.stack([
3304 kl_divergence(MultivariateNormal(loc[0][i], scale_tril=scale_tril[0][i]),
3305 MultivariateNormal(loc[1][i], scale_tril=scale_tril[1]))
for i
in range(0, b)])
3306 actual_kl = kl_divergence(MultivariateNormal(loc[0], scale_tril=scale_tril[0]),
3307 MultivariateNormal(loc[1], scale_tril=scale_tril[1]))
3310 def test_kl_lowrank_multivariate_normal(self):
3313 for i
in range(0, n):
3314 loc = [torch.randn(4)
for _
in range(0, 2)]
3315 cov_factor = [torch.randn(4, 3)
for _
in range(0, 2)]
3316 cov_diag = [transform_to(constraints.positive)(torch.randn(4))
for _
in range(0, 2)]
3317 covariance_matrix = [cov_factor[i].matmul(cov_factor[i].t()) +
3318 cov_diag[i].diag()
for i
in range(0, 2)]
3319 p = LowRankMultivariateNormal(loc[0], cov_factor[0], cov_diag[0])
3320 q = LowRankMultivariateNormal(loc[1], cov_factor[1], cov_diag[1])
3321 p_full = MultivariateNormal(loc[0], covariance_matrix[0])
3322 q_full = MultivariateNormal(loc[1], covariance_matrix[1])
3323 expected = kl_divergence(p_full, q_full)
3325 actual_lowrank_lowrank = kl_divergence(p, q)
3326 actual_lowrank_full = kl_divergence(p, q_full)
3327 actual_full_lowrank = kl_divergence(p_full, q)
3329 error_lowrank_lowrank = torch.abs(actual_lowrank_lowrank - expected).max()
3330 self.assertLess(error_lowrank_lowrank, self.
precision,
'\n'.join([
3331 'Incorrect KL(LowRankMultivariateNormal, LowRankMultivariateNormal) instance {}/{}'.format(i + 1, n),
3332 'Expected (from KL MultivariateNormal): {}'.format(expected),
3333 'Actual (analytic): {}'.format(actual_lowrank_lowrank),
3336 error_lowrank_full = torch.abs(actual_lowrank_full - expected).max()
3337 self.assertLess(error_lowrank_full, self.
precision,
'\n'.join([
3338 'Incorrect KL(LowRankMultivariateNormal, MultivariateNormal) instance {}/{}'.format(i + 1, n),
3339 'Expected (from KL MultivariateNormal): {}'.format(expected),
3340 'Actual (analytic): {}'.format(actual_lowrank_full),
3343 error_full_lowrank = torch.abs(actual_full_lowrank - expected).max()
3344 self.assertLess(error_full_lowrank, self.
precision,
'\n'.join([
3345 'Incorrect KL(MultivariateNormal, LowRankMultivariateNormal) instance {}/{}'.format(i + 1, n),
3346 'Expected (from KL MultivariateNormal): {}'.format(expected),
3347 'Actual (analytic): {}'.format(actual_full_lowrank),
3350 def test_kl_lowrank_multivariate_normal_batched(self):
3352 loc = [torch.randn(b, 3)
for _
in range(0, 2)]
3353 cov_factor = [torch.randn(b, 3, 2)
for _
in range(0, 2)]
3354 cov_diag = [transform_to(constraints.positive)(torch.randn(b, 3))
for _
in range(0, 2)]
3355 expected_kl = torch.stack([
3356 kl_divergence(LowRankMultivariateNormal(loc[0][i], cov_factor[0][i], cov_diag[0][i]),
3357 LowRankMultivariateNormal(loc[1][i], cov_factor[1][i], cov_diag[1][i]))
3358 for i
in range(0, b)])
3359 actual_kl = kl_divergence(LowRankMultivariateNormal(loc[0], cov_factor[0], cov_diag[0]),
3360 LowRankMultivariateNormal(loc[1], cov_factor[1], cov_diag[1]))
3363 def test_kl_exponential_family(self):
3365 if type(p) == type(q)
and issubclass(type(p), ExponentialFamily):
3366 actual = kl_divergence(p, q)
3367 expected = _kl_expfamily_expfamily(p, q)
3368 self.
assertEqual(actual, expected, message=
'\n'.join([
3369 'Incorrect KL({}, {}).'.format(type(p).__name__, type(q).__name__),
3370 'Expected (using Bregman Divergence) {}'.format(expected),
3371 'Actual (analytic) {}'.format(actual),
3372 'max error = {}'.format(torch.abs(actual - expected).max())
3375 def test_kl_infinite(self):
3377 self.assertTrue((kl_divergence(p, q) == inf).all(),
3378 'Incorrect KL({}, {})'.format(type(p).__name__, type(q).__name__))
3380 def test_kl_edgecases(self):
3381 self.
assertEqual(kl_divergence(Bernoulli(0), Bernoulli(0)), 0)
3382 self.
assertEqual(kl_divergence(Bernoulli(1), Bernoulli(1)), 0)
3385 def test_kl_shape(self):
3386 for Dist, params
in EXAMPLES:
3387 for i, param
in enumerate(params):
3388 dist = Dist(**param)
3390 kl = kl_divergence(dist, dist)
3391 except NotImplementedError:
3393 expected_shape = dist.batch_shape
if dist.batch_shape
else torch.Size()
3394 self.
assertEqual(kl.shape, expected_shape, message=
'\n'.join([
3395 '{} example {}/{}'.format(Dist.__name__, i + 1, len(params)),
3396 'Expected {}'.format(expected_shape),
3397 'Actual {}'.format(kl.shape),
3400 def test_entropy_monte_carlo(self):
3402 for Dist, params
in EXAMPLES:
3403 for i, param
in enumerate(params):
3404 dist = Dist(**param)
3406 actual = dist.entropy()
3407 except NotImplementedError:
3409 x = dist.sample(sample_shape=(60000,))
3410 expected = -dist.log_prob(x).mean(0)
3411 ignore = (expected == inf) | (expected == -inf)
3412 expected[ignore] = actual[ignore]
3413 self.
assertEqual(actual, expected, prec=0.2, message=
'\n'.join([
3414 '{} example {}/{}, incorrect .entropy().'.format(Dist.__name__, i + 1, len(params)),
3415 'Expected (monte carlo) {}'.format(expected),
3416 'Actual (analytic) {}'.format(actual),
3417 'max error = {}'.format(torch.abs(actual - expected).max()),
3420 def test_entropy_exponential_family(self):
3421 for Dist, params
in EXAMPLES:
3422 if not issubclass(Dist, ExponentialFamily):
3424 for i, param
in enumerate(params):
3425 dist = Dist(**param)
3427 actual = dist.entropy()
3428 except NotImplementedError:
3431 expected = ExponentialFamily.entropy(dist)
3432 except NotImplementedError:
3434 self.
assertEqual(actual, expected, message=
'\n'.join([
3435 '{} example {}/{}, incorrect .entropy().'.format(Dist.__name__, i + 1, len(params)),
3436 'Expected (Bregman Divergence) {}'.format(expected),
3437 'Actual (analytic) {}'.format(actual),
3438 'max error = {}'.format(torch.abs(actual - expected).max())
3443 def test_params_contains(self):
3444 for Dist, params
in EXAMPLES:
3445 for i, param
in enumerate(params):
3446 dist = Dist(**param)
3447 for name, value
in param.items():
3448 if isinstance(value, numbers.Number):
3450 if Dist
in (Categorical, OneHotCategorical, Multinomial)
and name ==
'probs':
3453 value = value / value.sum(-1,
True)
3455 constraint = dist.arg_constraints[name]
3459 if is_dependent(constraint):
3462 message =
'{} example {}/{} parameter {} = {}'.format(
3463 Dist.__name__, i + 1, len(params), name, value)
3464 self.assertTrue(constraint.check(value).all(), msg=message)
3466 def test_support_contains(self):
3467 for Dist, params
in EXAMPLES:
3468 self.assertIsInstance(Dist.support, Constraint)
3469 for i, param
in enumerate(params):
3470 dist = Dist(**param)
3471 value = dist.sample()
3472 constraint = dist.support
3473 message =
'{} example {}/{} sample = {}'.format(
3474 Dist.__name__, i + 1, len(params), value)
3475 self.assertTrue(constraint.check(value).all(), msg=message)
3479 def _test_pdf_score(self,
3485 expected_gradient=
None,
3487 if probs
is not None:
3488 p = probs.detach().requires_grad_()
3489 dist = dist_class(p)
3491 p = logits.detach().requires_grad_()
3492 dist = dist_class(logits=p)
3493 log_pdf = dist.log_prob(x)
3494 log_pdf.sum().backward()
3498 message=
'Incorrect value for tensor type: {}. Expected = {}, Actual = {}' 3499 .format(type(x), expected_value, log_pdf))
3500 if expected_gradient
is not None:
3504 message=
'Incorrect gradient for tensor type: {}. Expected = {}, Actual = {}' 3505 .format(type(x), expected_gradient, p.grad))
3507 def test_bernoulli_gradient(self):
3508 for tensor_type
in [torch.FloatTensor, torch.DoubleTensor]:
3510 probs=tensor_type([0]),
3512 expected_value=tensor_type([0]),
3513 expected_gradient=tensor_type([0]))
3516 probs=tensor_type([0]),
3518 expected_value=tensor_type([torch.finfo(tensor_type([]).dtype).eps]).log(),
3519 expected_gradient=tensor_type([0]))
3522 probs=tensor_type([1e-4]),
3524 expected_value=tensor_type([math.log(1e-4)]),
3525 expected_gradient=tensor_type([10000]))
3532 probs=tensor_type([1 - 1e-4]),
3534 expected_value=tensor_type([math.log(1e-4)]),
3535 expected_gradient=tensor_type([-10000]),
3539 logits=tensor_type([math.log(9999)]),
3541 expected_value=tensor_type([math.log(1e-4)]),
3542 expected_gradient=tensor_type([-1]),
3545 def test_bernoulli_with_logits_underflow(self):
3546 for tensor_type, lim
in ([(torch.FloatTensor, -1e38),
3547 (torch.DoubleTensor, -1e308)]):
3549 logits=tensor_type([lim]),
3551 expected_value=tensor_type([0]),
3552 expected_gradient=tensor_type([0]))
3554 def test_bernoulli_with_logits_overflow(self):
3555 for tensor_type, lim
in ([(torch.FloatTensor, 1e38),
3556 (torch.DoubleTensor, 1e308)]):
3558 logits=tensor_type([lim]),
3560 expected_value=tensor_type([0]),
3561 expected_gradient=tensor_type([0]))
3563 def test_categorical_log_prob(self):
3564 for dtype
in ([torch.float, torch.double]):
3565 p =
torch.tensor([0, 1], dtype=dtype, requires_grad=
True)
3566 categorical = OneHotCategorical(p)
3567 log_pdf = categorical.log_prob(
torch.tensor([0, 1], dtype=dtype))
3570 def test_categorical_log_prob_with_logits(self):
3571 for dtype
in ([torch.float, torch.double]):
3572 p =
torch.tensor([-inf, 0], dtype=dtype, requires_grad=
True)
3573 categorical = OneHotCategorical(logits=p)
3574 log_pdf_prob_1 = categorical.log_prob(
torch.tensor([0, 1], dtype=dtype))
3576 log_pdf_prob_0 = categorical.log_prob(
torch.tensor([1, 0], dtype=dtype))
3577 self.
assertEqual(log_pdf_prob_0.item(), -inf, allow_inf=
True)
3579 def test_multinomial_log_prob(self):
3580 for dtype
in ([torch.float, torch.double]):
3581 p =
torch.tensor([0, 1], dtype=dtype, requires_grad=
True)
3583 multinomial = Multinomial(10, p)
3584 log_pdf = multinomial.log_prob(s)
3587 def test_multinomial_log_prob_with_logits(self):
3588 for dtype
in ([torch.float, torch.double]):
3589 p =
torch.tensor([-inf, 0], dtype=dtype, requires_grad=
True)
3590 multinomial = Multinomial(10, logits=p)
3591 log_pdf_prob_1 = multinomial.log_prob(
torch.tensor([0, 10], dtype=dtype))
3593 log_pdf_prob_0 = multinomial.log_prob(
torch.tensor([10, 0], dtype=dtype))
3594 self.
assertEqual(log_pdf_prob_0.item(), -inf, allow_inf=
True)
3599 self.
examples = [e
for e
in EXAMPLES
if e.Dist
in 3600 (Categorical, OneHotCategorical, Bernoulli, Binomial, Multinomial)]
3602 def test_lazy_logits_initialization(self):
3605 if 'probs' in param:
3606 probs = param.pop(
'probs')
3607 param[
'logits'] = probs_to_logits(probs)
3608 dist = Dist(**param)
3609 shape = (1,)
if not dist.event_shape
else dist.event_shape
3610 dist.log_prob(torch.ones(shape))
3611 message =
'Failed for {} example 0/{}'.format(Dist.__name__, len(params))
3612 self.assertFalse(
'probs' in vars(dist), msg=message)
3614 dist.enumerate_support()
3615 except NotImplementedError:
3617 self.assertFalse(
'probs' in vars(dist), msg=message)
3618 batch_shape, event_shape = dist.batch_shape, dist.event_shape
3619 self.assertFalse(
'probs' in vars(dist), msg=message)
3621 def test_lazy_probs_initialization(self):
3624 if 'probs' in param:
3625 dist = Dist(**param)
3627 message =
'Failed for {} example 0/{}'.format(Dist.__name__, len(params))
3628 self.assertFalse(
'logits' in vars(dist), msg=message)
3630 dist.enumerate_support()
3631 except NotImplementedError:
3633 self.assertFalse(
'logits' in vars(dist), msg=message)
3634 batch_shape, event_shape = dist.batch_shape, dist.event_shape
3635 self.assertFalse(
'logits' in vars(dist), msg=message)
3638 @unittest.skipIf(
not TEST_NUMPY,
"NumPy not found")
3642 positive_var = torch.randn(20).exp()
3643 positive_var2 = torch.randn(20).exp()
3644 random_var = torch.randn(20)
3645 simplex_tensor = softmax(torch.randn(20), dim=-1)
3648 Bernoulli(simplex_tensor),
3649 scipy.stats.bernoulli(simplex_tensor)
3652 Beta(positive_var, positive_var2),
3653 scipy.stats.beta(positive_var, positive_var2)
3656 Binomial(10, simplex_tensor),
3657 scipy.stats.binom(10 * np.ones(simplex_tensor.shape), simplex_tensor.numpy())
3660 Cauchy(random_var, positive_var),
3661 scipy.stats.cauchy(loc=random_var, scale=positive_var)
3664 Dirichlet(positive_var),
3665 scipy.stats.dirichlet(positive_var)
3668 Exponential(positive_var),
3669 scipy.stats.expon(scale=positive_var.reciprocal())
3672 FisherSnedecor(positive_var, 4 + positive_var2),
3673 scipy.stats.f(positive_var, 4 + positive_var2)
3676 Gamma(positive_var, positive_var2),
3677 scipy.stats.gamma(positive_var, scale=positive_var2.reciprocal())
3680 Geometric(simplex_tensor),
3681 scipy.stats.geom(simplex_tensor, loc=-1)
3684 Gumbel(random_var, positive_var2),
3685 scipy.stats.gumbel_r(random_var, positive_var2)
3688 HalfCauchy(positive_var),
3689 scipy.stats.halfcauchy(scale=positive_var)
3692 HalfNormal(positive_var2),
3693 scipy.stats.halfnorm(scale=positive_var2)
3696 Laplace(random_var, positive_var2),
3697 scipy.stats.laplace(random_var, positive_var2)
3701 LogNormal(random_var, positive_var.clamp(max=3)),
3702 scipy.stats.lognorm(s=positive_var.clamp(max=3), scale=random_var.exp())
3705 LowRankMultivariateNormal(random_var, torch.zeros(20, 1), positive_var2),
3706 scipy.stats.multivariate_normal(random_var, torch.diag(positive_var2))
3709 Multinomial(10, simplex_tensor),
3710 scipy.stats.multinomial(10, simplex_tensor)
3713 MultivariateNormal(random_var, torch.diag(positive_var2)),
3714 scipy.stats.multivariate_normal(random_var, torch.diag(positive_var2))
3717 Normal(random_var, positive_var2),
3718 scipy.stats.norm(random_var, positive_var2)
3721 OneHotCategorical(simplex_tensor),
3722 scipy.stats.multinomial(1, simplex_tensor)
3725 Pareto(positive_var, 2 + positive_var2),
3726 scipy.stats.pareto(2 + positive_var2, scale=positive_var)
3729 Poisson(positive_var),
3730 scipy.stats.poisson(positive_var)
3733 StudentT(2 + positive_var, random_var, positive_var2),
3734 scipy.stats.t(2 + positive_var, random_var, positive_var2)
3737 Uniform(random_var, random_var + positive_var),
3738 scipy.stats.uniform(random_var, positive_var)
3741 Weibull(positive_var[0], positive_var2[0]),
3742 scipy.stats.weibull_min(c=positive_var2[0], scale=positive_var[0])
3746 def test_mean(self):
3748 if isinstance(pytorch_dist, (Cauchy, HalfCauchy)):
3751 elif isinstance(pytorch_dist, (LowRankMultivariateNormal, MultivariateNormal)):
3752 self.
assertEqual(pytorch_dist.mean, scipy_dist.mean, allow_inf=
True, message=pytorch_dist)
3754 self.
assertEqual(pytorch_dist.mean, scipy_dist.mean(), allow_inf=
True, message=pytorch_dist)
3756 def test_variance_stddev(self):
3758 if isinstance(pytorch_dist, (Cauchy, HalfCauchy)):
3761 elif isinstance(pytorch_dist, (Multinomial, OneHotCategorical)):
3762 self.
assertEqual(pytorch_dist.variance, np.diag(scipy_dist.cov()), message=pytorch_dist)
3763 self.
assertEqual(pytorch_dist.stddev, np.diag(scipy_dist.cov()) ** 0.5, message=pytorch_dist)
3764 elif isinstance(pytorch_dist, (LowRankMultivariateNormal, MultivariateNormal)):
3765 self.
assertEqual(pytorch_dist.variance, np.diag(scipy_dist.cov), message=pytorch_dist)
3766 self.
assertEqual(pytorch_dist.stddev, np.diag(scipy_dist.cov) ** 0.5, message=pytorch_dist)
3768 self.
assertEqual(pytorch_dist.variance, scipy_dist.var(), allow_inf=
True, message=pytorch_dist)
3769 self.
assertEqual(pytorch_dist.stddev, scipy_dist.var() ** 0.5, message=pytorch_dist)
3773 samples = pytorch_dist.sample((5,))
3775 cdf = pytorch_dist.cdf(samples)
3776 except NotImplementedError:
3778 self.
assertEqual(cdf, scipy_dist.cdf(samples), message=pytorch_dist)
3780 def test_icdf(self):
3782 samples = torch.rand((5,) + pytorch_dist.batch_shape)
3784 icdf = pytorch_dist.icdf(samples)
3785 except NotImplementedError:
3787 self.
assertEqual(icdf, scipy_dist.ppf(samples), message=pytorch_dist)
3793 transforms_by_cache_size = {}
3794 for cache_size
in [0, 1]:
3796 AbsTransform(cache_size=cache_size),
3797 ExpTransform(cache_size=cache_size),
3798 PowerTransform(exponent=2,
3799 cache_size=cache_size),
3801 cache_size=cache_size),
3802 SigmoidTransform(cache_size=cache_size),
3803 AffineTransform(0, 1, cache_size=cache_size),
3804 AffineTransform(1, -2, cache_size=cache_size),
3805 AffineTransform(torch.randn(5),
3807 cache_size=cache_size),
3808 AffineTransform(torch.randn(4, 5),
3810 cache_size=cache_size),
3811 SoftmaxTransform(cache_size=cache_size),
3812 StickBreakingTransform(cache_size=cache_size),
3813 LowerCholeskyTransform(cache_size=cache_size),
3815 AffineTransform(torch.randn(4, 5),
3817 cache_size=cache_size),
3820 AffineTransform(torch.randn(4, 5),
3822 cache_size=cache_size),
3823 ExpTransform(cache_size=cache_size),
3826 AffineTransform(0, 1, cache_size=cache_size),
3827 AffineTransform(torch.randn(4, 5),
3829 cache_size=cache_size),
3830 AffineTransform(1, -2, cache_size=cache_size),
3831 AffineTransform(torch.randn(4, 5),
3833 cache_size=cache_size),
3836 for t
in transforms[:]:
3837 transforms.append(t.inv)
3838 transforms.append(identity_transform)
3843 def _generate_data(self, transform):
3844 domain = transform.domain
3845 codomain = transform.codomain
3846 x = torch.empty(4, 5)
3847 if domain
is constraints.lower_cholesky
or codomain
is constraints.lower_cholesky:
3848 x = torch.empty(6, 6)
3851 elif domain
is constraints.real:
3853 elif domain
is constraints.positive:
3854 return x.normal_().exp()
3855 elif domain
is constraints.unit_interval:
3857 elif domain
is constraints.simplex:
3858 x = x.normal_().exp()
3859 x /= x.sum(-1,
True)
3861 raise ValueError(
'Unsupported domain: {}'.format(domain))
3863 def test_inv_inv(self):
3865 self.assertTrue(t.inv.inv
is t)
3867 def test_equality(self):
3869 for x, y
in product(transforms, transforms):
3871 self.assertTrue(x == y)
3872 self.assertFalse(x != y)
3874 self.assertFalse(x == y)
3875 self.assertTrue(x != y)
3877 self.assertTrue(identity_transform == identity_transform.inv)
3878 self.assertFalse(identity_transform != identity_transform.inv)
3880 def test_forward_inverse_cache(self):
3885 except NotImplementedError:
3887 x2 = transform.inv(y)
3889 if transform.bijective:
3892 '{} t.inv(t(-)) error'.format(transform),
3894 'y = t(x) = {}'.format(y),
3895 'x2 = t.inv(y) = {}'.format(x2),
3900 '{} t(t.inv(t(-))) error'.format(transform),
3902 'y = t(x) = {}'.format(y),
3903 'x2 = t.inv(y) = {}'.format(x2),
3904 'y2 = t(x2) = {}'.format(y2),
3907 def test_forward_inverse_no_cache(self):
3912 x2 = transform.inv(y.clone())
3914 except NotImplementedError:
3916 if transform.bijective:
3919 '{} t.inv(t(-)) error'.format(transform),
3921 'y = t(x) = {}'.format(y),
3922 'x2 = t.inv(y) = {}'.format(x2),
3927 '{} t(t.inv(t(-))) error'.format(transform),
3929 'y = t(x) = {}'.format(y),
3930 'x2 = t.inv(y) = {}'.format(x2),
3931 'y2 = t(x2) = {}'.format(y2),
3934 def test_univariate_forward_jacobian(self):
3936 if transform.event_dim > 0:
3941 actual = transform.log_abs_det_jacobian(x, y)
3942 except NotImplementedError:
3944 expected = torch.abs(grad([y.sum()], [x])[0]).log()
3945 self.
assertEqual(actual, expected, message=
'\n'.join([
3946 'Bad {}.log_abs_det_jacobian() disagrees with ()'.format(transform),
3947 'Expected: {}'.format(expected),
3948 'Actual: {}'.format(actual),
3951 def test_univariate_inverse_jacobian(self):
3953 if transform.event_dim > 0:
3957 x = transform.inv(y)
3958 actual = transform.log_abs_det_jacobian(x, y)
3959 except NotImplementedError:
3961 expected = -torch.abs(grad([x.sum()], [y])[0]).log()
3962 self.
assertEqual(actual, expected, message=
'\n'.join([
3963 '{}.log_abs_det_jacobian() disagrees with .inv()'.format(transform),
3964 'Expected: {}'.format(expected),
3965 'Actual: {}'.format(actual),
3968 def test_jacobian_shape(self):
3973 actual = transform.log_abs_det_jacobian(x, y)
3974 except NotImplementedError:
3976 self.
assertEqual(actual.shape, x.shape[:x.dim() - transform.event_dim])
3978 def test_transform_shapes(self):
3979 transform0 = ExpTransform()
3980 transform1 = SoftmaxTransform()
3981 transform2 = LowerCholeskyTransform()
3986 self.
assertEqual(ComposeTransform([transform0, transform1]).event_dim, 1)
3987 self.
assertEqual(ComposeTransform([transform0, transform2]).event_dim, 2)
3988 self.
assertEqual(ComposeTransform([transform1, transform2]).event_dim, 2)
3990 def test_transformed_distribution_shapes(self):
3991 transform0 = ExpTransform()
3992 transform1 = SoftmaxTransform()
3993 transform2 = LowerCholeskyTransform()
3994 base_dist0 = Normal(torch.zeros(4, 4), torch.ones(4, 4))
3995 base_dist1 = Dirichlet(torch.ones(4, 4))
3996 base_dist2 = Normal(torch.zeros(3, 4, 4), torch.ones(3, 4, 4))
3998 ((4, 4), (), base_dist0),
3999 ((4,), (4,), base_dist1),
4000 ((4, 4), (), TransformedDistribution(base_dist0, [transform0])),
4001 ((4,), (4,), TransformedDistribution(base_dist0, [transform1])),
4002 ((4,), (4,), TransformedDistribution(base_dist0, [transform0, transform1])),
4003 ((), (4, 4), TransformedDistribution(base_dist0, [transform0, transform2])),
4004 ((4,), (4,), TransformedDistribution(base_dist0, [transform1, transform0])),
4005 ((), (4, 4), TransformedDistribution(base_dist0, [transform1, transform2])),
4006 ((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform0])),
4007 ((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform1])),
4008 ((4,), (4,), TransformedDistribution(base_dist1, [transform0])),
4009 ((4,), (4,), TransformedDistribution(base_dist1, [transform1])),
4010 ((), (4, 4), TransformedDistribution(base_dist1, [transform2])),
4011 ((4,), (4,), TransformedDistribution(base_dist1, [transform0, transform1])),
4012 ((), (4, 4), TransformedDistribution(base_dist1, [transform0, transform2])),
4013 ((4,), (4,), TransformedDistribution(base_dist1, [transform1, transform0])),
4014 ((), (4, 4), TransformedDistribution(base_dist1, [transform1, transform2])),
4015 ((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform0])),
4016 ((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform1])),
4017 ((3, 4, 4), (), base_dist2),
4018 ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2])),
4019 ((3,), (4, 4), TransformedDistribution(base_dist2, [transform0, transform2])),
4020 ((3,), (4, 4), TransformedDistribution(base_dist2, [transform1, transform2])),
4021 ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform0])),
4022 ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform1])),
4024 for batch_shape, event_shape, dist
in examples:
4030 except NotImplementedError:
4033 def test_jit_fwd(self):
4042 except NotImplementedError:
4049 def test_jit_inv(self):
4054 return transform.inv(y)
4058 except NotImplementedError:
4065 def test_jit_jacobian(self):
4071 return transform.log_abs_det_jacobian(x, y)
4075 except NotImplementedError:
4084 def get_constraints(self, is_cuda=False):
4085 tensor = torch.cuda.DoubleTensor
if is_cuda
else torch.DoubleTensor
4088 constraints.positive,
4089 constraints.greater_than(tensor([-10., -2, 0, 2, 10])),
4090 constraints.greater_than(0),
4091 constraints.greater_than(2),
4092 constraints.greater_than(-2),
4093 constraints.greater_than_eq(0),
4094 constraints.greater_than_eq(2),
4095 constraints.greater_than_eq(-2),
4096 constraints.less_than(tensor([-10., -2, 0, 2, 10])),
4097 constraints.less_than(0),
4098 constraints.less_than(2),
4099 constraints.less_than(-2),
4100 constraints.unit_interval,
4101 constraints.interval(tensor([-4., -2, 0, 2, 4]),
4102 tensor([-3., 3, 1, 5, 5])),
4103 constraints.interval(-2, -1),
4104 constraints.interval(1, 2),
4105 constraints.half_open_interval(tensor([-4., -2, 0, 2, 4]),
4106 tensor([-3., 3, 1, 5, 5])),
4107 constraints.half_open_interval(-2, -1),
4108 constraints.half_open_interval(1, 2),
4109 constraints.simplex,
4110 constraints.lower_cholesky,
4113 def test_biject_to(self):
4116 t = biject_to(constraint)
4117 except NotImplementedError:
4119 self.assertTrue(t.bijective,
"biject_to({}) is not bijective".format(constraint))
4120 x = torch.randn(5, 5)
4122 self.assertTrue(constraint.check(y).all(),
'\n'.join([
4123 "Failed to biject_to({})".format(constraint),
4125 "biject_to(...)(x) = {}".format(y),
4128 self.
assertEqual(x, x2, message=
"Error in biject_to({}) inverse".format(constraint))
4130 j = t.log_abs_det_jacobian(x, y)
4131 self.
assertEqual(j.shape, x.shape[:x.dim() - t.event_dim])
4133 @unittest.skipIf(
not TEST_CUDA,
"CUDA not found")
4134 def test_biject_to_cuda(self):
4137 t = biject_to(constraint)
4138 except NotImplementedError:
4140 self.assertTrue(t.bijective,
"biject_to({}) is not bijective".format(constraint))
4142 x = torch.randn(5, 5).cuda()
4144 self.assertTrue(constraint.check(y).all(),
'\n'.join([
4145 "Failed to biject_to({})".format(constraint),
4147 "biject_to(...)(x) = {}".format(y),
4150 self.
assertEqual(x, x2, message=
"Error in biject_to({}) inverse".format(constraint))
4152 j = t.log_abs_det_jacobian(x, y)
4153 self.
assertEqual(j.shape, x.shape[:x.dim() - t.event_dim])
4155 def test_transform_to(self):
4157 t = transform_to(constraint)
4158 x = torch.randn(5, 5)
4160 self.assertTrue(constraint.check(y).all(),
"Failed to transform_to({})".format(constraint))
4163 self.
assertEqual(y, y2, message=
"Error in transform_to({}) pseudoinverse".format(constraint))
4165 @unittest.skipIf(
not TEST_CUDA,
"CUDA not found")
4166 def test_transform_to_cuda(self):
4168 t = transform_to(constraint)
4170 x = torch.randn(5, 5).cuda()
4172 self.assertTrue(constraint.check(y).all(),
"Failed to transform_to({})".format(constraint))
4175 self.
assertEqual(y, y2, message=
"Error in transform_to({}) pseudoinverse".format(constraint))
4180 super(TestCase, self).setUp()
4181 Distribution.set_default_validate_args(
True)
4183 def test_valid(self):
4184 for Dist, params
in EXAMPLES:
4185 for param
in params:
4186 Dist(validate_args=
True, **param)
4188 @unittest.skipIf(TEST_WITH_UBSAN,
"division-by-zero error with UBSAN")
4189 def test_invalid(self):
4190 for Dist, params
in BAD_EXAMPLES:
4191 for i, param
in enumerate(params):
4193 with self.assertRaises(ValueError):
4194 Dist(validate_args=
True, **param)
4195 except AssertionError:
4196 fail_string =
'ValueError not raised for {} example {}/{}' 4197 raise AssertionError(fail_string.format(Dist.__name__, i + 1, len(params)))
4200 super(TestCase, self).tearDown()
4201 Distribution.set_default_validate_args(
False)
4205 def _examples(self):
4206 for Dist, params
in EXAMPLES:
4207 for param
in params:
4209 values = tuple(param[key]
for key
in keys)
4210 if not all(isinstance(x, torch.Tensor)
for x
in values):
4212 sample = Dist(**param).sample()
4213 yield Dist, keys, values, sample
4215 def _perturb_tensor(self, value, constraint):
4216 if isinstance(constraint, constraints._IntegerGreaterThan):
4218 if isinstance(constraint, constraints._PositiveDefinite):
4219 return value + torch.eye(value.shape[-1])
4220 if value.dtype
in [torch.float, torch.double]:
4221 transform = transform_to(constraint)
4222 delta = value.new(value.shape).normal_()
4223 return transform(transform.inv(value) + delta)
4224 if value.dtype == torch.long:
4225 result = value.clone()
4226 result[value == 0] = 1
4227 result[value == 1] = 0
4229 raise NotImplementedError
4231 def _perturb(self, Dist, keys, values, sample):
4232 with torch.no_grad():
4234 param = dict(zip(keys, values))
4235 param[
'low'] = param[
'low'] - torch.rand(param[
'low'].shape)
4236 param[
'high'] = param[
'high'] + torch.rand(param[
'high'].shape)
4237 values = [param[key]
for key
in keys]
4239 values = [self.
_perturb_tensor(value, Dist.arg_constraints.get(key, constraints.real))
4240 for key, value
in zip(keys, values)]
4241 param = dict(zip(keys, values))
4242 sample = Dist(**param).sample()
4243 return values, sample
4245 def test_sample(self):
4246 for Dist, keys, values, sample
in self.
_examples():
4249 param = dict(zip(keys, values))
4250 dist = Dist(**param)
4251 return dist.sample()
4265 traced_sample = traced_f(*values)
4269 xfail = [Beta, Dirichlet]
4270 if Dist
not in xfail:
4271 self.assertTrue(any(n.isNondeterministic()
for n
in traced_f.graph.nodes()))
4273 def test_rsample(self):
4274 for Dist, keys, values, sample
in self.
_examples():
4275 if not Dist.has_rsample:
4279 param = dict(zip(keys, values))
4280 dist = Dist(**param)
4281 return dist.rsample()
4295 traced_sample = traced_f(*values)
4299 xfail = [Beta, Dirichlet]
4300 if Dist
not in xfail:
4301 self.assertTrue(any(n.isNondeterministic()
for n
in traced_f.graph.nodes()))
4303 def test_log_prob(self):
4304 for Dist, keys, values, sample
in self.
_examples():
4306 xfail = [LowRankMultivariateNormal, MultivariateNormal]
4310 def f(sample, *values):
4311 param = dict(zip(keys, values))
4312 dist = Dist(**param)
4313 return dist.log_prob(sample)
4318 values, sample = self.
_perturb(Dist, keys, values, sample)
4319 expected = f(sample, *values)
4320 actual = traced_f(sample, *values)
4322 message=
'{}\nExpected:\n{}\nActual:\n{}'.format(Dist.__name__, expected, actual))
4324 def test_enumerate_support(self):
4325 for Dist, keys, values, sample
in self.
_examples():
4332 param = dict(zip(keys, values))
4333 dist = Dist(**param)
4334 return dist.enumerate_support()
4338 except NotImplementedError:
4342 values, sample = self.
_perturb(Dist, keys, values, sample)
4343 expected = f(*values)
4344 actual = traced_f(*values)
4346 message=
'{}\nExpected:\n{}\nActual:\n{}'.format(Dist.__name__, expected, actual))
4348 def test_mean(self):
4349 for Dist, keys, values, sample
in self.
_examples():
4352 param = dict(zip(keys, values))
4353 dist = Dist(**param)
4358 except NotImplementedError:
4362 values, sample = self.
_perturb(Dist, keys, values, sample)
4363 expected = f(*values)
4364 actual = traced_f(*values)
4365 expected[expected == float(
'inf')] = 0.
4366 actual[actual == float(
'inf')] = 0.
4367 self.
assertEqual(expected, actual, allow_inf=
True,
4368 message=
'{}\nExpected:\n{}\nActual:\n{}'.format(Dist.__name__, expected, actual))
4370 def test_variance(self):
4371 for Dist, keys, values, sample
in self.
_examples():
4372 if Dist
in [Cauchy, HalfCauchy]:
4376 param = dict(zip(keys, values))
4377 dist = Dist(**param)
4378 return dist.variance
4382 except NotImplementedError:
4386 values, sample = self.
_perturb(Dist, keys, values, sample)
4387 expected = f(*values)
4388 actual = traced_f(*values)
4389 expected[expected == float(
'inf')] = 0.
4390 actual[actual == float(
'inf')] = 0.
4391 self.
assertEqual(expected, actual, allow_inf=
True,
4392 message=
'{}\nExpected:\n{}\nActual:\n{}'.format(Dist.__name__, expected, actual))
4394 def test_entropy(self):
4395 for Dist, keys, values, sample
in self.
_examples():
4397 xfail = [LowRankMultivariateNormal, MultivariateNormal]
4402 param = dict(zip(keys, values))
4403 dist = Dist(**param)
4404 return dist.entropy()
4408 except NotImplementedError:
4412 values, sample = self.
_perturb(Dist, keys, values, sample)
4413 expected = f(*values)
4414 actual = traced_f(*values)
4415 self.
assertEqual(expected, actual, allow_inf=
True,
4416 message=
'{}\nExpected:\n{}\nActual:\n{}'.format(Dist.__name__, expected, actual))
4419 for Dist, keys, values, sample
in self.
_examples():
4421 def f(sample, *values):
4422 param = dict(zip(keys, values))
4423 dist = Dist(**param)
4424 cdf = dist.cdf(sample)
4425 return dist.icdf(cdf)
4429 except NotImplementedError:
4433 values, sample = self.
_perturb(Dist, keys, values, sample)
4434 expected = f(sample, *values)
4435 actual = traced_f(sample, *values)
4436 self.
assertEqual(expected, actual, allow_inf=
True,
4437 message=
'{}\nExpected:\n{}\nActual:\n{}'.format(Dist.__name__, expected, actual))
4440 if __name__ ==
'__main__' and torch._C.has_lapack:
def assertEqual(self, x, y, prec=None, message='', allow_inf=False)
def pairwise(Dist, params)
def trace(func, example_inputs, optimize=True, check_trace=True, check_inputs=None, check_tolerance=1e-5, _force_outplace=False, _module_class=None)
def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices")
def _perturb_tensor(self, value, constraint)
def get_constraints(self, is_cuda=False)
def _perturb(self, Dist, keys, values, sample)
def _test_pdf_score(self, dist_class, x, expected_value, probs=None, logits=None, expected_gradient=None, prec=1e-5)