1 from numbers
import Number
12 Creates a Geometric distribution parameterized by :attr:`probs`, 13 where :attr:`probs` is the probability of success of Bernoulli trials. 14 It represents the probability that in :math:`k + 1` Bernoulli trials, the 15 first :math:`k` trials failed, before seeing a success. 17 Samples are non-negative integers [0, :math:`\inf`). 21 >>> m = Geometric(torch.tensor([0.3])) 22 >>> m.sample() # underlying Bernoulli has 30% chance 1; 70% chance 0 26 probs (Number, Tensor): the probability of sampling `1`. Must be in range (0, 1] 27 logits (Number, Tensor): the log-odds of sampling `1`. 29 arg_constraints = {
'probs': constraints.unit_interval,
30 'logits': constraints.real}
31 support = constraints.nonnegative_integer
33 def __init__(self, probs=None, logits=None, validate_args=None):
34 if (probs
is None) == (logits
is None):
35 raise ValueError(
"Either `probs` or `logits` must be specified, but not both.")
37 self.
probs, = broadcast_all(probs)
38 if not self.probs.gt(0).all():
39 raise ValueError(
'All elements of probs must be greater than 0')
41 self.
logits, = broadcast_all(logits)
42 probs_or_logits = probs
if probs
is not None else logits
43 if isinstance(probs_or_logits, Number):
44 batch_shape = torch.Size()
46 batch_shape = probs_or_logits.size()
47 super(Geometric, self).__init__(batch_shape, validate_args=validate_args)
49 def expand(self, batch_shape, _instance=None):
51 batch_shape = torch.Size(batch_shape)
52 if 'probs' in self.__dict__:
53 new.probs = self.probs.expand(batch_shape)
55 new.logits = self.logits.expand(batch_shape)
56 super(Geometric, new).__init__(batch_shape, validate_args=
False)
62 return 1. / self.
probs - 1.
70 return probs_to_logits(self.
probs, is_binary=
True)
74 return logits_to_probs(self.
logits, is_binary=
True)
76 def sample(self, sample_shape=torch.Size()):
78 tiny = torch.finfo(self.probs.dtype).tiny
80 if torch._C._get_tracing_state():
82 u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
85 u = self.probs.new(shape).uniform_(tiny, 1)
86 return (u.log() / (-self.
probs).log1p()).floor()
88 def log_prob(self, value):
91 value, probs = broadcast_all(value, self.probs.clone())
92 probs[(probs == 1) & (value == 0)] = 0
93 return value * (-probs).log1p() + self.probs.log()
96 return binary_cross_entropy_with_logits(self.
logits, self.
probs, reduction=
'none') / self.
probs
def _get_checked_instance(self, cls, _instance=None)
def _extended_shape(self, sample_shape=torch.Size())
def _validate_sample(self, value)