1 #include <torch/csrc/distributed/c10d/ddp.h> 3 #include <torch/csrc/cuda/comm.h> 4 #include <torch/csrc/utils/tensor_flatten.h> 6 #include <torch/csrc/cuda/nccl.h> 8 #include <c10d/ProcessGroup.hpp> 10 #include <ATen/ATen.h> 11 #include <ATen/cuda/CUDAEvent.h> 12 #include <c10/cuda/CUDAGuard.h> 13 #include <ATen/cuda/CUDAMultiStreamGuard.h> 25 void copyBroadcastTensorsToReplicas(
26 const std::vector<std::vector<at::Tensor>>& broadcastTensors,
27 std::vector<std::vector<at::Tensor>>& replicaData) {
28 AT_ASSERT(replicaData.size() == broadcastTensors.size());
30 for (
size_t replica = 1; replica < replicaData.size(); ++replica) {
31 AT_ASSERT(replicaData[replica].size() == broadcastTensors[replica].size());
32 for (
size_t tensor = 0; tensor < replicaData[replica].size(); ++tensor) {
33 replicaData[replica][tensor].set_(broadcastTensors[replica][tensor]);
39 std::vector<std::vector<at::Tensor>> bucketTensors(
40 std::vector<at::Tensor>& tensors,
43 std::vector<std::vector<at::Tensor>> bucketedTensors;
45 torch::utils::take_tensors(tensors, bucketSize, fineGrained);
47 bucketedTensors.reserve(tensorGroups.size());
48 for (
auto& tensorGroup : tensorGroups) {
49 bucketedTensors.push_back(std::move(tensorGroup.tensors));
51 return bucketedTensors;
54 void distBroadcastCoalesced(
55 ProcessGroup& processGroup,
56 std::vector<at::Tensor>& tensors,
59 std::vector<std::vector<at::Tensor>> bucketedTensors =
60 bucketTensors(tensors, bufferSize, fineGrained);
64 std::vector<std::vector<at::Tensor>> flatTensors;
65 std::vector<std::shared_ptr<ProcessGroup::Work>> work;
66 flatTensors.reserve(bucketedTensors.size());
67 work.reserve(bucketedTensors.size());
68 for (
const auto& tensorBucket : bucketedTensors) {
71 flatTensors.push_back({torch::utils::flatten_dense_tensors(tensorBucket)});
72 BroadcastOptions broadcastOptions;
73 broadcastOptions.rootRank = 0;
74 broadcastOptions.rootTensor = 0;
78 processGroup.broadcast(flatTensors.back(), broadcastOptions));
82 for (
size_t bucket = 0; bucket < bucketedTensors.size(); ++bucket) {
83 auto& tensors = bucketedTensors[bucket];
86 torch::utils::unflatten_dense_tensors(flatTensors[bucket][0], tensors);
87 AT_ASSERT(synced.size() == tensors.size());
88 for (
size_t i = 0; i < synced.size(); ++i) {
90 tensors[i].copy_(synced[i],
true);
96 ProcessGroup& processGroup,
97 std::vector<std::vector<at::Tensor>>& parameterData,
98 std::vector<std::vector<at::Tensor>>& bufferData,
99 const std::vector<int64_t>& devices,
100 int64_t broadcastBucketSize,
101 bool broadcastBuffers) {
102 AT_ASSERT(!parameterData.empty());
103 AT_ASSERT(!bufferData.empty());
104 AT_ASSERT(!devices.empty());
107 if (devices.size() > 1) {
111 auto result = torch::cuda::broadcast_coalesced(
112 parameterData[0], devices, broadcastBucketSize);
113 copyBroadcastTensorsToReplicas(result, parameterData);
116 if (broadcastBuffers && !bufferData[0].empty()) {
118 distBroadcastCoalesced(processGroup, bufferData[0], broadcastBucketSize);
120 if (devices.size() > 1) {
121 auto result = torch::cuda::broadcast_coalesced(
122 bufferData[0], devices, broadcastBucketSize);
123 copyBroadcastTensorsToReplicas(result, bufferData);
128 std::tuple<std::shared_ptr<ProcessGroup::Work>,
at::Tensor> queueReduction(
129 ProcessGroup& processGroup,
130 std::vector<std::vector<at::Tensor>>& gradsBatch,
131 const std::vector<int64_t>& devices) {
132 AT_ASSERT(!gradsBatch.empty());
133 AT_ASSERT(!devices.empty());
136 std::vector<at::cuda::CUDAEvent> events;
137 events.resize(devices.size());
142 std::vector<at::cuda::CUDAStream> workerStreams;
143 for (
size_t devIdx = 0; devIdx < devices.size(); ++devIdx) {
145 events[devIdx].record();
146 workerStreams.push_back(
147 at::cuda::getStreamFromPool(
false, devices[devIdx]));
149 events[devIdx].block(workerStreams.back());
155 std::vector<at::Tensor> gradsBatchCoalesced;
156 for (
size_t devIdx = 0; devIdx < devices.size(); ++devIdx) {
158 gradsBatchCoalesced.push_back(
159 torch::utils::flatten_dense_tensors(gradsBatch[devIdx]));
162 if (devices.size() > 1) {
163 torch::cuda::nccl::reduce(gradsBatchCoalesced, 0);
166 gradsBatchCoalesced[0] /= processGroup.getSize();
168 std::vector<at::Tensor> allreduceInput = {gradsBatchCoalesced[0]};
169 auto reductionWork = processGroup.allreduce(allreduceInput);
171 return std::make_tuple(reductionWork, gradsBatchCoalesced[0]);
175 std::shared_ptr<ProcessGroup::Work>& reductionWork,
176 std::vector<at::Tensor>& gradsBatch,
185 reductionWork->wait();
187 std::vector<at::Tensor> gradsReduced =
188 torch::utils::unflatten_dense_tensors(gradsBatchCoalesced, gradsBatch);
190 AT_ASSERT(gradsReduced.size() == gradsBatch.size());
192 for (
size_t i = 0; i < gradsReduced.size(); ++i) {
193 gradsBatch[i].copy_(gradsReduced[i]);
198 event.record(workerStream);
203 event.block(cudaGuard.original_stream());
A variant of StreamGuard that is specialized for CUDA.
A variant of DeviceGuard that is specialized for CUDA.