3 #include <c10/util/ArrayRef.h> 4 #include <c10/util/Half.h> 5 #include <c10/util/Optional.h> 6 #include <c10/util/typeid.h> 16 #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \ 25 _(at::ComplexHalf,ComplexHalf,z) \ 26 _(std::complex<float>,ComplexFloat,z) \ 27 _(std::complex<double>,ComplexDouble,z) \ 33 #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(_) \ 42 _(std::complex<float>,ComplexFloat,z) \ 43 _(std::complex<double>,ComplexDouble,z) \ 46 #define AT_FORALL_SCALAR_TYPES(_) \ 56 #define AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(_) \ 65 enum class ScalarType : int8_t {
66 #define DEFINE_ENUM(_1,n,_2) \ 68 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ENUM)
74 static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) {
75 #define DEFINE_CASE(ctype,name,_) \ 76 case ScalarType:: name : return caffe2::TypeMeta::Make<ctype>(); 79 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
80 case ScalarType::Undefined: return
caffe2::TypeMeta();
81 default: AT_ERROR("Unrecognized Scalartype ", scalar_type, " (please report this error)");
87 #define DEFINE_IF(ctype, name, _) \ 88 if (dtype == caffe2::TypeMeta::Make<ctype>()) { \ 89 return {ScalarType::name}; \ 91 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_IF)
94 return {ScalarType::Undefined};
100 if (
auto scalar_type = tryTypeMetaToScalarType(dtype)) {
103 AT_ERROR(
"Unsupported TypeMeta in ATen: ", dtype,
" (please report this error)");
107 if (
auto mt = tryTypeMetaToScalarType(m)) {
117 #define DEFINE_CONSTANT(_,name,_2) \ 118 constexpr ScalarType k##name = ScalarType::name; 120 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CONSTANT)
121 #undef DEFINE_CONSTANT 123 static inline const char * toString(ScalarType t) {
124 #define DEFINE_CASE(_,name,_2) \ 125 case ScalarType:: name : return #name; 128 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
130 return "UNKNOWN_SCALAR";
135 static inline size_t elementSize(ScalarType t) {
136 #define CASE_ELEMENTSIZE_CASE(ctype,name,_2) \ 137 case ScalarType:: name : return sizeof(ctype); 140 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(CASE_ELEMENTSIZE_CASE)
142 AT_ERROR("Unknown ScalarType");
144 #undef CASE_ELEMENTSIZE_CASE 147 static inline bool isIntegralType(ScalarType t) {
148 return (t == ScalarType::Byte ||
149 t == ScalarType::Char ||
150 t == ScalarType::Int ||
151 t == ScalarType::Long ||
152 t == ScalarType::Short);
155 static inline bool isFloatingType(ScalarType t) {
156 return (t == ScalarType::Double ||
157 t == ScalarType::Float ||
158 t == ScalarType::Half);
161 static inline bool isComplexType(ScalarType t) {
162 return (t == ScalarType::ComplexHalf ||
163 t == ScalarType::ComplexFloat ||
164 t == ScalarType::ComplexDouble);
167 static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
169 constexpr
auto u1 = ScalarType::Byte;
170 constexpr
auto i1 = ScalarType::Char;
171 constexpr
auto i2 = ScalarType::Short;
172 constexpr
auto i4 = ScalarType::Int;
173 constexpr
auto i8 = ScalarType::Long;
174 constexpr
auto f2 = ScalarType::Half;
175 constexpr
auto f4 = ScalarType::Float;
176 constexpr
auto f8 = ScalarType::Double;
177 constexpr
auto b1 = ScalarType::Bool;
178 constexpr
auto ud = ScalarType::Undefined;
179 if (a == ud || b == ud) {
180 return ScalarType::Undefined;
182 if (isComplexType(a) || isComplexType(b)) {
183 AT_ERROR(
"promoteTypes with complex numbers is not handled yet; figure out what the correct rules should be");
185 static constexpr ScalarType _promoteTypesLookup
186 [
static_cast<int>(ScalarType::NumOptions)]
187 [static_cast<int>(ScalarType::NumOptions)] = {
189 { u1, i2, i2, i4, i8, f2, f4, f8, u1 },
190 { i2, i1, i2, i4, i8, f2, f4, f8, i1 },
191 { i2, i2, i2, i4, i8, f2, f4, f8, i2 },
192 { i4, i4, i4, i4, i8, f2, f4, f8, i4 },
193 { i8, i8, i8, i8, i8, f2, f4, f8, i8 },
194 { f2, f2, f2, f2, f2, f2, f4, f8, f2 },
195 { f4, f4, f4, f4, f4, f4, f4, f8, f4 },
196 { f8, f8, f8, f8, f8, f8, f8, f8, f8 },
197 { u1, i1, i2, i4, i8, f2, f4, f8, b1 },
199 return _promoteTypesLookup[
static_cast<int>(a)][static_cast<int>(b)];
202 inline std::ostream& operator<<(
203 std::ostream& stream,
204 at::ScalarType scalar_type) {
205 return stream << toString(scalar_type);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
TensorOptions dtype(caffe2::TypeMeta dtype)
Convenience function that returns a TensorOptions object with the dtype set to the given one...