Caffe2 - Python API
A deep learning, cross platform ML framework
container.py
1 import warnings
2 from collections import OrderedDict
3 from torch._six import container_abcs
4 from itertools import islice
5 import operator
6 
7 import torch
8 from .module import Module
9 
10 
11 class Container(Module):
12 
13  def __init__(self, **kwargs):
14  super(Container, self).__init__()
15  # DeprecationWarning is ignored by default <sigh>
16  warnings.warn("nn.Container is deprecated. All of it's functionality "
17  "is now implemented in nn.Module. Subclass that instead.")
18  for key, value in kwargs.items():
19  self.add_module(key, value)
20 
21 
22 class Sequential(Module):
23  r"""A sequential container.
24  Modules will be added to it in the order they are passed in the constructor.
25  Alternatively, an ordered dict of modules can also be passed in.
26 
27  To make it easier to understand, here is a small example::
28 
29  # Example of using Sequential
30  model = nn.Sequential(
31  nn.Conv2d(1,20,5),
32  nn.ReLU(),
33  nn.Conv2d(20,64,5),
34  nn.ReLU()
35  )
36 
37  # Example of using Sequential with OrderedDict
38  model = nn.Sequential(OrderedDict([
39  ('conv1', nn.Conv2d(1,20,5)),
40  ('relu1', nn.ReLU()),
41  ('conv2', nn.Conv2d(20,64,5)),
42  ('relu2', nn.ReLU())
43  ]))
44 
45  Shape:
46  - Input: :math:`(*)` where `*` means, any number of additional
47  dimensions
48  - Output: :math:`(*)`, same shape as the input
49  """
50 
51  def __init__(self, *args):
52  super(Sequential, self).__init__()
53  if len(args) == 1 and isinstance(args[0], OrderedDict):
54  for key, module in args[0].items():
55  self.add_module(key, module)
56  else:
57  for idx, module in enumerate(args):
58  self.add_module(str(idx), module)
59 
60  def _get_item_by_idx(self, iterator, idx):
61  """Get the idx-th item of the iterator"""
62  size = len(self)
63  idx = operator.index(idx)
64  if not -size <= idx < size:
65  raise IndexError('index {} is out of range'.format(idx))
66  idx %= size
67  return next(islice(iterator, idx, None))
68 
69  def __getitem__(self, idx):
70  if isinstance(idx, slice):
71  return self.__class__(OrderedDict(list(self._modules.items())[idx]))
72  else:
73  return self._get_item_by_idx(self._modules.values(), idx)
74 
75  def __setitem__(self, idx, module):
76  key = self._get_item_by_idx(self._modules.keys(), idx)
77  return setattr(self, key, module)
78 
79  def __delitem__(self, idx):
80  if isinstance(idx, slice):
81  for key in list(self._modules.keys())[idx]:
82  delattr(self, key)
83  else:
84  key = self._get_item_by_idx(self._modules.keys(), idx)
85  delattr(self, key)
86 
87  def __len__(self):
88  return len(self._modules)
89 
90  def __dir__(self):
91  keys = super(Sequential, self).__dir__()
92  keys = [key for key in keys if not key.isdigit()]
93  return keys
94 
95  def forward(self, input):
96  for module in self._modules.values():
97  input = module(input)
98  return input
99 
100 
101 class ModuleList(Module):
102  r"""Holds submodules in a list.
103 
104  :class:`~torch.nn.ModuleList` can be indexed like a regular Python list, but
105  modules it contains are properly registered, and will be visible by all
106  :class:`~torch.nn.Module` methods.
107 
108  Arguments:
109  modules (iterable, optional): an iterable of modules to add
110 
111  Example::
112 
113  class MyModule(nn.Module):
114  def __init__(self):
115  super(MyModule, self).__init__()
116  self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
117 
118  def forward(self, x):
119  # ModuleList can act as an iterable, or be indexed using ints
120  for i, l in enumerate(self.linears):
121  x = self.linears[i // 2](x) + l(x)
122  return x
123  """
124 
125  def __init__(self, modules=None):
126  super(ModuleList, self).__init__()
127  if modules is not None:
128  self += modules
129 
130  def _get_abs_string_index(self, idx):
131  """Get the absolute index for the list of modules"""
132  idx = operator.index(idx)
133  if not (-len(self) <= idx < len(self)):
134  raise IndexError('index {} is out of range'.format(idx))
135  if idx < 0:
136  idx += len(self)
137  return str(idx)
138 
139  def __getitem__(self, idx):
140  if isinstance(idx, slice):
141  return self.__class__(list(self._modules.values())[idx])
142  else:
143  return self._modules[self._get_abs_string_index(idx)]
144 
145  def __setitem__(self, idx, module):
146  idx = self._get_abs_string_index(idx)
147  return setattr(self, str(idx), module)
148 
149  def __delitem__(self, idx):
150  if isinstance(idx, slice):
151  for k in range(len(self._modules))[idx]:
152  delattr(self, str(k))
153  else:
154  delattr(self, self._get_abs_string_index(idx))
155  # To preserve numbering, self._modules is being reconstructed with modules after deletion
156  str_indices = [str(i) for i in range(len(self._modules))]
157  self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
158 
159  def __len__(self):
160  return len(self._modules)
161 
162  def __iter__(self):
163  return iter(self._modules.values())
164 
165  def __iadd__(self, modules):
166  return self.extend(modules)
167 
168  def __dir__(self):
169  keys = super(ModuleList, self).__dir__()
170  keys = [key for key in keys if not key.isdigit()]
171  return keys
172 
173  def insert(self, index, module):
174  r"""Insert a given module before a given index in the list.
175 
176  Arguments:
177  index (int): index to insert.
178  module (nn.Module): module to insert
179  """
180  for i in range(len(self._modules), index, -1):
181  self._modules[str(i)] = self._modules[str(i - 1)]
182  self._modules[str(index)] = module
183 
184  def append(self, module):
185  r"""Appends a given module to the end of the list.
186 
187  Arguments:
188  module (nn.Module): module to append
189  """
190  self.add_module(str(len(self)), module)
191  return self
192 
193  def extend(self, modules):
194  r"""Appends modules from a Python iterable to the end of the list.
195 
196  Arguments:
197  modules (iterable): iterable of modules to append
198  """
199  if not isinstance(modules, container_abcs.Iterable):
200  raise TypeError("ModuleList.extend should be called with an "
201  "iterable, but got " + type(modules).__name__)
202  offset = len(self)
203  for i, module in enumerate(modules):
204  self.add_module(str(offset + i), module)
205  return self
206 
207 
208 class ModuleDict(Module):
209  r"""Holds submodules in a dictionary.
210 
211  :class:`~torch.nn.ModuleDict` can be indexed like a regular Python dictionary,
212  but modules it contains are properly registered, and will be visible by all
213  :class:`~torch.nn.Module` methods.
214 
215  :class:`~torch.nn.ModuleDict` is an **ordered** dictionary that respects
216 
217  * the order of insertion, and
218 
219  * in :meth:`~torch.nn.ModuleDict.update`, the order of the merged ``OrderedDict``
220  or another :class:`~torch.nn.ModuleDict` (the argument to :meth:`~torch.nn.ModuleDict.update`).
221 
222  Note that :meth:`~torch.nn.ModuleDict.update` with other unordered mapping
223  types (e.g., Python's plain ``dict``) does not preserve the order of the
224  merged mapping.
225 
226  Arguments:
227  modules (iterable, optional): a mapping (dictionary) of (string: module)
228  or an iterable of key-value pairs of type (string, module)
229 
230  Example::
231 
232  class MyModule(nn.Module):
233  def __init__(self):
234  super(MyModule, self).__init__()
235  self.choices = nn.ModuleDict({
236  'conv': nn.Conv2d(10, 10, 3),
237  'pool': nn.MaxPool2d(3)
238  })
239  self.activations = nn.ModuleDict([
240  ['lrelu', nn.LeakyReLU()],
241  ['prelu', nn.PReLU()]
242  ])
243 
244  def forward(self, x, choice, act):
245  x = self.choices[choice](x)
246  x = self.activations[act](x)
247  return x
248  """
249 
250  def __init__(self, modules=None):
251  super(ModuleDict, self).__init__()
252  if modules is not None:
253  self.update(modules)
254 
255  def __getitem__(self, key):
256  return self._modules[key]
257 
258  def __setitem__(self, key, module):
259  self.add_module(key, module)
260 
261  def __delitem__(self, key):
262  del self._modules[key]
263 
264  def __len__(self):
265  return len(self._modules)
266 
267  def __iter__(self):
268  return iter(self._modules)
269 
270  def __contains__(self, key):
271  return key in self._modules
272 
273  def clear(self):
274  """Remove all items from the ModuleDict.
275  """
276  self._modules.clear()
277 
278  def pop(self, key):
279  r"""Remove key from the ModuleDict and return its module.
280 
281  Arguments:
282  key (string): key to pop from the ModuleDict
283  """
284  v = self[key]
285  del self[key]
286  return v
287 
288  def keys(self):
289  r"""Return an iterable of the ModuleDict keys.
290  """
291  return self._modules.keys()
292 
293  def items(self):
294  r"""Return an iterable of the ModuleDict key/value pairs.
295  """
296  return self._modules.items()
297 
298  def values(self):
299  r"""Return an iterable of the ModuleDict values.
300  """
301  return self._modules.values()
302 
303  def update(self, modules):
304  r"""Update the :class:`~torch.nn.ModuleDict` with the key-value pairs from a
305  mapping or an iterable, overwriting existing keys.
306 
307  .. note::
308  If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or
309  an iterable of key-value pairs, the order of new elements in it is preserved.
310 
311  Arguments:
312  modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`,
313  or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`)
314  """
315  if not isinstance(modules, container_abcs.Iterable):
316  raise TypeError("ModuleDict.update should be called with an "
317  "iterable of key/value pairs, but got " +
318  type(modules).__name__)
319 
320  if isinstance(modules, container_abcs.Mapping):
321  if isinstance(modules, (OrderedDict, ModuleDict)):
322  for key, module in modules.items():
323  self[key] = module
324  else:
325  for key, module in sorted(modules.items()):
326  self[key] = module
327  else:
328  for j, m in enumerate(modules):
329  if not isinstance(m, container_abcs.Iterable):
330  raise TypeError("ModuleDict update sequence element "
331  "#" + str(j) + " should be Iterable; is" +
332  type(m).__name__)
333  if not len(m) == 2:
334  raise ValueError("ModuleDict update sequence element "
335  "#" + str(j) + " has length " + str(len(m)) +
336  "; 2 is required")
337  self[m[0]] = m[1]
338 
339 
340 class ParameterList(Module):
341  r"""Holds parameters in a list.
342 
343  :class:`~torch.nn.ParameterList` can be indexed like a regular Python
344  list, but parameters it contains are properly registered, and will be
345  visible by all :class:`~torch.nn.Module` methods.
346 
347  Arguments:
348  parameters (iterable, optional): an iterable of :class:`~torch.nn.Parameter` to add
349 
350  Example::
351 
352  class MyModule(nn.Module):
353  def __init__(self):
354  super(MyModule, self).__init__()
355  self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])
356 
357  def forward(self, x):
358  # ParameterList can act as an iterable, or be indexed using ints
359  for i, p in enumerate(self.params):
360  x = self.params[i // 2].mm(x) + p.mm(x)
361  return x
362  """
363 
364  def __init__(self, parameters=None):
365  super(ParameterList, self).__init__()
366  if parameters is not None:
367  self += parameters
368 
369  def _get_abs_string_index(self, idx):
370  """Get the absolute index for the list of modules"""
371  idx = operator.index(idx)
372  if not (-len(self) <= idx < len(self)):
373  raise IndexError('index {} is out of range'.format(idx))
374  if idx < 0:
375  idx += len(self)
376  return str(idx)
377 
378  def __getitem__(self, idx):
379  if isinstance(idx, slice):
380  return self.__class__(list(self._parameters.values())[idx])
381  else:
382  idx = self._get_abs_string_index(idx)
383  return self._parameters[str(idx)]
384 
385  def __setitem__(self, idx, param):
386  idx = self._get_abs_string_index(idx)
387  return self.register_parameter(str(idx), param)
388 
389  def __len__(self):
390  return len(self._parameters)
391 
392  def __iter__(self):
393  return iter(self._parameters.values())
394 
395  def __iadd__(self, parameters):
396  return self.extend(parameters)
397 
398  def __dir__(self):
399  keys = super(ParameterList, self).__dir__()
400  keys = [key for key in keys if not key.isdigit()]
401  return keys
402 
403  def append(self, parameter):
404  """Appends a given parameter at the end of the list.
405 
406  Arguments:
407  parameter (nn.Parameter): parameter to append
408  """
409  self.register_parameter(str(len(self)), parameter)
410  return self
411 
412  def extend(self, parameters):
413  """Appends parameters from a Python iterable to the end of the list.
414 
415  Arguments:
416  parameters (iterable): iterable of parameters to append
417  """
418  if not isinstance(parameters, container_abcs.Iterable):
419  raise TypeError("ParameterList.extend should be called with an "
420  "iterable, but got " + type(parameters).__name__)
421  offset = len(self)
422  for i, param in enumerate(parameters):
423  self.register_parameter(str(offset + i), param)
424  return self
425 
426  def extra_repr(self):
427  child_lines = []
428  for k, p in self._parameters.items():
429  size_str = 'x'.join(str(size) for size in p.size())
430  device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
431  parastr = 'Parameter containing: [{} of size {}{}]'.format(
432  torch.typename(p.data), size_str, device_str)
433  child_lines.append(' (' + str(k) + '): ' + parastr)
434  tmpstr = '\n'.join(child_lines)
435  return tmpstr
436 
437 
438 class ParameterDict(Module):
439  r"""Holds parameters in a dictionary.
440 
441  ParameterDict can be indexed like a regular Python dictionary, but parameters it
442  contains are properly registered, and will be visible by all Module methods.
443 
444  :class:`~torch.nn.ParameterDict` is an **ordered** dictionary that respects
445 
446  * the order of insertion, and
447 
448  * in :meth:`~torch.nn.ParameterDict.update`, the order of the merged ``OrderedDict``
449  or another :class:`~torch.nn.ParameterDict` (the argument to
450  :meth:`~torch.nn.ParameterDict.update`).
451 
452  Note that :meth:`~torch.nn.ParameterDict.update` with other unordered mapping
453  types (e.g., Python's plain ``dict``) does not preserve the order of the
454  merged mapping.
455 
456  Arguments:
457  parameters (iterable, optional): a mapping (dictionary) of
458  (string : :class:`~torch.nn.Parameter`) or an iterable of key-value pairs
459  of type (string, :class:`~torch.nn.Parameter`)
460 
461  Example::
462 
463  class MyModule(nn.Module):
464  def __init__(self):
465  super(MyModule, self).__init__()
466  self.params = nn.ParameterDict({
467  'left': nn.Parameter(torch.randn(5, 10)),
468  'right': nn.Parameter(torch.randn(5, 10))
469  })
470 
471  def forward(self, x, choice):
472  x = self.params[choice].mm(x)
473  return x
474  """
475 
476  def __init__(self, parameters=None):
477  super(ParameterDict, self).__init__()
478  if parameters is not None:
479  self.update(parameters)
480 
481  def __getitem__(self, key):
482  return self._parameters[key]
483 
484  def __setitem__(self, key, parameter):
485  self.register_parameter(key, parameter)
486 
487  def __delitem__(self, key):
488  del self._parameters[key]
489 
490  def __len__(self):
491  return len(self._parameters)
492 
493  def __iter__(self):
494  return iter(self._parameters.keys())
495 
496  def __contains__(self, key):
497  return key in self._parameters
498 
499  def clear(self):
500  """Remove all items from the ParameterDict.
501  """
502  self._parameters.clear()
503 
504  def pop(self, key):
505  r"""Remove key from the ParameterDict and return its parameter.
506 
507  Arguments:
508  key (string): key to pop from the ParameterDict
509  """
510  v = self[key]
511  del self[key]
512  return v
513 
514  def keys(self):
515  r"""Return an iterable of the ParameterDict keys.
516  """
517  return self._parameters.keys()
518 
519  def items(self):
520  r"""Return an iterable of the ParameterDict key/value pairs.
521  """
522  return self._parameters.items()
523 
524  def values(self):
525  r"""Return an iterable of the ParameterDict values.
526  """
527  return self._parameters.values()
528 
529  def update(self, parameters):
530  r"""Update the :class:`~torch.nn.ParameterDict` with the key-value pairs from a
531  mapping or an iterable, overwriting existing keys.
532 
533  .. note::
534  If :attr:`parameters` is an ``OrderedDict``, a :class:`~torch.nn.ParameterDict`, or
535  an iterable of key-value pairs, the order of new elements in it is preserved.
536 
537  Arguments:
538  parameters (iterable): a mapping (dictionary) from string to
539  :class:`~torch.nn.Parameter`, or an iterable of
540  key-value pairs of type (string, :class:`~torch.nn.Parameter`)
541  """
542  if not isinstance(parameters, container_abcs.Iterable):
543  raise TypeError("ParametersDict.update should be called with an "
544  "iterable of key/value pairs, but got " +
545  type(parameters).__name__)
546 
547  if isinstance(parameters, container_abcs.Mapping):
548  if isinstance(parameters, (OrderedDict, ParameterDict)):
549  for key, parameter in parameters.items():
550  self[key] = parameter
551  else:
552  for key, parameter in sorted(parameters.items()):
553  self[key] = parameter
554  else:
555  for j, p in enumerate(parameters):
556  if not isinstance(p, container_abcs.Iterable):
557  raise TypeError("ParameterDict update sequence element "
558  "#" + str(j) + " should be Iterable; is" +
559  type(p).__name__)
560  if not len(p) == 2:
561  raise ValueError("ParameterDict update sequence element "
562  "#" + str(j) + " has length " + str(len(p)) +
563  "; 2 is required")
564  self[p[0]] = p[1]
565 
566  def extra_repr(self):
567  child_lines = []
568  for k, p in self._parameters.items():
569  size_str = 'x'.join(str(size) for size in p.size())
570  device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
571  parastr = 'Parameter containing: [{} of size {}{}]'.format(
572  torch.typename(p.data), size_str, device_str)
573  child_lines.append(' (' + k + '): ' + parastr)
574  tmpstr = '\n'.join(child_lines)
575  return tmpstr
def _get_item_by_idx(self, iterator, idx)
Definition: container.py:60
def typename(o)
Define basic utilities.
Definition: __init__.py:94