Caffe2 - C++ API
A deep learning, cross platform ML framework
ideep_operator.h
1 #pragma once
2 
3 #include <ideep.hpp>
4 #include <caffe2/core/operator.h>
5 #include <caffe2/proto/caffe2_pb.h>
6 
7 namespace caffe2 {
8 
9 C10_DECLARE_REGISTRY(
10  IDEEPOperatorRegistry,
11  OperatorBase,
12  const OperatorDef&,
13  Workspace*);
14 
15 #define REGISTER_IDEEP_OPERATOR_CREATOR(key, ...) \
16  C10_REGISTER_CREATOR(IDEEPOperatorRegistry, key, __VA_ARGS__)
17 #define REGISTER_IDEEP_OPERATOR(name, ...) \
18  C10_REGISTER_CLASS(IDEEPOperatorRegistry, name, __VA_ARGS__)
19 #define REGISTER_IDEEP_OPERATOR_STR(str_name, ...) \
20  C10_REGISTER_TYPED_CLASS(IDEEPOperatorRegistry, str_name, __VA_ARGS__)
21 #define REGISTER_IDEEP_COMPARE_OPERATOR(Op) \
22  REGISTER_IDEEP_OPERATOR( \
23  Op, \
24  IDEEPFallbackOp<BinaryElementwiseOp< \
25  TensorTypes<bool, int32_t, int64_t, float, double>, \
26  CPUContext, \
27  Op##Functor<CPUContext>, \
28  FixedType<bool>>>)
29 
30 #define REGISTER_IDEEP_OPERATOR_WITH_ENGINE(name, engine, ...) \
31  C10_REGISTER_CLASS(IDEEPOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__)
32 
33 // IDEEPOperator is the base scaffolding of the operators that uses IDEEP. It
34 // provides a few operators that are useful to IDEEP specific implementations.
35 class IDEEPOperator : public OperatorBase {
36  public:
37  explicit IDEEPOperator(const OperatorDef& operator_def, Workspace* ws)
38  : OperatorBase(operator_def, ws),
39  context_(operator_def.device_option()),
40  order_(StringToStorageOrder(
41  OperatorBase::GetSingleArgument<string>("order", "NCHW"))) {
42  OPERATOR_NEEDS_FEATURE(
43  order_ == StorageOrder::NCHW, "Unsupported storage order.");
44  }
45  virtual ~IDEEPOperator() {}
46 
47  inline const ideep::tensor& Input(int index) {
48  return OperatorBase::template Input<ideep::tensor>(index);
49  }
50  inline ideep::tensor* Output(int index) {
51  return OperatorBase::template Output<ideep::tensor>(index);
52  }
53 
54  // The run function of Operator switches to the device, and then carries out
55  // the actual computation with RunOnDevice(). You should implement RunOnDevice
56  // instead of Run().
57  bool Run(int /* unused */ /*stream_id*/) final {
58  // Since IDEEP does not need to do SwithToDevice and
59  // FinishDeviceComputation,
60  // it is always just a re-route to RunOnDevice().
61  try {
62  StartAllObservers();
63  bool result = RunOnDevice();
64  StopAllObservers();
65  return result;
66  } catch (EnforceNotMet& err) {
67  err.AppendMessage(getErrorMsg());
68  throw;
69  } catch (ideep::error& e) {
70  LOG(ERROR) << "IDEEP error:" << e.message;
71  throw;
72  }
73  }
74 
75  // Waits for a previous event. Note that to properly wait and run
76  // asynchronously, WaitEvent, RunAsync and Record should all be executed
77  // on the same CPU thread.
78  void WaitEvent(const Event& ev, int /* unused */) final {
79  context_.WaitEvent(ev);
80  }
81 
82  void WaitEvents(const std::vector<const Event*>& events, int /* unused */)
83  final {
84  for (const auto& ev : events) {
85  context_.WaitEvent(*ev);
86  }
87  }
88 
89  void RecordEvent(const char* err_msg = nullptr) final {
90  if (event_) {
91  context_.Record(event_.get(), err_msg);
92  }
93  }
94 
95  virtual bool RunOnDevice() = 0;
96 
97  protected:
98  std::string getErrorMsg() {
99  if (has_debug_def()) {
100  return "Error from operator: " + ProtoDebugString(debug_def());
101  } else {
102  return "Error from operator: no op def";
103  }
104  }
105 
106  IDEEPContext context_;
107  StorageOrder order_;
108 };
109 
110 #define USE_IDEEP_OPERATOR_FUNCTIONS() \
111  USE_OPERATOR_BASE_FUNCTIONS; \
112  /* using override */ using IDEEPOperator::Input; \
113  /* using override */ using IDEEPOperator::Output; \
114  /* using override */ using IDEEPOperator::order_; \
115  /* using override */ using IDEEPOperator::context_;
116 
117 #define USE_SIMPLE_IDEEP_CTOR_DTOR(name) \
118  name(const OperatorDef& operator_def, Workspace* ws) \
119  : IDEEPOperator(operator_def, ws) {} \
120  virtual ~name() {}
121 
122 } // namespace caffe2
The primary ATen error class.
Definition: Exception.h:27
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13