1 #include <ATen/core/jit_type.h> 7 std::ostream& operator<<(std::ostream & out,
const Type & t) {
8 if(
auto value = t.cast<CompleteTensorType>()) {
9 out << toString(value->scalarType()) <<
"(";
10 auto& sizes = value->sizes();
11 auto& strides = value->strides();
12 AT_ASSERT(sizes.size() == strides.size());
13 for (
size_t i = 0; i < sizes.size(); i++) {
20 int64_t expected = i + 1 < sizes.size() ? sizes[i+1]*strides[i+1] : 1;
21 if (strides[i] != expected) {
26 }
else if (
auto value = t.cast<DimensionedTensorType>()) {
27 out << toString(value->scalarType()) <<
"(";
28 for (int64_t i = 0; i < value->dim(); ++i) {
35 }
else if(t.kind() == TypeKind::ListType) {
36 auto prim = t.cast<ListType>()->getElementType();
38 }
else if (t.kind() == TypeKind::OptionalType) {
39 auto prim = t.cast<OptionalType>()->getElementType();
41 }
else if(t.kind() == TypeKind::FutureType) {
42 auto elem = t.cast<FutureType>()->getElementType();
43 out <<
"Future[" << *elem <<
"]";
44 }
else if(
auto tup = t.cast<TupleType>()) {
46 for(
size_t i = 0; i < tup->elements().size(); ++i) {
49 out << *(tup->elements()[i]);
58 TensorTypePtr TensorType::get() {
59 static auto value = TensorType::create();
62 AutogradZeroTensorTypePtr AutogradZeroTensorType::get() {
63 static auto value = AutogradZeroTensorType::create();
66 NumberTypePtr NumberType::get() {
67 static auto value = NumberType::create();
70 IntTypePtr IntType::get() {
71 static auto value = IntType::create();
74 FloatTypePtr FloatType::get() {
75 static auto value = FloatType::create();
78 BoolTypePtr BoolType::get() {
79 static auto value = BoolType::create();
82 NoneTypePtr NoneType::get() {
83 static auto value = NoneType::create();
86 GeneratorTypePtr GeneratorType::get() {
87 static auto value = GeneratorType::create();
90 StringTypePtr StringType::get() {
91 static auto value = StringType::create();
94 DeviceObjTypePtr DeviceObjType::get() {
95 static auto value = DeviceObjType::create();
98 OptionalTypePtr OptionalType::ofTensor() {
99 static auto value = OptionalType::create(TensorType::get());
102 ListTypePtr ListType::ofTensors() {
103 static auto value = ListType::create(TensorType::get());
106 ListTypePtr ListType::ofInts() {
107 static auto value = ListType::create(IntType::get());
110 ListTypePtr ListType::ofFloats() {
111 static auto value = ListType::create(FloatType::get());
114 ListTypePtr ListType::ofBools() {
115 static auto value = ListType::create(BoolType::get());
125 TypePtr incompleteInferTypeFrom(
const IValue& value) {
126 if (value.isTensor()) {
127 return CompleteTensorType::create(value.toTensor());
128 }
else if (value.isDouble()) {
129 return FloatType::get();
130 }
else if (value.isInt()) {
131 return IntType::get();
132 }
else if (value.isBool()) {
133 return BoolType::get();
134 }
else if (value.isString()) {
135 return StringType::get();
136 }
else if (value.isIntList()) {
137 return ListType::ofInts();
138 }
else if (value.isTensorList()) {
139 return ListType::ofTensors();
140 }
else if (value.isBoolList()) {
141 return ListType::ofBools();
142 }
else if (value.isDoubleList()) {
143 return ListType::ofFloats();
144 }
else if (value.isTuple()) {
145 return TupleType::create(fmap(value.toTuple()->elements(), incompleteInferTypeFrom));
146 }
else if (value.isDevice()) {
147 return DeviceObjType::get();
149 AT_ERROR(
"Type cannot be accurately recovered from this IValue.");
157 TypePtr attemptToRecoverType(
const IValue& ivalue) {
158 if (ivalue.isGenericList()) {
159 auto& ivalue_list = ivalue.toGenericListRef();
160 if (ivalue_list.size() == 0) {
161 return ListType::create(VarType::create(
"t"));
163 return ListType::create(attemptToRecoverType(ivalue_list[0]));
165 if (ivalue.isGenericDict()) {
166 const auto& dict = ivalue.toGenericDictRef();
167 if (dict.size() == 0) {
168 return DictType::create(VarType::create(
"t"), VarType::create(
"t"));
170 auto item = dict.begin();
171 return DictType::create(
172 attemptToRecoverType(item->first), attemptToRecoverType(item->second));
174 return incompleteInferTypeFrom(ivalue);
178 bool isSubvalueOf(
const IValue& ivalue, TypePtr type) {
179 if (ivalue.isTuple()) {
180 const auto& ivalue_elem = ivalue.toTuple()->elements();
181 auto tuple_type = type->cast<TupleType>();
182 if (!tuple_type || tuple_type->elements().size() != ivalue_elem.size()) {
185 auto type_elem = tuple_type->elements();
186 bool is_subvalue =
true;
187 for (
size_t i = 0; i < type_elem.size() && is_subvalue; ++i) {
188 is_subvalue = isSubvalueOf(ivalue_elem[i], type_elem[i]);
192 if (ivalue.isGenericList()) {
193 auto list_type = type->cast<ListType>();
197 auto& ivalue_list = ivalue.toGenericListRef();
198 auto element_type = list_type->getElementType();
199 return std::all_of(ivalue_list.begin(), ivalue_list.end(), [&](
const IValue& list_elem) {
200 return isSubvalueOf(list_elem, element_type);
203 if (ivalue.isGenericDict()) {
204 auto dict_type = type->expect<DictType>();
205 const auto& dict = ivalue.toGenericDictRef();
207 dict.begin(), dict.end(), [=](
const std::pair<IValue, IValue>& item) {
208 return isSubvalueOf(item.first, dict_type->getKeyType()) &&
209 isSubvalueOf(item.second, dict_type->getValueType());
212 return incompleteInferTypeFrom(ivalue)->isSubtypeOf(type);
216 if (t1->isSubtypeOf(t2)) {
218 }
else if (t2->isSubtypeOf(t1)) {
227 if (
auto maybe_supertype = tryEitherIsTheSuperType(t1, t2)) {
228 return *maybe_supertype;
234 if (t1->isSubtypeOf(TensorType::get()) && t2->isSubtypeOf(TensorType::get())) {
235 return static_cast<TypePtr
>(TensorType::get());;
239 if (t1->isSubtypeOf(NoneType::get()) && !t2->isSubtypeOf(NoneType::get())) {
240 return OptionalType::create(t2);
241 }
else if (t2->isSubtypeOf(NoneType::get()) && !t1->isSubtypeOf(NoneType::get())) {
242 return OptionalType::create(t1);
246 if (t1->cast<ListType>() && t2->cast<ListType>()) {
252 auto unshaped_t1 = unshapedType(t1);
253 auto unshaped_t2 = unshapedType(t2);
254 return tryEitherIsTheSuperType(unshaped_t1, unshaped_t2);
255 }
else if(t1->cast<TupleType>() && t2->cast<TupleType>()) {
256 auto tuple1 = t1->cast<TupleType>();
257 auto tuple2 = t2->cast<TupleType>();
258 if (tuple1->elements().size() != tuple2->elements().size()) {
261 std::vector<TypePtr> elements;
262 for (
size_t i = 0; i < tuple1->elements().size(); i++) {
263 if (
auto elem = unifyTypes(tuple1->elements().at(i), tuple2->elements().at(i))) {
264 elements.push_back(*elem);
269 return static_cast<TypePtr
>(TupleType::create(elements));
275 MatchTypeReturn matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type_env) {
277 if(!formal->hasFreeVariables()) {
282 if(
auto vt = formal->cast<VarType>()) {
283 auto it = type_env.find(vt->name());
284 if(it == type_env.end()) {
285 type_env[vt->name()] = actual;
288 }
else if(
auto unified = unifyTypes(it->second, actual)) {
289 type_env[vt->name()] = *unified;
293 std::stringstream ss;
294 ss <<
"type variable '" << vt->name() <<
"' previously matched to type " <<
295 it->second->str() <<
" is matched to type " << actual->str();
296 ret.errMsg = ss.str();
298 }
else if(
auto lt_formal = formal->cast<ListType>()) {
299 if(
auto lt_actual = actual->cast<ListType>()) {
300 const auto innerType = matchTypeVariables(
301 lt_formal->getElementType(),
302 lt_actual->getElementType(),
304 if (!innerType.type) {
308 ret.type = ListType::create(*innerType.type);
311 std::stringstream ss;
312 ss <<
"cannot match a list to " << actual->str();
313 ret.errMsg = ss.str();
316 }
else if(
auto tp_formal = formal->cast<TupleType>()) {
317 if(
auto tp_actual = actual->cast<TupleType>()) {
318 if(tp_formal->elements().size() != tp_actual->elements().size()) {
319 ret.errMsg =
"cannot match tuples of mismatched size";
322 std::vector<TypePtr> elements;
323 for(
size_t i = 0; i < tp_formal->elements().size(); ++i) {
324 const auto result = matchTypeVariables(
325 tp_formal->elements()[i],
326 tp_actual->elements()[i],
331 elements.push_back(*result.type);
333 ret.type = TupleType::create(std::move(elements));
336 std::stringstream ss;
337 ss <<
"cannot match a tuple to " << actual->str();
338 ret.errMsg = ss.str();
341 }
else if (
auto lt_formal = formal->cast<FutureType>()) {
342 if (
auto lt_actual = actual->cast<FutureType>()) {
343 const auto innerType = matchTypeVariables(
344 lt_formal->getElementType(), lt_actual->getElementType(), type_env);
345 if (!innerType.type) {
348 ret.type = FutureType::create(*innerType.type);
351 std::stringstream ss;
352 ss <<
"cannot match a future to " << actual->str();
353 ret.errMsg = ss.str();
356 }
else if (
auto opt_formal = formal->cast<OptionalType>()) {
357 if (
auto opt_actual = actual->cast<OptionalType>()) {
358 const auto optionedType = matchTypeVariables(
359 opt_formal->getElementType(), opt_actual->getElementType(), type_env);
360 if (!optionedType.type) {
363 ret.type = OptionalType::create(*optionedType.type);
365 }
else if (!actual->isSubtypeOf(NoneType::get())) {
370 return matchTypeVariables(opt_formal->getElementType(), actual, type_env);
372 ret.errMsg =
"cannot match an Optional[T] to None, because there is no way to determine T from None.";
375 }
else if (
auto dict_formal = formal->cast<DictType>()) {
376 if (
auto dict_actual = actual->cast<DictType>()) {
377 auto key_type = matchTypeVariables(
378 dict_formal->getKeyType(),
379 dict_actual->getKeyType(),
382 if (!key_type.type) {
385 auto value_type = matchTypeVariables(
386 dict_formal->getValueType(),
387 dict_actual->getValueType(),
390 if (!value_type.type) {
393 ret.type = DictType::create(*key_type.type, *value_type.type);
396 std::stringstream ss;
397 ss <<
"cannot match a dict to " << actual->str();
398 ret.errMsg = ss.str();
403 AT_ERROR(
"unhandled free variable container: ", formal->str());
407 CAFFE2_API TypePtr evalTypeVariables(TypePtr type, std::unordered_map<std::string, TypePtr>& type_env) {
408 if(!type->hasFreeVariables())
411 if(
auto vt = type->cast<VarType>()) {
412 auto it = type_env.find(vt->name());
413 AT_ASSERTM(it != type_env.end(),
"schema has unbound type variable '", vt->name(),
"' in its return type");
416 auto new_contained = fmap(type->containedTypes(), [&](TypePtr t) {
417 return evalTypeVariables(t, type_env);
419 return type->withContained(std::move(new_contained));
424 const char * typeKindToString(TypeKind kind) {
425 #define CASE_TYPE(T) case TypeKind::T: return #T; 427 C10_FORALL_TYPES(CASE_TYPE)
433 bool Type::isSubtypeOf(
const TypePtr rhs)
const {
434 if(
auto rhs_ = rhs->cast<OptionalType>()) {
435 return this->isSubtypeOf(rhs_->getElementType());
437 return *
this == *rhs;
441 class ClassTypeRegistry {
443 void registerType(std::string name, ClassTypePtr type) {
444 std::lock_guard<std::mutex> g(mutex_);
449 ClassTypePtr getType(
const std::string& name) {
450 std::lock_guard<std::mutex> g(mutex_);
451 if (reg_.count(name)) {
452 return reg_.at(name);
458 std::lock_guard<std::mutex> g(mutex_);
464 std::unordered_map<std::string, ClassTypePtr> reg_;
467 ClassTypeRegistry& getRegistry() {
468 static ClassTypeRegistry r;
473 ClassTypePtr ClassType::create(
474 const std::string& name,
475 std::shared_ptr<Module> module) {
476 auto ptr = ClassTypePtr(
new ClassType(name, std::move(module)));
477 getRegistry().registerType(name, ptr);
481 ClassTypePtr ClassType::get(
const std::string& name) {
482 return getRegistry().getType(name);
485 void ClassType::clearRegistry() {
486 getRegistry().clear();
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...