1 #include <torch/extension.h> 3 #include <ATen/ExtensionBackendRegistration.h> 10 auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
15 return Tensor(std::move(tensor_impl));
20 return get_dtype_tensor(options.
dtype());
25 return get_dtype_tensor(a.
dtype());
30 return get_dtype_tensor(
self.dtype());
35 return get_dtype_tensor(
self.dtype());
40 const Tensor &
self,
const Tensor & target, int64_t reduction) {
42 return get_dtype_tensor(
self.dtype());
45 Tensor kl_div_backward_override(
46 const Tensor & grad_output,
51 return get_dtype_tensor(
self.dtype());
55 int64_t numel_override(
const Tensor &
self) {
60 return get_dtype_tensor(options.
dtype());
63 void init_msnpu_extension() {
64 register_extension_backend_op(
66 "zeros(IntArrayRef size, TensorOptions options) -> Tensor", &zeros_override);
67 register_extension_backend_op(
69 "add(Tensor self, Tensor other, Scalar alpha) -> Tensor", &add_override);
70 register_extension_backend_op(
72 "sum(Tensor self) -> Tensor", &sum_override);
73 register_extension_backend_op(
75 "expand(Tensor self, IntArrayRef size, bool implicit) -> Tensor",
77 register_extension_backend_op(
79 "kl_div(Tensor self, Tensor target, int64_t reduction) -> Tensor",
81 register_extension_backend_op(
83 "kl_div_backward(Tensor grad_output, Tensor self, Tensor target, int64_t reduction) -> Tensor",
84 &kl_div_backward_override);
85 register_extension_backend_op(
87 "numel(Tensor self) -> int64_t", &numel_override);
88 register_extension_backend_op(
90 "ones_like(Tensor self, TensorOptions options) -> Tensor",
97 static constexpr DeviceType static_type = DeviceType::MSNPU;
100 AT_ASSERT(t == DeviceType::MSNPU);
102 DeviceType
type()
const override {
103 return DeviceType::MSNPU;
106 AT_ASSERT(d.
type() == DeviceType::MSNPU);
107 AT_ASSERT(d.
index() == 0);
111 return Device(DeviceType::MSNPU, 0);
113 void setDevice(
Device d)
const override {
114 AT_ASSERT(d.
type() == DeviceType::MSNPU);
115 AT_ASSERT(d.
index() == 0);
117 void uncheckedSetDevice(
Device d)
const noexcept
override {
120 return Stream(Stream::DEFAULT,
Device(DeviceType::MSNPU, 0));
122 Stream exchangeStream(
Stream s)
const noexcept
override {
123 return Stream(Stream::DEFAULT,
Device(DeviceType::MSNPU, 0));
130 constexpr DeviceType MSNPUGuardImpl::static_type;
137 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
138 m.def(
"init_msnpu_extension", &init_msnpu_extension);
139 m.def(
"get_test_int", &get_test_int);
DeviceIndex deviceCount() const override
Get the number of devices.
A stream is a software mechanism used to synchronize launched kernels without requiring explicit sync...
Scalar represents a 0-dimensional tensor which contains a single element.
Device getDevice() const override
Get the current device.
caffe2::TypeMeta dtype() const noexcept
Returns a Tensor's dtype (TypeMeta). Defined in TensorMethods.h.
Represents a a compute device on which a tensor is located.
int16_t DeviceIndex
An index representing a specific device; e.g., the 1 in GPU 1.
C10_NODISCARD TensorOptions dtype(c10::optional< caffe2::TypeMeta > dtype) const noexcept
Return a copy of TensorOptions with dtype set to the given one.
DeviceType type() const override
Return the type of device managed by this guard implementation.
Flush-To-Zero and Denormals-Are-Zero mode.
DeviceIndex index() const noexcept
Returns the optional index.
DeviceGuardImplInterface represents the virtual interface which provides functionality to provide an ...
DeviceType type() const noexcept
Returns the type of device this is.