3 #include <torch/csrc/utils/tensor_types.h> 5 #include <torch/csrc/autograd/generated/VariableType.h> 6 #include <torch/csrc/Exceptions.h> 7 #include <torch/csrc/tensor/python_tensor.h> 10 #include <unordered_map> 15 namespace torch {
namespace utils {
17 static const char* backend_to_string(
const at::Type& type) {
18 switch (type.backend()) {
19 case at::Backend::CPU:
return "torch";
20 case at::Backend::CUDA:
return "torch.cuda";
21 case at::Backend::SparseCPU:
return "torch.sparse";
22 case at::Backend::SparseCUDA:
return "torch.cuda.sparse";
23 default: AT_ERROR(
"Unimplemented backend ", type.backend());
27 std::string type_to_string(
const at::Type& type) {
28 std::ostringstream ss;
29 ss << backend_to_string(type) <<
"." << toString(type.scalarType()) <<
"Tensor";
33 at::Type& type_from_string(
const std::string& str) {
34 static std::string cuda_prefix(
"torch.cuda.");
35 static std::once_flag cpu_once;
36 static std::once_flag cuda_once;
37 static std::unordered_map<std::string, Type*> cpu_map;
38 static std::unordered_map<std::string, Type*> cuda_map;
40 const std::unordered_map<std::string, Type*>* map =
nullptr;
42 if (str ==
"torch.Tensor") {
43 return torch::tensors::get_default_tensor_type();
46 if (std::mismatch(cuda_prefix.begin(), cuda_prefix.end(), str.begin()).first == cuda_prefix.end()) {
48 std::call_once(cuda_once, []() {
49 for (
auto type : autograd::VariableType::allCUDATypes()) {
50 cuda_map.emplace(type_to_string(*type), type);
55 std::call_once(cpu_once, []() {
56 for (
auto type : autograd::VariableType::allCPUTypes()) {
57 cpu_map.emplace(type_to_string(*type), type);
63 auto it = map->find(str);
64 if (it == map->end()) {
65 throw ValueError(
"invalid type: '%s'", str.c_str());
70 std::vector<std::pair<Backend, ScalarType>> all_declared_types() {
71 std::vector<std::pair<Backend, ScalarType>> ret;
73 std::vector<Backend> backends = { Backend::CPU, Backend::CUDA, Backend::SparseCPU, Backend::SparseCUDA };
74 std::vector<ScalarType> scalar_types = { ScalarType::Byte, ScalarType::Char, ScalarType::Double, ScalarType::Float,
75 ScalarType::Int, ScalarType::Long, ScalarType::Short, ScalarType::Half, ScalarType::Bool};
76 for (
auto& backend : backends) {
77 for (
auto& scalar_type : scalar_types) {
79 if ((scalar_type == ScalarType::Half || scalar_type == ScalarType::Bool) && (backend == Backend::SparseCUDA || backend == Backend::SparseCPU)) {
82 ret.emplace_back(std::make_pair(backend, scalar_type));
Flush-To-Zero and Denormals-Are-Zero mode.