Caffe2 - Python API
A deep learning, cross platform ML framework
replicate.py
1 import torch.cuda.comm as comm
2 from torch.cuda._utils import _get_device_index
3 
4 
5 def _is_script_module(module):
6  import torch.jit
7  return isinstance(module, torch.jit.ScriptModule)
8 
9 
10 def _init_script_module():
11  import torch.jit
12  return torch.jit.ScriptModule()
13 
14 
15 def _is_jit_enabled():
16  import torch.jit
17  return torch.jit._enabled
18 
19 
20 # Check if we can safely replicate the module.
21 # there are three types of module:
22 # 1. python modules
23 # 2. weak python modules (nn.Module annotated by @weak_module)
24 # 3. ScriptModule
25 #
26 # currently a module cannot be replicated properly if the descendants of
27 # any ScriptModule contains python module (type 1 above)
28 def _replicatable_module(module, memo=None):
29 
30  # module.modules() contains module itself as the first element
31  def descendant_modules(module):
32  gen = module.modules()
33  next(gen)
34  return gen
35 
36  if not _is_jit_enabled():
37  return True
38  if memo is None:
39  memo = set()
40 
41  # memorize visited modules
42  memo.add(module)
43  if _is_script_module(module):
44  memo.update(descendant_modules(module))
45  return all(_is_script_module(descendant) for
46  descendant in descendant_modules(module))
47 
48  for child in module.children():
49  # since any unreplicatable module will cause the check to return
50  # False early, visited modules here can be safely ignored.
51  if child in memo:
52  continue
53  if not _replicatable_module(child, memo):
54  return False
55 
56  return True
57 
58 
59 def _build_param_dict(modules, module_copies, module_indices):
60  param_dict = {}
61  for module in modules:
62  if not _is_script_module(module):
63  continue
64  replica = module_copies[module_indices[module]]
65  for name, param in module.named_parameters(recurse=False):
66  param_dict[param] = (replica, name)
67  for name, buffer in module.named_buffers(recurse=False):
68  param_dict[buffer] = (replica, name)
69  return param_dict
70 
71 
72 def _copy_scriptmodule_methods(modules, module_copies, module_indices):
73  param_dict = _build_param_dict(modules, module_copies, module_indices)
74  for i, module in enumerate(modules):
75  if not _is_script_module(module):
76  continue
77  replica = module_copies[i]
78  for method_name in module._method_names():
79  method = module._get_method(method_name)
80  param_list = []
81  for param in method.initial_ivalues():
82  param_list.append(param_dict[param])
83  replica._copy_method(method_name, param_list, module)
84 
85 
86 def _broadcast_coalesced_reshape(tensors, devices, detach=False):
87  from ._functions import Broadcast
88  if detach:
89  return comm.broadcast_coalesced(tensors, devices)
90  else:
91  # Use the autograd function to broadcast if not detach
92  if len(tensors) > 0:
93  tensor_copies = Broadcast.apply(devices, *tensors)
94  return [tensor_copies[i:i + len(tensors)]
95  for i in range(0, len(tensor_copies), len(tensors))]
96  else:
97  return []
98 
99 
100 def replicate(network, devices, detach=False):
101  if not _replicatable_module(network):
102  raise RuntimeError("Cannot replicate network where python modules are "
103  "childrens of ScriptModule")
104 
105  devices = list(map(lambda x: _get_device_index(x, True), devices))
106  num_replicas = len(devices)
107 
108  params = list(network.parameters())
109  param_indices = {param: idx for idx, param in enumerate(params)}
110  param_copies = _broadcast_coalesced_reshape(params, devices, detach)
111 
112  buffers = list(network.buffers())
113  buffers_rg = []
114  buffers_not_rg = []
115  for buf in buffers:
116  if buf.requires_grad and not detach:
117  buffers_rg.append(buf)
118  else:
119  buffers_not_rg.append(buf)
120 
121  buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)}
122  buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)}
123 
124  buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach)
125  buffer_copies_not_rg = _broadcast_coalesced_reshape(buffers_not_rg, devices, detach=True)
126 
127  modules = list(network.modules())
128  module_copies = [[] for device in devices]
129  module_indices = {}
130  scriptmodule_skip_attr = {"_parameters", "_buffers", "_modules"}
131 
132  for i, module in enumerate(modules):
133  module_indices[module] = i
134  for j in range(num_replicas):
135  if _is_script_module(module):
136  # we have to initialize ScriptModule properly so that
137  # it works with pybind11
138  replica = _init_script_module()
139  keys = set(module.__dict__.keys()) - scriptmodule_skip_attr
140  for key in keys:
141  replica.__dict__[key] = module.__dict__[key]
142  else:
143  replica = module.__new__(type(module))
144  replica.__dict__ = module.__dict__.copy()
145  replica._parameters = replica._parameters.copy()
146  replica._buffers = replica._buffers.copy()
147  replica._modules = replica._modules.copy()
148 
149  module_copies[j].append(replica)
150 
151  for i, module in enumerate(modules):
152  for key, child in module._modules.items():
153  if child is None:
154  for j in range(num_replicas):
155  replica = module_copies[j][i]
156  replica._modules[key] = None
157  else:
158  module_idx = module_indices[child]
159  for j in range(num_replicas):
160  replica = module_copies[j][i]
161  replica._modules[key] = module_copies[j][module_idx]
162  for key, param in module._parameters.items():
163  if param is None:
164  for j in range(num_replicas):
165  replica = module_copies[j][i]
166  replica._parameters[key] = None
167  else:
168  param_idx = param_indices[param]
169  for j in range(num_replicas):
170  replica = module_copies[j][i]
171  replica._parameters[key] = param_copies[j][param_idx]
172  for key, buf in module._buffers.items():
173  if buf is None:
174  for j in range(num_replicas):
175  replica = module_copies[j][i]
176  replica._buffers[key] = None
177  else:
178  if buf.requires_grad and not detach:
179  buffer_copies = buffer_copies_rg
180  buffer_idx = buffer_indices_rg[buf]
181  else:
182  buffer_copies = buffer_copies_not_rg
183  buffer_idx = buffer_indices_not_rg[buf]
184  for j in range(num_replicas):
185  replica = module_copies[j][i]
186  replica._buffers[key] = buffer_copies[j][buffer_idx]
187 
188  for j in range(num_replicas):
189  _copy_scriptmodule_methods(modules, module_copies[j], module_indices)
190 
191  return [module_copies[j][0] for j in range(num_replicas)]