Caffe2 - C++ API
A deep learning, cross platform ML framework
ScalarType.h
1 #pragma once
2 
3 #include <c10/util/ArrayRef.h>
4 #include <c10/util/Half.h>
5 #include <c10/util/Optional.h>
6 #include <c10/util/typeid.h>
7 
8 #include <cstdint>
9 #include <iostream>
10 #include <complex>
11 
12 namespace c10 {
13 
14 // NB: Order matters for this macro; it is relied upon in
15 // _promoteTypesLookup and the serialization format.
16 #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
17 _(uint8_t,Byte,i) /* 0 */ \
18 _(int8_t,Char,i) /* 1 */ \
19 _(int16_t,Short,i) /* 2 */ \
20 _(int,Int,i) /* 3 */ \
21 _(int64_t,Long,i) /* 4 */ \
22 _(at::Half,Half,d) /* 5 */ \
23 _(float,Float,d) /* 6 */ \
24 _(double,Double,d) /* 7 */ \
25 _(at::ComplexHalf,ComplexHalf,z) /* 8 */ \
26 _(std::complex<float>,ComplexFloat,z) /* 9 */ \
27 _(std::complex<double>,ComplexDouble,z) /* 10 */ \
28 _(bool,Bool,i) /* 11 */
29 
30 // If you want to support ComplexHalf for real, replace occurrences
31 // of this macro with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX. But
32 // beware: convert() doesn't work for all the conversions you need...
33 #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(_) \
34 _(uint8_t,Byte,i) \
35 _(int8_t,Char,i) \
36 _(int16_t,Short,i) \
37 _(int,Int,i) \
38 _(int64_t,Long,i) \
39 _(at::Half,Half,d) \
40 _(float,Float,d) \
41 _(double,Double,d) \
42 _(std::complex<float>,ComplexFloat,z) \
43 _(std::complex<double>,ComplexDouble,z) \
44 _(bool,Bool,i)
45 
46 #define AT_FORALL_SCALAR_TYPES(_) \
47 _(uint8_t,Byte,i) \
48 _(int8_t,Char,i) \
49 _(int16_t,Short,i) \
50 _(int,Int,i) \
51 _(int64_t,Long,i) \
52 _(at::Half,Half,d) \
53 _(float,Float,d) \
54 _(double,Double,d)
55 
56 #define AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(_) \
57 _(uint8_t,Byte,i) \
58 _(int8_t,Char,i) \
59 _(int16_t,Short,i) \
60 _(int,Int,i) \
61 _(int64_t,Long,i) \
62 _(float,Float,d) \
63 _(double,Double,d)
64 
65 enum class ScalarType : int8_t {
66 #define DEFINE_ENUM(_1,n,_2) \
67  n,
68  AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ENUM)
69 #undef DEFINE_ENUM
70  Undefined,
71  NumOptions
72 };
73 
74 static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) {
75 #define DEFINE_CASE(ctype,name,_) \
76  case ScalarType:: name : return caffe2::TypeMeta::Make<ctype>();
77 
78  switch(scalar_type) {
79  AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
80  case ScalarType::Undefined: return caffe2::TypeMeta();
81  default: AT_ERROR("Unrecognized Scalartype ", scalar_type, " (please report this error)");
82  }
83 #undef DEFINE_CASE
84 }
85 
86 static inline c10::optional<ScalarType> tryTypeMetaToScalarType(caffe2::TypeMeta dtype) {
87 #define DEFINE_IF(ctype, name, _) \
88  if (dtype == caffe2::TypeMeta::Make<ctype>()) { \
89  return {ScalarType::name}; \
90  }
91  AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_IF)
92 #undef DEFINE_IF
93  if (dtype == caffe2::TypeMeta()) {
94  return {ScalarType::Undefined};
95  }
96  return c10::nullopt;
97 }
98 
99 static inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) {
100  if (auto scalar_type = tryTypeMetaToScalarType(dtype)) {
101  return *scalar_type;
102  }
103  AT_ERROR("Unsupported TypeMeta in ATen: ", dtype, " (please report this error)");
104 }
105 
106 static inline bool operator==(ScalarType t, caffe2::TypeMeta m) {
107  if (auto mt = tryTypeMetaToScalarType(m)) {
108  return (*mt) == t;
109  }
110  return false;
111 }
112 
113 static inline bool operator==(caffe2::TypeMeta m, ScalarType t) {
114  return t == m;
115 }
116 
117 #define DEFINE_CONSTANT(_,name,_2) \
118 constexpr ScalarType k##name = ScalarType::name;
119 
120 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CONSTANT)
121 #undef DEFINE_CONSTANT
122 
123 static inline const char * toString(ScalarType t) {
124 #define DEFINE_CASE(_,name,_2) \
125  case ScalarType:: name : return #name;
126 
127  switch(t) {
128  AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
129  default:
130  return "UNKNOWN_SCALAR";
131  }
132 #undef DEFINE_CASE
133 }
134 
135 static inline size_t elementSize(ScalarType t) {
136 #define CASE_ELEMENTSIZE_CASE(ctype,name,_2) \
137  case ScalarType:: name : return sizeof(ctype);
138 
139  switch(t) {
140  AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(CASE_ELEMENTSIZE_CASE)
141  default:
142  AT_ERROR("Unknown ScalarType");
143  }
144 #undef CASE_ELEMENTSIZE_CASE
145 }
146 
147 static inline bool isIntegralType(ScalarType t) {
148  return (t == ScalarType::Byte ||
149  t == ScalarType::Char ||
150  t == ScalarType::Int ||
151  t == ScalarType::Long ||
152  t == ScalarType::Short);
153 }
154 
155 static inline bool isFloatingType(ScalarType t) {
156  return (t == ScalarType::Double ||
157  t == ScalarType::Float ||
158  t == ScalarType::Half);
159 }
160 
161 static inline bool isComplexType(ScalarType t) {
162  return (t == ScalarType::ComplexHalf ||
163  t == ScalarType::ComplexFloat ||
164  t == ScalarType::ComplexDouble);
165 }
166 
167 static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
168  // This is generated according to NumPy's promote_types
169  constexpr auto u1 = ScalarType::Byte;
170  constexpr auto i1 = ScalarType::Char;
171  constexpr auto i2 = ScalarType::Short;
172  constexpr auto i4 = ScalarType::Int;
173  constexpr auto i8 = ScalarType::Long;
174  constexpr auto f2 = ScalarType::Half;
175  constexpr auto f4 = ScalarType::Float;
176  constexpr auto f8 = ScalarType::Double;
177  constexpr auto b1 = ScalarType::Bool;
178  constexpr auto ud = ScalarType::Undefined;
179  if (a == ud || b == ud) {
180  return ScalarType::Undefined;
181  }
182  if (isComplexType(a) || isComplexType(b)) {
183  AT_ERROR("promoteTypes with complex numbers is not handled yet; figure out what the correct rules should be");
184  }
185  static constexpr ScalarType _promoteTypesLookup
186  [static_cast<int>(ScalarType::NumOptions)]
187  [static_cast<int>(ScalarType::NumOptions)] = {
188  /* u1 i1 i2 i4 i8 f2 f4 f8 b1 */
189  /* u1 */ { u1, i2, i2, i4, i8, f2, f4, f8, u1 },
190  /* i1 */ { i2, i1, i2, i4, i8, f2, f4, f8, i1 },
191  /* i2 */ { i2, i2, i2, i4, i8, f2, f4, f8, i2 },
192  /* i4 */ { i4, i4, i4, i4, i8, f2, f4, f8, i4 },
193  /* i8 */ { i8, i8, i8, i8, i8, f2, f4, f8, i8 },
194  /* f2 */ { f2, f2, f2, f2, f2, f2, f4, f8, f2 },
195  /* f4 */ { f4, f4, f4, f4, f4, f4, f4, f8, f4 },
196  /* f8 */ { f8, f8, f8, f8, f8, f8, f8, f8, f8 },
197  /* b1 */ { u1, i1, i2, i4, i8, f2, f4, f8, b1 },
198  };
199  return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
200 }
201 
202 inline std::ostream& operator<<(
203  std::ostream& stream,
204  at::ScalarType scalar_type) {
205  return stream << toString(scalar_type);
206 }
207 
208 } // namespace c10
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition: alias_info.h:7
TypeMeta is a thin class that allows us to store the type of a container such as a blob...
Definition: typeid.h:324
TensorOptions dtype(caffe2::TypeMeta dtype)
Convenience function that returns a TensorOptions object with the dtype set to the given one...