Caffe2 - C++ API
A deep learning, cross platform ML framework
extension_backend_test.cpp
1 #include <gtest/gtest.h>
2 
3 #include <ATen/ATen.h>
4 #include <ATen/NativeFunctions.h>
5 #include <ATen/ExtensionBackendRegistration.h>
6 
7 using namespace at;
8 
9 static int test_int;
10 
11 Tensor empty_override(IntArrayRef size, const TensorOptions & options) {
12  test_int = 1;
13  auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
14  Storage(
15  caffe2::TypeMeta::Make<float>(), 0, at::DataPtr(nullptr, Device(DeviceType::MSNPU, 1)), nullptr, false),
16  MSNPUTensorId(),
17  false);
18  return Tensor(std::move(tensor_impl));
19 }
20 
21 Tensor empty_like_override(const Tensor & self, const TensorOptions & options) {
22  test_int = 2;
23  return self;
24 }
25 
26 Tensor add_override(const Tensor & a, const Tensor & b , Scalar c) {
27  test_int = 3;
28  return a;
29 }
30 
31 TEST(BackendExtensionTest, TestRegisterOp) {
32  EXPECT_ANY_THROW(empty({5, 5}, at::kMSNPU));
33  register_extension_backend_op(
34  Backend::MSNPU,
35  "empty(IntArrayRef size, TensorOptions options) -> Tensor", &empty_override);
36  Tensor a = empty({5, 5}, at::kMSNPU);
37  ASSERT_EQ(a.device().type(), at::kMSNPU);
38  ASSERT_EQ(a.device().index(), 1);
39  ASSERT_EQ(a.dtype(), caffe2::TypeMeta::Make<float>());
40  ASSERT_EQ(test_int, 1);
41 
42  EXPECT_ANY_THROW(empty_like(a, at::kMSNPU));
43  register_extension_backend_op(
44  Backend::MSNPU,
45  "empty_like(Tensor self, TensorOptions options) -> Tensor", &empty_like_override);
46  Tensor b = empty_like(a, at::kMSNPU);
47  ASSERT_EQ(test_int, 2);
48 
49  EXPECT_ANY_THROW(add(a, b));
50  register_extension_backend_op(
51  Backend::MSNPU,
52  "add(Tensor self, Tensor other, Scalar alpha) -> Tensor", &add_override);
53  add(a, b);
54  ASSERT_EQ(test_int, 3);
55 
56  // Ensure that non-MSNPU operator still works
57  Tensor d = empty({5, 5}, at::kCPU);
58  ASSERT_EQ(d.device().type(), at::kCPU);
59 
60  // Attempt to register on a schema that has already has a function
61  EXPECT_ANY_THROW(
62  register_extension_backend_op(
63  Backend::MSNPU,
64  "empty(IntArrayRef size, TensorOptions options) -> Tensor", &empty_override)
65  );
66 }
caffe2::TypeMeta dtype() const noexcept
Returns a Tensor&#39;s dtype (TypeMeta). Defined in TensorMethods.h.
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
Device device() const
Returns a Tensor&#39;s device.
Flush-To-Zero and Denormals-Are-Zero mode.
DeviceIndex index() const noexcept
Returns the optional index.
Definition: Device.h:70
DeviceType type() const noexcept
Returns the type of device this is.
Definition: Device.h:65