1 #ifndef CAFFE2_CORE_OPERATOR_H_ 2 #define CAFFE2_CORE_OPERATOR_H_ 13 #include "c10/macros/Macros.h" 14 #include "c10/util/Registry.h" 15 #include "caffe2/core/blob.h" 16 #include "caffe2/core/common.h" 17 #include "caffe2/core/net.h" 18 #include "caffe2/core/observer.h" 19 #include "caffe2/core/operator_gradient.h" 20 #include "caffe2/core/operator_schema.h" 21 #include "caffe2/core/tensor.h" 22 #include "caffe2/core/types.h" 23 #include "caffe2/core/workspace.h" 24 #include "caffe2/proto/caffe2_pb.h" 25 #include "caffe2/utils/proto_utils.h" 27 #include <ATen/core/Tensor.h> 28 #include <ATen/core/function_schema.h> 29 #include <ATen/core/ivalue.h> 31 C10_DECLARE_bool(caffe2_operator_throw_if_fp_exceptions);
35 class CAFFE2_API OperatorBase;
36 typedef ObserverBase<OperatorBase> OperatorObserver;
51 std::vector<c10::IValue> inputs,
52 std::vector<at::Tensor> outputs);
64 CAFFE_ENFORCE(!isLegacyOperator());
65 return *fn_schema_.get();
71 if (isLegacyOperator()) {
72 CAFFE_ENFORCE(operator_def_,
"operator_def was null!");
73 return ArgumentHelper::HasArgument(*operator_def_, name);
75 return getFunctionSchema().argumentIndexWithName(name).has_value();
81 inline T GetSingleArgument(
const string& name,
const T& default_value)
const {
82 if (isLegacyOperator()) {
83 CAFFE_ENFORCE(operator_def_,
"operator_def was null!");
84 return ArgumentHelper::GetSingleArgument<OperatorDef, T>(
85 *operator_def_, name, default_value);
87 auto index = getFunctionSchema().argumentIndexWithName(name);
88 CAFFE_ENFORCE(index.has_value(),
"Couldn't get index for argument!", name);
89 const auto& value = newstyle_inputs_[index.value()];
90 return value.template to<T>();
94 inline bool HasSingleArgumentOfType(
const string& name)
const {
95 CAFFE_ENFORCE(operator_def_,
"operator_def was null!");
96 return ArgumentHelper::HasSingleArgumentOfType<OperatorDef, T>(
97 *operator_def_, name);
100 inline vector<T> GetVectorFromIValueList(
const c10::IValue& value)
const {
101 return value.template to<vector<T>>();
104 template <
typename T>
105 inline vector<T> GetRepeatedArgument(
107 const vector<T>& default_value = {})
const {
108 if (isLegacyOperator()) {
109 CAFFE_ENFORCE(operator_def_,
"operator_def was null!");
110 return ArgumentHelper::GetRepeatedArgument<OperatorDef, T>(
111 *operator_def_, name, default_value);
113 auto index = getFunctionSchema().argumentIndexWithName(name);
114 CAFFE_ENFORCE(index.has_value(),
"Couldn't get index for argument!", name);
115 const auto& value = newstyle_inputs_[index.value()];
116 return GetVectorFromIValueList<T>(value);
120 template <
typename T>
121 inline const T& Input(
int idx) {
123 !std::is_same<T, Tensor>::value,
124 "You should use Input<Tensor>(int, DeviceType) for " 126 DCHECK_LT(idx, inputs_.size());
128 return inputs_.at(idx)->template Get<T>();
130 if (has_debug_def()) {
131 enf.AppendMessage(
".\nOffending Blob name: ");
132 enf.AppendMessage(debug_def().input(idx));
133 enf.AppendMessage(
".\n");
143 template <
typename T>
144 inline const T& Input(
int idx, DeviceType type) {
145 if (isLegacyOperator()) {
147 std::is_same<T, Tensor>::value,
148 "Input(int, DeviceType) is only available for Tensor");
149 DCHECK_LT(idx, inputs_.size());
153 const auto& tensor = inputs_.at(idx)->template Get<T>();
156 if (has_debug_def()) {
157 enf.AppendMessage(
".\nOffending Blob name: ");
158 enf.AppendMessage(debug_def().input(idx));
159 enf.AppendMessage(
".\n");
164 DCHECK_LT(0, newstyle_inputs_.size());
166 if (newstyle_inputs_[0].isTensorList()) {
170 const auto& tensorList = newstyle_inputs_[0].toTensorListRef();
171 DCHECK_LT(idx, tensorList.size());
172 ival = tensorList[idx];
175 DCHECK_LT(idx, newstyle_inputs_.size());
176 ival = newstyle_inputs_[idx];
180 "Input(int, DeviceType) is only available for IValues that store Tensors");
182 CAFFE_ENFORCE_EQ(tensor.GetDeviceType(), type);
183 input_tensors_[idx] = std::move(tensor);
184 return input_tensors_[idx];
187 template <
typename T>
188 inline T* Output(
int idx) {
190 !std::is_same<T, Tensor>::value,
191 "You should use Output<Tensor>(int, DeviceType) for " 193 return outputs_.at(idx)->template GetMutable<T>();
197 template <
typename T>
198 inline T* Output(
int idx, DeviceType type) {
199 if (isLegacyOperator()) {
201 std::is_same<T, Tensor>::value,
202 "Output(int, DeviceType) is only available for Tensor");
204 return BlobGetMutableTensor(outputs_.at(idx), type);
206 auto& output = newstyle_outputs_[idx];
208 if (!tensor.defined() || tensor.GetDeviceType() != type) {
211 output =
at::Tensor(std::move(tensor.getIntrusivePtr()));
214 return &output_tensors_[idx];
219 CAFFE_ENFORCE_WITH_CALLER(
221 "device must be provided in option.");
222 if (isLegacyOperator()) {
223 return XBlobGetMutableTensor(outputs_.at(idx), dims, options);
226 return OutputTensor(idx, dims, options)->UnsafeSharedInstance();
229 void SetOutputTensor(
int idx,
Tensor tensor) {
230 if (!isLegacyOperator()) {
234 output_tensors_[idx] = std::move(tensor);
237 BlobSetTensor(outputs_.at(idx), std::move(tensor));
241 Tensor OutputTensorOrUndefined(
int idx) {
242 if (isLegacyOperator()) {
243 return BlobGetTensorOrUndefined(*outputs_.at(idx));
245 return output_tensors_[idx].UnsafeSharedInstance();
250 if (isLegacyOperator()) {
251 CAFFE_ENFORCE_WITH_CALLER(
253 "device must be provided in options.");
254 return BlobGetMutableTensor(outputs_.at(idx), dims, options);
256 auto& output = newstyle_outputs_[idx];
260 output =
at::Tensor(std::move(tensor.getIntrusivePtr()));
263 return &output_tensors_[idx];
267 Tensor* OutputTensorCopyFrom(
271 bool async =
false) {
272 CAFFE_ENFORCE_WITH_CALLER(
274 "device must be provided in options.");
277 options = options.
dtype(src.dtype());
279 CAFFE_ENFORCE_WITH_CALLER(
280 options.
dtype() == src.dtype(),
281 "We don't allow change of src data type in OutputTensorCopyFrom");
282 Tensor* t = OutputTensor(idx, src.sizes(), options);
283 t->CopyFrom(src, async);
288 return BlobSetTensor(OutputBlob(idx),
293 template <
typename T>
294 inline T* Output(
int idx,
T* allocated) {
295 outputs_.at(idx)->Reset(allocated);
299 inline const Blob& InputBlob(
int idx) {
300 return *inputs_.at(idx);
303 inline Blob* OutputBlob(
int idx) {
304 return outputs_.at(idx);
310 inline bool IsInputOutputAlias(
int i,
int j) {
311 return inputs_.at(i) == outputs_.at(j);
314 template <
typename T>
315 inline bool InputIsType(
int idx) {
317 !std::is_same<T, Tensor>::value,
318 "You should use InputIsTensorType(int, DeviceType) for " 320 return inputs_.at(idx)->template IsType<T>();
323 inline bool InputIsTensorType(
int idx, DeviceType device_type) {
324 return BlobIsTensorType(*inputs_.at(idx), device_type);
327 template <
typename T>
328 inline bool OutputIsType(
int idx) {
330 !std::is_same<T, Tensor>::value,
331 "You should use OutputIsTensorType(int, DeviceType) for " 333 return outputs_.at(idx)->template IsType<T>();
336 inline bool OutputIsTensorType(
int idx, DeviceType type) {
337 return BlobIsTensorType(*outputs_.at(idx), type);
340 inline int InputSize()
const {
344 inline int OutputSize()
const {
345 if (isLegacyOperator()) {
346 return outputs_.size();
348 return newstyle_outputs_.size();
350 inline const vector<const Blob*>& Inputs()
const {
return inputs_; }
351 inline const vector<Blob*>& Outputs() {
return outputs_; }
352 vector<TensorShape> InputTensorShapes()
const;
354 virtual void WaitEvent(
const Event& ev,
int = -1) {
358 inline void Wait(
const OperatorBase& other,
int stream_id = -1) {
359 if (!other.IsEventDisabled()) {
360 WaitEvent(other.event(), stream_id);
364 virtual void WaitEvents(
365 const std::vector<const Event*>& events,
367 for (
const auto& ev : events) {
372 virtual void Finish() {
378 virtual bool Run(
int = 0) {
379 CAFFE_NOT_IMPLEMENTED;
382 virtual bool HasAsyncPart()
const {
386 virtual bool SupportsAsyncScheduling()
const {
394 virtual bool RunAsync(
int stream_id = 0) {
396 auto result = Run(stream_id);
398 if (HasAsyncPart()) {
404 SetEventFinished(getErrorMsg().c_str());
408 SetEventFinishedWithException(err.
what());
410 }
catch (
const std::exception& err) {
411 SetEventFinishedWithException(err.what());
414 SetEventFinishedWithException(getErrorMsg().c_str());
420 if (!has_debug_def()) {
425 if (err->caller() !=
nullptr) {
426 for (
size_t i = 0; i < inputs_.size(); i++) {
427 if (inputs_[i]->GetRaw() == err->caller()) {
430 "\n** while accessing input: " + debug_def().input(i));
434 for (
size_t i = 0; i < outputs_.size(); i++) {
435 if (outputs_[i]->GetRaw() == err->caller()) {
437 err->AppendMessage(
"\n OR ");
440 "\n** while accessing output: " + debug_def().output(i));
447 inline const OperatorDef& debug_def()
const {
448 CAFFE_ENFORCE(has_debug_def(),
"operator_def was null!");
449 return *operator_def_;
452 inline void set_debug_def(
453 const std::shared_ptr<const OperatorDef>& operator_def) {
454 operator_def_ = operator_def;
457 inline bool has_debug_def()
const {
458 return operator_def_ !=
nullptr;
462 void RecordLastFailedOpNetPosition() {
463 if (net_position_ != kNoNetPositionSet) {
464 VLOG(1) <<
"Operator with id " << net_position_ <<
" failed";
465 operator_ws_->last_failed_op_net_position = net_position_;
467 VLOG(1) <<
"Failed operator doesn't have id set";
471 int net_position()
const {
472 return net_position_;
475 void set_net_position(
int idx) {
479 const DeviceOption& device_option()
const {
480 return device_option_;
483 const Event& event()
const {
484 CAFFE_ENFORCE(event_,
"Event is disabled");
489 CAFFE_ENFORCE(event_,
"Event is disabled");
499 void DisableEvent() {
503 bool IsEventDisabled()
const {
508 virtual void SyncDeviceBarrierForObservers() {
509 CAFFE_NOT_IMPLEMENTED;
515 virtual bool IsStreamFree(
int )
const {
519 const std::string& type()
const {
523 void annotate_engine(
const std::string& engine) {
527 const std::string& engine()
const {
539 std::vector<at::Tensor> move_newstyle_outputs() && {
540 return std::move(newstyle_outputs_);
544 static const int kNoNetPositionSet = -1;
548 std::shared_ptr<const OperatorDef> operator_def_;
549 DeviceOption device_option_;
552 vector<const Blob*> inputs_;
553 vector<Blob*> outputs_;
555 std::unique_ptr<const c10::FunctionSchema> fn_schema_ =
nullptr;
556 vector<c10::IValue> newstyle_inputs_;
557 vector<at::Tensor> newstyle_outputs_;
562 vector<caffe2::Tensor> input_tensors_;
563 vector<caffe2::Tensor> output_tensors_;
567 int net_position_{kNoNetPositionSet};
572 virtual void RecordEvent(
const char* =
nullptr) {
573 CAFFE_NOT_IMPLEMENTED;
576 void SetEventFinished(
const char* err_msg =
nullptr) {
578 event_->SetFinished(err_msg);
582 void SetEventFinishedWithException(
const char* err_msg =
nullptr) {
584 event_->SetFinishedWithException(err_msg);
588 std::string getErrorMsg() {
589 if (has_debug_def()) {
590 return "Error from operator: " + ProtoDebugString(debug_def());
592 return "Error from operator: no op def";
597 std::unique_ptr<Event> event_;
603 inline NetDef OperatorBase::GetSingleArgument<NetDef>(
604 const std::string& name,
605 const NetDef& default_value)
const {
606 if (isLegacyOperator()) {
607 CAFFE_ENFORCE(operator_def_,
"operator_def was null!");
608 return ArgumentHelper::GetSingleArgument<OperatorDef, NetDef>(
609 *operator_def_, name, default_value);
611 CAFFE_THROW(
"Cannot get NetDefs from IValue");
616 inline vector<int> OperatorBase::GetVectorFromIValueList<int>(
618 const auto& vs = value.toIntListRef();
620 out.reserve(vs.size());
621 for (
const auto& v : vs) {
628 inline vector<float> OperatorBase::GetVectorFromIValueList<float>(
630 const auto& vs = value.toDoubleListRef();
632 out.reserve(vs.size());
633 for (
const auto& v : vs) {
640 inline vector<string> OperatorBase::GetVectorFromIValueList<string>(
642 CAFFE_THROW(
"Cannot extract vector<string> from ivalue.");
650 #if defined(CUDART_VERSION) && CUDART_VERSION >= 9020 && __GNUC__ >= 7 651 #define OP_SINGLE_ARG(type, name, variable, default) \ 652 variable(this->template GetSingleArgument<type>(name, (default))) 654 #define OP_SINGLE_ARG(type, name, variable, default) \ 655 variable(OperatorBase::GetSingleArgument<type>(name, (default))) 668 #define INPUT_TAGS(first_input, ...) \ 669 enum _InputTags { first_input = 0, __VA_ARGS__ } 670 #define OUTPUT_TAGS(first_input, ...) \ 671 enum _OutputTags { first_input = 0, __VA_ARGS__ } 676 template <
class Context>
680 :
OperatorBase(operator_def, ws), context_(operator_def.device_option()) {
683 context_.SwitchToDevice();
687 std::vector<c10::IValue> inputs,
688 std::vector<at::Tensor> outputs)
689 :
OperatorBase(fn_schema, std::move(inputs), std::move(outputs)) {
692 context_.SwitchToDevice();
704 DeviceType type = Context::GetDeviceType()) {
705 return OperatorBase::template Input<Tensor>(idx, type);
714 return OperatorBase::XOutputTensor(
715 idx, dims, options.
device(context_.device()));
717 return OperatorBase::XOutputTensor(idx, dims, options);
771 return OperatorBase::OutputTensor(
772 idx, dims, options.
device(context_.device()));
774 return OperatorBase::OutputTensor(idx, dims, options);
779 inline Tensor* Output(
int idx, DeviceType type = Context::GetDeviceType()) {
780 return OperatorBase::template Output<Tensor>(idx, type);
790 Tensor* OutputTensorCopyFrom(
794 bool async =
false) {
796 return OperatorBase::OutputTensorCopyFrom(
797 idx, options.
device(context_.device()), src, async);
799 return OperatorBase::OutputTensorCopyFrom(idx, options, src, async);
802 void WaitEvent(
const Event& ev,
int stream_id = -1)
final {
803 if (stream_id >= 0) {
804 context_.SwitchToDevice(stream_id);
806 context_.WaitEvent(ev);
809 void WaitEvents(
const std::vector<const Event*>& events,
int stream_id = -1)
811 if (stream_id >= 0) {
812 context_.SwitchToDevice(stream_id);
814 for (
const auto& ev : events) {
815 context_.WaitEvent(*ev);
824 bool Run(
int stream_id = 0)
final {
828 context_.SwitchToDevice(stream_id);
830 if (FLAGS_caffe2_operator_throw_if_fp_exceptions) {
831 std::feclearexcept(FE_ALL_EXCEPT);
833 bool result = RunOnDevice();
834 if (FLAGS_caffe2_operator_throw_if_fp_exceptions) {
836 !std::fetestexcept(FE_DIVBYZERO),
837 "Division by zero floating point exception (FE_DIVBYZERO) reported.");
839 !std::fetestexcept(FE_INVALID),
840 "Invalid floating point exception (FE_INVALID) reported.");
842 !std::fetestexcept(FE_OVERFLOW),
843 "Overflow floating point exception (FE_OVERFLOW) reported.");
846 this->RecordLastFailedOpNetPosition();
848 context_.FinishDeviceComputation();
854 if (has_debug_def()) {
856 "Error from operator: \n" + ProtoDebugString(debug_def()));
857 AddRelatedBlobInfo(&err);
859 this->RecordLastFailedOpNetPosition();
863 this->RecordLastFailedOpNetPosition();
869 bool RunAsync(
int stream_id = 0)
final {
873 context_.SwitchToDevice(stream_id);
874 auto result = RunOnDevice();
876 if (HasAsyncPart()) {
884 SetEventFinished(getErrorMsg().c_str());
885 this->RecordLastFailedOpNetPosition();
892 if (has_debug_def()) {
894 "Error from operator: \n" + ProtoDebugString(debug_def()));
895 AddRelatedBlobInfo(&err);
897 SetEventFinishedWithException(err.
what());
898 this->RecordLastFailedOpNetPosition();
901 }
catch (
const std::exception& err) {
902 SetEventFinishedWithException(err.what());
903 this->RecordLastFailedOpNetPosition();
907 SetEventFinishedWithException(getErrorMsg().c_str());
908 this->RecordLastFailedOpNetPosition();
914 bool IsStreamFree(
int stream_id)
const override {
915 return context_.IsStreamFree(device_option(), stream_id);
918 virtual bool RunOnDevice() = 0;
930 bool HasAsyncPart()
const override {
931 return context_.HasAsyncPartDefault();
947 bool SupportsAsyncScheduling()
const override {
948 return HasAsyncPart() && context_.SupportsAsyncScheduling();
951 void SyncDeviceBarrierForObservers()
override {
952 context_.FinishDeviceComputation();
955 const Context* getContext()
const {
958 Context* getContext() {
963 void RecordEvent(
const char* err_msg =
nullptr)
final {
965 context_.Record(event_.get(), err_msg);
972 #define USE_OPERATOR_BASE_FUNCTIONS \ 973 using OperatorBase::HasArgument; \ 974 using OperatorBase::GetSingleArgument; \ 975 using OperatorBase::HasSingleArgumentOfType; \ 976 using OperatorBase::GetRepeatedArgument; \ 977 using OperatorBase::InputIsType; \ 978 using OperatorBase::InputSize; \ 979 using OperatorBase::Output; \ 980 using OperatorBase::Input; \ 981 using OperatorBase::OutputSize; \ 982 using OperatorBase::IsInputOutputAlias; \ 983 using OperatorBase::OutputTensorAlias 985 #define USE_OPERATOR_FUNCTIONS(context) \ 986 USE_OPERATOR_BASE_FUNCTIONS; \ 987 using Operator<context>::context_; \ 988 using Operator<context>::Input; \ 989 using Operator<context>::InputBlob; \ 990 using Operator<context>::Output; \ 991 using Operator<context>::OutputBlob; \ 992 using Operator<context>::OutputTensorCopyFrom 994 #define USE_OPERATOR_CONTEXT_FUNCTIONS USE_OPERATOR_FUNCTIONS(Context) 996 #define USE_SIMPLE_CTOR_DTOR(name) \ 997 template<class... Args> explicit name(Args&&... args) \ 998 : Operator<Context>(std::forward<Args>(args)...) {} \ 999 virtual ~name() noexcept {} 1031 #define USE_DISPATCH_HELPER \ 1032 template <typename FirstArg, typename... ExtraArgs> \ 1033 friend struct DispatchHelper 1035 template <
int... Values>
1038 template <
typename... Types>
1048 template <
typename... Types>
1051 template <
typename Sizes,
typename... ExtraArgs>
1054 template <
int FirstVal,
int... Values,
typename... ExtraArgs>
1056 template <
typename Op>
1057 static bool call(Op* op,
int value) {
1058 if (FirstVal == value) {
1059 return op->template DoRunWithValue<ExtraArgs..., FirstVal>();
1066 template <
typename... ExtraArgs>
1068 template <
typename Op>
1069 static bool call(Op* op, int64_t ) {
1070 return op->template DoRunWithValue<ExtraArgs..., -1>();
1074 #define C10_DEFINE_TENSOR_TYPES_DISPATCHER( \ 1075 TensorTypes, DoRunWithType, DoRunWithOtherType) \ 1076 template <typename FirstType, typename... Types, typename... ExtraArgs> \ 1077 struct DispatchHelper<TensorTypes<FirstType, Types...>, ExtraArgs...> { \ 1078 template <typename Op> \ 1079 static bool call(Op* op, const TypeMeta& meta) { \ 1081 !std::is_same<GenericTensorImplementation, FirstType>::value, \ 1082 "GenericTensorImplementation must be the last in TensorTypes list"); \ 1083 if (meta.Match<FirstType>()) { \ 1084 return op->template DoRunWithType<ExtraArgs..., FirstType>(); \ 1086 return DispatchHelper<TensorTypes<Types...>, ExtraArgs...>:: \ 1087 template call<Op>(op, meta); \ 1089 template <typename Op> \ 1090 static bool call(Op* op, const Tensor& tensor) { \ 1091 return call<Op>(op, tensor.dtype()); \ 1093 template <typename Op> \ 1094 static bool call(Op* op, const Blob& blob) { \ 1095 return call<Op>(op, blob.meta()); \ 1099 template <typename... ExtraArgs> \ 1100 struct DispatchHelper<TensorTypes<>, ExtraArgs...> { \ 1101 template <typename Op> \ 1102 static bool call(Op* , const TypeMeta& meta) { \ 1103 CAFFE_THROW("Unsupported type of tensor: ", meta.name()); \ 1105 template <typename Op> \ 1106 static bool call(Op* op, const Tensor& tensor) { \ 1107 return call<Op>(op, tensor.dtype()); \ 1109 template <typename Op> \ 1110 static bool call(Op* op, const Blob& blob) { \ 1111 return call<Op>(op, blob.meta()); \ 1115 template <typename... ExtraArgs> \ 1116 struct DispatchHelper< \ 1117 TensorTypes<GenericTensorImplementation>, \ 1119 template <typename Op> \ 1120 static bool call(Op* op, const TypeMeta&) { \ 1121 return op->template DoRunWithOtherType<ExtraArgs...>(); \ 1123 template <typename Op> \ 1124 static bool call(Op* op, const Tensor& tensor) { \ 1125 return call<Op>(op, tensor.dtype()); \ 1127 template <typename Op> \ 1128 static bool call(Op* op, const Blob& blob) { \ 1129 return call<Op>(op, blob.meta()); \ 1132 C10_DEFINE_TENSOR_TYPES_DISPATCHER(
1136 C10_DEFINE_TENSOR_TYPES_DISPATCHER(
1139 DoRunWithOtherType2)
1140 #undef C10_DEFINE_TENSOR_TYPES_DISPATCHER 1149 std::unique_ptr<OperatorBase>,
1155 std::unique_ptr<OperatorBase>,
1158 CAFFE2_API std::map<DeviceType, OperatorRegistry*>* gDeviceTypeRegistry();
1162 if (gDeviceTypeRegistry()->count(type)) {
1163 std::cerr <<
"Device type " << DeviceTypeName(type)
1164 <<
"registered twice. This should not happen. Did you have " 1165 "duplicated numbers assigned to different devices?";
1169 gDeviceTypeRegistry()->emplace(type, func());
1173 #define CAFFE_REGISTER_DEVICE_TYPE(type, registry_function) \ 1175 static DeviceTypeRegisterer C10_ANONYMOUS_VARIABLE( \ 1176 DeviceType)(type, ®istry_function); \ 1186 C10_DECLARE_REGISTRY(
1187 CPUOperatorRegistry,
1191 #define REGISTER_CPU_OPERATOR_CREATOR(key, ...) \ 1192 C10_REGISTER_CREATOR(CPUOperatorRegistry, key, __VA_ARGS__) 1193 #define REGISTER_CPU_OPERATOR(name, ...) \ 1194 C10_IMPORT void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \ 1195 static void CAFFE2_UNUSED CAFFE_ANONYMOUS_VARIABLE_CPU##name() { \ 1196 CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \ 1198 C10_REGISTER_CLASS(CPUOperatorRegistry, name, __VA_ARGS__) 1199 #define REGISTER_CPU_OPERATOR_STR(str_name, ...) \ 1200 C10_REGISTER_TYPED_CLASS(CPUOperatorRegistry, str_name, __VA_ARGS__) 1202 #define REGISTER_CPU_OPERATOR_WITH_ENGINE(name, engine, ...) \ 1203 C10_REGISTER_CLASS(CPUOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__) 1207 #ifdef CAFFE2_NO_GRADIENT_OPS 1208 #define REGISTER_CPU_GRADIENT_OPERATOR(...) 1210 #define REGISTER_CPU_GRADIENT_OPERATOR(...) \ 1211 MACRO_EXPAND(REGISTER_CPU_OPERATOR(__VA_ARGS__)) 1214 C10_DECLARE_REGISTRY(
1215 CUDAOperatorRegistry,
1219 #define REGISTER_CUDA_OPERATOR_CREATOR(key, ...) \ 1220 C10_REGISTER_CREATOR(CUDAOperatorRegistry, key, __VA_ARGS__) 1221 #define REGISTER_CUDA_OPERATOR(name, ...) \ 1222 C10_IMPORT void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \ 1223 static void CAFFE2_UNUSED CAFFE_ANONYMOUS_VARIABLE_CUDA##name() { \ 1224 CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \ 1226 C10_REGISTER_CLASS(CUDAOperatorRegistry, name, __VA_ARGS__) 1227 #define REGISTER_CUDA_OPERATOR_STR(str_name, ...) \ 1228 C10_REGISTER_TYPED_CLASS(CUDAOperatorRegistry, str_name, __VA_ARGS__) 1230 #define REGISTER_CUDA_OPERATOR_WITH_ENGINE(name, engine, ...) \ 1231 C10_REGISTER_CLASS(CUDAOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__) 1234 #define REGISTER_CUDNN_OPERATOR(name, ...) \ 1235 REGISTER_CUDA_OPERATOR_WITH_ENGINE(name, CUDNN, __VA_ARGS__) 1238 C10_DECLARE_REGISTRY(
1239 HIPOperatorRegistry,
1243 #define REGISTER_HIP_OPERATOR_CREATOR(key, ...) \ 1244 C10_REGISTER_CREATOR(HIPOperatorRegistry, key, __VA_ARGS__) 1245 #define REGISTER_HIP_OPERATOR(name, ...) \ 1246 C10_IMPORT void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \ 1247 static void CAFFE2_UNUSED CAFFE_ANONYMOUS_VARIABLE_HIP##name() { \ 1248 CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \ 1250 C10_REGISTER_CLASS(HIPOperatorRegistry, name, __VA_ARGS__) 1251 #define REGISTER_HIP_OPERATOR_STR(str_name, ...) \ 1252 C10_REGISTER_TYPED_CLASS(HIPOperatorRegistry, str_name, __VA_ARGS__) 1254 #define REGISTER_HIP_OPERATOR_WITH_ENGINE(name, engine, ...) \ 1255 C10_REGISTER_CLASS(HIPOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__) 1257 #define REGISTER_MIOPEN_OPERATOR(name, ...) \ 1258 REGISTER_HIP_OPERATOR_WITH_ENGINE(name, MIOPEN, __VA_ARGS__) \ 1259 REGISTER_HIP_OPERATOR_WITH_ENGINE(name, CUDNN, __VA_ARGS__) // Make CUDNN an alias of MIOPEN for HIP ops 1271 const int registered_ops = CPUOperatorRegistry()->Keys().size();
1277 if (registered_ops == 0) {
1279 "You might have made a build error: the Caffe2 library does not seem " 1280 "to be linked with whole-static library option. To do so, use " 1281 "-Wl,-force_load (clang) or -Wl,--whole-archive (gcc) to link the " 1295 const char* what()
const noexcept
override {
1296 return msg_.c_str();
1306 #define OPERATOR_NEEDS_FEATURE(condition, ...) \ 1307 if (!(condition)) { \ 1308 throw UnsupportedOperatorFeature(::c10::str(__VA_ARGS__)); \ 1313 CAFFE2_API unique_ptr<OperatorBase> CreateOperator(
1314 const OperatorDef& operator_def,
1316 int net_position = OperatorBase::kNoNetPositionSet);
1318 CAFFE2_API
const std::string OpRegistryKey(
1319 const std::string& op_type,
1320 const std::string& engine =
"");
1324 using EnginePrefType = std::vector<std::string>;
1326 using PerOpEnginePrefType =
1327 CaffeMap<DeviceType, CaffeMap<std::string, EnginePrefType>>;
1329 using GlobalEnginePrefType = CaffeMap<DeviceType, EnginePrefType>;
1330 CAFFE2_API
void SetPerOpEnginePref(
const PerOpEnginePrefType& per_op_engine_pref);
1331 CAFFE2_API
void SetGlobalEnginePref(
const GlobalEnginePrefType& global_engine_pref);
1332 CAFFE2_API
void SetEnginePref(
1333 const PerOpEnginePrefType& per_op_engine_pref,
1334 const GlobalEnginePrefType& global_engine_pref);
1335 CAFFE2_API
void SetOpEnginePref(
1336 const std::string& op_type,
1337 const CaffeMap<DeviceType, EnginePrefType>& op_pref);
1339 CAFFE2_API TensorShape GetTensorShapeOfBlob(
const Blob* b);
1341 CAFFE2_API TensorShapes InferBlobShapesAndTypes(
1342 CaffeMap<string, TensorShape>& blob_desc,
1343 const vector<NetDef*>& nets);
1345 CAFFE2_API TensorShapes InferBlobShapesAndTypesFromWorkspace(
1347 const vector<NetDef*>& nets);
1349 CAFFE2_API TensorShapes InferBlobShapesAndTypesFromMap(
1350 const CaffeMap<std::string, std::vector<int64_t>>& blob_dimensions,
1351 const vector<NetDef*>& nets);
1353 CAFFE2_API TensorShapes InferBlobShapesAndTypesFromMap(
1354 const CaffeMap<std::string, std::vector<int64_t>>& blob_dimensions,
1355 const CaffeMap<std::string, TensorProto_DataType>& blob_types,
1356 const vector<NetDef*>& nets);
1358 CAFFE2_API std::map<string, std::pair<DeviceOption, DeviceOption>> ValidateTensorDevices(
1360 const OperatorDef& op_def);
1363 CAFFE2_API std::set<std::string> GetRegisteredOperators();
1366 CAFFE2_API
void SetOperatorLogger(std::function<
void(
const OperatorDef&)> tracer);
1367 std::function<void(const OperatorDef&)> GetOperatorLogger();
1371 #include "caffe2/core/c10_operator.h" 1373 #endif // CAFFE2_CORE_OPERATOR_H_
Blob is a general container that hosts a typed pointer.
C10_NODISCARD TensorOptions device(c10::optional< Device > device) const noexcept
Return a copy of TensorOptions with device set to the given one, or cleared if device is nullopt...
c10::optional< Device > device_opt() const noexcept
Returns the device of the TensorOptions, or c10::nullopt if device is not specified.
bool isLegacyOperator() const
Return true if the operator was instantiated with OperatorDef New operators should be instantiated wi...
bool has_dtype() const noexcept
Returns whether the dtype is specified.
The primary ATen error class.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Inherit to make your class observable.
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A template class that allows one to register classes by keys.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
const char * what() const noexceptoverride
Returns the complete error message, including the source location.
C10_NODISCARD TensorOptions dtype(c10::optional< caffe2::TypeMeta > dtype) const noexcept
Return a copy of TensorOptions with dtype set to the given one.
Tensor XOutput(int idx, at::IntArrayRef dims, at::TensorOptions options)
XOutput is a modernized version of Output which returns a Tensor rather than a Tensor* (the raw point...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.