 Caffe2 - Python API A deep learning, cross platform ML framework
init.py
1 import math
2 import random
3 import warnings
4
5 import torch
6
7
8 def calculate_gain(nonlinearity, param=None):
9  r"""Return the recommended gain value for the given nonlinearity function.
10  The values are as follows:
11
12  ================= ====================================================
13  nonlinearity gain
14  ================= ====================================================
15  Linear / Identity :math:1
16  Conv{1,2,3}D :math:1
17  Sigmoid :math:1
18  Tanh :math:\frac{5}{3}
19  ReLU :math:\sqrt{2}
20  Leaky Relu :math:\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}
21  ================= ====================================================
22
23  Args:
24  nonlinearity: the non-linear function (nn.functional name)
25  param: optional parameter for the non-linear function
26
27  Examples:
28  >>> gain = nn.init.calculate_gain('leaky_relu')
29  """
30  linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
31  if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
32  return 1
33  elif nonlinearity == 'tanh':
34  return 5.0 / 3
35  elif nonlinearity == 'relu':
36  return math.sqrt(2.0)
37  elif nonlinearity == 'leaky_relu':
38  if param is None:
39  negative_slope = 0.01
40  elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
41  # True/False are instances of int, hence check above
42  negative_slope = param
43  else:
44  raise ValueError("negative_slope {} not a valid number".format(param))
45  return math.sqrt(2.0 / (1 + negative_slope ** 2))
46  else:
47  raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
48
49
50 def uniform_(tensor, a=0, b=1):
51  r"""Fills the input Tensor with values drawn from the uniform
52  distribution :math:\mathcal{U}(a, b).
53
54  Args:
55  tensor: an n-dimensional torch.Tensor
56  a: the lower bound of the uniform distribution
57  b: the upper bound of the uniform distribution
58
59  Examples:
60  >>> w = torch.empty(3, 5)
61  >>> nn.init.uniform_(w)
62  """
63  with torch.no_grad():
64  return tensor.uniform_(a, b)
65
66
67 def normal_(tensor, mean=0, std=1):
68  r"""Fills the input Tensor with values drawn from the normal
69  distribution :math:\mathcal{N}(\text{mean}, \text{std}).
70
71  Args:
72  tensor: an n-dimensional torch.Tensor
73  mean: the mean of the normal distribution
74  std: the standard deviation of the normal distribution
75
76  Examples:
77  >>> w = torch.empty(3, 5)
78  >>> nn.init.normal_(w)
79  """
80  with torch.no_grad():
81  return tensor.normal_(mean, std)
82
83
84 def constant_(tensor, val):
85  r"""Fills the input Tensor with the value :math:\text{val}.
86
87  Args:
88  tensor: an n-dimensional torch.Tensor
89  val: the value to fill the tensor with
90
91  Examples:
92  >>> w = torch.empty(3, 5)
93  >>> nn.init.constant_(w, 0.3)
94  """
95  with torch.no_grad():
96  return tensor.fill_(val)
97
98
99 def ones_(tensor):
100  r"""Fills the input Tensor with ones.
101
102  Args:
103  tensor: an n-dimensional torch.Tensor
104
105  Examples:
106  >>> w = torch.empty(3, 5)
107  >>> nn.init.ones_(w)
108  """
109  with torch.no_grad():
110  return tensor.fill_(1)
111
112
113 def zeros_(tensor):
114  r"""Fills the input Tensor with zeros.
115
116  Args:
117  tensor: an n-dimensional torch.Tensor
118
119  Examples:
120  >>> w = torch.empty(3, 5)
121  >>> nn.init.zeros_(w)
122  """
123  with torch.no_grad():
124  return tensor.zero_()
125
126
127 def eye_(tensor):
128  r"""Fills the 2-dimensional input Tensor with the identity
129  matrix. Preserves the identity of the inputs in Linear layers, where as
130  many inputs are preserved as possible.
131
132  Args:
133  tensor: a 2-dimensional torch.Tensor
134
135  Examples:
136  >>> w = torch.empty(3, 5)
137  >>> nn.init.eye_(w)
138  """
139  if tensor.ndimension() != 2:
140  raise ValueError("Only tensors with 2 dimensions are supported")
141
142  with torch.no_grad():
143  torch.eye(*tensor.shape, out=tensor, requires_grad=tensor.requires_grad)
144  return tensor
145
146
147 def dirac_(tensor):
148  r"""Fills the {3, 4, 5}-dimensional input Tensor with the Dirac
149  delta function. Preserves the identity of the inputs in Convolutional
150  layers, where as many input channels are preserved as possible.
151
152  Args:
153  tensor: a {3, 4, 5}-dimensional torch.Tensor
154
155  Examples:
156  >>> w = torch.empty(3, 16, 5, 5)
157  >>> nn.init.dirac_(w)
158  """
159  dimensions = tensor.ndimension()
160  if dimensions not in [3, 4, 5]:
161  raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported")
162
163  sizes = tensor.size()
164  min_dim = min(sizes, sizes)
165  with torch.no_grad():
166  tensor.zero_()
167
168  for d in range(min_dim):
169  if dimensions == 3: # Temporal convolution
170  tensor[d, d, tensor.size(2) // 2] = 1
171  elif dimensions == 4: # Spatial convolution
172  tensor[d, d, tensor.size(2) // 2, tensor.size(3) // 2] = 1
173  else: # Volumetric convolution
174  tensor[d, d, tensor.size(2) // 2, tensor.size(3) // 2, tensor.size(4) // 2] = 1
175  return tensor
176
177
178 def _calculate_fan_in_and_fan_out(tensor):
179  dimensions = tensor.ndimension()
180  if dimensions < 2:
181  raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
182
183  if dimensions == 2: # Linear
184  fan_in = tensor.size(1)
185  fan_out = tensor.size(0)
186  else:
187  num_input_fmaps = tensor.size(1)
188  num_output_fmaps = tensor.size(0)
189  receptive_field_size = 1
190  if tensor.dim() > 2:
191  receptive_field_size = tensor.numel()
192  fan_in = num_input_fmaps * receptive_field_size
193  fan_out = num_output_fmaps * receptive_field_size
194
195  return fan_in, fan_out
196
197
198 def xavier_uniform_(tensor, gain=1):
199  r"""Fills the input Tensor with values according to the method
200  described in Understanding the difficulty of training deep feedforward
201  neural networks - Glorot, X. & Bengio, Y. (2010), using a uniform
202  distribution. The resulting tensor will have values sampled from
203  :math:\mathcal{U}(-a, a) where
204
205  .. math::
206  a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}}
207
208  Also known as Glorot initialization.
209
210  Args:
211  tensor: an n-dimensional torch.Tensor
212  gain: an optional scaling factor
213
214  Examples:
215  >>> w = torch.empty(3, 5)
216  >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))
217  """
218  fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
219  std = gain * math.sqrt(2.0 / (fan_in + fan_out))
220  a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
221  with torch.no_grad():
222  return tensor.uniform_(-a, a)
223
224
225 def xavier_normal_(tensor, gain=1):
226  r"""Fills the input Tensor with values according to the method
227  described in Understanding the difficulty of training deep feedforward
228  neural networks - Glorot, X. & Bengio, Y. (2010), using a normal
229  distribution. The resulting tensor will have values sampled from
230  :math:\mathcal{N}(0, \text{std}) where
231
232  .. math::
233  \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}}
234
235  Also known as Glorot initialization.
236
237  Args:
238  tensor: an n-dimensional torch.Tensor
239  gain: an optional scaling factor
240
241  Examples:
242  >>> w = torch.empty(3, 5)
243  >>> nn.init.xavier_normal_(w)
244  """
245  fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
246  std = gain * math.sqrt(2.0 / (fan_in + fan_out))
247  with torch.no_grad():
248  return tensor.normal_(0, std)
249
250
251 def _calculate_correct_fan(tensor, mode):
252  mode = mode.lower()
253  valid_modes = ['fan_in', 'fan_out']
254  if mode not in valid_modes:
255  raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
256
257  fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
258  return fan_in if mode == 'fan_in' else fan_out
259
260
261 def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
262  r"""Fills the input Tensor with values according to the method
263  described in Delving deep into rectifiers: Surpassing human-level
264  performance on ImageNet classification - He, K. et al. (2015), using a
265  uniform distribution. The resulting tensor will have values sampled from
266  :math:\mathcal{U}(-\text{bound}, \text{bound}) where
267
268  .. math::
269  \text{bound} = \sqrt{\frac{6}{(1 + a^2) \times \text{fan\_in}}}
270
271  Also known as He initialization.
272
273  Args:
274  tensor: an n-dimensional torch.Tensor
275  a: the negative slope of the rectifier used after this layer (0 for ReLU
276  by default)
277  mode: either 'fan_in' (default) or 'fan_out'. Choosing 'fan_in'
278  preserves the magnitude of the variance of the weights in the
279  forward pass. Choosing 'fan_out' preserves the magnitudes in the
280  backwards pass.
281  nonlinearity: the non-linear function (nn.functional name),
282  recommended to use only with 'relu' or 'leaky_relu' (default).
283
284  Examples:
285  >>> w = torch.empty(3, 5)
286  >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
287  """
288  fan = _calculate_correct_fan(tensor, mode)
289  gain = calculate_gain(nonlinearity, a)
290  std = gain / math.sqrt(fan)
291  bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
292  with torch.no_grad():
293  return tensor.uniform_(-bound, bound)
294
295
296 def kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
297  r"""Fills the input Tensor with values according to the method
298  described in Delving deep into rectifiers: Surpassing human-level
299  performance on ImageNet classification - He, K. et al. (2015), using a
300  normal distribution. The resulting tensor will have values sampled from
301  :math:\mathcal{N}(0, \text{std}) where
302
303  .. math::
304  \text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan\_in}}}
305
306  Also known as He initialization.
307
308  Args:
309  tensor: an n-dimensional torch.Tensor
310  a: the negative slope of the rectifier used after this layer (0 for ReLU
311  by default)
312  mode: either 'fan_in' (default) or 'fan_out'. Choosing 'fan_in'
313  preserves the magnitude of the variance of the weights in the
314  forward pass. Choosing 'fan_out' preserves the magnitudes in the
315  backwards pass.
316  nonlinearity: the non-linear function (nn.functional name),
317  recommended to use only with 'relu' or 'leaky_relu' (default).
318
319  Examples:
320  >>> w = torch.empty(3, 5)
321  >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
322  """
323  fan = _calculate_correct_fan(tensor, mode)
324  gain = calculate_gain(nonlinearity, a)
325  std = gain / math.sqrt(fan)
326  with torch.no_grad():
327  return tensor.normal_(0, std)
328
329
330 def orthogonal_(tensor, gain=1):
331  r"""Fills the input Tensor with a (semi) orthogonal matrix, as
332  described in Exact solutions to the nonlinear dynamics of learning in deep
333  linear neural networks - Saxe, A. et al. (2013). The input tensor must have
334  at least 2 dimensions, and for tensors with more than 2 dimensions the
335  trailing dimensions are flattened.
336
337  Args:
338  tensor: an n-dimensional torch.Tensor, where :math:n \geq 2
339  gain: optional scaling factor
340
341  Examples:
342  >>> w = torch.empty(3, 5)
343  >>> nn.init.orthogonal_(w)
344  """
345  if tensor.ndimension() < 2:
346  raise ValueError("Only tensors with 2 or more dimensions are supported")
347
348  rows = tensor.size(0)
349  cols = tensor.numel()
350  flattened = tensor.new(rows, cols).normal_(0, 1)
351
352  if rows < cols:
353  flattened.t_()
354
355  # Compute the qr factorization
356  q, r = torch.qr(flattened)
357  # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
358  d = torch.diag(r, 0)
359  ph = d.sign()
360  q *= ph
361
362  if rows < cols:
363  q.t_()
364
365  with torch.no_grad():
366  tensor.view_as(q).copy_(q)
367  tensor.mul_(gain)
368  return tensor
369
370
371 def sparse_(tensor, sparsity, std=0.01):
372  r"""Fills the 2D input Tensor as a sparse matrix, where the
373  non-zero elements will be drawn from the normal distribution
374  :math:\mathcal{N}(0, 0.01), as described in Deep learning via
375  Hessian-free optimization - Martens, J. (2010).
376
377  Args:
378  tensor: an n-dimensional torch.Tensor
379  sparsity: The fraction of elements in each column to be set to zero
380  std: the standard deviation of the normal distribution used to generate
381  the non-zero values
382
383  Examples:
384  >>> w = torch.empty(3, 5)
385  >>> nn.init.sparse_(w, sparsity=0.1)
386  """
387  if tensor.ndimension() != 2:
388  raise ValueError("Only tensors with 2 dimensions are supported")
389
390  rows, cols = tensor.shape
391  num_zeros = int(math.ceil(sparsity * rows))
392
393  with torch.no_grad():
394  tensor.normal_(0, std)
395  for col_idx in range(cols):
396  row_indices = torch.randperm(rows)
397  zero_indices = row_indices[:num_zeros]
398  tensor[zero_indices, col_idx] = 0
399  return tensor
400
401
402 # for backward compatibility
403 def _make_deprecate(meth):
404  new_name = meth.__name__
405  old_name = new_name[:-1]
406
407  def deprecated_init(*args, **kwargs):
408  warnings.warn("nn.init.{} is now deprecated in favor of nn.init.{}."
409  .format(old_name, new_name), stacklevel=2)
410  return meth(*args, **kwargs)
411
412  deprecated_init.__doc__ = r"""
413  {old_name}(...)
414
415  .. warning::
416  This method is now deprecated in favor of :func:torch.nn.init.{new_name}.
417
418  See :func:~torch.nn.init.{new_name} for details.""".format(
419  old_name=old_name, new_name=new_name)
420  deprecated_init.__name__ = old_name
421  return deprecated_init
422
423
424 uniform = _make_deprecate(uniform_)
425 normal = _make_deprecate(normal_)
426 constant = _make_deprecate(constant_)
427 eye = _make_deprecate(eye_)
428 dirac = _make_deprecate(dirac_)
429 xavier_uniform = _make_deprecate(xavier_uniform_)
430 xavier_normal = _make_deprecate(xavier_normal_)
431 kaiming_uniform = _make_deprecate(kaiming_uniform_)
432 kaiming_normal = _make_deprecate(kaiming_normal_)
433 orthogonal = _make_deprecate(orthogonal_)
434 sparse = _make_deprecate(sparse_)