1 #include <gtest/gtest.h> 3 #include <torch/csrc/autograd/functions/comm.h> 4 #include <torch/nn/module.h> 5 #include <torch/nn/modules/linear.h> 6 #include <torch/nn/parallel/data_parallel.h> 7 #include <torch/nn/pimpl.h> 8 #include <torch/types.h> 10 #include <test/cpp/api/support.h> 26 auto input = torch::ones(10, torch::requires_grad(
true));
27 auto output = scatter.apply({input});
29 ASSERT_EQ(output.size(), 2);
30 ASSERT_EQ(output[0].size(0), 5);
31 ASSERT_EQ(output[1].size(0), 5);
33 ASSERT_TRUE(torch::cat({output[0].to(torch::kCPU), output[1].to(torch::kCPU)})
36 torch::Tensor sum = output[0].to({torch::kCUDA, 1}) + output[1];
39 ASSERT_TRUE(input.grad().defined());
40 ASSERT_TRUE(input.grad().device().is_cpu());
41 ASSERT_EQ(input.grad().sum().item<int32_t>(), 10);
47 auto a = torch::ones(5, torch::requires_grad(
true).device(torch::kCUDA, 0));
48 auto b = torch::ones(5, torch::requires_grad(
true).device(torch::kCUDA, 1));
50 auto outputs = gather.apply({a, b});
51 ASSERT_EQ(outputs.size(), 1);
54 ASSERT_EQ(output.size(0), 10);
57 auto chunks = output.chunk(2);
58 ASSERT_TRUE(chunks[0].to({torch::kCUDA, 0}).allclose(a));
59 ASSERT_TRUE(chunks[1].allclose(b));
63 ASSERT_TRUE(a.grad().defined());
65 ASSERT_EQ(a.grad().sum().item<int32_t>(), 5);
67 ASSERT_TRUE(b.grad().defined());
69 ASSERT_EQ(b.grad().sum().item<int32_t>(), 5);
74 auto replicas = parallel::replicate(
76 ASSERT_EQ(replicas.size(), 2);
78 auto original_parameters = linear->parameters();
80 auto replica1_parameters = replicas[0]->parameters();
81 for (
auto& parameter : replica1_parameters) {
82 ASSERT_EQ(parameter.device(),
torch::Device(torch::kCUDA, 0));
84 replicas[0]->to(torch::kCPU);
85 ASSERT_EQ(replica1_parameters.size(), original_parameters.size());
86 for (
size_t i = 0; i < original_parameters.size(); ++i) {
87 ASSERT_TRUE(replica1_parameters[i].allclose(original_parameters[i]));
89 replica1_parameters[i].data<float>() !=
90 original_parameters[i].data<float>());
93 auto replica2_parameters = replicas[1]->parameters();
94 for (
auto& parameter : replica2_parameters) {
95 ASSERT_EQ(parameter.device(),
torch::Device(torch::kCUDA, 1));
97 replicas[1]->to(torch::kCPU);
98 ASSERT_EQ(replica2_parameters.size(), original_parameters.size());
99 for (
size_t i = 0; i < original_parameters.size(); ++i) {
100 ASSERT_TRUE(replica2_parameters[i].allclose(original_parameters[i]));
102 replica2_parameters[i].data<float>() !=
103 original_parameters[i].data<float>());
110 Linear b(std::dynamic_pointer_cast<LinearImpl>(a->clone()));
111 b->to({torch::kCUDA, 0});
113 Linear c(std::dynamic_pointer_cast<LinearImpl>(a->clone()));
114 c->to({torch::kCUDA, 1});
116 std::vector<Linear> modules = {a, b, c};
117 std::vector<torch::Tensor> inputs = {
119 torch::ones({2, 3}, torch::device({torch::kCUDA, 0})),
120 torch::ones({2, 3}, torch::device({torch::kCUDA, 1}))};
122 auto outputs = parallel::parallel_apply(modules, inputs);
124 ASSERT_EQ(outputs.size(), 3);
125 ASSERT_TRUE(outputs[0].device().is_cpu());
127 ASSERT_EQ(outputs[1].device(),
torch::Device(torch::kCUDA, 0));
128 ASSERT_TRUE(outputs[1].to(torch::kCPU).allclose(outputs[0]));
130 ASSERT_EQ(outputs[2].device(),
torch::Device(torch::kCUDA, 1));
131 ASSERT_TRUE(outputs[2].to(torch::kCPU).allclose(outputs[0]));
134 TEST_F(
ParallelTest, ParallelApplyWithDifferentOutputDevice_MultiCUDA) {
137 return torch::ones(5, torch::kInt32);
141 std::vector<std::shared_ptr<M>> modules = {
142 std::make_shared<M>(), std::make_shared<M>(), std::make_shared<M>()};
143 std::vector<torch::Tensor> inputs = {
144 torch::empty({}), torch::empty({}), torch::empty({})};
145 std::vector<torch::Device> devices = {
146 {torch::kCUDA, 1}, {torch::kCUDA, 0}, {torch::kCPU}};
148 auto outputs = parallel::parallel_apply(modules, inputs, devices);
150 ASSERT_EQ(outputs.size(), 3);
151 ASSERT_TRUE(outputs[0].device().is_cuda());
152 ASSERT_EQ(outputs[0].device(),
torch::Device(torch::kCUDA, 1));
154 ASSERT_TRUE(outputs[1].device().is_cuda());
155 ASSERT_EQ(outputs[1].device(),
torch::Device(torch::kCUDA, 0));
157 ASSERT_TRUE(outputs[2].device().is_cpu());
160 TEST_F(
ParallelTest, ParallelApplyRethrowsException_MultiCUDA) {
162 void reset()
override {}
164 throw std::runtime_error(
"Badness!");
168 auto m = std::make_shared<M>();
169 auto input = torch::ones({10, 3});
170 ASSERT_THROWS_WITH(parallel::data_parallel(m, input),
"Badness!");
175 DataParallelPlacesTheOutputOnTheRequestedDevice_MultiCUDA) {
177 void reset()
override {}
180 return torch::ones(3);
183 auto m = std::make_shared<M>();
184 auto input = torch::ones({10, 3});
186 auto output = parallel::data_parallel(
191 ASSERT_TRUE(output.defined());
192 ASSERT_TRUE(output.device().is_cuda());
193 ASSERT_EQ(output.device().index(), 1);
197 auto output = parallel::data_parallel(
202 ASSERT_TRUE(output.defined());
203 ASSERT_TRUE(output.device().is_cuda());
204 ASSERT_EQ(output.device().index(), 1);
208 TEST_F(
ParallelTest, DataParallelUsesAllAvailableCUDADevices_CUDA) {
210 void reset()
override {}
216 auto m = std::make_shared<M>();
217 auto input = torch::ones({10, 3});
218 auto output = parallel::data_parallel(m, input);
220 const auto device_count = torch::cuda::device_count();
221 ASSERT_EQ(output.numel(), device_count);
222 for (
size_t i = 0; i < device_count; ++i) {
223 ASSERT_EQ(output[i].item<int32_t>(), i);
Represents a a compute device on which a tensor is located.
Device device() const
Returns a Tensor's device.
The clone() method in the base Module class does not have knowledge of the concrete runtime type of i...
The base class for all modules in PyTorch.
DeviceIndex index() const noexcept
Returns the optional index.