Caffe2 - Python API
A deep learning, cross platform ML framework
spectral_norm.py
1 """
2 Spectral Normalization from https://arxiv.org/abs/1802.05957
3 """
4 import torch
5 from torch.nn.functional import normalize
6 from torch.nn.parameter import Parameter
7 
8 
9 class SpectralNorm(object):
10  # Invariant before and after each forward call:
11  # u = normalize(W @ v)
12  # NB: At initialization, this invariant is not enforced
13 
14  _version = 1
15  # At version 1:
16  # made `W` not a buffer,
17  # added `v` as a buffer, and
18  # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`.
19 
20  def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
21  self.name = name
22  self.dim = dim
23  if n_power_iterations <= 0:
24  raise ValueError('Expected n_power_iterations to be positive, but '
25  'got n_power_iterations={}'.format(n_power_iterations))
26  self.n_power_iterations = n_power_iterations
27  self.eps = eps
28 
29  def reshape_weight_to_matrix(self, weight):
30  weight_mat = weight
31  if self.dim != 0:
32  # permute dim to front
33  weight_mat = weight_mat.permute(self.dim,
34  *[d for d in range(weight_mat.dim()) if d != self.dim])
35  height = weight_mat.size(0)
36  return weight_mat.reshape(height, -1)
37 
38  def compute_weight(self, module, do_power_iteration):
39  # NB: If `do_power_iteration` is set, the `u` and `v` vectors are
40  # updated in power iteration **in-place**. This is very important
41  # because in `DataParallel` forward, the vectors (being buffers) are
42  # broadcast from the parallelized module to each module replica,
43  # which is a new module object created on the fly. And each replica
44  # runs its own spectral norm power iteration. So simply assigning
45  # the updated vectors to the module this function runs on will cause
46  # the update to be lost forever. And the next time the parallelized
47  # module is replicated, the same randomly initialized vectors are
48  # broadcast and used!
49  #
50  # Therefore, to make the change propagate back, we rely on two
51  # important behaviors (also enforced via tests):
52  # 1. `DataParallel` doesn't clone storage if the broadcast tensor
53  # is already on correct device; and it makes sure that the
54  # parallelized module is already on `device[0]`.
55  # 2. If the out tensor in `out=` kwarg has correct shape, it will
56  # just fill in the values.
57  # Therefore, since the same power iteration is performed on all
58  # devices, simply updating the tensors in-place will make sure that
59  # the module replica on `device[0]` will update the _u vector on the
60  # parallized module (by shared storage).
61  #
62  # However, after we update `u` and `v` in-place, we need to **clone**
63  # them before using them to normalize the weight. This is to support
64  # backproping through two forward passes, e.g., the common pattern in
65  # GAN training: loss = D(real) - D(fake). Otherwise, engine will
66  # complain that variables needed to do backward for the first forward
67  # (i.e., the `u` and `v` vectors) are changed in the second forward.
68  weight = getattr(module, self.name + '_orig')
69  u = getattr(module, self.name + '_u')
70  v = getattr(module, self.name + '_v')
71  weight_mat = self.reshape_weight_to_matrix(weight)
72 
73  if do_power_iteration:
74  with torch.no_grad():
75  for _ in range(self.n_power_iterations):
76  # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
77  # are the first left and right singular vectors.
78  # This power iteration produces approximations of `u` and `v`.
79  v = normalize(torch.mv(weight_mat.t(), u), dim=0, eps=self.eps, out=v)
80  u = normalize(torch.mv(weight_mat, v), dim=0, eps=self.eps, out=u)
81  if self.n_power_iterations > 0:
82  # See above on why we need to clone
83  u = u.clone()
84  v = v.clone()
85 
86  sigma = torch.dot(u, torch.mv(weight_mat, v))
87  weight = weight / sigma
88  return weight
89 
90  def remove(self, module):
91  with torch.no_grad():
92  weight = self.compute_weight(module, do_power_iteration=False)
93  delattr(module, self.name)
94  delattr(module, self.name + '_u')
95  delattr(module, self.name + '_v')
96  delattr(module, self.name + '_orig')
97  module.register_parameter(self.name, torch.nn.Parameter(weight.detach()))
98 
99  def __call__(self, module, inputs):
100  setattr(module, self.name, self.compute_weight(module, do_power_iteration=module.training))
101 
102  def _solve_v_and_rescale(self, weight_mat, u, target_sigma):
103  # Tries to returns a vector `v` s.t. `u = normalize(W @ v)`
104  # (the invariant at top of this class) and `u @ W @ v = sigma`.
105  # This uses pinverse in case W^T W is not invertible.
106  v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1)).squeeze(1)
107  return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v)))
108 
109  @staticmethod
110  def apply(module, name, n_power_iterations, dim, eps):
111  for k, hook in module._forward_pre_hooks.items():
112  if isinstance(hook, SpectralNorm) and hook.name == name:
113  raise RuntimeError("Cannot register two spectral_norm hooks on "
114  "the same parameter {}".format(name))
115 
116  fn = SpectralNorm(name, n_power_iterations, dim, eps)
117  weight = module._parameters[name]
118 
119  with torch.no_grad():
120  weight_mat = fn.reshape_weight_to_matrix(weight)
121 
122  h, w = weight_mat.size()
123  # randomly initialize `u` and `v`
124  u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
125  v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)
126 
127  delattr(module, fn.name)
128  module.register_parameter(fn.name + "_orig", weight)
129  # We still need to assign weight back as fn.name because all sorts of
130  # things may assume that it exists, e.g., when initializing weights.
131  # However, we can't directly assign as it could be an nn.Parameter and
132  # gets added as a parameter. Instead, we register weight.data as a plain
133  # attribute.
134  setattr(module, fn.name, weight.data)
135  module.register_buffer(fn.name + "_u", u)
136  module.register_buffer(fn.name + "_v", v)
137 
138  module.register_forward_pre_hook(fn)
139 
140  module._register_state_dict_hook(SpectralNormStateDictHook(fn))
141  module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn))
142  return fn
143 
144 
145 # This is a top level class because Py2 pickle doesn't like inner class nor an
146 # instancemethod.
148  # See docstring of SpectralNorm._version on the changes to spectral_norm.
149  def __init__(self, fn):
150  self.fn = fn
151 
152  # For state_dict with version None, (assuming that it has gone through at
153  # least one training forward), we have
154  #
155  # u = normalize(W_orig @ v)
156  # W = W_orig / sigma, where sigma = u @ W_orig @ v
157  #
158  # To compute `v`, we solve `W_orig @ x = u`, and let
159  # v = x / (u @ W_orig @ x) * (W / W_orig).
160  def __call__(self, state_dict, prefix, local_metadata, strict,
161  missing_keys, unexpected_keys, error_msgs):
162  fn = self.fn
163  version = local_metadata.get('spectral_norm', {}).get(fn.name + '.version', None)
164  if version is None or version < 1:
165  with torch.no_grad():
166  weight_orig = state_dict[prefix + fn.name + '_orig']
167  weight = state_dict.pop(prefix + fn.name)
168  sigma = (weight_orig / weight).mean()
169  weight_mat = fn.reshape_weight_to_matrix(weight_orig)
170  u = state_dict[prefix + fn.name + '_u']
171  v = fn._solve_v_and_rescale(weight_mat, u, sigma)
172  state_dict[prefix + fn.name + '_v'] = v
173 
174 
175 # This is a top level class because Py2 pickle doesn't like inner class nor an
176 # instancemethod.
178  # See docstring of SpectralNorm._version on the changes to spectral_norm.
179  def __init__(self, fn):
180  self.fn = fn
181 
182  def __call__(self, module, state_dict, prefix, local_metadata):
183  if 'spectral_norm' not in local_metadata:
184  local_metadata['spectral_norm'] = {}
185  key = self.fn.name + '.version'
186  if key in local_metadata['spectral_norm']:
187  raise RuntimeError("Unexpected key in metadata['spectral_norm']: {}".format(key))
188  local_metadata['spectral_norm'][key] = self.fn._version
189 
190 
191 def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None):
192  r"""Applies spectral normalization to a parameter in the given module.
193 
194  .. math::
195  \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
196  \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
197 
198  Spectral normalization stabilizes the training of discriminators (critics)
199  in Generative Adversarial Networks (GANs) by rescaling the weight tensor
200  with spectral norm :math:`\sigma` of the weight matrix calculated using
201  power iteration method. If the dimension of the weight tensor is greater
202  than 2, it is reshaped to 2D in power iteration method to get spectral
203  norm. This is implemented via a hook that calculates spectral norm and
204  rescales weight before every :meth:`~Module.forward` call.
205 
206  See `Spectral Normalization for Generative Adversarial Networks`_ .
207 
208  .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
209 
210  Args:
211  module (nn.Module): containing module
212  name (str, optional): name of weight parameter
213  n_power_iterations (int, optional): number of power iterations to
214  calculate spectral norm
215  eps (float, optional): epsilon for numerical stability in
216  calculating norms
217  dim (int, optional): dimension corresponding to number of outputs,
218  the default is ``0``, except for modules that are instances of
219  ConvTranspose{1,2,3}d, when it is ``1``
220 
221  Returns:
222  The original module with the spectral norm hook
223 
224  Example::
225 
226  >>> m = spectral_norm(nn.Linear(20, 40))
227  >>> m
228  Linear(in_features=20, out_features=40, bias=True)
229  >>> m.weight_u.size()
230  torch.Size([40])
231 
232  """
233  if dim is None:
234  if isinstance(module, (torch.nn.ConvTranspose1d,
235  torch.nn.ConvTranspose2d,
236  torch.nn.ConvTranspose3d)):
237  dim = 1
238  else:
239  dim = 0
240  SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
241  return module
242 
243 
244 def remove_spectral_norm(module, name='weight'):
245  r"""Removes the spectral normalization reparameterization from a module.
246 
247  Args:
248  module (Module): containing module
249  name (str, optional): name of weight parameter
250 
251  Example:
252  >>> m = spectral_norm(nn.Linear(40, 10))
253  >>> remove_spectral_norm(m)
254  """
255  for k, hook in module._forward_pre_hooks.items():
256  if isinstance(hook, SpectralNorm) and hook.name == name:
257  hook.remove(module)
258  del module._forward_pre_hooks[k]
259  return module
260 
261  raise ValueError("spectral_norm of '{}' not found in {}".format(
262  name, module))
def compute_weight(self, module, do_power_iteration)