7 from collections
import OrderedDict
12 def save_for_backward(self, *tensors):
13 r"""Saves given tensors for a future call to :func:`~Function.backward`. 15 **This should be called at most once, and only from inside the** 16 :func:`forward` **method.** 18 Later, saved tensors can be accessed through the :attr:`saved_tensors` 19 attribute. Before returning them to the user, a check is made to ensure 20 they weren't used in any in-place operation that modified their content. 22 Arguments can also be ``None``. 26 def mark_dirty(self, *args):
27 r"""Marks given tensors as modified in an in-place operation. 29 **This should be called at most once, only from inside the** 30 :func:`forward` **method, and all arguments should be inputs.** 32 Every tensor that's been modified in-place in a call to :func:`forward` 33 should be given to this function, to ensure correctness of our checks. 34 It doesn't matter whether the function is called before or after 39 def mark_shared_storage(self, *pairs):
41 'mark_shared_storage is deprecated. ' 42 'Tensors with shared storages are automatically tracked. Note ' 43 'that calls to `set_()` are not tracked')
45 def mark_non_differentiable(self, *args):
46 r"""Marks outputs as non-differentiable. 48 **This should be called at most once, only from inside the** 49 :func:`forward` **method, and all arguments should be outputs.** 51 This will mark outputs as not requiring gradients, increasing the 52 efficiency of backward computation. You still need to accept a gradient 53 for each output in :meth:`~Function.backward`, but it's always going to 54 be a zero tensor with the same shape as the shape of a corresponding 57 This is used e.g. for indices returned from a max :class:`Function`. 65 def _register_hook(backward_hooks, hook):
66 if backward_hooks
is None:
67 backward_hooks = OrderedDict()
68 handle = hooks.RemovableHandle(backward_hooks)
69 backward_hooks[handle.id] = hook
70 return backward_hooks, handle
76 def apply(self, *args):
77 return self._forward_cls.backward(self, *args)
81 """Function metaclass. 83 This metaclass sets up the following properties: 84 _is_legacy: True if forward is not defined as a static method. 85 _backward_cls: The Function class corresponding to the differentiated 86 version of this function (which is generated on the fly by this 90 def __init__(cls, name, bases, attrs):
91 for super_cls
in cls.mro():
92 forward = super_cls.__dict__.get(
'forward')
93 if forward
is not None:
94 has_static_forward = isinstance(forward, staticmethod)
or isinstance(forward, classmethod)
97 setattr(cls,
'_is_legacy',
not has_static_forward)
100 if not has_static_forward:
101 return super(FunctionMeta, cls).__init__(name, bases, attrs)
103 backward_fn = type(name +
'Backward', (BackwardCFunction,), {
'_forward_cls': cls})
104 setattr(cls,
'_backward_cls', backward_fn)
106 return super(FunctionMeta, cls).__init__(name, bases, attrs)
110 r"""Records operation history and defines formulas for differentiating ops. 112 Every operation performed on :class:`Tensor` s creates a new function 113 object, that performs the computation, and records that it happened. 114 The history is retained in the form of a DAG of functions, with edges 115 denoting data dependencies (``input <- output``). Then, when backward is 116 called, the graph is processed in the topological ordering, by calling 117 :func:`backward` methods of each :class:`Function` object, and passing 118 returned gradients on to next :class:`Function` s. 120 Normally, the only way users interact with functions is by creating 121 subclasses and defining new operations. This is a recommended way of 122 extending torch.autograd. 124 Each function object is meant to be used only once (in the forward pass). 128 >>> class Exp(Function): 131 >>> def forward(ctx, i): 133 >>> ctx.save_for_backward(result) 137 >>> def backward(ctx, grad_output): 138 >>> result, = ctx.saved_tensors 139 >>> return grad_output * result 143 __call__ = _C._FunctionBase._do_forward
149 def forward(ctx, *args, **kwargs):
150 r"""Performs the operation. 152 This function is to be overridden by all subclasses. 154 It must accept a context ctx as the first argument, followed by any 155 number of arguments (tensors or other types). 157 The context can be used to store tensors that can be then retrieved 158 during the backward pass. 160 raise NotImplementedError
163 def backward(ctx, *grad_outputs):
164 r"""Defines a formula for differentiating the operation. 166 This function is to be overridden by all subclasses. 168 It must accept a context :attr:`ctx` as the first argument, followed by 169 as many outputs did :func:`forward` return, and it should return as many 170 tensors, as there were inputs to :func:`forward`. Each argument is the 171 gradient w.r.t the given output, and each returned value should be the 172 gradient w.r.t. the corresponding input. 174 The context can be used to retrieve tensors saved during the forward 175 pass. It also has an attribute :attr:`ctx.needs_input_grad` as a tuple 176 of booleans representing whether each input needs gradient. E.g., 177 :func:`backward` will have ``ctx.needs_input_grad[0] = True`` if the 178 first input to :func:`forward` needs gradient computated w.r.t. the 181 raise NotImplementedError
184 def once_differentiable(fn):
187 def wrapper(ctx, *args):
188 with torch.no_grad():
189 outputs = fn(ctx, *args)
191 if not torch.is_grad_enabled():
203 requires_grad = any(isinstance(arg, torch.Tensor)
and arg.requires_grad
205 if not requires_grad:
208 if not isinstance(outputs, tuple):
211 err_fn = torch._C._functions.DelayedError(
212 b
"trying to differentiate twice a function that was marked" 213 b
"with @once_differentiable", len(outputs))
218 def fake_requires_grad(var):
221 var.requires_grad =
True 224 return err_fn(*[fake_requires_grad(v)
for v
in outputs])
228 def traceable(fn_cls):
229 r"""Marks Function as traceable for the JIT. 231 Traceable functions have additional restrictions - they can't pass any 232 data-dependent values to backward (e.g. Prod passes the output, which makes 233 it non-traceable), and their backward should be implemented entirely in terms 234 of operations on autograd Tensors in all cases. 236 DON'T USE THIS DECORATOR. IT IS FOR INTERNAL USE ONLY AND SHOULD BE HANDLED WITH 237 CARE (or can give incorrect results otherwise). 239 fn_cls.is_traceable =
True 245 def __init__(self, inplace=False):
246 super(InplaceFunction, self).__init__()
250 def _nested_map(condition, fn, condition_msg=None):
256 elif isinstance(obj, (list, tuple)):
257 return type(obj)(_map(x)
for x
in obj)
259 raise ValueError(
"Auto nesting doesn't know how to process " 261 (
". Accepted types: " + condition_msg +
262 ", or lists/tuples of them" 263 if condition_msg
else ""))
268 def _jit_unwrap_structured(obj):
269 if hasattr(obj,
"_jit_unwrap"):
270 return obj._jit_unwrap()
274 def _iter_filter(condition, allow_unknown=False, condition_msg=None,
277 if conversion
is not None:
278 obj = conversion(obj)
283 elif isinstance(obj, (list, tuple)):
290 raise ValueError(
"Auto nesting doesn't know how to process " 292 (
". Accepted types: " + condition_msg +
293 ", or lists/tuples of them" 294 if condition_msg
else ""))
299 def _unflatten(input, proto):
302 def unflatten_helper(input, proto):
304 if hasattr(proto,
"_jit_wrap"):
305 return proto._jit_wrap(input)
306 if not isinstance(proto, (list, tuple)):
307 return input[0], input[1:]
312 res_e, input = unflatten_helper(input, e)
314 return type(proto)(res), input
316 return unflatten_helper(input, proto)[0]
319 _iter_jit_values = _iter_filter(
lambda o: o
is None or isinstance(o, torch._C.Value),
320 condition_msg=
"jit's Values or None")
321 _iter_tensors = _iter_filter(
lambda x: isinstance(x, torch.Tensor), condition_msg=
"Tensors",
322 conversion=_jit_unwrap_structured)
323 _iter_tensors_permissive = _iter_filter(
lambda x: isinstance(x, torch.Tensor),
325 condition_msg=
"Tensors (permissive)")
326 _iter_None_tensors = _iter_filter(
lambda o: o
is None or isinstance(o, torch.Tensor),
327 condition_msg=
"Tensors or None")
328 _map_tensor_data = _nested_map(
lambda x: isinstance(x, torch.Tensor),
lambda o: o.data,
329 condition_msg=
"Tensors")
334 def _do_forward(self, *input):
336 flat_input = tuple(_iter_tensors(input))
337 flat_output = super(NestedIOFunction, self)._do_forward(*flat_input)
340 return nested_tensors
342 def _do_backward(self, gradients, retain_variables):
344 result = super(NestedIOFunction, self)._do_backward(gradients, retain_variables)
345 if not retain_variables:
350 def backward(self, *gradients):
353 return tuple(_iter_None_tensors(result))
355 __call__ = _do_forward
357 def forward(self, *args):
362 return tuple(_iter_tensors(result))
364 def save_for_backward(self, *args):
365 self.
to_save = tuple(_iter_tensors(args))
369 def saved_tensors(self):
370 flat_tensors = super(NestedIOFunction, self).saved_tensors
373 def mark_dirty(self, *args, **kwargs):
376 def mark_non_differentiable(self, *args, **kwargs):
379 def forward_extended(self, *input):
380 raise NotImplementedError
382 def backward_extended(self, *grad_output):
383 raise NotImplementedError
def forward_extended(self, input)
def backward_extended(self, grad_output)
def typename(o)
Define basic utilities.