Caffe2 - C++ API
A deep learning, cross platform ML framework
tensor_dtypes.cpp
1 #include <torch/csrc/python_headers.h>
2 #include <torch/csrc/utils/tensor_dtypes.h>
3 #include <torch/csrc/Dtype.h>
4 #include <torch/csrc/DynamicTypes.h>
5 #include <torch/csrc/Exceptions.h>
6 #include <torch/csrc/autograd/generated/VariableType.h>
7 #include <torch/csrc/utils/tensor_types.h>
8 
9 namespace torch { namespace utils {
10 
11 static std::pair<std::string, std::string> getDtypeNames(at::ScalarType scalarType) {
12  switch(scalarType) {
13  case at::ScalarType::Byte:
14  // no "byte" because byte is signed in numpy and we overload
15  // byte to mean bool often
16  return std::make_pair("uint8", "");
17  case at::ScalarType::Char:
18  // no "char" because it is not consistently signed or unsigned; we want
19  // to move to int8
20  return std::make_pair("int8", "");
21  case at::ScalarType::Double:
22  return std::make_pair("float64", "double");
23  case at::ScalarType::Float:
24  return std::make_pair("float32", "float");
25  case at::ScalarType::Int:
26  return std::make_pair("int32", "int");
27  case at::ScalarType::Long:
28  return std::make_pair("int64", "long");
29  case at::ScalarType::Short:
30  return std::make_pair("int16", "short");
31  case at::ScalarType::Half:
32  return std::make_pair("float16", "half");
33  case at::ScalarType::ComplexHalf:
34  return std::make_pair("complex32", "");
35  case at::ScalarType::ComplexFloat:
36  return std::make_pair("complex64", "");
37  case at::ScalarType::ComplexDouble:
38  return std::make_pair("complex128", "");
39  case at::ScalarType::Bool:
40  return std::make_pair("bool", "");
41  default:
42  throw std::runtime_error("Unimplemented scalar type");
43  }
44 }
45 
46 void initializeDtypes() {
47  auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
48  if (!torch_module) throw python_error();
49 
50 #define DEFINE_SCALAR_TYPE(_1,n,_2) at::ScalarType::n,
51 
52  at::ScalarType all_scalar_types[] = {
53  AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_SCALAR_TYPE)
54  };
55 
56  for (at::ScalarType scalarType: all_scalar_types) {
57  std::string primary_name, legacy_name;
58  std::tie(primary_name, legacy_name) = getDtypeNames(scalarType);
59  std::string name = std::string(PyModule_GetName(torch_module.get())) + '.' + primary_name;
60  PyObject *dtype = THPDtype_New(scalarType, name);
61  torch::registerDtypeObject((THPDtype*)dtype, scalarType);
62  Py_INCREF(dtype);
63  if (PyModule_AddObject(torch_module.get(), primary_name.c_str(), dtype) != 0) {
64  throw python_error();
65  }
66  if (legacy_name != "") {
67  Py_INCREF(dtype);
68  if (PyModule_AddObject(torch_module.get(), legacy_name.c_str(), dtype) != 0) {
69  throw python_error();
70  }
71  }
72  }
73 }
74 
75 }} // namespace torch::utils
Definition: Dtype.h:9
Definition: jit_type.h:17