Caffe2 - C++ API
A deep learning, cross platform ML framework
msnpu_extension.cpp
1 #include <torch/extension.h>
2 
3 #include <ATen/ExtensionBackendRegistration.h>
4 
5 using namespace at;
6 
7 static int test_int;
8 
9 Tensor get_dtype_tensor(caffe2::TypeMeta dtype) {
10  auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
11  Storage(
12  dtype, 0, at::DataPtr(nullptr, Device(DeviceType::MSNPU, 0)), nullptr, false),
13  MSNPUTensorId(),
14  false);
15  return Tensor(std::move(tensor_impl));
16 }
17 
18 Tensor zeros_override(IntArrayRef size, const TensorOptions & options) {
19  test_int = 0;
20  return get_dtype_tensor(options.dtype());
21 }
22 
23 Tensor add_override(const Tensor & a, const Tensor & b , Scalar c) {
24  test_int = 1;
25  return get_dtype_tensor(a.dtype());
26 }
27 
28 Tensor sum_override(const Tensor & self) {
29  test_int = 2;
30  return get_dtype_tensor(self.dtype());
31 }
32 
33 // needed for sum backwards
34 Tensor expand_override(const Tensor & self, IntArrayRef size, bool implicit) {
35  return get_dtype_tensor(self.dtype());
36 }
37 
38 
39 Tensor kl_div_override(
40  const Tensor & self, const Tensor & target, int64_t reduction) {
41  test_int = 3;
42  return get_dtype_tensor(self.dtype());
43 }
44 
45 Tensor kl_div_backward_override(
46  const Tensor & grad_output,
47  const Tensor & self,
48  const Tensor & target,
49  int64_t reduction) {
50  test_int = 4;
51  return get_dtype_tensor(self.dtype());
52 }
53 
54 // numel and ones_like are needed for autograd backwards
55 int64_t numel_override(const Tensor & self) {
56  return 1;
57 }
58 
59 Tensor ones_like_override(const Tensor & self, const TensorOptions & options) {
60  return get_dtype_tensor(options.dtype());
61 }
62 
63 void init_msnpu_extension() {
64  register_extension_backend_op(
65  Backend::MSNPU,
66  "zeros(IntArrayRef size, TensorOptions options) -> Tensor", &zeros_override);
67  register_extension_backend_op(
68  Backend::MSNPU,
69  "add(Tensor self, Tensor other, Scalar alpha) -> Tensor", &add_override);
70  register_extension_backend_op(
71  Backend::MSNPU,
72  "sum(Tensor self) -> Tensor", &sum_override);
73  register_extension_backend_op(
74  Backend::MSNPU,
75  "expand(Tensor self, IntArrayRef size, bool implicit) -> Tensor",
76  &expand_override);
77  register_extension_backend_op(
78  Backend::MSNPU,
79  "kl_div(Tensor self, Tensor target, int64_t reduction) -> Tensor",
80  &kl_div_override);
81  register_extension_backend_op(
82  Backend::MSNPU,
83  "kl_div_backward(Tensor grad_output, Tensor self, Tensor target, int64_t reduction) -> Tensor",
84  &kl_div_backward_override);
85  register_extension_backend_op(
86  Backend::MSNPU,
87  "numel(Tensor self) -> int64_t", &numel_override);
88  register_extension_backend_op(
89  Backend::MSNPU,
90  "ones_like(Tensor self, TensorOptions options) -> Tensor",
91  &ones_like_override);
92 }
93 
94 // TODO: Extend this to exercise multi-device setting. In that case,
95 // we need to add a thread local variable to track the current device.
97  static constexpr DeviceType static_type = DeviceType::MSNPU;
98  MSNPUGuardImpl() {}
99  MSNPUGuardImpl(DeviceType t) {
100  AT_ASSERT(t == DeviceType::MSNPU);
101  }
102  DeviceType type() const override {
103  return DeviceType::MSNPU;
104  }
105  Device exchangeDevice(Device d) const override {
106  AT_ASSERT(d.type() == DeviceType::MSNPU);
107  AT_ASSERT(d.index() == 0);
108  return d;
109  }
110  Device getDevice() const override {
111  return Device(DeviceType::MSNPU, 0);
112  }
113  void setDevice(Device d) const override {
114  AT_ASSERT(d.type() == DeviceType::MSNPU);
115  AT_ASSERT(d.index() == 0);
116  }
117  void uncheckedSetDevice(Device d) const noexcept override {
118  }
119  Stream getStream(Device d) const noexcept override {
120  return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0));
121  }
122  Stream exchangeStream(Stream s) const noexcept override {
123  return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0));
124  }
125  DeviceIndex deviceCount() const override {
126  return 1;
127  }
128 };
129 
130 constexpr DeviceType MSNPUGuardImpl::static_type;
131 C10_REGISTER_GUARD_IMPL(MSNPU, MSNPUGuardImpl);
132 
133 int get_test_int() {
134  return test_int;
135 }
136 
137 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
138  m.def("init_msnpu_extension", &init_msnpu_extension);
139  m.def("get_test_int", &get_test_int);
140 }
DeviceIndex deviceCount() const override
Get the number of devices.
A stream is a software mechanism used to synchronize launched kernels without requiring explicit sync...
Definition: Stream.h:57
Scalar represents a 0-dimensional tensor which contains a single element.
Definition: Scalar.h:22
Device getDevice() const override
Get the current device.
caffe2::TypeMeta dtype() const noexcept
Returns a Tensor&#39;s dtype (TypeMeta). Defined in TensorMethods.h.
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
int16_t DeviceIndex
An index representing a specific device; e.g., the 1 in GPU 1.
Definition: Device.h:18
C10_NODISCARD TensorOptions dtype(c10::optional< caffe2::TypeMeta > dtype) const noexcept
Return a copy of TensorOptions with dtype set to the given one.
DeviceType type() const override
Return the type of device managed by this guard implementation.
Flush-To-Zero and Denormals-Are-Zero mode.
TypeMeta is a thin class that allows us to store the type of a container such as a blob...
Definition: typeid.h:324
DeviceIndex index() const noexcept
Returns the optional index.
Definition: Device.h:70
DeviceGuardImplInterface represents the virtual interface which provides functionality to provide an ...
DeviceType type() const noexcept
Returns the type of device this is.
Definition: Device.h:65