Caffe2 - C++ API
A deep learning, cross platform ML framework
relu_op.cc
1 #include <caffe2/ideep/ideep_utils.h>
2 
3 namespace caffe2 {
4 
5 class IDEEPReluOp final : public IDEEPOperator {
6  public:
7  USE_IDEEP_DEF_ALIASES();
8  USE_IDEEP_OPERATOR_FUNCTIONS();
9 
10  IDEEPReluOp(const OperatorDef& operator_def, Workspace* ws)
11  : IDEEPOperator(operator_def, ws), alpha_(0.0) {
12  // Figure out the Relu descriptor.
13  if (operator_def.type().substr(0, 4) == "Relu") {
14  alpha_ = 0.0;
15  } else if (operator_def.type().substr(0, 9) == "LeakyRelu") {
16  if (HasArgument("alpha")) {
17  alpha_ = static_cast<float>(
18  OperatorBase::GetSingleArgument<float>("alpha", 0.01));
19  }
20  } else {
21  LOG(FATAL) << "Unsupported Relu method: " << operator_def.type();
22  }
23  }
24  ~IDEEPReluOp() override {}
25 
26  bool RunOnDevice() override {
27  const auto& X = Input(INPUT);
28  auto* Y = Output(OUTPUT);
29 
30  ideep::eltwise_forward::compute(
31  X, *Y, ialgo::eltwise_relu, iprop::forward_training, alpha_);
32 
33  return true;
34  }
35 
36  private:
37  float alpha_;
38 
39  INPUT_TAGS(INPUT);
40  OUTPUT_TAGS(OUTPUT);
41 };
42 
43 class IDEEPReluGradientOp final : public IDEEPOperator {
44  public:
45  USE_IDEEP_DEF_ALIASES();
46  USE_IDEEP_OPERATOR_FUNCTIONS();
47 
48  IDEEPReluGradientOp(const OperatorDef& operator_def, Workspace* ws)
49  : IDEEPOperator(operator_def, ws), alpha_(0.0) {
50  // Figure out the Relu descriptor.
51  if (operator_def.type().substr(0, 12) == "ReluGradient") {
52  alpha_ = 0.0;
53  } else if (operator_def.type().substr(0, 17) == "LeakyReluGradient") {
54  if (HasArgument("alpha")) {
55  alpha_ = static_cast<float>(
56  OperatorBase::GetSingleArgument<float>("alpha", 0.01));
57  }
58  } else {
59  LOG(FATAL) << "Unsupported Relu method: " << operator_def.type();
60  }
61  }
62  ~IDEEPReluGradientOp() override {}
63 
64  bool RunOnDevice() override {
65  const auto& Y = Input(OUTPUT);
66  const auto& dY = Input(OUTPUT_GRAD);
67  auto* dX = Output(INPUT_GRAD);
68 
69  ideep::eltwise_backward::compute(Y, dY, *dX, ialgo::eltwise_relu, alpha_);
70 
71  return true;
72  }
73 
74  private:
75  float alpha_;
76 
77  INPUT_TAGS(OUTPUT, OUTPUT_GRAD);
78  OUTPUT_TAGS(INPUT_GRAD);
79 };
80 
81 REGISTER_IDEEP_OPERATOR(Relu, IDEEPReluOp);
82 REGISTER_IDEEP_OPERATOR(ReluGradient, IDEEPReluGradientOp);
83 
84 REGISTER_IDEEP_OPERATOR(LeakyRelu, IDEEPReluOp);
85 REGISTER_IDEEP_OPERATOR(LeakyReluGradient, IDEEPReluGradientOp);
86 
87 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: OpClasses.h:2
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70