Caffe2 - Python API
A deep learning, cross platform ML framework
function.py
1 import torch
2 import torch._C as _C
3 import torch.utils.hooks as hooks
4 from torch._six import with_metaclass
5 import functools
6 import warnings
7 from collections import OrderedDict
8 
9 
10 class _ContextMethodMixin(object):
11 
12  def save_for_backward(self, *tensors):
13  r"""Saves given tensors for a future call to :func:`~Function.backward`.
14 
15  **This should be called at most once, and only from inside the**
16  :func:`forward` **method.**
17 
18  Later, saved tensors can be accessed through the :attr:`saved_tensors`
19  attribute. Before returning them to the user, a check is made to ensure
20  they weren't used in any in-place operation that modified their content.
21 
22  Arguments can also be ``None``.
23  """
24  self.to_save = tensors
25 
26  def mark_dirty(self, *args):
27  r"""Marks given tensors as modified in an in-place operation.
28 
29  **This should be called at most once, only from inside the**
30  :func:`forward` **method, and all arguments should be inputs.**
31 
32  Every tensor that's been modified in-place in a call to :func:`forward`
33  should be given to this function, to ensure correctness of our checks.
34  It doesn't matter whether the function is called before or after
35  modification.
36  """
37  self.dirty_tensors = args
38 
39  def mark_shared_storage(self, *pairs):
40  warnings.warn(
41  'mark_shared_storage is deprecated. '
42  'Tensors with shared storages are automatically tracked. Note '
43  'that calls to `set_()` are not tracked')
44 
45  def mark_non_differentiable(self, *args):
46  r"""Marks outputs as non-differentiable.
47 
48  **This should be called at most once, only from inside the**
49  :func:`forward` **method, and all arguments should be outputs.**
50 
51  This will mark outputs as not requiring gradients, increasing the
52  efficiency of backward computation. You still need to accept a gradient
53  for each output in :meth:`~Function.backward`, but it's always going to
54  be a zero tensor with the same shape as the shape of a corresponding
55  output.
56 
57  This is used e.g. for indices returned from a max :class:`Function`.
58  """
59  self.non_differentiable = args
60 
61 
62 class _HookMixin(object):
63 
64  @staticmethod
65  def _register_hook(backward_hooks, hook):
66  if backward_hooks is None:
67  backward_hooks = OrderedDict()
68  handle = hooks.RemovableHandle(backward_hooks)
69  backward_hooks[handle.id] = hook
70  return backward_hooks, handle
71 
72 
74  _is_legacy = False
75 
76  def apply(self, *args):
77  return self._forward_cls.backward(self, *args)
78 
79 
80 class FunctionMeta(type):
81  """Function metaclass.
82 
83  This metaclass sets up the following properties:
84  _is_legacy: True if forward is not defined as a static method.
85  _backward_cls: The Function class corresponding to the differentiated
86  version of this function (which is generated on the fly by this
87  metaclass).
88  """
89 
90  def __init__(cls, name, bases, attrs):
91  for super_cls in cls.mro():
92  forward = super_cls.__dict__.get('forward')
93  if forward is not None:
94  has_static_forward = isinstance(forward, staticmethod) or isinstance(forward, classmethod)
95  break
96 
97  setattr(cls, '_is_legacy', not has_static_forward)
98 
99  # old-style functions
100  if not has_static_forward:
101  return super(FunctionMeta, cls).__init__(name, bases, attrs)
102 
103  backward_fn = type(name + 'Backward', (BackwardCFunction,), {'_forward_cls': cls})
104  setattr(cls, '_backward_cls', backward_fn)
105 
106  return super(FunctionMeta, cls).__init__(name, bases, attrs)
107 
108 
109 class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixin, _HookMixin)):
110  r"""Records operation history and defines formulas for differentiating ops.
111 
112  Every operation performed on :class:`Tensor` s creates a new function
113  object, that performs the computation, and records that it happened.
114  The history is retained in the form of a DAG of functions, with edges
115  denoting data dependencies (``input <- output``). Then, when backward is
116  called, the graph is processed in the topological ordering, by calling
117  :func:`backward` methods of each :class:`Function` object, and passing
118  returned gradients on to next :class:`Function` s.
119 
120  Normally, the only way users interact with functions is by creating
121  subclasses and defining new operations. This is a recommended way of
122  extending torch.autograd.
123 
124  Each function object is meant to be used only once (in the forward pass).
125 
126  Examples::
127 
128  >>> class Exp(Function):
129  >>>
130  >>> @staticmethod
131  >>> def forward(ctx, i):
132  >>> result = i.exp()
133  >>> ctx.save_for_backward(result)
134  >>> return result
135  >>>
136  >>> @staticmethod
137  >>> def backward(ctx, grad_output):
138  >>> result, = ctx.saved_tensors
139  >>> return grad_output * result
140  """
141 
142  # only for backward compatibility
143  __call__ = _C._FunctionBase._do_forward
144 
145  # for the tracer
146  is_traceable = False
147 
148  @staticmethod
149  def forward(ctx, *args, **kwargs):
150  r"""Performs the operation.
151 
152  This function is to be overridden by all subclasses.
153 
154  It must accept a context ctx as the first argument, followed by any
155  number of arguments (tensors or other types).
156 
157  The context can be used to store tensors that can be then retrieved
158  during the backward pass.
159  """
160  raise NotImplementedError
161 
162  @staticmethod
163  def backward(ctx, *grad_outputs):
164  r"""Defines a formula for differentiating the operation.
165 
166  This function is to be overridden by all subclasses.
167 
168  It must accept a context :attr:`ctx` as the first argument, followed by
169  as many outputs did :func:`forward` return, and it should return as many
170  tensors, as there were inputs to :func:`forward`. Each argument is the
171  gradient w.r.t the given output, and each returned value should be the
172  gradient w.r.t. the corresponding input.
173 
174  The context can be used to retrieve tensors saved during the forward
175  pass. It also has an attribute :attr:`ctx.needs_input_grad` as a tuple
176  of booleans representing whether each input needs gradient. E.g.,
177  :func:`backward` will have ``ctx.needs_input_grad[0] = True`` if the
178  first input to :func:`forward` needs gradient computated w.r.t. the
179  output.
180  """
181  raise NotImplementedError
182 
183 
184 def once_differentiable(fn):
185 
186  @functools.wraps(fn)
187  def wrapper(ctx, *args):
188  with torch.no_grad():
189  outputs = fn(ctx, *args)
190 
191  if not torch.is_grad_enabled():
192  return outputs
193 
194  # If any of the inputs have requires_grad=True, we force the outputs
195  # to have requires_grad=True but point to a grad_fn which throws an
196  # error message during (double) back-propagation.
197  # XXX: this is only an approximation of requires_grad - there's no way
198  # to figure out if fn didn't use ctx.saved_tensors and as a result
199  # some Tensors might require grad, even if no args do.
200  # Unfortunately, this leads to unexpected error messages ("no nodes
201  # require computing gradients"), but I don't have a better idea.
202  # These functions would raise an error in backward anyway.
203  requires_grad = any(isinstance(arg, torch.Tensor) and arg.requires_grad
204  for arg in args)
205  if not requires_grad:
206  return outputs
207 
208  if not isinstance(outputs, tuple):
209  outputs = (outputs,)
210 
211  err_fn = torch._C._functions.DelayedError(
212  b"trying to differentiate twice a function that was marked"
213  b"with @once_differentiable", len(outputs))
214 
215  # Create aliases of each output that has requires_grad=True. We need
216  # at least one of the inputs to err_fn to require grad so that the
217  # output will have a grad_fn.
218  def fake_requires_grad(var):
219  if var is not None:
220  var = var.detach()
221  var.requires_grad = True
222  return var
223 
224  return err_fn(*[fake_requires_grad(v) for v in outputs])
225  return wrapper
226 
227 
228 def traceable(fn_cls):
229  r"""Marks Function as traceable for the JIT.
230 
231  Traceable functions have additional restrictions - they can't pass any
232  data-dependent values to backward (e.g. Prod passes the output, which makes
233  it non-traceable), and their backward should be implemented entirely in terms
234  of operations on autograd Tensors in all cases.
235 
236  DON'T USE THIS DECORATOR. IT IS FOR INTERNAL USE ONLY AND SHOULD BE HANDLED WITH
237  CARE (or can give incorrect results otherwise).
238  """
239  fn_cls.is_traceable = True
240  return fn_cls
241 
242 
244 
245  def __init__(self, inplace=False):
246  super(InplaceFunction, self).__init__()
247  self.inplace = inplace
248 
249 
250 def _nested_map(condition, fn, condition_msg=None):
251  def _map(obj):
252  if condition(obj):
253  return fn(obj)
254  elif obj is None:
255  return None
256  elif isinstance(obj, (list, tuple)):
257  return type(obj)(_map(x) for x in obj)
258  else:
259  raise ValueError("Auto nesting doesn't know how to process "
260  "an input object of type " + torch.typename(obj) +
261  (". Accepted types: " + condition_msg +
262  ", or lists/tuples of them"
263  if condition_msg else ""))
264 
265  return _map
266 
267 
268 def _jit_unwrap_structured(obj):
269  if hasattr(obj, "_jit_unwrap"):
270  return obj._jit_unwrap()
271  return obj
272 
273 
274 def _iter_filter(condition, allow_unknown=False, condition_msg=None,
275  conversion=None):
276  def _iter(obj):
277  if conversion is not None:
278  obj = conversion(obj)
279  if condition(obj):
280  yield obj
281  elif obj is None:
282  return
283  elif isinstance(obj, (list, tuple)):
284  for o in obj:
285  for var in _iter(o):
286  yield var
287  elif allow_unknown:
288  yield obj
289  else:
290  raise ValueError("Auto nesting doesn't know how to process "
291  "an input object of type " + torch.typename(obj) +
292  (". Accepted types: " + condition_msg +
293  ", or lists/tuples of them"
294  if condition_msg else ""))
295 
296  return _iter
297 
298 
299 def _unflatten(input, proto):
300  # unflatten a list or tuple input into a nested list/tuple structure
301  # specified by proto
302  def unflatten_helper(input, proto):
303  res = []
304  if hasattr(proto, "_jit_wrap"):
305  return proto._jit_wrap(input)
306  if not isinstance(proto, (list, tuple)):
307  return input[0], input[1:]
308  for e in proto:
309  if e is None:
310  res.append(e)
311  else:
312  res_e, input = unflatten_helper(input, e)
313  res.append(res_e)
314  return type(proto)(res), input
315 
316  return unflatten_helper(input, proto)[0]
317 
318 
319 _iter_jit_values = _iter_filter(lambda o: o is None or isinstance(o, torch._C.Value),
320  condition_msg="jit's Values or None")
321 _iter_tensors = _iter_filter(lambda x: isinstance(x, torch.Tensor), condition_msg="Tensors",
322  conversion=_jit_unwrap_structured)
323 _iter_tensors_permissive = _iter_filter(lambda x: isinstance(x, torch.Tensor),
324  allow_unknown=True,
325  condition_msg="Tensors (permissive)")
326 _iter_None_tensors = _iter_filter(lambda o: o is None or isinstance(o, torch.Tensor),
327  condition_msg="Tensors or None")
328 _map_tensor_data = _nested_map(lambda x: isinstance(x, torch.Tensor), lambda o: o.data,
329  condition_msg="Tensors")
330 
331 
333 
334  def _do_forward(self, *input):
335  self._nested_input = input
336  flat_input = tuple(_iter_tensors(input))
337  flat_output = super(NestedIOFunction, self)._do_forward(*flat_input)
338  nested_output = self._nested_output
339  nested_tensors = _unflatten(flat_output, self._nested_output)
340  return nested_tensors
341 
342  def _do_backward(self, gradients, retain_variables):
343  self.retain_variables = retain_variables
344  result = super(NestedIOFunction, self)._do_backward(gradients, retain_variables)
345  if not retain_variables:
346  del self._nested_output
347  del self._to_save_nested
348  return result
349 
350  def backward(self, *gradients):
351  nested_gradients = _unflatten(gradients, self._nested_output)
352  result = self.backward_extended(*nested_gradients)
353  return tuple(_iter_None_tensors(result))
354 
355  __call__ = _do_forward
356 
357  def forward(self, *args):
358  nested_tensors = _map_tensor_data(self._nested_input)
359  result = self.forward_extended(*nested_tensors)
360  del self._nested_input
361  self._nested_output = result
362  return tuple(_iter_tensors(result))
363 
364  def save_for_backward(self, *args):
365  self.to_save = tuple(_iter_tensors(args))
366  self._to_save_nested = args
367 
368  @property
369  def saved_tensors(self):
370  flat_tensors = super(NestedIOFunction, self).saved_tensors
371  return _unflatten(flat_tensors, self._to_save_nested)
372 
373  def mark_dirty(self, *args, **kwargs):
374  self.dirty_tensors = tuple(_iter_tensors((args, kwargs)))
375 
376  def mark_non_differentiable(self, *args, **kwargs):
377  self.non_differentiable = tuple(_iter_tensors((args, kwargs)))
378 
379  def forward_extended(self, *input):
380  raise NotImplementedError
381 
382  def backward_extended(self, *grad_output):
383  raise NotImplementedError
def backward_extended(self, grad_output)
Definition: function.py:382
def typename(o)
Define basic utilities.
Definition: __init__.py:94