Caffe2 - C++ API
A deep learning, cross platform ML framework
operator.h
1 #ifndef CAFFE2_CORE_OPERATOR_H_
2 #define CAFFE2_CORE_OPERATOR_H_
3 
4 #include <array>
5 #include <cfenv>
6 #include <climits>
7 #include <cstddef>
8 #include <exception>
9 #include <set>
10 #include <typeinfo>
11 #include <vector>
12 
13 #include "c10/macros/Macros.h"
14 #include "c10/util/Registry.h"
15 #include "caffe2/core/blob.h"
16 #include "caffe2/core/common.h"
17 #include "caffe2/core/net.h"
18 #include "caffe2/core/observer.h"
19 #include "caffe2/core/operator_gradient.h"
20 #include "caffe2/core/operator_schema.h"
21 #include "caffe2/core/tensor.h"
22 #include "caffe2/core/types.h"
23 #include "caffe2/core/workspace.h"
24 #include "caffe2/proto/caffe2_pb.h"
25 #include "caffe2/utils/proto_utils.h"
26 
27 #include <ATen/core/Tensor.h>
28 #include <ATen/core/function_schema.h>
29 #include <ATen/core/ivalue.h>
30 
31 C10_DECLARE_bool(caffe2_operator_throw_if_fp_exceptions);
32 
33 namespace caffe2 {
34 
35 class CAFFE2_API OperatorBase;
36 typedef ObserverBase<OperatorBase> OperatorObserver;
37 
38 class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
39  public:
40  explicit OperatorBase(const OperatorDef& operator_def, Workspace* ws);
41 
42  /*
43  * Notes: All outputs ivalues must be tensors. Input ivalue list must start
44  * with all tensors ("inputs" in caffe2 terminology),
45  * followed by non-tensors ("arguments" in caffe2 terminology).
46  * Alternatively, inputs can be one tensor list ivalue followed by non-tensors
47  * to represent operators with a variable number of inputs.
48  */
49  explicit OperatorBase(
50  const c10::FunctionSchema& schema,
51  std::vector<c10::IValue> inputs,
52  std::vector<at::Tensor> outputs);
53 
54  virtual ~OperatorBase() noexcept {}
55 
59  bool isLegacyOperator() const {
60  return !fn_schema_;
61  }
62 
63  const c10::FunctionSchema& getFunctionSchema() const {
64  CAFFE_ENFORCE(!isLegacyOperator());
65  return *fn_schema_.get();
66  }
67 
70  inline bool HasArgument(const string& name) const {
71  if (isLegacyOperator()) {
72  CAFFE_ENFORCE(operator_def_, "operator_def was null!");
73  return ArgumentHelper::HasArgument(*operator_def_, name);
74  }
75  return getFunctionSchema().argumentIndexWithName(name).has_value();
76  }
77 
78  // Functions that deal with arguments. Basically, this allows us to map an
79  // argument name to a specific type of argument that we are trying to access.
80  template <typename T>
81  inline T GetSingleArgument(const string& name, const T& default_value) const {
82  if (isLegacyOperator()) {
83  CAFFE_ENFORCE(operator_def_, "operator_def was null!");
84  return ArgumentHelper::GetSingleArgument<OperatorDef, T>(
85  *operator_def_, name, default_value);
86  }
87  auto index = getFunctionSchema().argumentIndexWithName(name);
88  CAFFE_ENFORCE(index.has_value(), "Couldn't get index for argument!", name);
89  const auto& value = newstyle_inputs_[index.value()];
90  return value.template to<T>();
91  }
92 
93  template <typename T>
94  inline bool HasSingleArgumentOfType(const string& name) const {
95  CAFFE_ENFORCE(operator_def_, "operator_def was null!");
96  return ArgumentHelper::HasSingleArgumentOfType<OperatorDef, T>(
97  *operator_def_, name);
98  }
99  template <typename T>
100  inline vector<T> GetVectorFromIValueList(const c10::IValue& value) const {
101  return value.template to<vector<T>>();
102  }
103 
104  template <typename T>
105  inline vector<T> GetRepeatedArgument(
106  const string& name,
107  const vector<T>& default_value = {}) const {
108  if (isLegacyOperator()) {
109  CAFFE_ENFORCE(operator_def_, "operator_def was null!");
110  return ArgumentHelper::GetRepeatedArgument<OperatorDef, T>(
111  *operator_def_, name, default_value);
112  }
113  auto index = getFunctionSchema().argumentIndexWithName(name);
114  CAFFE_ENFORCE(index.has_value(), "Couldn't get index for argument!", name);
115  const auto& value = newstyle_inputs_[index.value()];
116  return GetVectorFromIValueList<T>(value);
117  }
118 
119  // Get the inputs and outputs as specific types.
120  template <typename T>
121  inline const T& Input(int idx) {
122  static_assert(
123  !std::is_same<T, Tensor>::value,
124  "You should use Input<Tensor>(int, DeviceType) for "
125  "Tensor.");
126  DCHECK_LT(idx, inputs_.size());
127  try {
128  return inputs_.at(idx)->template Get<T>();
129  } catch (::caffe2::EnforceNotMet& enf) {
130  if (has_debug_def()) {
131  enf.AppendMessage(".\nOffending Blob name: ");
132  enf.AppendMessage(debug_def().input(idx));
133  enf.AppendMessage(".\n");
134  }
135  throw enf;
136  }
137  }
138 
139  // TODO(jerryzh): Remove template
140  // and the type argument?
141  // This is to keep the API changes minimal and make refactoring
142  // a bit easier
143  template <typename T>
144  inline const T& Input(int idx, DeviceType type) {
145  if (isLegacyOperator()) {
146  static_assert(
147  std::is_same<T, Tensor>::value,
148  "Input(int, DeviceType) is only available for Tensor");
149  DCHECK_LT(idx, inputs_.size());
150  try {
151  // TODO(jerryzh): We'll need to check device type in Get<T>() later
152  // Get<T>() -> Get<T>(type)
153  const auto& tensor = inputs_.at(idx)->template Get<T>();
154  return tensor;
155  } catch (::caffe2::EnforceNotMet& enf) {
156  if (has_debug_def()) {
157  enf.AppendMessage(".\nOffending Blob name: ");
158  enf.AppendMessage(debug_def().input(idx));
159  enf.AppendMessage(".\n");
160  }
161  throw enf;
162  }
163  }
164  DCHECK_LT(0, newstyle_inputs_.size());
165  IValue ival;
166  if (newstyle_inputs_[0].isTensorList()) {
167  // if the first input is a tensor list, we get input tensors by indexing into that list.
168  // currently, this means that only tensors from that list are accessible as inputs.
169  // any hypothetical input tensors that come after the list are not accessible.
170  const auto& tensorList = newstyle_inputs_[0].toTensorListRef();
171  DCHECK_LT(idx, tensorList.size());
172  ival = tensorList[idx];
173  } else {
174  // if the first input is not a tensor list, we get input tensors by indexing into the inputs.
175  DCHECK_LT(idx, newstyle_inputs_.size());
176  ival = newstyle_inputs_[idx];
177  }
178  CAFFE_ENFORCE(
179  ival.isTensor(),
180  "Input(int, DeviceType) is only available for IValues that store Tensors");
181  Tensor tensor = caffe2::Tensor(ival.toTensor());
182  CAFFE_ENFORCE_EQ(tensor.GetDeviceType(), type);
183  input_tensors_[idx] = std::move(tensor);
184  return input_tensors_[idx];
185  }
186 
187  template <typename T>
188  inline T* Output(int idx) {
189  static_assert(
190  !std::is_same<T, Tensor>::value,
191  "You should use Output<Tensor>(int, DeviceType) for "
192  "Tensor.");
193  return outputs_.at(idx)->template GetMutable<T>();
194  }
195 
196  // TODO(jerryzh): Remove this template
197  template <typename T>
198  inline T* Output(int idx, DeviceType type) {
199  if (isLegacyOperator()) {
200  static_assert(
201  std::is_same<T, Tensor>::value,
202  "Output(int, DeviceType) is only available for Tensor");
203  // When you get a Tensor here it is not fully initialized
204  return BlobGetMutableTensor(outputs_.at(idx), type);
205  }
206  auto& output = newstyle_outputs_[idx];
207  Tensor tensor = caffe2::Tensor(output);
208  if (!tensor.defined() || tensor.GetDeviceType() != type) {
209  // Fix tensor type
210  tensor = Tensor(type);
211  output = at::Tensor(std::move(tensor.getIntrusivePtr()));
212  }
213  output_tensors_[idx] = caffe2::Tensor(output);
214  return &output_tensors_[idx];
215  }
216 
217  inline Tensor
218  XOutputTensor(int idx, at::IntArrayRef dims, at::TensorOptions options) {
219  CAFFE_ENFORCE_WITH_CALLER(
220  options.device_opt() != c10::nullopt,
221  "device must be provided in option.");
222  if (isLegacyOperator()) {
223  return XBlobGetMutableTensor(outputs_.at(idx), dims, options);
224  }
225 
226  return OutputTensor(idx, dims, options)->UnsafeSharedInstance();
227  }
228 
229  void SetOutputTensor(int idx, Tensor tensor) {
230  if (!isLegacyOperator()) {
231  newstyle_outputs_[idx] = at::Tensor(tensor);
232 
233  // also update the tensor in the hack
234  output_tensors_[idx] = std::move(tensor);
235  } else {
236  // update the tensor in the workspace
237  BlobSetTensor(outputs_.at(idx), std::move(tensor));
238  }
239  }
240 
241  Tensor OutputTensorOrUndefined(int idx) {
242  if (isLegacyOperator()) {
243  return BlobGetTensorOrUndefined(*outputs_.at(idx));
244  }
245  return output_tensors_[idx].UnsafeSharedInstance();
246  }
247 
248  inline Tensor*
249  OutputTensor(int idx, at::IntArrayRef dims, at::TensorOptions options) {
250  if (isLegacyOperator()) {
251  CAFFE_ENFORCE_WITH_CALLER(
252  options.device_opt() != c10::nullopt,
253  "device must be provided in options.");
254  return BlobGetMutableTensor(outputs_.at(idx), dims, options);
255  }
256  auto& output = newstyle_outputs_[idx];
257  Tensor tensor =
258  GetSizedTensorWithOptions(caffe2::Tensor(output), dims, options);
259  // assign it back in case it changed
260  output = at::Tensor(std::move(tensor.getIntrusivePtr()));
261 
262  output_tensors_[idx] = caffe2::Tensor(output);
263  return &output_tensors_[idx];
264  }
265 
266  // Get output Tensor of the operator and CopyFrom the given Tensor
267  Tensor* OutputTensorCopyFrom(
268  int idx,
269  at::TensorOptions options,
270  const Tensor& src,
271  bool async = false) {
272  CAFFE_ENFORCE_WITH_CALLER(
273  options.device_opt() != c10::nullopt,
274  "device must be provided in options.");
275  // Ouptut Tensor will always have the same data type as `src`
276  if (!options.has_dtype()) {
277  options = options.dtype(src.dtype());
278  }
279  CAFFE_ENFORCE_WITH_CALLER(
280  options.dtype() == src.dtype(),
281  "We don't allow change of src data type in OutputTensorCopyFrom");
282  Tensor* t = OutputTensor(idx, src.sizes(), options);
283  t->CopyFrom(src, async);
284  return t;
285  }
286 
287  Tensor* OutputTensorAlias(int idx, const Tensor& src) {
288  return BlobSetTensor(OutputBlob(idx),
289  src.Alias());
290  }
291 
292 
293  template <typename T>
294  inline T* Output(int idx, T* allocated) {
295  outputs_.at(idx)->Reset(allocated);
296  return allocated;
297  }
298 
299  inline const Blob& InputBlob(int idx) {
300  return *inputs_.at(idx);
301  }
302 
303  inline Blob* OutputBlob(int idx) {
304  return outputs_.at(idx);
305  }
306 
307  // Check whether output j is an alias of input i by comparing Blob pointers,
308  // note this does not check if the two Blobs points to the same Tensor, or if
309  // the Tensor pointers point to the same TensorImpl, or if the Storages alias
310  inline bool IsInputOutputAlias(int i, int j) {
311  return inputs_.at(i) == outputs_.at(j);
312  }
313 
314  template <typename T>
315  inline bool InputIsType(int idx) {
316  static_assert(
317  !std::is_same<T, Tensor>::value,
318  "You should use InputIsTensorType(int, DeviceType) for "
319  "Tensor.");
320  return inputs_.at(idx)->template IsType<T>();
321  }
322 
323  inline bool InputIsTensorType(int idx, DeviceType device_type) {
324  return BlobIsTensorType(*inputs_.at(idx), device_type);
325  }
326 
327  template <typename T>
328  inline bool OutputIsType(int idx) {
329  static_assert(
330  !std::is_same<T, Tensor>::value,
331  "You should use OutputIsTensorType(int, DeviceType) for "
332  "Tensor.");
333  return outputs_.at(idx)->template IsType<T>();
334  }
335 
336  inline bool OutputIsTensorType(int idx, DeviceType type) {
337  return BlobIsTensorType(*outputs_.at(idx), type);
338  }
339 
340  inline int InputSize() const {
341  return input_size_;
342  }
343 
344  inline int OutputSize() const {
345  if (isLegacyOperator()) {
346  return outputs_.size();
347  }
348  return newstyle_outputs_.size();
349  }
350  inline const vector<const Blob*>& Inputs() const { return inputs_; }
351  inline const vector<Blob*>& Outputs() { return outputs_; }
352  vector<TensorShape> InputTensorShapes() const;
353 
354  virtual void WaitEvent(const Event& ev, int /*stream_id */ = -1) {
355  ev.Finish();
356  }
357 
358  inline void Wait(const OperatorBase& other, int stream_id = -1) {
359  if (!other.IsEventDisabled()) {
360  WaitEvent(other.event(), stream_id);
361  }
362  }
363 
364  virtual void WaitEvents(
365  const std::vector<const Event*>& events,
366  int /*stream_id*/ = -1) {
367  for (const auto& ev : events) {
368  ev->Finish();
369  }
370  }
371 
372  virtual void Finish() {
373  if (event_) {
374  event_->Finish();
375  }
376  }
377 
378  virtual bool Run(int /* unused */ /*stream_id*/ = 0) {
379  CAFFE_NOT_IMPLEMENTED;
380  }
381 
382  virtual bool HasAsyncPart() const {
383  return false;
384  }
385 
386  virtual bool SupportsAsyncScheduling() const {
387  return false;
388  }
389 
390  // RunAsync, if implemenented by the specific operators, will schedule the
391  // computation on the corresponding context and record the event in its
392  // event_ member object. If the specific operator does not support RunAsync,
393  // it will simply be synchronous as a fallback.
394  virtual bool RunAsync(int stream_id = 0) {
395  try {
396  auto result = Run(stream_id);
397  if (result) {
398  if (HasAsyncPart()) {
399  RecordEvent();
400  } else {
401  SetEventFinished();
402  }
403  } else {
404  SetEventFinished(getErrorMsg().c_str());
405  }
406  return result;
407  } catch (EnforceNotMet& err) {
408  SetEventFinishedWithException(err.what());
409  throw;
410  } catch (const std::exception& err) {
411  SetEventFinishedWithException(err.what());
412  throw;
413  } catch (...) {
414  SetEventFinishedWithException(getErrorMsg().c_str());
415  throw;
416  }
417  }
418 
419  virtual void AddRelatedBlobInfo(EnforceNotMet* err) {
420  if (!has_debug_def()) {
421  return;
422  }
423 
424  bool found_input;
425  if (err->caller() != nullptr) {
426  for (size_t i = 0; i < inputs_.size(); i++) {
427  if (inputs_[i]->GetRaw() == err->caller()) {
428  found_input = true;
429  err->AppendMessage(
430  "\n** while accessing input: " + debug_def().input(i));
431  break;
432  }
433  }
434  for (size_t i = 0; i < outputs_.size(); i++) {
435  if (outputs_[i]->GetRaw() == err->caller()) {
436  if (found_input) {
437  err->AppendMessage("\n OR ");
438  }
439  err->AppendMessage(
440  "\n** while accessing output: " + debug_def().output(i));
441  break;
442  }
443  }
444  }
445  }
446 
447  inline const OperatorDef& debug_def() const {
448  CAFFE_ENFORCE(has_debug_def(), "operator_def was null!");
449  return *operator_def_;
450  }
451 
452  inline void set_debug_def(
453  const std::shared_ptr<const OperatorDef>& operator_def) {
454  operator_def_ = operator_def;
455  }
456 
457  inline bool has_debug_def() const {
458  return operator_def_ != nullptr;
459  }
460 
461  public:
462  void RecordLastFailedOpNetPosition() {
463  if (net_position_ != kNoNetPositionSet) {
464  VLOG(1) << "Operator with id " << net_position_ << " failed";
465  operator_ws_->last_failed_op_net_position = net_position_;
466  } else {
467  VLOG(1) << "Failed operator doesn't have id set";
468  }
469  }
470 
471  int net_position() const {
472  return net_position_;
473  }
474 
475  void set_net_position(int idx) {
476  net_position_ = idx;
477  }
478 
479  const DeviceOption& device_option() const {
480  return device_option_;
481  }
482 
483  const Event& event() const {
484  CAFFE_ENFORCE(event_, "Event is disabled");
485  return *event_;
486  }
487 
488  Event& event() {
489  CAFFE_ENFORCE(event_, "Event is disabled");
490  return *event_;
491  }
492 
493  void ResetEvent() {
494  if (event_) {
495  event_->Reset();
496  }
497  }
498 
499  void DisableEvent() {
500  event_ = nullptr;
501  }
502 
503  bool IsEventDisabled() const {
504  return !event_;
505  }
506 
507  // Internal API invoked by observers. Normal callers shouldn't invoke it.
508  virtual void SyncDeviceBarrierForObservers() {
509  CAFFE_NOT_IMPLEMENTED;
510  }
511 
512  // Checks whether stream is ready to execute new computation,
513  // used in stream allocation optimization to skip stream that is currently
514  // busy. Depends on context and operator's device, returns true by default
515  virtual bool IsStreamFree(int /* unused */) const {
516  return true;
517  }
518 
519  const std::string& type() const {
520  return type_;
521  }
522 
523  void annotate_engine(const std::string& engine) {
524  engine_ = engine;
525  }
526 
527  const std::string& engine() const {
528  return engine_;
529  }
530 
531  void SetExecutorHelper(ExecutorHelper* helper) {
532  helper_ = helper;
533  }
534 
535  ExecutorHelper* GetExecutorHelper() const {
536  return helper_;
537  }
538 
539  std::vector<at::Tensor> move_newstyle_outputs() && {
540  return std::move(newstyle_outputs_);
541  }
542 
543  public:
544  static const int kNoNetPositionSet = -1;
545 
546  private:
547  Workspace* operator_ws_;
548  std::shared_ptr<const OperatorDef> operator_def_;
549  DeviceOption device_option_;
550  std::string engine_;
551  std::string type_;
552  vector<const Blob*> inputs_;
553  vector<Blob*> outputs_;
554  // Preferrably use c10::optional, but nvcc doesn't work
555  std::unique_ptr<const c10::FunctionSchema> fn_schema_ = nullptr;
556  vector<c10::IValue> newstyle_inputs_;
557  vector<at::Tensor> newstyle_outputs_;
558  // HACK
559  // We preserve the fact that Output() returns Tensor*
560  // by storing Tensor in a vector owned by the
561  // operator.
562  vector<caffe2::Tensor> input_tensors_;
563  vector<caffe2::Tensor> output_tensors_;
564 
565  int input_size_;
566 
567  int net_position_{kNoNetPositionSet};
568 
569  ExecutorHelper* helper_ = nullptr;
570 
571  protected:
572  virtual void RecordEvent(const char* /*err_msg*/ = nullptr) {
573  CAFFE_NOT_IMPLEMENTED;
574  }
575 
576  void SetEventFinished(const char* err_msg = nullptr) {
577  if (event_) {
578  event_->SetFinished(err_msg);
579  }
580  }
581 
582  void SetEventFinishedWithException(const char* err_msg = nullptr) {
583  if (event_) {
584  event_->SetFinishedWithException(err_msg);
585  }
586  }
587 
588  std::string getErrorMsg() {
589  if (has_debug_def()) {
590  return "Error from operator: " + ProtoDebugString(debug_def());
591  } else {
592  return "Error from operator: no op def";
593  }
594  }
595 
596  // An event used by asynchronous execution.
597  std::unique_ptr<Event> event_;
598 
599  C10_DISABLE_COPY_AND_ASSIGN(OperatorBase);
600 };
601 
602 template <>
603 inline NetDef OperatorBase::GetSingleArgument<NetDef>(
604  const std::string& name,
605  const NetDef& default_value) const {
606  if (isLegacyOperator()) {
607  CAFFE_ENFORCE(operator_def_, "operator_def was null!");
608  return ArgumentHelper::GetSingleArgument<OperatorDef, NetDef>(
609  *operator_def_, name, default_value);
610  }
611  CAFFE_THROW("Cannot get NetDefs from IValue");
612  return NetDef();
613 }
614 
615 template <>
616 inline vector<int> OperatorBase::GetVectorFromIValueList<int>(
617  const c10::IValue& value) const {
618  const auto& vs = value.toIntListRef();
619  vector<int> out;
620  out.reserve(vs.size());
621  for (const auto& v : vs) {
622  out.emplace_back(v);
623  }
624  return out;
625 }
626 
627 template <>
628 inline vector<float> OperatorBase::GetVectorFromIValueList<float>(
629  const c10::IValue& value) const {
630  const auto& vs = value.toDoubleListRef();
631  vector<float> out;
632  out.reserve(vs.size());
633  for (const auto& v : vs) {
634  out.emplace_back(v);
635  }
636  return out;
637 }
638 
639 template <>
640 inline vector<string> OperatorBase::GetVectorFromIValueList<string>(
641  const c10::IValue& value) const {
642  CAFFE_THROW("Cannot extract vector<string> from ivalue.");
643  vector<string> out;
644  return out;
645 }
646 
647 // OP_SINGLE_ARG provides a shorter initialization choice for initialization of
648 // member variables for the class constructors.
649 // This is a workaround for CUDA9.2 and GCC7
650 #if defined(CUDART_VERSION) && CUDART_VERSION >= 9020 && __GNUC__ >= 7
651 #define OP_SINGLE_ARG(type, name, variable, default) \
652  variable(this->template GetSingleArgument<type>(name, (default)))
653 #else
654 #define OP_SINGLE_ARG(type, name, variable, default) \
655  variable(OperatorBase::GetSingleArgument<type>(name, (default)))
656 #endif
657 
658 // INPUT_TAGS and OUTPUT_TAGS are optional features to name the indices of the
659 // operator's inputs and outputs, in order to avoid confusion. For example, for
660 // a fully convolution layer that has input, weight and bias, you can define its
661 // input tags as:
662 // INPUT_TAGS(INPUT, WEIGHT, BIAS);
663 // And in the code, instead of doing
664 // auto& weight = Input(1);
665 // you can now do
666 // auto& weight = Input(WEIGHT);
667 // to make it more clear.
668 #define INPUT_TAGS(first_input, ...) \
669  enum _InputTags { first_input = 0, __VA_ARGS__ }
670 #define OUTPUT_TAGS(first_input, ...) \
671  enum _OutputTags { first_input = 0, __VA_ARGS__ }
672 
673 // Operator is the class that you usually want to derive, if your operator will
674 // run on different devices. You should then implement the RunOnDevice()
675 // function.
676 template <class Context>
677 class Operator : public OperatorBase {
678  public:
679  explicit Operator(const OperatorDef& operator_def, Workspace* ws)
680  : OperatorBase(operator_def, ws), context_(operator_def.device_option()) {
681  // In the constructor, we switch to the device so that the child class
682  // constructors will run on that device.
683  context_.SwitchToDevice();
684  }
685  explicit Operator(
686  const c10::FunctionSchema& fn_schema,
687  std::vector<c10::IValue> inputs,
688  std::vector<at::Tensor> outputs)
689  : OperatorBase(fn_schema, std::move(inputs), std::move(outputs)) {
690  // In the constructor, we switch to the device so that the child class
691  // constructors will run on that device.
692  context_.SwitchToDevice();
693  }
694  ~Operator() noexcept override {}
695 
702  inline const Tensor& Input(
703  int idx,
704  DeviceType type = Context::GetDeviceType()) {
705  return OperatorBase::template Input<Tensor>(idx, type);
706  }
707 
712  // We'll default device to the device of the current Operator Context
713  if (options.device_opt() == c10::nullopt) {
714  return OperatorBase::XOutputTensor(
715  idx, dims, options.device(context_.device()));
716  }
717  return OperatorBase::XOutputTensor(idx, dims, options);
718  }
719 
768  Tensor* Output(int idx, at::IntArrayRef dims, at::TensorOptions options) {
769  // We'll default device to the device of the current Operator Context
770  if (options.device_opt() == c10::nullopt) {
771  return OperatorBase::OutputTensor(
772  idx, dims, options.device(context_.device()));
773  }
774  return OperatorBase::OutputTensor(idx, dims, options);
775  }
776 
779  inline Tensor* Output(int idx, DeviceType type = Context::GetDeviceType()) {
780  return OperatorBase::template Output<Tensor>(idx, type);
781  }
782 
790  Tensor* OutputTensorCopyFrom(
791  int idx,
792  at::TensorOptions options,
793  const Tensor& src,
794  bool async = false) {
795  if (options.device_opt() == c10::nullopt) {
796  return OperatorBase::OutputTensorCopyFrom(
797  idx, options.device(context_.device()), src, async);
798  }
799  return OperatorBase::OutputTensorCopyFrom(idx, options, src, async);
800  }
801 
802  void WaitEvent(const Event& ev, int stream_id = -1) final {
803  if (stream_id >= 0) {
804  context_.SwitchToDevice(stream_id);
805  }
806  context_.WaitEvent(ev);
807  }
808 
809  void WaitEvents(const std::vector<const Event*>& events, int stream_id = -1)
810  final {
811  if (stream_id >= 0) {
812  context_.SwitchToDevice(stream_id);
813  }
814  for (const auto& ev : events) {
815  context_.WaitEvent(*ev);
816  }
817  }
818 
819  // The run function of Operator switches to the device, and then carries out
820  // the actual computation with RunOnDevice(). You should implement RunOnDevice
821  // instead of Run().
822  // Note: Run does not update operator's event and can be used only with
823  // non-async executors that do not rely on events
824  bool Run(int stream_id = 0) final {
825  try {
826  StartAllObservers();
827 
828  context_.SwitchToDevice(stream_id);
829 
830  if (FLAGS_caffe2_operator_throw_if_fp_exceptions) {
831  std::feclearexcept(FE_ALL_EXCEPT);
832  }
833  bool result = RunOnDevice();
834  if (FLAGS_caffe2_operator_throw_if_fp_exceptions) {
835  CAFFE_ENFORCE(
836  !std::fetestexcept(FE_DIVBYZERO),
837  "Division by zero floating point exception (FE_DIVBYZERO) reported.");
838  CAFFE_ENFORCE(
839  !std::fetestexcept(FE_INVALID),
840  "Invalid floating point exception (FE_INVALID) reported.");
841  CAFFE_ENFORCE(
842  !std::fetestexcept(FE_OVERFLOW),
843  "Overflow floating point exception (FE_OVERFLOW) reported.");
844  }
845  if (!result) {
846  this->RecordLastFailedOpNetPosition();
847  }
848  context_.FinishDeviceComputation(); // throws on error
849 
850  StopAllObservers();
851 
852  return result;
853  } catch (EnforceNotMet& err) {
854  if (has_debug_def()) {
855  err.AppendMessage(
856  "Error from operator: \n" + ProtoDebugString(debug_def()));
857  AddRelatedBlobInfo(&err);
858  }
859  this->RecordLastFailedOpNetPosition();
860  StopAllObservers();
861  throw;
862  } catch (...) {
863  this->RecordLastFailedOpNetPosition();
864  StopAllObservers();
865  throw;
866  }
867  }
868 
869  bool RunAsync(int stream_id = 0) final {
870  try {
871  StartAllObservers();
872 
873  context_.SwitchToDevice(stream_id);
874  auto result = RunOnDevice();
875  if (result) {
876  if (HasAsyncPart()) {
877  RecordEvent();
878  } else {
879  // Manually set CPU operator's event status to finished,
880  // unless this is an async CPU operator
881  SetEventFinished();
882  }
883  } else {
884  SetEventFinished(getErrorMsg().c_str());
885  this->RecordLastFailedOpNetPosition();
886  }
887 
888  StopAllObservers();
889 
890  return result;
891  } catch (EnforceNotMet& err) {
892  if (has_debug_def()) {
893  err.AppendMessage(
894  "Error from operator: \n" + ProtoDebugString(debug_def()));
895  AddRelatedBlobInfo(&err);
896  }
897  SetEventFinishedWithException(err.what());
898  this->RecordLastFailedOpNetPosition();
899  StopAllObservers();
900  throw;
901  } catch (const std::exception& err) {
902  SetEventFinishedWithException(err.what());
903  this->RecordLastFailedOpNetPosition();
904  StopAllObservers();
905  throw;
906  } catch (...) {
907  SetEventFinishedWithException(getErrorMsg().c_str());
908  this->RecordLastFailedOpNetPosition();
909  StopAllObservers();
910  throw;
911  }
912  }
913 
914  bool IsStreamFree(int stream_id) const override {
915  return context_.IsStreamFree(device_option(), stream_id);
916  }
917 
918  virtual bool RunOnDevice() = 0;
919 
920  // Returns whether operator has async on device part.
921  // CUDA operators by default have async parts, CPU operators by default
922  // don't have async parts and are finished after RunOnDevice call.
923  // Events of operators that don't have async parts are automatically set
924  // to finished state by RunAsync.
925  // Defaulting to the value from context (true for CUDA, false for CPU).
926  // Override in case of async CPU operators
927  // Async CPU operators are expected to catch all exceptions in async parts
928  // and set Event to finished/failed state with Event::SetFinished or
929  // SetFinishedWithException call.
930  bool HasAsyncPart() const override {
931  return context_.HasAsyncPartDefault();
932  }
933 
934  // Returns whether operator's RunOnDevice schedules async on device part and
935  // can be run without waiting for parent operator's async part to be finished
936  // on the same device.
937  // Note: when true, RunOnDevice must not access the content of the input blobs
938  // as they might not be computed yet
939  // Note: when true, operator's device needs to support async scheduling:
940  // - supports concept of streams: async ops scheduled on the same stream are
941  // guaranteed to be executed in the same order they were scheduled
942  // - provides non-blocking cross device/cross stream synchronization
943  // primitives
944  //
945  // By default, assuming an op with an async part can be scheduled
946  // asynchronously if device supports async scheduling
947  bool SupportsAsyncScheduling() const override {
948  return HasAsyncPart() && context_.SupportsAsyncScheduling();
949  }
950 
951  void SyncDeviceBarrierForObservers() override {
952  context_.FinishDeviceComputation();
953  }
954 
955  const Context* getContext() const {
956  return &context_;
957  }
958  Context* getContext() {
959  return &context_;
960  }
961 
962  protected:
963  void RecordEvent(const char* err_msg = nullptr) final {
964  if (event_) {
965  context_.Record(event_.get(), err_msg);
966  }
967  }
968 
969  Context context_;
970 };
971 
972 #define USE_OPERATOR_BASE_FUNCTIONS \
973  /* using override */ using OperatorBase::HasArgument; \
974  /* using override */ using OperatorBase::GetSingleArgument; \
975  /* using override */ using OperatorBase::HasSingleArgumentOfType; \
976  /* using override */ using OperatorBase::GetRepeatedArgument; \
977  /* using override */ using OperatorBase::InputIsType; \
978  /* using override */ using OperatorBase::InputSize; \
979  /* using override */ using OperatorBase::Output; \
980  /* using override */ using OperatorBase::Input; \
981  /* using override */ using OperatorBase::OutputSize; \
982  /* using override */ using OperatorBase::IsInputOutputAlias; \
983  /* using override */ using OperatorBase::OutputTensorAlias
984 
985 #define USE_OPERATOR_FUNCTIONS(context) \
986  USE_OPERATOR_BASE_FUNCTIONS; \
987  /* using override */ using Operator<context>::context_; \
988  /* using override */ using Operator<context>::Input; \
989  /* using override */ using Operator<context>::InputBlob; \
990  /* using override */ using Operator<context>::Output; \
991  /* using override */ using Operator<context>::OutputBlob; \
992  /* using override */ using Operator<context>::OutputTensorCopyFrom
993 
994 #define USE_OPERATOR_CONTEXT_FUNCTIONS USE_OPERATOR_FUNCTIONS(Context)
995 
996 #define USE_SIMPLE_CTOR_DTOR(name) \
997  template<class... Args> explicit name(Args&&... args) \
998  : Operator<Context>(std::forward<Args>(args)...) {} \
999  virtual ~name() noexcept {}
1000 
1001 // Helpers to implement runtime op polymorphism. Often it's convenient to make
1002 // an op work on different input types (e.g. i32 vs i64 indices) or special-case
1003 // it for particular input size (e.g. ScatterWeightedSum for block size of 1
1004 // doesn't need to call Eigen).
1005 //
1006 // DispatchHelper provides compile-time generation of nested "if" statements,
1007 // e.g. `DispatchHelper<FixedValues<1, 4>>::call(this, block_size);`
1008 // unrolls into:
1009 // if (block_size == 1) {
1010 // return DoRunWithValue<1>();
1011 // } else if (block_size = 4) {
1012 // return DoRunWithValue<4>();
1013 // } else {
1014 // return DoRunWithValue<-1>();
1015 // }`
1016 //
1017 // DoRunWithValue implementation can use template arguments to do "if"
1018 // statements
1019 // or proxy to functions in math.h which often provide fixed size
1020 // implementation.
1021 //
1022 // Similarly `TensorTypes<int32_t, int64_t>(this, Input(0))` provides branching
1023 // based on type of the first input and calls DoRunWithType.
1024 //
1025 // Note, that the same instance of Op class is used as the method, not class is
1026 // templated. We might consider adding static class-level polymorphism later.
1027 //
1028 // Convenient macro USE_DISPATCH_HELPER is provided for declaring friendship in
1029 // case DoRunWithValue or DoRunWithType are declared non-public.
1030 
1031 #define USE_DISPATCH_HELPER \
1032  template <typename FirstArg, typename... ExtraArgs> \
1033  friend struct DispatchHelper
1034 
1035 template <int... Values>
1036 struct FixedValues {};
1037 
1038 template <typename... Types>
1039 struct TensorTypes {};
1040 
1041 // Special tag that can be listed in TensorTypes to denote that a special
1042 // implementation in 'RunWithOtherType' needs to be called instead of failing
1043 // Obviously this needs to be the last item in lists, e.g.
1044 // TensorTypes<float, double, GenericTensorImplementation>
1046 
1047 // Same as TensorTypes but call DoRunWithType2
1048 template <typename... Types>
1049 struct TensorTypes2 {};
1050 
1051 template <typename Sizes, typename... ExtraArgs>
1053 
1054 template <int FirstVal, int... Values, typename... ExtraArgs>
1055 struct DispatchHelper<FixedValues<FirstVal, Values...>, ExtraArgs...> {
1056  template <typename Op>
1057  static bool call(Op* op, int value) {
1058  if (FirstVal == value) {
1059  return op->template DoRunWithValue<ExtraArgs..., FirstVal>();
1060  }
1061  return DispatchHelper<FixedValues<Values...>, ExtraArgs...>::template call<
1062  Op>(op, value);
1063  }
1064 };
1065 
1066 template <typename... ExtraArgs>
1067 struct DispatchHelper<FixedValues<>, ExtraArgs...> {
1068  template <typename Op>
1069  static bool call(Op* op, int64_t /*size*/) {
1070  return op->template DoRunWithValue<ExtraArgs..., -1>();
1071  }
1072 };
1073 
1074 #define C10_DEFINE_TENSOR_TYPES_DISPATCHER( \
1075  TensorTypes, DoRunWithType, DoRunWithOtherType) \
1076  template <typename FirstType, typename... Types, typename... ExtraArgs> \
1077  struct DispatchHelper<TensorTypes<FirstType, Types...>, ExtraArgs...> { \
1078  template <typename Op> \
1079  static bool call(Op* op, const TypeMeta& meta) { \
1080  static_assert( \
1081  !std::is_same<GenericTensorImplementation, FirstType>::value, \
1082  "GenericTensorImplementation must be the last in TensorTypes list"); \
1083  if (meta.Match<FirstType>()) { \
1084  return op->template DoRunWithType<ExtraArgs..., FirstType>(); \
1085  } \
1086  return DispatchHelper<TensorTypes<Types...>, ExtraArgs...>:: \
1087  template call<Op>(op, meta); \
1088  } \
1089  template <typename Op> \
1090  static bool call(Op* op, const Tensor& tensor) { \
1091  return call<Op>(op, tensor.dtype()); \
1092  } \
1093  template <typename Op> \
1094  static bool call(Op* op, const Blob& blob) { \
1095  return call<Op>(op, blob.meta()); \
1096  } \
1097  }; \
1098  \
1099  template <typename... ExtraArgs> \
1100  struct DispatchHelper<TensorTypes<>, ExtraArgs...> { \
1101  template <typename Op> \
1102  static bool call(Op* /* unused */, const TypeMeta& meta) { \
1103  CAFFE_THROW("Unsupported type of tensor: ", meta.name()); \
1104  } \
1105  template <typename Op> \
1106  static bool call(Op* op, const Tensor& tensor) { \
1107  return call<Op>(op, tensor.dtype()); \
1108  } \
1109  template <typename Op> \
1110  static bool call(Op* op, const Blob& blob) { \
1111  return call<Op>(op, blob.meta()); \
1112  } \
1113  }; \
1114  \
1115  template <typename... ExtraArgs> \
1116  struct DispatchHelper< \
1117  TensorTypes<GenericTensorImplementation>, \
1118  ExtraArgs...> { \
1119  template <typename Op> \
1120  static bool call(Op* op, const TypeMeta&) { \
1121  return op->template DoRunWithOtherType<ExtraArgs...>(); \
1122  } \
1123  template <typename Op> \
1124  static bool call(Op* op, const Tensor& tensor) { \
1125  return call<Op>(op, tensor.dtype()); \
1126  } \
1127  template <typename Op> \
1128  static bool call(Op* op, const Blob& blob) { \
1129  return call<Op>(op, blob.meta()); \
1130  } \
1131  };
1132 C10_DEFINE_TENSOR_TYPES_DISPATCHER(
1133  TensorTypes,
1134  DoRunWithType,
1135  DoRunWithOtherType)
1136 C10_DEFINE_TENSOR_TYPES_DISPATCHER(
1137  TensorTypes2,
1138  DoRunWithType2,
1139  DoRunWithOtherType2)
1140 #undef C10_DEFINE_TENSOR_TYPES_DISPATCHER
1141 
1142 // The device type registry. This works in two phases:
1143 // (1) gDeviceTypeRegistry() maps the device types values to the actual operator
1144 // registry function.
1145 // (2) Then, one can call the operator registry function to further create the
1146 // operators.
1147 typedef c10::Registry<
1148  std::string,
1149  std::unique_ptr<OperatorBase>,
1150  const OperatorDef&,
1151  Workspace*>
1152  OperatorRegistry;
1153 typedef c10::Registry<
1154  std::string,
1155  std::unique_ptr<OperatorBase>,
1156  const OperatorDef&,
1157  Workspace*>* (*RegistryFunction)();
1158 CAFFE2_API std::map<DeviceType, OperatorRegistry*>* gDeviceTypeRegistry();
1159 
1160 struct CAFFE2_API DeviceTypeRegisterer {
1161  explicit DeviceTypeRegisterer(DeviceType type, RegistryFunction func) {
1162  if (gDeviceTypeRegistry()->count(type)) {
1163  std::cerr << "Device type " << DeviceTypeName(type)
1164  << "registered twice. This should not happen. Did you have "
1165  "duplicated numbers assigned to different devices?";
1166  std::exit(1);
1167  }
1168  // Calling the registry function to get the actual registry pointer.
1169  gDeviceTypeRegistry()->emplace(type, func());
1170  }
1171 };
1172 
1173 #define CAFFE_REGISTER_DEVICE_TYPE(type, registry_function) \
1174  namespace { \
1175  static DeviceTypeRegisterer C10_ANONYMOUS_VARIABLE( \
1176  DeviceType)(type, &registry_function); \
1177  }
1178 
1179 // The operator registry. Since we are not expecting a great number of devices,
1180 // we will simply have an if-then type command and allocate the actual
1181 // generation to device-specific registerers.
1182 // Note that although we have CUDA and CUDNN here, the registerers themselves do
1183 // not depend on specific cuda or cudnn libraries. This means that we will be
1184 // able to compile it even when there is no cuda available - we simply do not
1185 // link any cuda or cudnn operators.
1186 C10_DECLARE_REGISTRY(
1187  CPUOperatorRegistry,
1188  OperatorBase,
1189  const OperatorDef&,
1190  Workspace*);
1191 #define REGISTER_CPU_OPERATOR_CREATOR(key, ...) \
1192  C10_REGISTER_CREATOR(CPUOperatorRegistry, key, __VA_ARGS__)
1193 #define REGISTER_CPU_OPERATOR(name, ...) \
1194  C10_IMPORT void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \
1195  static void CAFFE2_UNUSED CAFFE_ANONYMOUS_VARIABLE_CPU##name() { \
1196  CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \
1197  } \
1198  C10_REGISTER_CLASS(CPUOperatorRegistry, name, __VA_ARGS__)
1199 #define REGISTER_CPU_OPERATOR_STR(str_name, ...) \
1200  C10_REGISTER_TYPED_CLASS(CPUOperatorRegistry, str_name, __VA_ARGS__)
1201 
1202 #define REGISTER_CPU_OPERATOR_WITH_ENGINE(name, engine, ...) \
1203  C10_REGISTER_CLASS(CPUOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__)
1204 
1205 // Use these macros to register gradient operators. They can be automatically
1206 // excluded from builds that don't need them (e.g., mobile).
1207 #ifdef CAFFE2_NO_GRADIENT_OPS
1208 #define REGISTER_CPU_GRADIENT_OPERATOR(...) /* No gradients. */
1209 #else
1210 #define REGISTER_CPU_GRADIENT_OPERATOR(...) \
1211  MACRO_EXPAND(REGISTER_CPU_OPERATOR(__VA_ARGS__))
1212 #endif
1213 
1214 C10_DECLARE_REGISTRY(
1215  CUDAOperatorRegistry,
1216  OperatorBase,
1217  const OperatorDef&,
1218  Workspace*);
1219 #define REGISTER_CUDA_OPERATOR_CREATOR(key, ...) \
1220  C10_REGISTER_CREATOR(CUDAOperatorRegistry, key, __VA_ARGS__)
1221 #define REGISTER_CUDA_OPERATOR(name, ...) \
1222  C10_IMPORT void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \
1223  static void CAFFE2_UNUSED CAFFE_ANONYMOUS_VARIABLE_CUDA##name() { \
1224  CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \
1225  } \
1226  C10_REGISTER_CLASS(CUDAOperatorRegistry, name, __VA_ARGS__)
1227 #define REGISTER_CUDA_OPERATOR_STR(str_name, ...) \
1228  C10_REGISTER_TYPED_CLASS(CUDAOperatorRegistry, str_name, __VA_ARGS__)
1229 
1230 #define REGISTER_CUDA_OPERATOR_WITH_ENGINE(name, engine, ...) \
1231  C10_REGISTER_CLASS(CUDAOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__)
1232 
1233 // Macros for cudnn since we use it often
1234 #define REGISTER_CUDNN_OPERATOR(name, ...) \
1235  REGISTER_CUDA_OPERATOR_WITH_ENGINE(name, CUDNN, __VA_ARGS__)
1236 
1237 // Macros for HIP operators
1238 C10_DECLARE_REGISTRY(
1239  HIPOperatorRegistry,
1240  OperatorBase,
1241  const OperatorDef&,
1242  Workspace*);
1243 #define REGISTER_HIP_OPERATOR_CREATOR(key, ...) \
1244  C10_REGISTER_CREATOR(HIPOperatorRegistry, key, __VA_ARGS__)
1245 #define REGISTER_HIP_OPERATOR(name, ...) \
1246  C10_IMPORT void CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \
1247  static void CAFFE2_UNUSED CAFFE_ANONYMOUS_VARIABLE_HIP##name() { \
1248  CAFFE2_PLEASE_ADD_OPERATOR_SCHEMA_FOR_##name(); \
1249  } \
1250  C10_REGISTER_CLASS(HIPOperatorRegistry, name, __VA_ARGS__)
1251 #define REGISTER_HIP_OPERATOR_STR(str_name, ...) \
1252  C10_REGISTER_TYPED_CLASS(HIPOperatorRegistry, str_name, __VA_ARGS__)
1253 
1254 #define REGISTER_HIP_OPERATOR_WITH_ENGINE(name, engine, ...) \
1255  C10_REGISTER_CLASS(HIPOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__)
1256 
1257 #define REGISTER_MIOPEN_OPERATOR(name, ...) \
1258  REGISTER_HIP_OPERATOR_WITH_ENGINE(name, MIOPEN, __VA_ARGS__) \
1259  REGISTER_HIP_OPERATOR_WITH_ENGINE(name, CUDNN, __VA_ARGS__) // Make CUDNN an alias of MIOPEN for HIP ops
1260 
1261 // StaticLinkingProtector is a helper class that ensures that the Caffe2
1262 // library is linked correctly with whole archives (in the case of static
1263 // linking). What happens is that when CreateOperator is called for the first
1264 // time, it instantiates an OperatorLinkingProtector object to check if the
1265 // operator registry is empty. If it is empty, this means that we are not
1266 // properly linking the library.
1267 //
1268 // You should not need to use this class.
1271  const int registered_ops = CPUOperatorRegistry()->Keys().size();
1272  // Note: this is a check failure instead of an exception, because if
1273  // the linking is wrong, Caffe2 won't be able to run properly anyway,
1274  // so it's better to fail loud.
1275  // If Caffe2 is properly linked with whole archive, there should be more
1276  // than zero registered ops.
1277  if (registered_ops == 0) {
1278  LOG(FATAL) <<
1279  "You might have made a build error: the Caffe2 library does not seem "
1280  "to be linked with whole-static library option. To do so, use "
1281  "-Wl,-force_load (clang) or -Wl,--whole-archive (gcc) to link the "
1282  "Caffe2 library.";
1283  }
1284  }
1285 };
1286 
1287 // An exception that can be thrown by an operator constructor that notifies
1288 // that it does not support the given setting. This can be usually used for
1289 // specific engines that only implement a subset of the features required by
1290 // the original operator schema.
1291 // TODO(jiayq): make more feature-complete exception message.
1292 class CAFFE2_API UnsupportedOperatorFeature : public std::exception {
1293  public:
1294  UnsupportedOperatorFeature(const string& msg) : msg_(msg) {}
1295  const char* what() const noexcept override {
1296  return msg_.c_str();
1297  }
1298 
1299  private:
1300  string msg_;
1301 };
1302 
1303 // A helper macro that should ONLY be used in the operator constructor to check
1304 // if needed features are met. If not, throws the UnsupportedOperatorFeature
1305 // exception with the given message.
1306 #define OPERATOR_NEEDS_FEATURE(condition, ...) \
1307  if (!(condition)) { \
1308  throw UnsupportedOperatorFeature(::c10::str(__VA_ARGS__)); \
1309  }
1310 
1311 // Creates an operator with the given operator definition.
1312 // Throws on error and never returns nullptr
1313 CAFFE2_API unique_ptr<OperatorBase> CreateOperator(
1314  const OperatorDef& operator_def,
1315  Workspace* ws,
1316  int net_position = OperatorBase::kNoNetPositionSet);
1317 
1318 CAFFE2_API const std::string OpRegistryKey(
1319  const std::string& op_type,
1320  const std::string& engine = "");
1321 
1322 // User can set the preferred engines as a list of engine names, in
1323 // descending order of preference.
1324 using EnginePrefType = std::vector<std::string>;
1325 // {device_type -> {operator_name -> EnginePrefType}}
1326 using PerOpEnginePrefType =
1327  CaffeMap<DeviceType, CaffeMap<std::string, EnginePrefType>>;
1328 // {device_type -> EnginePrefType}
1329 using GlobalEnginePrefType = CaffeMap<DeviceType, EnginePrefType>;
1330 CAFFE2_API void SetPerOpEnginePref(const PerOpEnginePrefType& per_op_engine_pref);
1331 CAFFE2_API void SetGlobalEnginePref(const GlobalEnginePrefType& global_engine_pref);
1332 CAFFE2_API void SetEnginePref(
1333  const PerOpEnginePrefType& per_op_engine_pref,
1334  const GlobalEnginePrefType& global_engine_pref);
1335 CAFFE2_API void SetOpEnginePref(
1336  const std::string& op_type,
1337  const CaffeMap<DeviceType, EnginePrefType>& op_pref);
1338 
1339 CAFFE2_API TensorShape GetTensorShapeOfBlob(const Blob* b);
1340 
1341 CAFFE2_API TensorShapes InferBlobShapesAndTypes(
1342  CaffeMap<string, TensorShape>& blob_desc,
1343  const vector<NetDef*>& nets);
1344 
1345 CAFFE2_API TensorShapes InferBlobShapesAndTypesFromWorkspace(
1346  Workspace* ws,
1347  const vector<NetDef*>& nets);
1348 
1349 CAFFE2_API TensorShapes InferBlobShapesAndTypesFromMap(
1350  const CaffeMap<std::string, std::vector<int64_t>>& blob_dimensions,
1351  const vector<NetDef*>& nets);
1352 
1353 CAFFE2_API TensorShapes InferBlobShapesAndTypesFromMap(
1354  const CaffeMap<std::string, std::vector<int64_t>>& blob_dimensions,
1355  const CaffeMap<std::string, TensorProto_DataType>& blob_types,
1356  const vector<NetDef*>& nets);
1357 
1358 CAFFE2_API std::map<string, std::pair<DeviceOption, DeviceOption>> ValidateTensorDevices(
1359  OperatorBase& op,
1360  const OperatorDef& op_def);
1361 
1362 // Get a set of registered operator names
1363 CAFFE2_API std::set<std::string> GetRegisteredOperators();
1364 
1365 // Operator logging capabilities
1366 CAFFE2_API void SetOperatorLogger(std::function<void(const OperatorDef&)> tracer);
1367 std::function<void(const OperatorDef&)> GetOperatorLogger();
1368 
1369 } // namespace caffe2
1370 
1371 #include "caffe2/core/c10_operator.h"
1372 
1373 #endif // CAFFE2_CORE_OPERATOR_H_
Blob is a general container that hosts a typed pointer.
Definition: blob.h:24
C10_NODISCARD TensorOptions device(c10::optional< Device > device) const noexcept
Return a copy of TensorOptions with device set to the given one, or cleared if device is nullopt...
c10::optional< Device > device_opt() const noexcept
Returns the device of the TensorOptions, or c10::nullopt if device is not specified.
bool isLegacyOperator() const
Return true if the operator was instantiated with OperatorDef New operators should be instantiated wi...
Definition: operator.h:59
bool has_dtype() const noexcept
Returns whether the dtype is specified.
The primary ATen error class.
Definition: Exception.h:27
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
Inherit to make your class observable.
Definition: observer.h:45
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A template class that allows one to register classes by keys.
Definition: Registry.h:54
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
const char * what() const noexceptoverride
Returns the complete error message, including the source location.
Definition: Exception.h:70
C10_NODISCARD TensorOptions dtype(c10::optional< caffe2::TypeMeta > dtype) const noexcept
Return a copy of TensorOptions with dtype set to the given one.
Tensor XOutput(int idx, at::IntArrayRef dims, at::TensorOptions options)
XOutput is a modernized version of Output which returns a Tensor rather than a Tensor* (the raw point...
Definition: operator.h:711
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70