2 from numbers
import Number
13 Creates a Student's t-distribution parameterized by degree of 14 freedom :attr:`df`, mean :attr:`loc` and scale :attr:`scale`. 18 >>> m = StudentT(torch.tensor([2.0])) 19 >>> m.sample() # Student's t-distributed with degrees of freedom=2 23 df (float or Tensor): degrees of freedom 24 loc (float or Tensor): mean of the distribution 25 scale (float or Tensor): scale of the distribution 27 arg_constraints = {
'df': constraints.positive,
'loc': constraints.real,
'scale': constraints.positive}
28 support = constraints.real
40 m[self.df > 2] = self.
scale[self.df > 2].pow(2) * self.df[self.df > 2] / (self.df[self.df > 2] - 2)
41 m[(self.df <= 2) & (self.df > 1)] = inf
45 def __init__(self, df, loc=0., scale=1., validate_args=None):
46 self.df, self.loc, self.
scale = broadcast_all(df, loc, scale)
47 self.
_chi2 = Chi2(self.df)
48 batch_shape = self.df.size()
49 super(StudentT, self).__init__(batch_shape, validate_args=validate_args)
51 def expand(self, batch_shape, _instance=None):
53 batch_shape = torch.Size(batch_shape)
54 new.df = self.df.expand(batch_shape)
55 new.loc = self.loc.expand(batch_shape)
56 new.scale = self.scale.expand(batch_shape)
57 new._chi2 = self._chi2.expand(batch_shape)
58 super(StudentT, new).__init__(batch_shape, validate_args=
False)
62 def rsample(self, sample_shape=torch.Size()):
71 X = _standard_normal(shape, dtype=self.df.dtype, device=self.df.device)
72 Z = self._chi2.rsample(sample_shape)
73 Y = X * torch.rsqrt(Z / self.df)
74 return self.loc + self.
scale * Y
76 def log_prob(self, value):
79 y = (value - self.loc) / self.
scale 80 Z = (self.scale.log() +
82 0.5 * math.log(math.pi) +
83 torch.lgamma(0.5 * self.df) -
84 torch.lgamma(0.5 * (self.df + 1.)))
85 return -0.5 * (self.df + 1.) * torch.log1p(y**2. / self.df) - Z
88 lbeta = torch.lgamma(0.5 * self.df) + math.lgamma(0.5) - torch.lgamma(0.5 * (self.df + 1))
89 return (self.scale.log() +
91 (torch.digamma(0.5 * (self.df + 1)) - torch.digamma(0.5 * self.df)) +
92 0.5 * self.df.log() + lbeta)
def _get_checked_instance(self, cls, _instance=None)
def _extended_shape(self, sample_shape=torch.Size())
def _validate_sample(self, value)