Caffe2 - C++ API
A deep learning, cross platform ML framework
elementwise_sum_op.cc
1 #include <caffe2/ideep/ideep_utils.h>
2 #include <caffe2/ideep/operators/operator_fallback_ideep.h>
3 #include "caffe2/operators/utility_ops.h"
4 #include "caffe2/operators/elementwise_add_op.h"
5 
6 namespace caffe2 {
7 
8 class IDEEPSumOp final : public IDEEPOperator {
9  public:
10  USE_IDEEP_DEF_ALIASES();
11  USE_IDEEP_OPERATOR_FUNCTIONS();
15 
16  IDEEPSumOp(const OperatorDef& operator_def, Workspace* ws)
17  : IDEEPOperator(operator_def, ws),
18  fallback_sum_(operator_def, ws),
19  fallback_add_(operator_def, ws) {}
20  ~IDEEPSumOp() override {}
21 
22  bool RunOnDevice() override {
23  itensor::dims input_dims;
24  bool fallback_to_cpu = false;
25  vector<itensor> inputs_itensor;
26 
27  // We only support element-wise sum for ideep tensors here.
28  // If a CPU tensor is detected in input list, we have to fallback
29  // to corresponding CPU operator.
30  for (int i = 0; i < InputSize(); ++i) {
31  if (OperatorBase::InputBlob(i).template IsType<itensor>()) {
32  auto& tensor_ideep = Input(i);
33  if (input_dims.empty()) {
34  input_dims = tensor_ideep.get_dims();
35  } else if (input_dims != tensor_ideep.get_dims()) {
36  fallback_to_cpu = true;
37  break;
38  }
39  inputs_itensor.emplace_back(tensor_ideep);
40  } else {
41  CAFFE_ENFORCE(
42  BlobIsTensorType(OperatorBase::InputBlob(i), CPU),
43  "Expect cpu tensor if not itensor");
44  fallback_to_cpu = true;
45  break;
46  }
47  }
48 
49  if (!fallback_to_cpu) {
50  auto* Y = Output(OUTPUT);
51  if (InputSize() == 1) {
52  const auto& X = Input(INPUT0);
53  ideep::direct_copy::compute(X, *Y);
54  } else {
55  const vector<float> scales(InputSize(), 1.0);
56  ideep::sum::compute(scales, inputs_itensor, *Y);
57  }
58  return true;
59  }
60 
61  if (InputSize() == 2) {
62  return fallback_add_.Run(0);
63  }
64 
65  return fallback_sum_.Run(0);
66  }
67 
68  private:
69  FALLBACK_SUM fallback_sum_;
70  FALLBACK_ADD fallback_add_;
71 
72  INPUT_TAGS(INPUT0);
73  OUTPUT_TAGS(OUTPUT);
74 };
75 
76 REGISTER_IDEEP_OPERATOR(Sum, IDEEPSumOp);
77 REGISTER_IDEEP_OPERATOR(Add, IDEEPSumOp);
78 
79 } // namespace caffe2
Definition: OpClasses.h:414
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:40
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
A templated class to allow one to wrap a CPU operator as an IDEEP operator.
Definition: OpClasses.h:659