Caffe2 - C++ API
A deep learning, cross platform ML framework
type.cpp
1 #include <ATen/core/jit_type.h>
2 
3 #include <iostream>
4 
5 namespace c10 {
6 
7 std::ostream& operator<<(std::ostream & out, const Type & t) {
8  if(auto value = t.cast<CompleteTensorType>()) {
9  out << toString(value->scalarType()) << "(";
10  auto& sizes = value->sizes();
11  auto& strides = value->strides();
12  AT_ASSERT(sizes.size() == strides.size());
13  for (size_t i = 0; i < sizes.size(); i++) {
14  if (i > 0) {
15  out << ", ";
16  }
17  // TODO: figure out a good way to output strides, or
18  // add a "debug" printing mode which adds the extra stuff
19  out << sizes[i]; // << "%" << strides[i];
20  int64_t expected = i + 1 < sizes.size() ? sizes[i+1]*strides[i+1] : 1;
21  if (strides[i] != expected) {
22  out << "!"; //mark non-contiguous
23  }
24  }
25  out << ")";
26  } else if (auto value = t.cast<DimensionedTensorType>()) {
27  out << toString(value->scalarType()) << "(";
28  for (int64_t i = 0; i < value->dim(); ++i) {
29  if (i > 0) {
30  out << ", ";
31  }
32  out << "*";
33  }
34  out << ")";
35  } else if(t.kind() == TypeKind::ListType) {
36  auto prim = t.cast<ListType>()->getElementType();
37  out << *prim << "[]";
38  } else if (t.kind() == TypeKind::OptionalType) {
39  auto prim = t.cast<OptionalType>()->getElementType();
40  out << *prim << "?";
41  } else if(t.kind() == TypeKind::FutureType) {
42  auto elem = t.cast<FutureType>()->getElementType();
43  out << "Future[" << *elem << "]";
44  } else if(auto tup = t.cast<TupleType>()) {
45  out << "(";
46  for(size_t i = 0; i < tup->elements().size(); ++i) {
47  if(i > 0)
48  out << ", ";
49  out << *(tup->elements()[i]);
50  }
51  out << ")";
52  } else {
53  out << t.str();
54  }
55  return out;
56 }
57 
58 TensorTypePtr TensorType::get() {
59  static auto value = TensorType::create();
60  return value;
61 }
62 AutogradZeroTensorTypePtr AutogradZeroTensorType::get() {
63  static auto value = AutogradZeroTensorType::create();
64  return value;
65 }
66 NumberTypePtr NumberType::get() {
67  static auto value = NumberType::create();
68  return value;
69 }
70 IntTypePtr IntType::get() {
71  static auto value = IntType::create();
72  return value;
73 }
74 FloatTypePtr FloatType::get() {
75  static auto value = FloatType::create();
76  return value;
77 }
78 BoolTypePtr BoolType::get() {
79  static auto value = BoolType::create();
80  return value;
81 }
82 NoneTypePtr NoneType::get() {
83  static auto value = NoneType::create();
84  return value;
85 }
86 GeneratorTypePtr GeneratorType::get() {
87  static auto value = GeneratorType::create();
88  return value;
89 }
90 StringTypePtr StringType::get() {
91  static auto value = StringType::create();
92  return value;
93 }
94 DeviceObjTypePtr DeviceObjType::get() {
95  static auto value = DeviceObjType::create();
96  return value;
97 }
98 OptionalTypePtr OptionalType::ofTensor() {
99  static auto value = OptionalType::create(TensorType::get());
100  return value;
101 }
102 ListTypePtr ListType::ofTensors() {
103  static auto value = ListType::create(TensorType::get());
104  return value;
105 }
106 ListTypePtr ListType::ofInts() {
107  static auto value = ListType::create(IntType::get());
108  return value;
109 }
110 ListTypePtr ListType::ofFloats() {
111  static auto value = ListType::create(FloatType::get());
112  return value;
113 }
114 ListTypePtr ListType::ofBools() {
115  static auto value = ListType::create(BoolType::get());
116  return value;
117 }
118 
119 // why incomplete? You cannot completely recover a type from
120 // an IValue, List[List[int]] and List[List[Tensor]] will both
121 // become ivalue.isGenericList() and cannot be recovered.
122 // The only appropriate place to use this is where you know that
123 // you are only dealing with a subset of objects where you can recover
124 // the type, like in the tracer.
125 TypePtr incompleteInferTypeFrom(const IValue& value) {
126  if (value.isTensor()) {
127  return CompleteTensorType::create(value.toTensor());
128  } else if (value.isDouble()) {
129  return FloatType::get();
130  } else if (value.isInt()) {
131  return IntType::get();
132  } else if (value.isBool()) {
133  return BoolType::get();
134  } else if (value.isString()) {
135  return StringType::get();
136  } else if (value.isIntList()) {
137  return ListType::ofInts();
138  } else if (value.isTensorList()) {
139  return ListType::ofTensors();
140  } else if (value.isBoolList()) {
141  return ListType::ofBools();
142  } else if (value.isDoubleList()) {
143  return ListType::ofFloats();
144  } else if (value.isTuple()) {
145  return TupleType::create(fmap(value.toTuple()->elements(), incompleteInferTypeFrom));
146  } else if (value.isDevice()) {
147  return DeviceObjType::get();
148  }
149  AT_ERROR("Type cannot be accurately recovered from this IValue.");
150 }
151 
152 // This attempts to recover the type from an IValue, including nested Generic
153 // Lists. It only examines the first element (the first of the iterator in the
154 // case of a dict) of each generic container,
155 // and if a generic container is empty returns typevar as the base element.
156 // XXX: only used for better error messages, should not be used elsewhere
157 TypePtr attemptToRecoverType(const IValue& ivalue) {
158  if (ivalue.isGenericList()) {
159  auto& ivalue_list = ivalue.toGenericListRef();
160  if (ivalue_list.size() == 0) {
161  return ListType::create(VarType::create("t"));
162  }
163  return ListType::create(attemptToRecoverType(ivalue_list[0]));
164  }
165  if (ivalue.isGenericDict()) {
166  const auto& dict = ivalue.toGenericDictRef();
167  if (dict.size() == 0) {
168  return DictType::create(VarType::create("t"), VarType::create("t"));
169  }
170  auto item = dict.begin();
171  return DictType::create(
172  attemptToRecoverType(item->first), attemptToRecoverType(item->second));
173  }
174  return incompleteInferTypeFrom(ivalue);
175 }
176 
177 // Checks if input_ivalue is a subvalue of type.
178 bool isSubvalueOf(const IValue& ivalue, TypePtr type) {
179  if (ivalue.isTuple()) {
180  const auto& ivalue_elem = ivalue.toTuple()->elements();
181  auto tuple_type = type->cast<TupleType>();
182  if (!tuple_type || tuple_type->elements().size() != ivalue_elem.size()) {
183  return false;
184  }
185  auto type_elem = tuple_type->elements();
186  bool is_subvalue = true;
187  for (size_t i = 0; i < type_elem.size() && is_subvalue; ++i) {
188  is_subvalue = isSubvalueOf(ivalue_elem[i], type_elem[i]);
189  }
190  return is_subvalue;
191  }
192  if (ivalue.isGenericList()) {
193  auto list_type = type->cast<ListType>();
194  if (!list_type) {
195  return false;
196  }
197  auto& ivalue_list = ivalue.toGenericListRef();
198  auto element_type = list_type->getElementType();
199  return std::all_of(ivalue_list.begin(), ivalue_list.end(), [&](const IValue& list_elem) {
200  return isSubvalueOf(list_elem, element_type);
201  });
202  }
203  if (ivalue.isGenericDict()) {
204  auto dict_type = type->expect<DictType>();
205  const auto& dict = ivalue.toGenericDictRef();
206  return std::all_of(
207  dict.begin(), dict.end(), [=](const std::pair<IValue, IValue>& item) {
208  return isSubvalueOf(item.first, dict_type->getKeyType()) &&
209  isSubvalueOf(item.second, dict_type->getValueType());
210  });
211  }
212  return incompleteInferTypeFrom(ivalue)->isSubtypeOf(type);
213 }
214 
215 c10::optional<TypePtr> tryEitherIsTheSuperType(const TypePtr& t1, const TypePtr& t2) {
216  if (t1->isSubtypeOf(t2)) {
217  return t2;
218  } else if (t2->isSubtypeOf(t1)) {
219  return t1;
220  } else {
221  return c10::nullopt;
222  }
223 }
224 
225 c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2) {
226  //cases that t1 == t2, or t1 is a type refinement of t2 and vice versa
227  if (auto maybe_supertype = tryEitherIsTheSuperType(t1, t2)) {
228  return *maybe_supertype;
229  }
230 
231  // NB: we do not return NumberType because there is not currently enough
232  // operator support for it
233 
234  if (t1->isSubtypeOf(TensorType::get()) && t2->isSubtypeOf(TensorType::get())) {
235  return static_cast<TypePtr>(TensorType::get());;
236  }
237 
238  // if t1 is None and t2 is a concrete type, return Optional[t2] and vice versa
239  if (t1->isSubtypeOf(NoneType::get()) && !t2->isSubtypeOf(NoneType::get())) {
240  return OptionalType::create(t2);
241  } else if (t2->isSubtypeOf(NoneType::get()) && !t1->isSubtypeOf(NoneType::get())) {
242  return OptionalType::create(t1);
243  }
244 
245  //types which contain other types
246  if (t1->cast<ListType>() && t2->cast<ListType>()) {
247  // because we have runtime specializations of lists, e.g. int[] = std::vector<int64_t>
248  // int?[] = std::vector<IValue> we don't allow type coercion,
249  // since t1 & t2 may have different runtime representations.
250 
251  // allow Lists of different tensor types
252  auto unshaped_t1 = unshapedType(t1);
253  auto unshaped_t2 = unshapedType(t2);
254  return tryEitherIsTheSuperType(unshaped_t1, unshaped_t2);
255  } else if(t1->cast<TupleType>() && t2->cast<TupleType>()) {
256  auto tuple1 = t1->cast<TupleType>();
257  auto tuple2 = t2->cast<TupleType>();
258  if (tuple1->elements().size() != tuple2->elements().size()) {
259  return c10::nullopt;
260  }
261  std::vector<TypePtr> elements;
262  for (size_t i = 0; i < tuple1->elements().size(); i++) {
263  if (auto elem = unifyTypes(tuple1->elements().at(i), tuple2->elements().at(i))) {
264  elements.push_back(*elem);
265  } else {
266  return c10::nullopt;
267  }
268  }
269  return static_cast<TypePtr>(TupleType::create(elements));
270  }
271 
272  return c10::nullopt;
273 }
274 
275 MatchTypeReturn matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type_env) {
276  MatchTypeReturn ret;
277  if(!formal->hasFreeVariables()) {
278  ret.type = formal;
279  return ret;
280  }
281 
282  if(auto vt = formal->cast<VarType>()) {
283  auto it = type_env.find(vt->name());
284  if(it == type_env.end()) {
285  type_env[vt->name()] = actual;
286  ret.type = actual;
287  return ret;
288  } else if(auto unified = unifyTypes(it->second, actual)) {
289  type_env[vt->name()] = *unified;
290  ret.type = *unified;
291  return ret;
292  }
293  std::stringstream ss;
294  ss << "type variable '" << vt->name() <<"' previously matched to type " <<
295  it->second->str() << " is matched to type " << actual->str();
296  ret.errMsg = ss.str();
297  return ret;
298  } else if(auto lt_formal = formal->cast<ListType>()) {
299  if(auto lt_actual = actual->cast<ListType>()) {
300  const auto innerType = matchTypeVariables(
301  lt_formal->getElementType(),
302  lt_actual->getElementType(),
303  type_env);
304  if (!innerType.type) {
305  // propagate the errMsg onward
306  return innerType;
307  }
308  ret.type = ListType::create(*innerType.type);
309  return ret;
310  } else {
311  std::stringstream ss;
312  ss << "cannot match a list to " << actual->str();
313  ret.errMsg = ss.str();
314  return ret;
315  }
316  } else if(auto tp_formal = formal->cast<TupleType>()) {
317  if(auto tp_actual = actual->cast<TupleType>()) {
318  if(tp_formal->elements().size() != tp_actual->elements().size()) {
319  ret.errMsg = "cannot match tuples of mismatched size";
320  return ret;
321  }
322  std::vector<TypePtr> elements;
323  for(size_t i = 0; i < tp_formal->elements().size(); ++i) {
324  const auto result = matchTypeVariables(
325  tp_formal->elements()[i],
326  tp_actual->elements()[i],
327  type_env);
328  if (!result.type) {
329  return result;
330  }
331  elements.push_back(*result.type);
332  }
333  ret.type = TupleType::create(std::move(elements));
334  return ret;
335  } else {
336  std::stringstream ss;
337  ss << "cannot match a tuple to " << actual->str();
338  ret.errMsg = ss.str();
339  return ret;
340  }
341  } else if (auto lt_formal = formal->cast<FutureType>()) {
342  if (auto lt_actual = actual->cast<FutureType>()) {
343  const auto innerType = matchTypeVariables(
344  lt_formal->getElementType(), lt_actual->getElementType(), type_env);
345  if (!innerType.type) {
346  return innerType;
347  }
348  ret.type = FutureType::create(*innerType.type);
349  return ret;
350  } else {
351  std::stringstream ss;
352  ss << "cannot match a future to " << actual->str();
353  ret.errMsg = ss.str();
354  return ret;
355  }
356  } else if (auto opt_formal = formal->cast<OptionalType>()) {
357  if (auto opt_actual = actual->cast<OptionalType>()) {
358  const auto optionedType = matchTypeVariables(
359  opt_formal->getElementType(), opt_actual->getElementType(), type_env);
360  if (!optionedType.type) {
361  return optionedType;
362  }
363  ret.type = OptionalType::create(*optionedType.type);
364  return ret;
365  } else if (!actual->isSubtypeOf(NoneType::get())) {
366  // If the actual type is a non-optional, allow matching to the formal if
367  // its element type matches the actual.
368  // Don't match None because it is already an optional (but one of
369  // unknown type).
370  return matchTypeVariables(opt_formal->getElementType(), actual, type_env);
371  } else {
372  ret.errMsg = "cannot match an Optional[T] to None, because there is no way to determine T from None.";
373  return ret;
374  }
375  } else if (auto dict_formal = formal->cast<DictType>()) {
376  if (auto dict_actual = actual->cast<DictType>()) {
377  auto key_type = matchTypeVariables(
378  dict_formal->getKeyType(),
379  dict_actual->getKeyType(),
380  type_env
381  );
382  if (!key_type.type) {
383  return key_type;
384  }
385  auto value_type = matchTypeVariables(
386  dict_formal->getValueType(),
387  dict_actual->getValueType(),
388  type_env
389  );
390  if (!value_type.type) {
391  return value_type;
392  }
393  ret.type = DictType::create(*key_type.type, *value_type.type);
394  return ret;
395  } else {
396  std::stringstream ss;
397  ss << "cannot match a dict to " << actual->str();
398  ret.errMsg = ss.str();
399  return ret;
400  }
401  }
402 
403  AT_ERROR("unhandled free variable container: ", formal->str());
404 }
405 
406 // change return types like List[List[t]] into List[List[int]]
407 CAFFE2_API TypePtr evalTypeVariables(TypePtr type, std::unordered_map<std::string, TypePtr>& type_env) {
408  if(!type->hasFreeVariables())
409  return type;
410 
411  if(auto vt = type->cast<VarType>()) {
412  auto it = type_env.find(vt->name());
413  AT_ASSERTM(it != type_env.end(), "schema has unbound type variable '", vt->name(), "' in its return type");
414  return it->second;
415  } else {
416  auto new_contained = fmap(type->containedTypes(), [&](TypePtr t) {
417  return evalTypeVariables(t, type_env);
418  });
419  return type->withContained(std::move(new_contained));
420  }
421 }
422 
423 
424 const char * typeKindToString(TypeKind kind) {
425 #define CASE_TYPE(T) case TypeKind::T: return #T;
426  switch(kind) {
427  C10_FORALL_TYPES(CASE_TYPE)
428  }
429 #undef CASE_TYPE
430  return "";
431 }
432 
433 bool Type::isSubtypeOf(const TypePtr rhs) const {
434  if(auto rhs_ = rhs->cast<OptionalType>()) {
435  return this->isSubtypeOf(rhs_->getElementType());
436  }
437  return *this == *rhs;
438 }
439 
440 namespace {
441 class ClassTypeRegistry {
442  public:
443  void registerType(std::string name, ClassTypePtr type) {
444  std::lock_guard<std::mutex> g(mutex_);
445  // TODO: new type registrations will override the old ones. Is this safe?
446  reg_[name] = type;
447  }
448 
449  ClassTypePtr getType(const std::string& name) {
450  std::lock_guard<std::mutex> g(mutex_);
451  if (reg_.count(name)) {
452  return reg_.at(name);
453  }
454  return nullptr;
455  }
456 
457  void clear() {
458  std::lock_guard<std::mutex> g(mutex_);
459  reg_.clear();
460  }
461 
462  private:
463  std::mutex mutex_;
464  std::unordered_map<std::string, ClassTypePtr> reg_;
465 };
466 
467 ClassTypeRegistry& getRegistry() {
468  static ClassTypeRegistry r;
469  return r;
470 }
471 } // namespace
472 
473 ClassTypePtr ClassType::create(
474  const std::string& name,
475  std::shared_ptr<Module> module) {
476  auto ptr = ClassTypePtr(new ClassType(name, std::move(module)));
477  getRegistry().registerType(name, ptr);
478  return ptr;
479 }
480 
481 ClassTypePtr ClassType::get(const std::string& name) {
482  return getRegistry().getType(name);
483 }
484 
485 void ClassType::clearRegistry() {
486  getRegistry().clear();
487 }
488 
489 } // namespace c10
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition: alias_info.h:7