3 #include <ATen/core/TensorMethods.h> 4 #include <ATen/core/Type.h> 5 #include <ATen/core/functional.h> 6 #include <ATen/core/interned_strings.h> 7 #include <ATen/core/ivalue.h> 8 #include <c10/util/TypeList.h> 9 #include <caffe2/core/common.h> 11 #include <c10/util/Optional.h> 15 #include <type_traits> 28 #define C10_FORALL_TYPES(_) \ 30 _(DimensionedTensorType) \ 31 _(CompleteTensorType) \ 32 _(AutogradZeroTensorType) \ 50 #define DEFINE_TYPE(T) T, 51 C10_FORALL_TYPES(DEFINE_TYPE)
55 CAFFE2_API
const char * typeKindToString(TypeKind kind);
57 #define DEFINE_IS_SUBCLASS(_kind) \ 58 bool isSubclass(const TypeKind kind) const override { \ 59 return kind == TypeKind::_kind; \ 63 using TypePtr = std::shared_ptr<Type>;
65 struct CAFFE2_API
Type : std::enable_shared_from_this<Type> {
69 static std::shared_ptr<T> sliceType(std::shared_ptr<const T> ptr) {
70 auto result = std::make_shared<typename std::remove_const<T>::type>(*ptr);
74 result->kind_ = T::Kind;
83 virtual bool operator==(
const Type& rhs)
const = 0;
87 virtual bool isSubtypeOf(
const TypePtr rhs)
const;
91 virtual bool isSubclass(
const TypeKind kind)
const = 0;
94 virtual std::string str()
const = 0;
98 virtual std::string python_str()
const {
102 TypeKind kind()
const {
115 std::shared_ptr<T> cast() {
116 std::shared_ptr<T> r =
nullptr;
117 if (isSubclass(T::Kind)) {
118 r = std::static_pointer_cast<
T>(shared_from_this());
120 if (!r || T::Kind == kind()) {
123 return sliceType<T>(r);
127 std::shared_ptr<const T> cast()
const {
128 std::shared_ptr<const T> r =
nullptr;
129 if (isSubclass(T::Kind)) {
130 r = std::static_pointer_cast<
const T>(shared_from_this());
132 if (!r || T::Kind == kind()) {
135 return sliceType<T>(r);
139 std::shared_ptr<T> expect() {
145 std::shared_ptr<const T> expect()
const {
146 auto r = cast<const T>();
150 virtual ~
Type() =
default;
151 virtual bool hasFreeVariables()
const {
161 TypePtr withContained(std::vector<TypePtr> contained_types) {
162 auto current_contained = containedTypes();
163 AT_ASSERT(current_contained.size() == contained_types.size());
164 if(current_contained.equals(contained_types)) {
165 return shared_from_this();
167 return createWithContained(std::move(contained_types));
171 virtual TypePtr createWithContained(std::vector<TypePtr> contained_types)
const {
172 AT_ERROR(
"type with contained types did not overload createWithContained: ", str());
176 inline bool operator!=(
const Type & lhs,
const Type & rhs) {
177 return !(lhs == rhs);
182 template<TypeKind K,
typename T>
184 static const TypeKind Kind = K;
185 TypePtr getElementType()
const {
188 bool hasFreeVariables()
const override {
189 return has_free_variables_;
197 bool operator==(
const Type& rhs)
const override {
198 if(
auto rhs_ = rhs.cast<
T>()) {
199 return *getElementType() == *rhs_->getElementType();
206 , elem(std::move(elem))
207 , has_free_variables_(getElementType()->hasFreeVariables()) {}
210 bool has_free_variables_;
215 using OptionalTypePtr = std::shared_ptr<OptionalType>;
227 static OptionalTypePtr create(TypePtr element) {
228 return OptionalTypePtr(
new OptionalType(std::move(element)));
231 bool isSubtypeOf(
const TypePtr rhs)
const override {
233 return getElementType()->isSubtypeOf(rhs_->getElementType());
238 std::string str()
const override {
239 std::stringstream ss;
240 ss << getElementType()->str() <<
"?";
243 std::string python_str()
const override {
244 std::stringstream ss;
245 ss <<
"Optional[" << getElementType()->python_str() <<
"]";
249 TypePtr createWithContained(std::vector<TypePtr> contained_types)
const override {
250 AT_ASSERT(contained_types.size() == 1);
251 return create(contained_types[0]);
255 static OptionalTypePtr ofTensor();
261 using TensorTypePtr = std::shared_ptr<TensorType>;
267 static TensorTypePtr create() {
274 bool operator==(
const Type& rhs)
const override {
275 return rhs.kind() == kind();
277 std::string str()
const override {
280 static const TypeKind Kind = TypeKind::TensorType;
282 static TensorTypePtr
get();
284 TensorType(TypeKind kind=TypeKind::TensorType)
289 using AutogradZeroTensorTypePtr = std::shared_ptr<AutogradZeroTensorType>;
292 static AutogradZeroTensorTypePtr create() {
300 bool operator==(
const Type& rhs)
const override {
301 return rhs.kind() == kind();
303 bool isSubtypeOf(
const TypePtr rhs)
const override {
304 return rhs->kind() == TypeKind::TensorType ||
305 rhs->kind() == TypeKind::AutogradZeroTensorType ||
306 TensorType::isSubtypeOf(rhs);
308 std::string str()
const override {
309 return "UndefinedTensor";
312 static const TypeKind Kind = TypeKind::AutogradZeroTensorType;
314 static AutogradZeroTensorTypePtr
get();
320 using DimensionedTensorTypePtr = std::shared_ptr<DimensionedTensorType>;
323 template<
typename ...
T>
324 static DimensionedTensorTypePtr create(
T&& ... all ) {
328 at::ScalarType scalarType()
const {
return scalar_type_; }
330 int64_t dim()
const {
return dim_; }
331 bool requires_grad()
const override {
return requires_grad_; }
333 DimensionedTensorTypePtr toScalarType(at::ScalarType type){
334 auto t = DimensionedTensorType::create(*
this);
335 t->scalar_type_ = type;
338 DimensionedTensorTypePtr withDim(
size_t new_dim) {
339 auto t = DimensionedTensorType::create(*
this);
343 DimensionedTensorTypePtr withRequiresGrad(
bool req) {
344 auto t = DimensionedTensorType::create(*
this);
345 t->requires_grad_ = req;
349 bool operator==(
const Type& rhs)
const override {
350 if (rhs.kind() != TypeKind::DimensionedTensorType)
353 return scalarType() == rt->scalarType() &&
354 device() == rt->device() &&
357 bool isSubtypeOf(
const TypePtr rhs)
const override {
358 return rhs->kind() == TypeKind::TensorType ||
359 (rhs->kind() == TypeKind::DimensionedTensorType && Type::isSubtypeOf(rhs)) ||
360 TensorType::isSubtypeOf(rhs);
362 bool isSubclass(
const TypeKind kind)
const override {
363 return kind == TypeKind::TensorType ||
364 kind == TypeKind::DimensionedTensorType;
366 std::string str()
const override {
372 static const TypeKind Kind = TypeKind::DimensionedTensorType;
383 , scalar_type_(scalar_type)
384 , requires_grad_(at::isFloatingType(scalar_type) &&
requires_grad)
388 at::ScalarType scalar_type_;
395 using CompleteTensorTypePtr = std::shared_ptr<CompleteTensorType>;
398 template<
typename ...
T>
399 static CompleteTensorTypePtr create(
T&& ... all ) {
408 return CompleteTensorTypePtr(
new CompleteTensorType(scalar_type, device, sizes, strides));
411 const std::vector<int64_t>& sizes()
const {
return sizes_; }
412 const std::vector<int64_t>& strides()
const {
return strides_; }
415 return CompleteTensorType::create(scalar_type_, device_, sizes, strides);
419 return withSizesStrides(sizes, CompleteTensorType::contiguousStridesOf(sizes));
422 CompleteTensorTypePtr contiguous()
const {
423 auto t = CompleteTensorType::create(*
this);
424 t->strides_ = CompleteTensorType::contiguousStridesOf(sizes_);
428 CompleteTensorTypePtr toScalarType(at::ScalarType type){
429 auto t = CompleteTensorType::create(*
this);
430 t->scalar_type_ = type;
434 bool operator==(
const Type& rhs)
const override {
435 if(rhs.kind() != kind())
438 return scalarType() == rt->scalarType() &&
439 sizes() == rt->sizes() &&
440 strides() == rt->strides() &&
441 device() == rt->device();
443 bool isSubtypeOf(
const TypePtr rhs)
const override {
444 if (rhs->kind() == TypeKind::DimensionedTensorType)
445 return *expect<DimensionedTensorType>() == *rhs;
446 return rhs->kind() == TypeKind::TensorType ||
447 TensorType::isSubtypeOf(rhs);
449 bool isSubclass(
const TypeKind kind)
const override {
450 return kind == TypeKind::TensorType ||
451 kind == TypeKind::DimensionedTensorType ||
452 kind == TypeKind::CompleteTensorType;
454 std::string str()
const override {
461 for(
auto s : sizes()) {
467 static const TypeKind Kind = TypeKind::CompleteTensorType;
469 static TypePtr fromNumberType(TypePtr typ);
470 static TypePtr fromBoolType();
475 , sizes_(tensor.sizes().vec())
476 , strides_(tensor.strides().vec()) {}
481 , sizes_(sizes.vec())
482 , strides_(strides.vec()) {}
484 static std::vector<int64_t> contiguousStridesOf(
at::IntArrayRef sizes) {
485 std::vector<int64_t> strides(sizes.
size());
489 for(
size_t i = strides.
size() - 1; i > 0; i--) {
490 strides[i-1] = strides[i] * sizes[i];
495 std::vector<int64_t> sizes_;
496 std::vector<int64_t> strides_;
500 using ListTypePtr = std::shared_ptr<ListType>;
505 template<
typename ...
T>
506 static ListTypePtr create(
T&& ... all ) {
507 return ListTypePtr(
new ListType( std::forward<T>(all)... ));
510 std::string str()
const override {
511 std::stringstream ss;
512 ss << getElementType()->str() <<
"[]";
515 std::string python_str()
const override {
516 std::stringstream ss;
517 ss <<
"List[" << getElementType()->python_str() <<
"]";
520 TypePtr createWithContained(std::vector<TypePtr> contained_types)
const override {
521 return create(contained_types.at(0));
524 static ListTypePtr ofTensors();
525 static ListTypePtr ofInts();
526 static ListTypePtr ofFloats();
527 static ListTypePtr ofBools();
533 using DictTypePtr = std::shared_ptr<DictType>;
536 static const TypeKind Kind = TypeKind::DictType;
538 static DictTypePtr create(TypePtr key, TypePtr value) {
539 switch (key->kind()) {
540 case TypeKind::IntType:
541 case TypeKind::FloatType:
542 case TypeKind::StringType:
543 return DictTypePtr(
new DictType(key, value));
546 "Cannot create dict for key type '",
548 "', only int, float, and string keys are supported");
552 std::string str()
const override {
556 std::string python_str()
const override {
557 std::stringstream ss;
558 ss <<
"Dict[" << getKeyType()->python_str() <<
", " 559 << getValueType()->python_str() <<
"]";
563 TypePtr createWithContained(
564 std::vector<TypePtr> contained_types)
const override {
565 if (contained_types.size() != 2) {
566 throw std::runtime_error(
"Expected 2 contained types");
568 return create(contained_types.at(0), contained_types.at(1));
571 TypePtr getKeyType()
const {
575 TypePtr getValueType()
const {
580 bool isSubtypeOf(
const TypePtr rhs)
const override {
581 if (
auto dict_rhs = rhs->cast<
DictType>()) {
582 return getKeyType()->isSubtypeOf(dict_rhs->getKeyType()) &&
583 getValueType()->isSubtypeOf(dict_rhs->getValueType());
588 bool hasFreeVariables()
const override {
589 return has_free_variables;
600 bool operator==(
const Type& rhs)
const override {
601 if (
auto dict_rhs = rhs.cast<
DictType>()) {
602 return *getKeyType() == *(dict_rhs->getKeyType()) &&
603 *getValueType() == *(dict_rhs->getValueType());
609 DictType(TypePtr key, TypePtr value)
610 :
Type(TypeKind::DictType),
613 key->hasFreeVariables() || value->hasFreeVariables()) {}
614 std::vector<TypePtr> types;
615 bool has_free_variables;
619 using FutureTypePtr = std::shared_ptr<FutureType>;
623 template<
typename ...
T>
624 static FutureTypePtr create(TypePtr elem) {
625 return FutureTypePtr(
new FutureType(std::move(elem)));
630 std::string str()
const override {
631 std::stringstream ss;
632 ss <<
"Future(" << getElementType()->str() <<
")";
635 std::string python_str()
const override {
636 std::stringstream ss;
637 ss <<
"Future[" << getElementType()->python_str() <<
"]";
640 TypePtr createWithContained(std::vector<TypePtr> contained_types)
const override {
641 return create(contained_types.at(0));
648 using TupleTypePtr = std::shared_ptr<TupleType>;
652 static TupleTypePtr create(std::vector<TypePtr> types,
OptNameList names=c10::nullopt) {
653 return TupleTypePtr(
new TupleType(std::move(types), std::move(names)));
659 bool operator==(
const Type& rhs)
const override {
660 return compare(rhs, [](
const TypePtr a,
const TypePtr b) {
662 }) && names_ == rhs.expect<
TupleType>()->names_;
666 bool isSubtypeOf(
const TypePtr rhs_)
const override {
667 if (Type::isSubtypeOf(rhs_))
673 if (!hasNames() && rhs->hasNames())
676 bool names_match = !rhs->hasNames() || names() == rhs->names();
678 return names_match && compare(*rhs, [](
const TypePtr a,
const TypePtr b) {
679 return a->isSubtypeOf(b);
683 return std::any_of(elements_.begin(), elements_.end(),
684 [](
const TypePtr& ptr) {
return ptr->requires_grad(); });
686 std::string str()
const override {
687 std::stringstream ss;
689 for(
size_t i = 0; i < elements().size(); ++i) {
692 ss << elements()[i]->str();
697 std::string python_str()
const override {
698 std::stringstream ss;
700 for(
size_t i = 0; i < elements().size(); ++i) {
703 ss << elements()[i]->python_str();
708 bool hasFreeVariables()
const override {
709 return has_free_variables_;
711 bool hasNames()
const {
712 return names_.has_value();
714 const std::vector<std::string> &names()
const {
715 return names_.value();
721 TypePtr createWithContained(std::vector<TypePtr> contained_types)
const override {
722 return create(std::move(contained_types));
725 static const TypeKind Kind = TypeKind::TupleType;
728 :
Type(TypeKind::TupleType)
729 , elements_(std::move(elements_))
730 , names_(std::move(names)) {
731 has_free_variables_ =
732 std::any_of(elements_.begin(), elements_.end(), [](TypePtr v) {
733 return v->hasFreeVariables();
737 bool compare(
const Type& rhs, std::function<
bool(
const TypePtr,
const TypePtr)> fn)
const {
738 if(rhs.kind() != kind())
740 const auto & l_elements = elements();
741 const auto & r_elements = rhs.cast<
TupleType>()->elements();
742 if(l_elements.size() != r_elements.size())
744 for(
size_t i = 0; i < l_elements.size(); ++i) {
745 if(!fn(l_elements[i], r_elements[i]))
751 std::vector<TypePtr> elements_;
752 bool has_free_variables_;
757 using NumberTypePtr = std::shared_ptr<NumberType>;
763 static NumberTypePtr create() {
767 bool operator==(
const Type& rhs)
const override {
768 return rhs.kind() == kind();
770 std::string str()
const override {
773 std::string python_str()
const override {
778 static const TypeKind Kind = TypeKind::NumberType;
780 static NumberTypePtr
get();
782 NumberType(TypeKind kind=TypeKind::NumberType)
787 using FloatTypePtr = std::shared_ptr<FloatType>;
790 static FloatTypePtr create() {
794 bool operator==(
const Type& rhs)
const override {
795 return rhs.kind() == kind();
797 std::string str()
const override {
800 std::string python_str()
const override {
803 bool isSubtypeOf(
const TypePtr rhs)
const override {
804 return rhs->kind() == TypeKind::NumberType ||
805 NumberType::isSubtypeOf(rhs);
807 static const TypeKind Kind = TypeKind::FloatType;
809 static FloatTypePtr
get();
816 using IntTypePtr = std::shared_ptr<IntType>;
819 static IntTypePtr create() {
820 return IntTypePtr(
new IntType());
823 bool operator==(
const Type& rhs)
const override {
824 return rhs.kind() == kind();
826 std::string str()
const override {
829 std::string python_str()
const override {
832 bool isSubtypeOf(
const TypePtr rhs)
const override {
833 return rhs->kind() == TypeKind::NumberType ||
834 NumberType::isSubtypeOf(rhs);
836 static const TypeKind Kind = TypeKind::IntType;
838 static IntTypePtr
get();
845 using BoolTypePtr = std::shared_ptr<BoolType>;
848 static BoolTypePtr create( ) {
852 bool operator==(
const Type& rhs)
const override {
853 return rhs.kind() == kind();
855 std::string str()
const override {
858 static const TypeKind Kind = TypeKind::BoolType;
860 static BoolTypePtr
get();
863 :
Type(TypeKind::BoolType) {}
867 using StringTypePtr = std::shared_ptr<StringType>;
870 static StringTypePtr create() {
874 bool operator==(
const Type& rhs)
const override {
875 return rhs.kind() == kind();
877 std::string str()
const override {
880 std::string python_str()
const override {
883 static const TypeKind Kind = TypeKind::StringType;
885 static StringTypePtr
get();
888 :
Type(TypeKind::StringType) {}
892 using NoneTypePtr = std::shared_ptr<NoneType>;
895 static NoneTypePtr create() {
899 bool operator==(
const Type& rhs)
const override {
900 return rhs.kind() == kind();
902 bool isSubtypeOf(
const TypePtr rhs)
const override {
903 return rhs->kind() == TypeKind::NoneType;
905 std::string str()
const override {
908 static const TypeKind Kind = TypeKind::NoneType;
910 static NoneTypePtr
get();
913 :
Type(TypeKind::NoneType) {}
917 using GeneratorTypePtr = std::shared_ptr<GeneratorType>;
920 static GeneratorTypePtr create() {
924 bool operator==(
const Type& rhs)
const override {
925 return rhs.kind() == kind();
927 std::string str()
const override {
930 static const TypeKind Kind = TypeKind::GeneratorType;
932 static GeneratorTypePtr
get();
935 :
Type(TypeKind::GeneratorType) {}
939 using DeviceObjTypePtr = std::shared_ptr<DeviceObjType>;
942 static DeviceObjTypePtr create() {
946 bool operator==(
const Type& rhs)
const override {
947 return rhs.kind() == kind();
949 std::string str()
const override {
952 static const TypeKind Kind = TypeKind::DeviceObjType;
954 static DeviceObjTypePtr
get();
957 :
Type(TypeKind::DeviceObjType) {}
962 using VarTypePtr = std::shared_ptr<VarType>;
965 static VarTypePtr create(std::string name_) {
966 return VarTypePtr(
new VarType(std::move(name_)));
969 bool operator==(
const Type& rhs)
const override {
970 return rhs.kind() == kind();
972 std::string str()
const override {
975 const std::string& name()
const {
978 bool hasFreeVariables()
const override {
981 static const TypeKind Kind = TypeKind::VarType;
984 :
Type(TypeKind::VarType), name_(std::move(name_)) {}
988 CAFFE2_API std::ostream& operator<<(std::ostream & out,
const Type & t);
992 inline TypePtr unshapedType(
const TypePtr& type) {
993 if (type->kind() == TypeKind::DimensionedTensorType ||
994 type->kind() == TypeKind::CompleteTensorType) {
995 return TensorType::get();
997 return type->withContained(fmap(type->containedTypes(), unshapedType));
1000 inline TypePtr CompleteTensorType::fromNumberType(TypePtr typ) {
1001 if (typ->isSubtypeOf(IntType::get())) {
1002 return CompleteTensorType::create(at::kLong, at::kCPU, {});
1003 }
else if (typ->isSubtypeOf(FloatType::get())) {
1004 return CompleteTensorType::create(at::kFloat, at::kCPU, {});
1005 }
else if (typ->isSubtypeOf(BoolType::get())) {
1006 return CompleteTensorType::create(at::kLong, at::kCPU, {});
1008 AT_ERROR(
"unknown number type", typ->str());
1011 inline TypePtr CompleteTensorType::fromBoolType() {
1012 return CompleteTensorType::create(at::kLong, at::kCPU, {});
1031 static TypePtr call() {
return TensorType::get(); }
1034 static TypePtr call() {
return FloatType::get(); }
1037 static TypePtr call() {
return IntType::get(); }
1040 static TypePtr call() {
return BoolType::get(); }
1043 static TypePtr call() {
return NumberType::get(); }
1046 static TypePtr call() {
return StringType::get(); }
1049 static TypePtr call() {
1055 static TypePtr call() {
1060 template <
class K,
class V>
1062 static TypePtr call() {
1070 static TypePtr call() {
1076 template<
class T>
inline TypePtr getTypePtr() {
1082 CAFFE2_API TypePtr incompleteInferTypeFrom(
const IValue& value);
1083 CAFFE2_API TypePtr attemptToRecoverType(
const IValue& input_ivalue);
1084 CAFFE2_API
bool isSubvalueOf(
const IValue& input_ivalue, TypePtr type);
1086 using TypeEnv = std::unordered_map<std::string, TypePtr>;
1093 matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type_env);
1095 CAFFE2_API TypePtr evalTypeVariables(TypePtr type, TypeEnv & type_env);
1102 using ClassTypePtr = std::shared_ptr<ClassType>;
1103 using ::torch::jit::script::Module;
1104 using ::torch::jit::script::Method;
1109 static ClassTypePtr create(
1110 const std::string& name,
1111 std::shared_ptr<Module> module);
1113 static ClassTypePtr
get(
const std::string& name);
1115 static void clearRegistry();
1118 bool operator==(
const Type& rhs)
const override {
1119 if (
auto user_rhs = rhs.cast<
ClassType>()) {
1120 return typename_ == user_rhs->typename_;
1125 bool isSubtypeOf(
const TypePtr rhs)
const override {
1128 return *
this == *rhs;
1130 std::string str()
const override {
1131 return std::string(
"ClassType<") + typename_ +
">";
1134 std::string python_str()
const override {
1138 TypePtr getAttribute(
const std::string& name)
const {
1139 AT_ASSERT(attributeNames_.size() == attributeTypes_.size());
1141 for (
const auto& attr : attributeNames_) {
1148 if (pos >= attributeNames_.size()) {
1151 return attributeTypes_[pos];
1154 Method* getMethod(
const std::string& name)
const;
1155 std::vector<Method*> methods()
const;
1157 std::string name()
const {
1161 size_t numAttributes()
const {
1162 AT_ASSERT(attributeNames_.size() == attributeTypes_.size());
1163 return attributeNames_.size();
1169 size_t getAttributeSlot(
const std::string& name)
const {
1170 AT_ASSERT(attributeNames_.size() == attributeTypes_.size());
1172 for (
const auto& attr : attributeNames_) {
1178 throw std::runtime_error(
"Couldn't find attribute: " + name);
1181 bool hasAttribute(
const std::string& name)
const {
1182 return std::find_if(
1183 attributeNames_.cbegin(),
1184 attributeNames_.cend(),
1185 [&](
const std::string& attr) {
return attr == name; }) !=
1186 attributeNames_.cend();
1189 void addAttribute(
const std::string& name, TypePtr type) {
1190 attributeNames_.push_back(name);
1191 attributeTypes_.push_back(type);
1195 return attributeTypes_;
1198 static const TypeKind Kind = TypeKind::ClassType;
1201 ClassType(std::string name, std::shared_ptr<Module> module)
1202 :
Type(TypeKind::ClassType),
1203 typename_(std::move(name)),
1204 module_(std::move(module)) {}
1207 std::string typename_;
1215 std::vector<std::string> attributeNames_;
1216 std::vector<TypePtr> attributeTypes_;
1218 std::shared_ptr<Module> module_;
AT_CPP14_CONSTEXPR const T & back() const
back - Get the last element.
Scalar represents a 0-dimensional tensor which contains a single element.
Represents a a compute device on which a tensor is located.
constexpr size_t size() const
size - Get the array size.
bool is_variable() const noexcept
Returns true if the Tensor is actually a torch::autograd::Variable.
Device device() const
Returns a Tensor's device.
constexpr bool empty() const
empty - Check if the array is empty.
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
TensorOptions requires_grad(bool requires_grad=true)
Convenience function that returns a TensorOptions object with the requires_grad set to the given one...
Flush-To-Zero and Denormals-Are-Zero mode.
C10_NODISCARD TensorOptions requires_grad(c10::optional< bool > requires_grad) const noexcept
Sets the requires_grad property of the TensorOptions.