Caffe2 - C++ API
A deep learning, cross platform ML framework
tensor_types.cpp
1 #include <Python.h>
2 
3 #include <torch/csrc/utils/tensor_types.h>
4 
5 #include <torch/csrc/autograd/generated/VariableType.h>
6 #include <torch/csrc/Exceptions.h>
7 #include <torch/csrc/tensor/python_tensor.h>
8 
9 #include <sstream>
10 #include <unordered_map>
11 #include <algorithm>
12 
13 using namespace at;
14 
15 namespace torch { namespace utils {
16 
17 static const char* backend_to_string(const at::Type& type) {
18  switch (type.backend()) {
19  case at::Backend::CPU: return "torch";
20  case at::Backend::CUDA: return "torch.cuda";
21  case at::Backend::SparseCPU: return "torch.sparse";
22  case at::Backend::SparseCUDA: return "torch.cuda.sparse";
23  default: AT_ERROR("Unimplemented backend ", type.backend());
24  }
25 }
26 
27 std::string type_to_string(const at::Type& type) {
28  std::ostringstream ss;
29  ss << backend_to_string(type) << "." << toString(type.scalarType()) << "Tensor";
30  return ss.str();
31 }
32 
33 at::Type& type_from_string(const std::string& str) {
34  static std::string cuda_prefix("torch.cuda.");
35  static std::once_flag cpu_once;
36  static std::once_flag cuda_once;
37  static std::unordered_map<std::string, Type*> cpu_map;
38  static std::unordered_map<std::string, Type*> cuda_map;
39 
40  const std::unordered_map<std::string, Type*>* map = nullptr;
41 
42  if (str == "torch.Tensor") {
43  return torch::tensors::get_default_tensor_type();
44  }
45 
46  if (std::mismatch(cuda_prefix.begin(), cuda_prefix.end(), str.begin()).first == cuda_prefix.end()) {
47  // torch.cuda. is prefix of str
48  std::call_once(cuda_once, []() {
49  for (auto type : autograd::VariableType::allCUDATypes()) {
50  cuda_map.emplace(type_to_string(*type), type);
51  }
52  });
53  map = &cuda_map;
54  } else {
55  std::call_once(cpu_once, []() {
56  for (auto type : autograd::VariableType::allCPUTypes()) {
57  cpu_map.emplace(type_to_string(*type), type);
58  }
59  });
60  map = &cpu_map;
61  }
62 
63  auto it = map->find(str);
64  if (it == map->end()) {
65  throw ValueError("invalid type: '%s'", str.c_str());
66  }
67  return *it->second;
68 }
69 
70 std::vector<std::pair<Backend, ScalarType>> all_declared_types() {
71  std::vector<std::pair<Backend, ScalarType>> ret;
72  // can't easily iterate over enum classes
73  std::vector<Backend> backends = { Backend::CPU, Backend::CUDA, Backend::SparseCPU, Backend::SparseCUDA };
74  std::vector<ScalarType> scalar_types = { ScalarType::Byte, ScalarType::Char, ScalarType::Double, ScalarType::Float,
75  ScalarType::Int, ScalarType::Long, ScalarType::Short, ScalarType::Half, ScalarType::Bool};
76  for (auto& backend : backends) {
77  for (auto& scalar_type : scalar_types) {
78  // there are no sparse half or bool types.
79  if ((scalar_type == ScalarType::Half || scalar_type == ScalarType::Bool) && (backend == Backend::SparseCUDA || backend == Backend::SparseCPU)) {
80  continue;
81  }
82  ret.emplace_back(std::make_pair(backend, scalar_type));
83  }
84  }
85 
86  return ret;
87 }
88 
89 }} // namespace torch::utils
Definition: jit_type.h:17
Flush-To-Zero and Denormals-Are-Zero mode.