Caffe2 - C++ API
A deep learning, cross platform ML framework
jit_type.h
1 #pragma once
2 
3 #include <ATen/core/TensorMethods.h>
4 #include <ATen/core/Type.h>
5 #include <ATen/core/functional.h>
6 #include <ATen/core/interned_strings.h>
7 #include <ATen/core/ivalue.h>
8 #include <c10/util/TypeList.h>
9 #include <caffe2/core/common.h>
10 
11 #include <c10/util/Optional.h>
12 
13 #include <memory>
14 #include <iostream>
15 #include <type_traits>
16 
17 namespace torch {
18 namespace jit {
19 namespace script {
20 struct Module;
21 struct Method;
22 }
23 } // namespace jit
24 } // namespace torch
25 
26 namespace c10 {
27 
28 #define C10_FORALL_TYPES(_) \
29 _(TensorType) \
30 _(DimensionedTensorType) \
31 _(CompleteTensorType) \
32 _(AutogradZeroTensorType) \
33 _(TupleType) \
34 _(ListType) \
35 _(DictType) \
36 _(NumberType) \
37 _(FloatType) \
38 _(FutureType) \
39 _(IntType) \
40 _(NoneType) \
41 _(StringType) \
42 _(GeneratorType) \
43 _(BoolType) \
44 _(OptionalType) \
45 _(VarType) \
46 _(DeviceObjType) \
47 _(ClassType) \
48 
49 enum class TypeKind {
50 #define DEFINE_TYPE(T) T,
51  C10_FORALL_TYPES(DEFINE_TYPE)
52 #undef DEFINE_TYPE
53 };
54 
55 CAFFE2_API const char * typeKindToString(TypeKind kind);
56 
57 #define DEFINE_IS_SUBCLASS(_kind) \
58  bool isSubclass(const TypeKind kind) const override { \
59  return kind == TypeKind::_kind; \
60  }
61 
62 struct Type;
63 using TypePtr = std::shared_ptr<Type>;
64 
65 struct CAFFE2_API Type : std::enable_shared_from_this<Type> {
66 private:
67  TypeKind kind_;
68  template<typename T>
69  static std::shared_ptr<T> sliceType(std::shared_ptr<const T> ptr) {
70  auto result = std::make_shared<typename std::remove_const<T>::type>(*ptr);
71  // XXX: the line above will correctly slice the struct, and make its runtype
72  // type exactly equal to T. However, kind_ is a field of Type, so it will simply
73  // be copied, and we need to fix it in here to match the dynamic type.
74  result->kind_ = T::Kind;
75  return result;
76  }
77 
78 protected:
79  Type(TypeKind kind)
80  : kind_(kind) {}
81 
82 public:
83  virtual bool operator==(const Type& rhs) const = 0;
84 
85  // subtyping relation. By default, we return true for the case
86  // when the type is exactly equal or if this <: T where rhs = Optional[T]
87  virtual bool isSubtypeOf(const TypePtr rhs) const;
88 
89  // If this class can be cast to the kind passed in
90  // This removes the need for RTTI
91  virtual bool isSubclass(const TypeKind kind) const = 0;
92 
93  // How this type will appear in FunctionSchema declarations
94  virtual std::string str() const = 0;
95 
96  // How this type will appear as if it were a type annotation in Python
97  // which is sometimes different than how it appears in declarations (e.g. int[] vs List[int])
98  virtual std::string python_str() const {
99  return str();
100  }
101 
102  TypeKind kind() const {
103  return kind_;
104  }
105 
106  virtual bool requires_grad() const { return false; }
107 
108  // Dynamically cast this object to the subclass indicated by the
109  // template variable, returning nullptr if the cast is invalid.
110  // NOTE: if the cast succeeds, but the casted kind is not the
111  // run-time kind of the type, we also slice the structure, so
112  // that assignments of those types to values don't accidentally
113  // inherit more detailed information from subclasses.
114  template<typename T>
115  std::shared_ptr<T> cast() {
116  std::shared_ptr<T> r = nullptr;
117  if (isSubclass(T::Kind)) {
118  r = std::static_pointer_cast<T>(shared_from_this());
119  }
120  if (!r || T::Kind == kind()) {
121  return r;
122  } else {
123  return sliceType<T>(r);
124  }
125  }
126  template<typename T>
127  std::shared_ptr<const T> cast() const {
128  std::shared_ptr<const T> r = nullptr;
129  if (isSubclass(T::Kind)) {
130  r = std::static_pointer_cast<const T>(shared_from_this());
131  }
132  if (!r || T::Kind == kind()) {
133  return r;
134  } else {
135  return sliceType<T>(r);
136  }
137  }
138  template<typename T>
139  std::shared_ptr<T> expect() {
140  auto r = cast<T>();
141  AT_ASSERT(r);
142  return r;
143  }
144  template<typename T>
145  std::shared_ptr<const T> expect() const {
146  auto r = cast<const T>();
147  AT_ASSERT(r);
148  return r;
149  }
150  virtual ~Type() = default;
151  virtual bool hasFreeVariables() const {
152  return false;
153  }
154  // list of types this type contains, e.g. for a List then element type of a list
155  // for a tuple, the types of the tuple elements
156  virtual at::ArrayRef<TypePtr> containedTypes() const {
157  return {};
158  }
159  // create a new version of this type, replacing its contained types with
160  // contained_types
161  TypePtr withContained(std::vector<TypePtr> contained_types) {
162  auto current_contained = containedTypes();
163  AT_ASSERT(current_contained.size() == contained_types.size());
164  if(current_contained.equals(contained_types)) {
165  return shared_from_this();
166  }
167  return createWithContained(std::move(contained_types));
168  }
169  // per-type constructor, you only need to override this if the containedTypes()
170  // is not empty
171  virtual TypePtr createWithContained(std::vector<TypePtr> contained_types) const {
172  AT_ERROR("type with contained types did not overload createWithContained: ", str());
173  }
174 };
175 
176 inline bool operator!=(const Type & lhs, const Type & rhs) {
177  return !(lhs == rhs);
178 }
179 
180 // common base for all types that have a single sub element
181 // e.g. Future[T], Option[T], List[T]
182 template<TypeKind K, typename T>
183 struct SingleElementType : public Type {
184  static const TypeKind Kind = K;
185  TypePtr getElementType() const {
186  return elem;
187  }
188  bool hasFreeVariables() const override {
189  return has_free_variables_;
190  }
191  at::ArrayRef<TypePtr> containedTypes() const override {
192  return elem;
193  }
194  bool requires_grad() const override {
195  return elem->requires_grad();
196  }
197  bool operator==(const Type& rhs) const override {
198  if(auto rhs_ = rhs.cast<T>()) {
199  return *getElementType() == *rhs_->getElementType();
200  }
201  return false;
202  }
203 protected:
204  SingleElementType(TypePtr elem)
205  : Type(Kind)
206  , elem(std::move(elem))
207  , has_free_variables_(getElementType()->hasFreeVariables()) {}
208 private:
209  TypePtr elem;
210  bool has_free_variables_;
211 };
212 
213 
214 struct OptionalType;
215 using OptionalTypePtr = std::shared_ptr<OptionalType>;
216 // This type represents an optional type, for each element type.
217 // Optional[T] can accept both T and None(nullopt in C++)
218 // Subtype hierarchy for Optional:
219 // 1. Optional[T] isSubtypeOf Optional[R] iff T isSubtypeOf R
220 // 2. T isSubtypeOf Optional[R] if T isSubtypeOf R
221 // Note: NoneType is NOT a subtype of any optional.
222 // instead NoneType is convertable in schema matching to any Optional[T]
223 // it is handled this way because it is not possible to match None to Optional[T]
224 // and extract T. Intead, we always create a None constant instruction
225 // with a particular type: v: Optional[int] = None()
226 struct CAFFE2_API OptionalType: public SingleElementType<TypeKind::OptionalType, OptionalType> {
227  static OptionalTypePtr create(TypePtr element) {
228  return OptionalTypePtr(new OptionalType(std::move(element))); // NOLINT(modernize-make-shared)
229  }
230  DEFINE_IS_SUBCLASS(OptionalType);
231  bool isSubtypeOf(const TypePtr rhs) const override {
232  if(auto rhs_ = rhs->cast<OptionalType>()) {
233  return getElementType()->isSubtypeOf(rhs_->getElementType());
234  }
235  return false;
236  }
237 
238  std::string str() const override {
239  std::stringstream ss;
240  ss << getElementType()->str() << "?";
241  return ss.str();
242  }
243  std::string python_str() const override {
244  std::stringstream ss;
245  ss << "Optional[" << getElementType()->python_str() << "]";
246  return ss.str();
247  }
248 
249  TypePtr createWithContained(std::vector<TypePtr> contained_types) const override {
250  AT_ASSERT(contained_types.size() == 1);
251  return create(contained_types[0]);
252  }
253 
254  // common cast Optional[Tensor] for undefined tensor type
255  static OptionalTypePtr ofTensor();
256 private:
257  OptionalType(TypePtr elem) : SingleElementType(elem) {}
258 };
259 
260 struct TensorType;
261 using TensorTypePtr = std::shared_ptr<TensorType>;
262 // This type represents a single Tensor, with an unknown shape.
263 // Subtype hierarchy for Tensor Types (TensorType as the base type):
264 // CompleteTensorType <: DimensionedTensorType <: TensorType
265 // AutogradZeroTensorType <: TensorType
266 struct CAFFE2_API TensorType : public Type {
267  static TensorTypePtr create() {
268  return TensorTypePtr(new TensorType()); // NOLINT(modernize-make-shared)
269  }
270  DEFINE_IS_SUBCLASS(TensorType);
271 
272  bool requires_grad() const override { return true; }
273 
274  bool operator==(const Type& rhs) const override {
275  return rhs.kind() == kind();
276  }
277  std::string str() const override {
278  return "Tensor";
279  }
280  static const TypeKind Kind = TypeKind::TensorType;
281  // global singleton
282  static TensorTypePtr get();
283 protected:
284  TensorType(TypeKind kind=TypeKind::TensorType)
285  : Type(kind) {}
286 };
287 
289 using AutogradZeroTensorTypePtr = std::shared_ptr<AutogradZeroTensorType>;
290 // This type represents an undefined tensor.
291 struct CAFFE2_API AutogradZeroTensorType : public TensorType {
292  static AutogradZeroTensorTypePtr create() {
293  return AutogradZeroTensorTypePtr(new AutogradZeroTensorType()); // NOLINT(modernize-make-shared)
294  }
295 
296  DEFINE_IS_SUBCLASS(AutogradZeroTensorType);
297 
298  bool requires_grad() const override { return false; }
299 
300  bool operator==(const Type& rhs) const override {
301  return rhs.kind() == kind();
302  }
303  bool isSubtypeOf(const TypePtr rhs) const override {
304  return rhs->kind() == TypeKind::TensorType ||
305  rhs->kind() == TypeKind::AutogradZeroTensorType ||
306  TensorType::isSubtypeOf(rhs);
307  }
308  std::string str() const override {
309  return "UndefinedTensor";
310  }
311 
312  static const TypeKind Kind = TypeKind::AutogradZeroTensorType;
313  // global singleton
314  static AutogradZeroTensorTypePtr get();
315 protected:
316  AutogradZeroTensorType(): TensorType(TypeKind::AutogradZeroTensorType) {}
317 };
318 
319 struct DimensionedTensorType;
320 using DimensionedTensorTypePtr = std::shared_ptr<DimensionedTensorType>;
321 // This type represents a single Tensor with a specific size
322 struct CAFFE2_API DimensionedTensorType : public TensorType {
323  template<typename ... T>
324  static DimensionedTensorTypePtr create( T&& ... all ) {
325  return DimensionedTensorTypePtr(new DimensionedTensorType( std::forward<T>(all)... )); // NOLINT(modernize-make-shared)
326  }
327 
328  at::ScalarType scalarType() const { return scalar_type_; }
329  at::Device device() const { return device_; }
330  int64_t dim() const { return dim_; }
331  bool requires_grad() const override { return requires_grad_; }
332 
333  DimensionedTensorTypePtr toScalarType(at::ScalarType type){
334  auto t = DimensionedTensorType::create(*this);
335  t->scalar_type_ = type;
336  return t;
337  }
338  DimensionedTensorTypePtr withDim(size_t new_dim) {
339  auto t = DimensionedTensorType::create(*this);
340  t->dim_ = new_dim;
341  return t;
342  }
343  DimensionedTensorTypePtr withRequiresGrad(bool req) {
344  auto t = DimensionedTensorType::create(*this);
345  t->requires_grad_ = req;
346  return t;
347  }
348 
349  bool operator==(const Type& rhs) const override {
350  if (rhs.kind() != TypeKind::DimensionedTensorType)
351  return false;
352  auto rt = rhs.expect<DimensionedTensorType>();
353  return scalarType() == rt->scalarType() &&
354  device() == rt->device() &&
355  dim() == rt->dim();
356  }
357  bool isSubtypeOf(const TypePtr rhs) const override {
358  return rhs->kind() == TypeKind::TensorType ||
359  (rhs->kind() == TypeKind::DimensionedTensorType && Type::isSubtypeOf(rhs)) ||
360  TensorType::isSubtypeOf(rhs);
361  }
362  bool isSubclass(const TypeKind kind) const override {
363  return kind == TypeKind::TensorType ||
364  kind == TypeKind::DimensionedTensorType;
365  }
366  std::string str() const override {
367  // str is used for user-facing error messages, where we
368  // don't want to reveal underlying size information.
369  return "Tensor";
370  }
371 
372  static const TypeKind Kind = TypeKind::DimensionedTensorType;
373 
374 protected:
375  DimensionedTensorType(const at::Tensor& tensor, TypeKind kind=TypeKind::DimensionedTensorType)
376  : DimensionedTensorType(tensor.scalar_type(),
377  tensor.device(),
378  tensor.dim(),
379  tensor.is_variable() && tensor.requires_grad(),
380  kind) {}
381  DimensionedTensorType(at::ScalarType scalar_type, at::Device device, int64_t dim, bool requires_grad=true, TypeKind kind=TypeKind::DimensionedTensorType)
382  : TensorType(kind)
383  , scalar_type_(scalar_type)
384  , requires_grad_(at::isFloatingType(scalar_type) && requires_grad)
385  , device_(device)
386  , dim_(dim) {}
387 
388  at::ScalarType scalar_type_;
389  bool requires_grad_;
390  at::Device device_;
391  int64_t dim_;
392 };
393 
394 struct CompleteTensorType;
395 using CompleteTensorTypePtr = std::shared_ptr<CompleteTensorType>;
396 // This type represents a single Tensor with a specific size
397 struct CAFFE2_API CompleteTensorType : public DimensionedTensorType {
398  template<typename ... T>
399  static CompleteTensorTypePtr create( T&& ... all ) {
400  return CompleteTensorTypePtr(new CompleteTensorType( std::forward<T>(all)... )); // NOLINT(modernize-make-shared)
401  }
402 
403  // overloaded create variadic template argument as it could not distinguish initializer list
404  static CompleteTensorTypePtr create(at::ScalarType scalar_type, at::Device device, at::IntArrayRef sizes) {
405  return CompleteTensorTypePtr(new CompleteTensorType(scalar_type, device, sizes)); // NOLINT(modernize-make-shared)
406  }
407  static CompleteTensorTypePtr create(at::ScalarType scalar_type, at::Device device, at::IntArrayRef sizes, at::IntArrayRef strides) {
408  return CompleteTensorTypePtr(new CompleteTensorType(scalar_type, device, sizes, strides)); // NOLINT(modernize-make-shared)
409  }
410 
411  const std::vector<int64_t>& sizes() const { return sizes_; }
412  const std::vector<int64_t>& strides() const { return strides_; }
413 
414  TypePtr withSizesStrides(at::IntArrayRef sizes, at::IntArrayRef strides) const {
415  return CompleteTensorType::create(scalar_type_, device_, sizes, strides);
416  }
417 
418  TypePtr withSizes(at::IntArrayRef sizes) const {
419  return withSizesStrides(sizes, CompleteTensorType::contiguousStridesOf(sizes));
420  }
421 
422  CompleteTensorTypePtr contiguous() const {
423  auto t = CompleteTensorType::create(*this);
424  t->strides_ = CompleteTensorType::contiguousStridesOf(sizes_);
425  return t;
426  }
427 
428  CompleteTensorTypePtr toScalarType(at::ScalarType type){
429  auto t = CompleteTensorType::create(*this);
430  t->scalar_type_ = type;
431  return t;
432  }
433 
434  bool operator==(const Type& rhs) const override {
435  if(rhs.kind() != kind())
436  return false;
437  auto rt = rhs.expect<CompleteTensorType>();
438  return scalarType() == rt->scalarType() &&
439  sizes() == rt->sizes() &&
440  strides() == rt->strides() &&
441  device() == rt->device();
442  }
443  bool isSubtypeOf(const TypePtr rhs) const override {
444  if (rhs->kind() == TypeKind::DimensionedTensorType)
445  return *expect<DimensionedTensorType>() == *rhs;
446  return rhs->kind() == TypeKind::TensorType ||
447  TensorType::isSubtypeOf(rhs);
448  }
449  bool isSubclass(const TypeKind kind) const override {
450  return kind == TypeKind::TensorType ||
451  kind == TypeKind::DimensionedTensorType ||
452  kind == TypeKind::CompleteTensorType;
453  }
454  std::string str() const override {
455  // str is used for user-facing error messages, where we
456  // don't want to reveal underlying size information.
457  return "Tensor";
458  }
459  bool numel() const {
460  size_t prod = 1;
461  for(auto s : sizes()) {
462  prod *= s;
463  }
464  return prod;
465  }
466 
467  static const TypeKind Kind = TypeKind::CompleteTensorType;
468 
469  static TypePtr fromNumberType(TypePtr typ);
470  static TypePtr fromBoolType();
471 
472 private:
473  CompleteTensorType(const at::Tensor& tensor)
474  : DimensionedTensorType(tensor, TypeKind::CompleteTensorType)
475  , sizes_(tensor.sizes().vec())
476  , strides_(tensor.strides().vec()) {}
477  CompleteTensorType(at::ScalarType scalar_type, at::Device device, at::IntArrayRef sizes, bool requires_grad=true)
478  : CompleteTensorType(scalar_type, device, sizes, CompleteTensorType::contiguousStridesOf(sizes), requires_grad) {}
479  CompleteTensorType(at::ScalarType scalar_type, at::Device device, at::IntArrayRef sizes, at::IntArrayRef strides, bool requires_grad=true)
480  : DimensionedTensorType(scalar_type, device, sizes.size(), requires_grad, TypeKind::CompleteTensorType)
481  , sizes_(sizes.vec())
482  , strides_(strides.vec()) {}
483 
484  static std::vector<int64_t> contiguousStridesOf(at::IntArrayRef sizes) {
485  std::vector<int64_t> strides(sizes.size());
486  if(sizes.empty()) // zero-dim case
487  return strides;
488  strides.back() = 1;
489  for(size_t i = strides.size() - 1; i > 0; i--) {
490  strides[i-1] = strides[i] * sizes[i];
491  }
492  return strides;
493  }
494 
495  std::vector<int64_t> sizes_;
496  std::vector<int64_t> strides_;
497 };
498 
499 struct ListType;
500 using ListTypePtr = std::shared_ptr<ListType>;
501 struct CAFFE2_API ListType : public SingleElementType<TypeKind::ListType, ListType> {
502  // It's not exactly a singleton, but there should be exactly once instance of
503  // List[T] for every T
504  friend struct Type;
505  template<typename ... T>
506  static ListTypePtr create( T&& ... all ) {
507  return ListTypePtr(new ListType( std::forward<T>(all)... )); // NOLINT(modernize-make-shared)
508  }
509  DEFINE_IS_SUBCLASS(ListType);
510  std::string str() const override {
511  std::stringstream ss;
512  ss << getElementType()->str() << "[]";
513  return ss.str();
514  }
515  std::string python_str() const override {
516  std::stringstream ss;
517  ss << "List[" << getElementType()->python_str() << "]";
518  return ss.str();
519  }
520  TypePtr createWithContained(std::vector<TypePtr> contained_types) const override {
521  return create(contained_types.at(0));
522  }
523  // common cast List[Tensor]
524  static ListTypePtr ofTensors();
525  static ListTypePtr ofInts();
526  static ListTypePtr ofFloats();
527  static ListTypePtr ofBools();
528 private:
529  ListType(TypePtr elem) : SingleElementType(elem) {}
530 };
531 
532 struct DictType;
533 using DictTypePtr = std::shared_ptr<DictType>;
534 struct CAFFE2_API DictType : public Type {
535  friend struct Type;
536  static const TypeKind Kind = TypeKind::DictType;
537 
538  static DictTypePtr create(TypePtr key, TypePtr value) {
539  switch (key->kind()) {
540  case TypeKind::IntType:
541  case TypeKind::FloatType:
542  case TypeKind::StringType:
543  return DictTypePtr(new DictType(key, value));
544  default:
545  AT_ERROR(
546  "Cannot create dict for key type '",
547  key->str(),
548  "', only int, float, and string keys are supported");
549  }
550  }
551 
552  std::string str() const override {
553  return python_str();
554  }
555 
556  std::string python_str() const override {
557  std::stringstream ss;
558  ss << "Dict[" << getKeyType()->python_str() << ", "
559  << getValueType()->python_str() << "]";
560  return ss.str();
561  }
562 
563  TypePtr createWithContained(
564  std::vector<TypePtr> contained_types) const override {
565  if (contained_types.size() != 2) {
566  throw std::runtime_error("Expected 2 contained types");
567  }
568  return create(contained_types.at(0), contained_types.at(1));
569  }
570 
571  TypePtr getKeyType() const {
572  return types.at(0);
573  }
574 
575  TypePtr getValueType() const {
576  return types.at(1);
577  }
578 
579  DEFINE_IS_SUBCLASS(DictType);
580  bool isSubtypeOf(const TypePtr rhs) const override {
581  if (auto dict_rhs = rhs->cast<DictType>()) {
582  return getKeyType()->isSubtypeOf(dict_rhs->getKeyType()) &&
583  getValueType()->isSubtypeOf(dict_rhs->getValueType());
584  }
585  return false;
586  }
587 
588  bool hasFreeVariables() const override {
589  return has_free_variables;
590  }
591 
592  at::ArrayRef<TypePtr> containedTypes() const override {
593  return types;
594  }
595 
596  bool requires_grad() const override {
597  return getValueType()->requires_grad() || getKeyType()->requires_grad();
598  }
599 
600  bool operator==(const Type& rhs) const override {
601  if (auto dict_rhs = rhs.cast<DictType>()) {
602  return *getKeyType() == *(dict_rhs->getKeyType()) &&
603  *getValueType() == *(dict_rhs->getValueType());
604  }
605  return false;
606  }
607 
608  private:
609  DictType(TypePtr key, TypePtr value)
610  : Type(TypeKind::DictType),
611  types({key, value}),
612  has_free_variables(
613  key->hasFreeVariables() || value->hasFreeVariables()) {}
614  std::vector<TypePtr> types;
615  bool has_free_variables;
616 };
617 
618 struct FutureType;
619 using FutureTypePtr = std::shared_ptr<FutureType>;
620 
621 struct CAFFE2_API FutureType : public SingleElementType<TypeKind::FutureType, FutureType> {
622  friend struct Type;
623  template<typename ... T>
624  static FutureTypePtr create(TypePtr elem) {
625  return FutureTypePtr(new FutureType(std::move(elem))); // NOLINT(modernize-make-shared)
626  }
627 
628  DEFINE_IS_SUBCLASS(FutureType);
629 
630  std::string str() const override {
631  std::stringstream ss;
632  ss << "Future(" << getElementType()->str() << ")";
633  return ss.str();
634  }
635  std::string python_str() const override {
636  std::stringstream ss;
637  ss << "Future[" << getElementType()->python_str() << "]";
638  return ss.str();
639  }
640  TypePtr createWithContained(std::vector<TypePtr> contained_types) const override {
641  return create(contained_types.at(0));
642  }
643 private:
644  FutureType(TypePtr elem) : SingleElementType(elem) {}
645 };
646 
647 struct TupleType;
648 using TupleTypePtr = std::shared_ptr<TupleType>;
650 // This type represents a Tuple
651 struct CAFFE2_API TupleType : public Type {
652  static TupleTypePtr create(std::vector<TypePtr> types, OptNameList names=c10::nullopt) {
653  return TupleTypePtr(new TupleType(std::move(types), std::move(names))); // NOLINT(modernize-make-shared)
654  }
655  DEFINE_IS_SUBCLASS(TupleType);
656  at::ArrayRef<TypePtr> elements() const {
657  return elements_;
658  }
659  bool operator==(const Type& rhs) const override {
660  return compare(rhs, [](const TypePtr a, const TypePtr b) {
661  return *a == *b;
662  }) && names_ == rhs.expect<TupleType>()->names_;
663  // `compare` guarantees that rhs is always a TupleType, so the
664  // dynamic_cast above always success.
665  }
666  bool isSubtypeOf(const TypePtr rhs_) const override {
667  if (Type::isSubtypeOf(rhs_))
668  return true;
669  auto rhs = rhs_->cast<TupleType>();
670  if (!rhs)
671  return false;
672  // unnamed tuple is not a subtype of nametuple
673  if (!hasNames() && rhs->hasNames())
674  return false;
675  // namedtuple may be a subtype of unnamed tuple
676  bool names_match = !rhs->hasNames() || names() == rhs->names();
677  // co-variant rules for tuples
678  return names_match && compare(*rhs, [](const TypePtr a, const TypePtr b) {
679  return a->isSubtypeOf(b);
680  });
681  }
682  bool requires_grad() const override {
683  return std::any_of(elements_.begin(), elements_.end(),
684  [](const TypePtr& ptr) { return ptr->requires_grad(); });
685  }
686  std::string str() const override {
687  std::stringstream ss;
688  ss << "(";
689  for(size_t i = 0; i < elements().size(); ++i) {
690  if(i > 0)
691  ss << ", ";
692  ss << elements()[i]->str();
693  }
694  ss << ")";
695  return ss.str();
696  }
697  std::string python_str() const override {
698  std::stringstream ss;
699  ss << "Tuple[";
700  for(size_t i = 0; i < elements().size(); ++i) {
701  if(i > 0)
702  ss << ", ";
703  ss << elements()[i]->python_str();
704  }
705  ss << "]";
706  return ss.str();
707  }
708  bool hasFreeVariables() const override {
709  return has_free_variables_;
710  }
711  bool hasNames() const {
712  return names_.has_value();
713  }
714  const std::vector<std::string> &names() const {
715  return names_.value();
716  }
717 
718  at::ArrayRef<TypePtr> containedTypes() const override {
719  return elements_;
720  }
721  TypePtr createWithContained(std::vector<TypePtr> contained_types) const override {
722  return create(std::move(contained_types));
723  }
724 
725  static const TypeKind Kind = TypeKind::TupleType;
726 private:
727  TupleType(std::vector<TypePtr> elements_, OptNameList names)
728  : Type(TypeKind::TupleType)
729  , elements_(std::move(elements_))
730  , names_(std::move(names)) {
731  has_free_variables_ =
732  std::any_of(elements_.begin(), elements_.end(), [](TypePtr v) {
733  return v->hasFreeVariables();
734  });
735  }
736 
737  bool compare(const Type& rhs, std::function<bool(const TypePtr, const TypePtr)> fn) const {
738  if(rhs.kind() != kind())
739  return false;
740  const auto & l_elements = elements();
741  const auto & r_elements = rhs.cast<TupleType>()->elements();
742  if(l_elements.size() != r_elements.size())
743  return false;
744  for(size_t i = 0; i < l_elements.size(); ++i) {
745  if(!fn(l_elements[i], r_elements[i]))
746  return false;
747  }
748  return true;
749  }
750 
751  std::vector<TypePtr> elements_;
752  bool has_free_variables_;
753  OptNameList names_;
754 };
755 
756 struct NumberType;
757 using NumberTypePtr = std::shared_ptr<NumberType>;
758 // This type represents a Python number
759 // Subtype hierarchy for Number Types (NumberType as the base type):
760 // IntType <: NumberType
761 // FloatType <: NumberType
762 struct CAFFE2_API NumberType : public Type {
763  static NumberTypePtr create() {
764  return NumberTypePtr(new NumberType()); // NOLINT(modernize-make-shared)
765  }
766  DEFINE_IS_SUBCLASS(NumberType);
767  bool operator==(const Type& rhs) const override {
768  return rhs.kind() == kind();
769  }
770  std::string str() const override {
771  return "Scalar"; // match what PythonArgParser says for clarity
772  }
773  std::string python_str() const override {
774  return "number"; // technically not a valid python type, but
775  // we need to use it when parsing back in annotations
776  // for implicit conversions
777  }
778  static const TypeKind Kind = TypeKind::NumberType;
779  // global singleton
780  static NumberTypePtr get();
781 protected:
782  NumberType(TypeKind kind=TypeKind::NumberType)
783  : Type(kind) {}
784 };
785 
786 struct FloatType;
787 using FloatTypePtr = std::shared_ptr<FloatType>;
788 // This type represents a Python float number
789 struct CAFFE2_API FloatType : public NumberType {
790  static FloatTypePtr create() {
791  return FloatTypePtr(new FloatType()); // NOLINT(modernize-make-shared)
792  }
793  DEFINE_IS_SUBCLASS(FloatType);
794  bool operator==(const Type& rhs) const override {
795  return rhs.kind() == kind();
796  }
797  std::string str() const override {
798  return "float";
799  }
800  std::string python_str() const override {
801  return "float";
802  }
803  bool isSubtypeOf(const TypePtr rhs) const override {
804  return rhs->kind() == TypeKind::NumberType ||
805  NumberType::isSubtypeOf(rhs);
806  }
807  static const TypeKind Kind = TypeKind::FloatType;
808  // global singleton
809  static FloatTypePtr get();
810 private:
811  FloatType()
812  : NumberType(TypeKind::FloatType) {}
813 };
814 
815 struct IntType;
816 using IntTypePtr = std::shared_ptr<IntType>;
817 // This type represents a Python int number
818 struct CAFFE2_API IntType : public NumberType {
819  static IntTypePtr create() {
820  return IntTypePtr(new IntType()); // NOLINT(modernize-make-shared)
821  }
822  DEFINE_IS_SUBCLASS(IntType);
823  bool operator==(const Type& rhs) const override {
824  return rhs.kind() == kind();
825  }
826  std::string str() const override {
827  return "int";
828  }
829  std::string python_str() const override {
830  return "int";
831  }
832  bool isSubtypeOf(const TypePtr rhs) const override {
833  return rhs->kind() == TypeKind::NumberType ||
834  NumberType::isSubtypeOf(rhs);
835  }
836  static const TypeKind Kind = TypeKind::IntType;
837  // global singleton
838  static IntTypePtr get();
839 private:
840  IntType()
841  : NumberType(TypeKind::IntType) {}
842 };
843 
844 struct BoolType;
845 using BoolTypePtr = std::shared_ptr<BoolType>;
846 // This node represents a Python bool value
847 struct CAFFE2_API BoolType : public Type {
848  static BoolTypePtr create( ) {
849  return BoolTypePtr(new BoolType());
850  }
851  DEFINE_IS_SUBCLASS(BoolType);
852  bool operator==(const Type& rhs) const override {
853  return rhs.kind() == kind();
854  }
855  std::string str() const override {
856  return "bool";
857  }
858  static const TypeKind Kind = TypeKind::BoolType;
859  // global singleton
860  static BoolTypePtr get();
861 private:
862  BoolType()
863  : Type(TypeKind::BoolType) {}
864 };
865 
866 struct StringType;
867 using StringTypePtr = std::shared_ptr<StringType>;
868 // This type represents a Python string
869 struct CAFFE2_API StringType : public Type {
870  static StringTypePtr create() {
871  return StringTypePtr(new StringType()); // NOLINT(modernize-make-shared)
872  }
873  DEFINE_IS_SUBCLASS(StringType);
874  bool operator==(const Type& rhs) const override {
875  return rhs.kind() == kind();
876  }
877  std::string str() const override {
878  return "string";
879  }
880  std::string python_str() const override {
881  return "str";
882  }
883  static const TypeKind Kind = TypeKind::StringType;
884  // global singleton
885  static StringTypePtr get();
886 private:
887  StringType()
888  : Type(TypeKind::StringType) {}
889 };
890 
891 struct NoneType;
892 using NoneTypePtr = std::shared_ptr<NoneType>;
893 // This type represents a Python None
894 struct CAFFE2_API NoneType : public Type {
895  static NoneTypePtr create() {
896  return NoneTypePtr(new NoneType()); // NOLINT(modernize-make-shared)
897  }
898  DEFINE_IS_SUBCLASS(NoneType);
899  bool operator==(const Type& rhs) const override {
900  return rhs.kind() == kind();
901  }
902  bool isSubtypeOf(const TypePtr rhs) const override {
903  return rhs->kind() == TypeKind::NoneType;
904  }
905  std::string str() const override {
906  return "None";
907  }
908  static const TypeKind Kind = TypeKind::NoneType;
909  // global singleton
910  static NoneTypePtr get();
911 private:
912  NoneType()
913  : Type(TypeKind::NoneType) {}
914 };
915 
916 struct GeneratorType;
917 using GeneratorTypePtr = std::shared_ptr<GeneratorType>;
918 // This type represents a Generator
919 struct CAFFE2_API GeneratorType : public Type {
920  static GeneratorTypePtr create() {
921  return GeneratorTypePtr(new GeneratorType()); // NOLINT(modernize-make-shared)
922  }
923  DEFINE_IS_SUBCLASS(GeneratorType);
924  bool operator==(const Type& rhs) const override {
925  return rhs.kind() == kind();
926  }
927  std::string str() const override {
928  return "Generator";
929  }
930  static const TypeKind Kind = TypeKind::GeneratorType;
931  // global singleton
932  static GeneratorTypePtr get();
933 private:
934  GeneratorType()
935  : Type(TypeKind::GeneratorType) {}
936 };
937 
938 struct DeviceObjType;
939 using DeviceObjTypePtr = std::shared_ptr<DeviceObjType>;
940 // This type represents a Generator
941 struct CAFFE2_API DeviceObjType : public Type {
942  static DeviceObjTypePtr create() {
943  return DeviceObjTypePtr(new DeviceObjType()); // NOLINT(modernize-make-shared)
944  }
945  DEFINE_IS_SUBCLASS(DeviceObjType);
946  bool operator==(const Type& rhs) const override {
947  return rhs.kind() == kind();
948  }
949  std::string str() const override {
950  return "Device";
951  }
952  static const TypeKind Kind = TypeKind::DeviceObjType;
953  // global singleton
954  static DeviceObjTypePtr get();
955 private:
956  DeviceObjType()
957  : Type(TypeKind::DeviceObjType) {}
958 };
959 
960 
961 struct VarType;
962 using VarTypePtr = std::shared_ptr<VarType>;
963 // This type represents a type variable, used in FunctionSchema
964 struct VarType : public Type {
965  static VarTypePtr create(std::string name_) {
966  return VarTypePtr(new VarType(std::move(name_)));
967  }
968  DEFINE_IS_SUBCLASS(VarType);
969  bool operator==(const Type& rhs) const override {
970  return rhs.kind() == kind();
971  }
972  std::string str() const override {
973  return name();
974  }
975  const std::string& name() const {
976  return name_;
977  }
978  bool hasFreeVariables() const override {
979  return true;
980  }
981  static const TypeKind Kind = TypeKind::VarType;
982 private:
983  VarType(std::string name_)
984  : Type(TypeKind::VarType), name_(std::move(name_)) {}
985  std::string name_;
986 };
987 
988 CAFFE2_API std::ostream& operator<<(std::ostream & out, const Type & t);
989 // what is the type, ignoring extra size/shape information?
990 // e.g. Tensor(2x3) -> Dynamic, and Tuple(Tensor(2x3),...) -> Tuple(Dynamic,...)
991 
992 inline TypePtr unshapedType(const TypePtr& type) {
993  if (type->kind() == TypeKind::DimensionedTensorType ||
994  type->kind() == TypeKind::CompleteTensorType) {
995  return TensorType::get();
996  }
997  return type->withContained(fmap(type->containedTypes(), unshapedType));
998 }
999 
1000 inline TypePtr CompleteTensorType::fromNumberType(TypePtr typ) {
1001  if (typ->isSubtypeOf(IntType::get())) {
1002  return CompleteTensorType::create(at::kLong, at::kCPU, {});
1003  } else if (typ->isSubtypeOf(FloatType::get())) {
1004  return CompleteTensorType::create(at::kFloat, at::kCPU, {});
1005  } else if (typ->isSubtypeOf(BoolType::get())) {
1006  return CompleteTensorType::create(at::kLong, at::kCPU, {});
1007  }
1008  AT_ERROR("unknown number type", typ->str());
1009 }
1010 
1011 inline TypePtr CompleteTensorType::fromBoolType() {
1012  return CompleteTensorType::create(at::kLong, at::kCPU, {});
1013 }
1014 
1015 // Attempt to find the correct supertype of t1 and t2. If none is found then
1016 // nullopt will be returned. If t1 == t2, or t1 is a type refinement of t2,
1017 // then t2 will be returned (and vice versa).
1018 // Two different tensortypes will return dynamic.
1019 // Currently we chose not to support returning a NumberType for a float & int
1020 // input because of a lack of operator support for NumberType
1021 CAFFE2_API c10::optional<TypePtr> unifyTypes(
1022  const TypePtr& t1,
1023  const TypePtr& t2);
1024 
1025 namespace detail {
1026 template <typename T> struct getTypePtr_ final {
1027  static_assert(guts::false_t<T>::value, "Type could not be converted to any of the known types.");
1028 };
1029 
1030 template<> struct getTypePtr_<at::Tensor> final {
1031  static TypePtr call() { return TensorType::get(); }
1032 };
1033 template<> struct getTypePtr_<double> final {
1034  static TypePtr call() { return FloatType::get(); }
1035 };
1036 template<> struct getTypePtr_<int64_t> final {
1037  static TypePtr call() { return IntType::get(); }
1038 };
1039 template<> struct getTypePtr_<bool> final {
1040  static TypePtr call() { return BoolType::get(); }
1041 };
1042 template<> struct getTypePtr_<at::Scalar> final {
1043  static TypePtr call() { return NumberType::get(); }
1044 };
1045 template<> struct getTypePtr_<std::string> final {
1046  static TypePtr call() { return StringType::get(); }
1047 };
1048 template<class T> struct getTypePtr_<std::vector<T>> final {
1049  static TypePtr call() {
1050  static auto type = ListType::create(getTypePtr_<T>::call());
1051  return type;
1052  }
1053 };
1054 template<class T> struct getTypePtr_<ArrayRef<T>> final {
1055  static TypePtr call() {
1056  static auto type = ListType::create(getTypePtr_<T>::call());
1057  return type;
1058  }
1059 };
1060 template <class K, class V>
1061 struct getTypePtr_<std::unordered_map<K, V>> final {
1062  static TypePtr call() {
1063  static auto type =
1064  DictType::create(getTypePtr_<K>::call(), getTypePtr_<V>::call());
1065  return type;
1066  }
1067 };
1068 template <class T>
1069 struct getTypePtr_<at::optional<T>> final {
1070  static TypePtr call() {
1071  static auto type = OptionalType::create(getTypePtr_<T>::call());
1072  return type;
1073  }
1074 };
1075 }
1076 template<class T> inline TypePtr getTypePtr() {
1077  // TODO: static_assert that a templated function exists, and throw a friendy
1078  // error message if not
1080 }
1081 
1082 CAFFE2_API TypePtr incompleteInferTypeFrom(const IValue& value);
1083 CAFFE2_API TypePtr attemptToRecoverType(const IValue& input_ivalue);
1084 CAFFE2_API bool isSubvalueOf(const IValue& input_ivalue, TypePtr type);
1085 
1086 using TypeEnv = std::unordered_map<std::string, TypePtr>;
1088  c10::optional<TypePtr> type; // nullopt if there is no match
1089  std::string errMsg; // is there is no match, this contains the reason
1090 };
1091 
1092 CAFFE2_API MatchTypeReturn
1093 matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type_env);
1094 
1095 CAFFE2_API TypePtr evalTypeVariables(TypePtr type, TypeEnv & type_env);
1096 
1101 struct ClassType;
1102 using ClassTypePtr = std::shared_ptr<ClassType>;
1103 using ::torch::jit::script::Module;
1104 using ::torch::jit::script::Method;
1105 
1106 // This represents a class in TorchScript.
1107 struct CAFFE2_API ClassType : public Type {
1108  // Create a user type and register it globally.
1109  static ClassTypePtr create(
1110  const std::string& name,
1111  std::shared_ptr<Module> module);
1112  // returns nullptr if there is no type with that name
1113  static ClassTypePtr get(const std::string& name);
1114  // For testing: delete all registered types
1115  static void clearRegistry();
1116 
1117  DEFINE_IS_SUBCLASS(ClassType);
1118  bool operator==(const Type& rhs) const override {
1119  if (auto user_rhs = rhs.cast<ClassType>()) {
1120  return typename_ == user_rhs->typename_;
1121  }
1122  return false;
1123  }
1124 
1125  bool isSubtypeOf(const TypePtr rhs) const override {
1126  // XXX: We do not have inheritance implemented, only types that are the
1127  // same can subtype from each other.
1128  return *this == *rhs;
1129  }
1130  std::string str() const override {
1131  return std::string("ClassType<") + typename_ + ">";
1132  }
1133 
1134  std::string python_str() const override {
1135  return typename_;
1136  }
1137 
1138  TypePtr getAttribute(const std::string& name) const {
1139  AT_ASSERT(attributeNames_.size() == attributeTypes_.size());
1140  size_t pos = 0;
1141  for (const auto& attr : attributeNames_) {
1142  if (name == attr) {
1143  break;
1144  }
1145  ++pos;
1146  }
1147 
1148  if (pos >= attributeNames_.size()) {
1149  return nullptr;
1150  }
1151  return attributeTypes_[pos];
1152  }
1153 
1154  Method* getMethod(const std::string& name) const;
1155  std::vector<Method*> methods() const;
1156 
1157  std::string name() const {
1158  return typename_;
1159  }
1160 
1161  size_t numAttributes() const {
1162  AT_ASSERT(attributeNames_.size() == attributeTypes_.size());
1163  return attributeNames_.size();
1164  }
1165 
1166  // Attributes are stored in a specific slot at runtime for effiency.
1167  // When emitting instructions we specify the slot so that attribute access is
1168  // a constant lookup
1169  size_t getAttributeSlot(const std::string& name) const {
1170  AT_ASSERT(attributeNames_.size() == attributeTypes_.size());
1171  size_t slot = 0;
1172  for (const auto& attr : attributeNames_) {
1173  if (name == attr) {
1174  return slot;
1175  }
1176  slot++;
1177  }
1178  throw std::runtime_error("Couldn't find attribute: " + name);
1179  }
1180 
1181  bool hasAttribute(const std::string& name) const {
1182  return std::find_if(
1183  attributeNames_.cbegin(),
1184  attributeNames_.cend(),
1185  [&](const std::string& attr) { return attr == name; }) !=
1186  attributeNames_.cend();
1187  }
1188 
1189  void addAttribute(const std::string& name, TypePtr type) {
1190  attributeNames_.push_back(name);
1191  attributeTypes_.push_back(type);
1192  }
1193 
1194  at::ArrayRef<TypePtr> containedTypes() const override {
1195  return attributeTypes_;
1196  }
1197 
1198  static const TypeKind Kind = TypeKind::ClassType;
1199 
1200  private:
1201  ClassType(std::string name, std::shared_ptr<Module> module)
1202  : Type(TypeKind::ClassType),
1203  typename_(std::move(name)),
1204  module_(std::move(module)) {}
1205 
1206  // Name of type (note that this has to be globally unique).
1207  std::string typename_;
1208 
1209  // Mapping of attribute names -> their type.
1210  // NOTE: this does not contain methods, which are stored in the module
1211  // TODO: once modules support arbitrary ivalue attributes, we don't need this
1212  // anymore.
1213  // TODO: This is better represented as an OrderedDict, but alas it is not yet
1214  // available from c10
1215  std::vector<std::string> attributeNames_;
1216  std::vector<TypePtr> attributeTypes_;
1217  // Holds method attributes
1218  std::shared_ptr<Module> module_;
1219 
1220 };
1221 } // namespace c10
AT_CPP14_CONSTEXPR const T & back() const
back - Get the last element.
Definition: ArrayRef.h:149
Scalar represents a 0-dimensional tensor which contains a single element.
Definition: Scalar.h:22
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
constexpr size_t size() const
size - Get the array size.
Definition: ArrayRef.h:138
bool is_variable() const noexcept
Returns true if the Tensor is actually a torch::autograd::Variable.
Device device() const
Returns a Tensor&#39;s device.
constexpr bool empty() const
empty - Check if the array is empty.
Definition: ArrayRef.h:129
Definition: jit_type.h:17
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition: alias_info.h:7
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41
TensorOptions requires_grad(bool requires_grad=true)
Convenience function that returns a TensorOptions object with the requires_grad set to the given one...
Flush-To-Zero and Denormals-Are-Zero mode.
C10_NODISCARD TensorOptions requires_grad(c10::optional< bool > requires_grad) const noexcept
Sets the requires_grad property of the TensorOptions.