8 def calculate_gain(nonlinearity, param=None):
9 r"""Return the recommended gain value for the given nonlinearity function. 10 The values are as follows: 12 ================= ==================================================== 14 ================= ==================================================== 15 Linear / Identity :math:`1` 16 Conv{1,2,3}D :math:`1` 18 Tanh :math:`\frac{5}{3}` 20 Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` 21 ================= ==================================================== 24 nonlinearity: the non-linear function (`nn.functional` name) 25 param: optional parameter for the non-linear function 28 >>> gain = nn.init.calculate_gain('leaky_relu') 30 linear_fns = [
'linear',
'conv1d',
'conv2d',
'conv3d',
'conv_transpose1d',
'conv_transpose2d',
'conv_transpose3d']
31 if nonlinearity
in linear_fns
or nonlinearity ==
'sigmoid':
33 elif nonlinearity ==
'tanh':
35 elif nonlinearity ==
'relu':
37 elif nonlinearity ==
'leaky_relu':
40 elif not isinstance(param, bool)
and isinstance(param, int)
or isinstance(param, float):
42 negative_slope = param
44 raise ValueError(
"negative_slope {} not a valid number".format(param))
45 return math.sqrt(2.0 / (1 + negative_slope ** 2))
47 raise ValueError(
"Unsupported nonlinearity {}".format(nonlinearity))
50 def uniform_(tensor, a=0, b=1):
51 r"""Fills the input Tensor with values drawn from the uniform 52 distribution :math:`\mathcal{U}(a, b)`. 55 tensor: an n-dimensional `torch.Tensor` 56 a: the lower bound of the uniform distribution 57 b: the upper bound of the uniform distribution 60 >>> w = torch.empty(3, 5) 61 >>> nn.init.uniform_(w) 64 return tensor.uniform_(a, b)
67 def normal_(tensor, mean=0, std=1):
68 r"""Fills the input Tensor with values drawn from the normal 69 distribution :math:`\mathcal{N}(\text{mean}, \text{std})`. 72 tensor: an n-dimensional `torch.Tensor` 73 mean: the mean of the normal distribution 74 std: the standard deviation of the normal distribution 77 >>> w = torch.empty(3, 5) 78 >>> nn.init.normal_(w) 81 return tensor.normal_(mean, std)
84 def constant_(tensor, val):
85 r"""Fills the input Tensor with the value :math:`\text{val}`. 88 tensor: an n-dimensional `torch.Tensor` 89 val: the value to fill the tensor with 92 >>> w = torch.empty(3, 5) 93 >>> nn.init.constant_(w, 0.3) 96 return tensor.fill_(val)
100 r"""Fills the input Tensor with ones`. 103 tensor: an n-dimensional `torch.Tensor` 106 >>> w = torch.empty(3, 5) 109 with torch.no_grad():
110 return tensor.fill_(1)
114 r"""Fills the input Tensor with zeros`. 117 tensor: an n-dimensional `torch.Tensor` 120 >>> w = torch.empty(3, 5) 121 >>> nn.init.zeros_(w) 123 with torch.no_grad():
124 return tensor.zero_()
128 r"""Fills the 2-dimensional input `Tensor` with the identity 129 matrix. Preserves the identity of the inputs in `Linear` layers, where as 130 many inputs are preserved as possible. 133 tensor: a 2-dimensional `torch.Tensor` 136 >>> w = torch.empty(3, 5) 139 if tensor.ndimension() != 2:
140 raise ValueError(
"Only tensors with 2 dimensions are supported")
142 with torch.no_grad():
143 torch.eye(*tensor.shape, out=tensor, requires_grad=tensor.requires_grad)
148 r"""Fills the {3, 4, 5}-dimensional input `Tensor` with the Dirac 149 delta function. Preserves the identity of the inputs in `Convolutional` 150 layers, where as many input channels are preserved as possible. 153 tensor: a {3, 4, 5}-dimensional `torch.Tensor` 156 >>> w = torch.empty(3, 16, 5, 5) 157 >>> nn.init.dirac_(w) 159 dimensions = tensor.ndimension()
160 if dimensions
not in [3, 4, 5]:
161 raise ValueError(
"Only tensors with 3, 4, or 5 dimensions are supported")
163 sizes = tensor.size()
164 min_dim = min(sizes[0], sizes[1])
165 with torch.no_grad():
168 for d
in range(min_dim):
170 tensor[d, d, tensor.size(2) // 2] = 1
171 elif dimensions == 4:
172 tensor[d, d, tensor.size(2) // 2, tensor.size(3) // 2] = 1
174 tensor[d, d, tensor.size(2) // 2, tensor.size(3) // 2, tensor.size(4) // 2] = 1
178 def _calculate_fan_in_and_fan_out(tensor):
179 dimensions = tensor.ndimension()
181 raise ValueError(
"Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
184 fan_in = tensor.size(1)
185 fan_out = tensor.size(0)
187 num_input_fmaps = tensor.size(1)
188 num_output_fmaps = tensor.size(0)
189 receptive_field_size = 1
191 receptive_field_size = tensor[0][0].numel()
192 fan_in = num_input_fmaps * receptive_field_size
193 fan_out = num_output_fmaps * receptive_field_size
195 return fan_in, fan_out
198 def xavier_uniform_(tensor, gain=1):
199 r"""Fills the input `Tensor` with values according to the method 200 described in `Understanding the difficulty of training deep feedforward 201 neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform 202 distribution. The resulting tensor will have values sampled from 203 :math:`\mathcal{U}(-a, a)` where 206 a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}} 208 Also known as Glorot initialization. 211 tensor: an n-dimensional `torch.Tensor` 212 gain: an optional scaling factor 215 >>> w = torch.empty(3, 5) 216 >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu')) 218 fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
219 std = gain * math.sqrt(2.0 / (fan_in + fan_out))
220 a = math.sqrt(3.0) * std
221 with torch.no_grad():
222 return tensor.uniform_(-a, a)
225 def xavier_normal_(tensor, gain=1):
226 r"""Fills the input `Tensor` with values according to the method 227 described in `Understanding the difficulty of training deep feedforward 228 neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal 229 distribution. The resulting tensor will have values sampled from 230 :math:`\mathcal{N}(0, \text{std})` where 233 \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}} 235 Also known as Glorot initialization. 238 tensor: an n-dimensional `torch.Tensor` 239 gain: an optional scaling factor 242 >>> w = torch.empty(3, 5) 243 >>> nn.init.xavier_normal_(w) 245 fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
246 std = gain * math.sqrt(2.0 / (fan_in + fan_out))
247 with torch.no_grad():
248 return tensor.normal_(0, std)
251 def _calculate_correct_fan(tensor, mode):
253 valid_modes = [
'fan_in',
'fan_out']
254 if mode
not in valid_modes:
255 raise ValueError(
"Mode {} not supported, please use one of {}".format(mode, valid_modes))
257 fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
258 return fan_in
if mode ==
'fan_in' else fan_out
261 def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
262 r"""Fills the input `Tensor` with values according to the method 263 described in `Delving deep into rectifiers: Surpassing human-level 264 performance on ImageNet classification` - He, K. et al. (2015), using a 265 uniform distribution. The resulting tensor will have values sampled from 266 :math:`\mathcal{U}(-\text{bound}, \text{bound})` where 269 \text{bound} = \sqrt{\frac{6}{(1 + a^2) \times \text{fan\_in}}} 271 Also known as He initialization. 274 tensor: an n-dimensional `torch.Tensor` 275 a: the negative slope of the rectifier used after this layer (0 for ReLU 277 mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` 278 preserves the magnitude of the variance of the weights in the 279 forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the 281 nonlinearity: the non-linear function (`nn.functional` name), 282 recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). 285 >>> w = torch.empty(3, 5) 286 >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu') 288 fan = _calculate_correct_fan(tensor, mode)
289 gain = calculate_gain(nonlinearity, a)
290 std = gain / math.sqrt(fan)
291 bound = math.sqrt(3.0) * std
292 with torch.no_grad():
293 return tensor.uniform_(-bound, bound)
296 def kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
297 r"""Fills the input `Tensor` with values according to the method 298 described in `Delving deep into rectifiers: Surpassing human-level 299 performance on ImageNet classification` - He, K. et al. (2015), using a 300 normal distribution. The resulting tensor will have values sampled from 301 :math:`\mathcal{N}(0, \text{std})` where 304 \text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan\_in}}} 306 Also known as He initialization. 309 tensor: an n-dimensional `torch.Tensor` 310 a: the negative slope of the rectifier used after this layer (0 for ReLU 312 mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` 313 preserves the magnitude of the variance of the weights in the 314 forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the 316 nonlinearity: the non-linear function (`nn.functional` name), 317 recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). 320 >>> w = torch.empty(3, 5) 321 >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu') 323 fan = _calculate_correct_fan(tensor, mode)
324 gain = calculate_gain(nonlinearity, a)
325 std = gain / math.sqrt(fan)
326 with torch.no_grad():
327 return tensor.normal_(0, std)
330 def orthogonal_(tensor, gain=1):
331 r"""Fills the input `Tensor` with a (semi) orthogonal matrix, as 332 described in `Exact solutions to the nonlinear dynamics of learning in deep 333 linear neural networks` - Saxe, A. et al. (2013). The input tensor must have 334 at least 2 dimensions, and for tensors with more than 2 dimensions the 335 trailing dimensions are flattened. 338 tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2` 339 gain: optional scaling factor 342 >>> w = torch.empty(3, 5) 343 >>> nn.init.orthogonal_(w) 345 if tensor.ndimension() < 2:
346 raise ValueError(
"Only tensors with 2 or more dimensions are supported")
348 rows = tensor.size(0)
349 cols = tensor[0].numel()
350 flattened = tensor.new(rows, cols).normal_(0, 1)
356 q, r = torch.qr(flattened)
365 with torch.no_grad():
366 tensor.view_as(q).copy_(q)
371 def sparse_(tensor, sparsity, std=0.01):
372 r"""Fills the 2D input `Tensor` as a sparse matrix, where the 373 non-zero elements will be drawn from the normal distribution 374 :math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via 375 Hessian-free optimization` - Martens, J. (2010). 378 tensor: an n-dimensional `torch.Tensor` 379 sparsity: The fraction of elements in each column to be set to zero 380 std: the standard deviation of the normal distribution used to generate 384 >>> w = torch.empty(3, 5) 385 >>> nn.init.sparse_(w, sparsity=0.1) 387 if tensor.ndimension() != 2:
388 raise ValueError(
"Only tensors with 2 dimensions are supported")
390 rows, cols = tensor.shape
391 num_zeros = int(math.ceil(sparsity * rows))
393 with torch.no_grad():
394 tensor.normal_(0, std)
395 for col_idx
in range(cols):
396 row_indices = torch.randperm(rows)
397 zero_indices = row_indices[:num_zeros]
398 tensor[zero_indices, col_idx] = 0
403 def _make_deprecate(meth):
404 new_name = meth.__name__
405 old_name = new_name[:-1]
407 def deprecated_init(*args, **kwargs):
408 warnings.warn(
"nn.init.{} is now deprecated in favor of nn.init.{}." 409 .format(old_name, new_name), stacklevel=2)
410 return meth(*args, **kwargs)
412 deprecated_init.__doc__ =
r""" 416 This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`. 418 See :func:`~torch.nn.init.{new_name}` for details.""".format(
419 old_name=old_name, new_name=new_name)
420 deprecated_init.__name__ = old_name
421 return deprecated_init
424 uniform = _make_deprecate(uniform_)
425 normal = _make_deprecate(normal_)
426 constant = _make_deprecate(constant_)
427 eye = _make_deprecate(eye_)
428 dirac = _make_deprecate(dirac_)
429 xavier_uniform = _make_deprecate(xavier_uniform_)
430 xavier_normal = _make_deprecate(xavier_normal_)
431 kaiming_uniform = _make_deprecate(kaiming_uniform_)
432 kaiming_normal = _make_deprecate(kaiming_normal_)
433 orthogonal = _make_deprecate(orthogonal_)
434 sparse = _make_deprecate(sparse_)