3 #include <torch/cuda.h> 4 #include <torch/nn/module.h> 5 #include <torch/nn/pimpl.h> 6 #include <torch/types.h> 8 #include <torch/csrc/autograd/functions/comm.h> 10 #include <torch/csrc/cuda/comm.h> 12 #include <ATen/core/functional.h> 14 #include <ATen/Device.h> 15 #include <ATen/Parallel.h> 16 #include <c10/core/TensorOptions.h> 17 #include <c10/util/Exception.h> 33 template <
typename ModuleType>
34 std::vector<std::shared_ptr<ModuleType>> replicate(
35 const std::shared_ptr<ModuleType>& module,
36 const std::vector<Device>& devices) {
37 std::vector<std::shared_ptr<ModuleType>> replicas;
38 replicas.reserve(devices.size());
39 for (
const auto& device : devices) {
41 std::dynamic_pointer_cast<ModuleType>(module->clone(device)));
49 template <
typename ModuleType>
50 std::vector<ModuleHolder<ModuleType>> replicate(
51 const ModuleHolder<ModuleType>& module,
52 const std::vector<Device>& devices) {
53 auto ptrs = replicate(module.ptr(), devices);
54 return std::vector<ModuleHolder<ModuleType>>(ptrs.begin(), ptrs.end());
71 template <
typename ModuleType>
72 std::vector<Tensor> parallel_apply(
73 std::vector<ModuleType>& modules,
74 const std::vector<Tensor>& inputs,
75 const optional<std::vector<Device>>& devices = nullopt) {
77 modules.size() == inputs.size(),
"Must have as many inputs as modules");
80 modules.size() == devices->size(),
81 "Must have as many devices as modules");
84 std::vector<Tensor> outputs(modules.size());
91 std::exception_ptr exception;
97 [&modules, &inputs, &devices, &outputs, &mutex, &exception](
98 int64_t index, int64_t stop) {
99 for (; index < stop; ++index) {
101 auto output = modules[index]->forward(inputs[index]);
103 output.to(devices ? (*devices)[index] : inputs[index].device());
104 std::lock_guard<std::mutex> lock(mutex);
105 outputs[index] = output;
107 std::lock_guard<std::mutex> lock(mutex);
109 exception = std::current_exception();
116 std::rethrow_exception(exception);
134 template <
typename ModuleType>
138 optional<std::vector<Device>> devices = nullopt,
139 optional<Device> output_device = nullopt,
142 const auto device_count = torch::cuda::device_count();
144 device_count > 0,
"Expected at least one CUDA device to be available");
145 devices = std::vector<Device>();
146 devices->reserve(device_count);
147 for (
size_t index = 0; index < device_count; ++index) {
148 devices->emplace_back(kCUDA, index);
151 if (!output_device) {
152 output_device = devices->front();
155 if (devices->size() == 1) {
156 module->to(devices->front());
157 input = input.to(devices->front());
158 return module->forward(std::move(input)).to(*output_device);
162 autograd::Scatter scatter(*devices, nullopt, dim);
163 auto scattered_inputs = fmap<Tensor>(scatter.apply({std::move(input)}));
165 auto replicas = replicate(module, *devices);
166 auto outputs = parallel_apply(replicas, scattered_inputs, *devices);
167 return autograd::Gather(*output_device, dim)
168 .apply(fmap<autograd::Variable>(std::move(outputs)))
171 AT_ERROR(
"data_parallel not supported without CUDA");