5 def _is_script_module(module):
10 def _init_script_module():
15 def _is_jit_enabled():
17 return torch.jit._enabled
28 def _replicatable_module(module, memo=None):
31 def descendant_modules(module):
32 gen = module.modules()
36 if not _is_jit_enabled():
43 if _is_script_module(module):
44 memo.update(descendant_modules(module))
45 return all(_is_script_module(descendant)
for 46 descendant
in descendant_modules(module))
48 for child
in module.children():
53 if not _replicatable_module(child, memo):
59 def _build_param_dict(modules, module_copies, module_indices):
61 for module
in modules:
62 if not _is_script_module(module):
64 replica = module_copies[module_indices[module]]
65 for name, param
in module.named_parameters(recurse=
False):
66 param_dict[param] = (replica, name)
67 for name, buffer
in module.named_buffers(recurse=
False):
68 param_dict[buffer] = (replica, name)
72 def _copy_scriptmodule_methods(modules, module_copies, module_indices):
73 param_dict = _build_param_dict(modules, module_copies, module_indices)
74 for i, module
in enumerate(modules):
75 if not _is_script_module(module):
77 replica = module_copies[i]
78 for method_name
in module._method_names():
79 method = module._get_method(method_name)
81 for param
in method.initial_ivalues():
82 param_list.append(param_dict[param])
83 replica._copy_method(method_name, param_list, module)
86 def _broadcast_coalesced_reshape(tensors, devices, detach=False):
87 from ._functions
import Broadcast
89 return comm.broadcast_coalesced(tensors, devices)
93 tensor_copies = Broadcast.apply(devices, *tensors)
94 return [tensor_copies[i:i + len(tensors)]
95 for i
in range(0, len(tensor_copies), len(tensors))]
100 def replicate(network, devices, detach=False):
101 if not _replicatable_module(network):
102 raise RuntimeError(
"Cannot replicate network where python modules are " 103 "childrens of ScriptModule")
105 devices = list(map(
lambda x: _get_device_index(x,
True), devices))
106 num_replicas = len(devices)
108 params = list(network.parameters())
109 param_indices = {param: idx
for idx, param
in enumerate(params)}
110 param_copies = _broadcast_coalesced_reshape(params, devices, detach)
112 buffers = list(network.buffers())
116 if buf.requires_grad
and not detach:
117 buffers_rg.append(buf)
119 buffers_not_rg.append(buf)
121 buffer_indices_rg = {buf: idx
for idx, buf
in enumerate(buffers_rg)}
122 buffer_indices_not_rg = {buf: idx
for idx, buf
in enumerate(buffers_not_rg)}
124 buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach)
125 buffer_copies_not_rg = _broadcast_coalesced_reshape(buffers_not_rg, devices, detach=
True)
127 modules = list(network.modules())
128 module_copies = [[]
for device
in devices]
130 scriptmodule_skip_attr = {
"_parameters",
"_buffers",
"_modules"}
132 for i, module
in enumerate(modules):
133 module_indices[module] = i
134 for j
in range(num_replicas):
135 if _is_script_module(module):
138 replica = _init_script_module()
139 keys = set(module.__dict__.keys()) - scriptmodule_skip_attr
141 replica.__dict__[key] = module.__dict__[key]
143 replica = module.__new__(type(module))
144 replica.__dict__ = module.__dict__.copy()
145 replica._parameters = replica._parameters.copy()
146 replica._buffers = replica._buffers.copy()
147 replica._modules = replica._modules.copy()
149 module_copies[j].append(replica)
151 for i, module
in enumerate(modules):
152 for key, child
in module._modules.items():
154 for j
in range(num_replicas):
155 replica = module_copies[j][i]
156 replica._modules[key] =
None 158 module_idx = module_indices[child]
159 for j
in range(num_replicas):
160 replica = module_copies[j][i]
161 replica._modules[key] = module_copies[j][module_idx]
162 for key, param
in module._parameters.items():
164 for j
in range(num_replicas):
165 replica = module_copies[j][i]
166 replica._parameters[key] =
None 168 param_idx = param_indices[param]
169 for j
in range(num_replicas):
170 replica = module_copies[j][i]
171 replica._parameters[key] = param_copies[j][param_idx]
172 for key, buf
in module._buffers.items():
174 for j
in range(num_replicas):
175 replica = module_copies[j][i]
176 replica._buffers[key] =
None 178 if buf.requires_grad
and not detach:
179 buffer_copies = buffer_copies_rg
180 buffer_idx = buffer_indices_rg[buf]
182 buffer_copies = buffer_copies_not_rg
183 buffer_idx = buffer_indices_not_rg[buf]
184 for j
in range(num_replicas):
185 replica = module_copies[j][i]
186 replica._buffers[key] = buffer_copies[j][buffer_idx]
188 for j
in range(num_replicas):
189 _copy_scriptmodule_methods(modules, module_copies[j], module_indices)
191 return [module_copies[j][0]
for j
in range(num_replicas)]