4 #include <caffe2/core/operator.h> 5 #include <caffe2/proto/caffe2_pb.h> 10 IDEEPOperatorRegistry,
15 #define REGISTER_IDEEP_OPERATOR_CREATOR(key, ...) \ 16 C10_REGISTER_CREATOR(IDEEPOperatorRegistry, key, __VA_ARGS__) 17 #define REGISTER_IDEEP_OPERATOR(name, ...) \ 18 C10_REGISTER_CLASS(IDEEPOperatorRegistry, name, __VA_ARGS__) 19 #define REGISTER_IDEEP_OPERATOR_STR(str_name, ...) \ 20 C10_REGISTER_TYPED_CLASS(IDEEPOperatorRegistry, str_name, __VA_ARGS__) 21 #define REGISTER_IDEEP_COMPARE_OPERATOR(Op) \ 22 REGISTER_IDEEP_OPERATOR( \ 24 IDEEPFallbackOp<BinaryElementwiseOp< \ 25 TensorTypes<bool, int32_t, int64_t, float, double>, \ 27 Op##Functor<CPUContext>, \ 30 #define REGISTER_IDEEP_OPERATOR_WITH_ENGINE(name, engine, ...) \ 31 C10_REGISTER_CLASS(IDEEPOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__) 39 context_(operator_def.device_option()),
40 order_(StringToStorageOrder(
41 OperatorBase::GetSingleArgument<string>(
"order",
"NCHW"))) {
42 OPERATOR_NEEDS_FEATURE(
43 order_ == StorageOrder::NCHW,
"Unsupported storage order.");
47 inline const ideep::tensor& Input(
int index) {
48 return OperatorBase::template Input<ideep::tensor>(index);
50 inline ideep::tensor* Output(
int index) {
51 return OperatorBase::template Output<ideep::tensor>(index);
57 bool Run(
int )
final {
63 bool result = RunOnDevice();
67 err.AppendMessage(getErrorMsg());
69 }
catch (ideep::error& e) {
70 LOG(ERROR) <<
"IDEEP error:" << e.message;
78 void WaitEvent(
const Event& ev,
int )
final {
79 context_.WaitEvent(ev);
82 void WaitEvents(
const std::vector<const Event*>& events,
int )
84 for (
const auto& ev : events) {
85 context_.WaitEvent(*ev);
89 void RecordEvent(
const char* err_msg =
nullptr)
final {
91 context_.Record(event_.get(), err_msg);
95 virtual bool RunOnDevice() = 0;
98 std::string getErrorMsg() {
99 if (has_debug_def()) {
100 return "Error from operator: " + ProtoDebugString(debug_def());
102 return "Error from operator: no op def";
110 #define USE_IDEEP_OPERATOR_FUNCTIONS() \ 111 USE_OPERATOR_BASE_FUNCTIONS; \ 112 using IDEEPOperator::Input; \ 113 using IDEEPOperator::Output; \ 114 using IDEEPOperator::order_; \ 115 using IDEEPOperator::context_; 117 #define USE_SIMPLE_IDEEP_CTOR_DTOR(name) \ 118 name(const OperatorDef& operator_def, Workspace* ws) \ 119 : IDEEPOperator(operator_def, ws) {} \ The primary ATen error class.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...