Caffe2 - Python API
A deep learning, cross platform ML framework
uniform.py
1 from numbers import Number
2 
3 import torch
4 from torch.distributions import constraints
5 from torch.distributions.distribution import Distribution
6 from torch.distributions.utils import broadcast_all
7 
8 
10  r"""
11  Generates uniformly distributed random samples from the half-open interval
12  ``[low, high)``.
13 
14  Example::
15 
16  >>> m = Uniform(torch.tensor([0.0]), torch.tensor([5.0]))
17  >>> m.sample() # uniformly distributed in the range [0.0, 5.0)
18  tensor([ 2.3418])
19 
20  Args:
21  low (float or Tensor): lower range (inclusive).
22  high (float or Tensor): upper range (exclusive).
23  """
24  # TODO allow (loc,scale) parameterization to allow independent constraints.
25  arg_constraints = {'low': constraints.dependent, 'high': constraints.dependent}
26  has_rsample = True
27 
28  @property
29  def mean(self):
30  return (self.high + self.low) / 2
31 
32  @property
33  def stddev(self):
34  return (self.high - self.low) / 12**0.5
35 
36  @property
37  def variance(self):
38  return (self.high - self.low).pow(2) / 12
39 
40  def __init__(self, low, high, validate_args=None):
41  self.low, self.high = broadcast_all(low, high)
42 
43  if isinstance(low, Number) and isinstance(high, Number):
44  batch_shape = torch.Size()
45  else:
46  batch_shape = self.low.size()
47  super(Uniform, self).__init__(batch_shape, validate_args=validate_args)
48 
49  if self._validate_args and not torch.lt(self.low, self.high).all():
50  raise ValueError("Uniform is not defined when low>= high")
51 
52  def expand(self, batch_shape, _instance=None):
53  new = self._get_checked_instance(Uniform, _instance)
54  batch_shape = torch.Size(batch_shape)
55  new.low = self.low.expand(batch_shape)
56  new.high = self.high.expand(batch_shape)
57  super(Uniform, new).__init__(batch_shape, validate_args=False)
58  new._validate_args = self._validate_args
59  return new
60 
61  @constraints.dependent_property
62  def support(self):
63  return constraints.interval(self.low, self.high)
64 
65  def rsample(self, sample_shape=torch.Size()):
66  shape = self._extended_shape(sample_shape)
67  rand = torch.rand(shape, dtype=self.low.dtype, device=self.low.device)
68  return self.low + rand * (self.high - self.low)
69 
70  def log_prob(self, value):
71  if self._validate_args:
72  self._validate_sample(value)
73  lb = value.ge(self.low).type_as(self.low)
74  ub = value.lt(self.high).type_as(self.low)
75  return torch.log(lb.mul(ub)) - torch.log(self.high - self.low)
76 
77  def cdf(self, value):
78  if self._validate_args:
79  self._validate_sample(value)
80  result = (value - self.low) / (self.high - self.low)
81  return result.clamp(min=0, max=1)
82 
83  def icdf(self, value):
84  if self._validate_args:
85  self._validate_sample(value)
86  result = value * (self.high - self.low) + self.low
87  return result
88 
89  def entropy(self):
90  return torch.log(self.high - self.low)
def _get_checked_instance(self, cls, _instance=None)
def _extended_shape(self, sample_shape=torch.Size())