Caffe2 - Python API
A deep learning, cross platform ML framework
utils.py
1 from functools import update_wrapper
2 from numbers import Number
3 import math
4 import torch
5 import torch.nn.functional as F
6 
7 
8 # promote numbers to tensors of dtype torch.get_default_dtype()
9 def _default_promotion(v):
10  return torch.tensor(v, dtype=torch.get_default_dtype())
11 
12 
13 def broadcast_all(*values):
14  r"""
15  Given a list of values (possibly containing numbers), returns a list where each
16  value is broadcasted based on the following rules:
17  - `torch.*Tensor` instances are broadcasted as per :ref:`_broadcasting-semantics`.
18  - numbers.Number instances (scalars) are upcast to tensors having
19  the same size and type as the first tensor passed to `values`. If all the
20  values are scalars, then they are upcasted to scalar Tensors.
21 
22  Args:
23  values (list of `numbers.Number` or `torch.*Tensor`)
24 
25  Raises:
26  ValueError: if any of the values is not a `numbers.Number` or
27  `torch.*Tensor` instance
28  """
29  if not all(torch.is_tensor(v) or isinstance(v, Number) for v in values):
30  raise ValueError('Input arguments must all be instances of numbers.Number or torch.tensor.')
31  if not all(map(torch.is_tensor, values)):
32  new_tensor = _default_promotion
33  for value in values:
34  if torch.is_tensor(value):
35  new_tensor = value.new_tensor
36  break
37  values = [v if torch.is_tensor(v) else new_tensor(v) for v in values]
38  return torch.broadcast_tensors(*values)
39 
40 
41 def _standard_normal(shape, dtype, device):
42  if torch._C._get_tracing_state():
43  # [JIT WORKAROUND] lack of support for .normal_()
44  return torch.normal(torch.zeros(shape, dtype=dtype, device=device),
45  torch.ones(shape, dtype=dtype, device=device))
46  return torch.empty(shape, dtype=dtype, device=device).normal_()
47 
48 
49 def _sum_rightmost(value, dim):
50  r"""
51  Sum out ``dim`` many rightmost dimensions of a given tensor.
52 
53  Args:
54  value (Tensor): A tensor of ``.dim()`` at least ``dim``.
55  dim (int): The number of rightmost dims to sum out.
56  """
57  if dim == 0:
58  return value
59  required_shape = value.shape[:-dim] + (-1,)
60  return value.reshape(required_shape).sum(-1)
61 
62 
63 def logits_to_probs(logits, is_binary=False):
64  r"""
65  Converts a tensor of logits into probabilities. Note that for the
66  binary case, each value denotes log odds, whereas for the
67  multi-dimensional case, the values along the last dimension denote
68  the log probabilities (possibly unnormalized) of the events.
69  """
70  if is_binary:
71  return torch.sigmoid(logits)
72  return F.softmax(logits, dim=-1)
73 
74 
75 def clamp_probs(probs):
76  eps = torch.finfo(probs.dtype).eps
77  return probs.clamp(min=eps, max=1 - eps)
78 
79 
80 def probs_to_logits(probs, is_binary=False):
81  r"""
82  Converts a tensor of probabilities into logits. For the binary case,
83  this denotes the probability of occurrence of the event indexed by `1`.
84  For the multi-dimensional case, the values along the last dimension
85  denote the probabilities of occurrence of each of the events.
86  """
87  ps_clamped = clamp_probs(probs)
88  if is_binary:
89  return torch.log(ps_clamped) - torch.log1p(-ps_clamped)
90  return torch.log(ps_clamped)
91 
92 
93 class lazy_property(object):
94  r"""
95  Used as a decorator for lazy loading of class attributes. This uses a
96  non-data descriptor that calls the wrapped method to compute the property on
97  first call; thereafter replacing the wrapped method into an instance
98  attribute.
99  """
100  def __init__(self, wrapped):
101  self.wrapped = wrapped
102  update_wrapper(self, wrapped)
103 
104  def __get__(self, instance, obj_type=None):
105  if instance is None:
106  return self
107  with torch.enable_grad():
108  value = self.wrapped(instance)
109  setattr(instance, self.wrapped.__name__, value)
110  return value
def is_tensor(obj)
Definition: __init__.py:114