Caffe2 - Python API
A deep learning, cross platform ML framework
init.py
1 import math
2 import random
3 import warnings
4 
5 import torch
6 
7 
8 def calculate_gain(nonlinearity, param=None):
9  r"""Return the recommended gain value for the given nonlinearity function.
10  The values are as follows:
11 
12  ================= ====================================================
13  nonlinearity gain
14  ================= ====================================================
15  Linear / Identity :math:`1`
16  Conv{1,2,3}D :math:`1`
17  Sigmoid :math:`1`
18  Tanh :math:`\frac{5}{3}`
19  ReLU :math:`\sqrt{2}`
20  Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
21  ================= ====================================================
22 
23  Args:
24  nonlinearity: the non-linear function (`nn.functional` name)
25  param: optional parameter for the non-linear function
26 
27  Examples:
28  >>> gain = nn.init.calculate_gain('leaky_relu')
29  """
30  linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
31  if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
32  return 1
33  elif nonlinearity == 'tanh':
34  return 5.0 / 3
35  elif nonlinearity == 'relu':
36  return math.sqrt(2.0)
37  elif nonlinearity == 'leaky_relu':
38  if param is None:
39  negative_slope = 0.01
40  elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
41  # True/False are instances of int, hence check above
42  negative_slope = param
43  else:
44  raise ValueError("negative_slope {} not a valid number".format(param))
45  return math.sqrt(2.0 / (1 + negative_slope ** 2))
46  else:
47  raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
48 
49 
50 def uniform_(tensor, a=0, b=1):
51  r"""Fills the input Tensor with values drawn from the uniform
52  distribution :math:`\mathcal{U}(a, b)`.
53 
54  Args:
55  tensor: an n-dimensional `torch.Tensor`
56  a: the lower bound of the uniform distribution
57  b: the upper bound of the uniform distribution
58 
59  Examples:
60  >>> w = torch.empty(3, 5)
61  >>> nn.init.uniform_(w)
62  """
63  with torch.no_grad():
64  return tensor.uniform_(a, b)
65 
66 
67 def normal_(tensor, mean=0, std=1):
68  r"""Fills the input Tensor with values drawn from the normal
69  distribution :math:`\mathcal{N}(\text{mean}, \text{std})`.
70 
71  Args:
72  tensor: an n-dimensional `torch.Tensor`
73  mean: the mean of the normal distribution
74  std: the standard deviation of the normal distribution
75 
76  Examples:
77  >>> w = torch.empty(3, 5)
78  >>> nn.init.normal_(w)
79  """
80  with torch.no_grad():
81  return tensor.normal_(mean, std)
82 
83 
84 def constant_(tensor, val):
85  r"""Fills the input Tensor with the value :math:`\text{val}`.
86 
87  Args:
88  tensor: an n-dimensional `torch.Tensor`
89  val: the value to fill the tensor with
90 
91  Examples:
92  >>> w = torch.empty(3, 5)
93  >>> nn.init.constant_(w, 0.3)
94  """
95  with torch.no_grad():
96  return tensor.fill_(val)
97 
98 
99 def ones_(tensor):
100  r"""Fills the input Tensor with ones`.
101 
102  Args:
103  tensor: an n-dimensional `torch.Tensor`
104 
105  Examples:
106  >>> w = torch.empty(3, 5)
107  >>> nn.init.ones_(w)
108  """
109  with torch.no_grad():
110  return tensor.fill_(1)
111 
112 
113 def zeros_(tensor):
114  r"""Fills the input Tensor with zeros`.
115 
116  Args:
117  tensor: an n-dimensional `torch.Tensor`
118 
119  Examples:
120  >>> w = torch.empty(3, 5)
121  >>> nn.init.zeros_(w)
122  """
123  with torch.no_grad():
124  return tensor.zero_()
125 
126 
127 def eye_(tensor):
128  r"""Fills the 2-dimensional input `Tensor` with the identity
129  matrix. Preserves the identity of the inputs in `Linear` layers, where as
130  many inputs are preserved as possible.
131 
132  Args:
133  tensor: a 2-dimensional `torch.Tensor`
134 
135  Examples:
136  >>> w = torch.empty(3, 5)
137  >>> nn.init.eye_(w)
138  """
139  if tensor.ndimension() != 2:
140  raise ValueError("Only tensors with 2 dimensions are supported")
141 
142  with torch.no_grad():
143  torch.eye(*tensor.shape, out=tensor, requires_grad=tensor.requires_grad)
144  return tensor
145 
146 
147 def dirac_(tensor):
148  r"""Fills the {3, 4, 5}-dimensional input `Tensor` with the Dirac
149  delta function. Preserves the identity of the inputs in `Convolutional`
150  layers, where as many input channels are preserved as possible.
151 
152  Args:
153  tensor: a {3, 4, 5}-dimensional `torch.Tensor`
154 
155  Examples:
156  >>> w = torch.empty(3, 16, 5, 5)
157  >>> nn.init.dirac_(w)
158  """
159  dimensions = tensor.ndimension()
160  if dimensions not in [3, 4, 5]:
161  raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported")
162 
163  sizes = tensor.size()
164  min_dim = min(sizes[0], sizes[1])
165  with torch.no_grad():
166  tensor.zero_()
167 
168  for d in range(min_dim):
169  if dimensions == 3: # Temporal convolution
170  tensor[d, d, tensor.size(2) // 2] = 1
171  elif dimensions == 4: # Spatial convolution
172  tensor[d, d, tensor.size(2) // 2, tensor.size(3) // 2] = 1
173  else: # Volumetric convolution
174  tensor[d, d, tensor.size(2) // 2, tensor.size(3) // 2, tensor.size(4) // 2] = 1
175  return tensor
176 
177 
178 def _calculate_fan_in_and_fan_out(tensor):
179  dimensions = tensor.ndimension()
180  if dimensions < 2:
181  raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
182 
183  if dimensions == 2: # Linear
184  fan_in = tensor.size(1)
185  fan_out = tensor.size(0)
186  else:
187  num_input_fmaps = tensor.size(1)
188  num_output_fmaps = tensor.size(0)
189  receptive_field_size = 1
190  if tensor.dim() > 2:
191  receptive_field_size = tensor[0][0].numel()
192  fan_in = num_input_fmaps * receptive_field_size
193  fan_out = num_output_fmaps * receptive_field_size
194 
195  return fan_in, fan_out
196 
197 
198 def xavier_uniform_(tensor, gain=1):
199  r"""Fills the input `Tensor` with values according to the method
200  described in `Understanding the difficulty of training deep feedforward
201  neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform
202  distribution. The resulting tensor will have values sampled from
203  :math:`\mathcal{U}(-a, a)` where
204 
205  .. math::
206  a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}}
207 
208  Also known as Glorot initialization.
209 
210  Args:
211  tensor: an n-dimensional `torch.Tensor`
212  gain: an optional scaling factor
213 
214  Examples:
215  >>> w = torch.empty(3, 5)
216  >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))
217  """
218  fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
219  std = gain * math.sqrt(2.0 / (fan_in + fan_out))
220  a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
221  with torch.no_grad():
222  return tensor.uniform_(-a, a)
223 
224 
225 def xavier_normal_(tensor, gain=1):
226  r"""Fills the input `Tensor` with values according to the method
227  described in `Understanding the difficulty of training deep feedforward
228  neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal
229  distribution. The resulting tensor will have values sampled from
230  :math:`\mathcal{N}(0, \text{std})` where
231 
232  .. math::
233  \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}}
234 
235  Also known as Glorot initialization.
236 
237  Args:
238  tensor: an n-dimensional `torch.Tensor`
239  gain: an optional scaling factor
240 
241  Examples:
242  >>> w = torch.empty(3, 5)
243  >>> nn.init.xavier_normal_(w)
244  """
245  fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
246  std = gain * math.sqrt(2.0 / (fan_in + fan_out))
247  with torch.no_grad():
248  return tensor.normal_(0, std)
249 
250 
251 def _calculate_correct_fan(tensor, mode):
252  mode = mode.lower()
253  valid_modes = ['fan_in', 'fan_out']
254  if mode not in valid_modes:
255  raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
256 
257  fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
258  return fan_in if mode == 'fan_in' else fan_out
259 
260 
261 def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
262  r"""Fills the input `Tensor` with values according to the method
263  described in `Delving deep into rectifiers: Surpassing human-level
264  performance on ImageNet classification` - He, K. et al. (2015), using a
265  uniform distribution. The resulting tensor will have values sampled from
266  :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
267 
268  .. math::
269  \text{bound} = \sqrt{\frac{6}{(1 + a^2) \times \text{fan\_in}}}
270 
271  Also known as He initialization.
272 
273  Args:
274  tensor: an n-dimensional `torch.Tensor`
275  a: the negative slope of the rectifier used after this layer (0 for ReLU
276  by default)
277  mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
278  preserves the magnitude of the variance of the weights in the
279  forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
280  backwards pass.
281  nonlinearity: the non-linear function (`nn.functional` name),
282  recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
283 
284  Examples:
285  >>> w = torch.empty(3, 5)
286  >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
287  """
288  fan = _calculate_correct_fan(tensor, mode)
289  gain = calculate_gain(nonlinearity, a)
290  std = gain / math.sqrt(fan)
291  bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
292  with torch.no_grad():
293  return tensor.uniform_(-bound, bound)
294 
295 
296 def kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
297  r"""Fills the input `Tensor` with values according to the method
298  described in `Delving deep into rectifiers: Surpassing human-level
299  performance on ImageNet classification` - He, K. et al. (2015), using a
300  normal distribution. The resulting tensor will have values sampled from
301  :math:`\mathcal{N}(0, \text{std})` where
302 
303  .. math::
304  \text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan\_in}}}
305 
306  Also known as He initialization.
307 
308  Args:
309  tensor: an n-dimensional `torch.Tensor`
310  a: the negative slope of the rectifier used after this layer (0 for ReLU
311  by default)
312  mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
313  preserves the magnitude of the variance of the weights in the
314  forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
315  backwards pass.
316  nonlinearity: the non-linear function (`nn.functional` name),
317  recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
318 
319  Examples:
320  >>> w = torch.empty(3, 5)
321  >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
322  """
323  fan = _calculate_correct_fan(tensor, mode)
324  gain = calculate_gain(nonlinearity, a)
325  std = gain / math.sqrt(fan)
326  with torch.no_grad():
327  return tensor.normal_(0, std)
328 
329 
330 def orthogonal_(tensor, gain=1):
331  r"""Fills the input `Tensor` with a (semi) orthogonal matrix, as
332  described in `Exact solutions to the nonlinear dynamics of learning in deep
333  linear neural networks` - Saxe, A. et al. (2013). The input tensor must have
334  at least 2 dimensions, and for tensors with more than 2 dimensions the
335  trailing dimensions are flattened.
336 
337  Args:
338  tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2`
339  gain: optional scaling factor
340 
341  Examples:
342  >>> w = torch.empty(3, 5)
343  >>> nn.init.orthogonal_(w)
344  """
345  if tensor.ndimension() < 2:
346  raise ValueError("Only tensors with 2 or more dimensions are supported")
347 
348  rows = tensor.size(0)
349  cols = tensor[0].numel()
350  flattened = tensor.new(rows, cols).normal_(0, 1)
351 
352  if rows < cols:
353  flattened.t_()
354 
355  # Compute the qr factorization
356  q, r = torch.qr(flattened)
357  # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
358  d = torch.diag(r, 0)
359  ph = d.sign()
360  q *= ph
361 
362  if rows < cols:
363  q.t_()
364 
365  with torch.no_grad():
366  tensor.view_as(q).copy_(q)
367  tensor.mul_(gain)
368  return tensor
369 
370 
371 def sparse_(tensor, sparsity, std=0.01):
372  r"""Fills the 2D input `Tensor` as a sparse matrix, where the
373  non-zero elements will be drawn from the normal distribution
374  :math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via
375  Hessian-free optimization` - Martens, J. (2010).
376 
377  Args:
378  tensor: an n-dimensional `torch.Tensor`
379  sparsity: The fraction of elements in each column to be set to zero
380  std: the standard deviation of the normal distribution used to generate
381  the non-zero values
382 
383  Examples:
384  >>> w = torch.empty(3, 5)
385  >>> nn.init.sparse_(w, sparsity=0.1)
386  """
387  if tensor.ndimension() != 2:
388  raise ValueError("Only tensors with 2 dimensions are supported")
389 
390  rows, cols = tensor.shape
391  num_zeros = int(math.ceil(sparsity * rows))
392 
393  with torch.no_grad():
394  tensor.normal_(0, std)
395  for col_idx in range(cols):
396  row_indices = torch.randperm(rows)
397  zero_indices = row_indices[:num_zeros]
398  tensor[zero_indices, col_idx] = 0
399  return tensor
400 
401 
402 # for backward compatibility
403 def _make_deprecate(meth):
404  new_name = meth.__name__
405  old_name = new_name[:-1]
406 
407  def deprecated_init(*args, **kwargs):
408  warnings.warn("nn.init.{} is now deprecated in favor of nn.init.{}."
409  .format(old_name, new_name), stacklevel=2)
410  return meth(*args, **kwargs)
411 
412  deprecated_init.__doc__ = r"""
413  {old_name}(...)
414 
415  .. warning::
416  This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`.
417 
418  See :func:`~torch.nn.init.{new_name}` for details.""".format(
419  old_name=old_name, new_name=new_name)
420  deprecated_init.__name__ = old_name
421  return deprecated_init
422 
423 
424 uniform = _make_deprecate(uniform_)
425 normal = _make_deprecate(normal_)
426 constant = _make_deprecate(constant_)
427 eye = _make_deprecate(eye_)
428 dirac = _make_deprecate(dirac_)
429 xavier_uniform = _make_deprecate(xavier_uniform_)
430 xavier_normal = _make_deprecate(xavier_normal_)
431 kaiming_uniform = _make_deprecate(kaiming_uniform_)
432 kaiming_normal = _make_deprecate(kaiming_normal_)
433 orthogonal = _make_deprecate(orthogonal_)
434 sparse = _make_deprecate(sparse_)