Caffe2 - Python API
A deep learning, cross platform ML framework
grad_mode.py
1 import torch
2 import functools
3 
4 
5 class no_grad(object):
6  r"""Context-manager that disabled gradient calculation.
7 
8  Disabling gradient calculation is useful for inference, when you are sure
9  that you will not call :meth:`Tensor.backward()`. It will reduce memory
10  consumption for computations that would otherwise have `requires_grad=True`.
11  In this mode, the result of every computation will have
12  `requires_grad=False`, even when the inputs have `requires_grad=True`.
13 
14  Also functions as a decorator.
15 
16 
17  Example::
18 
19  >>> x = torch.tensor([1], requires_grad=True)
20  >>> with torch.no_grad():
21  ... y = x * 2
22  >>> y.requires_grad
23  False
24  >>> @torch.no_grad()
25  ... def doubler(x):
26  ... return x * 2
27  >>> z = doubler(x)
28  >>> z.requires_grad
29  False
30  """
31  def __enter__(self):
32  self.prev = torch.is_grad_enabled()
33  torch._C.set_grad_enabled(False)
34 
35  def __exit__(self, *args):
36  torch.set_grad_enabled(self.prev)
37  return False
38 
39  def __call__(self, func):
40  @functools.wraps(func)
41  def decorate_no_grad(*args, **kwargs):
42  with self:
43  return func(*args, **kwargs)
44  return decorate_no_grad
45 
46 
47 class enable_grad(object):
48  r"""Context-manager that enables gradient calculation.
49 
50  Enables gradient calculation inside a :class:`~no_grad` context. This has
51  no effect outside of :class:`~no_grad`.
52 
53  Also functions as a decorator.
54 
55 
56  Example::
57 
58  >>> x = torch.tensor([1], requires_grad=True)
59  >>> with torch.no_grad():
60  ... with torch.enable_grad():
61  ... y = x * 2
62  >>> y.requires_grad
63  True
64  >>> y.backward()
65  >>> x.grad
66  >>> @torch.enable_grad()
67  ... def doubler(x):
68  ... return x * 2
69  >>> with torch.no_grad():
70  ... z = doubler(x)
71  >>> z.requires_grad
72  True
73 
74  """
75  def __enter__(self):
76  self.prev = torch.is_grad_enabled()
77  torch._C.set_grad_enabled(True)
78 
79  def __exit__(self, *args):
80  torch.set_grad_enabled(self.prev)
81  return False
82 
83  def __call__(self, func):
84  @functools.wraps(func)
85  def decorate_enable_grad(*args, **kwargs):
86  with self:
87  return func(*args, **kwargs)
88  return decorate_enable_grad
89 
90 
91 class set_grad_enabled(object):
92  r"""Context-manager that sets gradient calculation to on or off.
93 
94  ``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`.
95  It can be used as a context-manager or as a function.
96 
97  Arguments:
98  mode (bool): Flag whether to enable grad (``True``), or disable
99  (``False``). This can be used to conditionally enable
100  gradients.
101 
102 
103  Example::
104 
105  >>> x = torch.tensor([1], requires_grad=True)
106  >>> is_train = False
107  >>> with torch.set_grad_enabled(is_train):
108  ... y = x * 2
109  >>> y.requires_grad
110  False
111  >>> torch.set_grad_enabled(True)
112  >>> y = x * 2
113  >>> y.requires_grad
114  True
115  >>> torch.set_grad_enabled(False)
116  >>> y = x * 2
117  >>> y.requires_grad
118  False
119 
120  """
121 
122  def __init__(self, mode):
123  self.prev = torch.is_grad_enabled()
124  torch._C.set_grad_enabled(mode)
125 
126  def __enter__(self):
127  pass
128 
129  def __exit__(self, *args):
130  torch.set_grad_enabled(self.prev)
131  return False