Caffe2 - C++ API
A deep learning, cross platform ML framework
data_parallel.h
1 #pragma once
2 
3 #include <torch/cuda.h>
4 #include <torch/nn/module.h>
5 #include <torch/nn/pimpl.h>
6 #include <torch/types.h>
7 
8 #include <torch/csrc/autograd/functions/comm.h>
9 #ifdef USE_CUDA
10 #include <torch/csrc/cuda/comm.h>
11 #endif
12 #include <ATen/core/functional.h>
13 
14 #include <ATen/Device.h>
15 #include <ATen/Parallel.h>
16 #include <c10/core/TensorOptions.h>
17 #include <c10/util/Exception.h>
18 
19 #include <cstddef>
20 #include <exception>
21 #include <memory>
22 #include <mutex>
23 #include <vector>
24 
25 namespace torch {
26 namespace nn {
27 namespace parallel {
28 
33 template <typename ModuleType>
34 std::vector<std::shared_ptr<ModuleType>> replicate(
35  const std::shared_ptr<ModuleType>& module,
36  const std::vector<Device>& devices) {
37  std::vector<std::shared_ptr<ModuleType>> replicas;
38  replicas.reserve(devices.size());
39  for (const auto& device : devices) {
40  replicas.push_back(
41  std::dynamic_pointer_cast<ModuleType>(module->clone(device)));
42  }
43  return replicas;
44 }
45 
49 template <typename ModuleType>
50 std::vector<ModuleHolder<ModuleType>> replicate(
51  const ModuleHolder<ModuleType>& module,
52  const std::vector<Device>& devices) {
53  auto ptrs = replicate(module.ptr(), devices);
54  return std::vector<ModuleHolder<ModuleType>>(ptrs.begin(), ptrs.end());
55 }
56 
71 template <typename ModuleType>
72 std::vector<Tensor> parallel_apply(
73  std::vector<ModuleType>& modules,
74  const std::vector<Tensor>& inputs,
75  const optional<std::vector<Device>>& devices = nullopt) {
76  AT_CHECK(
77  modules.size() == inputs.size(), "Must have as many inputs as modules");
78  if (devices) {
79  AT_CHECK(
80  modules.size() == devices->size(),
81  "Must have as many devices as modules");
82  }
83 
84  std::vector<Tensor> outputs(modules.size());
85  std::mutex mutex;
86 
87  // std::exception_ptr can be passed between threads:
88  // > An instance of std::exception_ptr may be passed to another function,
89  // > possibly on another thread, where the exception may be rethrown [...].
90  // https://en.cppreference.com/w/cpp/error/exception_ptr
91  std::exception_ptr exception;
92 
93  at::parallel_for(
94  /*begin=*/0,
95  /*end=*/modules.size(),
96  /*grain_size=*/1,
97  [&modules, &inputs, &devices, &outputs, &mutex, &exception](
98  int64_t index, int64_t stop) {
99  for (; index < stop; ++index) {
100  try {
101  auto output = modules[index]->forward(inputs[index]);
102  output =
103  output.to(devices ? (*devices)[index] : inputs[index].device());
104  std::lock_guard<std::mutex> lock(mutex);
105  outputs[index] = output;
106  } catch (...) {
107  std::lock_guard<std::mutex> lock(mutex);
108  if (!exception) {
109  exception = std::current_exception();
110  }
111  }
112  }
113  });
114 
115  if (exception) {
116  std::rethrow_exception(exception);
117  }
118 
119  return outputs;
120 }
121 
134 template <typename ModuleType>
135 Tensor data_parallel(
136  ModuleType module,
137  Tensor input,
138  optional<std::vector<Device>> devices = nullopt,
139  optional<Device> output_device = nullopt,
140  int64_t dim = 0) {
141  if (!devices) {
142  const auto device_count = torch::cuda::device_count();
143  AT_CHECK(
144  device_count > 0, "Expected at least one CUDA device to be available");
145  devices = std::vector<Device>();
146  devices->reserve(device_count);
147  for (size_t index = 0; index < device_count; ++index) {
148  devices->emplace_back(kCUDA, index);
149  }
150  }
151  if (!output_device) {
152  output_device = devices->front();
153  }
154 
155  if (devices->size() == 1) {
156  module->to(devices->front());
157  input = input.to(devices->front());
158  return module->forward(std::move(input)).to(*output_device);
159  }
160 
161 #ifdef USE_CUDA
162  autograd::Scatter scatter(*devices, /*chunk_sizes=*/nullopt, dim);
163  auto scattered_inputs = fmap<Tensor>(scatter.apply({std::move(input)}));
164 
165  auto replicas = replicate(module, *devices);
166  auto outputs = parallel_apply(replicas, scattered_inputs, *devices);
167  return autograd::Gather(*output_device, dim)
168  .apply(fmap<autograd::Variable>(std::move(outputs)))
169  .front();
170 #else
171  AT_ERROR("data_parallel not supported without CUDA");
172  return Tensor();
173 #endif
174 }
175 
176 } // namespace parallel
177 } // namespace nn
178 } // namespace torch
Definition: jit_type.h:17