2 from collections
import OrderedDict
4 from itertools
import islice
8 from .module
import Module
13 def __init__(self, **kwargs):
14 super(Container, self).__init__()
16 warnings.warn(
"nn.Container is deprecated. All of it's functionality " 17 "is now implemented in nn.Module. Subclass that instead.")
18 for key, value
in kwargs.items():
19 self.add_module(key, value)
23 r"""A sequential container. 24 Modules will be added to it in the order they are passed in the constructor. 25 Alternatively, an ordered dict of modules can also be passed in. 27 To make it easier to understand, here is a small example:: 29 # Example of using Sequential 30 model = nn.Sequential( 37 # Example of using Sequential with OrderedDict 38 model = nn.Sequential(OrderedDict([ 39 ('conv1', nn.Conv2d(1,20,5)), 41 ('conv2', nn.Conv2d(20,64,5)), 46 - Input: :math:`(*)` where `*` means, any number of additional 48 - Output: :math:`(*)`, same shape as the input 51 def __init__(self, *args):
52 super(Sequential, self).__init__()
53 if len(args) == 1
and isinstance(args[0], OrderedDict):
54 for key, module
in args[0].items():
55 self.add_module(key, module)
57 for idx, module
in enumerate(args):
58 self.add_module(str(idx), module)
60 def _get_item_by_idx(self, iterator, idx):
61 """Get the idx-th item of the iterator""" 63 idx = operator.index(idx)
64 if not -size <= idx < size:
65 raise IndexError(
'index {} is out of range'.format(idx))
67 return next(islice(iterator, idx,
None))
69 def __getitem__(self, idx):
70 if isinstance(idx, slice):
71 return self.__class__(OrderedDict(list(self._modules.items())[idx]))
75 def __setitem__(self, idx, module):
77 return setattr(self, key, module)
79 def __delitem__(self, idx):
80 if isinstance(idx, slice):
81 for key
in list(self._modules.keys())[idx]:
88 return len(self._modules)
91 keys = super(Sequential, self).__dir__()
92 keys = [key
for key
in keys
if not key.isdigit()]
95 def forward(self, input):
96 for module
in self._modules.values():
102 r"""Holds submodules in a list. 104 :class:`~torch.nn.ModuleList` can be indexed like a regular Python list, but 105 modules it contains are properly registered, and will be visible by all 106 :class:`~torch.nn.Module` methods. 109 modules (iterable, optional): an iterable of modules to add 113 class MyModule(nn.Module): 115 super(MyModule, self).__init__() 116 self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)]) 118 def forward(self, x): 119 # ModuleList can act as an iterable, or be indexed using ints 120 for i, l in enumerate(self.linears): 121 x = self.linears[i // 2](x) + l(x) 125 def __init__(self, modules=None):
126 super(ModuleList, self).__init__()
127 if modules
is not None:
130 def _get_abs_string_index(self, idx):
131 """Get the absolute index for the list of modules""" 132 idx = operator.index(idx)
133 if not (-len(self) <= idx < len(self)):
134 raise IndexError(
'index {} is out of range'.format(idx))
139 def __getitem__(self, idx):
140 if isinstance(idx, slice):
141 return self.__class__(list(self._modules.values())[idx])
145 def __setitem__(self, idx, module):
147 return setattr(self, str(idx), module)
149 def __delitem__(self, idx):
150 if isinstance(idx, slice):
151 for k
in range(len(self.
_modules))[idx]:
152 delattr(self, str(k))
156 str_indices = [str(i)
for i
in range(len(self.
_modules))]
157 self.
_modules = OrderedDict(list(zip(str_indices, self._modules.values())))
163 return iter(self._modules.values())
165 def __iadd__(self, modules):
166 return self.
extend(modules)
169 keys = super(ModuleList, self).__dir__()
170 keys = [key
for key
in keys
if not key.isdigit()]
173 def insert(self, index, module):
174 r"""Insert a given module before a given index in the list. 177 index (int): index to insert. 178 module (nn.Module): module to insert 180 for i
in range(len(self.
_modules), index, -1):
184 def append(self, module):
185 r"""Appends a given module to the end of the list. 188 module (nn.Module): module to append 190 self.add_module(str(len(self)), module)
193 def extend(self, modules):
194 r"""Appends modules from a Python iterable to the end of the list. 197 modules (iterable): iterable of modules to append 199 if not isinstance(modules, container_abcs.Iterable):
200 raise TypeError(
"ModuleList.extend should be called with an " 201 "iterable, but got " + type(modules).__name__)
203 for i, module
in enumerate(modules):
204 self.add_module(str(offset + i), module)
209 r"""Holds submodules in a dictionary. 211 :class:`~torch.nn.ModuleDict` can be indexed like a regular Python dictionary, 212 but modules it contains are properly registered, and will be visible by all 213 :class:`~torch.nn.Module` methods. 215 :class:`~torch.nn.ModuleDict` is an **ordered** dictionary that respects 217 * the order of insertion, and 219 * in :meth:`~torch.nn.ModuleDict.update`, the order of the merged ``OrderedDict`` 220 or another :class:`~torch.nn.ModuleDict` (the argument to :meth:`~torch.nn.ModuleDict.update`). 222 Note that :meth:`~torch.nn.ModuleDict.update` with other unordered mapping 223 types (e.g., Python's plain ``dict``) does not preserve the order of the 227 modules (iterable, optional): a mapping (dictionary) of (string: module) 228 or an iterable of key-value pairs of type (string, module) 232 class MyModule(nn.Module): 234 super(MyModule, self).__init__() 235 self.choices = nn.ModuleDict({ 236 'conv': nn.Conv2d(10, 10, 3), 237 'pool': nn.MaxPool2d(3) 239 self.activations = nn.ModuleDict([ 240 ['lrelu', nn.LeakyReLU()], 241 ['prelu', nn.PReLU()] 244 def forward(self, x, choice, act): 245 x = self.choices[choice](x) 246 x = self.activations[act](x) 250 def __init__(self, modules=None):
251 super(ModuleDict, self).__init__()
252 if modules
is not None:
255 def __getitem__(self, key):
256 return self._modules[key]
258 def __setitem__(self, key, module):
259 self.add_module(key, module)
261 def __delitem__(self, key):
262 del self._modules[key]
265 return len(self._modules)
268 return iter(self._modules)
270 def __contains__(self, key):
271 return key
in self._modules
274 """Remove all items from the ModuleDict. 276 self._modules.clear()
279 r"""Remove key from the ModuleDict and return its module. 282 key (string): key to pop from the ModuleDict 289 r"""Return an iterable of the ModuleDict keys. 291 return self._modules.keys()
294 r"""Return an iterable of the ModuleDict key/value pairs. 296 return self._modules.items()
299 r"""Return an iterable of the ModuleDict values. 301 return self._modules.values()
303 def update(self, modules):
304 r"""Update the :class:`~torch.nn.ModuleDict` with the key-value pairs from a 305 mapping or an iterable, overwriting existing keys. 308 If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or 309 an iterable of key-value pairs, the order of new elements in it is preserved. 312 modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`, 313 or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`) 315 if not isinstance(modules, container_abcs.Iterable):
316 raise TypeError(
"ModuleDict.update should be called with an " 317 "iterable of key/value pairs, but got " +
318 type(modules).__name__)
320 if isinstance(modules, container_abcs.Mapping):
321 if isinstance(modules, (OrderedDict, ModuleDict)):
322 for key, module
in modules.items():
325 for key, module
in sorted(modules.items()):
328 for j, m
in enumerate(modules):
329 if not isinstance(m, container_abcs.Iterable):
330 raise TypeError(
"ModuleDict update sequence element " 331 "#" + str(j) +
" should be Iterable; is" +
334 raise ValueError(
"ModuleDict update sequence element " 335 "#" + str(j) +
" has length " + str(len(m)) +
341 r"""Holds parameters in a list. 343 :class:`~torch.nn.ParameterList` can be indexed like a regular Python 344 list, but parameters it contains are properly registered, and will be 345 visible by all :class:`~torch.nn.Module` methods. 348 parameters (iterable, optional): an iterable of :class:`~torch.nn.Parameter` to add 352 class MyModule(nn.Module): 354 super(MyModule, self).__init__() 355 self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)]) 357 def forward(self, x): 358 # ParameterList can act as an iterable, or be indexed using ints 359 for i, p in enumerate(self.params): 360 x = self.params[i // 2].mm(x) + p.mm(x) 364 def __init__(self, parameters=None):
365 super(ParameterList, self).__init__()
366 if parameters
is not None:
369 def _get_abs_string_index(self, idx):
370 """Get the absolute index for the list of modules""" 371 idx = operator.index(idx)
372 if not (-len(self) <= idx < len(self)):
373 raise IndexError(
'index {} is out of range'.format(idx))
378 def __getitem__(self, idx):
379 if isinstance(idx, slice):
380 return self.__class__(list(self._parameters.values())[idx])
383 return self._parameters[str(idx)]
385 def __setitem__(self, idx, param):
387 return self.register_parameter(str(idx), param)
390 return len(self._parameters)
393 return iter(self._parameters.values())
395 def __iadd__(self, parameters):
396 return self.
extend(parameters)
399 keys = super(ParameterList, self).__dir__()
400 keys = [key
for key
in keys
if not key.isdigit()]
404 """Appends a given parameter at the end of the list. 407 parameter (nn.Parameter): parameter to append 409 self.register_parameter(str(len(self)), parameter)
413 """Appends parameters from a Python iterable to the end of the list. 416 parameters (iterable): iterable of parameters to append 418 if not isinstance(parameters, container_abcs.Iterable):
419 raise TypeError(
"ParameterList.extend should be called with an " 420 "iterable, but got " + type(parameters).__name__)
422 for i, param
in enumerate(parameters):
423 self.register_parameter(str(offset + i), param)
426 def extra_repr(self):
428 for k, p
in self._parameters.items():
429 size_str =
'x'.join(str(size)
for size
in p.size())
430 device_str =
'' if not p.is_cuda
else ' (GPU {})'.format(p.get_device())
431 parastr =
'Parameter containing: [{} of size {}{}]'.format(
433 child_lines.append(
' (' + str(k) +
'): ' + parastr)
434 tmpstr =
'\n'.join(child_lines)
439 r"""Holds parameters in a dictionary. 441 ParameterDict can be indexed like a regular Python dictionary, but parameters it 442 contains are properly registered, and will be visible by all Module methods. 444 :class:`~torch.nn.ParameterDict` is an **ordered** dictionary that respects 446 * the order of insertion, and 448 * in :meth:`~torch.nn.ParameterDict.update`, the order of the merged ``OrderedDict`` 449 or another :class:`~torch.nn.ParameterDict` (the argument to 450 :meth:`~torch.nn.ParameterDict.update`). 452 Note that :meth:`~torch.nn.ParameterDict.update` with other unordered mapping 453 types (e.g., Python's plain ``dict``) does not preserve the order of the 457 parameters (iterable, optional): a mapping (dictionary) of 458 (string : :class:`~torch.nn.Parameter`) or an iterable of key-value pairs 459 of type (string, :class:`~torch.nn.Parameter`) 463 class MyModule(nn.Module): 465 super(MyModule, self).__init__() 466 self.params = nn.ParameterDict({ 467 'left': nn.Parameter(torch.randn(5, 10)), 468 'right': nn.Parameter(torch.randn(5, 10)) 471 def forward(self, x, choice): 472 x = self.params[choice].mm(x) 476 def __init__(self, parameters=None):
477 super(ParameterDict, self).__init__()
478 if parameters
is not None:
481 def __getitem__(self, key):
482 return self._parameters[key]
484 def __setitem__(self, key, parameter):
485 self.register_parameter(key, parameter)
487 def __delitem__(self, key):
488 del self._parameters[key]
491 return len(self._parameters)
494 return iter(self._parameters.keys())
496 def __contains__(self, key):
497 return key
in self._parameters
500 """Remove all items from the ParameterDict. 502 self._parameters.clear()
505 r"""Remove key from the ParameterDict and return its parameter. 508 key (string): key to pop from the ParameterDict 515 r"""Return an iterable of the ParameterDict keys. 517 return self._parameters.keys()
520 r"""Return an iterable of the ParameterDict key/value pairs. 522 return self._parameters.items()
525 r"""Return an iterable of the ParameterDict values. 527 return self._parameters.values()
529 def update(self, parameters):
530 r"""Update the :class:`~torch.nn.ParameterDict` with the key-value pairs from a 531 mapping or an iterable, overwriting existing keys. 534 If :attr:`parameters` is an ``OrderedDict``, a :class:`~torch.nn.ParameterDict`, or 535 an iterable of key-value pairs, the order of new elements in it is preserved. 538 parameters (iterable): a mapping (dictionary) from string to 539 :class:`~torch.nn.Parameter`, or an iterable of 540 key-value pairs of type (string, :class:`~torch.nn.Parameter`) 542 if not isinstance(parameters, container_abcs.Iterable):
543 raise TypeError(
"ParametersDict.update should be called with an " 544 "iterable of key/value pairs, but got " +
545 type(parameters).__name__)
547 if isinstance(parameters, container_abcs.Mapping):
548 if isinstance(parameters, (OrderedDict, ParameterDict)):
549 for key, parameter
in parameters.items():
550 self[key] = parameter
552 for key, parameter
in sorted(parameters.items()):
553 self[key] = parameter
555 for j, p
in enumerate(parameters):
556 if not isinstance(p, container_abcs.Iterable):
557 raise TypeError(
"ParameterDict update sequence element " 558 "#" + str(j) +
" should be Iterable; is" +
561 raise ValueError(
"ParameterDict update sequence element " 562 "#" + str(j) +
" has length " + str(len(p)) +
566 def extra_repr(self):
568 for k, p
in self._parameters.items():
569 size_str =
'x'.join(str(size)
for size
in p.size())
570 device_str =
'' if not p.is_cuda
else ' (GPU {})'.format(p.get_device())
571 parastr =
'Parameter containing: [{} of size {}{}]'.format(
573 child_lines.append(
' (' + k +
'): ' + parastr)
574 tmpstr =
'\n'.join(child_lines)
def _get_abs_string_index(self, idx)
def _get_item_by_idx(self, iterator, idx)
def append(self, parameter)
def extend(self, modules)
def update(self, parameters)
def update(self, modules)
def _get_abs_string_index(self, idx)
def typename(o)
Define basic utilities.
def extend(self, parameters)