2 Weight Normalization from https://arxiv.org/abs/1602.07868 5 from torch
import _weight_norm, norm_except_dim
9 def __init__(self, name, dim):
15 def compute_weight(self, module):
16 g = getattr(module, self.
name +
'_g')
17 v = getattr(module, self.
name +
'_v')
18 return _weight_norm(v, g, self.
dim)
21 def apply(module, name, dim):
22 for k, hook
in module._forward_pre_hooks.items():
23 if isinstance(hook, WeightNorm)
and hook.name == name:
24 raise RuntimeError(
"Cannot register two weight_norm hooks on " 25 "the same parameter {}".format(name))
32 weight = getattr(module, name)
35 del module._parameters[name]
38 module.register_parameter(name +
'_g',
Parameter(norm_except_dim(weight, 2, dim).data))
39 module.register_parameter(name +
'_v',
Parameter(weight.data))
40 setattr(module, name, fn.compute_weight(module))
43 module.register_forward_pre_hook(fn)
47 def remove(self, module):
49 delattr(module, self.
name)
50 del module._parameters[self.
name +
'_g']
51 del module._parameters[self.
name +
'_v']
54 def __call__(self, module, inputs):
58 def weight_norm(module, name='weight', dim=0):
59 r"""Applies weight normalization to a parameter in the given module. 62 \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} 64 Weight normalization is a reparameterization that decouples the magnitude 65 of a weight tensor from its direction. This replaces the parameter specified 66 by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude 67 (e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``). 68 Weight normalization is implemented via a hook that recomputes the weight 69 tensor from the magnitude and direction before every :meth:`~Module.forward` 72 By default, with ``dim=0``, the norm is computed independently per output 73 channel/plane. To compute a norm over the entire weight tensor, use 76 See https://arxiv.org/abs/1602.07868 79 module (Module): containing module 80 name (str, optional): name of weight parameter 81 dim (int, optional): dimension over which to compute the norm 84 The original module with the weight norm hook 88 >>> m = weight_norm(nn.Linear(20, 40), name='weight') 90 Linear(in_features=20, out_features=40, bias=True) 97 WeightNorm.apply(module, name, dim)
101 def remove_weight_norm(module, name='weight'):
102 r"""Removes the weight normalization reparameterization from a module. 105 module (Module): containing module 106 name (str, optional): name of weight parameter 109 >>> m = weight_norm(nn.Linear(20, 40)) 110 >>> remove_weight_norm(m) 112 for k, hook
in module._forward_pre_hooks.items():
113 if isinstance(hook, WeightNorm)
and hook.name == name:
115 del module._forward_pre_hooks[k]
118 raise ValueError(
"weight_norm of '{}' not found in {}" 119 .format(name, module))
def compute_weight(self, module)