2 Spectral Normalization from https://arxiv.org/abs/1802.05957 20 def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
23 if n_power_iterations <= 0:
24 raise ValueError(
'Expected n_power_iterations to be positive, but ' 25 'got n_power_iterations={}'.format(n_power_iterations))
29 def reshape_weight_to_matrix(self, weight):
33 weight_mat = weight_mat.permute(self.
dim,
34 *[d
for d
in range(weight_mat.dim())
if d != self.
dim])
35 height = weight_mat.size(0)
36 return weight_mat.reshape(height, -1)
38 def compute_weight(self, module, do_power_iteration):
68 weight = getattr(module, self.
name +
'_orig')
69 u = getattr(module, self.
name +
'_u')
70 v = getattr(module, self.
name +
'_v')
73 if do_power_iteration:
79 v = normalize(torch.mv(weight_mat.t(), u), dim=0, eps=self.
eps, out=v)
80 u = normalize(torch.mv(weight_mat, v), dim=0, eps=self.
eps, out=u)
86 sigma = torch.dot(u, torch.mv(weight_mat, v))
87 weight = weight / sigma
90 def remove(self, module):
93 delattr(module, self.
name)
94 delattr(module, self.
name +
'_u')
95 delattr(module, self.
name +
'_v')
96 delattr(module, self.
name +
'_orig')
97 module.register_parameter(self.
name, torch.nn.Parameter(weight.detach()))
99 def __call__(self, module, inputs):
100 setattr(module, self.
name, self.
compute_weight(module, do_power_iteration=module.training))
102 def _solve_v_and_rescale(self, weight_mat, u, target_sigma):
106 v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1)).squeeze(1)
107 return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v)))
110 def apply(module, name, n_power_iterations, dim, eps):
111 for k, hook
in module._forward_pre_hooks.items():
112 if isinstance(hook, SpectralNorm)
and hook.name == name:
113 raise RuntimeError(
"Cannot register two spectral_norm hooks on " 114 "the same parameter {}".format(name))
117 weight = module._parameters[name]
119 with torch.no_grad():
120 weight_mat = fn.reshape_weight_to_matrix(weight)
122 h, w = weight_mat.size()
124 u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
125 v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)
127 delattr(module, fn.name)
128 module.register_parameter(fn.name +
"_orig", weight)
134 setattr(module, fn.name, weight.data)
135 module.register_buffer(fn.name +
"_u", u)
136 module.register_buffer(fn.name +
"_v", v)
138 module.register_forward_pre_hook(fn)
149 def __init__(self, fn):
160 def __call__(self, state_dict, prefix, local_metadata, strict,
161 missing_keys, unexpected_keys, error_msgs):
163 version = local_metadata.get(
'spectral_norm', {}).get(fn.name +
'.version',
None)
164 if version
is None or version < 1:
165 with torch.no_grad():
166 weight_orig = state_dict[prefix + fn.name +
'_orig']
167 weight = state_dict.pop(prefix + fn.name)
168 sigma = (weight_orig / weight).mean()
169 weight_mat = fn.reshape_weight_to_matrix(weight_orig)
170 u = state_dict[prefix + fn.name +
'_u']
171 v = fn._solve_v_and_rescale(weight_mat, u, sigma)
172 state_dict[prefix + fn.name +
'_v'] = v
179 def __init__(self, fn):
182 def __call__(self, module, state_dict, prefix, local_metadata):
183 if 'spectral_norm' not in local_metadata:
184 local_metadata[
'spectral_norm'] = {}
185 key = self.fn.name +
'.version' 186 if key
in local_metadata[
'spectral_norm']:
187 raise RuntimeError(
"Unexpected key in metadata['spectral_norm']: {}".format(key))
188 local_metadata[
'spectral_norm'][key] = self.fn._version
191 def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None):
192 r"""Applies spectral normalization to a parameter in the given module. 195 \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, 196 \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} 198 Spectral normalization stabilizes the training of discriminators (critics) 199 in Generative Adversarial Networks (GANs) by rescaling the weight tensor 200 with spectral norm :math:`\sigma` of the weight matrix calculated using 201 power iteration method. If the dimension of the weight tensor is greater 202 than 2, it is reshaped to 2D in power iteration method to get spectral 203 norm. This is implemented via a hook that calculates spectral norm and 204 rescales weight before every :meth:`~Module.forward` call. 206 See `Spectral Normalization for Generative Adversarial Networks`_ . 208 .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 211 module (nn.Module): containing module 212 name (str, optional): name of weight parameter 213 n_power_iterations (int, optional): number of power iterations to 214 calculate spectral norm 215 eps (float, optional): epsilon for numerical stability in 217 dim (int, optional): dimension corresponding to number of outputs, 218 the default is ``0``, except for modules that are instances of 219 ConvTranspose{1,2,3}d, when it is ``1`` 222 The original module with the spectral norm hook 226 >>> m = spectral_norm(nn.Linear(20, 40)) 228 Linear(in_features=20, out_features=40, bias=True) 229 >>> m.weight_u.size() 234 if isinstance(module, (torch.nn.ConvTranspose1d,
235 torch.nn.ConvTranspose2d,
236 torch.nn.ConvTranspose3d)):
240 SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
244 def remove_spectral_norm(module, name='weight'):
245 r"""Removes the spectral normalization reparameterization from a module. 248 module (Module): containing module 249 name (str, optional): name of weight parameter 252 >>> m = spectral_norm(nn.Linear(40, 10)) 253 >>> remove_spectral_norm(m) 255 for k, hook
in module._forward_pre_hooks.items():
256 if isinstance(hook, SpectralNorm)
and hook.name == name:
258 del module._forward_pre_hooks[k]
261 raise ValueError(
"spectral_norm of '{}' not found in {}".format(
def compute_weight(self, module, do_power_iteration)
def reshape_weight_to_matrix(self, weight)