1 #include "gtest/gtest.h" 4 #include <ATen/cuda/CUDAContext.h> 5 #include <caffe2/core/init.h> 6 #include <caffe2/core/operator.h> 7 #include <caffe2/core/context_gpu.h> 8 #include <caffe2/utils/math.h> 14 CUDA_ENFORCE(cudaMemcpy(&result, addr,
sizeof(
T), cudaMemcpyDefault));
19 void cuda_set(
T* addr,
T value) {
20 CUDA_ENFORCE(cudaMemcpy(addr, &value,
sizeof(
T), cudaMemcpyDefault));
23 TEST(CUDACaffe2ToPytorch, SimpleLegacy) {
24 if (!at::cuda::is_available())
return;
26 c2_tensor.Resize(4, 4);
27 auto data = c2_tensor.mutable_data<int64_t>();
30 caffe2::math::Set<int64_t>(16, 777, data, &context);
33 ASSERT_TRUE(&at_tensor.type() !=
nullptr);
34 ASSERT_TRUE(at_tensor.is_cuda());
36 auto at_cpu = at_tensor.cpu();
37 auto it = at_cpu.data<int64_t>();
38 for (int64_t i = 0; i < 16; i++) {
39 ASSERT_EQ(it[i], 777);
43 TEST(CUDACaffe2ToPytorch, Simple) {
44 if (!at::cuda::is_available())
return;
46 caffe2::empty({4, 4}, at::dtype<int64_t>().device(caffe2::CUDA));
47 auto data = c2_tensor.mutable_data<int64_t>();
50 caffe2::math::Set<int64_t>(16, 777, data, &context);
53 ASSERT_TRUE(&at_tensor.type() !=
nullptr);
54 ASSERT_TRUE(at_tensor.is_cuda());
56 auto at_cpu = at_tensor.cpu();
57 auto it = at_cpu.data<int64_t>();
58 for (int64_t i = 0; i < 16; i++) {
59 ASSERT_EQ(it[i], 777);
63 TEST(CUDACaffe2ToPytorch, Op) {
64 if (!at::cuda::is_available())
return;
66 caffe2::empty({3, 3}, at::dtype<int64_t>().device(caffe2::CUDA));
67 auto data = c2_tensor.mutable_data<int64_t>();
70 caffe2::math::Set<int64_t>(9, 111, data, &context);
73 ASSERT_TRUE(at_tensor.is_cuda());
75 ASSERT_EQ(at::sum(at_tensor).item<int64_t>(), 999);
78 TEST(CUDAPytorchToCaffe2, Op) {
79 if (!at::cuda::is_available())
return;
83 auto at_tensor_a = at::ones({5, 5}, at::dtype(at::kFloat).device(at::kCUDA));
84 auto at_tensor_b = at::ones({5, 5}, at::dtype(at::kFloat).device(at::kCUDA));
85 auto at_tensor_c = at::ones({5, 5}, at::dtype(at::kFloat).device(at::kCUDA));
93 BlobSetTensor(workspace.
CreateBlob(
"c"), c2_tensor_from_aten.Alias());
97 auto op = net.add_op();
103 op->mutable_device_option()->set_device_type(caffe2::PROTO_CUDA);
106 workspace.RunNetOnce(net);
109 ASSERT_EQ(result.GetDeviceType(), caffe2::CUDA);
111 auto data = result.data<
float>();
112 for (int64_t i = 0; i < 25; i++) {
113 ASSERT_EQ(cuda_get(data + i), 3.0);
116 ASSERT_TRUE(at_result.is_cuda());
117 ASSERT_EQ(at::sum(at_result).item<float>(), 75);
120 TEST(CUDAPytorchToCaffe2, SharedStorageWrite) {
121 if (!at::cuda::is_available())
return;
122 auto at_tensor_a = at::ones({5, 5}, at::dtype(at::kFloat).device(at::kCUDA));
123 auto at_tensor_b = at_tensor_a.view({25});
129 cuda_set<float>(c2_tensor_a.mutable_data<
float>() + 1, 123);
130 ASSERT_EQ(cuda_get(c2_tensor_b.mutable_data<
float>() + 1), 123);
131 ASSERT_EQ(at_tensor_a[0][1].item().to<float>(), 123);
132 ASSERT_EQ(at_tensor_b[1].item().to<float>(), 123);
135 TEST(CUDAPytorchToCaffe2, MutualResizes) {
136 if (!at::cuda::is_available())
return;
137 auto at_tensor = at::ones({5, 5}, at::dtype(at::kFloat).device(at::kCUDA));
142 cuda_set<float>(c2_tensor.mutable_data<
float>(), 123);
143 ASSERT_EQ(at_tensor[0][0].item().to<float>(), 123);
146 at_tensor.resize_({4, 4});
147 cuda_set<float>(c2_tensor.mutable_data<
float>() + 1, 234);
148 ASSERT_EQ(at_tensor[0][1].item().to<float>(), 234);
151 at_tensor.resize_({6, 6});
152 cuda_set<float>(c2_tensor.mutable_data<
float>() + 2, 345);
153 ASSERT_EQ(at_tensor[0][2].item().to<float>(), 345);
154 ASSERT_EQ(c2_tensor.sizes()[0], 6);
155 ASSERT_EQ(c2_tensor.sizes()[1], 6);
159 c2_tensor.Resize(7, 7);
160 cuda_set<float>(c2_tensor.mutable_data<
float>() + 3, 456);
161 ASSERT_EQ(at_tensor[0][3].item().to<float>(), 456);
162 ASSERT_EQ(at_tensor.sizes()[0], 7);
163 ASSERT_EQ(at_tensor.sizes()[1], 7);
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
const T & Get() const
Gets the const reference of the stored object.