Caffe2 - Python API
A deep learning, cross platform ML framework
weight_norm.py
1 r"""
2 Weight Normalization from https://arxiv.org/abs/1602.07868
3 """
4 from torch.nn.parameter import Parameter
5 from torch import _weight_norm, norm_except_dim
6 
7 
8 class WeightNorm(object):
9  def __init__(self, name, dim):
10  if dim is None:
11  dim = -1
12  self.name = name
13  self.dim = dim
14 
15  def compute_weight(self, module):
16  g = getattr(module, self.name + '_g')
17  v = getattr(module, self.name + '_v')
18  return _weight_norm(v, g, self.dim)
19 
20  @staticmethod
21  def apply(module, name, dim):
22  for k, hook in module._forward_pre_hooks.items():
23  if isinstance(hook, WeightNorm) and hook.name == name:
24  raise RuntimeError("Cannot register two weight_norm hooks on "
25  "the same parameter {}".format(name))
26 
27  if dim is None:
28  dim = -1
29 
30  fn = WeightNorm(name, dim)
31 
32  weight = getattr(module, name)
33 
34  # remove w from parameter list
35  del module._parameters[name]
36 
37  # add g and v as new parameters and express w as g/||v|| * v
38  module.register_parameter(name + '_g', Parameter(norm_except_dim(weight, 2, dim).data))
39  module.register_parameter(name + '_v', Parameter(weight.data))
40  setattr(module, name, fn.compute_weight(module))
41 
42  # recompute weight before every forward()
43  module.register_forward_pre_hook(fn)
44 
45  return fn
46 
47  def remove(self, module):
48  weight = self.compute_weight(module)
49  delattr(module, self.name)
50  del module._parameters[self.name + '_g']
51  del module._parameters[self.name + '_v']
52  module.register_parameter(self.name, Parameter(weight.data))
53 
54  def __call__(self, module, inputs):
55  setattr(module, self.name, self.compute_weight(module))
56 
57 
58 def weight_norm(module, name='weight', dim=0):
59  r"""Applies weight normalization to a parameter in the given module.
60 
61  .. math::
62  \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
63 
64  Weight normalization is a reparameterization that decouples the magnitude
65  of a weight tensor from its direction. This replaces the parameter specified
66  by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude
67  (e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``).
68  Weight normalization is implemented via a hook that recomputes the weight
69  tensor from the magnitude and direction before every :meth:`~Module.forward`
70  call.
71 
72  By default, with ``dim=0``, the norm is computed independently per output
73  channel/plane. To compute a norm over the entire weight tensor, use
74  ``dim=None``.
75 
76  See https://arxiv.org/abs/1602.07868
77 
78  Args:
79  module (Module): containing module
80  name (str, optional): name of weight parameter
81  dim (int, optional): dimension over which to compute the norm
82 
83  Returns:
84  The original module with the weight norm hook
85 
86  Example::
87 
88  >>> m = weight_norm(nn.Linear(20, 40), name='weight')
89  >>> m
90  Linear(in_features=20, out_features=40, bias=True)
91  >>> m.weight_g.size()
92  torch.Size([40, 1])
93  >>> m.weight_v.size()
94  torch.Size([40, 20])
95 
96  """
97  WeightNorm.apply(module, name, dim)
98  return module
99 
100 
101 def remove_weight_norm(module, name='weight'):
102  r"""Removes the weight normalization reparameterization from a module.
103 
104  Args:
105  module (Module): containing module
106  name (str, optional): name of weight parameter
107 
108  Example:
109  >>> m = weight_norm(nn.Linear(20, 40))
110  >>> remove_weight_norm(m)
111  """
112  for k, hook in module._forward_pre_hooks.items():
113  if isinstance(hook, WeightNorm) and hook.name == name:
114  hook.remove(module)
115  del module._forward_pre_hooks[k]
116  return module
117 
118  raise ValueError("weight_norm of '{}' not found in {}"
119  .format(name, module))