Caffe2 - C++ API
A deep learning, cross platform ML framework
operator.h
1 
17 #ifndef CAFFE2_CORE_OPERATOR_H_
18 #define CAFFE2_CORE_OPERATOR_H_
19 
20 #include <array>
21 #include <climits>
22 #include <cstddef>
23 #include <exception>
24 #include <typeinfo>
25 #include <vector>
26 
27 #include "caffe2/core/blob.h"
28 #include "caffe2/core/common.h"
29 #include "caffe2/core/net.h"
30 #include "caffe2/core/observer.h"
31 #include "caffe2/core/operator_gradient.h"
32 #include "caffe2/core/operator_schema.h"
33 #include "caffe2/core/registry.h"
34 #include "caffe2/core/tensor.h"
35 #include "caffe2/core/types.h"
36 #include "caffe2/core/workspace.h"
37 #include "caffe2/proto/caffe2.pb.h"
38 #include "caffe2/utils/proto_utils.h"
39 
40 namespace caffe2 {
41 
42 class OperatorBase;
43 typedef ObserverBase<OperatorBase> OperatorObserver;
44 
45 class OperatorBase : public Observable<OperatorBase> {
46  public:
47  explicit OperatorBase(const OperatorDef& operator_def, Workspace* ws);
48  virtual ~OperatorBase() noexcept {}
49 
52  inline bool HasArgument(const string& name) const {
53  CAFFE_ENFORCE(operator_def_, "operator_def was null!");
54  return ArgumentHelper::HasArgument(*operator_def_, name);
55  }
56 
57  // Functions that deal with arguments. Basically, this allows us to map an
58  // argument name to a specific type of argument that we are trying to access.
59  template <typename T>
60  inline T GetSingleArgument(const string& name, const T& default_value) const {
61  CAFFE_ENFORCE(operator_def_, "operator_def was null!");
62  return ArgumentHelper::GetSingleArgument<OperatorDef, T>(
63  *operator_def_, name, default_value);
64  }
65  template <typename T>
66  inline bool HasSingleArgumentOfType(const string& name) const {
67  CAFFE_ENFORCE(operator_def_, "operator_def was null!");
68  return ArgumentHelper::HasSingleArgumentOfType<OperatorDef, T>(
69  *operator_def_, name);
70  }
71  template <typename T>
72  inline vector<T> GetRepeatedArgument(
73  const string& name,
74  const vector<T>& default_value = {}) const {
75  CAFFE_ENFORCE(operator_def_, "operator_def was null!");
76  return ArgumentHelper::GetRepeatedArgument<OperatorDef, T>(
77  *operator_def_, name, default_value);
78  }
79 
80  // Get the inputs and outputs as specific types.
81  template <typename T>
82  inline const T& Input(int idx) {
83  DCHECK_LT(idx, inputs_.size());
84  try {
85  return inputs_.at(idx)->template Get<T>();
86  } catch (::caffe2::EnforceNotMet& enf) {
87  if (has_debug_def()) {
88  enf.AppendMessage(".\nOffending Blob name: ");
89  enf.AppendMessage(debug_def().input(idx));
90  enf.AppendMessage(".\n");
91  }
92  throw enf;
93  }
94  }
95 
96  template <typename T>
97  inline T* Output(int idx) {
98  return outputs_.at(idx)->template GetMutable<T>();
99  }
100 
101  inline const Blob& InputBlob(int idx) {
102  return *inputs_.at(idx);
103  }
104 
105  inline Blob* OutputBlob(int idx) {
106  return outputs_.at(idx);
107  }
108 
109  template <typename T>
110  inline bool InputIsType(int idx) {
111  return inputs_.at(idx)->template IsType<T>();
112  }
113 
114  template <typename T>
115  inline bool OutputIsType(int idx) {
116  return outputs_.at(idx)->template IsType<T>();
117  }
118 
119  inline int InputSize() { return inputs_.size(); }
120  inline int OutputSize() { return outputs_.size(); }
121  inline const vector<const Blob*>& Inputs() const { return inputs_; }
122  inline const vector<Blob*>& Outputs() { return outputs_; }
123  vector<TensorShape> InputTensorShapes();
124 
125  virtual void WaitEvent(const Event& ev, int stream_id = -1) {
126  ev.Finish();
127  }
128 
129  inline void Wait(const OperatorBase& other, int stream_id = -1) {
130  WaitEvent(other.event(), stream_id);
131  }
132 
133  virtual void WaitEvents(
134  const std::vector<const Event*>& events,
135  int stream_id = -1) {
136  for (const auto& ev : events) {
137  ev->Finish();
138  }
139  }
140 
141  virtual void Finish() {
142  if (event_) {
143  event_->Finish();
144  }
145  }
146 
147  virtual bool Run(int /* unused */ /*stream_id*/ = 0) {
148  CAFFE_NOT_IMPLEMENTED;
149  }
150 
151  virtual bool HasAsyncPart() const {
152  return false;
153  }
154 
155  virtual bool SupportsAsyncScheduling() const {
156  return false;
157  }
158 
159  // RunAsync, if implemenented by the specific operators, will schedule the
160  // computation on the corresponding context and record the event in its
161  // event_ member object. If the specific operator does not support RunAsync,
162  // it will simply be synchronous as a fallback.
163  virtual bool RunAsync(int stream_id = 0) {
164  return Run(stream_id);
165  }
166 
167  virtual void AddRelatedBlobInfo(EnforceNotMet* err) {
168  if (!has_debug_def()) {
169  return;
170  }
171 
172  bool found_input;
173  if (err->caller() != nullptr) {
174  for (int i = 0; i < inputs_.size(); i++) {
175  if (inputs_[i]->GetRaw() == err->caller()) {
176  found_input = true;
177  err->AppendMessage(
178  "\n** while accessing input: " + debug_def().input(i));
179  break;
180  }
181  }
182  for (int i = 0; i < outputs_.size(); i++) {
183  if (outputs_[i]->GetRaw() == err->caller()) {
184  if (found_input) {
185  err->AppendMessage("\n OR ");
186  }
187  err->AppendMessage(
188  "\n** while accessing output: " + debug_def().output(i));
189  break;
190  }
191  }
192  }
193  }
194 
195  inline const OperatorDef& debug_def() const {
196  CAFFE_ENFORCE(has_debug_def(), "operator_def was null!");
197  return *operator_def_;
198  }
199 
200  inline void set_debug_def(
201  const std::shared_ptr<const OperatorDef>& operator_def) {
202  operator_def_ = operator_def;
203  }
204 
205  inline bool has_debug_def() const {
206  return operator_def_ != nullptr;
207  }
208 
209  public:
210  void RecordLastFailedOpNetPosition() {
211  if (net_position_ != kNoNetPositionSet) {
212  VLOG(1) << "Operator with id " << net_position_ << " failed";
213  operator_ws_->last_failed_op_net_position = net_position_;
214  } else {
215  VLOG(1) << "Failed operator doesn't have id set";
216  }
217  }
218 
219  int net_position() const {
220  return net_position_;
221  }
222 
223  void set_net_position(int idx) {
224  net_position_ = idx;
225  }
226 
227  const DeviceOption& device_option() const {
228  return device_option_;
229  }
230 
231  const Event& event() const {
232  CAFFE_ENFORCE(event_, "Event is disabled");
233  return *event_;
234  }
235 
236  Event& event() {
237  CAFFE_ENFORCE(event_, "Event is disabled");
238  return *event_;
239  }
240 
241  void ResetEvent() {
242  if (event_) {
243  event_->Reset();
244  }
245  }
246 
247  void DisableEvent() {
248  event_ = nullptr;
249  }
250 
251  bool IsEventDisabled() const {
252  return !event_;
253  }
254 
255  // Checks whether stream is ready to execute new computation,
256  // used in stream allocation optimization to skip stream that is currently
257  // busy. Depends on context and operator's device, returns true by default
258  virtual bool IsStreamFree(int /* unused */) const {
259  return true;
260  }
261 
262  const std::string& type() {
263  CAFFE_ENFORCE(operator_def_.get() != nullptr);
264  return operator_def_->type();
265  }
266 
267  void annotate_engine(const std::string& engine) {
268  engine_ = engine;
269  }
270 
271  const std::string& engine() const {
272  return engine_;
273  }
274 
275  public:
276  static constexpr int kNoNetPositionSet = -1;
277 
278  private:
279  Workspace* operator_ws_;
280  std::shared_ptr<const OperatorDef> operator_def_;
281  DeviceOption device_option_;
282  std::string engine_;
283  vector<const Blob*> inputs_;
284  vector<Blob*> outputs_;
285 
286  int net_position_{kNoNetPositionSet};
287 
288  protected:
289  virtual void RecordEvent(const char* err_msg = nullptr) {
290  CAFFE_NOT_IMPLEMENTED;
291  }
292 
293  // An event used by asynchronous execution.
294  std::unique_ptr<Event> event_;
295 
296  DISABLE_COPY_AND_ASSIGN(OperatorBase);
297 };
298 
299 // If your operator does not need any specialized contructor or destructor,
300 // you can simply use this to save two lines of code.
301 #define USE_SIMPLE_BASE_CTOR_DTOR(name) \
302  name(const OperatorDef& operator_def, Workspace* ws) \
303  : OperatorBase(operator_def, ws) {} \
304  virtual ~name() noexcept {}
305 
306 // OP_SINGLE_ARG provides a shorter initialization choice for initialization of
307 // member variables for the class constructors.
308 #define OP_SINGLE_ARG(type, name, variable, default) \
309  variable(OperatorBase::GetSingleArgument<type>(name, (default)))
310 
311 // INPUT_TAGS and OUTPUT_TAGS are optional features to name the indices of the
312 // operator's inputs and outputs, in order to avoid confusion. For example, for
313 // a fully convolution layer that has input, weight and bias, you can define its
314 // input tags as:
315 // INPUT_TAGS(INPUT, WEIGHT, BIAS);
316 // And in the code, instead of doing
317 // auto& weight = Input(1);
318 // you can now do
319 // auto& weight = Input(WEIGHT);
320 // to make it more clear.
321 #define INPUT_TAGS(first_input, ...) \
322  enum _InputTags { first_input = 0, __VA_ARGS__ }
323 #define OUTPUT_TAGS(first_input, ...) \
324  enum _OutputTags { first_input = 0, __VA_ARGS__ }
325 
326 // Operator is the class that you usually want to derive, if your operator will
327 // run on different devices. You should then implement the RunOnDevice()
328 // function.
329 template <class Context>
330 class Operator : public OperatorBase {
331  public:
332  explicit Operator(const OperatorDef& operator_def, Workspace* ws)
333  : OperatorBase(operator_def, ws), context_(operator_def.device_option()) {
334  // In the constructor, we switch to the device so that the child class
335  // constructors will run on that device.
336  context_.SwitchToDevice(0);
337  }
338  ~Operator() noexcept override {}
339 
340  inline const Tensor<Context>& Input(int idx) {
341  return OperatorBase::template Input<Tensor<Context>>(idx);
342  }
343  inline Tensor<Context>* Output(int idx) {
344  return OperatorBase::template Output<Tensor<Context>>(idx);
345  }
346 
347  void WaitEvent(const Event& ev, int stream_id = -1) final {
348  if (stream_id >= 0) {
349  context_.SwitchToDevice(stream_id);
350  }
351  context_.WaitEvent(ev);
352  }
353 
354  void WaitEvents(const std::vector<const Event*>& events, int stream_id = -1)
355  final {
356  if (stream_id >= 0) {
357  context_.SwitchToDevice(stream_id);
358  }
359  for (const auto& ev : events) {
360  context_.WaitEvent(*ev);
361  }
362  }
363 
364  // The run function of Operator switches to the device, and then carries out
365  // the actual computation with RunOnDevice(). You should implement RunOnDevice
366  // instead of Run().
367  // Note: Run does not update operator's event and can be used only with
368  // non-async executors that do not rely on events
369  bool Run(int stream_id = 0) final {
370  try {
371  StartAllObservers();
372 
373  context_.SwitchToDevice(stream_id);
374  bool result = RunOnDevice();
375  if (!result) {
376  this->RecordLastFailedOpNetPosition();
377  }
378  context_.FinishDeviceComputation(); // throws on error
379 
380  StopAllObservers();
381 
382  return result;
383  } catch (EnforceNotMet& err) {
384  if (has_debug_def()) {
385  err.AppendMessage(
386  "Error from operator: \n" + ProtoDebugString(debug_def()));
387  AddRelatedBlobInfo(&err);
388  }
389  this->RecordLastFailedOpNetPosition();
390  throw;
391  } catch (...) {
392  this->RecordLastFailedOpNetPosition();
393  throw;
394  }
395  }
396 
397  bool RunAsync(int stream_id = 0) final {
398  try {
399  context_.SwitchToDevice(stream_id);
400  auto result = RunOnDevice();
401  if (result) {
402  if (HasAsyncPart()) {
403  RecordEvent();
404  } else {
405  // Manually set CPU operator's event status to finished,
406  // unless this is an async CPU operator
407  event().SetFinished();
408  }
409  } else {
410  event().SetFinished(getErrorMsg().c_str());
411  this->RecordLastFailedOpNetPosition();
412  }
413  return result;
414  } catch (EnforceNotMet& err) {
415  if (has_debug_def()) {
416  err.AppendMessage(
417  "Error from operator: \n" + ProtoDebugString(debug_def()));
418  AddRelatedBlobInfo(&err);
419  }
420  event().SetFinished(err.what());
421  this->RecordLastFailedOpNetPosition();
422  throw;
423  } catch (const std::exception& err) {
424  event().SetFinished(err.what());
425  this->RecordLastFailedOpNetPosition();
426  throw;
427  } catch (...) {
428  event().SetFinished(getErrorMsg().c_str());
429  this->RecordLastFailedOpNetPosition();
430  throw;
431  }
432  }
433 
434  bool IsStreamFree(int stream_id) const override {
435  return context_.IsStreamFree(device_option(), stream_id);
436  }
437 
438  virtual bool RunOnDevice() = 0;
439 
440  // Returns whether operator has async on device part.
441  // CUDA operators by default have async parts, CPU operators by default
442  // don't have async parts and are finished after RunOnDevice call.
443  // Events of operators that don't have async parts are automatically set
444  // to finished state by RunAsync.
445  // Defaulting to the value from context (true for CUDA, false for CPU).
446  // Override in case of async CPU operators
447  bool HasAsyncPart() const override {
448  return context_.HasAsyncPartDefault();
449  }
450 
451  // Returns whether operator's RunOnDevice schedules async on device part and
452  // can be run without waiting for parent operator's async part to be finished
453  // on the same device.
454  // Note: when true, RunOnDevice must not access the content of the input blobs
455  // as they might not be computed yet
456  // Note: when true, operator's device needs to support async scheduling:
457  // - supports concept of streams: async ops scheduled on the same stream are
458  // guaranteed to be executed in the same order they were scheduled
459  // - provides non-blocking cross device/cross stream synchronization
460  // primitives
461  //
462  // By default, assuming an op with an async part can be scheduled
463  // asynchronously if device supports async scheduling
464  bool SupportsAsyncScheduling() const override {
465  return HasAsyncPart() && context_.SupportsAsyncScheduling();
466  }
467 
468  protected:
469  void RecordEvent(const char* err_msg = nullptr) final {
470  if (event_) {
471  context_.Record(event_.get(), err_msg);
472  }
473  }
474 
475  std::string getErrorMsg() {
476  if (has_debug_def()) {
477  return "Error from operator: " + ProtoDebugString(debug_def());
478  } else {
479  return "Error from operator: no op def";
480  }
481  }
482 
483  Context context_;
484 };
485 
486 #define USE_OPERATOR_BASE_FUNCTIONS \
487  /* using override */ using OperatorBase::HasArgument; \
488  /* using override */ using OperatorBase::GetSingleArgument; \
489  /* using override */ using OperatorBase::HasSingleArgumentOfType; \
490  /* using override */ using OperatorBase::GetRepeatedArgument; \
491  /* using override */ using OperatorBase::InputIsType; \
492  /* using override */ using OperatorBase::InputSize; \
493  /* using override */ using OperatorBase::OutputSize
494 
495 #define USE_OPERATOR_FUNCTIONS(context) \
496  USE_OPERATOR_BASE_FUNCTIONS; \
497  /* using override */ using Operator<context>::context_; \
498  /* using override */ using Operator<context>::Input; \
499  /* using override */ using Operator<context>::InputBlob; \
500  /* using override */ using Operator<context>::Output; \
501  /* using override */ using Operator<context>::OutputBlob
502 
503 #define USE_OPERATOR_CONTEXT_FUNCTIONS USE_OPERATOR_FUNCTIONS(Context)
504 
505 #define USE_SIMPLE_CTOR_DTOR(name) \
506  name(const OperatorDef& operator_def, Workspace* ws) \
507  : Operator<Context>(operator_def, ws) {} \
508  virtual ~name() noexcept {}
509 
510 // Helpers to implement runtime op polymorphism. Often it's convenient to make
511 // an op work on different input types (e.g. i32 vs i64 indices) or special-case
512 // it for particular input size (e.g. ScatterWeightedSum for block size of 1
513 // doesn't need to call Eigen).
514 //
515 // DispatchHelper provides compile-time generation of nested "if" statements,
516 // e.g. `DispatchHelper<FixedValues<1, 4>>::call(this, block_size);`
517 // unrolls into:
518 // if (block_size == 1) {
519 // return DoRunWithValue<1>();
520 // } else if (block_size = 4) {
521 // return DoRunWithValue<4>();
522 // } else {
523 // return DoRunWithValue<-1>();
524 // }`
525 //
526 // DoRunWithValue implementation can use template arguments to do "if"
527 // statements
528 // or proxy to functions in math.h which often provide fixed size
529 // implementation.
530 //
531 // Similarly `TensorTypes<int32_t, int64_t>(this, Input(0))` provides branching
532 // based on type of the first input and calls DoRunWithType.
533 //
534 // Note, that the same instance of Op class is used as the method, not class is
535 // templated. We might consider adding static class-level polymorphism later.
536 //
537 // Convenient macro USE_DISPATCH_HELPER is provided for declaring friendship in
538 // case DoRunWithValue or DoRunWithType are declared non-public.
539 
540 #define USE_DISPATCH_HELPER \
541  template <typename FirstArg, typename... ExtraArgs> \
542  friend struct DispatchHelper
543 
544 template <int... Values>
545 struct FixedValues {};
546 
547 template <typename... Types>
548 struct TensorTypes {};
549 
550 // Special tag that can be listed in TensorTypes to denote that a special
551 // implementation in 'RunWithOtherType' needs to be called instead of failing
552 // Obviously this needs to be the last item in lists, e.g.
553 // TensorTypes<float, double, GenericTensorImplementation>
555 
556 // Same as TensorTypes but call DoRunWithType2
557 template <typename... Types>
558 struct TensorTypes2 {};
559 
560 template <typename Sizes, typename... ExtraArgs>
562 
563 template <int FirstVal, int... Values, typename... ExtraArgs>
564 struct DispatchHelper<FixedValues<FirstVal, Values...>, ExtraArgs...> {
565  template <typename Op>
566  static bool call(Op* op, int value) {
567  if (FirstVal == value) {
568  return op->template DoRunWithValue<ExtraArgs..., FirstVal>();
569  }
570  return DispatchHelper<FixedValues<Values...>, ExtraArgs...>::template call<
571  Op>(op, value);
572  }
573 };
574 
575 template <typename... ExtraArgs>
576 struct DispatchHelper<FixedValues<>, ExtraArgs...> {
577  template <typename Op>
578  static bool call(Op* op, TIndex /*size*/) {
579  return op->template DoRunWithValue<ExtraArgs..., -1>();
580  }
581 };
582 
583 #define CAFFE2_DEFINE_TENSOR_TYPES_DISPATCHER( \
584  TensorTypes, DoRunWithType, DoRunWithOtherType) \
585  template <typename FirstType, typename... Types, typename... ExtraArgs> \
586  struct DispatchHelper<TensorTypes<FirstType, Types...>, ExtraArgs...> { \
587  template <typename Op> \
588  static bool call(Op* op, const TypeMeta& meta) { \
589  static_assert( \
590  !std::is_same<GenericTensorImplementation, FirstType>::value, \
591  "GenericTensorImplementation must be the last in TensorTypes list"); \
592  if (meta.Match<FirstType>()) { \
593  return op->template DoRunWithType<ExtraArgs..., FirstType>(); \
594  } \
595  return DispatchHelper<TensorTypes<Types...>, ExtraArgs...>:: \
596  template call<Op>(op, meta); \
597  } \
598  template <typename Op, typename Context> \
599  static bool call(Op* op, const Tensor<Context>& tensor) { \
600  return call<Op>(op, tensor.meta()); \
601  } \
602  template <typename Op> \
603  static bool call(Op* op, const Blob& blob) { \
604  return call<Op>(op, blob.meta()); \
605  } \
606  }; \
607  \
608  template <typename... ExtraArgs> \
609  struct DispatchHelper<TensorTypes<>, ExtraArgs...> { \
610  template <typename Op> \
611  static bool call(Op* /* unused */, const TypeMeta& meta) { \
612  CAFFE_THROW("Unsupported type of tensor: ", meta.name()); \
613  } \
614  template <typename Op, typename Context> \
615  static bool call(Op* op, const Tensor<Context>& tensor) { \
616  return call<Op>(op, tensor.meta()); \
617  } \
618  template <typename Op> \
619  static bool call(Op* op, const Blob& blob) { \
620  return call<Op>(op, blob.meta()); \
621  } \
622  }; \
623  \
624  template <typename... ExtraArgs> \
625  struct DispatchHelper< \
626  TensorTypes<GenericTensorImplementation>, \
627  ExtraArgs...> { \
628  template <typename Op> \
629  static bool call(Op* op, const TypeMeta&) { \
630  return op->template DoRunWithOtherType<ExtraArgs...>(); \
631  } \
632  template <typename Op, typename Context> \
633  static bool call(Op* op, const Tensor<Context>& tensor) { \
634  return call<Op>(op, tensor.meta()); \
635  } \
636  template <typename Op> \
637  static bool call(Op* op, const Blob& blob) { \
638  return call<Op>(op, blob.meta()); \
639  } \
640  };
641 CAFFE2_DEFINE_TENSOR_TYPES_DISPATCHER(
642  TensorTypes,
643  DoRunWithType,
644  DoRunWithOtherType)
645 CAFFE2_DEFINE_TENSOR_TYPES_DISPATCHER(
646  TensorTypes2,
647  DoRunWithType2,
648  DoRunWithOtherType2)
649 #undef CAFFE2_DEFINE_TENSOR_TYPES_DISPATCHER
650 
651 // The device type registry. This works in two phases:
652 // (1) gDeviceTypeRegistry() maps the device types values to the actual operator
653 // registry function.
654 // (2) Then, one can call the operator registry function to further create the
655 // operators.
656 typedef Registry<
657  std::string,
658  std::unique_ptr<OperatorBase>,
659  const OperatorDef&,
660  Workspace*>
661  OperatorRegistry;
662 typedef Registry<
663  std::string,
664  std::unique_ptr<OperatorBase>,
665  const OperatorDef&,
666  Workspace*>* (*RegistryFunction)();
667 std::map<int32_t, OperatorRegistry*>* gDeviceTypeRegistry();
668 
670  explicit DeviceTypeRegisterer(int32_t type, RegistryFunction func) {
671  if (gDeviceTypeRegistry()->count(type)) {
672  std::cerr << "Device type " << type
673  << "registered twice. This should not happen. Did you have "
674  "duplicated numbers assigned to different devices?";
675  std::exit(1);
676  }
677  // Calling the registry function to get the actual registry pointer.
678  gDeviceTypeRegistry()->emplace(type, func());
679  }
680 };
681 
682 #define CAFFE_REGISTER_DEVICE_TYPE(type, registry_function) \
683  namespace { \
684  static DeviceTypeRegisterer CAFFE_ANONYMOUS_VARIABLE( \
685  DeviceType)(type, &registry_function); \
686  }
687 
688 // The operator registry. Since we are not expecting a great number of devices,
689 // we will simply have an if-then type command and allocate the actual
690 // generation to device-specific registerers.
691 // Note that although we have CUDA and CUDNN here, the registerers themselves do
692 // not depend on specific cuda or cudnn libraries. This means that we will be
693 // able to compile it even when there is no cuda available - we simply do not
694 // link any cuda or cudnn operators.
695 CAFFE_DECLARE_REGISTRY(
696  CPUOperatorRegistry,
697  OperatorBase,
698  const OperatorDef&,
699  Workspace*);
700 #define REGISTER_CPU_OPERATOR_CREATOR(key, ...) \
701  CAFFE_REGISTER_CREATOR(CPUOperatorRegistry, key, __VA_ARGS__)
702 #define REGISTER_CPU_OPERATOR(name, ...) \
703  extern void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \
704  static void CAFFE2_UNUSED CAFFE_ANONYMOUS_VARIABLE_CPU##name() { \
705  CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \
706  } \
707  CAFFE_REGISTER_CLASS(CPUOperatorRegistry, name, __VA_ARGS__)
708 #define REGISTER_CPU_OPERATOR_STR(str_name, ...) \
709  CAFFE_REGISTER_TYPED_CLASS(CPUOperatorRegistry, str_name, __VA_ARGS__)
710 
711 #define REGISTER_CPU_OPERATOR_WITH_ENGINE(name, engine, ...) \
712  CAFFE_REGISTER_CLASS(CPUOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__)
713 
714 CAFFE_DECLARE_REGISTRY(
715  CUDAOperatorRegistry,
716  OperatorBase,
717  const OperatorDef&,
718  Workspace*);
719 #define REGISTER_CUDA_OPERATOR_CREATOR(key, ...) \
720  CAFFE_REGISTER_CREATOR(CUDAOperatorRegistry, key, __VA_ARGS__)
721 #define REGISTER_CUDA_OPERATOR(name, ...) \
722  extern void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \
723  static void CAFFE2_UNUSED CAFFE_ANONYMOUS_VARIABLE_CUDA##name() { \
724  CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \
725  } \
726  CAFFE_REGISTER_CLASS(CUDAOperatorRegistry, name, __VA_ARGS__)
727 #define REGISTER_CUDA_OPERATOR_STR(str_name, ...) \
728  CAFFE_REGISTER_TYPED_CLASS(CUDAOperatorRegistry, str_name, __VA_ARGS__)
729 
730 #define REGISTER_CUDA_OPERATOR_WITH_ENGINE(name, engine, ...) \
731  CAFFE_REGISTER_CLASS( \
732  CUDAOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__)
733 
734 // Macros for cudnn since we use it often
735 #define REGISTER_CUDNN_OPERATOR(name, ...) \
736  REGISTER_CUDA_OPERATOR_WITH_ENGINE(name, CUDNN, __VA_ARGS__)
737 
738 // StaticLinkingProtector is a helper class that ensures that the Caffe2
739 // library is linked correctly with whole archives (in the case of static
740 // linking). What happens is that when CreateOperator is called for the first
741 // time, it instantiates an OperatorLinkingProtector object to check if the
742 // operator registry is empty. If it is empty, this means that we are not
743 // properly linking the library.
744 //
745 // You should not need to use this class.
748  const int registered_ops = CPUOperatorRegistry()->Keys().size();
749  // Note: this is a check failure instead of an exception, because if
750  // the linking is wrong, Caffe2 won't be able to run properly anyway,
751  // so it's better to fail loud.
752  // If Caffe2 is properly linked with whole archive, there should be more
753  // than zero registered ops.
754  if (registered_ops == 0) {
755  LOG(FATAL) <<
756  "You might have made a build error: the Caffe2 library does not seem "
757  "to be linked with whole-static library option. To do so, use "
758  "-Wl,-force_load (clang) or -Wl,--whole-archive (gcc) to link the "
759  "Caffe2 library.";
760  }
761  }
762 };
763 
764 // An exception that can be thrown by an operator constructor that notifies
765 // that it does not support the given setting. This can be usually used for
766 // specific engines that only implement a subset of the features required by
767 // the original operator schema.
768 // TODO(jiayq): make more feature-complete exception message.
769 class UnsupportedOperatorFeature : public std::exception {
770  public:
771  UnsupportedOperatorFeature(const string& msg) : msg_(msg) {}
772  const char* what() const noexcept override {
773  return msg_.c_str();
774  }
775 
776  private:
777  string msg_;
778 };
779 
780 // A helper macro that should ONLY be used in the operator constructor to check
781 // if needed features are met. If not, throws the UnsupportedOperatorFeature
782 // exception with the given message.
783 #define OPERATOR_NEEDS_FEATURE(condition, ...) \
784  if (!(condition)) { \
785  throw UnsupportedOperatorFeature(::caffe2::MakeString(__VA_ARGS__)); \
786  }
787 
788 // Creates an operator with the given operator definition.
789 // Throws on error and never returns nullptr
790 unique_ptr<OperatorBase> CreateOperator(
791  const OperatorDef& operator_def,
792  Workspace* ws,
793  int net_position = OperatorBase::kNoNetPositionSet);
794 
795 const std::string OpRegistryKey(
796  const std::string& op_type,
797  const std::string& engine = "");
798 
799 // User can set the preferred engines as a list of engine names, in
800 // descending order of preference.
801 using EnginePrefType = std::vector<std::string>;
802 // {device_type -> {operator_name -> EnginePrefType}}
803 using PerOpEnginePrefType =
804  CaffeMap<int, CaffeMap<std::string, EnginePrefType>>;
805 // {device_type -> EnginePrefType}
806 using GlobalEnginePrefType = CaffeMap<int, EnginePrefType>;
807 void SetPerOpEnginePref(const PerOpEnginePrefType& per_op_engine_pref);
808 void SetGlobalEnginePref(const GlobalEnginePrefType& global_engine_pref);
809 void SetEnginePref(
810  const PerOpEnginePrefType& per_op_engine_pref,
811  const GlobalEnginePrefType& global_engine_pref);
812 void SetOpEnginePref(
813  const std::string& op_type,
814  const CaffeMap<int, EnginePrefType>& op_pref);
815 
816 TensorShape GetTensorShapeOfBlob(const Blob* b);
817 
818 TensorShapes InferBlobShapesAndTypesFromWorkspace(
819  Workspace* ws,
820  const vector<std::unique_ptr<NetDef>>& nets);
821 
822 TensorShapes InferBlobShapesAndTypesFromMap(
823  const CaffeMap<std::string, std::vector<TIndex>>& blob_dimensions,
824  const vector<std::unique_ptr<NetDef>>& nets);
825 
826 std::map<string, std::pair<DeviceOption, DeviceOption>> ValidateTensorDevices(
827  OperatorBase& op,
828  const OperatorDef& op_def);
829 
830 } // namespace caffe2
831 
832 #endif // CAFFE2_CORE_OPERATOR_H_
Blob is a general container that hosts a typed pointer.
Definition: blob.h:41
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Inherit to make your class observable.
Definition: observer.h:60
Copyright (c) 2016-present, Facebook, Inc.
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:52
A template class that allows one to register classes by keys.
Definition: registry.h:57