1 from __future__
import absolute_import, division, print_function, unicode_literals
6 def detach_variable(inputs):
7 if isinstance(inputs, tuple):
10 if not isinstance(inp, torch.Tensor):
15 x.requires_grad = inp.requires_grad
20 "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
23 def check_backward_validity(inputs):
24 if not any(inp.requires_grad
for inp
in inputs
if isinstance(inp, torch.Tensor)):
25 warnings.warn(
"None of the inputs have requires_grad=True. Gradients will be None")
35 def get_device_states(*args):
38 fwd_gpu_devices = list(set(arg.get_device()
for arg
in args
39 if isinstance(arg, torch.Tensor)
and arg.is_cuda))
42 for device
in fwd_gpu_devices:
44 fwd_gpu_states.append(torch.cuda.get_rng_state())
46 return fwd_gpu_devices, fwd_gpu_states
49 def set_device_states(devices, states):
50 for device, state
in zip(devices, states):
52 torch.cuda.set_rng_state(state)
58 def forward(ctx, run_function, preserve_rng_state, *args):
59 check_backward_validity(args)
60 ctx.run_function = run_function
61 ctx.preserve_rng_state = preserve_rng_state
62 if preserve_rng_state:
63 ctx.fwd_cpu_state = torch.get_rng_state()
68 ctx.had_cuda_in_fwd =
False 69 if torch.cuda._initialized:
70 ctx.had_cuda_in_fwd =
True 71 ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
72 ctx.save_for_backward(*args)
74 outputs = run_function(*args)
78 def backward(ctx, *args):
80 raise RuntimeError(
"Checkpointing is not compatible with .grad(), please use .backward() if possible")
81 inputs = ctx.saved_tensors
86 if ctx.preserve_rng_state
and ctx.had_cuda_in_fwd:
87 rng_devices = ctx.fwd_gpu_devices
89 if ctx.preserve_rng_state:
90 torch.set_rng_state(ctx.fwd_cpu_state)
91 if ctx.had_cuda_in_fwd:
92 set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
93 detached_inputs = detach_variable(inputs)
94 with torch.enable_grad():
95 outputs = ctx.run_function(*detached_inputs)
97 if isinstance(outputs, torch.Tensor):
100 grads = tuple(inp.grad
if isinstance(inp, torch.Tensor)
else inp
101 for inp
in detached_inputs)
102 return (
None,
None) + grads
106 r"""Checkpoint a model or part of the model 108 Checkpointing works by trading compute for memory. Rather than storing all 109 intermediate activations of the entire computation graph for computing 110 backward, the checkpointed part does **not** save intermediate activations, 111 and instead recomputes them in backward pass. It can be applied on any part 114 Specifically, in the forward pass, :attr:`function` will run in 115 :func:`torch.no_grad` manner, i.e., not storing the intermediate 116 activations. Instead, the forward pass saves the inputs tuple and the 117 :attr:`function` parameter. In the backwards pass, the saved inputs and 118 :attr:`function` is retreived, and the forward pass is computed on 119 :attr:`function` again, now tracking the intermediate activations, and then 120 the gradients are calculated using these activation values. 123 Checkpointing doesn't work with :func:`torch.autograd.grad`, but only 124 with :func:`torch.autograd.backward`. 127 If :attr:`function` invocation during backward does anything different 128 than the one during forward, e.g., due to some global variable, the 129 checkpointed version won't be equivalent, and unfortunately it can't be 133 At least one of the inputs needs to have :code:`requires_grad=True` if 134 grads are needed for model inputs, otherwise the checkpointed part of the 135 model won't have gradients. 138 function: describes what to run in the forward pass of the model or 139 part of the model. It should also know how to handle the inputs 140 passed as the tuple. For example, in LSTM, if user passes 141 ``(activation, hidden)``, :attr:`function` should correctly use the 142 first input as ``activation`` and the second input as ``hidden`` 143 preserve_rng_state(bool, optional, default=True): Omit stashing and restoring 144 the RNG state during each checkpoint. 145 args: tuple containing inputs to the :attr:`function` 148 Output of running :attr:`function` on :attr:`*args` 151 preserve = kwargs.pop(
'preserve_rng_state',
True)
153 raise ValueError(
"Unexpected keyword arguments: " +
",".join(arg
for arg
in kwargs))
155 return CheckpointFunction.apply(function, preserve, *args)
158 def checkpoint_sequential(functions, segments, *inputs, **kwargs):
159 r"""A helper function for checkpointing sequential models. 161 Sequential models execute a list of modules/functions in order 162 (sequentially). Therefore, we can divide such a model in various segments 163 and checkpoint each segment. All segments except the last will run in 164 :func:`torch.no_grad` manner, i.e., not storing the intermediate 165 activations. The inputs of each checkpointed segment will be saved for 166 re-running the segment in the backward pass. 168 See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works. 171 Checkpointing doesn't work with :func:`torch.autograd.grad`, but only 172 with :func:`torch.autograd.backward`. 175 At least one of the inputs needs to have :code:`requires_grad=True` if 176 grads are needed for model inputs, otherwise the checkpointed part of the 177 model won't have gradients. 180 functions: A :class:`torch.nn.Sequential` or the list of modules or 181 functions (comprising the model) to run sequentially. 182 segments: Number of chunks to create in the model 183 inputs: tuple of Tensors that are inputs to :attr:`functions` 184 preserve_rng_state(bool, optional, default=True): Omit stashing and restoring 185 the RNG state during each checkpoint. 188 Output of running :attr:`functions` sequentially on :attr:`*inputs` 191 >>> model = nn.Sequential(...) 192 >>> input_var = checkpoint_sequential(model, chunks, input_var) 195 preserve = kwargs.pop(
'preserve_rng_state',
True)
197 raise ValueError(
"Unexpected keyword arguments: " +
",".join(arg
for arg
in kwargs))
199 def run_function(start, end, functions):
200 def forward(*inputs):
201 for j
in range(start, end + 1):
202 if isinstance(inputs, tuple):
203 inputs = functions[j](*inputs)
205 inputs = functions[j](inputs)
209 if isinstance(functions, torch.nn.Sequential):
210 functions = list(functions.children())
212 segment_size = len(functions) // segments
215 for start
in range(0, segment_size * (segments - 1), segment_size):
216 end = start + segment_size - 1
217 inputs =
checkpoint(run_function(start, end, functions), *inputs,
218 preserve_rng_state=preserve)
219 if not isinstance(inputs, tuple):
221 return run_function(end + 1, len(functions) - 1, functions)(*inputs)
def backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None)
Module caffe2.python.checkpoint.
def _is_checkpoint_valid()
def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices")