Caffe2 - C++ API
A deep learning, cross platform ML framework
complex_registration_extension.cpp
1 #include <torch/extension.h>
2 
3 #include <ATen/CPUFloatType.h>
4 #include <ATen/Type.h>
5 #include <ATen/core/VariableHooksInterface.h>
6 #include <ATen/detail/ComplexHooksInterface.h>
7 
8 #include <c10/core/Allocator.h>
9 #include <ATen/CPUGenerator.h>
10 #include <ATen/DeviceGuard.h>
11 #include <ATen/NativeFunctions.h>
12 #include <ATen/Utils.h>
13 #include <ATen/WrapDimUtils.h>
14 #include <c10/util/Half.h>
15 #include <c10/core/TensorImpl.h>
16 #include <c10/core/UndefinedTensorImpl.h>
17 #include <c10/util/Optional.h>
18 
19 #include <cstddef>
20 #include <functional>
21 #include <memory>
22 #include <utility>
23 
24 #include <ATen/Config.h>
25 
26 namespace at {
27 
31  CPUTensorId(),
32  /*is_variable=*/false,
33  /*is_undefined=*/false) {}
34 
35  ScalarType scalarType() const override;
36  caffe2::TypeMeta typeMeta() const override;
37  Backend backend() const override;
38  const char* toString() const override;
39  TypeID ID() const override;
40 
41  Tensor empty(IntArrayRef size, const TensorOptions & options) const override {
42  // Delegate to the appropriate cpu tensor factory
43  const DeviceGuard device_guard(options.device());
44  return at::native::empty_cpu(/* actuals */ size, options);
45  }
46 };
47 
50  void registerComplexTypes(Context* context) const override {
51  context->registerType(
52  Backend::CPU, ScalarType::ComplexFloat, new CPUComplexFloatType());
53  }
54 };
55 
56 ScalarType CPUComplexFloatType::scalarType() const {
57  return ScalarType::ComplexFloat;
58 }
59 
60 caffe2::TypeMeta CPUComplexFloatType::typeMeta() const {
61  return scalarTypeToTypeMeta(ScalarType::ComplexFloat);
62 }
63 
64 Backend CPUComplexFloatType::backend() const {
65  return Backend::CPU;
66 }
67 
68 const char* CPUComplexFloatType::toString() const {
69  return "CPUComplexFloatType";
70 }
71 
72 TypeID CPUComplexFloatType::ID() const {
73  return TypeID::CPUComplexFloat;
74 }
75 
76 REGISTER_COMPLEX_HOOKS(ComplexHooks);
77 
78 } // namespace at
79 
80 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { }
C10_NODISCARD TensorOptions device(c10::optional< Device > device) const noexcept
Return a copy of TensorOptions with device set to the given one, or cleared if device is nullopt...
Backend
This legacy enum class defines the set of backends supported by old school, code generated Type-based...
Definition: Backend.h:23
RAII guard that sets a certain default device in its constructor, and changes it back to the device t...
Definition: DeviceGuard.h:19
Flush-To-Zero and Denormals-Are-Zero mode.
TypeMeta is a thin class that allows us to store the type of a container such as a blob...
Definition: typeid.h:324