Caffe2 - C++ API
A deep learning, cross platform ML framework
elementwise_sum_relu_op.cc
1 
17 #include "caffe2/operators/utility_ops.h"
18 
19 namespace caffe2 {
20 
21 template <class Context>
22 class SumReluOp : public SumOp<Context> {
23  public:
24  USE_OPERATOR_CONTEXT_FUNCTIONS;
25  SumReluOp(const OperatorDef& operator_def, Workspace* ws)
26  : SumOp<Context>(operator_def, ws) {}
27 
28  template <typename T, typename M>
29  bool DoRunWithType() {
30  if (!SumOp<Context>::template DoRunWithType<T, M>()) {
31  return false;
32  }
33 
34  auto* output = Output(0);
35  T* output_data = output->template mutable_data<T>();
36  for (int i = 0; i < output->size(); ++i) {
37  output_data[i] = std::max(static_cast<T>(0), output_data[i]);
38  }
39  return true;
40  }
41 
42  bool RunOnDevice() override {
43  if (Input(0).template IsType<float>()) {
44  return DoRunWithType<float, float>();
45  } else if (Input(0).template IsType<int>()) {
46  return DoRunWithType<int, int>();
47  } else {
48  CAFFE_THROW(
49  "Sum operator only supports 32-bit float and ints, but",
50  " input was of type ",
51  Input(0).dtype().name());
52  }
53  }
54 };
55 
56 REGISTER_CPU_OPERATOR(SumRelu, SumReluOp<CPUContext>);
57 
58 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13