Caffe2 - C++ API
A deep learning, cross platform ML framework
ivalue.h
1 #pragma once
2 
3 #include <condition_variable>
4 #include <type_traits>
5 
6 #include <ATen/core/blob.h>
7 #include <ATen/core/interned_strings.h>
8 #include <c10/core/Scalar.h>
9 #include <c10/core/TensorImpl.h>
10 #include <c10/core/UndefinedTensorImpl.h>
11 #include <c10/util/intrusive_ptr.h>
12 
13 #include <ATen/core/Tensor.h>
14 
15 namespace c10 {
16 struct IValue;
17 
18 namespace ivalue {
19 
20 template <typename T>
21 using Shared = c10::intrusive_ptr<T>;
22 
23 // string
24 struct CAFFE2_API ConstantString final : c10::intrusive_ptr_target {
25  private:
26  const std::string str_;
27  public:
28  ConstantString(std::string str)
29  : str_(std::move(str)) {}
30  static c10::intrusive_ptr<ConstantString> create(std::string str_);
31  const std::string & string() const {
32  return str_;
33  }
34  operator const std::string & () const {
35  return string();
36  }
37  CAFFE2_API friend std::ostream& operator<<(
38  std::ostream& out,
39  const ConstantString& v);
40 };
41 
42 template <typename Elem>
43 struct CAFFE2_API List : c10::intrusive_ptr_target {
44  private:
45  std::vector<Elem> elements_;
46 
47  public:
48  typedef Elem ElemType;
49 
50  List(std::vector<Elem> elements_) : elements_(std::move(elements_)) {}
51  static c10::intrusive_ptr<List<Elem>> create(std::vector<Elem> elements_) {
52  return c10::make_intrusive<List<Elem>>(std::move(elements_));
53  }
54  const std::vector<Elem>& elements() const & {
55  return elements_;
56  }
57  operator const std::vector<Elem>&() const {
58  return elements();
59  }
60 
61  std::vector<Elem>& elements() & {
62  return elements_;
63  }
64  operator std::vector<Elem>&() {
65  return elements();
66  }
67 
68  std::vector<Elem>&& elements() && {
69  return std::move(elements_);
70  }
71 };
72 
73 struct DictHash {
74  size_t operator()(const IValue& ivalue) const;
75 };
76 
77 struct DictEqualTo {
78  bool operator()(const IValue& lhs, const IValue& rhs) const;
79 };
80 
81 using UnorderedMap = std::unordered_map<IValue, IValue, DictHash, DictEqualTo>;
82 
83 struct Future;
84 struct GenericDict;
85 
86 struct CAFFE2_API Tuple : public List<IValue> {
87  using List<IValue>::List;
88  static c10::intrusive_ptr<Tuple> create(std::vector<IValue> elements_) {
89  return c10::make_intrusive<Tuple>(std::move(elements_));
90  }
91 };
92 using IntList = List<int64_t>;
93 using TensorList = List<at::Tensor>;
94 using DoubleList = List<double>;
95 using BoolList = List<bool>;
96 using GenericList = List<IValue>;
97 
98 struct Object;
99 }
100 
101 // IValue is the generic tagged union used by the interpreter to hold
102 // all value types.
103 // It is a 16-byte object with an 8-byte payload and an 8-byte tag.
104 // The tag is currently 4 bytes to determine the type, and 1 byte
105 // to mark whether that type is a subtype of c10::intrusive_ptr_target and needs
106 // retain/release calls.
107 
108 #define TORCH_FORALL_TAGS(_) \
109  _(None) \
110  _(Tensor) \
111  _(Double) \
112  _(Int) \
113  _(Bool) \
114  _(Tuple) \
115  _(IntList) \
116  _(DoubleList) \
117  _(BoolList) \
118  _(String) \
119  _(TensorList) \
120  _(Blob) \
121  _(GenericList) \
122  _(GenericDict) \
123  _(Future) \
124  _(Device) \
125  _(Object)
126 
127 struct CAFFE2_API IValue final {
128  IValue()
129  : payload{0}
130  , tag(Tag::None)
131  , is_intrusive_ptr(false) {}
132  IValue(const IValue& rhs)
133  : payload(rhs.payload),
134  tag(rhs.tag),
135  is_intrusive_ptr(rhs.is_intrusive_ptr) {
136  if (is_intrusive_ptr) {
137  c10::raw::intrusive_ptr::incref(payload.as_intrusive_ptr);
138  }
139  }
140  IValue(IValue&& rhs) noexcept : IValue() {
141  swap(rhs);
142  }
143  ~IValue() {
144  if (is_intrusive_ptr) {
145  c10::raw::intrusive_ptr::decref(payload.as_intrusive_ptr);
146  }
147  }
148  IValue & operator=(IValue && rhs) & noexcept {
149  IValue(std::move(rhs)).swap(*this); // this also sets rhs to None
150  return *this;
151  }
152  IValue & operator=(IValue const & rhs) & {
153  IValue(rhs).swap(*this);
154  return *this;
155  }
156 
157  void dump() const;
158 
159  bool isAliasOf(const IValue& rhs) const {
160  if (this->tag != rhs.tag) {
161  // Trivially don't alias if the type is different
162  return false;
163  }
164 
165  if (!this->is_intrusive_ptr) {
166  // Primitive types don't alias anything
167  return false;
168  }
169 
170  AT_ASSERT(rhs.is_intrusive_ptr);
171 
172  // Tensors should be compared based on internal storage
173  if (this->isTensor()) {
174  const auto thisTensor = this->toTensor();
175  const auto rhsTensor = rhs.toTensor();
176  return thisTensor.is_alias_of(rhsTensor);
177  }
178 
179  // Other types can be compared by their ptr value
180  return this->payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr;
181  }
182  void swap(IValue & rhs) noexcept {
183  std::swap(payload, rhs.payload);
184  std::swap(is_intrusive_ptr, rhs.is_intrusive_ptr);
185  std::swap(tag, rhs.tag);
186  }
187 
188  // Accessors for subtypes are arranged together below
189  // While some of these accessors could be generated through templates,
190  // we prefer to write them manually for clarity
191 
192  // Tensor
193  IValue(at::Tensor t)
194  : tag(Tag::Tensor), is_intrusive_ptr(t.defined()) {
195  // Note: the undefined tensor is not refcounted, so while it
196  // is tagged as a tensor, is_intrusive_ptr is set to false.
197  // This is not an optional optimization: our incref call
198  // *will not* do the right thing when called on an
199  // undefined tensor.
200  payload.as_intrusive_ptr = t.unsafeReleaseTensorImpl();
201  }
202  bool isTensor() const { return Tag::Tensor == tag; }
203  at::Tensor toTensor() && {
204  AT_ASSERT(isTensor());
205  return at::Tensor(moveToIntrusivePtr<at::TensorImpl, at::UndefinedTensorImpl>());
206  }
207  at::Tensor toTensor() const & {
208  AT_ASSERT(isTensor());
209  return at::Tensor(toIntrusivePtr<at::TensorImpl, at::UndefinedTensorImpl>());
210  }
211 
212  const IValue& toIValue() const {
213  return *this;
214  }
215  IValue& toIValue() {
216  return *this;
217  }
218 
219  IValue(intrusive_ptr<caffe2::Blob> blob)
220  : tag(Tag::Blob), is_intrusive_ptr(true) {
221  // TODO (after Tensor merge) If we pass in a Blob holding a Tensor, extract
222  // and store it as a Tensor instead.
223  payload.as_intrusive_ptr = blob.release();
224  }
225  bool isBlob() const {
226  return Tag::Blob == tag;
227  }
229  AT_ASSERT(isBlob());
230  return moveToIntrusivePtr<caffe2::Blob>();
231  }
232  c10::intrusive_ptr<caffe2::Blob> toBlob() const & {
233  AT_ASSERT(isBlob());
234  return toIntrusivePtr<caffe2::Blob>();;
235  }
236 
237  // Tuple
239  bool isTuple() const { return Tag::Tuple == tag; }
240  c10::intrusive_ptr<ivalue::Tuple> toTuple() && {
241  AT_ASSERT(isTuple());
242  return moveToIntrusivePtr<ivalue::Tuple>();
243  }
244  c10::intrusive_ptr<ivalue::Tuple> toTuple() const & {
245  AT_ASSERT(isTuple());
246  return toIntrusivePtr<ivalue::Tuple>();
247  }
248 
249  // Double
250  IValue(double d)
251  : tag(Tag::Double), is_intrusive_ptr(false) {
252  payload.as_double = d;
253  }
254  bool isDouble() const { return Tag::Double == tag; }
255  double toDouble() const {
256  AT_ASSERT(isDouble());
257  return payload.as_double;
258  }
259 
260  // Future
262  bool isFuture() const { return Tag::Future == tag; }
263  c10::intrusive_ptr<ivalue::Future> toFuture() && {
264  AT_ASSERT(isFuture());
265  return moveToIntrusivePtr<ivalue::Future>();
266  }
267  c10::intrusive_ptr<ivalue::Future> toFuture() const & {
268  AT_ASSERT(isFuture());
269  return toIntrusivePtr<ivalue::Future>();
270  }
271 
272  // Int
273  IValue(int64_t i)
274  : tag(Tag::Int), is_intrusive_ptr(false) {
275  payload.as_int = i;
276  }
277 
278  // allow you to pass literals (3, 4) without ambiguity
279  IValue(int32_t i)
280  : IValue(static_cast<int64_t>(i)) {}
281 
282  bool isInt() const { return Tag::Int == tag; }
283 
284  int64_t toInt() const {
285  AT_ASSERT(isInt());
286  return payload.as_int;
287  }
288 
289  // Bool
290  IValue(bool b)
291  : tag(Tag::Bool), is_intrusive_ptr(false) {
292  payload.as_bool = b;
293  }
294  bool isBool() const { return Tag::Bool == tag; }
295  bool toBool() const {
296  AT_ASSERT(isBool());
297  return payload.as_bool;
298  }
299 
300  // IntList
302  IValue(std::vector<int64_t> v);
303  IValue(at::ArrayRef<int64_t> v)
304  : IValue(v.vec()) {}
305  bool isIntList() const { return Tag::IntList == tag; }
306  c10::intrusive_ptr<ivalue::IntList> toIntList() && {
307  AT_ASSERT(isIntList());
308  return moveToIntrusivePtr<ivalue::IntList>();
309  }
310  c10::intrusive_ptr<ivalue::IntList> toIntList() const & {
311  AT_ASSERT(isIntList());
312  return toIntrusivePtr<ivalue::IntList>();
313  }
314 
315  const std::vector<int64_t>& toIntListRef() const;
316  const std::vector<double>& toDoubleListRef() const;
317  const std::vector<bool>& toBoolListRef() const;
318  const std::vector<at::Tensor>& toTensorListRef() const;
319  const std::vector<IValue>& toGenericListRef() const;
320  const ivalue::UnorderedMap& toGenericDictRef() const;
321  const std::string& toStringRef() const;
322 
323  // ConstantString
325  IValue(std::string v);
326  bool isString() const { return Tag::String == tag; }
328  AT_ASSERT(isString());
329  return moveToIntrusivePtr<ivalue::ConstantString>();
330  }
331  c10::intrusive_ptr<ivalue::ConstantString> toString() const & {
332  AT_ASSERT(isString());
333  return toIntrusivePtr<ivalue::ConstantString>();
334  }
335 
336  // DoubleList
338  IValue(std::vector<double> v);
339  bool isDoubleList() const { return Tag::DoubleList == tag; }
340  c10::intrusive_ptr<ivalue::DoubleList> toDoubleList() && {
341  AT_ASSERT(isDoubleList());
342  return moveToIntrusivePtr<ivalue::DoubleList>();
343  }
344  c10::intrusive_ptr<ivalue::DoubleList> toDoubleList() const & {
345  AT_ASSERT(isDoubleList());
346  return toIntrusivePtr<ivalue::DoubleList>();
347  }
348 
349  // BoolList
351  IValue(std::vector<bool> v);
352  bool isBoolList() const { return Tag::BoolList == tag; }
353  c10::intrusive_ptr<ivalue::BoolList> toBoolList() && {
354  AT_ASSERT(isBoolList());
355  return moveToIntrusivePtr<ivalue::BoolList>();
356  }
357  c10::intrusive_ptr<ivalue::BoolList> toBoolList() const & {
358  AT_ASSERT(isBoolList());
359  return toIntrusivePtr<ivalue::BoolList>();
360  }
361 
362  //TensorList
364  IValue(std::vector<at::Tensor> v);
365  bool isTensorList() const { return Tag::TensorList == tag; }
366  c10::intrusive_ptr<ivalue::TensorList> toTensorList() && {
367  AT_ASSERT(isTensorList());
368  return moveToIntrusivePtr<ivalue::TensorList>();
369  }
370  c10::intrusive_ptr<ivalue::TensorList> toTensorList() const & {
371  AT_ASSERT(isTensorList());
372  return toIntrusivePtr<ivalue::TensorList>();
373  }
374 
375  //GenericList
377  IValue(std::vector<IValue> v);
378  bool isGenericList() const { return Tag::GenericList == tag; }
379  c10::intrusive_ptr<ivalue::GenericList> toGenericList() && {
380  AT_ASSERT(isGenericList());
381  return moveToIntrusivePtr<ivalue::GenericList>();
382  }
383  c10::intrusive_ptr<ivalue::GenericList> toGenericList() const & {
384  AT_ASSERT(isGenericList());
385  return toIntrusivePtr<ivalue::GenericList>();
386  }
387 
388  // GenericDict
390  IValue(ivalue::UnorderedMap v);
391  bool isGenericDict() const { return Tag::GenericDict == tag; }
392  c10::intrusive_ptr<ivalue::GenericDict> toGenericDict() && {
393  AT_ASSERT(isGenericDict());
394  return moveToIntrusivePtr<ivalue::GenericDict>();
395  }
396  c10::intrusive_ptr<ivalue::GenericDict> toGenericDict() const & {
397  AT_ASSERT(isGenericDict());
398  return toIntrusivePtr<ivalue::GenericDict>();
399  }
400 
401  // ClassType
403  bool isObject() const { return tag == Tag::Object; }
404  c10::intrusive_ptr<ivalue::Object> toObject() && {
405  AT_ASSERT(isObject());
406  return toIntrusivePtr<ivalue::Object>();
407  }
408  c10::intrusive_ptr<ivalue::Object> toObject() const & {
409  AT_ASSERT(isObject());
410  return toIntrusivePtr<ivalue::Object>();
411  }
412 
413  // None
414  bool isNone() const {
415  return Tag::None == tag;
416  }
417  std::string toNone() const {
418  AT_ASSERT(isNone());
419  return "None";
420  }
421  // Scalar, which gets encoded as either an Int or a Double
422  IValue(at::Scalar s)
423  : IValue() {
424  if(s.isFloatingPoint()) {
425  *this = s.toDouble();
426  } else {
427  *this = s.toLong();
428  }
429  }
430  bool isScalar() const {
431  return isDouble() || isInt();
432  }
433  at::Scalar toScalar() const {
434  if(isDouble())
435  return toDouble();
436  else if(isInt())
437  return toInt();
438  throw std::runtime_error("IValue is not a Scalar");
439  }
440 
441  // Device
442  IValue(c10::Device d)
443  : tag(Tag::Device), is_intrusive_ptr(false) {
444  payload.as_device.type = d.type();
445  payload.as_device.index = d.index();
446  }
447  bool isDevice() const { return Tag::Device == tag; }
448  c10::Device toDevice() const {
449  AT_ASSERT(isDevice());
450  return c10::Device(payload.as_device.type, payload.as_device.index);
451  }
452 
453  // ScalarType
454  at::ScalarType toScalarType() const {
455  return static_cast<at::ScalarType>(toInt());
456  }
457 
458  // Layout
459  at::Layout toLayout() const {
460  return static_cast<at::Layout>(toInt());
461  }
462 
463  // for debugging
464  std::string tagKind() const {
465  switch(tag) {
466  #define DEFINE_CASE(x) case Tag::x: return #x;
467  TORCH_FORALL_TAGS(DEFINE_CASE)
468  #undef DEFINE_CASE
469  }
470  return "Invalid Tag";
471  }
472 
473  // generic v.to<at::Tensor>() implementations
474  // that can be used in special functions like pop/push
475  // that use template meta-programming.
476  // prefer the directly named methods when you can,
477  // since they are simpler to understand
478 
479  // Note: if you get linker errors saying one of these is missing,
480  // change it to ... && = delete; and you will see better error messages for why
481  // However, we cannot commit this because some compiler versions barf on it.
482  template<typename T>
483  T to() &&;
484  template<typename T>
485  T to() const &;
486 
487  // ToOptional: convert a IValue to the Optional obj that accepts both T and None
488  template<typename T>
489  optional<T> toOptional();
490 
491  // this is a shallow comparison of two IValues to test the object identity
492  bool isSameIdentity(IValue& rhs);
493 
494  CAFFE2_API friend std::ostream& operator<<(
495  std::ostream& out,
496  const IValue& v);
497 
498  bool isPtrType() const {
499  return is_intrusive_ptr;
500  }
501 
502  private:
503  // NOTE: IValue tags are intentionally private. In the future we may encode
504  // this value different (e.g. using NaN boxing), and this would make it more
505  // costly to determine the tag for all types vs just determining if something
506  // is a particular type. Instead we want clients to use the `isX` methods when
507  // possible. If for perf. reasons you really, absolutely, must have a jump
508  // table, then we can revisit this.
509  enum class Tag : uint32_t {
510 #define DEFINE_TAG(x) x,
511  TORCH_FORALL_TAGS(DEFINE_TAG)
512 #undef DEFINE_TAG
513  };
514 
515  template<class T, class NullType = c10::detail::intrusive_target_default_null_type<T>>
516  c10::intrusive_ptr<T, NullType> moveToIntrusivePtr() {
517  auto t = c10::intrusive_ptr<T, NullType>::reclaim(static_cast<T*>(payload.as_intrusive_ptr));
518  clearToNone();
519  return t;
520  }
521  template<typename T, class NullType = c10::detail::intrusive_target_default_null_type<T>>
522  c10::intrusive_ptr<T, NullType> toIntrusivePtr() const {
523  auto r = c10::intrusive_ptr<T, NullType>::reclaim(static_cast<T*>(payload.as_intrusive_ptr));
524  auto p = r;
525  r.release();
526  return p;
527  }
528  void clearToNone() {
529  payload.as_int = 0;
530  tag = Tag::None;
531  is_intrusive_ptr = false;
532  }
533  union {
534  int64_t as_int;
535  double as_double;
536  bool as_bool;
537  c10::intrusive_ptr_target* as_intrusive_ptr;
538  struct {
539  DeviceType type;
540  DeviceIndex index;
541  } as_device;
542  } payload;
543  Tag tag;
544  bool is_intrusive_ptr;
545 };
546 
547 // Future
548 struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
549  private:
550  c10::intrusive_ptr<Future> intrusive_from_this() {
551  c10::raw::intrusive_ptr::incref(this); // we are creating a new pointer
552  // from a raw `this` pointer
553  // so we need to bump the refcount
554  // to account for this ownership
556  }
557 
558  public:
559  struct CAFFE2_API FutureError final : public std::exception {
560  FutureError(std::string&& error_msg_)
561  : error_msg(std::move(error_msg_)) {}
562 
563  FutureError() = default;
564 
565  const char* what() const noexcept override {
566  return error_msg.c_str();
567  }
568 
569  std::string error_msg;
570  };
571 
575  void wait() {
576  if (completed()) {
577  return;
578  }
579  std::condition_variable finished;
580  bool fired = false;
581 
582  // Add a callback to notify the current thread
583  // when the current future completes.
584  addCallback([&] {
585  std::unique_lock<std::mutex> lock(mutex_);
586  finished.notify_all();
587  fired = true;
588  });
589 
590  // The current thread will be blocked unless the above callback is fired.
591  std::unique_lock<std::mutex> lock(mutex_);
592  while (!fired) {
593  finished.wait(lock);
594  }
595 
596  AT_ASSERT(completed());
597  }
598 
602  void markCompleted(IValue value) {
603  {
604  // This is not to protect completed_ but to create a barrier
605  // from possible addCallback() calls
606  std::unique_lock<std::mutex> lock(mutex_);
607  AT_ASSERT(!completed());
608  completed_ = true;
609  value_ = std::move(value);
610  }
611 
612  fireCallbacks();
613  }
614 
615  void markCompleted(FutureError&& error_) {
616  {
617  // This is not to protect completed_ but to create a barrier
618  // from possible addCallback() calls
619  std::unique_lock<std::mutex> lock(mutex_);
620  AT_ASSERT(!completed());
621  completed_ = true;
622  has_error = true;
623  error = std::move(error_);
624  }
625 
626  fireCallbacks();
627  }
628 
629  // Get the result of the current future.
630  IValue value() {
631  std::unique_lock<std::mutex> lock(mutex_);
632  AT_ASSERT(completed());
633  if (has_error) {
634  throw error;
635  }
636  return value_;
637  }
638 
645  void addCallback(std::function<void(void)> callback) {
646  std::unique_lock<std::mutex> lock(mutex_);
647  if (completed()) {
648  lock.unlock();
649  callback();
650  return;
651  }
652  callbacks.push_back(callback);
653  }
654 
655  // Check if the current future has completed
656  bool completed() {
657  return completed_;
658  }
659 
660  CAFFE2_API friend std::ostream& operator<<(
661  std::ostream& out,
662  const Future& v);
663 
664  private:
665  void fireCallbacks() {
666  AT_ASSERT(completed());
667  // There is no need to protect callbacks with the lock.
668  // Once completed_ is set to true, no one can add new callback to the list.
669  for (auto& callback : callbacks) {
670  callback();
671  }
672  callbacks.clear();
673  }
674 
675  std::mutex mutex_;
676  IValue value_; // when finished the value
677  std::atomic_bool completed_ = {false}; // is this future complete
678  std::vector<std::function<void(void)>> callbacks;
679  bool has_error = false;
680  FutureError error;
681 };
682 
683 // User-defined object.
684 struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
685  public:
686  Object(Symbol name, size_t numSlots) : typename_(std::move(name)) {
687  slots_.resize(numSlots);
688  }
689 
690  static c10::intrusive_ptr<Object> create(
691  Symbol name,
692  size_t numSlots) {
693  return c10::make_intrusive<Object>(std::move(name), numSlots);
694  }
695 
696  void setSlot(size_t slot, IValue v) {
697  slots_[slot] = v;
698  }
699 
700  IValue getSlot(size_t slot) const {
701  return slots_.at(slot);
702  }
703 
704  Symbol name() const {
705  return typename_;
706  }
707 
708  private:
709  const Symbol typename_;
710  std::vector<IValue> slots_;
711 };
712 
713 struct C10_EXPORT ivalue::GenericDict : c10::intrusive_ptr_target {
714  private:
715  UnorderedMap elements_;
716 
717  public:
718  GenericDict(UnorderedMap elements_)
719  : elements_(std::move(elements_)) {}
720  static c10::intrusive_ptr<GenericDict> create(
721  UnorderedMap elements_) {
722  return c10::make_intrusive<GenericDict>(std::move(elements_));
723  }
724  const UnorderedMap& elements() const {
725  return elements_;
726  }
727  operator const UnorderedMap&() const {
728  return elements();
729  }
730 
731  UnorderedMap& elements() {
732  return elements_;
733  }
734  operator UnorderedMap&() {
735  return elements();
736  }
737 };
738 
739 #undef TORCH_FORALL_TAGS
740 
741 namespace detail {
742 
745 };
746 using _guarded_unsigned_long = c10::guts::conditional_t<
747  std::is_same<unsigned long, uint32_t>::value ||
748  std::is_same<unsigned long, uint64_t>::value,
750  unsigned long>;
751 
752 } // namespace detail
753 
754 #define DEFINE_TO(type, method_name) \
755 template<> \
756 inline type IValue::to<type>() && { \
757  return std::move(*this).method_name(); \
758 } \
759 template<> \
760 inline type IValue::to<type>() const & { \
761  return this->method_name(); \
762 }
763 DEFINE_TO(at::Tensor, toTensor)
764 DEFINE_TO(c10::intrusive_ptr<ivalue::Tuple>, toTuple)
765 DEFINE_TO(float, toDouble)
766 DEFINE_TO(double, toDouble)
767 DEFINE_TO(unsigned char, toInt)
768 DEFINE_TO(signed char, toInt)
769 DEFINE_TO(unsigned short, toInt)
770 DEFINE_TO(short, toInt)
771 DEFINE_TO(int, toInt)
772 DEFINE_TO(uint32_t, toInt)
773 DEFINE_TO(uint64_t, toInt)
774 DEFINE_TO(detail::_guarded_unsigned_long, toInt)
775 DEFINE_TO(int64_t, toInt)
776 DEFINE_TO(bool, toBool)
777 DEFINE_TO(c10::intrusive_ptr<caffe2::Blob>, toBlob);
778 DEFINE_TO(c10::intrusive_ptr<ivalue::DoubleList>, toDoubleList)
779 DEFINE_TO(c10::intrusive_ptr<ivalue::IntList>, toIntList)
780 DEFINE_TO(c10::intrusive_ptr<ivalue::BoolList>, toBoolList)
781 DEFINE_TO(c10::intrusive_ptr<ivalue::TensorList>, toTensorList)
782 DEFINE_TO(c10::intrusive_ptr<ivalue::GenericList>, toGenericList)
783 DEFINE_TO(c10::intrusive_ptr<ivalue::GenericDict>, toGenericDict)
785 DEFINE_TO(c10::intrusive_ptr<ivalue::Object>, toObject)
786 DEFINE_TO(at::Scalar, toScalar)
787 DEFINE_TO(std::vector<int64_t>, toIntListRef)
788 DEFINE_TO(std::vector<double>, toDoubleListRef)
789 DEFINE_TO(std::vector<bool>, toBoolListRef)
790 DEFINE_TO(std::vector<at::Tensor>, toTensorListRef)
791 DEFINE_TO(std::vector<IValue>, toGenericListRef)
792 DEFINE_TO(std::string, toStringRef)
793 DEFINE_TO(c10::intrusive_ptr<ivalue::Future>, toFuture)
794 DEFINE_TO(IValue, toIValue)
795 DEFINE_TO(c10::Device, toDevice)
796 DEFINE_TO(at::ScalarType, toScalarType)
797 DEFINE_TO(at::Layout, toLayout)
798 
799 template <typename T>
800 struct _fake_type {};
801 
802 template <typename Elem>
803 std::vector<Elem> generic_to(
804  const IValue* ivalue,
805  _fake_type<std::vector<Elem>>) {
806  return fmap(ivalue->toGenericListRef(), [](IValue item_ivalue) { return item_ivalue.to<Elem>(); });
807 }
808 
809 template <typename K, typename V>
810 std::unordered_map<K, V> generic_to(
811  const IValue* ivalue,
812  _fake_type<std::unordered_map<K, V>>) {
813  std::unordered_map<K, V> specialized_dict;
814 
815  for (auto item : ivalue->toGenericDictRef()) {
816  specialized_dict[item.first.to<K>()] = item.second.to<V>();
817  }
818 
819  return specialized_dict;
820 }
821 
822 template <typename T>
823 inline T IValue::to() && {
824  return generic_to(this, _fake_type<T>{});
825 }
826 
827 template <typename T>
828 inline T IValue::to() const& {
829  return generic_to(this, _fake_type<T>{});
830 }
831 
832 // note: when adding a DEFINE_TO case here you should also add a
833 // toX method to IValue. These named methods are much more discoverable
834 // than the to templated function.
835 
836 inline IValue::IValue(c10::intrusive_ptr<ivalue::Tuple> v)
837 : tag(Tag::Tuple), is_intrusive_ptr(true) {
838  payload.as_intrusive_ptr = v.release();
839 }
840 
841 inline IValue::IValue(c10::intrusive_ptr<ivalue::IntList> v)
842 : tag(Tag::IntList), is_intrusive_ptr(true) {
843  payload.as_intrusive_ptr = v.release();
844 }
845 inline IValue::IValue(std::vector<int64_t> v)
846 : IValue(ivalue::IntList::create(std::move(v))) {}
847 
848 inline IValue::IValue(c10::intrusive_ptr<ivalue::ConstantString> v)
849 : tag(Tag::String), is_intrusive_ptr(true) {
850  payload.as_intrusive_ptr = v.release();
851 }
852 inline IValue::IValue(std::string v)
853 : IValue(ivalue::ConstantString::create(std::move(v))) {}
854 
855 inline IValue::IValue(c10::intrusive_ptr<ivalue::DoubleList> v)
856 : tag(Tag::DoubleList), is_intrusive_ptr(true) {
857  payload.as_intrusive_ptr = v.release();
858 }
859 inline IValue::IValue(std::vector<double> v)
860 : IValue(ivalue::DoubleList::create(std::move(v))) {}
861 
862 inline IValue::IValue(c10::intrusive_ptr<ivalue::BoolList> v)
863 : tag(Tag::BoolList), is_intrusive_ptr(true) {
864  payload.as_intrusive_ptr = v.release();
865 }
866 inline IValue::IValue(std::vector<bool> v)
867 : IValue(ivalue::BoolList::create(std::move(v))) {}
868 
869 inline IValue::IValue(c10::intrusive_ptr<ivalue::TensorList> v)
870 : tag(Tag::TensorList), is_intrusive_ptr(true) {
871  payload.as_intrusive_ptr = v.release();
872 }
873 inline IValue::IValue(std::vector<at::Tensor> v)
874 : IValue(ivalue::TensorList::create(std::move(v))) {}
875 
876 inline IValue::IValue(c10::intrusive_ptr<ivalue::GenericList> v)
877 : tag(Tag::GenericList), is_intrusive_ptr(true) {
878  payload.as_intrusive_ptr = v.release();
879 }
880 inline IValue::IValue(std::vector<IValue> v)
881 : IValue(ivalue::GenericList::create(std::move(v))) {}
882 
883 inline IValue::IValue(c10::intrusive_ptr<ivalue::GenericDict> v)
884 : tag(Tag::GenericDict), is_intrusive_ptr(true) {
885  payload.as_intrusive_ptr = v.release();
886 }
887 inline IValue::IValue(ivalue::UnorderedMap v)
888 : IValue(ivalue::GenericDict::create(std::move(v))) {}
889 
890 inline IValue::IValue(c10::intrusive_ptr<ivalue::Object> v)
891 : tag(Tag::Object), is_intrusive_ptr(true) {
892  payload.as_intrusive_ptr = v.release();
893 }
894 inline IValue::IValue(c10::intrusive_ptr<ivalue::Future> v)
895 : tag(Tag::Future), is_intrusive_ptr(true) {
896  payload.as_intrusive_ptr = v.release();
897 }
898 
899 inline const std::vector<int64_t>& IValue::toIntListRef() const {
900  return toIntList()->elements();
901 }
902 
903 inline const std::vector<double>& IValue::toDoubleListRef() const {
904  return toDoubleList()->elements();
905 }
906 
907 inline const std::vector<at::Tensor>& IValue::toTensorListRef() const {
908  return toTensorList()->elements();
909 }
910 
911 inline const std::vector<bool>& IValue::toBoolListRef() const {
912  return toBoolList()->elements();
913 }
914 
915 inline const std::vector<IValue>& IValue::toGenericListRef() const {
916  return toGenericList()->elements();
917 }
918 
919 inline const c10::ivalue::UnorderedMap& IValue::
920  toGenericDictRef() const {
921  return toGenericDict()->elements();
922 }
923 
924 inline const std::string& IValue::toStringRef() const {
925  return toString()->string();
926 }
927 
928 template<typename T>
929 inline optional<T> IValue::toOptional() {
930  if (this->isNone()) {
931  return nullopt;
932  }
933  return this->to<T>();
934 }
935 
936 inline bool IValue::isSameIdentity(IValue& rhs) {
937  // We choose to not use memcmp for payload check due to potential random padding characters on union type
938 
939  // Semantics:
940  // 1. None is None, False is False, and True is True are all true
941  // 2. If it is a tensor type, we need to take undefined tensor into account
942  // 3. Undefined_tensor is None and vice versa should be true
943  // 4. If it is a reference type (i.e. is_intrusive_ptr), then is is True when the pointed-to object is the same.
944  // 5. False for all other comparisons.
945  if (this->isNone() && rhs.isNone()) {
946  return true;
947  } else if (this->isBool() && rhs.isBool()) {
948  // for bool type, do equality check
949  return this->toBool() == rhs.toBool();
950  } else if (this->isTensor() && rhs.isTensor()) {
951  // for tensor type, just check the as_intrusive_ptr since is_intrusive_ptr is false for undefined tensor
952  return this->payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr;
953  } else if (this->isTensor() && rhs.isNone()) {
954  // special case: undefined tensor and None are the same identity
955  return !this->is_intrusive_ptr;
956  } else if (this->isNone() && rhs.isTensor()) {
957  // special case: undefined tensor and None are the same identity
958  return !rhs.is_intrusive_ptr;
959  } else {
960  // for objects holding in IValue, do shallow compare on pointer address to testify the identity
961  return this->is_intrusive_ptr && rhs.is_intrusive_ptr
962  && this->payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr;
963  }
964 }
965 } // namespace c10
966 
967 inline size_t at::ivalue::DictHash::operator()(
968  const c10::IValue& ivalue) const {
969  if (ivalue.isInt()) {
970  return std::hash<int>()(ivalue.toInt());
971  } else if (ivalue.isString()) {
972  return std::hash<std::string>()(ivalue.toStringRef());
973  } else if (ivalue.isDouble()) {
974  return std::hash<double>()(ivalue.toDouble());
975  } else {
976  throw std::runtime_error("Can't hash IValues with this tag");
977  }
978 }
979 
980 inline bool at::ivalue::DictEqualTo::operator()(
981  const c10::IValue& lhs,
982  const c10::IValue& rhs) const {
983  if (lhs.isInt()) {
984  return lhs.toInt() == rhs.toInt();
985  } else if (lhs.isString()) {
986  return lhs.toStringRef() == rhs.toStringRef();
987  } else if (lhs.isDouble()) {
988  return lhs.toDouble() == rhs.toDouble();
989  } else {
990  throw std::runtime_error("Can't compare IValues with this tag");
991  }
992 }
void markCompleted(IValue value)
Explicitly mark the future as completed with the output value.
Definition: ivalue.h:602
Scalar represents a 0-dimensional tensor which contains a single element.
Definition: Scalar.h:22
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
int16_t DeviceIndex
An index representing a specific device; e.g., the 1 in GPU 1.
Definition: Device.h:18
intrusive_ptr<T> is an alternative to shared_ptr<T> that has better performance because it does the r...
Definition: intrusive_ptr.h:35
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition: alias_info.h:7
TTarget * release() noexcept
Returns an owning (!) pointer to the underlying object and makes the intrusive_ptr instance invalid...
DeviceIndex index() const noexcept
Returns the optional index.
Definition: Device.h:70
void wait()
Wait on the future until it completes.
Definition: ivalue.h:575
void addCallback(std::function< void(void)> callback)
Add a callback to the future.
Definition: ivalue.h:645
static intrusive_ptr reclaim(TTarget *owning_ptr)
Takes an owning pointer to TTarget* and creates an intrusive_ptr that takes over ownership.
DeviceType type() const noexcept
Returns the type of device this is.
Definition: Device.h:65