1 #include <torch/csrc/python_headers.h> 2 #include <torch/csrc/utils/tensor_dtypes.h> 3 #include <torch/csrc/Dtype.h> 4 #include <torch/csrc/DynamicTypes.h> 5 #include <torch/csrc/Exceptions.h> 6 #include <torch/csrc/autograd/generated/VariableType.h> 7 #include <torch/csrc/utils/tensor_types.h> 9 namespace torch {
namespace utils {
11 static std::pair<std::string, std::string> getDtypeNames(at::ScalarType scalarType) {
13 case at::ScalarType::Byte:
16 return std::make_pair(
"uint8",
"");
17 case at::ScalarType::Char:
20 return std::make_pair(
"int8",
"");
21 case at::ScalarType::Double:
22 return std::make_pair(
"float64",
"double");
23 case at::ScalarType::Float:
24 return std::make_pair(
"float32",
"float");
25 case at::ScalarType::Int:
26 return std::make_pair(
"int32",
"int");
27 case at::ScalarType::Long:
28 return std::make_pair(
"int64",
"long");
29 case at::ScalarType::Short:
30 return std::make_pair(
"int16",
"short");
31 case at::ScalarType::Half:
32 return std::make_pair(
"float16",
"half");
33 case at::ScalarType::ComplexHalf:
34 return std::make_pair(
"complex32",
"");
35 case at::ScalarType::ComplexFloat:
36 return std::make_pair(
"complex64",
"");
37 case at::ScalarType::ComplexDouble:
38 return std::make_pair(
"complex128",
"");
39 case at::ScalarType::Bool:
40 return std::make_pair(
"bool",
"");
42 throw std::runtime_error(
"Unimplemented scalar type");
46 void initializeDtypes() {
47 auto torch_module =
THPObjectPtr(PyImport_ImportModule(
"torch"));
50 #define DEFINE_SCALAR_TYPE(_1,n,_2) at::ScalarType::n, 52 at::ScalarType all_scalar_types[] = {
53 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_SCALAR_TYPE)
56 for (at::ScalarType scalarType: all_scalar_types) {
57 std::string primary_name, legacy_name;
58 std::tie(primary_name, legacy_name) = getDtypeNames(scalarType);
59 std::string name = std::string(PyModule_GetName(torch_module.get())) +
'.' + primary_name;
60 PyObject *dtype = THPDtype_New(scalarType, name);
61 torch::registerDtypeObject((
THPDtype*)dtype, scalarType);
63 if (PyModule_AddObject(torch_module.get(), primary_name.c_str(), dtype) != 0) {
66 if (legacy_name !=
"") {
68 if (PyModule_AddObject(torch_module.get(), legacy_name.c_str(), dtype) != 0) {