Caffe2 - Python API
A deep learning, cross platform ML framework
checkpoint.py
1 from __future__ import absolute_import, division, print_function, unicode_literals
2 import torch
3 import warnings
4 
5 
6 def detach_variable(inputs):
7  if isinstance(inputs, tuple):
8  out = []
9  for inp in inputs:
10  if not isinstance(inp, torch.Tensor):
11  out.append(inp)
12  continue
13 
14  x = inp.detach()
15  x.requires_grad = inp.requires_grad
16  out.append(x)
17  return tuple(out)
18  else:
19  raise RuntimeError(
20  "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
21 
22 
23 def check_backward_validity(inputs):
24  if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):
25  warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
26 
27 
28 # We can't know if the run_fn will internally move some args to different devices,
29 # which would require logic to preserve rng states for those devices as well.
30 # We could paranoically stash and restore ALL the rng states for all visible devices,
31 # but that seems very wasteful for most cases. Compromise: Stash the RNG state for
32 # the device of all Tensor args.
33 #
34 # To consider: maybe get_device_states and set_device_states should reside in torch/random.py?
35 def get_device_states(*args):
36  # This will not error out if "arg" is a CPU tensor or a non-tensor type because
37  # the conditionals short-circuit.
38  fwd_gpu_devices = list(set(arg.get_device() for arg in args
39  if isinstance(arg, torch.Tensor) and arg.is_cuda))
40 
41  fwd_gpu_states = []
42  for device in fwd_gpu_devices:
43  with torch.cuda.device(device):
44  fwd_gpu_states.append(torch.cuda.get_rng_state())
45 
46  return fwd_gpu_devices, fwd_gpu_states
47 
48 
49 def set_device_states(devices, states):
50  for device, state in zip(devices, states):
51  with torch.cuda.device(device):
52  torch.cuda.set_rng_state(state)
53 
54 
55 class CheckpointFunction(torch.autograd.Function):
56 
57  @staticmethod
58  def forward(ctx, run_function, preserve_rng_state, *args):
59  check_backward_validity(args)
60  ctx.run_function = run_function
61  ctx.preserve_rng_state = preserve_rng_state
62  if preserve_rng_state:
63  ctx.fwd_cpu_state = torch.get_rng_state()
64  # Don't eagerly initialize the cuda context by accident.
65  # (If the user intends that the context is initialized later, within their
66  # run_function, we SHOULD actually stash the cuda state here. Unfortunately,
67  # we have no way to anticipate this will happen before we run the function.)
68  ctx.had_cuda_in_fwd = False
69  if torch.cuda._initialized:
70  ctx.had_cuda_in_fwd = True
71  ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
72  ctx.save_for_backward(*args)
73  with torch.no_grad():
74  outputs = run_function(*args)
75  return outputs
76 
77  @staticmethod
78  def backward(ctx, *args):
80  raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
81  inputs = ctx.saved_tensors
82  # Stash the surrounding rng state, and mimic the state that was
83  # present at this time during forward. Restore the surrouding state
84  # when we're done.
85  rng_devices = []
86  if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
87  rng_devices = ctx.fwd_gpu_devices
88  with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
89  if ctx.preserve_rng_state:
90  torch.set_rng_state(ctx.fwd_cpu_state)
91  if ctx.had_cuda_in_fwd:
92  set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
93  detached_inputs = detach_variable(inputs)
94  with torch.enable_grad():
95  outputs = ctx.run_function(*detached_inputs)
96 
97  if isinstance(outputs, torch.Tensor):
98  outputs = (outputs,)
99  torch.autograd.backward(outputs, args)
100  grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
101  for inp in detached_inputs)
102  return (None, None) + grads
103 
104 
105 def checkpoint(function, *args, **kwargs):
106  r"""Checkpoint a model or part of the model
107 
108  Checkpointing works by trading compute for memory. Rather than storing all
109  intermediate activations of the entire computation graph for computing
110  backward, the checkpointed part does **not** save intermediate activations,
111  and instead recomputes them in backward pass. It can be applied on any part
112  of a model.
113 
114  Specifically, in the forward pass, :attr:`function` will run in
115  :func:`torch.no_grad` manner, i.e., not storing the intermediate
116  activations. Instead, the forward pass saves the inputs tuple and the
117  :attr:`function` parameter. In the backwards pass, the saved inputs and
118  :attr:`function` is retreived, and the forward pass is computed on
119  :attr:`function` again, now tracking the intermediate activations, and then
120  the gradients are calculated using these activation values.
121 
122  .. warning::
123  Checkpointing doesn't work with :func:`torch.autograd.grad`, but only
124  with :func:`torch.autograd.backward`.
125 
126  .. warning::
127  If :attr:`function` invocation during backward does anything different
128  than the one during forward, e.g., due to some global variable, the
129  checkpointed version won't be equivalent, and unfortunately it can't be
130  detected.
131 
132  .. warning:
133  At least one of the inputs needs to have :code:`requires_grad=True` if
134  grads are needed for model inputs, otherwise the checkpointed part of the
135  model won't have gradients.
136 
137  Args:
138  function: describes what to run in the forward pass of the model or
139  part of the model. It should also know how to handle the inputs
140  passed as the tuple. For example, in LSTM, if user passes
141  ``(activation, hidden)``, :attr:`function` should correctly use the
142  first input as ``activation`` and the second input as ``hidden``
143  preserve_rng_state(bool, optional, default=True): Omit stashing and restoring
144  the RNG state during each checkpoint.
145  args: tuple containing inputs to the :attr:`function`
146 
147  Returns:
148  Output of running :attr:`function` on :attr:`*args`
149  """
150  # Hack to mix *args with **kwargs in a python 2.7-compliant way
151  preserve = kwargs.pop('preserve_rng_state', True)
152  if kwargs:
153  raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
154 
155  return CheckpointFunction.apply(function, preserve, *args)
156 
157 
158 def checkpoint_sequential(functions, segments, *inputs, **kwargs):
159  r"""A helper function for checkpointing sequential models.
160 
161  Sequential models execute a list of modules/functions in order
162  (sequentially). Therefore, we can divide such a model in various segments
163  and checkpoint each segment. All segments except the last will run in
164  :func:`torch.no_grad` manner, i.e., not storing the intermediate
165  activations. The inputs of each checkpointed segment will be saved for
166  re-running the segment in the backward pass.
167 
168  See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
169 
170  .. warning::
171  Checkpointing doesn't work with :func:`torch.autograd.grad`, but only
172  with :func:`torch.autograd.backward`.
173 
174  .. warning:
175  At least one of the inputs needs to have :code:`requires_grad=True` if
176  grads are needed for model inputs, otherwise the checkpointed part of the
177  model won't have gradients.
178 
179  Args:
180  functions: A :class:`torch.nn.Sequential` or the list of modules or
181  functions (comprising the model) to run sequentially.
182  segments: Number of chunks to create in the model
183  inputs: tuple of Tensors that are inputs to :attr:`functions`
184  preserve_rng_state(bool, optional, default=True): Omit stashing and restoring
185  the RNG state during each checkpoint.
186 
187  Returns:
188  Output of running :attr:`functions` sequentially on :attr:`*inputs`
189 
190  Example:
191  >>> model = nn.Sequential(...)
192  >>> input_var = checkpoint_sequential(model, chunks, input_var)
193  """
194  # Hack to mix *args with **kwargs in a python 2.7-compliant way
195  preserve = kwargs.pop('preserve_rng_state', True)
196  if kwargs:
197  raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
198 
199  def run_function(start, end, functions):
200  def forward(*inputs):
201  for j in range(start, end + 1):
202  if isinstance(inputs, tuple):
203  inputs = functions[j](*inputs)
204  else:
205  inputs = functions[j](inputs)
206  return inputs
207  return forward
208 
209  if isinstance(functions, torch.nn.Sequential):
210  functions = list(functions.children())
211 
212  segment_size = len(functions) // segments
213  # the last chunk has to be non-volatile
214  end = -1
215  for start in range(0, segment_size * (segments - 1), segment_size):
216  end = start + segment_size - 1
217  inputs = checkpoint(run_function(start, end, functions), *inputs,
218  preserve_rng_state=preserve)
219  if not isinstance(inputs, tuple):
220  inputs = (inputs,)
221  return run_function(end + 1, len(functions) - 1, functions)(*inputs)
def backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None)
Definition: __init__.py:38
Module caffe2.python.checkpoint.
def _is_checkpoint_valid()
Definition: __init__.py:166
def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices")
Definition: random.py:49