6 r"""Context-manager that disabled gradient calculation. 8 Disabling gradient calculation is useful for inference, when you are sure 9 that you will not call :meth:`Tensor.backward()`. It will reduce memory 10 consumption for computations that would otherwise have `requires_grad=True`. 11 In this mode, the result of every computation will have 12 `requires_grad=False`, even when the inputs have `requires_grad=True`. 14 Also functions as a decorator. 19 >>> x = torch.tensor([1], requires_grad=True) 20 >>> with torch.no_grad(): 32 self.
prev = torch.is_grad_enabled()
33 torch._C.set_grad_enabled(
False)
35 def __exit__(self, *args):
36 torch.set_grad_enabled(self.
prev)
39 def __call__(self, func):
40 @functools.wraps(func)
41 def decorate_no_grad(*args, **kwargs):
43 return func(*args, **kwargs)
44 return decorate_no_grad
48 r"""Context-manager that enables gradient calculation. 50 Enables gradient calculation inside a :class:`~no_grad` context. This has 51 no effect outside of :class:`~no_grad`. 53 Also functions as a decorator. 58 >>> x = torch.tensor([1], requires_grad=True) 59 >>> with torch.no_grad(): 60 ... with torch.enable_grad(): 66 >>> @torch.enable_grad() 69 >>> with torch.no_grad(): 76 self.
prev = torch.is_grad_enabled()
77 torch._C.set_grad_enabled(
True)
79 def __exit__(self, *args):
80 torch.set_grad_enabled(self.
prev)
83 def __call__(self, func):
84 @functools.wraps(func)
85 def decorate_enable_grad(*args, **kwargs):
87 return func(*args, **kwargs)
88 return decorate_enable_grad
92 r"""Context-manager that sets gradient calculation to on or off. 94 ``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`. 95 It can be used as a context-manager or as a function. 98 mode (bool): Flag whether to enable grad (``True``), or disable 99 (``False``). This can be used to conditionally enable 105 >>> x = torch.tensor([1], requires_grad=True) 107 >>> with torch.set_grad_enabled(is_train): 111 >>> torch.set_grad_enabled(True) 115 >>> torch.set_grad_enabled(False) 122 def __init__(self, mode):
123 self.
prev = torch.is_grad_enabled()
124 torch._C.set_grad_enabled(mode)
129 def __exit__(self, *args):
130 torch.set_grad_enabled(self.prev)