1 from collections
import OrderedDict
6 from ..backends.thnn
import backend
as thnn_backend
7 from ..parameter
import Parameter
11 def _addindent(s_, numSpaces):
17 s = [(numSpaces *
' ') + line
for line
in s]
24 r"""Base class for all neural network modules. 26 Your models should also subclass this class. 28 Modules can also contain other Modules, allowing to nest them in 29 a tree structure. You can assign the submodules as regular attributes:: 32 import torch.nn.functional as F 34 class Model(nn.Module): 36 super(Model, self).__init__() 37 self.conv1 = nn.Conv2d(1, 20, 5) 38 self.conv2 = nn.Conv2d(20, 20, 5) 41 x = F.relu(self.conv1(x)) 42 return F.relu(self.conv2(x)) 44 Submodules assigned in this way will be registered, and will have their 45 parameters converted too when you call :meth:`to`, etc. 50 r"""This allows better BC support for :meth:`load_state_dict`. In 51 :meth:`state_dict`, the version number will be saved as in the attribute 52 `_metadata` of the returned state dict, and thus pickled. `_metadata` is a 53 dictionary with keys that follow the naming convention of state dict. See 54 ``_load_from_state_dict`` on how to use this information in loading. 56 If new parameters/buffers are added/removed from a module, this number shall 57 be bumped, and the module's `_load_from_state_dict` method can compare the 58 version number and do appropriate changes if the state dict is from before 74 def forward(self, *input):
75 r"""Defines the computation performed at every call. 77 Should be overridden by all subclasses. 80 Although the recipe for forward pass needs to be defined within 81 this function, one should call the :class:`Module` instance afterwards 82 instead of this since the former takes care of running the 83 registered hooks while the latter silently ignores them. 85 raise NotImplementedError
87 def register_buffer(self, name, tensor):
88 r"""Adds a persistent buffer to the module. 90 This is typically used to register a buffer that should not to be 91 considered a model parameter. For example, BatchNorm's ``running_mean`` 92 is not a parameter, but is part of the persistent state. 94 Buffers can be accessed as attributes using given names. 97 name (string): name of the buffer. The buffer can be accessed 98 from this module using the given name 99 tensor (Tensor): buffer to be registered. 103 >>> self.register_buffer('running_mean', torch.zeros(num_features)) 106 if '_buffers' not in self.__dict__:
107 raise AttributeError(
108 "cannot assign buffer before Module.__init__() call")
109 elif not isinstance(name, torch._six.string_classes):
110 raise TypeError(
"buffer name should be a string. " 113 raise KeyError(
"buffer name can't contain \".\"")
115 raise KeyError(
"buffer name can't be empty string \"\"")
116 elif hasattr(self, name)
and name
not in self.
_buffers:
117 raise KeyError(
"attribute '{}' already exists".format(name))
118 elif tensor
is not None and not isinstance(tensor, torch.Tensor):
119 raise TypeError(
"cannot assign '{}' object to buffer '{}' " 120 "(torch Tensor or None required)" 125 def register_parameter(self, name, param):
126 r"""Adds a parameter to the module. 128 The parameter can be accessed as an attribute using given name. 131 name (string): name of the parameter. The parameter can be accessed 132 from this module using the given name 133 param (Parameter): parameter to be added to the module. 135 if '_parameters' not in self.__dict__:
136 raise AttributeError(
137 "cannot assign parameter before Module.__init__() call")
139 elif not isinstance(name, torch._six.string_classes):
140 raise TypeError(
"parameter name should be a string. " 143 raise KeyError(
"parameter name can't contain \".\"")
145 raise KeyError(
"parameter name can't be empty string \"\"")
146 elif hasattr(self, name)
and name
not in self.
_parameters:
147 raise KeyError(
"attribute '{}' already exists".format(name))
151 elif not isinstance(param, Parameter):
152 raise TypeError(
"cannot assign '{}' object to parameter '{}' " 153 "(torch.nn.Parameter or None required)" 157 "Cannot assign non-leaf Tensor to parameter '{0}'. Model " 158 "parameters must be created explicitly. To express '{0}' " 159 "as a function of another Tensor, compute the value in " 160 "the forward() method.".format(name))
164 def add_module(self, name, module):
165 r"""Adds a child module to the current module. 167 The module can be accessed as an attribute using the given name. 170 name (string): name of the child module. The child module can be 171 accessed from this module using the given name 172 module (Module): child module to be added to the module. 174 if not isinstance(module, Module)
and module
is not None:
175 raise TypeError(
"{} is not a Module subclass".format(
177 elif not isinstance(name, torch._six.string_classes):
178 raise TypeError(
"module name should be a string. Got {}".format(
180 elif hasattr(self, name)
and name
not in self.
_modules:
181 raise KeyError(
"attribute '{}' already exists".format(name))
183 raise KeyError(
"module name can't contain \".\"")
185 raise KeyError(
"module name can't be empty string \"\"")
188 def _apply(self, fn):
192 for param
in self._parameters.values():
193 if param
is not None:
196 param.data = fn(param.data)
197 if param._grad
is not None:
198 param._grad.data = fn(param._grad.data)
200 for key, buf
in self._buffers.items():
207 r"""Applies ``fn`` recursively to every submodule (as returned by ``.children()``) 208 as well as self. Typical use includes initializing the parameters of a model 209 (see also :ref:`torch-nn-init`). 212 fn (:class:`Module` -> None): function to be applied to each submodule 219 >>> def init_weights(m): 221 if type(m) == nn.Linear: 222 m.weight.data.fill_(1.0) 225 >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) 226 >>> net.apply(init_weights) 227 Linear(in_features=2, out_features=2, bias=True) 228 Parameter containing: 231 Linear(in_features=2, out_features=2, bias=True) 232 Parameter containing: 236 (0): Linear(in_features=2, out_features=2, bias=True) 237 (1): Linear(in_features=2, out_features=2, bias=True) 240 (0): Linear(in_features=2, out_features=2, bias=True) 241 (1): Linear(in_features=2, out_features=2, bias=True) 249 def cuda(self, device=None):
250 r"""Moves all model parameters and buffers to the GPU. 252 This also makes associated parameters and buffers different objects. So 253 it should be called before constructing optimizer if the module will 254 live on GPU while being optimized. 257 device (int, optional): if specified, all parameters will be 258 copied to that device 263 return self.
_apply(
lambda t: t.cuda(device))
266 r"""Moves all model parameters and buffers to the CPU. 271 return self.
_apply(
lambda t: t.cpu())
273 def type(self, dst_type):
274 r"""Casts all parameters and buffers to :attr:`dst_type`. 277 dst_type (type or string): the desired type 282 return self.
_apply(
lambda t: t.type(dst_type))
285 r"""Casts all floating point parameters and buffers to float datatype. 290 return self.
_apply(
lambda t: t.float()
if t.is_floating_point()
else t)
293 r"""Casts all floating point parameters and buffers to ``double`` datatype. 298 return self.
_apply(
lambda t: t.double()
if t.is_floating_point()
else t)
301 r"""Casts all floating point parameters and buffers to ``half`` datatype. 306 return self.
_apply(
lambda t: t.half()
if t.is_floating_point()
else t)
308 def to(self, *args, **kwargs):
309 r"""Moves and/or casts the parameters and buffers. 311 This can be called as 313 .. function:: to(device=None, dtype=None, non_blocking=False) 315 .. function:: to(dtype, non_blocking=False) 317 .. function:: to(tensor, non_blocking=False) 319 Its signature is similar to :meth:`torch.Tensor.to`, but only accepts 320 floating point desired :attr:`dtype` s. In addition, this method will 321 only cast the floating point parameters and buffers to :attr:`dtype` 322 (if given). The integral parameters and buffers will be moved 323 :attr:`device`, if that is given, but with dtypes unchanged. When 324 :attr:`non_blocking` is set, it tries to convert/move asynchronously 325 with respect to the host if possible, e.g., moving CPU Tensors with 326 pinned memory to CUDA devices. 328 See below for examples. 331 This method modifies the module in-place. 334 device (:class:`torch.device`): the desired device of the parameters 335 and buffers in this module 336 dtype (:class:`torch.dtype`): the desired floating point type of 337 the floating point parameters and buffers in this module 338 tensor (torch.Tensor): Tensor whose dtype and device are the desired 339 dtype and device for all parameters and buffers in this module 346 >>> linear = nn.Linear(2, 2) 348 Parameter containing: 349 tensor([[ 0.1913, -0.3420], 351 >>> linear.to(torch.double) 352 Linear(in_features=2, out_features=2, bias=True) 354 Parameter containing: 355 tensor([[ 0.1913, -0.3420], 356 [-0.5113, -0.2325]], dtype=torch.float64) 357 >>> gpu1 = torch.device("cuda:1") 358 >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) 359 Linear(in_features=2, out_features=2, bias=True) 361 Parameter containing: 362 tensor([[ 0.1914, -0.3420], 363 [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') 364 >>> cpu = torch.device("cpu") 366 Linear(in_features=2, out_features=2, bias=True) 368 Parameter containing: 369 tensor([[ 0.1914, -0.3420], 370 [-0.5112, -0.2324]], dtype=torch.float16) 374 device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)
376 if dtype
is not None:
377 if not dtype.is_floating_point:
378 raise TypeError(
'nn.Module.to only accepts floating point ' 379 'dtypes, but got desired dtype={}'.format(dtype))
382 return t.to(device, dtype
if t.is_floating_point()
else None, non_blocking)
384 return self.
_apply(convert)
386 def register_backward_hook(self, hook):
387 r"""Registers a backward hook on the module. 389 The hook will be called every time the gradients with respect to module 390 inputs are computed. The hook should have the following signature:: 392 hook(module, grad_input, grad_output) -> Tensor or None 394 The :attr:`grad_input` and :attr:`grad_output` may be tuples if the 395 module has multiple inputs or outputs. The hook should not modify its 396 arguments, but it can optionally return a new gradient with respect to 397 input that will be used in place of :attr:`grad_input` in subsequent 401 :class:`torch.utils.hooks.RemovableHandle`: 402 a handle that can be used to remove the added hook by calling 407 The current implementation will not have the presented behavior 408 for complex :class:`Module` that perform many operations. 409 In some failure cases, :attr:`grad_input` and :attr:`grad_output` will only 410 contain the gradients for a subset of the inputs and outputs. 411 For such :class:`Module`, you should use :func:`torch.Tensor.register_hook` 412 directly on a specific input or output to get the required gradients. 419 def register_forward_pre_hook(self, hook):
420 r"""Registers a forward pre-hook on the module. 422 The hook will be called every time before :func:`forward` is invoked. 423 It should have the following signature:: 425 hook(module, input) -> None 427 The hook should not modify the input. 430 :class:`torch.utils.hooks.RemovableHandle`: 431 a handle that can be used to remove the added hook by calling 438 def register_forward_hook(self, hook):
439 r"""Registers a forward hook on the module. 441 The hook will be called every time after :func:`forward` has computed an output. 442 It should have the following signature:: 444 hook(module, input, output) -> None 446 The hook should not modify the input or output. 449 :class:`torch.utils.hooks.RemovableHandle`: 450 a handle that can be used to remove the added hook by calling 457 def _tracing_name(self, tracing_state):
458 if not tracing_state._traced_module_stack:
460 module = tracing_state._traced_module_stack[-1]
461 for name, child
in module.named_children():
466 def _slow_forward(self, *input, **kwargs):
467 tracing_state = torch._C._get_tracing_state()
468 if not tracing_state:
469 return self.
forward(*input, **kwargs)
470 if not hasattr(tracing_state,
'_traced_module_stack'):
471 tracing_state._traced_module_stack = []
474 tracing_state.push_scope(
'%s[%s]' % (self.
_get_name(), name))
476 tracing_state.push_scope(self.
_get_name())
477 tracing_state._traced_module_stack.append(self)
479 result = self.
forward(*input, **kwargs)
481 tracing_state.pop_scope()
482 tracing_state._traced_module_stack.pop()
485 def __call__(self, *input, **kwargs):
486 for hook
in self._forward_pre_hooks.values():
488 if torch._C._get_tracing_state():
491 result = self.
forward(*input, **kwargs)
492 for hook
in self._forward_hooks.values():
493 hook_result = hook(self, input, result)
494 if hook_result
is not None:
496 "forward hooks should never return any values, but '{}'" 497 "didn't return None".format(hook))
500 while not isinstance(var, torch.Tensor):
501 if isinstance(var, dict):
502 var = next((v
for v
in var.values()
if isinstance(v, torch.Tensor)))
505 grad_fn = var.grad_fn
506 if grad_fn
is not None:
507 for hook
in self._backward_hooks.values():
508 wrapper = functools.partial(hook, self)
509 functools.update_wrapper(wrapper, hook)
510 grad_fn.register_hook(wrapper)
513 def __setstate__(self, state):
514 self.__dict__.update(state)
516 if '_forward_pre_hooks' not in self.__dict__:
518 if '_state_dict_hooks' not in self.__dict__:
520 if '_load_state_dict_pre_hooks' not in self.__dict__:
523 def __getattr__(self, name):
524 if '_parameters' in self.__dict__:
525 _parameters = self.__dict__[
'_parameters']
526 if name
in _parameters:
527 return _parameters[name]
528 if '_buffers' in self.__dict__:
529 _buffers = self.__dict__[
'_buffers']
531 return _buffers[name]
532 if '_modules' in self.__dict__:
533 modules = self.__dict__[
'_modules']
536 raise AttributeError(
"'{}' object has no attribute '{}'".format(
537 type(self).__name__, name))
539 def __setattr__(self, name, value):
540 def remove_from(*dicts):
545 params = self.__dict__.get(
'_parameters')
546 if isinstance(value, Parameter):
548 raise AttributeError(
549 "cannot assign parameters before Module.__init__() call")
552 elif params
is not None and name
in params:
553 if value
is not None:
554 raise TypeError(
"cannot assign '{}' as parameter '{}' " 555 "(torch.nn.Parameter or None expected)" 559 modules = self.__dict__.get(
'_modules')
560 if isinstance(value, Module):
562 raise AttributeError(
563 "cannot assign module before Module.__init__() call")
565 modules[name] = value
566 elif modules
is not None and name
in modules:
567 if value
is not None:
568 raise TypeError(
"cannot assign '{}' as child module '{}' " 569 "(torch.nn.Module or None expected)" 571 modules[name] = value
573 buffers = self.__dict__.get(
'_buffers')
574 if buffers
is not None and name
in buffers:
575 if value
is not None and not isinstance(value, torch.Tensor):
576 raise TypeError(
"cannot assign '{}' as buffer '{}' " 577 "(torch.Tensor or None expected)" 579 buffers[name] = value
581 object.__setattr__(self, name, value)
583 def __delattr__(self, name):
591 object.__delattr__(self, name)
593 def _register_state_dict_hook(self, hook):
594 r"""These hooks will be called with arguments: `self`, `state_dict`, 595 `prefix`, `local_metadata`, after the `state_dict` of `self` is set. 596 Note that only parameters and buffers of `self` or its children are 597 guaranteed to exist in `state_dict`. The hooks may modify `state_dict` 598 inplace or return a new one. 604 def state_dict(self, destination=None, prefix='', keep_vars=False):
605 r"""Returns a dictionary containing a whole state of the module. 607 Both parameters and persistent buffers (e.g. running averages) are 608 included. Keys are corresponding parameter and buffer names. 612 a dictionary containing a whole state of the module 616 >>> module.state_dict().keys() 620 if destination
is None:
621 destination = OrderedDict()
622 destination._metadata = OrderedDict()
623 destination._metadata[prefix[:-1]] = local_metadata = dict(version=self.
_version)
624 for name, param
in self._parameters.items():
625 if param
is not None:
626 destination[prefix + name] = param
if keep_vars
else param.data
627 for name, buf
in self._buffers.items():
629 destination[prefix + name] = buf
if keep_vars
else buf.data
630 for name, module
in self._modules.items():
631 if module
is not None:
632 module.state_dict(destination, prefix + name +
'.', keep_vars=keep_vars)
633 for hook
in self._state_dict_hooks.values():
634 hook_result = hook(self, destination, prefix, local_metadata)
635 if hook_result
is not None:
636 destination = hook_result
639 def _register_load_state_dict_pre_hook(self, hook):
640 r"""These hooks will be called with arguments: `state_dict`, `prefix`, 641 `local_metadata`, `strict`, `missing_keys`, `unexpected_keys`, 642 `error_msgs`, before loading `state_dict` into `self`. These arguments 643 are exactly the same as those of `_load_from_state_dict`. 649 def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
650 missing_keys, unexpected_keys, error_msgs):
651 r"""Copies parameters and buffers from :attr:`state_dict` into only 652 this module, but not its descendants. This is called on every submodule 653 in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this 654 module in input :attr:`state_dict` is provided as :attr:`local_metadata`. 655 For state dicts without metadata, :attr:`local_metadata` is empty. 656 Subclasses can achieve class-specific backward compatible loading using 657 the version number at `local_metadata.get("version", None)`. 660 :attr:`state_dict` is not the same object as the input 661 :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So 665 state_dict (dict): a dict containing parameters and 667 prefix (str): the prefix for parameters and buffers used in this 669 local_metadata (dict): a dict containing the metadata for this module. 671 strict (bool): whether to strictly enforce that the keys in 672 :attr:`state_dict` with :attr:`prefix` match the names of 673 parameters and buffers in this module 674 missing_keys (list of str): if ``strict=True``, add missing keys to 676 unexpected_keys (list of str): if ``strict=True``, add unexpected 678 error_msgs (list of str): error messages should be added to this 679 list, and will be reported together in 680 :meth:`~torch.nn.Module.load_state_dict` 682 for hook
in self._load_state_dict_pre_hooks.values():
683 hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
685 local_name_params = itertools.chain(self._parameters.items(), self._buffers.items())
686 local_state = {k: v.data
for k, v
in local_name_params
if v
is not None}
688 for name, param
in local_state.items():
690 if key
in state_dict:
691 input_param = state_dict[key]
694 if len(param.shape) == 0
and len(input_param.shape) == 1:
695 input_param = input_param[0]
697 if input_param.shape != param.shape:
699 error_msgs.append(
'size mismatch for {}: copying a param with shape {} from checkpoint, ' 700 'the shape in current model is {}.' 701 .format(key, input_param.shape, param.shape))
704 if isinstance(input_param, Parameter):
706 input_param = input_param.data
708 param.copy_(input_param)
710 error_msgs.append(
'While copying the parameter named "{}", ' 711 'whose dimensions in the model are {} and ' 712 'whose dimensions in the checkpoint are {}.' 713 .format(key, param.size(), input_param.size()))
715 missing_keys.append(key)
718 for key
in state_dict.keys():
719 if key.startswith(prefix):
720 input_name = key[len(prefix):]
721 input_name = input_name.split(
'.', 1)[0]
722 if input_name
not in self.
_modules and input_name
not in local_state:
723 unexpected_keys.append(key)
725 def load_state_dict(self, state_dict, strict=True):
726 r"""Copies parameters and buffers from :attr:`state_dict` into 727 this module and its descendants. If :attr:`strict` is ``True``, then 728 the keys of :attr:`state_dict` must exactly match the keys returned 729 by this module's :meth:`~torch.nn.Module.state_dict` function. 732 state_dict (dict): a dict containing parameters and 734 strict (bool, optional): whether to strictly enforce that the keys 735 in :attr:`state_dict` match the keys returned by this module's 736 :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` 743 metadata = getattr(state_dict,
'_metadata',
None)
744 state_dict = state_dict.copy()
745 if metadata
is not None:
746 state_dict._metadata = metadata
748 def load(module, prefix=''):
749 local_metadata = {}
if metadata
is None else metadata.get(prefix[:-1], {})
750 module._load_from_state_dict(
751 state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
752 for name, child
in module._modules.items():
753 if child
is not None:
754 load(child, prefix + name +
'.')
760 if len(unexpected_keys) > 0:
762 0,
'Unexpected key(s) in state_dict: {}. '.format(
763 ', '.join(
'"{}"'.format(k)
for k
in unexpected_keys)))
764 if len(missing_keys) > 0:
766 0,
'Missing key(s) in state_dict: {}. '.format(
767 ', '.join(
'"{}"'.format(k)
for k
in missing_keys)))
769 if len(error_msgs) > 0:
770 raise RuntimeError(
'Error(s) in loading state_dict for {}:\n\t{}'.format(
771 self.__class__.__name__,
"\n\t".join(error_msgs)))
773 def _named_members(self, get_members_fn, prefix='', recurse=True):
774 r"""Helper method for yielding various names + members of modules.""" 776 modules = self.
named_modules(prefix=prefix)
if recurse
else [(prefix, self)]
777 for module_prefix, module
in modules:
778 members = get_members_fn(module)
780 if v
is None or v
in memo:
783 name = module_prefix + (
'.' if module_prefix
else '') + k
786 def parameters(self, recurse=True):
787 r"""Returns an iterator over module parameters. 789 This is typically passed to an optimizer. 792 recurse (bool): if True, then yields parameters of this module 793 and all submodules. Otherwise, yields only parameters that 794 are direct members of this module. 797 Parameter: module parameter 801 >>> for param in model.parameters(): 802 >>> print(type(param.data), param.size()) 803 <class 'torch.FloatTensor'> (20L,) 804 <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L) 810 def named_parameters(self, prefix='', recurse=True):
811 r"""Returns an iterator over module parameters, yielding both the 812 name of the parameter as well as the parameter itself. 815 prefix (str): prefix to prepend to all parameter names. 816 recurse (bool): if True, then yields parameters of this module 817 and all submodules. Otherwise, yields only parameters that 818 are direct members of this module. 821 (string, Parameter): Tuple containing the name and parameter 825 >>> for name, param in self.named_parameters(): 826 >>> if name in ['bias']: 827 >>> print(param.size()) 831 lambda module: module._parameters.items(),
832 prefix=prefix, recurse=recurse)
836 def buffers(self, recurse=True):
837 r"""Returns an iterator over module buffers. 840 recurse (bool): if True, then yields buffers of this module 841 and all submodules. Otherwise, yields only buffers that 842 are direct members of this module. 845 torch.Tensor: module buffer 849 >>> for buf in model.buffers(): 850 >>> print(type(buf.data), buf.size()) 851 <class 'torch.FloatTensor'> (20L,) 852 <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L) 858 def named_buffers(self, prefix='', recurse=True):
859 r"""Returns an iterator over module buffers, yielding both the 860 name of the buffer as well as the buffer itself. 863 prefix (str): prefix to prepend to all buffer names. 864 recurse (bool): if True, then yields buffers of this module 865 and all submodules. Otherwise, yields only buffers that 866 are direct members of this module. 869 (string, torch.Tensor): Tuple containing the name and buffer 873 >>> for name, buf in self.named_buffers(): 874 >>> if name in ['running_var']: 875 >>> print(buf.size()) 879 lambda module: module._buffers.items(),
880 prefix=prefix, recurse=recurse)
885 r"""Returns an iterator over immediate children modules. 888 Module: a child module 893 def named_children(self):
894 r"""Returns an iterator over immediate children modules, yielding both 895 the name of the module as well as the module itself. 898 (string, Module): Tuple containing a name and child module 902 >>> for name, module in model.named_children(): 903 >>> if name in ['conv4', 'conv5']: 908 for name, module
in self._modules.items():
909 if module
is not None and module
not in memo:
914 r"""Returns an iterator over all modules in the network. 917 Module: a module in the network 920 Duplicate modules are returned only once. In the following 921 example, ``l`` will be returned only once. 925 >>> l = nn.Linear(2, 2) 926 >>> net = nn.Sequential(l, l) 927 >>> for idx, m in enumerate(net.modules()): 931 (0): Linear(in_features=2, out_features=2, bias=True) 932 (1): Linear(in_features=2, out_features=2, bias=True) 934 1 -> Linear(in_features=2, out_features=2, bias=True) 940 def named_modules(self, memo=None, prefix=''):
941 r"""Returns an iterator over all modules in the network, yielding 942 both the name of the module as well as the module itself. 945 (string, Module): Tuple of name and module 948 Duplicate modules are returned only once. In the following 949 example, ``l`` will be returned only once. 953 >>> l = nn.Linear(2, 2) 954 >>> net = nn.Sequential(l, l) 955 >>> for idx, m in enumerate(net.named_modules()): 958 0 -> ('', Sequential( 959 (0): Linear(in_features=2, out_features=2, bias=True) 960 (1): Linear(in_features=2, out_features=2, bias=True) 962 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) 971 for name, module
in self._modules.items():
974 submodule_prefix = prefix + (
'.' if prefix
else '') + name
975 for m
in module.named_modules(memo, submodule_prefix):
978 def train(self, mode=True):
979 r"""Sets the module in training mode. 981 This has any effect only on certain modules. See documentations of 982 particular modules for details of their behaviors in training/evaluation 983 mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, 995 r"""Sets the module in evaluation mode. 997 This has any effect only on certain modules. See documentations of 998 particular modules for details of their behaviors in training/evaluation 999 mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, 1002 return self.
train(
False)
1004 def zero_grad(self):
1005 r"""Sets gradients of all model parameters to zero.""" 1007 if p.grad
is not None:
1011 def share_memory(self):
1012 return self.
_apply(
lambda t: t.share_memory_())
1014 def _get_name(self):
1015 return self.__class__.__name__
1017 def extra_repr(self):
1018 r"""Set the extra representation of the module 1020 To print customized extra information, you should reimplement 1021 this method in your own modules. Both single-line and multi-line 1022 strings are acceptable. 1032 extra_lines = extra_repr.split(
'\n')
1034 for key, module
in self._modules.items():
1035 mod_str = repr(module)
1036 mod_str = _addindent(mod_str, 2)
1037 child_lines.append(
'(' + key +
'): ' + mod_str)
1038 lines = extra_lines + child_lines
1043 if len(extra_lines) == 1
and not child_lines:
1044 main_str += extra_lines[0]
1046 main_str +=
'\n ' +
'\n '.join(lines) +
'\n' 1052 module_attrs = dir(self.__class__)
1053 attrs = list(self.__dict__.keys())
1054 parameters = list(self._parameters.keys())
1055 modules = list(self._modules.keys())
1056 buffers = list(self._buffers.keys())
1057 keys = module_attrs + attrs + parameters + modules + buffers
1060 keys = [key
for key
in keys
if not key[0].isdigit()]
def named_modules(self, memo=None, prefix='')
def named_buffers(self, prefix='', recurse=True)
_load_state_dict_pre_hooks
def _named_members(self, get_members_fn, prefix='', recurse=True)
def register_parameter(self, name, param)
Module caffe2.python.helpers.train.
def _tracing_name(self, tracing_state)
def parameters(self, recurse=True)
def _slow_forward(self, input, kwargs)
def named_parameters(self, prefix='', recurse=True)
def train(self, mode=True)
def typename(o)
Define basic utilities.