Caffe2 - Python API
A deep learning, cross platform ML framework
module.py
1 from collections import OrderedDict
2 import functools
3 import itertools
4 
5 import torch
6 from ..backends.thnn import backend as thnn_backend
7 from ..parameter import Parameter
8 import torch.utils.hooks as hooks
9 
10 
11 def _addindent(s_, numSpaces):
12  s = s_.split('\n')
13  # don't do anything for single-line stuff
14  if len(s) == 1:
15  return s_
16  first = s.pop(0)
17  s = [(numSpaces * ' ') + line for line in s]
18  s = '\n'.join(s)
19  s = first + '\n' + s
20  return s
21 
22 
23 class Module(object):
24  r"""Base class for all neural network modules.
25 
26  Your models should also subclass this class.
27 
28  Modules can also contain other Modules, allowing to nest them in
29  a tree structure. You can assign the submodules as regular attributes::
30 
31  import torch.nn as nn
32  import torch.nn.functional as F
33 
34  class Model(nn.Module):
35  def __init__(self):
36  super(Model, self).__init__()
37  self.conv1 = nn.Conv2d(1, 20, 5)
38  self.conv2 = nn.Conv2d(20, 20, 5)
39 
40  def forward(self, x):
41  x = F.relu(self.conv1(x))
42  return F.relu(self.conv2(x))
43 
44  Submodules assigned in this way will be registered, and will have their
45  parameters converted too when you call :meth:`to`, etc.
46  """
47 
48  dump_patches = False
49 
50  r"""This allows better BC support for :meth:`load_state_dict`. In
51  :meth:`state_dict`, the version number will be saved as in the attribute
52  `_metadata` of the returned state dict, and thus pickled. `_metadata` is a
53  dictionary with keys that follow the naming convention of state dict. See
54  ``_load_from_state_dict`` on how to use this information in loading.
55 
56  If new parameters/buffers are added/removed from a module, this number shall
57  be bumped, and the module's `_load_from_state_dict` method can compare the
58  version number and do appropriate changes if the state dict is from before
59  the change."""
60  _version = 1
61 
62  def __init__(self):
63  self._backend = thnn_backend
64  self._parameters = OrderedDict()
65  self._buffers = OrderedDict()
66  self._backward_hooks = OrderedDict()
67  self._forward_hooks = OrderedDict()
68  self._forward_pre_hooks = OrderedDict()
69  self._state_dict_hooks = OrderedDict()
70  self._load_state_dict_pre_hooks = OrderedDict()
71  self._modules = OrderedDict()
72  self.training = True
73 
74  def forward(self, *input):
75  r"""Defines the computation performed at every call.
76 
77  Should be overridden by all subclasses.
78 
79  .. note::
80  Although the recipe for forward pass needs to be defined within
81  this function, one should call the :class:`Module` instance afterwards
82  instead of this since the former takes care of running the
83  registered hooks while the latter silently ignores them.
84  """
85  raise NotImplementedError
86 
87  def register_buffer(self, name, tensor):
88  r"""Adds a persistent buffer to the module.
89 
90  This is typically used to register a buffer that should not to be
91  considered a model parameter. For example, BatchNorm's ``running_mean``
92  is not a parameter, but is part of the persistent state.
93 
94  Buffers can be accessed as attributes using given names.
95 
96  Args:
97  name (string): name of the buffer. The buffer can be accessed
98  from this module using the given name
99  tensor (Tensor): buffer to be registered.
100 
101  Example::
102 
103  >>> self.register_buffer('running_mean', torch.zeros(num_features))
104 
105  """
106  if '_buffers' not in self.__dict__:
107  raise AttributeError(
108  "cannot assign buffer before Module.__init__() call")
109  elif not isinstance(name, torch._six.string_classes):
110  raise TypeError("buffer name should be a string. "
111  "Got {}".format(torch.typename(name)))
112  elif '.' in name:
113  raise KeyError("buffer name can't contain \".\"")
114  elif name == '':
115  raise KeyError("buffer name can't be empty string \"\"")
116  elif hasattr(self, name) and name not in self._buffers:
117  raise KeyError("attribute '{}' already exists".format(name))
118  elif tensor is not None and not isinstance(tensor, torch.Tensor):
119  raise TypeError("cannot assign '{}' object to buffer '{}' "
120  "(torch Tensor or None required)"
121  .format(torch.typename(tensor), name))
122  else:
123  self._buffers[name] = tensor
124 
125  def register_parameter(self, name, param):
126  r"""Adds a parameter to the module.
127 
128  The parameter can be accessed as an attribute using given name.
129 
130  Args:
131  name (string): name of the parameter. The parameter can be accessed
132  from this module using the given name
133  param (Parameter): parameter to be added to the module.
134  """
135  if '_parameters' not in self.__dict__:
136  raise AttributeError(
137  "cannot assign parameter before Module.__init__() call")
138 
139  elif not isinstance(name, torch._six.string_classes):
140  raise TypeError("parameter name should be a string. "
141  "Got {}".format(torch.typename(name)))
142  elif '.' in name:
143  raise KeyError("parameter name can't contain \".\"")
144  elif name == '':
145  raise KeyError("parameter name can't be empty string \"\"")
146  elif hasattr(self, name) and name not in self._parameters:
147  raise KeyError("attribute '{}' already exists".format(name))
148 
149  if param is None:
150  self._parameters[name] = None
151  elif not isinstance(param, Parameter):
152  raise TypeError("cannot assign '{}' object to parameter '{}' "
153  "(torch.nn.Parameter or None required)"
154  .format(torch.typename(param), name))
155  elif param.grad_fn:
156  raise ValueError(
157  "Cannot assign non-leaf Tensor to parameter '{0}'. Model "
158  "parameters must be created explicitly. To express '{0}' "
159  "as a function of another Tensor, compute the value in "
160  "the forward() method.".format(name))
161  else:
162  self._parameters[name] = param
163 
164  def add_module(self, name, module):
165  r"""Adds a child module to the current module.
166 
167  The module can be accessed as an attribute using the given name.
168 
169  Args:
170  name (string): name of the child module. The child module can be
171  accessed from this module using the given name
172  module (Module): child module to be added to the module.
173  """
174  if not isinstance(module, Module) and module is not None:
175  raise TypeError("{} is not a Module subclass".format(
176  torch.typename(module)))
177  elif not isinstance(name, torch._six.string_classes):
178  raise TypeError("module name should be a string. Got {}".format(
179  torch.typename(name)))
180  elif hasattr(self, name) and name not in self._modules:
181  raise KeyError("attribute '{}' already exists".format(name))
182  elif '.' in name:
183  raise KeyError("module name can't contain \".\"")
184  elif name == '':
185  raise KeyError("module name can't be empty string \"\"")
186  self._modules[name] = module
187 
188  def _apply(self, fn):
189  for module in self.children():
190  module._apply(fn)
191 
192  for param in self._parameters.values():
193  if param is not None:
194  # Tensors stored in modules are graph leaves, and we don't
195  # want to create copy nodes, so we have to unpack the data.
196  param.data = fn(param.data)
197  if param._grad is not None:
198  param._grad.data = fn(param._grad.data)
199 
200  for key, buf in self._buffers.items():
201  if buf is not None:
202  self._buffers[key] = fn(buf)
203 
204  return self
205 
206  def apply(self, fn):
207  r"""Applies ``fn`` recursively to every submodule (as returned by ``.children()``)
208  as well as self. Typical use includes initializing the parameters of a model
209  (see also :ref:`torch-nn-init`).
210 
211  Args:
212  fn (:class:`Module` -> None): function to be applied to each submodule
213 
214  Returns:
215  Module: self
216 
217  Example::
218 
219  >>> def init_weights(m):
220  print(m)
221  if type(m) == nn.Linear:
222  m.weight.data.fill_(1.0)
223  print(m.weight)
224 
225  >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
226  >>> net.apply(init_weights)
227  Linear(in_features=2, out_features=2, bias=True)
228  Parameter containing:
229  tensor([[ 1., 1.],
230  [ 1., 1.]])
231  Linear(in_features=2, out_features=2, bias=True)
232  Parameter containing:
233  tensor([[ 1., 1.],
234  [ 1., 1.]])
235  Sequential(
236  (0): Linear(in_features=2, out_features=2, bias=True)
237  (1): Linear(in_features=2, out_features=2, bias=True)
238  )
239  Sequential(
240  (0): Linear(in_features=2, out_features=2, bias=True)
241  (1): Linear(in_features=2, out_features=2, bias=True)
242  )
243  """
244  for module in self.children():
245  module.apply(fn)
246  fn(self)
247  return self
248 
249  def cuda(self, device=None):
250  r"""Moves all model parameters and buffers to the GPU.
251 
252  This also makes associated parameters and buffers different objects. So
253  it should be called before constructing optimizer if the module will
254  live on GPU while being optimized.
255 
256  Arguments:
257  device (int, optional): if specified, all parameters will be
258  copied to that device
259 
260  Returns:
261  Module: self
262  """
263  return self._apply(lambda t: t.cuda(device))
264 
265  def cpu(self):
266  r"""Moves all model parameters and buffers to the CPU.
267 
268  Returns:
269  Module: self
270  """
271  return self._apply(lambda t: t.cpu())
272 
273  def type(self, dst_type):
274  r"""Casts all parameters and buffers to :attr:`dst_type`.
275 
276  Arguments:
277  dst_type (type or string): the desired type
278 
279  Returns:
280  Module: self
281  """
282  return self._apply(lambda t: t.type(dst_type))
283 
284  def float(self):
285  r"""Casts all floating point parameters and buffers to float datatype.
286 
287  Returns:
288  Module: self
289  """
290  return self._apply(lambda t: t.float() if t.is_floating_point() else t)
291 
292  def double(self):
293  r"""Casts all floating point parameters and buffers to ``double`` datatype.
294 
295  Returns:
296  Module: self
297  """
298  return self._apply(lambda t: t.double() if t.is_floating_point() else t)
299 
300  def half(self):
301  r"""Casts all floating point parameters and buffers to ``half`` datatype.
302 
303  Returns:
304  Module: self
305  """
306  return self._apply(lambda t: t.half() if t.is_floating_point() else t)
307 
308  def to(self, *args, **kwargs):
309  r"""Moves and/or casts the parameters and buffers.
310 
311  This can be called as
312 
313  .. function:: to(device=None, dtype=None, non_blocking=False)
314 
315  .. function:: to(dtype, non_blocking=False)
316 
317  .. function:: to(tensor, non_blocking=False)
318 
319  Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
320  floating point desired :attr:`dtype` s. In addition, this method will
321  only cast the floating point parameters and buffers to :attr:`dtype`
322  (if given). The integral parameters and buffers will be moved
323  :attr:`device`, if that is given, but with dtypes unchanged. When
324  :attr:`non_blocking` is set, it tries to convert/move asynchronously
325  with respect to the host if possible, e.g., moving CPU Tensors with
326  pinned memory to CUDA devices.
327 
328  See below for examples.
329 
330  .. note::
331  This method modifies the module in-place.
332 
333  Args:
334  device (:class:`torch.device`): the desired device of the parameters
335  and buffers in this module
336  dtype (:class:`torch.dtype`): the desired floating point type of
337  the floating point parameters and buffers in this module
338  tensor (torch.Tensor): Tensor whose dtype and device are the desired
339  dtype and device for all parameters and buffers in this module
340 
341  Returns:
342  Module: self
343 
344  Example::
345 
346  >>> linear = nn.Linear(2, 2)
347  >>> linear.weight
348  Parameter containing:
349  tensor([[ 0.1913, -0.3420],
350  [-0.5113, -0.2325]])
351  >>> linear.to(torch.double)
352  Linear(in_features=2, out_features=2, bias=True)
353  >>> linear.weight
354  Parameter containing:
355  tensor([[ 0.1913, -0.3420],
356  [-0.5113, -0.2325]], dtype=torch.float64)
357  >>> gpu1 = torch.device("cuda:1")
358  >>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
359  Linear(in_features=2, out_features=2, bias=True)
360  >>> linear.weight
361  Parameter containing:
362  tensor([[ 0.1914, -0.3420],
363  [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
364  >>> cpu = torch.device("cpu")
365  >>> linear.to(cpu)
366  Linear(in_features=2, out_features=2, bias=True)
367  >>> linear.weight
368  Parameter containing:
369  tensor([[ 0.1914, -0.3420],
370  [-0.5112, -0.2324]], dtype=torch.float16)
371 
372  """
373 
374  device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)
375 
376  if dtype is not None:
377  if not dtype.is_floating_point:
378  raise TypeError('nn.Module.to only accepts floating point '
379  'dtypes, but got desired dtype={}'.format(dtype))
380 
381  def convert(t):
382  return t.to(device, dtype if t.is_floating_point() else None, non_blocking)
383 
384  return self._apply(convert)
385 
386  def register_backward_hook(self, hook):
387  r"""Registers a backward hook on the module.
388 
389  The hook will be called every time the gradients with respect to module
390  inputs are computed. The hook should have the following signature::
391 
392  hook(module, grad_input, grad_output) -> Tensor or None
393 
394  The :attr:`grad_input` and :attr:`grad_output` may be tuples if the
395  module has multiple inputs or outputs. The hook should not modify its
396  arguments, but it can optionally return a new gradient with respect to
397  input that will be used in place of :attr:`grad_input` in subsequent
398  computations.
399 
400  Returns:
401  :class:`torch.utils.hooks.RemovableHandle`:
402  a handle that can be used to remove the added hook by calling
403  ``handle.remove()``
404 
405  .. warning ::
406 
407  The current implementation will not have the presented behavior
408  for complex :class:`Module` that perform many operations.
409  In some failure cases, :attr:`grad_input` and :attr:`grad_output` will only
410  contain the gradients for a subset of the inputs and outputs.
411  For such :class:`Module`, you should use :func:`torch.Tensor.register_hook`
412  directly on a specific input or output to get the required gradients.
413 
414  """
415  handle = hooks.RemovableHandle(self._backward_hooks)
416  self._backward_hooks[handle.id] = hook
417  return handle
418 
419  def register_forward_pre_hook(self, hook):
420  r"""Registers a forward pre-hook on the module.
421 
422  The hook will be called every time before :func:`forward` is invoked.
423  It should have the following signature::
424 
425  hook(module, input) -> None
426 
427  The hook should not modify the input.
428 
429  Returns:
430  :class:`torch.utils.hooks.RemovableHandle`:
431  a handle that can be used to remove the added hook by calling
432  ``handle.remove()``
433  """
434  handle = hooks.RemovableHandle(self._forward_pre_hooks)
435  self._forward_pre_hooks[handle.id] = hook
436  return handle
437 
438  def register_forward_hook(self, hook):
439  r"""Registers a forward hook on the module.
440 
441  The hook will be called every time after :func:`forward` has computed an output.
442  It should have the following signature::
443 
444  hook(module, input, output) -> None
445 
446  The hook should not modify the input or output.
447 
448  Returns:
449  :class:`torch.utils.hooks.RemovableHandle`:
450  a handle that can be used to remove the added hook by calling
451  ``handle.remove()``
452  """
453  handle = hooks.RemovableHandle(self._forward_hooks)
454  self._forward_hooks[handle.id] = hook
455  return handle
456 
457  def _tracing_name(self, tracing_state):
458  if not tracing_state._traced_module_stack:
459  return None
460  module = tracing_state._traced_module_stack[-1]
461  for name, child in module.named_children():
462  if child is self:
463  return name
464  return None
465 
466  def _slow_forward(self, *input, **kwargs):
467  tracing_state = torch._C._get_tracing_state()
468  if not tracing_state:
469  return self.forward(*input, **kwargs)
470  if not hasattr(tracing_state, '_traced_module_stack'):
471  tracing_state._traced_module_stack = []
472  name = self._tracing_name(tracing_state)
473  if name:
474  tracing_state.push_scope('%s[%s]' % (self._get_name(), name))
475  else:
476  tracing_state.push_scope(self._get_name())
477  tracing_state._traced_module_stack.append(self)
478  try:
479  result = self.forward(*input, **kwargs)
480  finally:
481  tracing_state.pop_scope()
482  tracing_state._traced_module_stack.pop()
483  return result
484 
485  def __call__(self, *input, **kwargs):
486  for hook in self._forward_pre_hooks.values():
487  hook(self, input)
488  if torch._C._get_tracing_state():
489  result = self._slow_forward(*input, **kwargs)
490  else:
491  result = self.forward(*input, **kwargs)
492  for hook in self._forward_hooks.values():
493  hook_result = hook(self, input, result)
494  if hook_result is not None:
495  raise RuntimeError(
496  "forward hooks should never return any values, but '{}'"
497  "didn't return None".format(hook))
498  if len(self._backward_hooks) > 0:
499  var = result
500  while not isinstance(var, torch.Tensor):
501  if isinstance(var, dict):
502  var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
503  else:
504  var = var[0]
505  grad_fn = var.grad_fn
506  if grad_fn is not None:
507  for hook in self._backward_hooks.values():
508  wrapper = functools.partial(hook, self)
509  functools.update_wrapper(wrapper, hook)
510  grad_fn.register_hook(wrapper)
511  return result
512 
513  def __setstate__(self, state):
514  self.__dict__.update(state)
515  # Support loading old checkpoints that don't have the following attrs:
516  if '_forward_pre_hooks' not in self.__dict__:
517  self._forward_pre_hooks = OrderedDict()
518  if '_state_dict_hooks' not in self.__dict__:
519  self._state_dict_hooks = OrderedDict()
520  if '_load_state_dict_pre_hooks' not in self.__dict__:
521  self._load_state_dict_pre_hooks = OrderedDict()
522 
523  def __getattr__(self, name):
524  if '_parameters' in self.__dict__:
525  _parameters = self.__dict__['_parameters']
526  if name in _parameters:
527  return _parameters[name]
528  if '_buffers' in self.__dict__:
529  _buffers = self.__dict__['_buffers']
530  if name in _buffers:
531  return _buffers[name]
532  if '_modules' in self.__dict__:
533  modules = self.__dict__['_modules']
534  if name in modules:
535  return modules[name]
536  raise AttributeError("'{}' object has no attribute '{}'".format(
537  type(self).__name__, name))
538 
539  def __setattr__(self, name, value):
540  def remove_from(*dicts):
541  for d in dicts:
542  if name in d:
543  del d[name]
544 
545  params = self.__dict__.get('_parameters')
546  if isinstance(value, Parameter):
547  if params is None:
548  raise AttributeError(
549  "cannot assign parameters before Module.__init__() call")
550  remove_from(self.__dict__, self._buffers, self._modules)
551  self.register_parameter(name, value)
552  elif params is not None and name in params:
553  if value is not None:
554  raise TypeError("cannot assign '{}' as parameter '{}' "
555  "(torch.nn.Parameter or None expected)"
556  .format(torch.typename(value), name))
557  self.register_parameter(name, value)
558  else:
559  modules = self.__dict__.get('_modules')
560  if isinstance(value, Module):
561  if modules is None:
562  raise AttributeError(
563  "cannot assign module before Module.__init__() call")
564  remove_from(self.__dict__, self._parameters, self._buffers)
565  modules[name] = value
566  elif modules is not None and name in modules:
567  if value is not None:
568  raise TypeError("cannot assign '{}' as child module '{}' "
569  "(torch.nn.Module or None expected)"
570  .format(torch.typename(value), name))
571  modules[name] = value
572  else:
573  buffers = self.__dict__.get('_buffers')
574  if buffers is not None and name in buffers:
575  if value is not None and not isinstance(value, torch.Tensor):
576  raise TypeError("cannot assign '{}' as buffer '{}' "
577  "(torch.Tensor or None expected)"
578  .format(torch.typename(value), name))
579  buffers[name] = value
580  else:
581  object.__setattr__(self, name, value)
582 
583  def __delattr__(self, name):
584  if name in self._parameters:
585  del self._parameters[name]
586  elif name in self._buffers:
587  del self._buffers[name]
588  elif name in self._modules:
589  del self._modules[name]
590  else:
591  object.__delattr__(self, name)
592 
593  def _register_state_dict_hook(self, hook):
594  r"""These hooks will be called with arguments: `self`, `state_dict`,
595  `prefix`, `local_metadata`, after the `state_dict` of `self` is set.
596  Note that only parameters and buffers of `self` or its children are
597  guaranteed to exist in `state_dict`. The hooks may modify `state_dict`
598  inplace or return a new one.
599  """
600  handle = hooks.RemovableHandle(self._state_dict_hooks)
601  self._state_dict_hooks[handle.id] = hook
602  return handle
603 
604  def state_dict(self, destination=None, prefix='', keep_vars=False):
605  r"""Returns a dictionary containing a whole state of the module.
606 
607  Both parameters and persistent buffers (e.g. running averages) are
608  included. Keys are corresponding parameter and buffer names.
609 
610  Returns:
611  dict:
612  a dictionary containing a whole state of the module
613 
614  Example::
615 
616  >>> module.state_dict().keys()
617  ['bias', 'weight']
618 
619  """
620  if destination is None:
621  destination = OrderedDict()
622  destination._metadata = OrderedDict()
623  destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
624  for name, param in self._parameters.items():
625  if param is not None:
626  destination[prefix + name] = param if keep_vars else param.data
627  for name, buf in self._buffers.items():
628  if buf is not None:
629  destination[prefix + name] = buf if keep_vars else buf.data
630  for name, module in self._modules.items():
631  if module is not None:
632  module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
633  for hook in self._state_dict_hooks.values():
634  hook_result = hook(self, destination, prefix, local_metadata)
635  if hook_result is not None:
636  destination = hook_result
637  return destination
638 
639  def _register_load_state_dict_pre_hook(self, hook):
640  r"""These hooks will be called with arguments: `state_dict`, `prefix`,
641  `local_metadata`, `strict`, `missing_keys`, `unexpected_keys`,
642  `error_msgs`, before loading `state_dict` into `self`. These arguments
643  are exactly the same as those of `_load_from_state_dict`.
644  """
645  handle = hooks.RemovableHandle(self._load_state_dict_pre_hooks)
646  self._load_state_dict_pre_hooks[handle.id] = hook
647  return handle
648 
649  def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
650  missing_keys, unexpected_keys, error_msgs):
651  r"""Copies parameters and buffers from :attr:`state_dict` into only
652  this module, but not its descendants. This is called on every submodule
653  in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
654  module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
655  For state dicts without metadata, :attr:`local_metadata` is empty.
656  Subclasses can achieve class-specific backward compatible loading using
657  the version number at `local_metadata.get("version", None)`.
658 
659  .. note::
660  :attr:`state_dict` is not the same object as the input
661  :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
662  it can be modified.
663 
664  Arguments:
665  state_dict (dict): a dict containing parameters and
666  persistent buffers.
667  prefix (str): the prefix for parameters and buffers used in this
668  module
669  local_metadata (dict): a dict containing the metadata for this module.
670  See
671  strict (bool): whether to strictly enforce that the keys in
672  :attr:`state_dict` with :attr:`prefix` match the names of
673  parameters and buffers in this module
674  missing_keys (list of str): if ``strict=True``, add missing keys to
675  this list
676  unexpected_keys (list of str): if ``strict=True``, add unexpected
677  keys to this list
678  error_msgs (list of str): error messages should be added to this
679  list, and will be reported together in
680  :meth:`~torch.nn.Module.load_state_dict`
681  """
682  for hook in self._load_state_dict_pre_hooks.values():
683  hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
684 
685  local_name_params = itertools.chain(self._parameters.items(), self._buffers.items())
686  local_state = {k: v.data for k, v in local_name_params if v is not None}
687 
688  for name, param in local_state.items():
689  key = prefix + name
690  if key in state_dict:
691  input_param = state_dict[key]
692 
693  # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
694  if len(param.shape) == 0 and len(input_param.shape) == 1:
695  input_param = input_param[0]
696 
697  if input_param.shape != param.shape:
698  # local shape should match the one in checkpoint
699  error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
700  'the shape in current model is {}.'
701  .format(key, input_param.shape, param.shape))
702  continue
703 
704  if isinstance(input_param, Parameter):
705  # backwards compatibility for serialized parameters
706  input_param = input_param.data
707  try:
708  param.copy_(input_param)
709  except Exception:
710  error_msgs.append('While copying the parameter named "{}", '
711  'whose dimensions in the model are {} and '
712  'whose dimensions in the checkpoint are {}.'
713  .format(key, param.size(), input_param.size()))
714  elif strict:
715  missing_keys.append(key)
716 
717  if strict:
718  for key in state_dict.keys():
719  if key.startswith(prefix):
720  input_name = key[len(prefix):]
721  input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
722  if input_name not in self._modules and input_name not in local_state:
723  unexpected_keys.append(key)
724 
725  def load_state_dict(self, state_dict, strict=True):
726  r"""Copies parameters and buffers from :attr:`state_dict` into
727  this module and its descendants. If :attr:`strict` is ``True``, then
728  the keys of :attr:`state_dict` must exactly match the keys returned
729  by this module's :meth:`~torch.nn.Module.state_dict` function.
730 
731  Arguments:
732  state_dict (dict): a dict containing parameters and
733  persistent buffers.
734  strict (bool, optional): whether to strictly enforce that the keys
735  in :attr:`state_dict` match the keys returned by this module's
736  :meth:`~torch.nn.Module.state_dict` function. Default: ``True``
737  """
738  missing_keys = []
739  unexpected_keys = []
740  error_msgs = []
741 
742  # copy state_dict so _load_from_state_dict can modify it
743  metadata = getattr(state_dict, '_metadata', None)
744  state_dict = state_dict.copy()
745  if metadata is not None:
746  state_dict._metadata = metadata
747 
748  def load(module, prefix=''):
749  local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
750  module._load_from_state_dict(
751  state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
752  for name, child in module._modules.items():
753  if child is not None:
754  load(child, prefix + name + '.')
755 
756  load(self)
757 
758  if strict:
759  error_msg = ''
760  if len(unexpected_keys) > 0:
761  error_msgs.insert(
762  0, 'Unexpected key(s) in state_dict: {}. '.format(
763  ', '.join('"{}"'.format(k) for k in unexpected_keys)))
764  if len(missing_keys) > 0:
765  error_msgs.insert(
766  0, 'Missing key(s) in state_dict: {}. '.format(
767  ', '.join('"{}"'.format(k) for k in missing_keys)))
768 
769  if len(error_msgs) > 0:
770  raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
771  self.__class__.__name__, "\n\t".join(error_msgs)))
772 
773  def _named_members(self, get_members_fn, prefix='', recurse=True):
774  r"""Helper method for yielding various names + members of modules."""
775  memo = set()
776  modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
777  for module_prefix, module in modules:
778  members = get_members_fn(module)
779  for k, v in members:
780  if v is None or v in memo:
781  continue
782  memo.add(v)
783  name = module_prefix + ('.' if module_prefix else '') + k
784  yield name, v
785 
786  def parameters(self, recurse=True):
787  r"""Returns an iterator over module parameters.
788 
789  This is typically passed to an optimizer.
790 
791  Args:
792  recurse (bool): if True, then yields parameters of this module
793  and all submodules. Otherwise, yields only parameters that
794  are direct members of this module.
795 
796  Yields:
797  Parameter: module parameter
798 
799  Example::
800 
801  >>> for param in model.parameters():
802  >>> print(type(param.data), param.size())
803  <class 'torch.FloatTensor'> (20L,)
804  <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
805 
806  """
807  for name, param in self.named_parameters(recurse=recurse):
808  yield param
809 
810  def named_parameters(self, prefix='', recurse=True):
811  r"""Returns an iterator over module parameters, yielding both the
812  name of the parameter as well as the parameter itself.
813 
814  Args:
815  prefix (str): prefix to prepend to all parameter names.
816  recurse (bool): if True, then yields parameters of this module
817  and all submodules. Otherwise, yields only parameters that
818  are direct members of this module.
819 
820  Yields:
821  (string, Parameter): Tuple containing the name and parameter
822 
823  Example::
824 
825  >>> for name, param in self.named_parameters():
826  >>> if name in ['bias']:
827  >>> print(param.size())
828 
829  """
830  gen = self._named_members(
831  lambda module: module._parameters.items(),
832  prefix=prefix, recurse=recurse)
833  for elem in gen:
834  yield elem
835 
836  def buffers(self, recurse=True):
837  r"""Returns an iterator over module buffers.
838 
839  Args:
840  recurse (bool): if True, then yields buffers of this module
841  and all submodules. Otherwise, yields only buffers that
842  are direct members of this module.
843 
844  Yields:
845  torch.Tensor: module buffer
846 
847  Example::
848 
849  >>> for buf in model.buffers():
850  >>> print(type(buf.data), buf.size())
851  <class 'torch.FloatTensor'> (20L,)
852  <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
853 
854  """
855  for name, buf in self.named_buffers(recurse=recurse):
856  yield buf
857 
858  def named_buffers(self, prefix='', recurse=True):
859  r"""Returns an iterator over module buffers, yielding both the
860  name of the buffer as well as the buffer itself.
861 
862  Args:
863  prefix (str): prefix to prepend to all buffer names.
864  recurse (bool): if True, then yields buffers of this module
865  and all submodules. Otherwise, yields only buffers that
866  are direct members of this module.
867 
868  Yields:
869  (string, torch.Tensor): Tuple containing the name and buffer
870 
871  Example::
872 
873  >>> for name, buf in self.named_buffers():
874  >>> if name in ['running_var']:
875  >>> print(buf.size())
876 
877  """
878  gen = self._named_members(
879  lambda module: module._buffers.items(),
880  prefix=prefix, recurse=recurse)
881  for elem in gen:
882  yield elem
883 
884  def children(self):
885  r"""Returns an iterator over immediate children modules.
886 
887  Yields:
888  Module: a child module
889  """
890  for name, module in self.named_children():
891  yield module
892 
893  def named_children(self):
894  r"""Returns an iterator over immediate children modules, yielding both
895  the name of the module as well as the module itself.
896 
897  Yields:
898  (string, Module): Tuple containing a name and child module
899 
900  Example::
901 
902  >>> for name, module in model.named_children():
903  >>> if name in ['conv4', 'conv5']:
904  >>> print(module)
905 
906  """
907  memo = set()
908  for name, module in self._modules.items():
909  if module is not None and module not in memo:
910  memo.add(module)
911  yield name, module
912 
913  def modules(self):
914  r"""Returns an iterator over all modules in the network.
915 
916  Yields:
917  Module: a module in the network
918 
919  Note:
920  Duplicate modules are returned only once. In the following
921  example, ``l`` will be returned only once.
922 
923  Example::
924 
925  >>> l = nn.Linear(2, 2)
926  >>> net = nn.Sequential(l, l)
927  >>> for idx, m in enumerate(net.modules()):
928  print(idx, '->', m)
929 
930  0 -> Sequential(
931  (0): Linear(in_features=2, out_features=2, bias=True)
932  (1): Linear(in_features=2, out_features=2, bias=True)
933  )
934  1 -> Linear(in_features=2, out_features=2, bias=True)
935 
936  """
937  for name, module in self.named_modules():
938  yield module
939 
940  def named_modules(self, memo=None, prefix=''):
941  r"""Returns an iterator over all modules in the network, yielding
942  both the name of the module as well as the module itself.
943 
944  Yields:
945  (string, Module): Tuple of name and module
946 
947  Note:
948  Duplicate modules are returned only once. In the following
949  example, ``l`` will be returned only once.
950 
951  Example::
952 
953  >>> l = nn.Linear(2, 2)
954  >>> net = nn.Sequential(l, l)
955  >>> for idx, m in enumerate(net.named_modules()):
956  print(idx, '->', m)
957 
958  0 -> ('', Sequential(
959  (0): Linear(in_features=2, out_features=2, bias=True)
960  (1): Linear(in_features=2, out_features=2, bias=True)
961  ))
962  1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
963 
964  """
965 
966  if memo is None:
967  memo = set()
968  if self not in memo:
969  memo.add(self)
970  yield prefix, self
971  for name, module in self._modules.items():
972  if module is None:
973  continue
974  submodule_prefix = prefix + ('.' if prefix else '') + name
975  for m in module.named_modules(memo, submodule_prefix):
976  yield m
977 
978  def train(self, mode=True):
979  r"""Sets the module in training mode.
980 
981  This has any effect only on certain modules. See documentations of
982  particular modules for details of their behaviors in training/evaluation
983  mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
984  etc.
985 
986  Returns:
987  Module: self
988  """
989  self.training = mode
990  for module in self.children():
991  module.train(mode)
992  return self
993 
994  def eval(self):
995  r"""Sets the module in evaluation mode.
996 
997  This has any effect only on certain modules. See documentations of
998  particular modules for details of their behaviors in training/evaluation
999  mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
1000  etc.
1001  """
1002  return self.train(False)
1003 
1004  def zero_grad(self):
1005  r"""Sets gradients of all model parameters to zero."""
1006  for p in self.parameters():
1007  if p.grad is not None:
1008  p.grad.detach_()
1009  p.grad.zero_()
1010 
1011  def share_memory(self):
1012  return self._apply(lambda t: t.share_memory_())
1013 
1014  def _get_name(self):
1015  return self.__class__.__name__
1016 
1017  def extra_repr(self):
1018  r"""Set the extra representation of the module
1019 
1020  To print customized extra information, you should reimplement
1021  this method in your own modules. Both single-line and multi-line
1022  strings are acceptable.
1023  """
1024  return ''
1025 
1026  def __repr__(self):
1027  # We treat the extra repr like the sub-module, one item per line
1028  extra_lines = []
1029  extra_repr = self.extra_repr()
1030  # empty string will be split into list ['']
1031  if extra_repr:
1032  extra_lines = extra_repr.split('\n')
1033  child_lines = []
1034  for key, module in self._modules.items():
1035  mod_str = repr(module)
1036  mod_str = _addindent(mod_str, 2)
1037  child_lines.append('(' + key + '): ' + mod_str)
1038  lines = extra_lines + child_lines
1039 
1040  main_str = self._get_name() + '('
1041  if lines:
1042  # simple one-liner info, which most builtin Modules will use
1043  if len(extra_lines) == 1 and not child_lines:
1044  main_str += extra_lines[0]
1045  else:
1046  main_str += '\n ' + '\n '.join(lines) + '\n'
1047 
1048  main_str += ')'
1049  return main_str
1050 
1051  def __dir__(self):
1052  module_attrs = dir(self.__class__)
1053  attrs = list(self.__dict__.keys())
1054  parameters = list(self._parameters.keys())
1055  modules = list(self._modules.keys())
1056  buffers = list(self._buffers.keys())
1057  keys = module_attrs + attrs + parameters + modules + buffers
1058 
1059  # Eliminate attrs that are not legal Python variable names
1060  keys = [key for key in keys if not key[0].isdigit()]
1061 
1062  return sorted(keys)
def named_modules(self, memo=None, prefix='')
Definition: module.py:940
def named_buffers(self, prefix='', recurse=True)
Definition: module.py:858
def _named_members(self, get_members_fn, prefix='', recurse=True)
Definition: module.py:773
def forward(self, input)
Definition: module.py:74
def register_parameter(self, name, param)
Definition: module.py:125
Module caffe2.python.helpers.train.
def _tracing_name(self, tracing_state)
Definition: module.py:457
def parameters(self, recurse=True)
Definition: module.py:786
def _slow_forward(self, input, kwargs)
Definition: module.py:466
def named_parameters(self, prefix='', recurse=True)
Definition: module.py:810
def train(self, mode=True)
Definition: module.py:978
def typename(o)
Define basic utilities.
Definition: __init__.py:94