3 #include <condition_variable> 6 #include <ATen/core/blob.h> 7 #include <ATen/core/interned_strings.h> 8 #include <c10/core/Scalar.h> 9 #include <c10/core/TensorImpl.h> 10 #include <c10/core/UndefinedTensorImpl.h> 11 #include <c10/util/intrusive_ptr.h> 13 #include <ATen/core/Tensor.h> 26 const std::string str_;
29 : str_(std::move(str)) {}
31 const std::string & string()
const {
34 operator const std::string & ()
const {
37 CAFFE2_API
friend std::ostream& operator<<(
42 template <
typename Elem>
45 std::vector<Elem> elements_;
48 typedef Elem ElemType;
50 List(std::vector<Elem> elements_) : elements_(std::move(elements_)) {}
52 return c10::make_intrusive<List<Elem>>(std::move(elements_));
54 const std::vector<Elem>& elements()
const & {
57 operator const std::vector<Elem>&()
const {
61 std::vector<Elem>& elements() & {
64 operator std::vector<Elem>&() {
68 std::vector<Elem>&& elements() && {
69 return std::move(elements_);
78 bool operator()(
const IValue& lhs,
const IValue& rhs)
const;
81 using UnorderedMap = std::unordered_map<IValue, IValue, DictHash, DictEqualTo>;
89 return c10::make_intrusive<Tuple>(std::move(elements_));
108 #define TORCH_FORALL_TAGS(_) \ 131 , is_intrusive_ptr(
false) {}
132 IValue(
const IValue& rhs)
133 : payload(rhs.payload),
135 is_intrusive_ptr(rhs.is_intrusive_ptr) {
136 if (is_intrusive_ptr) {
137 c10::raw::intrusive_ptr::incref(payload.as_intrusive_ptr);
140 IValue(IValue&& rhs) noexcept : IValue() {
144 if (is_intrusive_ptr) {
145 c10::raw::intrusive_ptr::decref(payload.as_intrusive_ptr);
148 IValue & operator=(IValue && rhs) & noexcept {
149 IValue(std::move(rhs)).swap(*
this);
152 IValue & operator=(IValue
const & rhs) & {
153 IValue(rhs).swap(*
this);
159 bool isAliasOf(
const IValue& rhs)
const {
160 if (this->tag != rhs.tag) {
165 if (!this->is_intrusive_ptr) {
170 AT_ASSERT(rhs.is_intrusive_ptr);
173 if (this->isTensor()) {
174 const auto thisTensor = this->toTensor();
175 const auto rhsTensor = rhs.toTensor();
176 return thisTensor.is_alias_of(rhsTensor);
180 return this->payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr;
182 void swap(IValue & rhs) noexcept {
183 std::swap(payload, rhs.payload);
184 std::swap(is_intrusive_ptr, rhs.is_intrusive_ptr);
185 std::swap(tag, rhs.tag);
194 : tag(Tag::Tensor), is_intrusive_ptr(t.defined()) {
200 payload.as_intrusive_ptr = t.unsafeReleaseTensorImpl();
202 bool isTensor()
const {
return Tag::Tensor == tag; }
204 AT_ASSERT(isTensor());
205 return at::Tensor(moveToIntrusivePtr<at::TensorImpl, at::UndefinedTensorImpl>());
208 AT_ASSERT(isTensor());
209 return at::Tensor(toIntrusivePtr<at::TensorImpl, at::UndefinedTensorImpl>());
212 const IValue& toIValue()
const {
220 : tag(Tag::Blob), is_intrusive_ptr(
true) {
223 payload.as_intrusive_ptr = blob.
release();
225 bool isBlob()
const {
226 return Tag::Blob == tag;
230 return moveToIntrusivePtr<caffe2::Blob>();
234 return toIntrusivePtr<caffe2::Blob>();;
239 bool isTuple()
const {
return Tag::Tuple == tag; }
241 AT_ASSERT(isTuple());
242 return moveToIntrusivePtr<ivalue::Tuple>();
245 AT_ASSERT(isTuple());
246 return toIntrusivePtr<ivalue::Tuple>();
251 : tag(Tag::Double), is_intrusive_ptr(
false) {
252 payload.as_double = d;
254 bool isDouble()
const {
return Tag::Double == tag; }
255 double toDouble()
const {
256 AT_ASSERT(isDouble());
257 return payload.as_double;
262 bool isFuture()
const {
return Tag::Future == tag; }
264 AT_ASSERT(isFuture());
265 return moveToIntrusivePtr<ivalue::Future>();
268 AT_ASSERT(isFuture());
269 return toIntrusivePtr<ivalue::Future>();
274 : tag(Tag::Int), is_intrusive_ptr(
false) {
280 : IValue(static_cast<int64_t>(i)) {}
282 bool isInt()
const {
return Tag::Int == tag; }
284 int64_t toInt()
const {
286 return payload.as_int;
291 : tag(Tag::Bool), is_intrusive_ptr(
false) {
294 bool isBool()
const {
return Tag::Bool == tag; }
295 bool toBool()
const {
297 return payload.as_bool;
302 IValue(std::vector<int64_t> v);
305 bool isIntList()
const {
return Tag::IntList == tag; }
307 AT_ASSERT(isIntList());
308 return moveToIntrusivePtr<ivalue::IntList>();
311 AT_ASSERT(isIntList());
312 return toIntrusivePtr<ivalue::IntList>();
315 const std::vector<int64_t>& toIntListRef()
const;
316 const std::vector<double>& toDoubleListRef()
const;
317 const std::vector<bool>& toBoolListRef()
const;
318 const std::vector<at::Tensor>& toTensorListRef()
const;
319 const std::vector<IValue>& toGenericListRef()
const;
320 const ivalue::UnorderedMap& toGenericDictRef()
const;
321 const std::string& toStringRef()
const;
325 IValue(std::string v);
326 bool isString()
const {
return Tag::String == tag; }
328 AT_ASSERT(isString());
329 return moveToIntrusivePtr<ivalue::ConstantString>();
332 AT_ASSERT(isString());
333 return toIntrusivePtr<ivalue::ConstantString>();
338 IValue(std::vector<double> v);
339 bool isDoubleList()
const {
return Tag::DoubleList == tag; }
341 AT_ASSERT(isDoubleList());
342 return moveToIntrusivePtr<ivalue::DoubleList>();
345 AT_ASSERT(isDoubleList());
346 return toIntrusivePtr<ivalue::DoubleList>();
351 IValue(std::vector<bool> v);
352 bool isBoolList()
const {
return Tag::BoolList == tag; }
354 AT_ASSERT(isBoolList());
355 return moveToIntrusivePtr<ivalue::BoolList>();
358 AT_ASSERT(isBoolList());
359 return toIntrusivePtr<ivalue::BoolList>();
364 IValue(std::vector<at::Tensor> v);
365 bool isTensorList()
const {
return Tag::TensorList == tag; }
367 AT_ASSERT(isTensorList());
368 return moveToIntrusivePtr<ivalue::TensorList>();
371 AT_ASSERT(isTensorList());
372 return toIntrusivePtr<ivalue::TensorList>();
377 IValue(std::vector<IValue> v);
378 bool isGenericList()
const {
return Tag::GenericList == tag; }
380 AT_ASSERT(isGenericList());
381 return moveToIntrusivePtr<ivalue::GenericList>();
384 AT_ASSERT(isGenericList());
385 return toIntrusivePtr<ivalue::GenericList>();
390 IValue(ivalue::UnorderedMap v);
391 bool isGenericDict()
const {
return Tag::GenericDict == tag; }
393 AT_ASSERT(isGenericDict());
394 return moveToIntrusivePtr<ivalue::GenericDict>();
397 AT_ASSERT(isGenericDict());
398 return toIntrusivePtr<ivalue::GenericDict>();
403 bool isObject()
const {
return tag == Tag::Object; }
405 AT_ASSERT(isObject());
406 return toIntrusivePtr<ivalue::Object>();
409 AT_ASSERT(isObject());
410 return toIntrusivePtr<ivalue::Object>();
414 bool isNone()
const {
415 return Tag::None == tag;
417 std::string toNone()
const {
424 if(s.isFloatingPoint()) {
425 *
this = s.toDouble();
430 bool isScalar()
const {
431 return isDouble() || isInt();
438 throw std::runtime_error(
"IValue is not a Scalar");
443 : tag(Tag::Device), is_intrusive_ptr(
false) {
444 payload.as_device.type = d.
type();
445 payload.as_device.index = d.
index();
447 bool isDevice()
const {
return Tag::Device == tag; }
449 AT_ASSERT(isDevice());
450 return c10::Device(payload.as_device.type, payload.as_device.index);
454 at::ScalarType toScalarType()
const {
455 return static_cast<at::ScalarType
>(toInt());
459 at::Layout toLayout()
const {
460 return static_cast<at::Layout
>(toInt());
464 std::string tagKind()
const {
466 #define DEFINE_CASE(x) case Tag::x: return #x; 467 TORCH_FORALL_TAGS(DEFINE_CASE)
470 return "Invalid Tag";
492 bool isSameIdentity(IValue& rhs);
494 CAFFE2_API
friend std::ostream& operator<<(
498 bool isPtrType()
const {
499 return is_intrusive_ptr;
509 enum class Tag : uint32_t {
510 #define DEFINE_TAG(x) x, 511 TORCH_FORALL_TAGS(DEFINE_TAG)
515 template<
class T,
class NullType = c10::detail::
intrusive_target_default_null_type<T>>
521 template<
typename T,
class NullType = c10::detail::
intrusive_target_default_null_type<T>>
531 is_intrusive_ptr =
false;
544 bool is_intrusive_ptr;
551 c10::raw::intrusive_ptr::incref(
this);
561 : error_msg(std::move(error_msg_)) {}
563 FutureError() =
default;
565 const char* what()
const noexcept
override {
566 return error_msg.c_str();
569 std::string error_msg;
579 std::condition_variable finished;
585 std::unique_lock<std::mutex> lock(mutex_);
586 finished.notify_all();
591 std::unique_lock<std::mutex> lock(mutex_);
596 AT_ASSERT(completed());
606 std::unique_lock<std::mutex> lock(mutex_);
607 AT_ASSERT(!completed());
609 value_ = std::move(value);
619 std::unique_lock<std::mutex> lock(mutex_);
620 AT_ASSERT(!completed());
623 error = std::move(error_);
631 std::unique_lock<std::mutex> lock(mutex_);
632 AT_ASSERT(completed());
646 std::unique_lock<std::mutex> lock(mutex_);
652 callbacks.push_back(callback);
660 CAFFE2_API
friend std::ostream& operator<<(
665 void fireCallbacks() {
666 AT_ASSERT(completed());
669 for (
auto& callback : callbacks) {
677 std::atomic_bool completed_ = {
false};
678 std::vector<std::function<void(void)>> callbacks;
679 bool has_error =
false;
686 Object(
Symbol name,
size_t numSlots) : typename_(std::move(name)) {
687 slots_.resize(numSlots);
693 return c10::make_intrusive<Object>(std::move(name), numSlots);
696 void setSlot(
size_t slot,
IValue v) {
700 IValue getSlot(
size_t slot)
const {
701 return slots_.at(slot);
710 std::vector<IValue> slots_;
715 UnorderedMap elements_;
718 GenericDict(UnorderedMap elements_)
719 : elements_(std::move(elements_)) {}
721 UnorderedMap elements_) {
722 return c10::make_intrusive<GenericDict>(std::move(elements_));
724 const UnorderedMap& elements()
const {
727 operator const UnorderedMap&()
const {
731 UnorderedMap& elements() {
734 operator UnorderedMap&() {
739 #undef TORCH_FORALL_TAGS 746 using _guarded_unsigned_long = c10::guts::conditional_t<
747 std::is_same<unsigned long, uint32_t>::value ||
748 std::is_same<unsigned long, uint64_t>::value,
754 #define DEFINE_TO(type, method_name) \ 756 inline type IValue::to<type>() && { \ 757 return std::move(*this).method_name(); \ 760 inline type IValue::to<type>() const & { \ 761 return this->method_name(); \ 765 DEFINE_TO(
float, toDouble)
766 DEFINE_TO(
double, toDouble)
767 DEFINE_TO(
unsigned char, toInt)
768 DEFINE_TO(
signed char, toInt)
769 DEFINE_TO(
unsigned short, toInt)
770 DEFINE_TO(
short, toInt)
771 DEFINE_TO(
int, toInt)
772 DEFINE_TO(uint32_t, toInt)
773 DEFINE_TO(uint64_t, toInt)
774 DEFINE_TO(detail::_guarded_unsigned_long, toInt)
775 DEFINE_TO(int64_t, toInt)
776 DEFINE_TO(
bool, toBool)
787 DEFINE_TO(std::vector<int64_t>, toIntListRef)
788 DEFINE_TO(std::vector<double>, toDoubleListRef)
789 DEFINE_TO(std::vector<bool>, toBoolListRef)
790 DEFINE_TO(std::vector<at::Tensor>, toTensorListRef)
791 DEFINE_TO(std::vector<IValue>, toGenericListRef)
792 DEFINE_TO(std::string, toStringRef)
794 DEFINE_TO(
IValue, toIValue)
796 DEFINE_TO(at::ScalarType, toScalarType)
797 DEFINE_TO(at::Layout, toLayout)
799 template <
typename T>
802 template <
typename Elem>
803 std::vector<Elem> generic_to(
806 return fmap(ivalue->toGenericListRef(), [](
IValue item_ivalue) {
return item_ivalue.to<Elem>(); });
809 template <
typename K,
typename V>
810 std::unordered_map<K, V> generic_to(
813 std::unordered_map<K, V> specialized_dict;
815 for (
auto item : ivalue->toGenericDictRef()) {
816 specialized_dict[item.first.to<K>()] = item.second.to<V>();
819 return specialized_dict;
822 template <
typename T>
823 inline T IValue::to() && {
827 template <
typename T>
828 inline T IValue::to()
const& {
837 : tag(Tag::Tuple), is_intrusive_ptr(
true) {
838 payload.as_intrusive_ptr = v.
release();
842 : tag(Tag::IntList), is_intrusive_ptr(
true) {
843 payload.as_intrusive_ptr = v.
release();
845 inline IValue::IValue(std::vector<int64_t> v)
846 :
IValue(ivalue::IntList::create(std::move(v))) {}
849 : tag(Tag::String), is_intrusive_ptr(
true) {
850 payload.as_intrusive_ptr = v.
release();
852 inline IValue::IValue(std::string v)
853 :
IValue(ivalue::ConstantString::create(std::move(v))) {}
856 : tag(Tag::DoubleList), is_intrusive_ptr(
true) {
857 payload.as_intrusive_ptr = v.
release();
859 inline IValue::IValue(std::vector<double> v)
860 :
IValue(ivalue::DoubleList::create(std::move(v))) {}
863 : tag(Tag::BoolList), is_intrusive_ptr(
true) {
864 payload.as_intrusive_ptr = v.
release();
866 inline IValue::IValue(std::vector<bool> v)
867 :
IValue(ivalue::BoolList::create(std::move(v))) {}
870 : tag(Tag::TensorList), is_intrusive_ptr(
true) {
871 payload.as_intrusive_ptr = v.
release();
873 inline IValue::IValue(std::vector<at::Tensor> v)
874 :
IValue(ivalue::TensorList::create(std::move(v))) {}
877 : tag(Tag::GenericList), is_intrusive_ptr(
true) {
878 payload.as_intrusive_ptr = v.
release();
880 inline IValue::IValue(std::vector<IValue> v)
881 :
IValue(ivalue::GenericList::create(std::move(v))) {}
884 : tag(Tag::GenericDict), is_intrusive_ptr(
true) {
885 payload.as_intrusive_ptr = v.
release();
887 inline IValue::IValue(ivalue::UnorderedMap v)
888 :
IValue(ivalue::GenericDict::create(std::move(v))) {}
891 : tag(Tag::Object), is_intrusive_ptr(
true) {
892 payload.as_intrusive_ptr = v.
release();
895 : tag(Tag::Future), is_intrusive_ptr(
true) {
896 payload.as_intrusive_ptr = v.
release();
899 inline const std::vector<int64_t>& IValue::toIntListRef()
const {
900 return toIntList()->elements();
903 inline const std::vector<double>& IValue::toDoubleListRef()
const {
904 return toDoubleList()->elements();
907 inline const std::vector<at::Tensor>& IValue::toTensorListRef()
const {
908 return toTensorList()->elements();
911 inline const std::vector<bool>& IValue::toBoolListRef()
const {
912 return toBoolList()->elements();
915 inline const std::vector<IValue>& IValue::toGenericListRef()
const {
916 return toGenericList()->elements();
919 inline const c10::ivalue::UnorderedMap& IValue::
920 toGenericDictRef()
const {
921 return toGenericDict()->elements();
924 inline const std::string& IValue::toStringRef()
const {
925 return toString()->string();
930 if (this->isNone()) {
933 return this->to<T>();
936 inline bool IValue::isSameIdentity(
IValue& rhs) {
945 if (this->isNone() && rhs.isNone()) {
947 }
else if (this->isBool() && rhs.isBool()) {
949 return this->toBool() == rhs.toBool();
950 }
else if (this->isTensor() && rhs.isTensor()) {
952 return this->payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr;
953 }
else if (this->isTensor() && rhs.isNone()) {
955 return !this->is_intrusive_ptr;
956 }
else if (this->isNone() && rhs.isTensor()) {
958 return !rhs.is_intrusive_ptr;
961 return this->is_intrusive_ptr && rhs.is_intrusive_ptr
962 && this->payload.as_intrusive_ptr == rhs.payload.as_intrusive_ptr;
967 inline size_t at::ivalue::DictHash::operator()(
969 if (ivalue.isInt()) {
970 return std::hash<int>()(ivalue.toInt());
971 }
else if (ivalue.isString()) {
972 return std::hash<std::string>()(ivalue.toStringRef());
973 }
else if (ivalue.isDouble()) {
974 return std::hash<double>()(ivalue.toDouble());
976 throw std::runtime_error(
"Can't hash IValues with this tag");
980 inline bool at::ivalue::DictEqualTo::operator()(
984 return lhs.toInt() == rhs.toInt();
985 }
else if (lhs.isString()) {
986 return lhs.toStringRef() == rhs.toStringRef();
987 }
else if (lhs.isDouble()) {
988 return lhs.toDouble() == rhs.toDouble();
990 throw std::runtime_error(
"Can't compare IValues with this tag");
void markCompleted(IValue value)
Explicitly mark the future as completed with the output value.
Scalar represents a 0-dimensional tensor which contains a single element.
Represents a a compute device on which a tensor is located.
int16_t DeviceIndex
An index representing a specific device; e.g., the 1 in GPU 1.
intrusive_ptr<T> is an alternative to shared_ptr<T> that has better performance because it does the r...
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
TTarget * release() noexcept
Returns an owning (!) pointer to the underlying object and makes the intrusive_ptr instance invalid...
DeviceIndex index() const noexcept
Returns the optional index.
void wait()
Wait on the future until it completes.
void addCallback(std::function< void(void)> callback)
Add a callback to the future.
static intrusive_ptr reclaim(TTarget *owning_ptr)
Takes an owning pointer to TTarget* and creates an intrusive_ptr that takes over ownership.
DeviceType type() const noexcept
Returns the type of device this is.