Caffe2 - C++ API
A deep learning, cross platform ML framework
mkl_operator.h
1 
17 #ifndef CAFFE2_UTILS_MKL_OPERATOR_H_
18 #define CAFFE2_UTILS_MKL_OPERATOR_H_
19 
20 #include "caffe2/core/operator.h"
21 #include "caffe2/mkl/utils/mkl_dnn_cppwrapper.h"
22 #include "caffe2/mkl/utils/mkl_memory.h"
23 #include "caffe2/proto/caffe2.pb.h"
24 
25 namespace caffe2 {
26 
27 CAFFE_DECLARE_REGISTRY(
28  MKLOperatorRegistry,
29  OperatorBase,
30  const OperatorDef&,
31  Workspace*);
32 #define REGISTER_MKL_OPERATOR_CREATOR(key, ...) \
33  CAFFE_REGISTER_CREATOR(MKLOperatorRegistry, key, __VA_ARGS__)
34 #define REGISTER_MKL_OPERATOR(name, ...) \
35  CAFFE_REGISTER_CLASS(MKLOperatorRegistry, name, __VA_ARGS__)
36 #define REGISTER_MKL_OPERATOR_STR(str_name, ...) \
37  CAFFE_REGISTER_TYPED_CLASS(MKLOperatorRegistry, str_name, __VA_ARGS__)
38 
39 #define REGISTER_MKL_OPERATOR_WITH_ENGINE(name, engine, ...) \
40  CAFFE_REGISTER_CLASS(MKLOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__)
41 
42 namespace mkl {
43 // MKLOperator is the base scaffolding of the operators that uses MKLDNN. It
44 // provides a few operators that are useful to MKLDNN specific implementations.
45 template <typename T>
46 class MKLOperator : public OperatorBase {
47  public:
48  explicit MKLOperator(const OperatorDef& operator_def, Workspace* ws)
49  : OperatorBase(operator_def, ws),
50  context_(operator_def.device_option()) {}
51  virtual ~MKLOperator() {}
52 
53  inline const MKLMemory<T>& Input(int idx) {
54  return OperatorBase::template Input<MKLMemory<T>>(idx);
55  }
56  inline MKLMemory<T>* Output(int idx) {
57  return OperatorBase::template Output<MKLMemory<T>>(idx);
58  }
59 
60  // The run function of Operator switches to the device, and then carries out
61  // the actual computation with RunOnDevice(). You should implement RunOnDevice
62  // instead of Run().
63  bool Run(int /* unused */ /*stream_id*/) final {
64  // Since MKLDNN does not need to do SwithToDevice and
65  // FinishDeviceComputation,
66  // it is always just a re-route to RunOnDevice().
67  try {
68  return RunOnDevice();
69  } catch (EnforceNotMet& err) {
70  err.AppendMessage(getErrorMsg());
71  throw;
72  }
73  }
74 
75  // Waits for a previous event. Note that to properly wait and run
76  // asynchronously, WaitEvent, RunAsync and Record should all be executed
77  // on the same CPU thread.
78  void WaitEvent(const Event& ev, int /* unused */) final {
79  context_.WaitEvent(ev);
80  }
81 
82  void WaitEvents(const std::vector<const Event*>& events, int /* unused */)
83  final {
84  for (const auto& ev : events) {
85  context_.WaitEvent(*ev);
86  }
87  }
88 
89  void RecordEvent(const char* err_msg = nullptr) final {
90  if (event_) {
91  context_.Record(event_.get(), err_msg);
92  }
93  }
94 
95  virtual bool RunOnDevice() = 0;
96 
97  inline void ExecutePrimitive() {
98  MKLDNN_SAFE_CALL(mkl::dnnExecute<T>(primitive_, resources_));
99  }
100 
101  protected:
102  std::string getErrorMsg() {
103  if (has_debug_def()) {
104  return "Error from operator: " + ProtoDebugString(debug_def());
105  } else {
106  return "Error from operator: no op def";
107  }
108  }
109 
110  MKLContext context_;
111  // The primitive used in the operator.
112  PrimitiveWrapper<T> primitive_;
113  // Size cache for all the input sizes.
114  vector<vector<TIndex>> input_size_cache_;
115  // An internal MKLMemory buffer. This is usually handy when we have a
116  // single output from the operator. If your operator has multiple outputs
117  // then you should allocate your own buffer.
118  MKLMemory<T> buffer_;
119  // The resources vector that we will need to use;
120  void* resources_[dnnResourceNumber];
121 };
122 } // namespace mkl
123 
124 #define USE_MKLOPERATOR_FUNCTIONS(T) \
125  USE_OPERATOR_BASE_FUNCTIONS; \
126  /* using override */ using MKLOperator<T>::Input; \
127  /* using override */ using MKLOperator<T>::Output; \
128  /* using override */ using MKLOperator<T>::ExecutePrimitive; \
129  /* using override */ using MKLOperator<T>::primitive_; \
130  /* using override */ using MKLOperator<T>::input_size_cache_; \
131  /* using override */ using MKLOperator<T>::buffer_; \
132  /* using override */ using MKLOperator<T>::resources_
133 
134 #define USE_SIMPLE_MKL_CTOR_DTOR(name, T) \
135  name(const OperatorDef& operator_def, Workspace* ws) \
136  : MKLOperator<T>(operator_def, ws) {} \
137  virtual ~name() {}
138 
139 } // namespace caffe2
140 
141 #endif // CAFFE2_UTILS_MKL_OPERATOR_H_
The MKL Context, which is largely the same as the CPUContext.
Definition: mkl_context.h:36
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.
A wrapper around an opaque MKL internal resource that has certain layouts and convertion primitives s...
Definition: mkl_memory.h:153