Caffe2 - Python API
A deep learning, cross platform ML framework
test_distributions.py
1 """
2 Note [Randomized statistical tests]
3 -----------------------------------
4 
5 This note describes how to maintain tests in this file as random sources
6 change. This file contains two types of randomized tests:
7 
8 1. The easier type of randomized test are tests that should always pass but are
9  initialized with random data. If these fail something is wrong, but it's
10  fine to use a fixed seed by inheriting from common.TestCase.
11 
12 2. The trickier tests are statistical tests. These tests explicitly call
13  set_rng_seed(n) and are marked "see Note [Randomized statistical tests]".
14  These statistical tests have a known positive failure rate
15  (we set failure_rate=1e-3 by default). We need to balance strength of these
16  tests with annoyance of false alarms. One way that works is to specifically
17  set seeds in each of the randomized tests. When a random generator
18  occasionally changes (as in #4312 vectorizing the Box-Muller sampler), some
19  of these statistical tests may (rarely) fail. If one fails in this case,
20  it's fine to increment the seed of the failing test (but you shouldn't need
21  to increment it more than once; otherwise something is probably actually
22  wrong).
23 """
24 
25 import math
26 import numbers
27 import unittest
28 from collections import namedtuple
29 from itertools import product
30 from random import shuffle
31 
32 import torch
33 from torch._six import inf
34 from common_utils import TestCase, run_tests, set_rng_seed, TEST_WITH_UBSAN, load_tests, skipIfRocm
35 from common_cuda import TEST_CUDA
36 from torch.autograd import grad, gradcheck
37 from torch.distributions import (Bernoulli, Beta, Binomial, Categorical,
38  Cauchy, Chi2, Dirichlet, Distribution,
39  Exponential, ExponentialFamily,
40  FisherSnedecor, Gamma, Geometric, Gumbel,
41  HalfCauchy, HalfNormal,
42  Independent, Laplace, LogisticNormal,
43  LogNormal, LowRankMultivariateNormal,
44  Multinomial, MultivariateNormal,
45  NegativeBinomial, Normal, OneHotCategorical, Pareto,
46  Poisson, RelaxedBernoulli, RelaxedOneHotCategorical,
47  StudentT, TransformedDistribution, Uniform,
48  Weibull, constraints, kl_divergence)
49 from torch.distributions.constraint_registry import biject_to, transform_to
50 from torch.distributions.constraints import Constraint, is_dependent
51 from torch.distributions.dirichlet import _Dirichlet_backward
52 from torch.distributions.kl import _kl_expfamily_expfamily
53 from torch.distributions.transforms import (AbsTransform, AffineTransform,
54  ComposeTransform, ExpTransform,
55  LowerCholeskyTransform,
56  PowerTransform, SigmoidTransform,
57  SoftmaxTransform,
58  StickBreakingTransform,
59  identity_transform)
60 from torch.distributions.utils import probs_to_logits, lazy_property
61 from torch.nn.functional import softmax
62 
63 # load_tests from common_utils is used to automatically filter tests for
64 # sharding on sandcastle. This line silences flake warnings
65 load_tests = load_tests
66 
67 TEST_NUMPY = True
68 try:
69  import numpy as np
70  import scipy.stats
71  import scipy.special
72 except ImportError:
73  TEST_NUMPY = False
74 
75 
76 def pairwise(Dist, *params):
77  """
78  Creates a pair of distributions `Dist` initialzed to test each element of
79  param with each other.
80  """
81  params1 = [torch.tensor([p] * len(p)) for p in params]
82  params2 = [p.transpose(0, 1) for p in params1]
83  return Dist(*params1), Dist(*params2)
84 
85 
86 def is_all_nan(tensor):
87  """
88  Checks if all entries of a tensor is nan.
89  """
90  return (tensor != tensor).all()
91 
92 
93 # Register all distributions for generic tests.
94 Example = namedtuple('Example', ['Dist', 'params'])
95 EXAMPLES = [
96  Example(Bernoulli, [
97  {'probs': torch.tensor([0.7, 0.2, 0.4], requires_grad=True)},
98  {'probs': torch.tensor([0.3], requires_grad=True)},
99  {'probs': 0.3},
100  {'logits': torch.tensor([0.], requires_grad=True)},
101  ]),
102  Example(Geometric, [
103  {'probs': torch.tensor([0.7, 0.2, 0.4], requires_grad=True)},
104  {'probs': torch.tensor([0.3], requires_grad=True)},
105  {'probs': 0.3},
106  ]),
107  Example(Beta, [
108  {
109  'concentration1': torch.randn(2, 3).exp().requires_grad_(),
110  'concentration0': torch.randn(2, 3).exp().requires_grad_(),
111  },
112  {
113  'concentration1': torch.randn(4).exp().requires_grad_(),
114  'concentration0': torch.randn(4).exp().requires_grad_(),
115  },
116  ]),
117  Example(Categorical, [
118  {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True)},
119  {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)},
120  {'logits': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
121  ]),
122  Example(Binomial, [
123  {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True), 'total_count': 10},
124  {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': 10},
125  {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': torch.tensor([10])},
126  {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': torch.tensor([10, 8])},
127  {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True),
128  'total_count': torch.tensor([[10., 8.], [5., 3.]])},
129  {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True),
130  'total_count': torch.tensor(0.)},
131  ]),
132  Example(NegativeBinomial, [
133  {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True), 'total_count': 10},
134  {'probs': torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True), 'total_count': 10},
135  {'probs': torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True), 'total_count': torch.tensor([10])},
136  {'probs': torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True), 'total_count': torch.tensor([10, 8])},
137  {'probs': torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True),
138  'total_count': torch.tensor([[10., 8.], [5., 3.]])},
139  {'probs': torch.tensor([[0.9, 0.0], [0.0, 0.9]], requires_grad=True),
140  'total_count': torch.tensor(0.)},
141  ]),
142  Example(Multinomial, [
143  {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True), 'total_count': 10},
144  {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True), 'total_count': 10},
145  ]),
146  Example(Cauchy, [
147  {'loc': 0.0, 'scale': 1.0},
148  {'loc': torch.tensor([0.0]), 'scale': 1.0},
149  {'loc': torch.tensor([[0.0], [0.0]]),
150  'scale': torch.tensor([[1.0], [1.0]])}
151  ]),
152  Example(Chi2, [
153  {'df': torch.randn(2, 3).exp().requires_grad_()},
154  {'df': torch.randn(1).exp().requires_grad_()},
155  ]),
156  Example(StudentT, [
157  {'df': torch.randn(2, 3).exp().requires_grad_()},
158  {'df': torch.randn(1).exp().requires_grad_()},
159  ]),
160  Example(Dirichlet, [
161  {'concentration': torch.randn(2, 3).exp().requires_grad_()},
162  {'concentration': torch.randn(4).exp().requires_grad_()},
163  ]),
164  Example(Exponential, [
165  {'rate': torch.randn(5, 5).abs().requires_grad_()},
166  {'rate': torch.randn(1).abs().requires_grad_()},
167  ]),
168  Example(FisherSnedecor, [
169  {
170  'df1': torch.randn(5, 5).abs().requires_grad_(),
171  'df2': torch.randn(5, 5).abs().requires_grad_(),
172  },
173  {
174  'df1': torch.randn(1).abs().requires_grad_(),
175  'df2': torch.randn(1).abs().requires_grad_(),
176  },
177  {
178  'df1': torch.tensor([1.0]),
179  'df2': 1.0,
180  }
181  ]),
182  Example(Gamma, [
183  {
184  'concentration': torch.randn(2, 3).exp().requires_grad_(),
185  'rate': torch.randn(2, 3).exp().requires_grad_(),
186  },
187  {
188  'concentration': torch.randn(1).exp().requires_grad_(),
189  'rate': torch.randn(1).exp().requires_grad_(),
190  },
191  ]),
192  Example(Gumbel, [
193  {
194  'loc': torch.randn(5, 5, requires_grad=True),
195  'scale': torch.randn(5, 5).abs().requires_grad_(),
196  },
197  {
198  'loc': torch.randn(1, requires_grad=True),
199  'scale': torch.randn(1).abs().requires_grad_(),
200  },
201  ]),
202  Example(HalfCauchy, [
203  {'scale': 1.0},
204  {'scale': torch.tensor([[1.0], [1.0]])}
205  ]),
206  Example(HalfNormal, [
207  {'scale': torch.randn(5, 5).abs().requires_grad_()},
208  {'scale': torch.randn(1).abs().requires_grad_()},
209  {'scale': torch.tensor([1e-5, 1e-5], requires_grad=True)}
210  ]),
211  Example(Independent, [
212  {
213  'base_distribution': Normal(torch.randn(2, 3, requires_grad=True),
214  torch.randn(2, 3).abs().requires_grad_()),
215  'reinterpreted_batch_ndims': 0,
216  },
217  {
218  'base_distribution': Normal(torch.randn(2, 3, requires_grad=True),
219  torch.randn(2, 3).abs().requires_grad_()),
220  'reinterpreted_batch_ndims': 1,
221  },
222  {
223  'base_distribution': Normal(torch.randn(2, 3, requires_grad=True),
224  torch.randn(2, 3).abs().requires_grad_()),
225  'reinterpreted_batch_ndims': 2,
226  },
227  {
228  'base_distribution': Normal(torch.randn(2, 3, 5, requires_grad=True),
229  torch.randn(2, 3, 5).abs().requires_grad_()),
230  'reinterpreted_batch_ndims': 2,
231  },
232  {
233  'base_distribution': Normal(torch.randn(2, 3, 5, requires_grad=True),
234  torch.randn(2, 3, 5).abs().requires_grad_()),
235  'reinterpreted_batch_ndims': 3,
236  },
237  ]),
238  Example(Laplace, [
239  {
240  'loc': torch.randn(5, 5, requires_grad=True),
241  'scale': torch.randn(5, 5).abs().requires_grad_(),
242  },
243  {
244  'loc': torch.randn(1, requires_grad=True),
245  'scale': torch.randn(1).abs().requires_grad_(),
246  },
247  {
248  'loc': torch.tensor([1.0, 0.0], requires_grad=True),
249  'scale': torch.tensor([1e-5, 1e-5], requires_grad=True),
250  },
251  ]),
252  Example(LogNormal, [
253  {
254  'loc': torch.randn(5, 5, requires_grad=True),
255  'scale': torch.randn(5, 5).abs().requires_grad_(),
256  },
257  {
258  'loc': torch.randn(1, requires_grad=True),
259  'scale': torch.randn(1).abs().requires_grad_(),
260  },
261  {
262  'loc': torch.tensor([1.0, 0.0], requires_grad=True),
263  'scale': torch.tensor([1e-5, 1e-5], requires_grad=True),
264  },
265  ]),
266  Example(LogisticNormal, [
267  {
268  'loc': torch.randn(5, 5).requires_grad_(),
269  'scale': torch.randn(5, 5).abs().requires_grad_(),
270  },
271  {
272  'loc': torch.randn(1).requires_grad_(),
273  'scale': torch.randn(1).abs().requires_grad_(),
274  },
275  {
276  'loc': torch.tensor([1.0, 0.0], requires_grad=True),
277  'scale': torch.tensor([1e-5, 1e-5], requires_grad=True),
278  },
279  ]),
280  Example(LowRankMultivariateNormal, [
281  {
282  'loc': torch.randn(5, 2, requires_grad=True),
283  'cov_factor': torch.randn(5, 2, 1, requires_grad=True),
284  'cov_diag': torch.tensor([2.0, 0.25], requires_grad=True),
285  },
286  {
287  'loc': torch.randn(4, 3, requires_grad=True),
288  'cov_factor': torch.randn(3, 2, requires_grad=True),
289  'cov_diag': torch.tensor([5.0, 1.5, 3.], requires_grad=True),
290  }
291  ]),
292  Example(MultivariateNormal, [
293  {
294  'loc': torch.randn(5, 2, requires_grad=True),
295  'covariance_matrix': torch.tensor([[2.0, 0.3], [0.3, 0.25]], requires_grad=True),
296  },
297  {
298  'loc': torch.randn(2, 3, requires_grad=True),
299  'precision_matrix': torch.tensor([[2.0, 0.1, 0.0],
300  [0.1, 0.25, 0.0],
301  [0.0, 0.0, 0.3]], requires_grad=True),
302  },
303  {
304  'loc': torch.randn(5, 3, 2, requires_grad=True),
305  'scale_tril': torch.tensor([[[2.0, 0.0], [-0.5, 0.25]],
306  [[2.0, 0.0], [0.3, 0.25]],
307  [[5.0, 0.0], [-0.5, 1.5]]], requires_grad=True),
308  },
309  {
310  'loc': torch.tensor([1.0, -1.0]),
311  'covariance_matrix': torch.tensor([[5.0, -0.5], [-0.5, 1.5]]),
312  },
313  ]),
314  Example(Normal, [
315  {
316  'loc': torch.randn(5, 5, requires_grad=True),
317  'scale': torch.randn(5, 5).abs().requires_grad_(),
318  },
319  {
320  'loc': torch.randn(1, requires_grad=True),
321  'scale': torch.randn(1).abs().requires_grad_(),
322  },
323  {
324  'loc': torch.tensor([1.0, 0.0], requires_grad=True),
325  'scale': torch.tensor([1e-5, 1e-5], requires_grad=True),
326  },
327  ]),
328  Example(OneHotCategorical, [
329  {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True)},
330  {'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]], requires_grad=True)},
331  {'logits': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
332  ]),
333  Example(Pareto, [
334  {
335  'scale': 1.0,
336  'alpha': 1.0
337  },
338  {
339  'scale': torch.randn(5, 5).abs().requires_grad_(),
340  'alpha': torch.randn(5, 5).abs().requires_grad_()
341  },
342  {
343  'scale': torch.tensor([1.0]),
344  'alpha': 1.0
345  }
346  ]),
347  Example(Poisson, [
348  {
349  'rate': torch.randn(5, 5).abs().requires_grad_(),
350  },
351  {
352  'rate': torch.randn(3).abs().requires_grad_(),
353  },
354  {
355  'rate': 0.2,
356  }
357  ]),
358  Example(RelaxedBernoulli, [
359  {
360  'temperature': torch.tensor([0.5], requires_grad=True),
361  'probs': torch.tensor([0.7, 0.2, 0.4], requires_grad=True),
362  },
363  {
364  'temperature': torch.tensor([2.0]),
365  'probs': torch.tensor([0.3]),
366  },
367  {
368  'temperature': torch.tensor([7.2]),
369  'logits': torch.tensor([-2.0, 2.0, 1.0, 5.0])
370  }
371  ]),
372  Example(RelaxedOneHotCategorical, [
373  {
374  'temperature': torch.tensor([0.5], requires_grad=True),
375  'probs': torch.tensor([[0.1, 0.2, 0.7], [0.5, 0.3, 0.2]], requires_grad=True)
376  },
377  {
378  'temperature': torch.tensor([2.0]),
379  'probs': torch.tensor([[1.0, 0.0], [0.0, 1.0]])
380  },
381  {
382  'temperature': torch.tensor([7.2]),
383  'logits': torch.tensor([[-2.0, 2.0], [1.0, 5.0]])
384  }
385  ]),
386  Example(TransformedDistribution, [
387  {
388  'base_distribution': Normal(torch.randn(2, 3, requires_grad=True),
389  torch.randn(2, 3).abs().requires_grad_()),
390  'transforms': [],
391  },
392  {
393  'base_distribution': Normal(torch.randn(2, 3, requires_grad=True),
394  torch.randn(2, 3).abs().requires_grad_()),
395  'transforms': ExpTransform(),
396  },
397  {
398  'base_distribution': Normal(torch.randn(2, 3, 5, requires_grad=True),
399  torch.randn(2, 3, 5).abs().requires_grad_()),
400  'transforms': [AffineTransform(torch.randn(3, 5), torch.randn(3, 5)),
401  ExpTransform()],
402  },
403  {
404  'base_distribution': Normal(torch.randn(2, 3, 5, requires_grad=True),
405  torch.randn(2, 3, 5).abs().requires_grad_()),
406  'transforms': AffineTransform(1, 2),
407  },
408  ]),
409  Example(Uniform, [
410  {
411  'low': torch.zeros(5, 5, requires_grad=True),
412  'high': torch.ones(5, 5, requires_grad=True),
413  },
414  {
415  'low': torch.zeros(1, requires_grad=True),
416  'high': torch.ones(1, requires_grad=True),
417  },
418  {
419  'low': torch.tensor([1.0, 1.0], requires_grad=True),
420  'high': torch.tensor([2.0, 3.0], requires_grad=True),
421  },
422  ]),
423  Example(Weibull, [
424  {
425  'scale': torch.randn(5, 5).abs().requires_grad_(),
426  'concentration': torch.randn(1).abs().requires_grad_()
427  }
428  ])
429 ]
430 
431 BAD_EXAMPLES = [
432  Example(Bernoulli, [
433  {'probs': torch.tensor([1.1, 0.2, 0.4], requires_grad=True)},
434  {'probs': torch.tensor([-0.5], requires_grad=True)},
435  {'probs': 1.00001},
436  ]),
437  Example(Beta, [
438  {
439  'concentration1': torch.tensor([0.0], requires_grad=True),
440  'concentration0': torch.tensor([0.0], requires_grad=True),
441  },
442  {
443  'concentration1': torch.tensor([-1.0], requires_grad=True),
444  'concentration0': torch.tensor([-2.0], requires_grad=True),
445  },
446  ]),
447  Example(Geometric, [
448  {'probs': torch.tensor([1.1, 0.2, 0.4], requires_grad=True)},
449  {'probs': torch.tensor([-0.3], requires_grad=True)},
450  {'probs': 1.00000001},
451  ]),
452  Example(Categorical, [
453  {'probs': torch.tensor([[-0.1, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True)},
454  {'probs': torch.tensor([[-1.0, 10.0], [0.0, -1.0]], requires_grad=True)},
455  ]),
456  Example(Binomial, [
457  {'probs': torch.tensor([[-0.0000001, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True),
458  'total_count': 10},
459  {'probs': torch.tensor([[1.0, 0.0], [0.0, 2.0]], requires_grad=True),
460  'total_count': 10},
461  ]),
462  Example(NegativeBinomial, [
463  {'probs': torch.tensor([[-0.0000001, 0.2, 0.3], [0.5, 0.3, 0.2]], requires_grad=True),
464  'total_count': 10},
465  {'probs': torch.tensor([[1.0, 0.0], [0.0, 2.0]], requires_grad=True),
466  'total_count': 10},
467  ]),
468  Example(Cauchy, [
469  {'loc': 0.0, 'scale': -1.0},
470  {'loc': torch.tensor([0.0]), 'scale': 0.0},
471  {'loc': torch.tensor([[0.0], [-2.0]]),
472  'scale': torch.tensor([[-0.000001], [1.0]])}
473  ]),
474  Example(Chi2, [
475  {'df': torch.tensor([0.], requires_grad=True)},
476  {'df': torch.tensor([-2.], requires_grad=True)},
477  ]),
478  Example(StudentT, [
479  {'df': torch.tensor([0.], requires_grad=True)},
480  {'df': torch.tensor([-2.], requires_grad=True)},
481  ]),
482  Example(Dirichlet, [
483  {'concentration': torch.tensor([0.], requires_grad=True)},
484  {'concentration': torch.tensor([-2.], requires_grad=True)}
485  ]),
486  Example(Exponential, [
487  {'rate': torch.tensor([0., 0.], requires_grad=True)},
488  {'rate': torch.tensor([-2.], requires_grad=True)}
489  ]),
490  Example(FisherSnedecor, [
491  {
492  'df1': torch.tensor([0., 0.], requires_grad=True),
493  'df2': torch.tensor([-1., -100.], requires_grad=True),
494  },
495  {
496  'df1': torch.tensor([1., 1.], requires_grad=True),
497  'df2': torch.tensor([0., 0.], requires_grad=True),
498  }
499  ]),
500  Example(Gamma, [
501  {
502  'concentration': torch.tensor([0., 0.], requires_grad=True),
503  'rate': torch.tensor([-1., -100.], requires_grad=True),
504  },
505  {
506  'concentration': torch.tensor([1., 1.], requires_grad=True),
507  'rate': torch.tensor([0., 0.], requires_grad=True),
508  }
509  ]),
510  Example(Gumbel, [
511  {
512  'loc': torch.tensor([1., 1.], requires_grad=True),
513  'scale': torch.tensor([0., 1.], requires_grad=True),
514  },
515  {
516  'loc': torch.tensor([1., 1.], requires_grad=True),
517  'scale': torch.tensor([1., -1.], requires_grad=True),
518  },
519  ]),
520  Example(HalfCauchy, [
521  {'scale': -1.0},
522  {'scale': 0.0},
523  {'scale': torch.tensor([[-0.000001], [1.0]])}
524  ]),
525  Example(HalfNormal, [
526  {'scale': torch.tensor([0., 1.], requires_grad=True)},
527  {'scale': torch.tensor([1., -1.], requires_grad=True)},
528  ]),
529  Example(Laplace, [
530  {
531  'loc': torch.tensor([1., 1.], requires_grad=True),
532  'scale': torch.tensor([0., 1.], requires_grad=True),
533  },
534  {
535  'loc': torch.tensor([1., 1.], requires_grad=True),
536  'scale': torch.tensor([1., -1.], requires_grad=True),
537  },
538  ]),
539  Example(LogNormal, [
540  {
541  'loc': torch.tensor([1., 1.], requires_grad=True),
542  'scale': torch.tensor([0., 1.], requires_grad=True),
543  },
544  {
545  'loc': torch.tensor([1., 1.], requires_grad=True),
546  'scale': torch.tensor([1., -1.], requires_grad=True),
547  },
548  ]),
549  Example(MultivariateNormal, [
550  {
551  'loc': torch.tensor([1., 1.], requires_grad=True),
552  'covariance_matrix': torch.tensor([[1.0, 0.0], [0.0, -2.0]], requires_grad=True),
553  },
554  ]),
555  Example(Normal, [
556  {
557  'loc': torch.tensor([1., 1.], requires_grad=True),
558  'scale': torch.tensor([0., 1.], requires_grad=True),
559  },
560  {
561  'loc': torch.tensor([1., 1.], requires_grad=True),
562  'scale': torch.tensor([1., -1.], requires_grad=True),
563  },
564  {
565  'loc': torch.tensor([1.0, 0.0], requires_grad=True),
566  'scale': torch.tensor([1e-5, -1e-5], requires_grad=True),
567  },
568  ]),
569  Example(OneHotCategorical, [
570  {'probs': torch.tensor([[0.1, 0.2, 0.3], [0.1, -10.0, 0.2]], requires_grad=True)},
571  {'probs': torch.tensor([[0.0, 0.0], [0.0, 0.0]], requires_grad=True)},
572  ]),
573  Example(Pareto, [
574  {
575  'scale': 0.0,
576  'alpha': 0.0
577  },
578  {
579  'scale': torch.tensor([0.0, 0.0], requires_grad=True),
580  'alpha': torch.tensor([-1e-5, 0.0], requires_grad=True)
581  },
582  {
583  'scale': torch.tensor([1.0]),
584  'alpha': -1.0
585  }
586  ]),
587  Example(Poisson, [
588  {
589  'rate': torch.tensor([0.0], requires_grad=True),
590  },
591  {
592  'rate': -1.0,
593  }
594  ]),
595  Example(RelaxedBernoulli, [
596  {
597  'temperature': torch.tensor([1.5], requires_grad=True),
598  'probs': torch.tensor([1.7, 0.2, 0.4], requires_grad=True),
599  },
600  {
601  'temperature': torch.tensor([2.0]),
602  'probs': torch.tensor([-1.0]),
603  }
604  ]),
605  Example(RelaxedOneHotCategorical, [
606  {
607  'temperature': torch.tensor([0.5], requires_grad=True),
608  'probs': torch.tensor([[-0.1, 0.2, 0.7], [0.5, 0.3, 0.2]], requires_grad=True)
609  },
610  {
611  'temperature': torch.tensor([2.0]),
612  'probs': torch.tensor([[-1.0, 0.0], [-1.0, 1.1]])
613  }
614  ]),
615  Example(TransformedDistribution, [
616  {
617  'base_distribution': Normal(0, 1),
618  'transforms': lambda x: x,
619  },
620  {
621  'base_distribution': Normal(0, 1),
622  'transforms': [lambda x: x],
623  },
624  ]),
625  Example(Uniform, [
626  {
627  'low': torch.tensor([2.0], requires_grad=True),
628  'high': torch.tensor([2.0], requires_grad=True),
629  },
630  {
631  'low': torch.tensor([0.0], requires_grad=True),
632  'high': torch.tensor([0.0], requires_grad=True),
633  },
634  {
635  'low': torch.tensor([1.0], requires_grad=True),
636  'high': torch.tensor([0.0], requires_grad=True),
637  }
638  ]),
639  Example(Weibull, [
640  {
641  'scale': torch.tensor([0.0], requires_grad=True),
642  'concentration': torch.tensor([0.0], requires_grad=True)
643  },
644  {
645  'scale': torch.tensor([1.0], requires_grad=True),
646  'concentration': torch.tensor([-1.0], requires_grad=True)
647  }
648  ])
649 ]
650 
651 
653  _do_cuda_memory_leak_check = True
654 
655  def _gradcheck_log_prob(self, dist_ctor, ctor_params):
656  # performs gradient checks on log_prob
657  distribution = dist_ctor(*ctor_params)
658  s = distribution.sample()
659  if s.is_floating_point():
660  s = s.detach().requires_grad_()
661 
662  expected_shape = distribution.batch_shape + distribution.event_shape
663  self.assertEqual(s.size(), expected_shape)
664 
665  def apply_fn(s, *params):
666  return dist_ctor(*params).log_prob(s)
667 
668  gradcheck(apply_fn, (s,) + tuple(ctor_params), raise_exception=True)
669 
670  def _check_log_prob(self, dist, asset_fn):
671  # checks that the log_prob matches a reference function
672  s = dist.sample()
673  log_probs = dist.log_prob(s)
674  log_probs_data_flat = log_probs.view(-1)
675  s_data_flat = s.view(len(log_probs_data_flat), -1)
676  for i, (val, log_prob) in enumerate(zip(s_data_flat, log_probs_data_flat)):
677  asset_fn(i, val.squeeze(), log_prob)
678 
679  def _check_sampler_sampler(self, torch_dist, ref_dist, message, multivariate=False,
680  num_samples=10000, failure_rate=1e-3):
681  # Checks that the .sample() method matches a reference function.
682  torch_samples = torch_dist.sample((num_samples,)).squeeze()
683  torch_samples = torch_samples.cpu().numpy()
684  ref_samples = ref_dist.rvs(num_samples).astype(np.float64)
685  if multivariate:
686  # Project onto a random axis.
687  axis = np.random.normal(size=torch_samples.shape[-1])
688  axis /= np.linalg.norm(axis)
689  torch_samples = np.dot(torch_samples, axis)
690  ref_samples = np.dot(ref_samples, axis)
691  samples = [(x, +1) for x in torch_samples] + [(x, -1) for x in ref_samples]
692  shuffle(samples) # necessary to prevent stable sort from making uneven bins for discrete
693  samples.sort(key=lambda x: x[0])
694  samples = np.array(samples)[:, 1]
695 
696  # Aggregate into bins filled with roughly zero-mean unit-variance RVs.
697  num_bins = 10
698  samples_per_bin = len(samples) // num_bins
699  bins = samples.reshape((num_bins, samples_per_bin)).mean(axis=1)
700  stddev = samples_per_bin ** -0.5
701  threshold = stddev * scipy.special.erfinv(1 - 2 * failure_rate / num_bins)
702  message = '{}.sample() is biased:\n{}'.format(message, bins)
703  for bias in bins:
704  self.assertLess(-threshold, bias, message)
705  self.assertLess(bias, threshold, message)
706 
707  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
708  def _check_sampler_discrete(self, torch_dist, ref_dist, message,
709  num_samples=10000, failure_rate=1e-3):
710  """Runs a Chi2-test for the support, but ignores tail instead of combining"""
711  torch_samples = torch_dist.sample((num_samples,)).squeeze()
712  torch_samples = torch_samples.cpu().numpy()
713  unique, counts = np.unique(torch_samples, return_counts=True)
714  pmf = ref_dist.pmf(unique)
715  msk = (counts > 5) & ((pmf * num_samples) > 5)
716  self.assertGreater(pmf[msk].sum(), 0.9, "Distribution is too sparse for test; try increasing num_samples")
717  chisq, p = scipy.stats.chisquare(counts[msk], pmf[msk] * num_samples)
718  self.assertGreater(p, failure_rate, message)
719 
720  def _check_enumerate_support(self, dist, examples):
721  for params, expected in examples:
722  params = {k: torch.tensor(v) for k, v in params.items()}
723  expected = torch.tensor(expected)
724  d = dist(**params)
725  actual = d.enumerate_support(expand=False)
726  self.assertEqual(actual, expected)
727  actual = d.enumerate_support(expand=True)
728  expected_with_expand = expected.expand((-1,) + d.batch_shape + d.event_shape)
729  self.assertEqual(actual, expected_with_expand)
730 
731  def test_repr(self):
732  for Dist, params in EXAMPLES:
733  for param in params:
734  dist = Dist(**param)
735  self.assertTrue(repr(dist).startswith(dist.__class__.__name__))
736 
737  def test_sample_detached(self):
738  for Dist, params in EXAMPLES:
739  for i, param in enumerate(params):
740  variable_params = [p for p in param.values() if getattr(p, 'requires_grad', False)]
741  if not variable_params:
742  continue
743  dist = Dist(**param)
744  sample = dist.sample()
745  self.assertFalse(sample.requires_grad,
746  msg='{} example {}/{}, .sample() is not detached'.format(
747  Dist.__name__, i + 1, len(params)))
748 
749  def test_rsample_requires_grad(self):
750  for Dist, params in EXAMPLES:
751  for i, param in enumerate(params):
752  if not any(getattr(p, 'requires_grad', False) for p in param.values()):
753  continue
754  dist = Dist(**param)
755  if not dist.has_rsample:
756  continue
757  sample = dist.rsample()
758  self.assertTrue(sample.requires_grad,
759  msg='{} example {}/{}, .rsample() does not require grad'.format(
760  Dist.__name__, i + 1, len(params)))
761 
762  def test_enumerate_support_type(self):
763  for Dist, params in EXAMPLES:
764  for i, param in enumerate(params):
765  dist = Dist(**param)
766  try:
767  self.assertTrue(type(dist.sample()) is type(dist.enumerate_support()),
768  msg=('{} example {}/{}, return type mismatch between ' +
769  'sample and enumerate_support.').format(Dist.__name__, i + 1, len(params)))
770  except NotImplementedError:
771  pass
772 
773  def test_lazy_property_grad(self):
774  x = torch.randn(1, requires_grad=True)
775 
776  class Dummy(object):
777  @lazy_property
778  def y(self):
779  return x + 1
780 
781  def test():
782  x.grad = None
783  Dummy().y.backward()
784  self.assertEqual(x.grad, torch.ones(1))
785 
786  test()
787  with torch.no_grad():
788  test()
789 
790  mean = torch.randn(2)
791  cov = torch.eye(2, requires_grad=True)
792  distn = MultivariateNormal(mean, cov)
793  with torch.no_grad():
794  distn.scale_tril
795  distn.scale_tril.sum().backward()
796  self.assertIsNotNone(cov.grad)
797 
798  def test_has_examples(self):
799  distributions_with_examples = {e.Dist for e in EXAMPLES}
800  for Dist in globals().values():
801  if isinstance(Dist, type) and issubclass(Dist, Distribution) \
802  and Dist is not Distribution and Dist is not ExponentialFamily:
803  self.assertIn(Dist, distributions_with_examples,
804  "Please add {} to the EXAMPLES list in test_distributions.py".format(Dist.__name__))
805 
806  def test_distribution_expand(self):
807  shapes = [torch.Size(), torch.Size((2,)), torch.Size((2, 1))]
808  for Dist, params in EXAMPLES:
809  for param in params:
810  for shape in shapes:
811  d = Dist(**param)
812  expanded_shape = shape + d.batch_shape
813  original_shape = d.batch_shape + d.event_shape
814  expected_shape = shape + original_shape
815  expanded = d.expand(batch_shape=list(expanded_shape))
816  sample = expanded.sample()
817  actual_shape = expanded.sample().shape
818  self.assertEqual(expanded.__class__, d.__class__)
819  self.assertEqual(d.sample().shape, original_shape)
820  self.assertEqual(expanded.log_prob(sample), d.log_prob(sample))
821  self.assertEqual(actual_shape, expected_shape)
822  self.assertEqual(expanded.batch_shape, expanded_shape)
823  try:
824  self.assertEqual(expanded.mean,
825  d.mean.expand(expanded_shape + d.event_shape),
826  allow_inf=True)
827  self.assertEqual(expanded.variance,
828  d.variance.expand(expanded_shape + d.event_shape),
829  allow_inf=True)
830  except NotImplementedError:
831  pass
832 
833  def test_distribution_subclass_expand(self):
834  expand_by = torch.Size((2,))
835  for Dist, params in EXAMPLES:
836 
837  class SubClass(Dist):
838  pass
839 
840  for param in params:
841  d = SubClass(**param)
842  expanded_shape = expand_by + d.batch_shape
843  original_shape = d.batch_shape + d.event_shape
844  expected_shape = expand_by + original_shape
845  expanded = d.expand(batch_shape=expanded_shape)
846  sample = expanded.sample()
847  actual_shape = expanded.sample().shape
848  self.assertEqual(expanded.__class__, d.__class__)
849  self.assertEqual(d.sample().shape, original_shape)
850  self.assertEqual(expanded.log_prob(sample), d.log_prob(sample))
851  self.assertEqual(actual_shape, expected_shape)
852 
853  def test_bernoulli(self):
854  p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True)
855  r = torch.tensor(0.3, requires_grad=True)
856  s = 0.3
857  self.assertEqual(Bernoulli(p).sample((8,)).size(), (8, 3))
858  self.assertFalse(Bernoulli(p).sample().requires_grad)
859  self.assertEqual(Bernoulli(r).sample((8,)).size(), (8,))
860  self.assertEqual(Bernoulli(r).sample().size(), ())
861  self.assertEqual(Bernoulli(r).sample((3, 2)).size(), (3, 2,))
862  self.assertEqual(Bernoulli(s).sample().size(), ())
863  self._gradcheck_log_prob(Bernoulli, (p,))
864 
865  def ref_log_prob(idx, val, log_prob):
866  prob = p[idx]
867  self.assertEqual(log_prob, math.log(prob if val else 1 - prob))
868 
869  self._check_log_prob(Bernoulli(p), ref_log_prob)
870  self._check_log_prob(Bernoulli(logits=p.log() - (-p).log1p()), ref_log_prob)
871  self.assertRaises(NotImplementedError, Bernoulli(r).rsample)
872 
873  # check entropy computation
874  self.assertEqual(Bernoulli(p).entropy(), torch.tensor([0.6108, 0.5004, 0.6730]), prec=1e-4)
875  self.assertEqual(Bernoulli(torch.tensor([0.0])).entropy(), torch.tensor([0.0]))
876  self.assertEqual(Bernoulli(s).entropy(), torch.tensor(0.6108), prec=1e-4)
877 
878  def test_bernoulli_enumerate_support(self):
879  examples = [
880  ({"probs": [0.1]}, [[0], [1]]),
881  ({"probs": [0.1, 0.9]}, [[0], [1]]),
882  ({"probs": [[0.1, 0.2], [0.3, 0.4]]}, [[[0]], [[1]]]),
883  ]
884  self._check_enumerate_support(Bernoulli, examples)
885 
886  def test_bernoulli_3d(self):
887  p = torch.full((2, 3, 5), 0.5).requires_grad_()
888  self.assertEqual(Bernoulli(p).sample().size(), (2, 3, 5))
889  self.assertEqual(Bernoulli(p).sample(sample_shape=(2, 5)).size(),
890  (2, 5, 2, 3, 5))
891  self.assertEqual(Bernoulli(p).sample((2,)).size(), (2, 2, 3, 5))
892 
893  def test_geometric(self):
894  p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True)
895  r = torch.tensor(0.3, requires_grad=True)
896  s = 0.3
897  self.assertEqual(Geometric(p).sample((8,)).size(), (8, 3))
898  self.assertEqual(Geometric(1).sample(), 0)
899  self.assertEqual(Geometric(1).log_prob(torch.tensor(1.)), -inf, allow_inf=True)
900  self.assertEqual(Geometric(1).log_prob(torch.tensor(0.)), 0)
901  self.assertFalse(Geometric(p).sample().requires_grad)
902  self.assertEqual(Geometric(r).sample((8,)).size(), (8,))
903  self.assertEqual(Geometric(r).sample().size(), ())
904  self.assertEqual(Geometric(r).sample((3, 2)).size(), (3, 2))
905  self.assertEqual(Geometric(s).sample().size(), ())
906  self._gradcheck_log_prob(Geometric, (p,))
907  self.assertRaises(ValueError, lambda: Geometric(0))
908  self.assertRaises(NotImplementedError, Geometric(r).rsample)
909 
910  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
911  def test_geometric_log_prob_and_entropy(self):
912  p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True)
913  s = 0.3
914 
915  def ref_log_prob(idx, val, log_prob):
916  prob = p[idx].detach()
917  self.assertEqual(log_prob, scipy.stats.geom(prob, loc=-1).logpmf(val))
918 
919  self._check_log_prob(Geometric(p), ref_log_prob)
920  self._check_log_prob(Geometric(logits=p.log() - (-p).log1p()), ref_log_prob)
921 
922  # check entropy computation
923  self.assertEqual(Geometric(p).entropy(), scipy.stats.geom(p.detach().numpy(), loc=-1).entropy(), prec=1e-3)
924  self.assertEqual(float(Geometric(s).entropy()), scipy.stats.geom(s, loc=-1).entropy().item(), prec=1e-3)
925 
926  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
927  def test_geometric_sample(self):
928  set_rng_seed(0) # see Note [Randomized statistical tests]
929  for prob in [0.01, 0.18, 0.8]:
930  self._check_sampler_discrete(Geometric(prob),
931  scipy.stats.geom(p=prob, loc=-1),
932  'Geometric(prob={})'.format(prob))
933 
934  def test_binomial(self):
935  p = torch.arange(0.05, 1, 0.1).requires_grad_()
936  for total_count in [1, 2, 10]:
937  self._gradcheck_log_prob(lambda p: Binomial(total_count, p), [p])
938  self._gradcheck_log_prob(lambda p: Binomial(total_count, None, p.log()), [p])
939  self.assertRaises(NotImplementedError, Binomial(10, p).rsample)
940  self.assertRaises(NotImplementedError, Binomial(10, p).entropy)
941 
942  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
943  def test_binomial_log_prob(self):
944  probs = torch.arange(0.05, 1, 0.1)
945  for total_count in [1, 2, 10]:
946 
947  def ref_log_prob(idx, x, log_prob):
948  p = probs.view(-1)[idx].item()
949  expected = scipy.stats.binom(total_count, p).logpmf(x)
950  self.assertAlmostEqual(log_prob, expected, places=3)
951 
952  self._check_log_prob(Binomial(total_count, probs), ref_log_prob)
953  logits = probs_to_logits(probs, is_binary=True)
954  self._check_log_prob(Binomial(total_count, logits=logits), ref_log_prob)
955 
956  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
957  def test_binomial_log_prob_vectorized_count(self):
958  probs = torch.tensor([0.2, 0.7, 0.9])
959  for total_count, sample in [(torch.tensor([10]), torch.tensor([7., 3., 9.])),
960  (torch.tensor([1, 2, 10]), torch.tensor([0., 1., 9.]))]:
961  log_prob = Binomial(total_count, probs).log_prob(sample)
962  expected = scipy.stats.binom(total_count.cpu().numpy(), probs.cpu().numpy()).logpmf(sample)
963  self.assertAlmostEqual(log_prob, expected, places=4)
964 
965  def test_binomial_enumerate_support(self):
966  examples = [
967  ({"probs": [0.1], "total_count": 2}, [[0], [1], [2]]),
968  ({"probs": [0.1, 0.9], "total_count": 2}, [[0], [1], [2]]),
969  ({"probs": [[0.1, 0.2], [0.3, 0.4]], "total_count": 3}, [[[0]], [[1]], [[2]], [[3]]]),
970  ]
971  self._check_enumerate_support(Binomial, examples)
972 
973  def test_binomial_extreme_vals(self):
974  total_count = 100
975  bin0 = Binomial(total_count, 0)
976  self.assertEqual(bin0.sample(), 0)
977  self.assertAlmostEqual(bin0.log_prob(torch.tensor([0.]))[0], 0, places=3)
978  self.assertEqual(float(bin0.log_prob(torch.tensor([1.])).exp()), 0, allow_inf=True)
979  bin1 = Binomial(total_count, 1)
980  self.assertEqual(bin1.sample(), total_count)
981  self.assertAlmostEqual(bin1.log_prob(torch.tensor([float(total_count)]))[0], 0, places=3)
982  self.assertEqual(float(bin1.log_prob(torch.tensor([float(total_count - 1)])).exp()), 0, allow_inf=True)
983  zero_counts = torch.zeros(torch.Size((2, 2)))
984  bin2 = Binomial(zero_counts, 1)
985  self.assertEqual(bin2.sample(), zero_counts)
986  self.assertEqual(bin2.log_prob(zero_counts), zero_counts)
987 
988  def test_binomial_vectorized_count(self):
989  set_rng_seed(0)
990  total_count = torch.tensor([[4, 7], [3, 8]])
991  bin0 = Binomial(total_count, torch.tensor(1.))
992  self.assertEqual(bin0.sample(), total_count)
993  bin1 = Binomial(total_count, torch.tensor(0.5))
994  samples = bin1.sample(torch.Size((100000,)))
995  self.assertTrue((samples <= total_count.type_as(samples)).all())
996  self.assertEqual(samples.mean(dim=0), bin1.mean, prec=0.02)
997  self.assertEqual(samples.var(dim=0), bin1.variance, prec=0.02)
998 
999  def test_negative_binomial(self):
1000  p = torch.arange(0.05, 1, 0.1).requires_grad_()
1001  for total_count in [1, 2, 10]:
1002  self._gradcheck_log_prob(lambda p: NegativeBinomial(total_count, p), [p])
1003  self._gradcheck_log_prob(lambda p: NegativeBinomial(total_count, None, p.log()), [p])
1004  self.assertRaises(NotImplementedError, NegativeBinomial(10, p).rsample)
1005  self.assertRaises(NotImplementedError, NegativeBinomial(10, p).entropy)
1006 
1007  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1008  def test_negative_binomial_log_prob(self):
1009  probs = torch.arange(0.05, 1, 0.1)
1010  for total_count in [1, 2, 10]:
1011 
1012  def ref_log_prob(idx, x, log_prob):
1013  p = probs.view(-1)[idx].item()
1014  expected = scipy.stats.nbinom(total_count, 1 - p).logpmf(x)
1015  self.assertAlmostEqual(log_prob, expected, places=3)
1016 
1017  self._check_log_prob(NegativeBinomial(total_count, probs), ref_log_prob)
1018  logits = probs_to_logits(probs, is_binary=True)
1019  self._check_log_prob(NegativeBinomial(total_count, logits=logits), ref_log_prob)
1020 
1021  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1022  def test_negative_binomial_log_prob_vectorized_count(self):
1023  probs = torch.tensor([0.2, 0.7, 0.9])
1024  for total_count, sample in [(torch.tensor([10]), torch.tensor([7., 3., 9.])),
1025  (torch.tensor([1, 2, 10]), torch.tensor([0., 1., 9.]))]:
1026  log_prob = NegativeBinomial(total_count, probs).log_prob(sample)
1027  expected = scipy.stats.nbinom(total_count.cpu().numpy(), 1 - probs.cpu().numpy()).logpmf(sample)
1028  self.assertAlmostEqual(log_prob, expected, places=4)
1029 
1030  def test_multinomial_1d(self):
1031  total_count = 10
1032  p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
1033  self.assertEqual(Multinomial(total_count, p).sample().size(), (3,))
1034  self.assertEqual(Multinomial(total_count, p).sample((2, 2)).size(), (2, 2, 3))
1035  self.assertEqual(Multinomial(total_count, p).sample((1,)).size(), (1, 3))
1036  self._gradcheck_log_prob(lambda p: Multinomial(total_count, p), [p])
1037  self._gradcheck_log_prob(lambda p: Multinomial(total_count, None, p.log()), [p])
1038  self.assertRaises(NotImplementedError, Multinomial(10, p).rsample)
1039 
1040  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1041  def test_multinomial_1d_log_prob(self):
1042  total_count = 10
1043  p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
1044  dist = Multinomial(total_count, probs=p)
1045  x = dist.sample()
1046  log_prob = dist.log_prob(x)
1047  expected = torch.tensor(scipy.stats.multinomial.logpmf(x.numpy(), n=total_count, p=dist.probs.detach().numpy()))
1048  self.assertEqual(log_prob, expected)
1049 
1050  dist = Multinomial(total_count, logits=p.log())
1051  x = dist.sample()
1052  log_prob = dist.log_prob(x)
1053  expected = torch.tensor(scipy.stats.multinomial.logpmf(x.numpy(), n=total_count, p=dist.probs.detach().numpy()))
1054  self.assertEqual(log_prob, expected)
1055 
1056  def test_multinomial_2d(self):
1057  total_count = 10
1058  probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
1059  probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
1060  p = torch.tensor(probabilities, requires_grad=True)
1061  s = torch.tensor(probabilities_1, requires_grad=True)
1062  self.assertEqual(Multinomial(total_count, p).sample().size(), (2, 3))
1063  self.assertEqual(Multinomial(total_count, p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3))
1064  self.assertEqual(Multinomial(total_count, p).sample((6,)).size(), (6, 2, 3))
1065  set_rng_seed(0)
1066  self._gradcheck_log_prob(lambda p: Multinomial(total_count, p), [p])
1067  self._gradcheck_log_prob(lambda p: Multinomial(total_count, None, p.log()), [p])
1068 
1069  # sample check for extreme value of probs
1070  self.assertEqual(Multinomial(total_count, s).sample(),
1071  torch.tensor([[total_count, 0], [0, total_count]]))
1072 
1073  # check entropy computation
1074  self.assertRaises(NotImplementedError, Multinomial(10, p).entropy)
1075 
1076  def test_categorical_1d(self):
1077  p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
1078  self.assertTrue(is_all_nan(Categorical(p).mean))
1079  self.assertTrue(is_all_nan(Categorical(p).variance))
1080  self.assertEqual(Categorical(p).sample().size(), ())
1081  self.assertFalse(Categorical(p).sample().requires_grad)
1082  self.assertEqual(Categorical(p).sample((2, 2)).size(), (2, 2))
1083  self.assertEqual(Categorical(p).sample((1,)).size(), (1,))
1084  self._gradcheck_log_prob(Categorical, (p,))
1085  self.assertRaises(NotImplementedError, Categorical(p).rsample)
1086 
1087  def test_categorical_2d(self):
1088  probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
1089  probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
1090  p = torch.tensor(probabilities, requires_grad=True)
1091  s = torch.tensor(probabilities_1, requires_grad=True)
1092  self.assertEqual(Categorical(p).mean.size(), (2,))
1093  self.assertEqual(Categorical(p).variance.size(), (2,))
1094  self.assertTrue(is_all_nan(Categorical(p).mean))
1095  self.assertTrue(is_all_nan(Categorical(p).variance))
1096  self.assertEqual(Categorical(p).sample().size(), (2,))
1097  self.assertEqual(Categorical(p).sample(sample_shape=(3, 4)).size(), (3, 4, 2))
1098  self.assertEqual(Categorical(p).sample((6,)).size(), (6, 2))
1099  self._gradcheck_log_prob(Categorical, (p,))
1100 
1101  # sample check for extreme value of probs
1102  set_rng_seed(0)
1103  self.assertEqual(Categorical(s).sample(sample_shape=(2,)),
1104  torch.tensor([[0, 1], [0, 1]]))
1105 
1106  def ref_log_prob(idx, val, log_prob):
1107  sample_prob = p[idx][val] / p[idx].sum()
1108  self.assertEqual(log_prob, math.log(sample_prob))
1109 
1110  self._check_log_prob(Categorical(p), ref_log_prob)
1111  self._check_log_prob(Categorical(logits=p.log()), ref_log_prob)
1112 
1113  # check entropy computation
1114  self.assertEqual(Categorical(p).entropy(), torch.tensor([1.0114, 1.0297]), prec=1e-4)
1115  self.assertEqual(Categorical(s).entropy(), torch.tensor([0.0, 0.0]))
1116 
1117  def test_categorical_enumerate_support(self):
1118  examples = [
1119  ({"probs": [0.1, 0.2, 0.7]}, [0, 1, 2]),
1120  ({"probs": [[0.1, 0.9], [0.3, 0.7]]}, [[0], [1]]),
1121  ]
1122  self._check_enumerate_support(Categorical, examples)
1123 
1124  def test_one_hot_categorical_1d(self):
1125  p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
1126  self.assertEqual(OneHotCategorical(p).sample().size(), (3,))
1127  self.assertFalse(OneHotCategorical(p).sample().requires_grad)
1128  self.assertEqual(OneHotCategorical(p).sample((2, 2)).size(), (2, 2, 3))
1129  self.assertEqual(OneHotCategorical(p).sample((1,)).size(), (1, 3))
1130  self._gradcheck_log_prob(OneHotCategorical, (p,))
1131  self.assertRaises(NotImplementedError, OneHotCategorical(p).rsample)
1132 
1133  def test_one_hot_categorical_2d(self):
1134  probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
1135  probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
1136  p = torch.tensor(probabilities, requires_grad=True)
1137  s = torch.tensor(probabilities_1, requires_grad=True)
1138  self.assertEqual(OneHotCategorical(p).sample().size(), (2, 3))
1139  self.assertEqual(OneHotCategorical(p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3))
1140  self.assertEqual(OneHotCategorical(p).sample((6,)).size(), (6, 2, 3))
1141  self._gradcheck_log_prob(OneHotCategorical, (p,))
1142 
1143  dist = OneHotCategorical(p)
1144  x = dist.sample()
1145  self.assertEqual(dist.log_prob(x), Categorical(p).log_prob(x.max(-1)[1]))
1146 
1147  def test_one_hot_categorical_enumerate_support(self):
1148  examples = [
1149  ({"probs": [0.1, 0.2, 0.7]}, [[1, 0, 0], [0, 1, 0], [0, 0, 1]]),
1150  ({"probs": [[0.1, 0.9], [0.3, 0.7]]}, [[[1, 0]], [[0, 1]]]),
1151  ]
1152  self._check_enumerate_support(OneHotCategorical, examples)
1153 
1154  def test_poisson_shape(self):
1155  rate = torch.randn(2, 3).abs().requires_grad_()
1156  rate_1d = torch.randn(1).abs().requires_grad_()
1157  self.assertEqual(Poisson(rate).sample().size(), (2, 3))
1158  self.assertEqual(Poisson(rate).sample((7,)).size(), (7, 2, 3))
1159  self.assertEqual(Poisson(rate_1d).sample().size(), (1,))
1160  self.assertEqual(Poisson(rate_1d).sample((1,)).size(), (1, 1))
1161  self.assertEqual(Poisson(2.0).sample((2,)).size(), (2,))
1162 
1163  @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
1164  def test_poisson_log_prob(self):
1165  rate = torch.randn(2, 3).abs().requires_grad_()
1166  rate_1d = torch.randn(1).abs().requires_grad_()
1167 
1168  def ref_log_prob(idx, x, log_prob):
1169  l = rate.view(-1)[idx].detach()
1170  expected = scipy.stats.poisson.logpmf(x, l)
1171  self.assertAlmostEqual(log_prob, expected, places=3)
1172 
1173  set_rng_seed(0)
1174  self._check_log_prob(Poisson(rate), ref_log_prob)
1175  self._gradcheck_log_prob(Poisson, (rate,))
1176  self._gradcheck_log_prob(Poisson, (rate_1d,))
1177 
1178  @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
1179  def test_poisson_sample(self):
1180  set_rng_seed(1) # see Note [Randomized statistical tests]
1181  for rate in [0.1, 1.0, 5.0]:
1182  self._check_sampler_discrete(Poisson(rate),
1183  scipy.stats.poisson(rate),
1184  'Poisson(lambda={})'.format(rate),
1185  failure_rate=1e-3)
1186 
1187  @unittest.skipIf(not TEST_CUDA, "CUDA not found")
1188  @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
1189  def test_poisson_gpu_sample(self):
1190  set_rng_seed(1)
1191  for rate in [0.12, 0.9, 4.0]:
1192  self._check_sampler_discrete(Poisson(torch.tensor([rate]).cuda()),
1193  scipy.stats.poisson(rate),
1194  'Poisson(lambda={}, cuda)'.format(rate),
1195  failure_rate=1e-3)
1196 
1197  def test_relaxed_bernoulli(self):
1198  p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True)
1199  r = torch.tensor(0.3, requires_grad=True)
1200  s = 0.3
1201  temp = torch.tensor(0.67, requires_grad=True)
1202  self.assertEqual(RelaxedBernoulli(temp, p).sample((8,)).size(), (8, 3))
1203  self.assertFalse(RelaxedBernoulli(temp, p).sample().requires_grad)
1204  self.assertEqual(RelaxedBernoulli(temp, r).sample((8,)).size(), (8,))
1205  self.assertEqual(RelaxedBernoulli(temp, r).sample().size(), ())
1206  self.assertEqual(RelaxedBernoulli(temp, r).sample((3, 2)).size(), (3, 2,))
1207  self.assertEqual(RelaxedBernoulli(temp, s).sample().size(), ())
1208  self._gradcheck_log_prob(RelaxedBernoulli, (temp, p))
1209  self._gradcheck_log_prob(RelaxedBernoulli, (temp, r))
1210 
1211  # test that rsample doesn't fail
1212  s = RelaxedBernoulli(temp, p).rsample()
1213  s.backward(torch.ones_like(s))
1214 
1215  @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
1216  def test_rounded_relaxed_bernoulli(self):
1217  set_rng_seed(0) # see Note [Randomized statistical tests]
1218 
1219  class Rounded(object):
1220  def __init__(self, dist):
1221  self.dist = dist
1222 
1223  def sample(self, *args, **kwargs):
1224  return torch.round(self.dist.sample(*args, **kwargs))
1225 
1226  for probs, temp in product([0.1, 0.2, 0.8], [0.1, 1.0, 10.0]):
1227  self._check_sampler_discrete(Rounded(RelaxedBernoulli(temp, probs)),
1228  scipy.stats.bernoulli(probs),
1229  'Rounded(RelaxedBernoulli(temp={}, probs={}))'.format(temp, probs),
1230  failure_rate=1e-3)
1231 
1232  for probs in [0.001, 0.2, 0.999]:
1233  equal_probs = torch.tensor(0.5)
1234  dist = RelaxedBernoulli(1e10, probs)
1235  s = dist.rsample()
1236  self.assertEqual(equal_probs, s)
1237 
1238  def test_relaxed_one_hot_categorical_1d(self):
1239  p = torch.tensor([0.1, 0.2, 0.3], requires_grad=True)
1240  temp = torch.tensor(0.67, requires_grad=True)
1241  self.assertEqual(RelaxedOneHotCategorical(probs=p, temperature=temp).sample().size(), (3,))
1242  self.assertFalse(RelaxedOneHotCategorical(probs=p, temperature=temp).sample().requires_grad)
1243  self.assertEqual(RelaxedOneHotCategorical(probs=p, temperature=temp).sample((2, 2)).size(), (2, 2, 3))
1244  self.assertEqual(RelaxedOneHotCategorical(probs=p, temperature=temp).sample((1,)).size(), (1, 3))
1245  self._gradcheck_log_prob(RelaxedOneHotCategorical, (temp, p))
1246 
1247  def test_relaxed_one_hot_categorical_2d(self):
1248  probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
1249  probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
1250  temp = torch.tensor([3.0], requires_grad=True)
1251  # The lower the temperature, the more unstable the log_prob gradcheck is
1252  # w.r.t. the sample. Values below 0.25 empirically fail the default tol.
1253  temp_2 = torch.tensor([0.25], requires_grad=True)
1254  p = torch.tensor(probabilities, requires_grad=True)
1255  s = torch.tensor(probabilities_1, requires_grad=True)
1256  self.assertEqual(RelaxedOneHotCategorical(temp, p).sample().size(), (2, 3))
1257  self.assertEqual(RelaxedOneHotCategorical(temp, p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3))
1258  self.assertEqual(RelaxedOneHotCategorical(temp, p).sample((6,)).size(), (6, 2, 3))
1259  self._gradcheck_log_prob(RelaxedOneHotCategorical, (temp, p))
1260  self._gradcheck_log_prob(RelaxedOneHotCategorical, (temp_2, p))
1261 
1262  @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
1263  def test_argmax_relaxed_categorical(self):
1264  set_rng_seed(0) # see Note [Randomized statistical tests]
1265 
1266  class ArgMax(object):
1267  def __init__(self, dist):
1268  self.dist = dist
1269 
1270  def sample(self, *args, **kwargs):
1271  s = self.dist.sample(*args, **kwargs)
1272  _, idx = torch.max(s, -1)
1273  return idx
1274 
1275  class ScipyCategorical(object):
1276  def __init__(self, dist):
1277  self.dist = dist
1278 
1279  def pmf(self, samples):
1280  new_samples = np.zeros(samples.shape + self.dist.p.shape)
1281  new_samples[np.arange(samples.shape[0]), samples] = 1
1282  return self.dist.pmf(new_samples)
1283 
1284  for probs, temp in product([torch.tensor([0.1, 0.9]), torch.tensor([0.2, 0.2, 0.6])], [0.1, 1.0, 10.0]):
1285  self._check_sampler_discrete(ArgMax(RelaxedOneHotCategorical(temp, probs)),
1286  ScipyCategorical(scipy.stats.multinomial(1, probs)),
1287  'Rounded(RelaxedOneHotCategorical(temp={}, probs={}))'.format(temp, probs),
1288  failure_rate=1e-3)
1289 
1290  for probs in [torch.tensor([0.1, 0.9]), torch.tensor([0.2, 0.2, 0.6])]:
1291  equal_probs = torch.ones(probs.size()) / probs.size()[0]
1292  dist = RelaxedOneHotCategorical(1e10, probs)
1293  s = dist.rsample()
1294  self.assertEqual(equal_probs, s)
1295 
1296  def test_uniform(self):
1297  low = torch.zeros(5, 5, requires_grad=True)
1298  high = (torch.ones(5, 5) * 3).requires_grad_()
1299  low_1d = torch.zeros(1, requires_grad=True)
1300  high_1d = (torch.ones(1) * 3).requires_grad_()
1301  self.assertEqual(Uniform(low, high).sample().size(), (5, 5))
1302  self.assertEqual(Uniform(low, high).sample((7,)).size(), (7, 5, 5))
1303  self.assertEqual(Uniform(low_1d, high_1d).sample().size(), (1,))
1304  self.assertEqual(Uniform(low_1d, high_1d).sample((1,)).size(), (1, 1))
1305  self.assertEqual(Uniform(0.0, 1.0).sample((1,)).size(), (1,))
1306 
1307  # Check log_prob computation when value outside range
1308  uniform = Uniform(low_1d, high_1d)
1309  above_high = torch.tensor([4.0])
1310  below_low = torch.tensor([-1.0])
1311  self.assertEqual(uniform.log_prob(above_high).item(), -inf, allow_inf=True)
1312  self.assertEqual(uniform.log_prob(below_low).item(), -inf, allow_inf=True)
1313 
1314  # check cdf computation when value outside range
1315  self.assertEqual(uniform.cdf(below_low).item(), 0)
1316  self.assertEqual(uniform.cdf(above_high).item(), 1)
1317 
1318  set_rng_seed(1)
1319  self._gradcheck_log_prob(Uniform, (low, high))
1320  self._gradcheck_log_prob(Uniform, (low, 1.0))
1321  self._gradcheck_log_prob(Uniform, (0.0, high))
1322 
1323  state = torch.get_rng_state()
1324  rand = low.new(low.size()).uniform_()
1325  torch.set_rng_state(state)
1326  u = Uniform(low, high).rsample()
1327  u.backward(torch.ones_like(u))
1328  self.assertEqual(low.grad, 1 - rand)
1329  self.assertEqual(high.grad, rand)
1330  low.grad.zero_()
1331  high.grad.zero_()
1332 
1333  def test_cauchy(self):
1334  loc = torch.zeros(5, 5, requires_grad=True)
1335  scale = torch.ones(5, 5, requires_grad=True)
1336  loc_1d = torch.zeros(1, requires_grad=True)
1337  scale_1d = torch.ones(1, requires_grad=True)
1338  self.assertTrue(is_all_nan(Cauchy(loc_1d, scale_1d).mean))
1339  self.assertEqual(Cauchy(loc_1d, scale_1d).variance, inf, allow_inf=True)
1340  self.assertEqual(Cauchy(loc, scale).sample().size(), (5, 5))
1341  self.assertEqual(Cauchy(loc, scale).sample((7,)).size(), (7, 5, 5))
1342  self.assertEqual(Cauchy(loc_1d, scale_1d).sample().size(), (1,))
1343  self.assertEqual(Cauchy(loc_1d, scale_1d).sample((1,)).size(), (1, 1))
1344  self.assertEqual(Cauchy(0.0, 1.0).sample((1,)).size(), (1,))
1345 
1346  set_rng_seed(1)
1347  self._gradcheck_log_prob(Cauchy, (loc, scale))
1348  self._gradcheck_log_prob(Cauchy, (loc, 1.0))
1349  self._gradcheck_log_prob(Cauchy, (0.0, scale))
1350 
1351  state = torch.get_rng_state()
1352  eps = loc.new(loc.size()).cauchy_()
1353  torch.set_rng_state(state)
1354  c = Cauchy(loc, scale).rsample()
1355  c.backward(torch.ones_like(c))
1356  self.assertEqual(loc.grad, torch.ones_like(scale))
1357  self.assertEqual(scale.grad, eps)
1358  loc.grad.zero_()
1359  scale.grad.zero_()
1360 
1361  def test_halfcauchy(self):
1362  scale = torch.ones(5, 5, requires_grad=True)
1363  scale_1d = torch.ones(1, requires_grad=True)
1364  self.assertTrue(is_all_nan(HalfCauchy(scale_1d).mean))
1365  self.assertEqual(HalfCauchy(scale_1d).variance, inf, allow_inf=True)
1366  self.assertEqual(HalfCauchy(scale).sample().size(), (5, 5))
1367  self.assertEqual(HalfCauchy(scale).sample((7,)).size(), (7, 5, 5))
1368  self.assertEqual(HalfCauchy(scale_1d).sample().size(), (1,))
1369  self.assertEqual(HalfCauchy(scale_1d).sample((1,)).size(), (1, 1))
1370  self.assertEqual(HalfCauchy(1.0).sample((1,)).size(), (1,))
1371 
1372  set_rng_seed(1)
1373  self._gradcheck_log_prob(HalfCauchy, (scale,))
1374  self._gradcheck_log_prob(HalfCauchy, (1.0,))
1375 
1376  state = torch.get_rng_state()
1377  eps = scale.new(scale.size()).cauchy_().abs_()
1378  torch.set_rng_state(state)
1379  c = HalfCauchy(scale).rsample()
1380  c.backward(torch.ones_like(c))
1381  self.assertEqual(scale.grad, eps)
1382  scale.grad.zero_()
1383 
1384  def test_halfnormal(self):
1385  std = torch.randn(5, 5).abs().requires_grad_()
1386  std_1d = torch.randn(1, requires_grad=True)
1387  std_delta = torch.tensor([1e-5, 1e-5])
1388  self.assertEqual(HalfNormal(std).sample().size(), (5, 5))
1389  self.assertEqual(HalfNormal(std).sample((7,)).size(), (7, 5, 5))
1390  self.assertEqual(HalfNormal(std_1d).sample((1,)).size(), (1, 1))
1391  self.assertEqual(HalfNormal(std_1d).sample().size(), (1,))
1392  self.assertEqual(HalfNormal(.6).sample((1,)).size(), (1,))
1393  self.assertEqual(HalfNormal(50.0).sample((1,)).size(), (1,))
1394 
1395  # sample check for extreme value of std
1396  set_rng_seed(1)
1397  self.assertEqual(HalfNormal(std_delta).sample(sample_shape=(1, 2)),
1398  torch.tensor([[[0.0, 0.0], [0.0, 0.0]]]),
1399  prec=1e-4)
1400 
1401  self._gradcheck_log_prob(HalfNormal, (std,))
1402  self._gradcheck_log_prob(HalfNormal, (1.0,))
1403 
1404  # check .log_prob() can broadcast.
1405  dist = HalfNormal(torch.ones(2, 1, 4))
1406  log_prob = dist.log_prob(torch.ones(3, 1))
1407  self.assertEqual(log_prob.shape, (2, 3, 4))
1408 
1409  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1410  def test_halfnormal_logprob(self):
1411  std = torch.randn(5, 1).abs().requires_grad_()
1412 
1413  def ref_log_prob(idx, x, log_prob):
1414  s = std.view(-1)[idx].detach()
1415  expected = scipy.stats.halfnorm(scale=s).logpdf(x)
1416  self.assertAlmostEqual(log_prob, expected, places=3)
1417 
1418  self._check_log_prob(HalfNormal(std), ref_log_prob)
1419 
1420  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1421  def test_halfnormal_sample(self):
1422  set_rng_seed(0) # see Note [Randomized statistical tests]
1423  for std in [0.1, 1.0, 10.0]:
1424  self._check_sampler_sampler(HalfNormal(std),
1425  scipy.stats.halfnorm(scale=std),
1426  'HalfNormal(scale={})'.format(std))
1427 
1428  def test_lognormal(self):
1429  mean = torch.randn(5, 5, requires_grad=True)
1430  std = torch.randn(5, 5).abs().requires_grad_()
1431  mean_1d = torch.randn(1, requires_grad=True)
1432  std_1d = torch.randn(1).abs().requires_grad_()
1433  mean_delta = torch.tensor([1.0, 0.0])
1434  std_delta = torch.tensor([1e-5, 1e-5])
1435  self.assertEqual(LogNormal(mean, std).sample().size(), (5, 5))
1436  self.assertEqual(LogNormal(mean, std).sample((7,)).size(), (7, 5, 5))
1437  self.assertEqual(LogNormal(mean_1d, std_1d).sample((1,)).size(), (1, 1))
1438  self.assertEqual(LogNormal(mean_1d, std_1d).sample().size(), (1,))
1439  self.assertEqual(LogNormal(0.2, .6).sample((1,)).size(), (1,))
1440  self.assertEqual(LogNormal(-0.7, 50.0).sample((1,)).size(), (1,))
1441 
1442  # sample check for extreme value of mean, std
1443  set_rng_seed(1)
1444  self.assertEqual(LogNormal(mean_delta, std_delta).sample(sample_shape=(1, 2)),
1445  torch.tensor([[[math.exp(1), 1.0], [math.exp(1), 1.0]]]),
1446  prec=1e-4)
1447 
1448  self._gradcheck_log_prob(LogNormal, (mean, std))
1449  self._gradcheck_log_prob(LogNormal, (mean, 1.0))
1450  self._gradcheck_log_prob(LogNormal, (0.0, std))
1451 
1452  # check .log_prob() can broadcast.
1453  dist = LogNormal(torch.zeros(4), torch.ones(2, 1, 1))
1454  log_prob = dist.log_prob(torch.ones(3, 1))
1455  self.assertEqual(log_prob.shape, (2, 3, 4))
1456 
1457  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1458  def test_lognormal_logprob(self):
1459  mean = torch.randn(5, 1, requires_grad=True)
1460  std = torch.randn(5, 1).abs().requires_grad_()
1461 
1462  def ref_log_prob(idx, x, log_prob):
1463  m = mean.view(-1)[idx].detach()
1464  s = std.view(-1)[idx].detach()
1465  expected = scipy.stats.lognorm(s=s, scale=math.exp(m)).logpdf(x)
1466  self.assertAlmostEqual(log_prob, expected, places=3)
1467 
1468  self._check_log_prob(LogNormal(mean, std), ref_log_prob)
1469 
1470  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1471  def test_lognormal_sample(self):
1472  set_rng_seed(0) # see Note [Randomized statistical tests]
1473  for mean, std in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
1474  self._check_sampler_sampler(LogNormal(mean, std),
1475  scipy.stats.lognorm(scale=math.exp(mean), s=std),
1476  'LogNormal(loc={}, scale={})'.format(mean, std))
1477 
1478  def test_logisticnormal(self):
1479  mean = torch.randn(5, 5).requires_grad_()
1480  std = torch.randn(5, 5).abs().requires_grad_()
1481  mean_1d = torch.randn(1).requires_grad_()
1482  std_1d = torch.randn(1).requires_grad_()
1483  mean_delta = torch.tensor([1.0, 0.0])
1484  std_delta = torch.tensor([1e-5, 1e-5])
1485  self.assertEqual(LogisticNormal(mean, std).sample().size(), (5, 6))
1486  self.assertEqual(LogisticNormal(mean, std).sample((7,)).size(), (7, 5, 6))
1487  self.assertEqual(LogisticNormal(mean_1d, std_1d).sample((1,)).size(), (1, 2))
1488  self.assertEqual(LogisticNormal(mean_1d, std_1d).sample().size(), (2,))
1489  self.assertEqual(LogisticNormal(0.2, .6).sample((1,)).size(), (2,))
1490  self.assertEqual(LogisticNormal(-0.7, 50.0).sample((1,)).size(), (2,))
1491 
1492  # sample check for extreme value of mean, std
1493  set_rng_seed(1)
1494  self.assertEqual(LogisticNormal(mean_delta, std_delta).sample(),
1495  torch.tensor([math.exp(1) / (1. + 1. + math.exp(1)),
1496  1. / (1. + 1. + math.exp(1)),
1497  1. / (1. + 1. + math.exp(1))]),
1498  prec=1e-4)
1499 
1500  self._gradcheck_log_prob(LogisticNormal, (mean, std))
1501  self._gradcheck_log_prob(LogisticNormal, (mean, 1.0))
1502  self._gradcheck_log_prob(LogisticNormal, (0.0, std))
1503 
1504  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1505  def test_logisticnormal_logprob(self):
1506  mean = torch.randn(5, 7).requires_grad_()
1507  std = torch.randn(5, 7).abs().requires_grad_()
1508 
1509  # Smoke test for now
1510  # TODO: Once _check_log_prob works with multidimensional distributions,
1511  # add proper testing of the log probabilities.
1512  dist = LogisticNormal(mean, std)
1513  assert dist.log_prob(dist.sample()).detach().cpu().numpy().shape == (5,)
1514 
1515  def _get_logistic_normal_ref_sampler(self, base_dist):
1516 
1517  def _sampler(num_samples):
1518  x = base_dist.rvs(num_samples)
1519  offset = np.log((x.shape[-1] + 1) - np.ones_like(x).cumsum(-1))
1520  z = 1. / (1. + np.exp(offset - x))
1521  z_cumprod = np.cumprod(1 - z, axis=-1)
1522  y1 = np.pad(z, ((0, 0), (0, 1)), mode='constant', constant_values=1.)
1523  y2 = np.pad(z_cumprod, ((0, 0), (1, 0)), mode='constant', constant_values=1.)
1524  return y1 * y2
1525 
1526  return _sampler
1527 
1528  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1529  def test_logisticnormal_sample(self):
1530  set_rng_seed(0) # see Note [Randomized statistical tests]
1531  means = map(np.asarray, [(-1.0, -1.0), (0.0, 0.0), (1.0, 1.0)])
1532  covs = map(np.diag, [(0.1, 0.1), (1.0, 1.0), (10.0, 10.0)])
1533  for mean, cov in product(means, covs):
1534  base_dist = scipy.stats.multivariate_normal(mean=mean, cov=cov)
1535  ref_dist = scipy.stats.multivariate_normal(mean=mean, cov=cov)
1536  ref_dist.rvs = self._get_logistic_normal_ref_sampler(base_dist)
1537  mean_th = torch.tensor(mean)
1538  std_th = torch.tensor(np.sqrt(np.diag(cov)))
1539  self._check_sampler_sampler(
1540  LogisticNormal(mean_th, std_th), ref_dist,
1541  'LogisticNormal(loc={}, scale={})'.format(mean_th, std_th),
1542  multivariate=True)
1543 
1544  def test_normal(self):
1545  loc = torch.randn(5, 5, requires_grad=True)
1546  scale = torch.randn(5, 5).abs().requires_grad_()
1547  loc_1d = torch.randn(1, requires_grad=True)
1548  scale_1d = torch.randn(1).abs().requires_grad_()
1549  loc_delta = torch.tensor([1.0, 0.0])
1550  scale_delta = torch.tensor([1e-5, 1e-5])
1551  self.assertEqual(Normal(loc, scale).sample().size(), (5, 5))
1552  self.assertEqual(Normal(loc, scale).sample((7,)).size(), (7, 5, 5))
1553  self.assertEqual(Normal(loc_1d, scale_1d).sample((1,)).size(), (1, 1))
1554  self.assertEqual(Normal(loc_1d, scale_1d).sample().size(), (1,))
1555  self.assertEqual(Normal(0.2, .6).sample((1,)).size(), (1,))
1556  self.assertEqual(Normal(-0.7, 50.0).sample((1,)).size(), (1,))
1557 
1558  # sample check for extreme value of mean, std
1559  set_rng_seed(1)
1560  self.assertEqual(Normal(loc_delta, scale_delta).sample(sample_shape=(1, 2)),
1561  torch.tensor([[[1.0, 0.0], [1.0, 0.0]]]),
1562  prec=1e-4)
1563 
1564  self._gradcheck_log_prob(Normal, (loc, scale))
1565  self._gradcheck_log_prob(Normal, (loc, 1.0))
1566  self._gradcheck_log_prob(Normal, (0.0, scale))
1567 
1568  state = torch.get_rng_state()
1569  eps = torch.normal(torch.zeros_like(loc), torch.ones_like(scale))
1570  torch.set_rng_state(state)
1571  z = Normal(loc, scale).rsample()
1572  z.backward(torch.ones_like(z))
1573  self.assertEqual(loc.grad, torch.ones_like(loc))
1574  self.assertEqual(scale.grad, eps)
1575  loc.grad.zero_()
1576  scale.grad.zero_()
1577  self.assertEqual(z.size(), (5, 5))
1578 
1579  def ref_log_prob(idx, x, log_prob):
1580  m = loc.view(-1)[idx]
1581  s = scale.view(-1)[idx]
1582  expected = (math.exp(-(x - m) ** 2 / (2 * s ** 2)) /
1583  math.sqrt(2 * math.pi * s ** 2))
1584  self.assertAlmostEqual(log_prob, math.log(expected), places=3)
1585 
1586  self._check_log_prob(Normal(loc, scale), ref_log_prob)
1587 
1588  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1589  def test_normal_sample(self):
1590  set_rng_seed(0) # see Note [Randomized statistical tests]
1591  for loc, scale in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
1592  self._check_sampler_sampler(Normal(loc, scale),
1593  scipy.stats.norm(loc=loc, scale=scale),
1594  'Normal(mean={}, std={})'.format(loc, scale))
1595 
1596  def test_lowrank_multivariate_normal_shape(self):
1597  mean = torch.randn(5, 3, requires_grad=True)
1598  mean_no_batch = torch.randn(3, requires_grad=True)
1599  mean_multi_batch = torch.randn(6, 5, 3, requires_grad=True)
1600 
1601  # construct PSD covariance
1602  cov_factor = torch.randn(3, 1, requires_grad=True)
1603  cov_diag = torch.randn(3).abs().requires_grad_()
1604 
1605  # construct batch of PSD covariances
1606  cov_factor_batched = torch.randn(6, 5, 3, 2, requires_grad=True)
1607  cov_diag_batched = torch.randn(6, 5, 3).abs().requires_grad_()
1608 
1609  # ensure that sample, batch, event shapes all handled correctly
1610  self.assertEqual(LowRankMultivariateNormal(mean, cov_factor, cov_diag)
1611  .sample().size(), (5, 3))
1612  self.assertEqual(LowRankMultivariateNormal(mean_no_batch, cov_factor, cov_diag)
1613  .sample().size(), (3,))
1614  self.assertEqual(LowRankMultivariateNormal(mean_multi_batch, cov_factor, cov_diag)
1615  .sample().size(), (6, 5, 3))
1616  self.assertEqual(LowRankMultivariateNormal(mean, cov_factor, cov_diag)
1617  .sample((2,)).size(), (2, 5, 3))
1618  self.assertEqual(LowRankMultivariateNormal(mean_no_batch, cov_factor, cov_diag)
1619  .sample((2,)).size(), (2, 3))
1620  self.assertEqual(LowRankMultivariateNormal(mean_multi_batch, cov_factor, cov_diag)
1621  .sample((2,)).size(), (2, 6, 5, 3))
1622  self.assertEqual(LowRankMultivariateNormal(mean, cov_factor, cov_diag)
1623  .sample((2, 7)).size(), (2, 7, 5, 3))
1624  self.assertEqual(LowRankMultivariateNormal(mean_no_batch, cov_factor, cov_diag)
1625  .sample((2, 7)).size(), (2, 7, 3))
1626  self.assertEqual(LowRankMultivariateNormal(mean_multi_batch, cov_factor, cov_diag)
1627  .sample((2, 7)).size(), (2, 7, 6, 5, 3))
1628  self.assertEqual(LowRankMultivariateNormal(mean, cov_factor_batched, cov_diag_batched)
1629  .sample((2, 7)).size(), (2, 7, 6, 5, 3))
1630  self.assertEqual(LowRankMultivariateNormal(mean_no_batch, cov_factor_batched, cov_diag_batched)
1631  .sample((2, 7)).size(), (2, 7, 6, 5, 3))
1632  self.assertEqual(LowRankMultivariateNormal(mean_multi_batch, cov_factor_batched, cov_diag_batched)
1633  .sample((2, 7)).size(), (2, 7, 6, 5, 3))
1634 
1635  # check gradients
1636  self._gradcheck_log_prob(LowRankMultivariateNormal,
1637  (mean, cov_factor, cov_diag))
1638  self._gradcheck_log_prob(LowRankMultivariateNormal,
1639  (mean_multi_batch, cov_factor, cov_diag))
1640  self._gradcheck_log_prob(LowRankMultivariateNormal,
1641  (mean_multi_batch, cov_factor_batched, cov_diag_batched))
1642 
1643  @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
1644  def test_lowrank_multivariate_normal_log_prob(self):
1645  mean = torch.randn(3, requires_grad=True)
1646  cov_factor = torch.randn(3, 1, requires_grad=True)
1647  cov_diag = torch.randn(3).abs().requires_grad_()
1648  cov = cov_factor.matmul(cov_factor.t()) + cov_diag.diag()
1649 
1650  # check that logprob values match scipy logpdf,
1651  # and that covariance and scale_tril parameters are equivalent
1652  dist1 = LowRankMultivariateNormal(mean, cov_factor, cov_diag)
1653  ref_dist = scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy())
1654 
1655  x = dist1.sample((10,))
1656  expected = ref_dist.logpdf(x.numpy())
1657 
1658  self.assertAlmostEqual(0.0, np.mean((dist1.log_prob(x).detach().numpy() - expected)**2), places=3)
1659 
1660  # Double-check that batched versions behave the same as unbatched
1661  mean = torch.randn(5, 3, requires_grad=True)
1662  cov_factor = torch.randn(5, 3, 2, requires_grad=True)
1663  cov_diag = torch.randn(5, 3).abs().requires_grad_()
1664 
1665  dist_batched = LowRankMultivariateNormal(mean, cov_factor, cov_diag)
1666  dist_unbatched = [LowRankMultivariateNormal(mean[i], cov_factor[i], cov_diag[i])
1667  for i in range(mean.size(0))]
1668 
1669  x = dist_batched.sample((10,))
1670  batched_prob = dist_batched.log_prob(x)
1671  unbatched_prob = torch.stack([dist_unbatched[i].log_prob(x[:, i]) for i in range(5)]).t()
1672 
1673  self.assertEqual(batched_prob.shape, unbatched_prob.shape)
1674  self.assertAlmostEqual(0.0, (batched_prob - unbatched_prob).abs().max(), places=3)
1675 
1676  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1677  def test_lowrank_multivariate_normal_sample(self):
1678  set_rng_seed(0) # see Note [Randomized statistical tests]
1679  mean = torch.randn(5, requires_grad=True)
1680  cov_factor = torch.randn(5, 1, requires_grad=True)
1681  cov_diag = torch.randn(5).abs().requires_grad_()
1682  cov = cov_factor.matmul(cov_factor.t()) + cov_diag.diag()
1683 
1684  self._check_sampler_sampler(LowRankMultivariateNormal(mean, cov_factor, cov_diag),
1685  scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy()),
1686  'LowRankMultivariateNormal(loc={}, cov_factor={}, cov_diag={})'
1687  .format(mean, cov_factor, cov_diag), multivariate=True)
1688 
1689  def test_lowrank_multivariate_normal_properties(self):
1690  loc = torch.randn(5)
1691  cov_factor = torch.randn(5, 2)
1692  cov_diag = torch.randn(5).abs()
1693  cov = cov_factor.matmul(cov_factor.t()) + cov_diag.diag()
1694  m1 = LowRankMultivariateNormal(loc, cov_factor, cov_diag)
1695  m2 = MultivariateNormal(loc=loc, covariance_matrix=cov)
1696  self.assertEqual(m1.mean, m2.mean)
1697  self.assertEqual(m1.variance, m2.variance)
1698  self.assertEqual(m1.covariance_matrix, m2.covariance_matrix)
1699  self.assertEqual(m1.scale_tril, m2.scale_tril)
1700  self.assertEqual(m1.precision_matrix, m2.precision_matrix)
1701  self.assertEqual(m1.entropy(), m2.entropy())
1702 
1703  def test_lowrank_multivariate_normal_moments(self):
1704  set_rng_seed(0) # see Note [Randomized statistical tests]
1705  mean = torch.randn(5)
1706  cov_factor = torch.randn(5, 2)
1707  cov_diag = torch.randn(5).abs()
1708  d = LowRankMultivariateNormal(mean, cov_factor, cov_diag)
1709  samples = d.rsample((100000,))
1710  empirical_mean = samples.mean(0)
1711  self.assertEqual(d.mean, empirical_mean, prec=0.01)
1712  empirical_var = samples.var(0)
1713  self.assertEqual(d.variance, empirical_var, prec=0.02)
1714 
1715  def test_multivariate_normal_shape(self):
1716  mean = torch.randn(5, 3, requires_grad=True)
1717  mean_no_batch = torch.randn(3, requires_grad=True)
1718  mean_multi_batch = torch.randn(6, 5, 3, requires_grad=True)
1719 
1720  # construct PSD covariance
1721  tmp = torch.randn(3, 10)
1722  cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
1723  prec = cov.inverse().requires_grad_()
1724  scale_tril = torch.cholesky(cov, upper=False).requires_grad_()
1725 
1726  # construct batch of PSD covariances
1727  tmp = torch.randn(6, 5, 3, 10)
1728  cov_batched = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_()
1729  prec_batched = [C.inverse() for C in cov_batched.view((-1, 3, 3))]
1730  prec_batched = torch.stack(prec_batched).view(cov_batched.shape)
1731  scale_tril_batched = [torch.cholesky(C, upper=False) for C in cov_batched.view((-1, 3, 3))]
1732  scale_tril_batched = torch.stack(scale_tril_batched).view(cov_batched.shape)
1733 
1734  # ensure that sample, batch, event shapes all handled correctly
1735  self.assertEqual(MultivariateNormal(mean, cov).sample().size(), (5, 3))
1736  self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample().size(), (3,))
1737  self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample().size(), (6, 5, 3))
1738  self.assertEqual(MultivariateNormal(mean, cov).sample((2,)).size(), (2, 5, 3))
1739  self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample((2,)).size(), (2, 3))
1740  self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample((2,)).size(), (2, 6, 5, 3))
1741  self.assertEqual(MultivariateNormal(mean, cov).sample((2, 7)).size(), (2, 7, 5, 3))
1742  self.assertEqual(MultivariateNormal(mean_no_batch, cov).sample((2, 7)).size(), (2, 7, 3))
1743  self.assertEqual(MultivariateNormal(mean_multi_batch, cov).sample((2, 7)).size(), (2, 7, 6, 5, 3))
1744  self.assertEqual(MultivariateNormal(mean, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3))
1745  self.assertEqual(MultivariateNormal(mean_no_batch, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3))
1746  self.assertEqual(MultivariateNormal(mean_multi_batch, cov_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3))
1747  self.assertEqual(MultivariateNormal(mean, precision_matrix=prec).sample((2, 7)).size(), (2, 7, 5, 3))
1748  self.assertEqual(MultivariateNormal(mean, precision_matrix=prec_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3))
1749  self.assertEqual(MultivariateNormal(mean, scale_tril=scale_tril).sample((2, 7)).size(), (2, 7, 5, 3))
1750  self.assertEqual(MultivariateNormal(mean, scale_tril=scale_tril_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3))
1751 
1752  # check gradients
1753  self._gradcheck_log_prob(MultivariateNormal, (mean, cov))
1754  self._gradcheck_log_prob(MultivariateNormal, (mean_multi_batch, cov))
1755  self._gradcheck_log_prob(MultivariateNormal, (mean_multi_batch, cov_batched))
1756  self._gradcheck_log_prob(MultivariateNormal, (mean, None, prec))
1757  self._gradcheck_log_prob(MultivariateNormal, (mean_no_batch, None, prec_batched))
1758  self._gradcheck_log_prob(MultivariateNormal, (mean, None, None, scale_tril))
1759  self._gradcheck_log_prob(MultivariateNormal, (mean_no_batch, None, None, scale_tril_batched))
1760 
1761  @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
1762  def test_multivariate_normal_log_prob(self):
1763  mean = torch.randn(3, requires_grad=True)
1764  tmp = torch.randn(3, 10)
1765  cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
1766  prec = cov.inverse().requires_grad_()
1767  scale_tril = torch.cholesky(cov, upper=False).requires_grad_()
1768 
1769  # check that logprob values match scipy logpdf,
1770  # and that covariance and scale_tril parameters are equivalent
1771  dist1 = MultivariateNormal(mean, cov)
1772  dist2 = MultivariateNormal(mean, precision_matrix=prec)
1773  dist3 = MultivariateNormal(mean, scale_tril=scale_tril)
1774  ref_dist = scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy())
1775 
1776  x = dist1.sample((10,))
1777  expected = ref_dist.logpdf(x.numpy())
1778 
1779  self.assertAlmostEqual(0.0, np.mean((dist1.log_prob(x).detach().numpy() - expected)**2), places=3)
1780  self.assertAlmostEqual(0.0, np.mean((dist2.log_prob(x).detach().numpy() - expected)**2), places=3)
1781  self.assertAlmostEqual(0.0, np.mean((dist3.log_prob(x).detach().numpy() - expected)**2), places=3)
1782 
1783  # Double-check that batched versions behave the same as unbatched
1784  mean = torch.randn(5, 3, requires_grad=True)
1785  tmp = torch.randn(5, 3, 10)
1786  cov = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_()
1787 
1788  dist_batched = MultivariateNormal(mean, cov)
1789  dist_unbatched = [MultivariateNormal(mean[i], cov[i]) for i in range(mean.size(0))]
1790 
1791  x = dist_batched.sample((10,))
1792  batched_prob = dist_batched.log_prob(x)
1793  unbatched_prob = torch.stack([dist_unbatched[i].log_prob(x[:, i]) for i in range(5)]).t()
1794 
1795  self.assertEqual(batched_prob.shape, unbatched_prob.shape)
1796  self.assertAlmostEqual(0.0, (batched_prob - unbatched_prob).abs().max(), places=3)
1797 
1798  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1799  def test_multivariate_normal_sample(self):
1800  set_rng_seed(0) # see Note [Randomized statistical tests]
1801  mean = torch.randn(3, requires_grad=True)
1802  tmp = torch.randn(3, 10)
1803  cov = (torch.matmul(tmp, tmp.t()) / tmp.size(-1)).requires_grad_()
1804  prec = cov.inverse().requires_grad_()
1805  scale_tril = torch.cholesky(cov, upper=False).requires_grad_()
1806 
1807  self._check_sampler_sampler(MultivariateNormal(mean, cov),
1808  scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy()),
1809  'MultivariateNormal(loc={}, cov={})'.format(mean, cov),
1810  multivariate=True)
1811  self._check_sampler_sampler(MultivariateNormal(mean, precision_matrix=prec),
1812  scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy()),
1813  'MultivariateNormal(loc={}, prec={})'.format(mean, prec),
1814  multivariate=True)
1815  self._check_sampler_sampler(MultivariateNormal(mean, scale_tril=scale_tril),
1816  scipy.stats.multivariate_normal(mean.detach().numpy(), cov.detach().numpy()),
1817  'MultivariateNormal(loc={}, scale_tril={})'.format(mean, scale_tril),
1818  multivariate=True)
1819 
1820  def test_multivariate_normal_properties(self):
1821  loc = torch.randn(5)
1822  scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(5, 5))
1823  m = MultivariateNormal(loc=loc, scale_tril=scale_tril)
1824  self.assertEqual(m.covariance_matrix, m.scale_tril.mm(m.scale_tril.t()))
1825  self.assertEqual(m.covariance_matrix.mm(m.precision_matrix), torch.eye(m.event_shape[0]))
1826  self.assertEqual(m.scale_tril, torch.cholesky(m.covariance_matrix, upper=False))
1827 
1828  def test_multivariate_normal_moments(self):
1829  set_rng_seed(0) # see Note [Randomized statistical tests]
1830  mean = torch.randn(5)
1831  scale_tril = transform_to(constraints.lower_cholesky)(torch.randn(5, 5))
1832  d = MultivariateNormal(mean, scale_tril=scale_tril)
1833  samples = d.rsample((100000,))
1834  empirical_mean = samples.mean(0)
1835  self.assertEqual(d.mean, empirical_mean, prec=0.01)
1836  empirical_var = samples.var(0)
1837  self.assertEqual(d.variance, empirical_var, prec=0.05)
1838 
1839  def test_exponential(self):
1840  rate = torch.randn(5, 5).abs().requires_grad_()
1841  rate_1d = torch.randn(1).abs().requires_grad_()
1842  self.assertEqual(Exponential(rate).sample().size(), (5, 5))
1843  self.assertEqual(Exponential(rate).sample((7,)).size(), (7, 5, 5))
1844  self.assertEqual(Exponential(rate_1d).sample((1,)).size(), (1, 1))
1845  self.assertEqual(Exponential(rate_1d).sample().size(), (1,))
1846  self.assertEqual(Exponential(0.2).sample((1,)).size(), (1,))
1847  self.assertEqual(Exponential(50.0).sample((1,)).size(), (1,))
1848 
1849  self._gradcheck_log_prob(Exponential, (rate,))
1850  state = torch.get_rng_state()
1851  eps = rate.new(rate.size()).exponential_()
1852  torch.set_rng_state(state)
1853  z = Exponential(rate).rsample()
1854  z.backward(torch.ones_like(z))
1855  self.assertEqual(rate.grad, -eps / rate**2)
1856  rate.grad.zero_()
1857  self.assertEqual(z.size(), (5, 5))
1858 
1859  def ref_log_prob(idx, x, log_prob):
1860  m = rate.view(-1)[idx]
1861  expected = math.log(m) - m * x
1862  self.assertAlmostEqual(log_prob, expected, places=3)
1863 
1864  self._check_log_prob(Exponential(rate), ref_log_prob)
1865 
1866  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1867  def test_exponential_sample(self):
1868  set_rng_seed(1) # see Note [Randomized statistical tests]
1869  for rate in [1e-5, 1.0, 10.]:
1870  self._check_sampler_sampler(Exponential(rate),
1871  scipy.stats.expon(scale=1. / rate),
1872  'Exponential(rate={})'.format(rate))
1873 
1874  def test_laplace(self):
1875  loc = torch.randn(5, 5, requires_grad=True)
1876  scale = torch.randn(5, 5).abs().requires_grad_()
1877  loc_1d = torch.randn(1, requires_grad=True)
1878  scale_1d = torch.randn(1, requires_grad=True)
1879  loc_delta = torch.tensor([1.0, 0.0])
1880  scale_delta = torch.tensor([1e-5, 1e-5])
1881  self.assertEqual(Laplace(loc, scale).sample().size(), (5, 5))
1882  self.assertEqual(Laplace(loc, scale).sample((7,)).size(), (7, 5, 5))
1883  self.assertEqual(Laplace(loc_1d, scale_1d).sample((1,)).size(), (1, 1))
1884  self.assertEqual(Laplace(loc_1d, scale_1d).sample().size(), (1,))
1885  self.assertEqual(Laplace(0.2, .6).sample((1,)).size(), (1,))
1886  self.assertEqual(Laplace(-0.7, 50.0).sample((1,)).size(), (1,))
1887 
1888  # sample check for extreme value of mean, std
1889  set_rng_seed(0)
1890  self.assertEqual(Laplace(loc_delta, scale_delta).sample(sample_shape=(1, 2)),
1891  torch.tensor([[[1.0, 0.0], [1.0, 0.0]]]),
1892  prec=1e-4)
1893 
1894  self._gradcheck_log_prob(Laplace, (loc, scale))
1895  self._gradcheck_log_prob(Laplace, (loc, 1.0))
1896  self._gradcheck_log_prob(Laplace, (0.0, scale))
1897 
1898  state = torch.get_rng_state()
1899  eps = torch.ones_like(loc).uniform_(-.5, .5)
1900  torch.set_rng_state(state)
1901  z = Laplace(loc, scale).rsample()
1902  z.backward(torch.ones_like(z))
1903  self.assertEqual(loc.grad, torch.ones_like(loc))
1904  self.assertEqual(scale.grad, -eps.sign() * torch.log1p(-2 * eps.abs()))
1905  loc.grad.zero_()
1906  scale.grad.zero_()
1907  self.assertEqual(z.size(), (5, 5))
1908 
1909  def ref_log_prob(idx, x, log_prob):
1910  m = loc.view(-1)[idx]
1911  s = scale.view(-1)[idx]
1912  expected = (-math.log(2 * s) - abs(x - m) / s)
1913  self.assertAlmostEqual(log_prob, expected, places=3)
1914 
1915  self._check_log_prob(Laplace(loc, scale), ref_log_prob)
1916 
1917  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1918  def test_laplace_sample(self):
1919  set_rng_seed(1) # see Note [Randomized statistical tests]
1920  for loc, scale in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
1921  self._check_sampler_sampler(Laplace(loc, scale),
1922  scipy.stats.laplace(loc=loc, scale=scale),
1923  'Laplace(loc={}, scale={})'.format(loc, scale))
1924 
1925  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1926  def test_gamma_shape(self):
1927  alpha = torch.randn(2, 3).exp().requires_grad_()
1928  beta = torch.randn(2, 3).exp().requires_grad_()
1929  alpha_1d = torch.randn(1).exp().requires_grad_()
1930  beta_1d = torch.randn(1).exp().requires_grad_()
1931  self.assertEqual(Gamma(alpha, beta).sample().size(), (2, 3))
1932  self.assertEqual(Gamma(alpha, beta).sample((5,)).size(), (5, 2, 3))
1933  self.assertEqual(Gamma(alpha_1d, beta_1d).sample((1,)).size(), (1, 1))
1934  self.assertEqual(Gamma(alpha_1d, beta_1d).sample().size(), (1,))
1935  self.assertEqual(Gamma(0.5, 0.5).sample().size(), ())
1936  self.assertEqual(Gamma(0.5, 0.5).sample((1,)).size(), (1,))
1937 
1938  def ref_log_prob(idx, x, log_prob):
1939  a = alpha.view(-1)[idx].detach()
1940  b = beta.view(-1)[idx].detach()
1941  expected = scipy.stats.gamma.logpdf(x, a, scale=1 / b)
1942  self.assertAlmostEqual(log_prob, expected, places=3)
1943 
1944  self._check_log_prob(Gamma(alpha, beta), ref_log_prob)
1945 
1946  @unittest.skipIf(not TEST_CUDA, "CUDA not found")
1947  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1948  def test_gamma_gpu_shape(self):
1949  alpha = torch.randn(2, 3).cuda().exp().requires_grad_()
1950  beta = torch.randn(2, 3).cuda().exp().requires_grad_()
1951  alpha_1d = torch.randn(1).cuda().exp().requires_grad_()
1952  beta_1d = torch.randn(1).cuda().exp().requires_grad_()
1953  self.assertEqual(Gamma(alpha, beta).sample().size(), (2, 3))
1954  self.assertEqual(Gamma(alpha, beta).sample((5,)).size(), (5, 2, 3))
1955  self.assertEqual(Gamma(alpha_1d, beta_1d).sample((1,)).size(), (1, 1))
1956  self.assertEqual(Gamma(alpha_1d, beta_1d).sample().size(), (1,))
1957  self.assertEqual(Gamma(0.5, 0.5).sample().size(), ())
1958  self.assertEqual(Gamma(0.5, 0.5).sample((1,)).size(), (1,))
1959 
1960  def ref_log_prob(idx, x, log_prob):
1961  a = alpha.view(-1)[idx].detach().cpu()
1962  b = beta.view(-1)[idx].detach().cpu()
1963  expected = scipy.stats.gamma.logpdf(x.cpu(), a, scale=1 / b)
1964  self.assertAlmostEqual(log_prob, expected, places=3)
1965 
1966  self._check_log_prob(Gamma(alpha, beta), ref_log_prob)
1967 
1968  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1969  def test_gamma_sample(self):
1970  set_rng_seed(0) # see Note [Randomized statistical tests]
1971  for alpha, beta in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
1972  self._check_sampler_sampler(Gamma(alpha, beta),
1973  scipy.stats.gamma(alpha, scale=1.0 / beta),
1974  'Gamma(concentration={}, rate={})'.format(alpha, beta))
1975 
1976  @unittest.skipIf(not TEST_CUDA, "CUDA not found")
1977  @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
1978  @skipIfRocm
1979  def test_gamma_gpu_sample(self):
1980  set_rng_seed(0)
1981  for alpha, beta in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
1982  a, b = torch.tensor([alpha]).cuda(), torch.tensor([beta]).cuda()
1983  self._check_sampler_sampler(Gamma(a, b),
1984  scipy.stats.gamma(alpha, scale=1.0 / beta),
1985  'Gamma(alpha={}, beta={})'.format(alpha, beta),
1986  failure_rate=1e-4)
1987 
1988  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
1989  def test_pareto(self):
1990  scale = torch.randn(2, 3).abs().requires_grad_()
1991  alpha = torch.randn(2, 3).abs().requires_grad_()
1992  scale_1d = torch.randn(1).abs().requires_grad_()
1993  alpha_1d = torch.randn(1).abs().requires_grad_()
1994  self.assertEqual(Pareto(scale_1d, 0.5).mean, inf, allow_inf=True)
1995  self.assertEqual(Pareto(scale_1d, 0.5).variance, inf, allow_inf=True)
1996  self.assertEqual(Pareto(scale, alpha).sample().size(), (2, 3))
1997  self.assertEqual(Pareto(scale, alpha).sample((5,)).size(), (5, 2, 3))
1998  self.assertEqual(Pareto(scale_1d, alpha_1d).sample((1,)).size(), (1, 1))
1999  self.assertEqual(Pareto(scale_1d, alpha_1d).sample().size(), (1,))
2000  self.assertEqual(Pareto(1.0, 1.0).sample().size(), ())
2001  self.assertEqual(Pareto(1.0, 1.0).sample((1,)).size(), (1,))
2002 
2003  def ref_log_prob(idx, x, log_prob):
2004  s = scale.view(-1)[idx].detach()
2005  a = alpha.view(-1)[idx].detach()
2006  expected = scipy.stats.pareto.logpdf(x, a, scale=s)
2007  self.assertAlmostEqual(log_prob, expected, places=3)
2008 
2009  self._check_log_prob(Pareto(scale, alpha), ref_log_prob)
2010 
2011  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2012  def test_pareto_sample(self):
2013  set_rng_seed(1) # see Note [Randomized statistical tests]
2014  for scale, alpha in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
2015  self._check_sampler_sampler(Pareto(scale, alpha),
2016  scipy.stats.pareto(alpha, scale=scale),
2017  'Pareto(scale={}, alpha={})'.format(scale, alpha))
2018 
2019  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2020  def test_gumbel(self):
2021  loc = torch.randn(2, 3, requires_grad=True)
2022  scale = torch.randn(2, 3).abs().requires_grad_()
2023  loc_1d = torch.randn(1, requires_grad=True)
2024  scale_1d = torch.randn(1).abs().requires_grad_()
2025  self.assertEqual(Gumbel(loc, scale).sample().size(), (2, 3))
2026  self.assertEqual(Gumbel(loc, scale).sample((5,)).size(), (5, 2, 3))
2027  self.assertEqual(Gumbel(loc_1d, scale_1d).sample().size(), (1,))
2028  self.assertEqual(Gumbel(loc_1d, scale_1d).sample((1,)).size(), (1, 1))
2029  self.assertEqual(Gumbel(1.0, 1.0).sample().size(), ())
2030  self.assertEqual(Gumbel(1.0, 1.0).sample((1,)).size(), (1,))
2031 
2032  def ref_log_prob(idx, x, log_prob):
2033  l = loc.view(-1)[idx].detach()
2034  s = scale.view(-1)[idx].detach()
2035  expected = scipy.stats.gumbel_r.logpdf(x, loc=l, scale=s)
2036  self.assertAlmostEqual(log_prob, expected, places=3)
2037 
2038  self._check_log_prob(Gumbel(loc, scale), ref_log_prob)
2039 
2040  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2041  def test_gumbel_sample(self):
2042  set_rng_seed(1) # see note [Randomized statistical tests]
2043  for loc, scale in product([-5.0, -1.0, -0.1, 0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
2044  self._check_sampler_sampler(Gumbel(loc, scale),
2045  scipy.stats.gumbel_r(loc=loc, scale=scale),
2046  'Gumbel(loc={}, scale={})'.format(loc, scale))
2047 
2048  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2049  def test_fishersnedecor(self):
2050  df1 = torch.randn(2, 3).abs().requires_grad_()
2051  df2 = torch.randn(2, 3).abs().requires_grad_()
2052  df1_1d = torch.randn(1).abs()
2053  df2_1d = torch.randn(1).abs()
2054  self.assertTrue(is_all_nan(FisherSnedecor(1, 2).mean))
2055  self.assertTrue(is_all_nan(FisherSnedecor(1, 4).variance))
2056  self.assertEqual(FisherSnedecor(df1, df2).sample().size(), (2, 3))
2057  self.assertEqual(FisherSnedecor(df1, df2).sample((5,)).size(), (5, 2, 3))
2058  self.assertEqual(FisherSnedecor(df1_1d, df2_1d).sample().size(), (1,))
2059  self.assertEqual(FisherSnedecor(df1_1d, df2_1d).sample((1,)).size(), (1, 1))
2060  self.assertEqual(FisherSnedecor(1.0, 1.0).sample().size(), ())
2061  self.assertEqual(FisherSnedecor(1.0, 1.0).sample((1,)).size(), (1,))
2062 
2063  def ref_log_prob(idx, x, log_prob):
2064  f1 = df1.view(-1)[idx].detach()
2065  f2 = df2.view(-1)[idx].detach()
2066  expected = scipy.stats.f.logpdf(x, f1, f2)
2067  self.assertAlmostEqual(log_prob, expected, places=3)
2068 
2069  self._check_log_prob(FisherSnedecor(df1, df2), ref_log_prob)
2070 
2071  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2072  def test_fishersnedecor_sample(self):
2073  set_rng_seed(1) # see note [Randomized statistical tests]
2074  for df1, df2 in product([0.1, 0.5, 1.0, 5.0, 10.0], [0.1, 0.5, 1.0, 5.0, 10.0]):
2075  self._check_sampler_sampler(FisherSnedecor(df1, df2),
2076  scipy.stats.f(df1, df2),
2077  'FisherSnedecor(loc={}, scale={})'.format(df1, df2))
2078 
2079  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2080  def test_chi2_shape(self):
2081  df = torch.randn(2, 3).exp().requires_grad_()
2082  df_1d = torch.randn(1).exp().requires_grad_()
2083  self.assertEqual(Chi2(df).sample().size(), (2, 3))
2084  self.assertEqual(Chi2(df).sample((5,)).size(), (5, 2, 3))
2085  self.assertEqual(Chi2(df_1d).sample((1,)).size(), (1, 1))
2086  self.assertEqual(Chi2(df_1d).sample().size(), (1,))
2087  self.assertEqual(Chi2(torch.tensor(0.5, requires_grad=True)).sample().size(), ())
2088  self.assertEqual(Chi2(0.5).sample().size(), ())
2089  self.assertEqual(Chi2(0.5).sample((1,)).size(), (1,))
2090 
2091  def ref_log_prob(idx, x, log_prob):
2092  d = df.view(-1)[idx].detach()
2093  expected = scipy.stats.chi2.logpdf(x, d)
2094  self.assertAlmostEqual(log_prob, expected, places=3)
2095 
2096  self._check_log_prob(Chi2(df), ref_log_prob)
2097 
2098  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2099  def test_chi2_sample(self):
2100  set_rng_seed(0) # see Note [Randomized statistical tests]
2101  for df in [0.1, 1.0, 5.0]:
2102  self._check_sampler_sampler(Chi2(df),
2103  scipy.stats.chi2(df),
2104  'Chi2(df={})'.format(df))
2105 
2106  @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
2107  def test_studentT(self):
2108  df = torch.randn(2, 3).exp().requires_grad_()
2109  df_1d = torch.randn(1).exp().requires_grad_()
2110  self.assertTrue(is_all_nan(StudentT(1).mean))
2111  self.assertTrue(is_all_nan(StudentT(1).variance))
2112  self.assertEqual(StudentT(2).variance, inf, allow_inf=True)
2113  self.assertEqual(StudentT(df).sample().size(), (2, 3))
2114  self.assertEqual(StudentT(df).sample((5,)).size(), (5, 2, 3))
2115  self.assertEqual(StudentT(df_1d).sample((1,)).size(), (1, 1))
2116  self.assertEqual(StudentT(df_1d).sample().size(), (1,))
2117  self.assertEqual(StudentT(torch.tensor(0.5, requires_grad=True)).sample().size(), ())
2118  self.assertEqual(StudentT(0.5).sample().size(), ())
2119  self.assertEqual(StudentT(0.5).sample((1,)).size(), (1,))
2120 
2121  def ref_log_prob(idx, x, log_prob):
2122  d = df.view(-1)[idx].detach()
2123  expected = scipy.stats.t.logpdf(x, d)
2124  self.assertAlmostEqual(log_prob, expected, places=3)
2125 
2126  self._check_log_prob(StudentT(df), ref_log_prob)
2127 
2128  @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
2129  def test_studentT_sample(self):
2130  set_rng_seed(11) # see Note [Randomized statistical tests]
2131  for df, loc, scale in product([0.1, 1.0, 5.0, 10.0], [-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
2132  self._check_sampler_sampler(StudentT(df=df, loc=loc, scale=scale),
2133  scipy.stats.t(df=df, loc=loc, scale=scale),
2134  'StudentT(df={}, loc={}, scale={})'.format(df, loc, scale))
2135 
2136  @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
2137  def test_studentT_log_prob(self):
2138  set_rng_seed(0) # see Note [Randomized statistical tests]
2139  num_samples = 10
2140  for df, loc, scale in product([0.1, 1.0, 5.0, 10.0], [-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
2141  dist = StudentT(df=df, loc=loc, scale=scale)
2142  x = dist.sample((num_samples,))
2143  actual_log_prob = dist.log_prob(x)
2144  for i in range(num_samples):
2145  expected_log_prob = scipy.stats.t.logpdf(x[i], df=df, loc=loc, scale=scale)
2146  self.assertAlmostEqual(float(actual_log_prob[i]), float(expected_log_prob), places=3)
2147 
2148  def test_dirichlet_shape(self):
2149  alpha = torch.randn(2, 3).exp().requires_grad_()
2150  alpha_1d = torch.randn(4).exp().requires_grad_()
2151  self.assertEqual(Dirichlet(alpha).sample().size(), (2, 3))
2152  self.assertEqual(Dirichlet(alpha).sample((5,)).size(), (5, 2, 3))
2153  self.assertEqual(Dirichlet(alpha_1d).sample().size(), (4,))
2154  self.assertEqual(Dirichlet(alpha_1d).sample((1,)).size(), (1, 4))
2155 
2156  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2157  def test_dirichlet_log_prob(self):
2158  num_samples = 10
2159  alpha = torch.exp(torch.randn(5))
2160  dist = Dirichlet(alpha)
2161  x = dist.sample((num_samples,))
2162  actual_log_prob = dist.log_prob(x)
2163  for i in range(num_samples):
2164  expected_log_prob = scipy.stats.dirichlet.logpdf(x[i].numpy(), alpha.numpy())
2165  self.assertAlmostEqual(actual_log_prob[i], expected_log_prob, places=3)
2166 
2167  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2168  def test_dirichlet_sample(self):
2169  set_rng_seed(0) # see Note [Randomized statistical tests]
2170  alpha = torch.exp(torch.randn(3))
2171  self._check_sampler_sampler(Dirichlet(alpha),
2172  scipy.stats.dirichlet(alpha.numpy()),
2173  'Dirichlet(alpha={})'.format(list(alpha)),
2174  multivariate=True)
2175 
2176  def test_beta_shape(self):
2177  con1 = torch.randn(2, 3).exp().requires_grad_()
2178  con0 = torch.randn(2, 3).exp().requires_grad_()
2179  con1_1d = torch.randn(4).exp().requires_grad_()
2180  con0_1d = torch.randn(4).exp().requires_grad_()
2181  self.assertEqual(Beta(con1, con0).sample().size(), (2, 3))
2182  self.assertEqual(Beta(con1, con0).sample((5,)).size(), (5, 2, 3))
2183  self.assertEqual(Beta(con1_1d, con0_1d).sample().size(), (4,))
2184  self.assertEqual(Beta(con1_1d, con0_1d).sample((1,)).size(), (1, 4))
2185  self.assertEqual(Beta(0.1, 0.3).sample().size(), ())
2186  self.assertEqual(Beta(0.1, 0.3).sample((5,)).size(), (5,))
2187 
2188  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2189  def test_beta_log_prob(self):
2190  for _ in range(100):
2191  con1 = np.exp(np.random.normal())
2192  con0 = np.exp(np.random.normal())
2193  dist = Beta(con1, con0)
2194  x = dist.sample()
2195  actual_log_prob = dist.log_prob(x).sum()
2196  expected_log_prob = scipy.stats.beta.logpdf(x, con1, con0)
2197  self.assertAlmostEqual(float(actual_log_prob), float(expected_log_prob), places=3, allow_inf=True)
2198 
2199  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2200  def test_beta_sample(self):
2201  set_rng_seed(1) # see Note [Randomized statistical tests]
2202  for con1, con0 in product([0.1, 1.0, 10.0], [0.1, 1.0, 10.0]):
2203  self._check_sampler_sampler(Beta(con1, con0),
2204  scipy.stats.beta(con1, con0),
2205  'Beta(alpha={}, beta={})'.format(con1, con0))
2206  # Check that small alphas do not cause NANs.
2207  for Tensor in [torch.FloatTensor, torch.DoubleTensor]:
2208  x = Beta(Tensor([1e-6]), Tensor([1e-6])).sample()[0]
2209  self.assertTrue(np.isfinite(x) and x > 0, 'Invalid Beta.sample(): {}'.format(x))
2210 
2211  def test_beta_underflow(self):
2212  # For low values of (alpha, beta), the gamma samples can underflow
2213  # with float32 and result in a spurious mode at 0.5. To prevent this,
2214  # torch._sample_dirichlet works with double precision for intermediate
2215  # calculations.
2216  set_rng_seed(1)
2217  num_samples = 50000
2218  for dtype in [torch.float, torch.double]:
2219  conc = torch.tensor(1e-2, dtype=dtype)
2220  beta_samples = Beta(conc, conc).sample([num_samples])
2221  self.assertEqual((beta_samples == 0).sum(), 0)
2222  self.assertEqual((beta_samples == 1).sum(), 0)
2223  # assert support is concentrated around 0 and 1
2224  frac_zeros = float((beta_samples < 0.1).sum()) / num_samples
2225  frac_ones = float((beta_samples > 0.9).sum()) / num_samples
2226  self.assertEqual(frac_zeros, 0.5, 0.05)
2227  self.assertEqual(frac_ones, 0.5, 0.05)
2228 
2229  @unittest.skipIf(not TEST_CUDA, "CUDA not found")
2230  def test_beta_underflow_gpu(self):
2231  set_rng_seed(1)
2232  num_samples = 50000
2233  conc = torch.tensor(1e-2, dtype=torch.float64).cuda()
2234  beta_samples = Beta(conc, conc).sample([num_samples])
2235  self.assertEqual((beta_samples == 0).sum(), 0)
2236  self.assertEqual((beta_samples == 1).sum(), 0)
2237  # assert support is concentrated around 0 and 1
2238  frac_zeros = float((beta_samples < 0.1).sum()) / num_samples
2239  frac_ones = float((beta_samples > 0.9).sum()) / num_samples
2240  # TODO: increase precision once imbalance on GPU is fixed.
2241  self.assertEqual(frac_zeros, 0.5, 0.12)
2242  self.assertEqual(frac_ones, 0.5, 0.12)
2243 
2244  def test_independent_shape(self):
2245  for Dist, params in EXAMPLES:
2246  for param in params:
2247  base_dist = Dist(**param)
2248  x = base_dist.sample()
2249  base_log_prob_shape = base_dist.log_prob(x).shape
2250  for reinterpreted_batch_ndims in range(len(base_dist.batch_shape) + 1):
2251  indep_dist = Independent(base_dist, reinterpreted_batch_ndims)
2252  indep_log_prob_shape = base_log_prob_shape[:len(base_log_prob_shape) - reinterpreted_batch_ndims]
2253  self.assertEqual(indep_dist.log_prob(x).shape, indep_log_prob_shape)
2254  self.assertEqual(indep_dist.sample().shape, base_dist.sample().shape)
2255  self.assertEqual(indep_dist.has_rsample, base_dist.has_rsample)
2256  if indep_dist.has_rsample:
2257  self.assertEqual(indep_dist.sample().shape, base_dist.sample().shape)
2258  try:
2259  self.assertEqual(indep_dist.enumerate_support().shape, base_dist.enumerate_support().shape)
2260  self.assertEqual(indep_dist.mean.shape, base_dist.mean.shape)
2261  except NotImplementedError:
2262  pass
2263  try:
2264  self.assertEqual(indep_dist.variance.shape, base_dist.variance.shape)
2265  except NotImplementedError:
2266  pass
2267  try:
2268  self.assertEqual(indep_dist.entropy().shape, indep_log_prob_shape)
2269  except NotImplementedError:
2270  pass
2271 
2272  def test_independent_expand(self):
2273  for Dist, params in EXAMPLES:
2274  for param in params:
2275  base_dist = Dist(**param)
2276  for reinterpreted_batch_ndims in range(len(base_dist.batch_shape) + 1):
2277  for s in [torch.Size(), torch.Size((2,)), torch.Size((2, 3))]:
2278  indep_dist = Independent(base_dist, reinterpreted_batch_ndims)
2279  expanded_shape = s + indep_dist.batch_shape
2280  expanded = indep_dist.expand(expanded_shape)
2281  expanded_sample = expanded.sample()
2282  expected_shape = expanded_shape + indep_dist.event_shape
2283  self.assertEqual(expanded_sample.shape, expected_shape)
2284  self.assertEqual(expanded.log_prob(expanded_sample),
2285  indep_dist.log_prob(expanded_sample))
2286  self.assertEqual(expanded.event_shape, indep_dist.event_shape)
2287  self.assertEqual(expanded.batch_shape, expanded_shape)
2288 
2289  def test_cdf_icdf_inverse(self):
2290  # Tests the invertibility property on the distributions
2291  for Dist, params in EXAMPLES:
2292  for i, param in enumerate(params):
2293  dist = Dist(**param)
2294  samples = dist.sample(sample_shape=(20,))
2295  try:
2296  cdf = dist.cdf(samples)
2297  actual = dist.icdf(cdf)
2298  except NotImplementedError:
2299  continue
2300  rel_error = torch.abs(actual - samples) / (1e-10 + torch.abs(samples))
2301  self.assertLess(rel_error.max(), 1e-4, msg='\n'.join([
2302  '{} example {}/{}, icdf(cdf(x)) != x'.format(Dist.__name__, i + 1, len(params)),
2303  'x = {}'.format(samples),
2304  'cdf(x) = {}'.format(cdf),
2305  'icdf(cdf(x)) = {}'.format(actual),
2306  ]))
2307 
2308  def test_cdf_log_prob(self):
2309  # Tests if the differentiation of the CDF gives the PDF at a given value
2310  for Dist, params in EXAMPLES:
2311  for i, param in enumerate(params):
2312  dist = Dist(**param)
2313  samples = dist.sample()
2314  if samples.dtype.is_floating_point:
2315  samples.requires_grad_()
2316  try:
2317  cdfs = dist.cdf(samples)
2318  pdfs = dist.log_prob(samples).exp()
2319  except NotImplementedError:
2320  continue
2321  cdfs_derivative = grad(cdfs.sum(), [samples])[0] # this should not be wrapped in torch.abs()
2322  self.assertEqual(cdfs_derivative, pdfs, message='\n'.join([
2323  '{} example {}/{}, d(cdf)/dx != pdf(x)'.format(Dist.__name__, i + 1, len(params)),
2324  'x = {}'.format(samples),
2325  'cdf = {}'.format(cdfs),
2326  'pdf = {}'.format(pdfs),
2327  'grad(cdf) = {}'.format(cdfs_derivative),
2328  ]))
2329 
2330  def test_valid_parameter_broadcasting(self):
2331  # Test correct broadcasting of parameter sizes for distributions that have multiple
2332  # parameters.
2333  # example type (distribution instance, expected sample shape)
2334  valid_examples = [
2335  (Normal(loc=torch.tensor([0., 0.]), scale=1),
2336  (2,)),
2337  (Normal(loc=0, scale=torch.tensor([1., 1.])),
2338  (2,)),
2339  (Normal(loc=torch.tensor([0., 0.]), scale=torch.tensor([1.])),
2340  (2,)),
2341  (Normal(loc=torch.tensor([0., 0.]), scale=torch.tensor([[1.], [1.]])),
2342  (2, 2)),
2343  (Normal(loc=torch.tensor([0., 0.]), scale=torch.tensor([[1.]])),
2344  (1, 2)),
2345  (Normal(loc=torch.tensor([0.]), scale=torch.tensor([[1.]])),
2346  (1, 1)),
2347  (FisherSnedecor(df1=torch.tensor([1., 1.]), df2=1),
2348  (2,)),
2349  (FisherSnedecor(df1=1, df2=torch.tensor([1., 1.])),
2350  (2,)),
2351  (FisherSnedecor(df1=torch.tensor([1., 1.]), df2=torch.tensor([1.])),
2352  (2,)),
2353  (FisherSnedecor(df1=torch.tensor([1., 1.]), df2=torch.tensor([[1.], [1.]])),
2354  (2, 2)),
2355  (FisherSnedecor(df1=torch.tensor([1., 1.]), df2=torch.tensor([[1.]])),
2356  (1, 2)),
2357  (FisherSnedecor(df1=torch.tensor([1.]), df2=torch.tensor([[1.]])),
2358  (1, 1)),
2359  (Gamma(concentration=torch.tensor([1., 1.]), rate=1),
2360  (2,)),
2361  (Gamma(concentration=1, rate=torch.tensor([1., 1.])),
2362  (2,)),
2363  (Gamma(concentration=torch.tensor([1., 1.]), rate=torch.tensor([[1.], [1.], [1.]])),
2364  (3, 2)),
2365  (Gamma(concentration=torch.tensor([1., 1.]), rate=torch.tensor([[1.], [1.]])),
2366  (2, 2)),
2367  (Gamma(concentration=torch.tensor([1., 1.]), rate=torch.tensor([[1.]])),
2368  (1, 2)),
2369  (Gamma(concentration=torch.tensor([1.]), rate=torch.tensor([[1.]])),
2370  (1, 1)),
2371  (Gumbel(loc=torch.tensor([0., 0.]), scale=1),
2372  (2,)),
2373  (Gumbel(loc=0, scale=torch.tensor([1., 1.])),
2374  (2,)),
2375  (Gumbel(loc=torch.tensor([0., 0.]), scale=torch.tensor([1.])),
2376  (2,)),
2377  (Gumbel(loc=torch.tensor([0., 0.]), scale=torch.tensor([[1.], [1.]])),
2378  (2, 2)),
2379  (Gumbel(loc=torch.tensor([0., 0.]), scale=torch.tensor([[1.]])),
2380  (1, 2)),
2381  (Gumbel(loc=torch.tensor([0.]), scale=torch.tensor([[1.]])),
2382  (1, 1)),
2383  (Laplace(loc=torch.tensor([0., 0.]), scale=1),
2384  (2,)),
2385  (Laplace(loc=0, scale=torch.tensor([1., 1.])),
2386  (2,)),
2387  (Laplace(loc=torch.tensor([0., 0.]), scale=torch.tensor([1.])),
2388  (2,)),
2389  (Laplace(loc=torch.tensor([0., 0.]), scale=torch.tensor([[1.], [1.]])),
2390  (2, 2)),
2391  (Laplace(loc=torch.tensor([0., 0.]), scale=torch.tensor([[1.]])),
2392  (1, 2)),
2393  (Laplace(loc=torch.tensor([0.]), scale=torch.tensor([[1.]])),
2394  (1, 1)),
2395  (Pareto(scale=torch.tensor([1., 1.]), alpha=1),
2396  (2,)),
2397  (Pareto(scale=1, alpha=torch.tensor([1., 1.])),
2398  (2,)),
2399  (Pareto(scale=torch.tensor([1., 1.]), alpha=torch.tensor([1.])),
2400  (2,)),
2401  (Pareto(scale=torch.tensor([1., 1.]), alpha=torch.tensor([[1.], [1.]])),
2402  (2, 2)),
2403  (Pareto(scale=torch.tensor([1., 1.]), alpha=torch.tensor([[1.]])),
2404  (1, 2)),
2405  (Pareto(scale=torch.tensor([1.]), alpha=torch.tensor([[1.]])),
2406  (1, 1)),
2407  (StudentT(df=torch.tensor([1., 1.]), loc=1),
2408  (2,)),
2409  (StudentT(df=1, scale=torch.tensor([1., 1.])),
2410  (2,)),
2411  (StudentT(df=torch.tensor([1., 1.]), loc=torch.tensor([1.])),
2412  (2,)),
2413  (StudentT(df=torch.tensor([1., 1.]), scale=torch.tensor([[1.], [1.]])),
2414  (2, 2)),
2415  (StudentT(df=torch.tensor([1., 1.]), loc=torch.tensor([[1.]])),
2416  (1, 2)),
2417  (StudentT(df=torch.tensor([1.]), scale=torch.tensor([[1.]])),
2418  (1, 1)),
2419  (StudentT(df=1., loc=torch.zeros(5, 1), scale=torch.ones(3)),
2420  (5, 3)),
2421  ]
2422 
2423  for dist, expected_size in valid_examples:
2424  actual_size = dist.sample().size()
2425  self.assertEqual(actual_size, expected_size,
2426  '{} actual size: {} != expected size: {}'.format(dist, actual_size, expected_size))
2427 
2428  sample_shape = torch.Size((2,))
2429  expected_size = sample_shape + expected_size
2430  actual_size = dist.sample(sample_shape).size()
2431  self.assertEqual(actual_size, expected_size,
2432  '{} actual size: {} != expected size: {}'.format(dist, actual_size, expected_size))
2433 
2434  def test_invalid_parameter_broadcasting(self):
2435  # invalid broadcasting cases; should throw error
2436  # example type (distribution class, distribution params)
2437  invalid_examples = [
2438  (Normal, {
2439  'loc': torch.tensor([[0, 0]]),
2440  'scale': torch.tensor([1, 1, 1, 1])
2441  }),
2442  (Normal, {
2443  'loc': torch.tensor([[[0, 0, 0], [0, 0, 0]]]),
2444  'scale': torch.tensor([1, 1])
2445  }),
2446  (FisherSnedecor, {
2447  'df1': torch.tensor([1, 1]),
2448  'df2': torch.tensor([1, 1, 1]),
2449  }),
2450  (Gumbel, {
2451  'loc': torch.tensor([[0, 0]]),
2452  'scale': torch.tensor([1, 1, 1, 1])
2453  }),
2454  (Gumbel, {
2455  'loc': torch.tensor([[[0, 0, 0], [0, 0, 0]]]),
2456  'scale': torch.tensor([1, 1])
2457  }),
2458  (Gamma, {
2459  'concentration': torch.tensor([0, 0]),
2460  'rate': torch.tensor([1, 1, 1])
2461  }),
2462  (Laplace, {
2463  'loc': torch.tensor([0, 0]),
2464  'scale': torch.tensor([1, 1, 1])
2465  }),
2466  (Pareto, {
2467  'scale': torch.tensor([1, 1]),
2468  'alpha': torch.tensor([1, 1, 1])
2469  }),
2470  (StudentT, {
2471  'df': torch.tensor([1, 1]),
2472  'scale': torch.tensor([1, 1, 1])
2473  }),
2474  (StudentT, {
2475  'df': torch.tensor([1, 1]),
2476  'loc': torch.tensor([1, 1, 1])
2477  })
2478  ]
2479 
2480  for dist, kwargs in invalid_examples:
2481  self.assertRaises(RuntimeError, dist, **kwargs)
2482 
2483 
2484 # These tests are only needed for a few distributions that implement custom
2485 # reparameterized gradients. Most .rsample() implementations simply rely on
2486 # the reparameterization trick and do not need to be tested for accuracy.
2488  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2489  def test_gamma(self):
2490  num_samples = 100
2491  for alpha in [1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]:
2492  alphas = torch.tensor([alpha] * num_samples, dtype=torch.float, requires_grad=True)
2493  betas = alphas.new_ones(num_samples)
2494  x = Gamma(alphas, betas).rsample()
2495  x.sum().backward()
2496  x, ind = x.sort()
2497  x = x.detach().numpy()
2498  actual_grad = alphas.grad[ind].numpy()
2499  # Compare with expected gradient dx/dalpha along constant cdf(x,alpha).
2500  cdf = scipy.stats.gamma.cdf
2501  pdf = scipy.stats.gamma.pdf
2502  eps = 0.01 * alpha / (1.0 + alpha ** 0.5)
2503  cdf_alpha = (cdf(x, alpha + eps) - cdf(x, alpha - eps)) / (2 * eps)
2504  cdf_x = pdf(x, alpha)
2505  expected_grad = -cdf_alpha / cdf_x
2506  rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
2507  self.assertLess(np.max(rel_error), 0.0005, '\n'.join([
2508  'Bad gradient dx/alpha for x ~ Gamma({}, 1)'.format(alpha),
2509  'x {}'.format(x),
2510  'expected {}'.format(expected_grad),
2511  'actual {}'.format(actual_grad),
2512  'rel error {}'.format(rel_error),
2513  'max error {}'.format(rel_error.max()),
2514  'at alpha={}, x={}'.format(alpha, x[rel_error.argmax()]),
2515  ]))
2516 
2517  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2518  def test_chi2(self):
2519  num_samples = 100
2520  for df in [1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]:
2521  dfs = torch.tensor([df] * num_samples, dtype=torch.float, requires_grad=True)
2522  x = Chi2(dfs).rsample()
2523  x.sum().backward()
2524  x, ind = x.sort()
2525  x = x.detach().numpy()
2526  actual_grad = dfs.grad[ind].numpy()
2527  # Compare with expected gradient dx/ddf along constant cdf(x,df).
2528  cdf = scipy.stats.chi2.cdf
2529  pdf = scipy.stats.chi2.pdf
2530  eps = 0.01 * df / (1.0 + df ** 0.5)
2531  cdf_df = (cdf(x, df + eps) - cdf(x, df - eps)) / (2 * eps)
2532  cdf_x = pdf(x, df)
2533  expected_grad = -cdf_df / cdf_x
2534  rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
2535  self.assertLess(np.max(rel_error), 0.001, '\n'.join([
2536  'Bad gradient dx/ddf for x ~ Chi2({})'.format(df),
2537  'x {}'.format(x),
2538  'expected {}'.format(expected_grad),
2539  'actual {}'.format(actual_grad),
2540  'rel error {}'.format(rel_error),
2541  'max error {}'.format(rel_error.max()),
2542  ]))
2543 
2544  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2545  def test_dirichlet_on_diagonal(self):
2546  num_samples = 20
2547  grid = [1e-1, 1e0, 1e1]
2548  for a0, a1, a2 in product(grid, grid, grid):
2549  alphas = torch.tensor([[a0, a1, a2]] * num_samples, dtype=torch.float, requires_grad=True)
2550  x = Dirichlet(alphas).rsample()[:, 0]
2551  x.sum().backward()
2552  x, ind = x.sort()
2553  x = x.detach().numpy()
2554  actual_grad = alphas.grad[ind].numpy()[:, 0]
2555  # Compare with expected gradient dx/dalpha0 along constant cdf(x,alpha).
2556  # This reduces to a distribution Beta(alpha[0], alpha[1] + alpha[2]).
2557  cdf = scipy.stats.beta.cdf
2558  pdf = scipy.stats.beta.pdf
2559  alpha, beta = a0, a1 + a2
2560  eps = 0.01 * alpha / (1.0 + np.sqrt(alpha))
2561  cdf_alpha = (cdf(x, alpha + eps, beta) - cdf(x, alpha - eps, beta)) / (2 * eps)
2562  cdf_x = pdf(x, alpha, beta)
2563  expected_grad = -cdf_alpha / cdf_x
2564  rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
2565  self.assertLess(np.max(rel_error), 0.001, '\n'.join([
2566  'Bad gradient dx[0]/dalpha[0] for Dirichlet([{}, {}, {}])'.format(a0, a1, a2),
2567  'x {}'.format(x),
2568  'expected {}'.format(expected_grad),
2569  'actual {}'.format(actual_grad),
2570  'rel error {}'.format(rel_error),
2571  'max error {}'.format(rel_error.max()),
2572  'at x={}'.format(x[rel_error.argmax()]),
2573  ]))
2574 
2575  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2576  def test_beta_wrt_alpha(self):
2577  num_samples = 20
2578  grid = [1e-2, 1e-1, 1e0, 1e1, 1e2]
2579  for con1, con0 in product(grid, grid):
2580  con1s = torch.tensor([con1] * num_samples, dtype=torch.float, requires_grad=True)
2581  con0s = con1s.new_tensor([con0] * num_samples)
2582  x = Beta(con1s, con0s).rsample()
2583  x.sum().backward()
2584  x, ind = x.sort()
2585  x = x.detach().numpy()
2586  actual_grad = con1s.grad[ind].numpy()
2587  # Compare with expected gradient dx/dcon1 along constant cdf(x,con1,con0).
2588  cdf = scipy.stats.beta.cdf
2589  pdf = scipy.stats.beta.pdf
2590  eps = 0.01 * con1 / (1.0 + np.sqrt(con1))
2591  cdf_alpha = (cdf(x, con1 + eps, con0) - cdf(x, con1 - eps, con0)) / (2 * eps)
2592  cdf_x = pdf(x, con1, con0)
2593  expected_grad = -cdf_alpha / cdf_x
2594  rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
2595  self.assertLess(np.max(rel_error), 0.005, '\n'.join([
2596  'Bad gradient dx/dcon1 for x ~ Beta({}, {})'.format(con1, con0),
2597  'x {}'.format(x),
2598  'expected {}'.format(expected_grad),
2599  'actual {}'.format(actual_grad),
2600  'rel error {}'.format(rel_error),
2601  'max error {}'.format(rel_error.max()),
2602  'at x = {}'.format(x[rel_error.argmax()]),
2603  ]))
2604 
2605  @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2606  def test_beta_wrt_beta(self):
2607  num_samples = 20
2608  grid = [1e-2, 1e-1, 1e0, 1e1, 1e2]
2609  for con1, con0 in product(grid, grid):
2610  con0s = torch.tensor([con0] * num_samples, dtype=torch.float, requires_grad=True)
2611  con1s = con0s.new_tensor([con1] * num_samples)
2612  x = Beta(con1s, con0s).rsample()
2613  x.sum().backward()
2614  x, ind = x.sort()
2615  x = x.detach().numpy()
2616  actual_grad = con0s.grad[ind].numpy()
2617  # Compare with expected gradient dx/dcon0 along constant cdf(x,con1,con0).
2618  cdf = scipy.stats.beta.cdf
2619  pdf = scipy.stats.beta.pdf
2620  eps = 0.01 * con0 / (1.0 + np.sqrt(con0))
2621  cdf_beta = (cdf(x, con1, con0 + eps) - cdf(x, con1, con0 - eps)) / (2 * eps)
2622  cdf_x = pdf(x, con1, con0)
2623  expected_grad = -cdf_beta / cdf_x
2624  rel_error = np.abs(actual_grad - expected_grad) / (expected_grad + 1e-30)
2625  self.assertLess(np.max(rel_error), 0.005, '\n'.join([
2626  'Bad gradient dx/dcon0 for x ~ Beta({}, {})'.format(con1, con0),
2627  'x {}'.format(x),
2628  'expected {}'.format(expected_grad),
2629  'actual {}'.format(actual_grad),
2630  'rel error {}'.format(rel_error),
2631  'max error {}'.format(rel_error.max()),
2632  'at x = {!r}'.format(x[rel_error.argmax()]),
2633  ]))
2634 
2635  def test_dirichlet_multivariate(self):
2636  alpha_crit = 0.25 * (5.0 ** 0.5 - 1.0)
2637  num_samples = 100000
2638  for shift in [-0.1, -0.05, -0.01, 0.0, 0.01, 0.05, 0.10]:
2639  alpha = alpha_crit + shift
2640  alpha = torch.tensor([alpha], dtype=torch.float, requires_grad=True)
2641  alpha_vec = torch.cat([alpha, alpha, alpha.new([1])])
2642  z = Dirichlet(alpha_vec.expand(num_samples, 3)).rsample()
2643  mean_z3 = 1.0 / (2.0 * alpha + 1.0)
2644  loss = torch.pow(z[:, 2] - mean_z3, 2.0).mean()
2645  actual_grad = grad(loss, [alpha])[0]
2646  # Compute expected gradient by hand.
2647  num = 1.0 - 2.0 * alpha - 4.0 * alpha**2
2648  den = (1.0 + alpha)**2 * (1.0 + 2.0 * alpha)**3
2649  expected_grad = num / den
2650  self.assertEqual(actual_grad, expected_grad, 0.002, '\n'.join([
2651  "alpha = alpha_c + %.2g" % shift,
2652  "expected_grad: %.5g" % expected_grad,
2653  "actual_grad: %.5g" % actual_grad,
2654  "error = %.2g" % torch.abs(expected_grad - actual_grad).max(),
2655  ]))
2656 
2657  def test_dirichlet_tangent_field(self):
2658  num_samples = 20
2659  alpha_grid = [0.5, 1.0, 2.0]
2660 
2661  # v = dx/dalpha[0] is the reparameterized gradient aka tangent field.
2662  def compute_v(x, alpha):
2663  return torch.stack([
2664  _Dirichlet_backward(x, alpha, torch.eye(3, 3)[i].expand_as(x))[:, 0]
2665  for i in range(3)
2666  ], dim=-1)
2667 
2668  for a1, a2, a3 in product(alpha_grid, alpha_grid, alpha_grid):
2669  alpha = torch.tensor([a1, a2, a3], requires_grad=True).expand(num_samples, 3)
2670  x = Dirichlet(alpha).rsample()
2671  dlogp_da = grad([Dirichlet(alpha).log_prob(x.detach()).sum()],
2672  [alpha], retain_graph=True)[0][:, 0]
2673  dlogp_dx = grad([Dirichlet(alpha.detach()).log_prob(x).sum()],
2674  [x], retain_graph=True)[0]
2675  v = torch.stack([grad([x[:, i].sum()], [alpha], retain_graph=True)[0][:, 0]
2676  for i in range(3)], dim=-1)
2677  # Compute ramaining properties by finite difference.
2678  self.assertEqual(compute_v(x, alpha), v, message='Bug in compute_v() helper')
2679  # dx is an arbitrary orthonormal basis tangent to the simplex.
2680  dx = torch.tensor([[2., -1., -1.], [0., 1., -1.]])
2681  dx /= dx.norm(2, -1, True)
2682  eps = 1e-2 * x.min(-1, True)[0] # avoid boundary
2683  dv0 = (compute_v(x + eps * dx[0], alpha) - compute_v(x - eps * dx[0], alpha)) / (2 * eps)
2684  dv1 = (compute_v(x + eps * dx[1], alpha) - compute_v(x - eps * dx[1], alpha)) / (2 * eps)
2685  div_v = (dv0 * dx[0] + dv1 * dx[1]).sum(-1)
2686  # This is a modification of the standard continuity equation, using the product rule to allow
2687  # expression in terms of log_prob rather than the less numerically stable log_prob.exp().
2688  error = dlogp_da + (dlogp_dx * v).sum(-1) + div_v
2689  self.assertLess(torch.abs(error).max(), 0.005, '\n'.join([
2690  'Dirichlet([{}, {}, {}]) gradient violates continuity equation:'.format(a1, a2, a3),
2691  'error = {}'.format(error),
2692  ]))
2693 
2694 
2696  def setUp(self):
2697  super(TestCase, self).setUp()
2698  self.scalar_sample = 1
2699  self.tensor_sample_1 = torch.ones(3, 2)
2700  self.tensor_sample_2 = torch.ones(3, 2, 3)
2701  Distribution.set_default_validate_args(True)
2702 
2703  def tearDown(self):
2704  super(TestCase, self).tearDown()
2705  Distribution.set_default_validate_args(False)
2706 
2707  def test_entropy_shape(self):
2708  for Dist, params in EXAMPLES:
2709  for i, param in enumerate(params):
2710  dist = Dist(validate_args=False, **param)
2711  try:
2712  actual_shape = dist.entropy().size()
2713  expected_shape = dist.batch_shape if dist.batch_shape else torch.Size()
2714  message = '{} example {}/{}, shape mismatch. expected {}, actual {}'.format(
2715  Dist.__name__, i + 1, len(params), expected_shape, actual_shape)
2716  self.assertEqual(actual_shape, expected_shape, message=message)
2717  except NotImplementedError:
2718  continue
2719 
2720  def test_bernoulli_shape_scalar_params(self):
2721  bernoulli = Bernoulli(0.3)
2722  self.assertEqual(bernoulli._batch_shape, torch.Size())
2723  self.assertEqual(bernoulli._event_shape, torch.Size())
2724  self.assertEqual(bernoulli.sample().size(), torch.Size())
2725  self.assertEqual(bernoulli.sample((3, 2)).size(), torch.Size((3, 2)))
2726  self.assertRaises(ValueError, bernoulli.log_prob, self.scalar_sample)
2727  self.assertEqual(bernoulli.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
2728  self.assertEqual(bernoulli.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
2729 
2730  def test_bernoulli_shape_tensor_params(self):
2731  bernoulli = Bernoulli(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
2732  self.assertEqual(bernoulli._batch_shape, torch.Size((3, 2)))
2733  self.assertEqual(bernoulli._event_shape, torch.Size(()))
2734  self.assertEqual(bernoulli.sample().size(), torch.Size((3, 2)))
2735  self.assertEqual(bernoulli.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
2736  self.assertEqual(bernoulli.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
2737  self.assertRaises(ValueError, bernoulli.log_prob, self.tensor_sample_2)
2738  self.assertEqual(bernoulli.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)))
2739 
2740  def test_geometric_shape_scalar_params(self):
2741  geometric = Geometric(0.3)
2742  self.assertEqual(geometric._batch_shape, torch.Size())
2743  self.assertEqual(geometric._event_shape, torch.Size())
2744  self.assertEqual(geometric.sample().size(), torch.Size())
2745  self.assertEqual(geometric.sample((3, 2)).size(), torch.Size((3, 2)))
2746  self.assertRaises(ValueError, geometric.log_prob, self.scalar_sample)
2747  self.assertEqual(geometric.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
2748  self.assertEqual(geometric.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
2749 
2750  def test_geometric_shape_tensor_params(self):
2751  geometric = Geometric(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
2752  self.assertEqual(geometric._batch_shape, torch.Size((3, 2)))
2753  self.assertEqual(geometric._event_shape, torch.Size(()))
2754  self.assertEqual(geometric.sample().size(), torch.Size((3, 2)))
2755  self.assertEqual(geometric.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
2756  self.assertEqual(geometric.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
2757  self.assertRaises(ValueError, geometric.log_prob, self.tensor_sample_2)
2758  self.assertEqual(geometric.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)))
2759 
2760  def test_beta_shape_scalar_params(self):
2761  dist = Beta(0.1, 0.1)
2762  self.assertEqual(dist._batch_shape, torch.Size())
2763  self.assertEqual(dist._event_shape, torch.Size())
2764  self.assertEqual(dist.sample().size(), torch.Size())
2765  self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2)))
2766  self.assertRaises(ValueError, dist.log_prob, self.scalar_sample)
2767  self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
2768  self.assertEqual(dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
2769 
2770  def test_beta_shape_tensor_params(self):
2771  dist = Beta(torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]),
2772  torch.tensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]))
2773  self.assertEqual(dist._batch_shape, torch.Size((3, 2)))
2774  self.assertEqual(dist._event_shape, torch.Size(()))
2775  self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
2776  self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
2777  self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
2778  self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
2779  self.assertEqual(dist.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)))
2780 
2781  def test_binomial_shape(self):
2782  dist = Binomial(10, torch.tensor([0.6, 0.3]))
2783  self.assertEqual(dist._batch_shape, torch.Size((2,)))
2784  self.assertEqual(dist._event_shape, torch.Size(()))
2785  self.assertEqual(dist.sample().size(), torch.Size((2,)))
2786  self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 2)))
2787  self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
2788  self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
2789 
2790  def test_binomial_shape_vectorized_n(self):
2791  dist = Binomial(torch.tensor([[10, 3, 1], [4, 8, 4]]), torch.tensor([0.6, 0.3, 0.1]))
2792  self.assertEqual(dist._batch_shape, torch.Size((2, 3)))
2793  self.assertEqual(dist._event_shape, torch.Size(()))
2794  self.assertEqual(dist.sample().size(), torch.Size((2, 3)))
2795  self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 2, 3)))
2796  self.assertEqual(dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
2797  self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1)
2798 
2799  def test_multinomial_shape(self):
2800  dist = Multinomial(10, torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
2801  self.assertEqual(dist._batch_shape, torch.Size((3,)))
2802  self.assertEqual(dist._event_shape, torch.Size((2,)))
2803  self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
2804  self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
2805  self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,)))
2806  self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
2807  self.assertEqual(dist.log_prob(torch.ones(3, 1, 2)).size(), torch.Size((3, 3)))
2808 
2809  def test_categorical_shape(self):
2810  # unbatched
2811  dist = Categorical(torch.tensor([0.6, 0.3, 0.1]))
2812  self.assertEqual(dist._batch_shape, torch.Size(()))
2813  self.assertEqual(dist._event_shape, torch.Size(()))
2814  self.assertEqual(dist.sample().size(), torch.Size())
2815  self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2,)))
2816  self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
2817  self.assertEqual(dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
2818  self.assertEqual(dist.log_prob(torch.ones(3, 1)).size(), torch.Size((3, 1)))
2819  # batched
2820  dist = Categorical(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
2821  self.assertEqual(dist._batch_shape, torch.Size((3,)))
2822  self.assertEqual(dist._event_shape, torch.Size(()))
2823  self.assertEqual(dist.sample().size(), torch.Size((3,)))
2824  self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3,)))
2825  self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1)
2826  self.assertEqual(dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
2827  self.assertEqual(dist.log_prob(torch.ones(3, 1)).size(), torch.Size((3, 3)))
2828 
2829  def test_one_hot_categorical_shape(self):
2830  # unbatched
2831  dist = OneHotCategorical(torch.tensor([0.6, 0.3, 0.1]))
2832  self.assertEqual(dist._batch_shape, torch.Size(()))
2833  self.assertEqual(dist._event_shape, torch.Size((3,)))
2834  self.assertEqual(dist.sample().size(), torch.Size((3,)))
2835  self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3)))
2836  self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1)
2837  simplex_sample = self.tensor_sample_2 / self.tensor_sample_2.sum(-1, keepdim=True)
2838  self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3, 2,)))
2839  self.assertEqual(dist.log_prob(dist.enumerate_support()).size(), torch.Size((3,)))
2840  simplex_sample = torch.ones(3, 3) / 3
2841  self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3,)))
2842  # batched
2843  dist = OneHotCategorical(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
2844  self.assertEqual(dist._batch_shape, torch.Size((3,)))
2845  self.assertEqual(dist._event_shape, torch.Size((2,)))
2846  self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
2847  self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
2848  simplex_sample = self.tensor_sample_1 / self.tensor_sample_1.sum(-1, keepdim=True)
2849  self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3,)))
2850  self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
2851  self.assertEqual(dist.log_prob(dist.enumerate_support()).size(), torch.Size((2, 3)))
2852  simplex_sample = torch.ones(3, 1, 2) / 2
2853  self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3, 3)))
2854 
2855  def test_cauchy_shape_scalar_params(self):
2856  cauchy = Cauchy(0, 1)
2857  self.assertEqual(cauchy._batch_shape, torch.Size())
2858  self.assertEqual(cauchy._event_shape, torch.Size())
2859  self.assertEqual(cauchy.sample().size(), torch.Size())
2860  self.assertEqual(cauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2)))
2861  self.assertRaises(ValueError, cauchy.log_prob, self.scalar_sample)
2862  self.assertEqual(cauchy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
2863  self.assertEqual(cauchy.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
2864 
2865  def test_cauchy_shape_tensor_params(self):
2866  cauchy = Cauchy(torch.tensor([0., 0.]), torch.tensor([1., 1.]))
2867  self.assertEqual(cauchy._batch_shape, torch.Size((2,)))
2868  self.assertEqual(cauchy._event_shape, torch.Size(()))
2869  self.assertEqual(cauchy.sample().size(), torch.Size((2,)))
2870  self.assertEqual(cauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2)))
2871  self.assertEqual(cauchy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
2872  self.assertRaises(ValueError, cauchy.log_prob, self.tensor_sample_2)
2873  self.assertEqual(cauchy.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
2874 
2875  def test_halfcauchy_shape_scalar_params(self):
2876  halfcauchy = HalfCauchy(1)
2877  self.assertEqual(halfcauchy._batch_shape, torch.Size())
2878  self.assertEqual(halfcauchy._event_shape, torch.Size())
2879  self.assertEqual(halfcauchy.sample().size(), torch.Size())
2880  self.assertEqual(halfcauchy.sample(torch.Size((3, 2))).size(),
2881  torch.Size((3, 2)))
2882  self.assertRaises(ValueError, halfcauchy.log_prob, self.scalar_sample)
2883  self.assertEqual(halfcauchy.log_prob(self.tensor_sample_1).size(),
2884  torch.Size((3, 2)))
2885  self.assertEqual(halfcauchy.log_prob(self.tensor_sample_2).size(),
2886  torch.Size((3, 2, 3)))
2887 
2888  def test_halfcauchy_shape_tensor_params(self):
2889  halfcauchy = HalfCauchy(torch.tensor([1., 1.]))
2890  self.assertEqual(halfcauchy._batch_shape, torch.Size((2,)))
2891  self.assertEqual(halfcauchy._event_shape, torch.Size(()))
2892  self.assertEqual(halfcauchy.sample().size(), torch.Size((2,)))
2893  self.assertEqual(halfcauchy.sample(torch.Size((3, 2))).size(),
2894  torch.Size((3, 2, 2)))
2895  self.assertEqual(halfcauchy.log_prob(self.tensor_sample_1).size(),
2896  torch.Size((3, 2)))
2897  self.assertRaises(ValueError, halfcauchy.log_prob, self.tensor_sample_2)
2898  self.assertEqual(halfcauchy.log_prob(torch.ones(2, 1)).size(),
2899  torch.Size((2, 2)))
2900 
2901  def test_dirichlet_shape(self):
2902  dist = Dirichlet(torch.tensor([[0.6, 0.3], [1.6, 1.3], [2.6, 2.3]]))
2903  self.assertEqual(dist._batch_shape, torch.Size((3,)))
2904  self.assertEqual(dist._event_shape, torch.Size((2,)))
2905  self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
2906  self.assertEqual(dist.sample((5, 4)).size(), torch.Size((5, 4, 3, 2)))
2907  simplex_sample = self.tensor_sample_1 / self.tensor_sample_1.sum(-1, keepdim=True)
2908  self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3,)))
2909  self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
2910  simplex_sample = torch.ones(3, 1, 2)
2911  simplex_sample = simplex_sample / simplex_sample.sum(-1).unsqueeze(-1)
2912  self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3, 3)))
2913 
2914  def test_gamma_shape_scalar_params(self):
2915  gamma = Gamma(1, 1)
2916  self.assertEqual(gamma._batch_shape, torch.Size())
2917  self.assertEqual(gamma._event_shape, torch.Size())
2918  self.assertEqual(gamma.sample().size(), torch.Size())
2919  self.assertEqual(gamma.sample((3, 2)).size(), torch.Size((3, 2)))
2920  self.assertRaises(ValueError, gamma.log_prob, self.scalar_sample)
2921  self.assertEqual(gamma.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
2922  self.assertEqual(gamma.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
2923 
2924  def test_gamma_shape_tensor_params(self):
2925  gamma = Gamma(torch.tensor([1., 1.]), torch.tensor([1., 1.]))
2926  self.assertEqual(gamma._batch_shape, torch.Size((2,)))
2927  self.assertEqual(gamma._event_shape, torch.Size(()))
2928  self.assertEqual(gamma.sample().size(), torch.Size((2,)))
2929  self.assertEqual(gamma.sample((3, 2)).size(), torch.Size((3, 2, 2)))
2930  self.assertEqual(gamma.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
2931  self.assertRaises(ValueError, gamma.log_prob, self.tensor_sample_2)
2932  self.assertEqual(gamma.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
2933 
2934  def test_chi2_shape_scalar_params(self):
2935  chi2 = Chi2(1)
2936  self.assertEqual(chi2._batch_shape, torch.Size())
2937  self.assertEqual(chi2._event_shape, torch.Size())
2938  self.assertEqual(chi2.sample().size(), torch.Size())
2939  self.assertEqual(chi2.sample((3, 2)).size(), torch.Size((3, 2)))
2940  self.assertRaises(ValueError, chi2.log_prob, self.scalar_sample)
2941  self.assertEqual(chi2.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
2942  self.assertEqual(chi2.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
2943 
2944  def test_chi2_shape_tensor_params(self):
2945  chi2 = Chi2(torch.tensor([1., 1.]))
2946  self.assertEqual(chi2._batch_shape, torch.Size((2,)))
2947  self.assertEqual(chi2._event_shape, torch.Size(()))
2948  self.assertEqual(chi2.sample().size(), torch.Size((2,)))
2949  self.assertEqual(chi2.sample((3, 2)).size(), torch.Size((3, 2, 2)))
2950  self.assertEqual(chi2.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
2951  self.assertRaises(ValueError, chi2.log_prob, self.tensor_sample_2)
2952  self.assertEqual(chi2.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
2953 
2954  def test_studentT_shape_scalar_params(self):
2955  st = StudentT(1)
2956  self.assertEqual(st._batch_shape, torch.Size())
2957  self.assertEqual(st._event_shape, torch.Size())
2958  self.assertEqual(st.sample().size(), torch.Size())
2959  self.assertEqual(st.sample((3, 2)).size(), torch.Size((3, 2)))
2960  self.assertRaises(ValueError, st.log_prob, self.scalar_sample)
2961  self.assertEqual(st.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
2962  self.assertEqual(st.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
2963 
2964  def test_studentT_shape_tensor_params(self):
2965  st = StudentT(torch.tensor([1., 1.]))
2966  self.assertEqual(st._batch_shape, torch.Size((2,)))
2967  self.assertEqual(st._event_shape, torch.Size(()))
2968  self.assertEqual(st.sample().size(), torch.Size((2,)))
2969  self.assertEqual(st.sample((3, 2)).size(), torch.Size((3, 2, 2)))
2970  self.assertEqual(st.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
2971  self.assertRaises(ValueError, st.log_prob, self.tensor_sample_2)
2972  self.assertEqual(st.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
2973 
2974  def test_pareto_shape_scalar_params(self):
2975  pareto = Pareto(1, 1)
2976  self.assertEqual(pareto._batch_shape, torch.Size())
2977  self.assertEqual(pareto._event_shape, torch.Size())
2978  self.assertEqual(pareto.sample().size(), torch.Size())
2979  self.assertEqual(pareto.sample((3, 2)).size(), torch.Size((3, 2)))
2980  self.assertEqual(pareto.log_prob(self.tensor_sample_1 + 1).size(), torch.Size((3, 2)))
2981  self.assertEqual(pareto.log_prob(self.tensor_sample_2 + 1).size(), torch.Size((3, 2, 3)))
2982 
2983  def test_gumbel_shape_scalar_params(self):
2984  gumbel = Gumbel(1, 1)
2985  self.assertEqual(gumbel._batch_shape, torch.Size())
2986  self.assertEqual(gumbel._event_shape, torch.Size())
2987  self.assertEqual(gumbel.sample().size(), torch.Size())
2988  self.assertEqual(gumbel.sample((3, 2)).size(), torch.Size((3, 2)))
2989  self.assertEqual(gumbel.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
2990  self.assertEqual(gumbel.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
2991 
2992  def test_weibull_scale_scalar_params(self):
2993  weibull = Weibull(1, 1)
2994  self.assertEqual(weibull._batch_shape, torch.Size())
2995  self.assertEqual(weibull._event_shape, torch.Size())
2996  self.assertEqual(weibull.sample().size(), torch.Size())
2997  self.assertEqual(weibull.sample((3, 2)).size(), torch.Size((3, 2)))
2998  self.assertEqual(weibull.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
2999  self.assertEqual(weibull.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
3000 
3001  def test_normal_shape_scalar_params(self):
3002  normal = Normal(0, 1)
3003  self.assertEqual(normal._batch_shape, torch.Size())
3004  self.assertEqual(normal._event_shape, torch.Size())
3005  self.assertEqual(normal.sample().size(), torch.Size())
3006  self.assertEqual(normal.sample((3, 2)).size(), torch.Size((3, 2)))
3007  self.assertRaises(ValueError, normal.log_prob, self.scalar_sample)
3008  self.assertEqual(normal.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
3009  self.assertEqual(normal.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
3010 
3011  def test_normal_shape_tensor_params(self):
3012  normal = Normal(torch.tensor([0., 0.]), torch.tensor([1., 1.]))
3013  self.assertEqual(normal._batch_shape, torch.Size((2,)))
3014  self.assertEqual(normal._event_shape, torch.Size(()))
3015  self.assertEqual(normal.sample().size(), torch.Size((2,)))
3016  self.assertEqual(normal.sample((3, 2)).size(), torch.Size((3, 2, 2)))
3017  self.assertEqual(normal.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
3018  self.assertRaises(ValueError, normal.log_prob, self.tensor_sample_2)
3019  self.assertEqual(normal.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
3020 
3021  def test_uniform_shape_scalar_params(self):
3022  uniform = Uniform(0, 1)
3023  self.assertEqual(uniform._batch_shape, torch.Size())
3024  self.assertEqual(uniform._event_shape, torch.Size())
3025  self.assertEqual(uniform.sample().size(), torch.Size())
3026  self.assertEqual(uniform.sample(torch.Size((3, 2))).size(), torch.Size((3, 2)))
3027  self.assertRaises(ValueError, uniform.log_prob, self.scalar_sample)
3028  self.assertEqual(uniform.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
3029  self.assertEqual(uniform.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
3030 
3031  def test_uniform_shape_tensor_params(self):
3032  uniform = Uniform(torch.tensor([0., 0.]), torch.tensor([1., 1.]))
3033  self.assertEqual(uniform._batch_shape, torch.Size((2,)))
3034  self.assertEqual(uniform._event_shape, torch.Size(()))
3035  self.assertEqual(uniform.sample().size(), torch.Size((2,)))
3036  self.assertEqual(uniform.sample(torch.Size((3, 2))).size(), torch.Size((3, 2, 2)))
3037  self.assertEqual(uniform.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
3038  self.assertRaises(ValueError, uniform.log_prob, self.tensor_sample_2)
3039  self.assertEqual(uniform.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
3040 
3041  def test_exponential_shape_scalar_param(self):
3042  expon = Exponential(1.)
3043  self.assertEqual(expon._batch_shape, torch.Size())
3044  self.assertEqual(expon._event_shape, torch.Size())
3045  self.assertEqual(expon.sample().size(), torch.Size())
3046  self.assertEqual(expon.sample((3, 2)).size(), torch.Size((3, 2)))
3047  self.assertRaises(ValueError, expon.log_prob, self.scalar_sample)
3048  self.assertEqual(expon.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
3049  self.assertEqual(expon.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
3050 
3051  def test_exponential_shape_tensor_param(self):
3052  expon = Exponential(torch.tensor([1., 1.]))
3053  self.assertEqual(expon._batch_shape, torch.Size((2,)))
3054  self.assertEqual(expon._event_shape, torch.Size(()))
3055  self.assertEqual(expon.sample().size(), torch.Size((2,)))
3056  self.assertEqual(expon.sample((3, 2)).size(), torch.Size((3, 2, 2)))
3057  self.assertEqual(expon.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
3058  self.assertRaises(ValueError, expon.log_prob, self.tensor_sample_2)
3059  self.assertEqual(expon.log_prob(torch.ones(2, 2)).size(), torch.Size((2, 2)))
3060 
3061  def test_laplace_shape_scalar_params(self):
3062  laplace = Laplace(0, 1)
3063  self.assertEqual(laplace._batch_shape, torch.Size())
3064  self.assertEqual(laplace._event_shape, torch.Size())
3065  self.assertEqual(laplace.sample().size(), torch.Size())
3066  self.assertEqual(laplace.sample((3, 2)).size(), torch.Size((3, 2)))
3067  self.assertRaises(ValueError, laplace.log_prob, self.scalar_sample)
3068  self.assertEqual(laplace.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
3069  self.assertEqual(laplace.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
3070 
3071  def test_laplace_shape_tensor_params(self):
3072  laplace = Laplace(torch.tensor([0., 0.]), torch.tensor([1., 1.]))
3073  self.assertEqual(laplace._batch_shape, torch.Size((2,)))
3074  self.assertEqual(laplace._event_shape, torch.Size(()))
3075  self.assertEqual(laplace.sample().size(), torch.Size((2,)))
3076  self.assertEqual(laplace.sample((3, 2)).size(), torch.Size((3, 2, 2)))
3077  self.assertEqual(laplace.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
3078  self.assertRaises(ValueError, laplace.log_prob, self.tensor_sample_2)
3079  self.assertEqual(laplace.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))
3080 
3081 
3083 
3084  def setUp(self):
3085 
3086  class Binomial30(Binomial):
3087  def __init__(self, probs):
3088  super(Binomial30, self).__init__(30, probs)
3089 
3090  # These are pairs of distributions with 4 x 4 parameters as specified.
3091  # The first of the pair e.g. bernoulli[0] varies column-wise and the second
3092  # e.g. bernoulli[1] varies row-wise; that way we test all param pairs.
3093  bernoulli = pairwise(Bernoulli, [0.1, 0.2, 0.6, 0.9])
3094  binomial30 = pairwise(Binomial30, [0.1, 0.2, 0.6, 0.9])
3095  binomial_vectorized_count = (Binomial(torch.tensor([3, 4]), torch.tensor([0.4, 0.6])),
3096  Binomial(torch.tensor([3, 4]), torch.tensor([0.5, 0.8])))
3097  beta = pairwise(Beta, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5])
3098  categorical = pairwise(Categorical, [[0.4, 0.3, 0.3],
3099  [0.2, 0.7, 0.1],
3100  [0.33, 0.33, 0.34],
3101  [0.2, 0.2, 0.6]])
3102  chi2 = pairwise(Chi2, [1.0, 2.0, 2.5, 5.0])
3103  dirichlet = pairwise(Dirichlet, [[0.1, 0.2, 0.7],
3104  [0.5, 0.4, 0.1],
3105  [0.33, 0.33, 0.34],
3106  [0.2, 0.2, 0.4]])
3107  exponential = pairwise(Exponential, [1.0, 2.5, 5.0, 10.0])
3108  gamma = pairwise(Gamma, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5])
3109  gumbel = pairwise(Gumbel, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5])
3110  halfnormal = pairwise(HalfNormal, [1.0, 2.0, 1.0, 2.0])
3111  laplace = pairwise(Laplace, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5])
3112  lognormal = pairwise(LogNormal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
3113  normal = pairwise(Normal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
3114  independent = (Independent(normal[0], 1), Independent(normal[1], 1))
3115  onehotcategorical = pairwise(OneHotCategorical, [[0.4, 0.3, 0.3],
3116  [0.2, 0.7, 0.1],
3117  [0.33, 0.33, 0.34],
3118  [0.2, 0.2, 0.6]])
3119  pareto = pairwise(Pareto, [2.5, 4.0, 2.5, 4.0], [2.25, 3.75, 2.25, 3.75])
3120  poisson = pairwise(Poisson, [0.3, 1.0, 5.0, 10.0])
3121  uniform_within_unit = pairwise(Uniform, [0.15, 0.95, 0.2, 0.8], [0.1, 0.9, 0.25, 0.75])
3122  uniform_positive = pairwise(Uniform, [1, 1.5, 2, 4], [1.2, 2.0, 3, 7])
3123  uniform_real = pairwise(Uniform, [-2., -1, 0, 2], [-1., 1, 1, 4])
3124  uniform_pareto = pairwise(Uniform, [6.5, 8.5, 6.5, 8.5], [7.5, 7.5, 9.5, 9.5])
3125 
3126  # These tests should pass with precision = 0.01, but that makes tests very expensive.
3127  # Instead, we test with precision = 0.1 and only test with higher precision locally
3128  # when adding a new KL implementation.
3129  # The following pairs are not tested due to very high variance of the monte carlo
3130  # estimator; their implementations have been reviewed with extra care:
3131  # - (pareto, normal)
3132  self.precision = 0.1 # Set this to 0.01 when testing a new KL implementation.
3133  self.max_samples = int(1e07) # Increase this when testing at smaller precision.
3134  self.samples_per_batch = int(1e04)
3135  self.finite_examples = [
3136  (bernoulli, bernoulli),
3137  (bernoulli, poisson),
3138  (beta, beta),
3139  (beta, chi2),
3140  (beta, exponential),
3141  (beta, gamma),
3142  (beta, normal),
3143  (binomial30, binomial30),
3144  (binomial_vectorized_count, binomial_vectorized_count),
3145  (categorical, categorical),
3146  (chi2, chi2),
3147  (chi2, exponential),
3148  (chi2, gamma),
3149  (chi2, normal),
3150  (dirichlet, dirichlet),
3151  (exponential, chi2),
3152  (exponential, exponential),
3153  (exponential, gamma),
3154  (exponential, gumbel),
3155  (exponential, normal),
3156  (gamma, chi2),
3157  (gamma, exponential),
3158  (gamma, gamma),
3159  (gamma, gumbel),
3160  (gamma, normal),
3161  (gumbel, gumbel),
3162  (gumbel, normal),
3163  (halfnormal, halfnormal),
3164  (independent, independent),
3165  (laplace, laplace),
3166  (lognormal, lognormal),
3167  (laplace, normal),
3168  (normal, gumbel),
3169  (normal, normal),
3170  (onehotcategorical, onehotcategorical),
3171  (pareto, chi2),
3172  (pareto, pareto),
3173  (pareto, exponential),
3174  (pareto, gamma),
3175  (poisson, poisson),
3176  (uniform_within_unit, beta),
3177  (uniform_positive, chi2),
3178  (uniform_positive, exponential),
3179  (uniform_positive, gamma),
3180  (uniform_real, gumbel),
3181  (uniform_real, normal),
3182  (uniform_pareto, pareto),
3183  ]
3184 
3185  self.infinite_examples = [
3186  (Bernoulli(0), Bernoulli(1)),
3187  (Bernoulli(1), Bernoulli(0)),
3188  (Categorical(torch.tensor([0.9, 0.1])), Categorical(torch.tensor([1., 0.]))),
3189  (Categorical(torch.tensor([[0.9, 0.1], [.9, .1]])), Categorical(torch.tensor([1., 0.]))),
3190  (Beta(1, 2), Uniform(0.25, 1)),
3191  (Beta(1, 2), Uniform(0, 0.75)),
3192  (Beta(1, 2), Uniform(0.25, 0.75)),
3193  (Beta(1, 2), Pareto(1, 2)),
3194  (Binomial(31, 0.7), Binomial(30, 0.3)),
3195  (Binomial(torch.tensor([3, 4]), torch.tensor([0.4, 0.6])),
3196  Binomial(torch.tensor([2, 3]), torch.tensor([0.5, 0.8]))),
3197  (Chi2(1), Beta(2, 3)),
3198  (Chi2(1), Pareto(2, 3)),
3199  (Chi2(1), Uniform(-2, 3)),
3200  (Exponential(1), Beta(2, 3)),
3201  (Exponential(1), Pareto(2, 3)),
3202  (Exponential(1), Uniform(-2, 3)),
3203  (Gamma(1, 2), Beta(3, 4)),
3204  (Gamma(1, 2), Pareto(3, 4)),
3205  (Gamma(1, 2), Uniform(-3, 4)),
3206  (Gumbel(-1, 2), Beta(3, 4)),
3207  (Gumbel(-1, 2), Chi2(3)),
3208  (Gumbel(-1, 2), Exponential(3)),
3209  (Gumbel(-1, 2), Gamma(3, 4)),
3210  (Gumbel(-1, 2), Pareto(3, 4)),
3211  (Gumbel(-1, 2), Uniform(-3, 4)),
3212  (Laplace(-1, 2), Beta(3, 4)),
3213  (Laplace(-1, 2), Chi2(3)),
3214  (Laplace(-1, 2), Exponential(3)),
3215  (Laplace(-1, 2), Gamma(3, 4)),
3216  (Laplace(-1, 2), Pareto(3, 4)),
3217  (Laplace(-1, 2), Uniform(-3, 4)),
3218  (Normal(-1, 2), Beta(3, 4)),
3219  (Normal(-1, 2), Chi2(3)),
3220  (Normal(-1, 2), Exponential(3)),
3221  (Normal(-1, 2), Gamma(3, 4)),
3222  (Normal(-1, 2), Pareto(3, 4)),
3223  (Normal(-1, 2), Uniform(-3, 4)),
3224  (Pareto(2, 1), Chi2(3)),
3225  (Pareto(2, 1), Exponential(3)),
3226  (Pareto(2, 1), Gamma(3, 4)),
3227  (Pareto(1, 2), Normal(-3, 4)),
3228  (Pareto(1, 2), Pareto(3, 4)),
3229  (Poisson(2), Bernoulli(0.5)),
3230  (Poisson(2.3), Binomial(10, 0.2)),
3231  (Uniform(-1, 1), Beta(2, 2)),
3232  (Uniform(0, 2), Beta(3, 4)),
3233  (Uniform(-1, 2), Beta(3, 4)),
3234  (Uniform(-1, 2), Chi2(3)),
3235  (Uniform(-1, 2), Exponential(3)),
3236  (Uniform(-1, 2), Gamma(3, 4)),
3237  (Uniform(-1, 2), Pareto(3, 4)),
3238  ]
3239 
3240  def test_kl_monte_carlo(self):
3241  set_rng_seed(0) # see Note [Randomized statistical tests]
3242  for (p, _), (_, q) in self.finite_examples:
3243  actual = kl_divergence(p, q)
3244  numerator = 0
3245  denominator = 0
3246  while denominator < self.max_samples:
3247  x = p.sample(sample_shape=(self.samples_per_batch,))
3248  numerator += (p.log_prob(x) - q.log_prob(x)).sum(0)
3249  denominator += x.size(0)
3250  expected = numerator / denominator
3251  error = torch.abs(expected - actual) / (1 + expected)
3252  if error[error == error].max() < self.precision:
3253  break
3254  self.assertLess(error[error == error].max(), self.precision, '\n'.join([
3255  'Incorrect KL({}, {}).'.format(type(p).__name__, type(q).__name__),
3256  'Expected ({} Monte Carlo samples): {}'.format(denominator, expected),
3257  'Actual (analytic): {}'.format(actual),
3258  ]))
3259 
3260  # Multivariate normal has a separate Monte Carlo based test due to the requirement of random generation of
3261  # positive (semi) definite matrices. n is set to 5, but can be increased during testing.
3262  def test_kl_multivariate_normal(self):
3263  set_rng_seed(0) # see Note [Randomized statistical tests]
3264  n = 5 # Number of tests for multivariate_normal
3265  for i in range(0, n):
3266  loc = [torch.randn(4) for _ in range(0, 2)]
3267  scale_tril = [transform_to(constraints.lower_cholesky)(torch.randn(4, 4)) for _ in range(0, 2)]
3268  p = MultivariateNormal(loc=loc[0], scale_tril=scale_tril[0])
3269  q = MultivariateNormal(loc=loc[1], scale_tril=scale_tril[1])
3270  actual = kl_divergence(p, q)
3271  numerator = 0
3272  denominator = 0
3273  while denominator < self.max_samples:
3274  x = p.sample(sample_shape=(self.samples_per_batch,))
3275  numerator += (p.log_prob(x) - q.log_prob(x)).sum(0)
3276  denominator += x.size(0)
3277  expected = numerator / denominator
3278  error = torch.abs(expected - actual) / (1 + expected)
3279  if error[error == error].max() < self.precision:
3280  break
3281  self.assertLess(error[error == error].max(), self.precision, '\n'.join([
3282  'Incorrect KL(MultivariateNormal, MultivariateNormal) instance {}/{}'.format(i + 1, n),
3283  'Expected ({} Monte Carlo sample): {}'.format(denominator, expected),
3284  'Actual (analytic): {}'.format(actual),
3285  ]))
3286 
3287  def test_kl_multivariate_normal_batched(self):
3288  b = 7 # Number of batches
3289  loc = [torch.randn(b, 3) for _ in range(0, 2)]
3290  scale_tril = [transform_to(constraints.lower_cholesky)(torch.randn(b, 3, 3)) for _ in range(0, 2)]
3291  expected_kl = torch.stack([
3292  kl_divergence(MultivariateNormal(loc[0][i], scale_tril=scale_tril[0][i]),
3293  MultivariateNormal(loc[1][i], scale_tril=scale_tril[1][i])) for i in range(0, b)])
3294  actual_kl = kl_divergence(MultivariateNormal(loc[0], scale_tril=scale_tril[0]),
3295  MultivariateNormal(loc[1], scale_tril=scale_tril[1]))
3296  self.assertEqual(expected_kl, actual_kl)
3297 
3298  def test_kl_multivariate_normal_batched_broadcasted(self):
3299  b = 7 # Number of batches
3300  loc = [torch.randn(b, 3) for _ in range(0, 2)]
3301  scale_tril = [transform_to(constraints.lower_cholesky)(torch.randn(b, 3, 3)),
3302  transform_to(constraints.lower_cholesky)(torch.randn(3, 3))]
3303  expected_kl = torch.stack([
3304  kl_divergence(MultivariateNormal(loc[0][i], scale_tril=scale_tril[0][i]),
3305  MultivariateNormal(loc[1][i], scale_tril=scale_tril[1])) for i in range(0, b)])
3306  actual_kl = kl_divergence(MultivariateNormal(loc[0], scale_tril=scale_tril[0]),
3307  MultivariateNormal(loc[1], scale_tril=scale_tril[1]))
3308  self.assertEqual(expected_kl, actual_kl)
3309 
3310  def test_kl_lowrank_multivariate_normal(self):
3311  set_rng_seed(0) # see Note [Randomized statistical tests]
3312  n = 5 # Number of tests for lowrank_multivariate_normal
3313  for i in range(0, n):
3314  loc = [torch.randn(4) for _ in range(0, 2)]
3315  cov_factor = [torch.randn(4, 3) for _ in range(0, 2)]
3316  cov_diag = [transform_to(constraints.positive)(torch.randn(4)) for _ in range(0, 2)]
3317  covariance_matrix = [cov_factor[i].matmul(cov_factor[i].t()) +
3318  cov_diag[i].diag() for i in range(0, 2)]
3319  p = LowRankMultivariateNormal(loc[0], cov_factor[0], cov_diag[0])
3320  q = LowRankMultivariateNormal(loc[1], cov_factor[1], cov_diag[1])
3321  p_full = MultivariateNormal(loc[0], covariance_matrix[0])
3322  q_full = MultivariateNormal(loc[1], covariance_matrix[1])
3323  expected = kl_divergence(p_full, q_full)
3324 
3325  actual_lowrank_lowrank = kl_divergence(p, q)
3326  actual_lowrank_full = kl_divergence(p, q_full)
3327  actual_full_lowrank = kl_divergence(p_full, q)
3328 
3329  error_lowrank_lowrank = torch.abs(actual_lowrank_lowrank - expected).max()
3330  self.assertLess(error_lowrank_lowrank, self.precision, '\n'.join([
3331  'Incorrect KL(LowRankMultivariateNormal, LowRankMultivariateNormal) instance {}/{}'.format(i + 1, n),
3332  'Expected (from KL MultivariateNormal): {}'.format(expected),
3333  'Actual (analytic): {}'.format(actual_lowrank_lowrank),
3334  ]))
3335 
3336  error_lowrank_full = torch.abs(actual_lowrank_full - expected).max()
3337  self.assertLess(error_lowrank_full, self.precision, '\n'.join([
3338  'Incorrect KL(LowRankMultivariateNormal, MultivariateNormal) instance {}/{}'.format(i + 1, n),
3339  'Expected (from KL MultivariateNormal): {}'.format(expected),
3340  'Actual (analytic): {}'.format(actual_lowrank_full),
3341  ]))
3342 
3343  error_full_lowrank = torch.abs(actual_full_lowrank - expected).max()
3344  self.assertLess(error_full_lowrank, self.precision, '\n'.join([
3345  'Incorrect KL(MultivariateNormal, LowRankMultivariateNormal) instance {}/{}'.format(i + 1, n),
3346  'Expected (from KL MultivariateNormal): {}'.format(expected),
3347  'Actual (analytic): {}'.format(actual_full_lowrank),
3348  ]))
3349 
3350  def test_kl_lowrank_multivariate_normal_batched(self):
3351  b = 7 # Number of batches
3352  loc = [torch.randn(b, 3) for _ in range(0, 2)]
3353  cov_factor = [torch.randn(b, 3, 2) for _ in range(0, 2)]
3354  cov_diag = [transform_to(constraints.positive)(torch.randn(b, 3)) for _ in range(0, 2)]
3355  expected_kl = torch.stack([
3356  kl_divergence(LowRankMultivariateNormal(loc[0][i], cov_factor[0][i], cov_diag[0][i]),
3357  LowRankMultivariateNormal(loc[1][i], cov_factor[1][i], cov_diag[1][i]))
3358  for i in range(0, b)])
3359  actual_kl = kl_divergence(LowRankMultivariateNormal(loc[0], cov_factor[0], cov_diag[0]),
3360  LowRankMultivariateNormal(loc[1], cov_factor[1], cov_diag[1]))
3361  self.assertEqual(expected_kl, actual_kl)
3362 
3363  def test_kl_exponential_family(self):
3364  for (p, _), (_, q) in self.finite_examples:
3365  if type(p) == type(q) and issubclass(type(p), ExponentialFamily):
3366  actual = kl_divergence(p, q)
3367  expected = _kl_expfamily_expfamily(p, q)
3368  self.assertEqual(actual, expected, message='\n'.join([
3369  'Incorrect KL({}, {}).'.format(type(p).__name__, type(q).__name__),
3370  'Expected (using Bregman Divergence) {}'.format(expected),
3371  'Actual (analytic) {}'.format(actual),
3372  'max error = {}'.format(torch.abs(actual - expected).max())
3373  ]))
3374 
3375  def test_kl_infinite(self):
3376  for p, q in self.infinite_examples:
3377  self.assertTrue((kl_divergence(p, q) == inf).all(),
3378  'Incorrect KL({}, {})'.format(type(p).__name__, type(q).__name__))
3379 
3380  def test_kl_edgecases(self):
3381  self.assertEqual(kl_divergence(Bernoulli(0), Bernoulli(0)), 0)
3382  self.assertEqual(kl_divergence(Bernoulli(1), Bernoulli(1)), 0)
3383  self.assertEqual(kl_divergence(Categorical(torch.tensor([0., 1.])), Categorical(torch.tensor([0., 1.]))), 0)
3384 
3385  def test_kl_shape(self):
3386  for Dist, params in EXAMPLES:
3387  for i, param in enumerate(params):
3388  dist = Dist(**param)
3389  try:
3390  kl = kl_divergence(dist, dist)
3391  except NotImplementedError:
3392  continue
3393  expected_shape = dist.batch_shape if dist.batch_shape else torch.Size()
3394  self.assertEqual(kl.shape, expected_shape, message='\n'.join([
3395  '{} example {}/{}'.format(Dist.__name__, i + 1, len(params)),
3396  'Expected {}'.format(expected_shape),
3397  'Actual {}'.format(kl.shape),
3398  ]))
3399 
3400  def test_entropy_monte_carlo(self):
3401  set_rng_seed(0) # see Note [Randomized statistical tests]
3402  for Dist, params in EXAMPLES:
3403  for i, param in enumerate(params):
3404  dist = Dist(**param)
3405  try:
3406  actual = dist.entropy()
3407  except NotImplementedError:
3408  continue
3409  x = dist.sample(sample_shape=(60000,))
3410  expected = -dist.log_prob(x).mean(0)
3411  ignore = (expected == inf) | (expected == -inf)
3412  expected[ignore] = actual[ignore]
3413  self.assertEqual(actual, expected, prec=0.2, message='\n'.join([
3414  '{} example {}/{}, incorrect .entropy().'.format(Dist.__name__, i + 1, len(params)),
3415  'Expected (monte carlo) {}'.format(expected),
3416  'Actual (analytic) {}'.format(actual),
3417  'max error = {}'.format(torch.abs(actual - expected).max()),
3418  ]))
3419 
3420  def test_entropy_exponential_family(self):
3421  for Dist, params in EXAMPLES:
3422  if not issubclass(Dist, ExponentialFamily):
3423  continue
3424  for i, param in enumerate(params):
3425  dist = Dist(**param)
3426  try:
3427  actual = dist.entropy()
3428  except NotImplementedError:
3429  continue
3430  try:
3431  expected = ExponentialFamily.entropy(dist)
3432  except NotImplementedError:
3433  continue
3434  self.assertEqual(actual, expected, message='\n'.join([
3435  '{} example {}/{}, incorrect .entropy().'.format(Dist.__name__, i + 1, len(params)),
3436  'Expected (Bregman Divergence) {}'.format(expected),
3437  'Actual (analytic) {}'.format(actual),
3438  'max error = {}'.format(torch.abs(actual - expected).max())
3439  ]))
3440 
3441 
3443  def test_params_contains(self):
3444  for Dist, params in EXAMPLES:
3445  for i, param in enumerate(params):
3446  dist = Dist(**param)
3447  for name, value in param.items():
3448  if isinstance(value, numbers.Number):
3449  value = torch.tensor([value])
3450  if Dist in (Categorical, OneHotCategorical, Multinomial) and name == 'probs':
3451  # These distributions accept positive probs, but elsewhere we
3452  # use a stricter constraint to the simplex.
3453  value = value / value.sum(-1, True)
3454  try:
3455  constraint = dist.arg_constraints[name]
3456  except KeyError:
3457  continue # ignore optional parameters
3458 
3459  if is_dependent(constraint):
3460  continue
3461 
3462  message = '{} example {}/{} parameter {} = {}'.format(
3463  Dist.__name__, i + 1, len(params), name, value)
3464  self.assertTrue(constraint.check(value).all(), msg=message)
3465 
3466  def test_support_contains(self):
3467  for Dist, params in EXAMPLES:
3468  self.assertIsInstance(Dist.support, Constraint)
3469  for i, param in enumerate(params):
3470  dist = Dist(**param)
3471  value = dist.sample()
3472  constraint = dist.support
3473  message = '{} example {}/{} sample = {}'.format(
3474  Dist.__name__, i + 1, len(params), value)
3475  self.assertTrue(constraint.check(value).all(), msg=message)
3476 
3477 
3479  def _test_pdf_score(self,
3480  dist_class,
3481  x,
3482  expected_value,
3483  probs=None,
3484  logits=None,
3485  expected_gradient=None,
3486  prec=1e-5):
3487  if probs is not None:
3488  p = probs.detach().requires_grad_()
3489  dist = dist_class(p)
3490  else:
3491  p = logits.detach().requires_grad_()
3492  dist = dist_class(logits=p)
3493  log_pdf = dist.log_prob(x)
3494  log_pdf.sum().backward()
3495  self.assertEqual(log_pdf,
3496  expected_value,
3497  prec=prec,
3498  message='Incorrect value for tensor type: {}. Expected = {}, Actual = {}'
3499  .format(type(x), expected_value, log_pdf))
3500  if expected_gradient is not None:
3501  self.assertEqual(p.grad,
3502  expected_gradient,
3503  prec=prec,
3504  message='Incorrect gradient for tensor type: {}. Expected = {}, Actual = {}'
3505  .format(type(x), expected_gradient, p.grad))
3506 
3507  def test_bernoulli_gradient(self):
3508  for tensor_type in [torch.FloatTensor, torch.DoubleTensor]:
3509  self._test_pdf_score(dist_class=Bernoulli,
3510  probs=tensor_type([0]),
3511  x=tensor_type([0]),
3512  expected_value=tensor_type([0]),
3513  expected_gradient=tensor_type([0]))
3514 
3515  self._test_pdf_score(dist_class=Bernoulli,
3516  probs=tensor_type([0]),
3517  x=tensor_type([1]),
3518  expected_value=tensor_type([torch.finfo(tensor_type([]).dtype).eps]).log(),
3519  expected_gradient=tensor_type([0]))
3520 
3521  self._test_pdf_score(dist_class=Bernoulli,
3522  probs=tensor_type([1e-4]),
3523  x=tensor_type([1]),
3524  expected_value=tensor_type([math.log(1e-4)]),
3525  expected_gradient=tensor_type([10000]))
3526 
3527  # Lower precision due to:
3528  # >>> 1 / (1 - torch.FloatTensor([0.9999]))
3529  # 9998.3408
3530  # [torch.FloatTensor of size 1]
3531  self._test_pdf_score(dist_class=Bernoulli,
3532  probs=tensor_type([1 - 1e-4]),
3533  x=tensor_type([0]),
3534  expected_value=tensor_type([math.log(1e-4)]),
3535  expected_gradient=tensor_type([-10000]),
3536  prec=2)
3537 
3538  self._test_pdf_score(dist_class=Bernoulli,
3539  logits=tensor_type([math.log(9999)]),
3540  x=tensor_type([0]),
3541  expected_value=tensor_type([math.log(1e-4)]),
3542  expected_gradient=tensor_type([-1]),
3543  prec=1e-3)
3544 
3545  def test_bernoulli_with_logits_underflow(self):
3546  for tensor_type, lim in ([(torch.FloatTensor, -1e38),
3547  (torch.DoubleTensor, -1e308)]):
3548  self._test_pdf_score(dist_class=Bernoulli,
3549  logits=tensor_type([lim]),
3550  x=tensor_type([0]),
3551  expected_value=tensor_type([0]),
3552  expected_gradient=tensor_type([0]))
3553 
3554  def test_bernoulli_with_logits_overflow(self):
3555  for tensor_type, lim in ([(torch.FloatTensor, 1e38),
3556  (torch.DoubleTensor, 1e308)]):
3557  self._test_pdf_score(dist_class=Bernoulli,
3558  logits=tensor_type([lim]),
3559  x=tensor_type([1]),
3560  expected_value=tensor_type([0]),
3561  expected_gradient=tensor_type([0]))
3562 
3563  def test_categorical_log_prob(self):
3564  for dtype in ([torch.float, torch.double]):
3565  p = torch.tensor([0, 1], dtype=dtype, requires_grad=True)
3566  categorical = OneHotCategorical(p)
3567  log_pdf = categorical.log_prob(torch.tensor([0, 1], dtype=dtype))
3568  self.assertEqual(log_pdf.item(), 0)
3569 
3570  def test_categorical_log_prob_with_logits(self):
3571  for dtype in ([torch.float, torch.double]):
3572  p = torch.tensor([-inf, 0], dtype=dtype, requires_grad=True)
3573  categorical = OneHotCategorical(logits=p)
3574  log_pdf_prob_1 = categorical.log_prob(torch.tensor([0, 1], dtype=dtype))
3575  self.assertEqual(log_pdf_prob_1.item(), 0)
3576  log_pdf_prob_0 = categorical.log_prob(torch.tensor([1, 0], dtype=dtype))
3577  self.assertEqual(log_pdf_prob_0.item(), -inf, allow_inf=True)
3578 
3579  def test_multinomial_log_prob(self):
3580  for dtype in ([torch.float, torch.double]):
3581  p = torch.tensor([0, 1], dtype=dtype, requires_grad=True)
3582  s = torch.tensor([0, 10], dtype=dtype)
3583  multinomial = Multinomial(10, p)
3584  log_pdf = multinomial.log_prob(s)
3585  self.assertEqual(log_pdf.item(), 0)
3586 
3587  def test_multinomial_log_prob_with_logits(self):
3588  for dtype in ([torch.float, torch.double]):
3589  p = torch.tensor([-inf, 0], dtype=dtype, requires_grad=True)
3590  multinomial = Multinomial(10, logits=p)
3591  log_pdf_prob_1 = multinomial.log_prob(torch.tensor([0, 10], dtype=dtype))
3592  self.assertEqual(log_pdf_prob_1.item(), 0)
3593  log_pdf_prob_0 = multinomial.log_prob(torch.tensor([10, 0], dtype=dtype))
3594  self.assertEqual(log_pdf_prob_0.item(), -inf, allow_inf=True)
3595 
3596 
3598  def setUp(self):
3599  self.examples = [e for e in EXAMPLES if e.Dist in
3600  (Categorical, OneHotCategorical, Bernoulli, Binomial, Multinomial)]
3601 
3602  def test_lazy_logits_initialization(self):
3603  for Dist, params in self.examples:
3604  param = params[0]
3605  if 'probs' in param:
3606  probs = param.pop('probs')
3607  param['logits'] = probs_to_logits(probs)
3608  dist = Dist(**param)
3609  shape = (1,) if not dist.event_shape else dist.event_shape
3610  dist.log_prob(torch.ones(shape))
3611  message = 'Failed for {} example 0/{}'.format(Dist.__name__, len(params))
3612  self.assertFalse('probs' in vars(dist), msg=message)
3613  try:
3614  dist.enumerate_support()
3615  except NotImplementedError:
3616  pass
3617  self.assertFalse('probs' in vars(dist), msg=message)
3618  batch_shape, event_shape = dist.batch_shape, dist.event_shape
3619  self.assertFalse('probs' in vars(dist), msg=message)
3620 
3621  def test_lazy_probs_initialization(self):
3622  for Dist, params in self.examples:
3623  param = params[0]
3624  if 'probs' in param:
3625  dist = Dist(**param)
3626  dist.sample()
3627  message = 'Failed for {} example 0/{}'.format(Dist.__name__, len(params))
3628  self.assertFalse('logits' in vars(dist), msg=message)
3629  try:
3630  dist.enumerate_support()
3631  except NotImplementedError:
3632  pass
3633  self.assertFalse('logits' in vars(dist), msg=message)
3634  batch_shape, event_shape = dist.batch_shape, dist.event_shape
3635  self.assertFalse('logits' in vars(dist), msg=message)
3636 
3637 
3638 @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
3640  def setUp(self):
3641  set_rng_seed(0)
3642  positive_var = torch.randn(20).exp()
3643  positive_var2 = torch.randn(20).exp()
3644  random_var = torch.randn(20)
3645  simplex_tensor = softmax(torch.randn(20), dim=-1)
3646  self.distribution_pairs = [
3647  (
3648  Bernoulli(simplex_tensor),
3649  scipy.stats.bernoulli(simplex_tensor)
3650  ),
3651  (
3652  Beta(positive_var, positive_var2),
3653  scipy.stats.beta(positive_var, positive_var2)
3654  ),
3655  (
3656  Binomial(10, simplex_tensor),
3657  scipy.stats.binom(10 * np.ones(simplex_tensor.shape), simplex_tensor.numpy())
3658  ),
3659  (
3660  Cauchy(random_var, positive_var),
3661  scipy.stats.cauchy(loc=random_var, scale=positive_var)
3662  ),
3663  (
3664  Dirichlet(positive_var),
3665  scipy.stats.dirichlet(positive_var)
3666  ),
3667  (
3668  Exponential(positive_var),
3669  scipy.stats.expon(scale=positive_var.reciprocal())
3670  ),
3671  (
3672  FisherSnedecor(positive_var, 4 + positive_var2), # var for df2<=4 is undefined
3673  scipy.stats.f(positive_var, 4 + positive_var2)
3674  ),
3675  (
3676  Gamma(positive_var, positive_var2),
3677  scipy.stats.gamma(positive_var, scale=positive_var2.reciprocal())
3678  ),
3679  (
3680  Geometric(simplex_tensor),
3681  scipy.stats.geom(simplex_tensor, loc=-1)
3682  ),
3683  (
3684  Gumbel(random_var, positive_var2),
3685  scipy.stats.gumbel_r(random_var, positive_var2)
3686  ),
3687  (
3688  HalfCauchy(positive_var),
3689  scipy.stats.halfcauchy(scale=positive_var)
3690  ),
3691  (
3692  HalfNormal(positive_var2),
3693  scipy.stats.halfnorm(scale=positive_var2)
3694  ),
3695  (
3696  Laplace(random_var, positive_var2),
3697  scipy.stats.laplace(random_var, positive_var2)
3698  ),
3699  (
3700  # Tests fail 1e-5 threshold if scale > 3
3701  LogNormal(random_var, positive_var.clamp(max=3)),
3702  scipy.stats.lognorm(s=positive_var.clamp(max=3), scale=random_var.exp())
3703  ),
3704  (
3705  LowRankMultivariateNormal(random_var, torch.zeros(20, 1), positive_var2),
3706  scipy.stats.multivariate_normal(random_var, torch.diag(positive_var2))
3707  ),
3708  (
3709  Multinomial(10, simplex_tensor),
3710  scipy.stats.multinomial(10, simplex_tensor)
3711  ),
3712  (
3713  MultivariateNormal(random_var, torch.diag(positive_var2)),
3714  scipy.stats.multivariate_normal(random_var, torch.diag(positive_var2))
3715  ),
3716  (
3717  Normal(random_var, positive_var2),
3718  scipy.stats.norm(random_var, positive_var2)
3719  ),
3720  (
3721  OneHotCategorical(simplex_tensor),
3722  scipy.stats.multinomial(1, simplex_tensor)
3723  ),
3724  (
3725  Pareto(positive_var, 2 + positive_var2),
3726  scipy.stats.pareto(2 + positive_var2, scale=positive_var)
3727  ),
3728  (
3729  Poisson(positive_var),
3730  scipy.stats.poisson(positive_var)
3731  ),
3732  (
3733  StudentT(2 + positive_var, random_var, positive_var2),
3734  scipy.stats.t(2 + positive_var, random_var, positive_var2)
3735  ),
3736  (
3737  Uniform(random_var, random_var + positive_var),
3738  scipy.stats.uniform(random_var, positive_var)
3739  ),
3740  (
3741  Weibull(positive_var[0], positive_var2[0]), # scipy var for Weibull only supports scalars
3742  scipy.stats.weibull_min(c=positive_var2[0], scale=positive_var[0])
3743  )
3744  ]
3745 
3746  def test_mean(self):
3747  for pytorch_dist, scipy_dist in self.distribution_pairs:
3748  if isinstance(pytorch_dist, (Cauchy, HalfCauchy)):
3749  # Cauchy, HalfCauchy distributions' mean is nan, skipping check
3750  continue
3751  elif isinstance(pytorch_dist, (LowRankMultivariateNormal, MultivariateNormal)):
3752  self.assertEqual(pytorch_dist.mean, scipy_dist.mean, allow_inf=True, message=pytorch_dist)
3753  else:
3754  self.assertEqual(pytorch_dist.mean, scipy_dist.mean(), allow_inf=True, message=pytorch_dist)
3755 
3756  def test_variance_stddev(self):
3757  for pytorch_dist, scipy_dist in self.distribution_pairs:
3758  if isinstance(pytorch_dist, (Cauchy, HalfCauchy)):
3759  # Cauchy, HalfCauchy distributions' standard deviation is nan, skipping check
3760  continue
3761  elif isinstance(pytorch_dist, (Multinomial, OneHotCategorical)):
3762  self.assertEqual(pytorch_dist.variance, np.diag(scipy_dist.cov()), message=pytorch_dist)
3763  self.assertEqual(pytorch_dist.stddev, np.diag(scipy_dist.cov()) ** 0.5, message=pytorch_dist)
3764  elif isinstance(pytorch_dist, (LowRankMultivariateNormal, MultivariateNormal)):
3765  self.assertEqual(pytorch_dist.variance, np.diag(scipy_dist.cov), message=pytorch_dist)
3766  self.assertEqual(pytorch_dist.stddev, np.diag(scipy_dist.cov) ** 0.5, message=pytorch_dist)
3767  else:
3768  self.assertEqual(pytorch_dist.variance, scipy_dist.var(), allow_inf=True, message=pytorch_dist)
3769  self.assertEqual(pytorch_dist.stddev, scipy_dist.var() ** 0.5, message=pytorch_dist)
3770 
3771  def test_cdf(self):
3772  for pytorch_dist, scipy_dist in self.distribution_pairs:
3773  samples = pytorch_dist.sample((5,))
3774  try:
3775  cdf = pytorch_dist.cdf(samples)
3776  except NotImplementedError:
3777  continue
3778  self.assertEqual(cdf, scipy_dist.cdf(samples), message=pytorch_dist)
3779 
3780  def test_icdf(self):
3781  for pytorch_dist, scipy_dist in self.distribution_pairs:
3782  samples = torch.rand((5,) + pytorch_dist.batch_shape)
3783  try:
3784  icdf = pytorch_dist.icdf(samples)
3785  except NotImplementedError:
3786  continue
3787  self.assertEqual(icdf, scipy_dist.ppf(samples), message=pytorch_dist)
3788 
3789 
3791  def setUp(self):
3792  self.transforms = []
3793  transforms_by_cache_size = {}
3794  for cache_size in [0, 1]:
3795  transforms = [
3796  AbsTransform(cache_size=cache_size),
3797  ExpTransform(cache_size=cache_size),
3798  PowerTransform(exponent=2,
3799  cache_size=cache_size),
3800  PowerTransform(exponent=torch.tensor(5.).normal_(),
3801  cache_size=cache_size),
3802  SigmoidTransform(cache_size=cache_size),
3803  AffineTransform(0, 1, cache_size=cache_size),
3804  AffineTransform(1, -2, cache_size=cache_size),
3805  AffineTransform(torch.randn(5),
3806  torch.randn(5),
3807  cache_size=cache_size),
3808  AffineTransform(torch.randn(4, 5),
3809  torch.randn(4, 5),
3810  cache_size=cache_size),
3811  SoftmaxTransform(cache_size=cache_size),
3812  StickBreakingTransform(cache_size=cache_size),
3813  LowerCholeskyTransform(cache_size=cache_size),
3814  ComposeTransform([
3815  AffineTransform(torch.randn(4, 5),
3816  torch.randn(4, 5),
3817  cache_size=cache_size),
3818  ]),
3819  ComposeTransform([
3820  AffineTransform(torch.randn(4, 5),
3821  torch.randn(4, 5),
3822  cache_size=cache_size),
3823  ExpTransform(cache_size=cache_size),
3824  ]),
3825  ComposeTransform([
3826  AffineTransform(0, 1, cache_size=cache_size),
3827  AffineTransform(torch.randn(4, 5),
3828  torch.randn(4, 5),
3829  cache_size=cache_size),
3830  AffineTransform(1, -2, cache_size=cache_size),
3831  AffineTransform(torch.randn(4, 5),
3832  torch.randn(4, 5),
3833  cache_size=cache_size),
3834  ]),
3835  ]
3836  for t in transforms[:]:
3837  transforms.append(t.inv)
3838  transforms.append(identity_transform)
3839  self.transforms += transforms
3840  if cache_size == 0:
3841  self.unique_transforms = transforms[:]
3842 
3843  def _generate_data(self, transform):
3844  domain = transform.domain
3845  codomain = transform.codomain
3846  x = torch.empty(4, 5)
3847  if domain is constraints.lower_cholesky or codomain is constraints.lower_cholesky:
3848  x = torch.empty(6, 6)
3849  x = x.normal_()
3850  return x
3851  elif domain is constraints.real:
3852  return x.normal_()
3853  elif domain is constraints.positive:
3854  return x.normal_().exp()
3855  elif domain is constraints.unit_interval:
3856  return x.uniform_()
3857  elif domain is constraints.simplex:
3858  x = x.normal_().exp()
3859  x /= x.sum(-1, True)
3860  return x
3861  raise ValueError('Unsupported domain: {}'.format(domain))
3862 
3863  def test_inv_inv(self):
3864  for t in self.transforms:
3865  self.assertTrue(t.inv.inv is t)
3866 
3867  def test_equality(self):
3868  transforms = self.unique_transforms
3869  for x, y in product(transforms, transforms):
3870  if x is y:
3871  self.assertTrue(x == y)
3872  self.assertFalse(x != y)
3873  else:
3874  self.assertFalse(x == y)
3875  self.assertTrue(x != y)
3876 
3877  self.assertTrue(identity_transform == identity_transform.inv)
3878  self.assertFalse(identity_transform != identity_transform.inv)
3879 
3880  def test_forward_inverse_cache(self):
3881  for transform in self.transforms:
3882  x = self._generate_data(transform).requires_grad_()
3883  try:
3884  y = transform(x)
3885  except NotImplementedError:
3886  continue
3887  x2 = transform.inv(y) # should be implemented at least by caching
3888  y2 = transform(x2) # should be implemented at least by caching
3889  if transform.bijective:
3890  # verify function inverse
3891  self.assertEqual(x2, x, message='\n'.join([
3892  '{} t.inv(t(-)) error'.format(transform),
3893  'x = {}'.format(x),
3894  'y = t(x) = {}'.format(y),
3895  'x2 = t.inv(y) = {}'.format(x2),
3896  ]))
3897  else:
3898  # verify weaker function pseudo-inverse
3899  self.assertEqual(y2, y, message='\n'.join([
3900  '{} t(t.inv(t(-))) error'.format(transform),
3901  'x = {}'.format(x),
3902  'y = t(x) = {}'.format(y),
3903  'x2 = t.inv(y) = {}'.format(x2),
3904  'y2 = t(x2) = {}'.format(y2),
3905  ]))
3906 
3907  def test_forward_inverse_no_cache(self):
3908  for transform in self.transforms:
3909  x = self._generate_data(transform).requires_grad_()
3910  try:
3911  y = transform(x)
3912  x2 = transform.inv(y.clone()) # bypass cache
3913  y2 = transform(x2)
3914  except NotImplementedError:
3915  continue
3916  if transform.bijective:
3917  # verify function inverse
3918  self.assertEqual(x2, x, message='\n'.join([
3919  '{} t.inv(t(-)) error'.format(transform),
3920  'x = {}'.format(x),
3921  'y = t(x) = {}'.format(y),
3922  'x2 = t.inv(y) = {}'.format(x2),
3923  ]))
3924  else:
3925  # verify weaker function pseudo-inverse
3926  self.assertEqual(y2, y, message='\n'.join([
3927  '{} t(t.inv(t(-))) error'.format(transform),
3928  'x = {}'.format(x),
3929  'y = t(x) = {}'.format(y),
3930  'x2 = t.inv(y) = {}'.format(x2),
3931  'y2 = t(x2) = {}'.format(y2),
3932  ]))
3933 
3934  def test_univariate_forward_jacobian(self):
3935  for transform in self.transforms:
3936  if transform.event_dim > 0:
3937  continue
3938  x = self._generate_data(transform).requires_grad_()
3939  try:
3940  y = transform(x)
3941  actual = transform.log_abs_det_jacobian(x, y)
3942  except NotImplementedError:
3943  continue
3944  expected = torch.abs(grad([y.sum()], [x])[0]).log()
3945  self.assertEqual(actual, expected, message='\n'.join([
3946  'Bad {}.log_abs_det_jacobian() disagrees with ()'.format(transform),
3947  'Expected: {}'.format(expected),
3948  'Actual: {}'.format(actual),
3949  ]))
3950 
3951  def test_univariate_inverse_jacobian(self):
3952  for transform in self.transforms:
3953  if transform.event_dim > 0:
3954  continue
3955  y = self._generate_data(transform.inv).requires_grad_()
3956  try:
3957  x = transform.inv(y)
3958  actual = transform.log_abs_det_jacobian(x, y)
3959  except NotImplementedError:
3960  continue
3961  expected = -torch.abs(grad([x.sum()], [y])[0]).log()
3962  self.assertEqual(actual, expected, message='\n'.join([
3963  '{}.log_abs_det_jacobian() disagrees with .inv()'.format(transform),
3964  'Expected: {}'.format(expected),
3965  'Actual: {}'.format(actual),
3966  ]))
3967 
3968  def test_jacobian_shape(self):
3969  for transform in self.transforms:
3970  x = self._generate_data(transform)
3971  try:
3972  y = transform(x)
3973  actual = transform.log_abs_det_jacobian(x, y)
3974  except NotImplementedError:
3975  continue
3976  self.assertEqual(actual.shape, x.shape[:x.dim() - transform.event_dim])
3977 
3978  def test_transform_shapes(self):
3979  transform0 = ExpTransform()
3980  transform1 = SoftmaxTransform()
3981  transform2 = LowerCholeskyTransform()
3982 
3983  self.assertEqual(transform0.event_dim, 0)
3984  self.assertEqual(transform1.event_dim, 1)
3985  self.assertEqual(transform2.event_dim, 2)
3986  self.assertEqual(ComposeTransform([transform0, transform1]).event_dim, 1)
3987  self.assertEqual(ComposeTransform([transform0, transform2]).event_dim, 2)
3988  self.assertEqual(ComposeTransform([transform1, transform2]).event_dim, 2)
3989 
3990  def test_transformed_distribution_shapes(self):
3991  transform0 = ExpTransform()
3992  transform1 = SoftmaxTransform()
3993  transform2 = LowerCholeskyTransform()
3994  base_dist0 = Normal(torch.zeros(4, 4), torch.ones(4, 4))
3995  base_dist1 = Dirichlet(torch.ones(4, 4))
3996  base_dist2 = Normal(torch.zeros(3, 4, 4), torch.ones(3, 4, 4))
3997  examples = [
3998  ((4, 4), (), base_dist0),
3999  ((4,), (4,), base_dist1),
4000  ((4, 4), (), TransformedDistribution(base_dist0, [transform0])),
4001  ((4,), (4,), TransformedDistribution(base_dist0, [transform1])),
4002  ((4,), (4,), TransformedDistribution(base_dist0, [transform0, transform1])),
4003  ((), (4, 4), TransformedDistribution(base_dist0, [transform0, transform2])),
4004  ((4,), (4,), TransformedDistribution(base_dist0, [transform1, transform0])),
4005  ((), (4, 4), TransformedDistribution(base_dist0, [transform1, transform2])),
4006  ((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform0])),
4007  ((), (4, 4), TransformedDistribution(base_dist0, [transform2, transform1])),
4008  ((4,), (4,), TransformedDistribution(base_dist1, [transform0])),
4009  ((4,), (4,), TransformedDistribution(base_dist1, [transform1])),
4010  ((), (4, 4), TransformedDistribution(base_dist1, [transform2])),
4011  ((4,), (4,), TransformedDistribution(base_dist1, [transform0, transform1])),
4012  ((), (4, 4), TransformedDistribution(base_dist1, [transform0, transform2])),
4013  ((4,), (4,), TransformedDistribution(base_dist1, [transform1, transform0])),
4014  ((), (4, 4), TransformedDistribution(base_dist1, [transform1, transform2])),
4015  ((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform0])),
4016  ((), (4, 4), TransformedDistribution(base_dist1, [transform2, transform1])),
4017  ((3, 4, 4), (), base_dist2),
4018  ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2])),
4019  ((3,), (4, 4), TransformedDistribution(base_dist2, [transform0, transform2])),
4020  ((3,), (4, 4), TransformedDistribution(base_dist2, [transform1, transform2])),
4021  ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform0])),
4022  ((3,), (4, 4), TransformedDistribution(base_dist2, [transform2, transform1])),
4023  ]
4024  for batch_shape, event_shape, dist in examples:
4025  self.assertEqual(dist.batch_shape, batch_shape)
4026  self.assertEqual(dist.event_shape, event_shape)
4027  x = dist.rsample()
4028  try:
4029  dist.log_prob(x) # this should not crash
4030  except NotImplementedError:
4031  continue
4032 
4033  def test_jit_fwd(self):
4034  for transform in self.unique_transforms:
4035  x = self._generate_data(transform).requires_grad_()
4036 
4037  def f(x):
4038  return transform(x)
4039 
4040  try:
4041  traced_f = torch.jit.trace(f, (x,))
4042  except NotImplementedError:
4043  continue
4044 
4045  # check on different inputs
4046  x = self._generate_data(transform).requires_grad_()
4047  self.assertEqual(f(x), traced_f(x))
4048 
4049  def test_jit_inv(self):
4050  for transform in self.unique_transforms:
4051  y = self._generate_data(transform.inv).requires_grad_()
4052 
4053  def f(y):
4054  return transform.inv(y)
4055 
4056  try:
4057  traced_f = torch.jit.trace(f, (y,))
4058  except NotImplementedError:
4059  continue
4060 
4061  # check on different inputs
4062  y = self._generate_data(transform.inv).requires_grad_()
4063  self.assertEqual(f(y), traced_f(y))
4064 
4065  def test_jit_jacobian(self):
4066  for transform in self.unique_transforms:
4067  x = self._generate_data(transform).requires_grad_()
4068 
4069  def f(x):
4070  y = transform(x)
4071  return transform.log_abs_det_jacobian(x, y)
4072 
4073  try:
4074  traced_f = torch.jit.trace(f, (x,))
4075  except NotImplementedError:
4076  continue
4077 
4078  # check on different inputs
4079  x = self._generate_data(transform).requires_grad_()
4080  self.assertEqual(f(x), traced_f(x))
4081 
4082 
4084  def get_constraints(self, is_cuda=False):
4085  tensor = torch.cuda.DoubleTensor if is_cuda else torch.DoubleTensor
4086  return [
4087  constraints.real,
4088  constraints.positive,
4089  constraints.greater_than(tensor([-10., -2, 0, 2, 10])),
4090  constraints.greater_than(0),
4091  constraints.greater_than(2),
4092  constraints.greater_than(-2),
4093  constraints.greater_than_eq(0),
4094  constraints.greater_than_eq(2),
4095  constraints.greater_than_eq(-2),
4096  constraints.less_than(tensor([-10., -2, 0, 2, 10])),
4097  constraints.less_than(0),
4098  constraints.less_than(2),
4099  constraints.less_than(-2),
4100  constraints.unit_interval,
4101  constraints.interval(tensor([-4., -2, 0, 2, 4]),
4102  tensor([-3., 3, 1, 5, 5])),
4103  constraints.interval(-2, -1),
4104  constraints.interval(1, 2),
4105  constraints.half_open_interval(tensor([-4., -2, 0, 2, 4]),
4106  tensor([-3., 3, 1, 5, 5])),
4107  constraints.half_open_interval(-2, -1),
4108  constraints.half_open_interval(1, 2),
4109  constraints.simplex,
4110  constraints.lower_cholesky,
4111  ]
4112 
4113  def test_biject_to(self):
4114  for constraint in self.get_constraints():
4115  try:
4116  t = biject_to(constraint)
4117  except NotImplementedError:
4118  continue
4119  self.assertTrue(t.bijective, "biject_to({}) is not bijective".format(constraint))
4120  x = torch.randn(5, 5)
4121  y = t(x)
4122  self.assertTrue(constraint.check(y).all(), '\n'.join([
4123  "Failed to biject_to({})".format(constraint),
4124  "x = {}".format(x),
4125  "biject_to(...)(x) = {}".format(y),
4126  ]))
4127  x2 = t.inv(y)
4128  self.assertEqual(x, x2, message="Error in biject_to({}) inverse".format(constraint))
4129 
4130  j = t.log_abs_det_jacobian(x, y)
4131  self.assertEqual(j.shape, x.shape[:x.dim() - t.event_dim])
4132 
4133  @unittest.skipIf(not TEST_CUDA, "CUDA not found")
4134  def test_biject_to_cuda(self):
4135  for constraint in self.get_constraints(is_cuda=True):
4136  try:
4137  t = biject_to(constraint)
4138  except NotImplementedError:
4139  continue
4140  self.assertTrue(t.bijective, "biject_to({}) is not bijective".format(constraint))
4141  # x = torch.randn(5, 5, device="cuda")
4142  x = torch.randn(5, 5).cuda()
4143  y = t(x)
4144  self.assertTrue(constraint.check(y).all(), '\n'.join([
4145  "Failed to biject_to({})".format(constraint),
4146  "x = {}".format(x),
4147  "biject_to(...)(x) = {}".format(y),
4148  ]))
4149  x2 = t.inv(y)
4150  self.assertEqual(x, x2, message="Error in biject_to({}) inverse".format(constraint))
4151 
4152  j = t.log_abs_det_jacobian(x, y)
4153  self.assertEqual(j.shape, x.shape[:x.dim() - t.event_dim])
4154 
4155  def test_transform_to(self):
4156  for constraint in self.get_constraints():
4157  t = transform_to(constraint)
4158  x = torch.randn(5, 5)
4159  y = t(x)
4160  self.assertTrue(constraint.check(y).all(), "Failed to transform_to({})".format(constraint))
4161  x2 = t.inv(y)
4162  y2 = t(x2)
4163  self.assertEqual(y, y2, message="Error in transform_to({}) pseudoinverse".format(constraint))
4164 
4165  @unittest.skipIf(not TEST_CUDA, "CUDA not found")
4166  def test_transform_to_cuda(self):
4167  for constraint in self.get_constraints(is_cuda=True):
4168  t = transform_to(constraint)
4169  # x = torch.randn(5, 5, device="cuda")
4170  x = torch.randn(5, 5).cuda()
4171  y = t(x)
4172  self.assertTrue(constraint.check(y).all(), "Failed to transform_to({})".format(constraint))
4173  x2 = t.inv(y)
4174  y2 = t(x2)
4175  self.assertEqual(y, y2, message="Error in transform_to({}) pseudoinverse".format(constraint))
4176 
4177 
4179  def setUp(self):
4180  super(TestCase, self).setUp()
4181  Distribution.set_default_validate_args(True)
4182 
4183  def test_valid(self):
4184  for Dist, params in EXAMPLES:
4185  for param in params:
4186  Dist(validate_args=True, **param)
4187 
4188  @unittest.skipIf(TEST_WITH_UBSAN, "division-by-zero error with UBSAN")
4189  def test_invalid(self):
4190  for Dist, params in BAD_EXAMPLES:
4191  for i, param in enumerate(params):
4192  try:
4193  with self.assertRaises(ValueError):
4194  Dist(validate_args=True, **param)
4195  except AssertionError:
4196  fail_string = 'ValueError not raised for {} example {}/{}'
4197  raise AssertionError(fail_string.format(Dist.__name__, i + 1, len(params)))
4198 
4199  def tearDown(self):
4200  super(TestCase, self).tearDown()
4201  Distribution.set_default_validate_args(False)
4202 
4203 
4205  def _examples(self):
4206  for Dist, params in EXAMPLES:
4207  for param in params:
4208  keys = param.keys()
4209  values = tuple(param[key] for key in keys)
4210  if not all(isinstance(x, torch.Tensor) for x in values):
4211  continue
4212  sample = Dist(**param).sample()
4213  yield Dist, keys, values, sample
4214 
4215  def _perturb_tensor(self, value, constraint):
4216  if isinstance(constraint, constraints._IntegerGreaterThan):
4217  return value + 1
4218  if isinstance(constraint, constraints._PositiveDefinite):
4219  return value + torch.eye(value.shape[-1])
4220  if value.dtype in [torch.float, torch.double]:
4221  transform = transform_to(constraint)
4222  delta = value.new(value.shape).normal_()
4223  return transform(transform.inv(value) + delta)
4224  if value.dtype == torch.long:
4225  result = value.clone()
4226  result[value == 0] = 1
4227  result[value == 1] = 0
4228  return result
4229  raise NotImplementedError
4230 
4231  def _perturb(self, Dist, keys, values, sample):
4232  with torch.no_grad():
4233  if Dist is Uniform:
4234  param = dict(zip(keys, values))
4235  param['low'] = param['low'] - torch.rand(param['low'].shape)
4236  param['high'] = param['high'] + torch.rand(param['high'].shape)
4237  values = [param[key] for key in keys]
4238  else:
4239  values = [self._perturb_tensor(value, Dist.arg_constraints.get(key, constraints.real))
4240  for key, value in zip(keys, values)]
4241  param = dict(zip(keys, values))
4242  sample = Dist(**param).sample()
4243  return values, sample
4244 
4245  def test_sample(self):
4246  for Dist, keys, values, sample in self._examples():
4247 
4248  def f(*values):
4249  param = dict(zip(keys, values))
4250  dist = Dist(**param)
4251  return dist.sample()
4252 
4253  traced_f = torch.jit.trace(f, values, check_trace=False)
4254 
4255  # FIXME Schema not found for node
4256  xfail = [
4257  Cauchy, # aten::cauchy(Double(2,1), float, float, Generator)
4258  HalfCauchy, # aten::cauchy(Double(2, 1), float, float, Generator)
4259  ]
4260  if Dist in xfail:
4261  continue
4262 
4263  with torch.random.fork_rng():
4264  sample = f(*values)
4265  traced_sample = traced_f(*values)
4266  self.assertEqual(sample, traced_sample)
4267 
4268  # FIXME no nondeterministic nodes found in trace
4269  xfail = [Beta, Dirichlet]
4270  if Dist not in xfail:
4271  self.assertTrue(any(n.isNondeterministic() for n in traced_f.graph.nodes()))
4272 
4273  def test_rsample(self):
4274  for Dist, keys, values, sample in self._examples():
4275  if not Dist.has_rsample:
4276  continue
4277 
4278  def f(*values):
4279  param = dict(zip(keys, values))
4280  dist = Dist(**param)
4281  return dist.rsample()
4282 
4283  traced_f = torch.jit.trace(f, values, check_trace=False)
4284 
4285  # FIXME Schema not found for node
4286  xfail = [
4287  Cauchy, # aten::cauchy(Double(2,1), float, float, Generator)
4288  HalfCauchy, # aten::cauchy(Double(2, 1), float, float, Generator)
4289  ]
4290  if Dist in xfail:
4291  continue
4292 
4293  with torch.random.fork_rng():
4294  sample = f(*values)
4295  traced_sample = traced_f(*values)
4296  self.assertEqual(sample, traced_sample)
4297 
4298  # FIXME no nondeterministic nodes found in trace
4299  xfail = [Beta, Dirichlet]
4300  if Dist not in xfail:
4301  self.assertTrue(any(n.isNondeterministic() for n in traced_f.graph.nodes()))
4302 
4303  def test_log_prob(self):
4304  for Dist, keys, values, sample in self._examples():
4305  # FIXME traced functions produce incorrect results
4306  xfail = [LowRankMultivariateNormal, MultivariateNormal]
4307  if Dist in xfail:
4308  continue
4309 
4310  def f(sample, *values):
4311  param = dict(zip(keys, values))
4312  dist = Dist(**param)
4313  return dist.log_prob(sample)
4314 
4315  traced_f = torch.jit.trace(f, (sample,) + values)
4316 
4317  # check on different data
4318  values, sample = self._perturb(Dist, keys, values, sample)
4319  expected = f(sample, *values)
4320  actual = traced_f(sample, *values)
4321  self.assertEqual(expected, actual,
4322  message='{}\nExpected:\n{}\nActual:\n{}'.format(Dist.__name__, expected, actual))
4323 
4324  def test_enumerate_support(self):
4325  for Dist, keys, values, sample in self._examples():
4326  # FIXME traced functions produce incorrect results
4327  xfail = [Binomial]
4328  if Dist in xfail:
4329  continue
4330 
4331  def f(*values):
4332  param = dict(zip(keys, values))
4333  dist = Dist(**param)
4334  return dist.enumerate_support()
4335 
4336  try:
4337  traced_f = torch.jit.trace(f, values)
4338  except NotImplementedError:
4339  continue
4340 
4341  # check on different data
4342  values, sample = self._perturb(Dist, keys, values, sample)
4343  expected = f(*values)
4344  actual = traced_f(*values)
4345  self.assertEqual(expected, actual,
4346  message='{}\nExpected:\n{}\nActual:\n{}'.format(Dist.__name__, expected, actual))
4347 
4348  def test_mean(self):
4349  for Dist, keys, values, sample in self._examples():
4350 
4351  def f(*values):
4352  param = dict(zip(keys, values))
4353  dist = Dist(**param)
4354  return dist.mean
4355 
4356  try:
4357  traced_f = torch.jit.trace(f, values)
4358  except NotImplementedError:
4359  continue
4360 
4361  # check on different data
4362  values, sample = self._perturb(Dist, keys, values, sample)
4363  expected = f(*values)
4364  actual = traced_f(*values)
4365  expected[expected == float('inf')] = 0.
4366  actual[actual == float('inf')] = 0.
4367  self.assertEqual(expected, actual, allow_inf=True,
4368  message='{}\nExpected:\n{}\nActual:\n{}'.format(Dist.__name__, expected, actual))
4369 
4370  def test_variance(self):
4371  for Dist, keys, values, sample in self._examples():
4372  if Dist in [Cauchy, HalfCauchy]:
4373  continue # infinite variance
4374 
4375  def f(*values):
4376  param = dict(zip(keys, values))
4377  dist = Dist(**param)
4378  return dist.variance
4379 
4380  try:
4381  traced_f = torch.jit.trace(f, values)
4382  except NotImplementedError:
4383  continue
4384 
4385  # check on different data
4386  values, sample = self._perturb(Dist, keys, values, sample)
4387  expected = f(*values)
4388  actual = traced_f(*values)
4389  expected[expected == float('inf')] = 0.
4390  actual[actual == float('inf')] = 0.
4391  self.assertEqual(expected, actual, allow_inf=True,
4392  message='{}\nExpected:\n{}\nActual:\n{}'.format(Dist.__name__, expected, actual))
4393 
4394  def test_entropy(self):
4395  for Dist, keys, values, sample in self._examples():
4396  # FIXME traced functions produce incorrect results
4397  xfail = [LowRankMultivariateNormal, MultivariateNormal]
4398  if Dist in xfail:
4399  continue
4400 
4401  def f(*values):
4402  param = dict(zip(keys, values))
4403  dist = Dist(**param)
4404  return dist.entropy()
4405 
4406  try:
4407  traced_f = torch.jit.trace(f, values)
4408  except NotImplementedError:
4409  continue
4410 
4411  # check on different data
4412  values, sample = self._perturb(Dist, keys, values, sample)
4413  expected = f(*values)
4414  actual = traced_f(*values)
4415  self.assertEqual(expected, actual, allow_inf=True,
4416  message='{}\nExpected:\n{}\nActual:\n{}'.format(Dist.__name__, expected, actual))
4417 
4418  def test_cdf(self):
4419  for Dist, keys, values, sample in self._examples():
4420 
4421  def f(sample, *values):
4422  param = dict(zip(keys, values))
4423  dist = Dist(**param)
4424  cdf = dist.cdf(sample)
4425  return dist.icdf(cdf)
4426 
4427  try:
4428  traced_f = torch.jit.trace(f, (sample,) + values)
4429  except NotImplementedError:
4430  continue
4431 
4432  # check on different data
4433  values, sample = self._perturb(Dist, keys, values, sample)
4434  expected = f(sample, *values)
4435  actual = traced_f(sample, *values)
4436  self.assertEqual(expected, actual, allow_inf=True,
4437  message='{}\nExpected:\n{}\nActual:\n{}'.format(Dist.__name__, expected, actual))
4438 
4439 
4440 if __name__ == '__main__' and torch._C.has_lapack:
4441  run_tests()
def assertEqual(self, x, y, prec=None, message='', allow_inf=False)
def pairwise(Dist, params)
Definition: test.py:1
def trace(func, example_inputs, optimize=True, check_trace=True, check_inputs=None, check_tolerance=1e-5, _force_outplace=False, _module_class=None)
Definition: __init__.py:596
def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices")
Definition: random.py:49
def _perturb_tensor(self, value, constraint)
def _perturb(self, Dist, keys, values, sample)
def _test_pdf_score(self, dist_class, x, expected_value, probs=None, logits=None, expected_gradient=None, prec=1e-5)