Caffe2 - Python API
A deep learning, cross platform ML framework
clip_grad.py
1 import warnings
2 import torch
3 from torch._six import inf
4 
5 
6 def clip_grad_norm_(parameters, max_norm, norm_type=2):
7  r"""Clips gradient norm of an iterable of parameters.
8 
9  The norm is computed over all gradients together, as if they were
10  concatenated into a single vector. Gradients are modified in-place.
11 
12  Arguments:
13  parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
14  single Tensor that will have gradients normalized
15  max_norm (float or int): max norm of the gradients
16  norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
17  infinity norm.
18 
19  Returns:
20  Total norm of the parameters (viewed as a single vector).
21  """
22  if isinstance(parameters, torch.Tensor):
23  parameters = [parameters]
24  parameters = list(filter(lambda p: p.grad is not None, parameters))
25  max_norm = float(max_norm)
26  norm_type = float(norm_type)
27  if norm_type == inf:
28  total_norm = max(p.grad.data.abs().max() for p in parameters)
29  else:
30  total_norm = 0
31  for p in parameters:
32  param_norm = p.grad.data.norm(norm_type)
33  total_norm += param_norm.item() ** norm_type
34  total_norm = total_norm ** (1. / norm_type)
35  clip_coef = max_norm / (total_norm + 1e-6)
36  if clip_coef < 1:
37  for p in parameters:
38  p.grad.data.mul_(clip_coef)
39  return total_norm
40 
41 
42 def clip_grad_norm(parameters, max_norm, norm_type=2):
43  r"""Clips gradient norm of an iterable of parameters.
44 
45  .. warning::
46  This method is now deprecated in favor of
47  :func:`torch.nn.utils.clip_grad_norm_`.
48  """
49  warnings.warn("torch.nn.utils.clip_grad_norm is now deprecated in favor "
50  "of torch.nn.utils.clip_grad_norm_.", stacklevel=2)
51  return clip_grad_norm_(parameters, max_norm, norm_type)
52 
53 
54 def clip_grad_value_(parameters, clip_value):
55  r"""Clips gradient of an iterable of parameters at specified value.
56 
57  Gradients are modified in-place.
58 
59  Arguments:
60  parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
61  single Tensor that will have gradients normalized
62  clip_value (float or int): maximum allowed value of the gradients.
63  The gradients are clipped in the range
64  :math:`\left[\text{-clip\_value}, \text{clip\_value}\right]`
65  """
66  if isinstance(parameters, torch.Tensor):
67  parameters = [parameters]
68  clip_value = float(clip_value)
69  for p in filter(lambda p: p.grad is not None, parameters):
70  p.grad.data.clamp_(min=-clip_value, max=clip_value)