Caffe2 - C++ API
A deep learning, cross platform ML framework
leaky_relu_op.h
1 #pragma once
2 
3 #include "caffe2/core/context.h"
4 #include "caffe2/core/logging.h"
5 #include "caffe2/core/operator.h"
6 
7 namespace caffe2 {
8 
9 template <typename T, class Context>
10 class LeakyReluOp : public Operator<Context> {
11  public:
12  template <class... Args>
13  explicit LeakyReluOp(Args&&... args)
14  : Operator<Context>(std::forward<Args>(args)...), alpha_(0.01) {
15  if (HasArgument("alpha")) {
16  alpha_ =
17  static_cast<T>(this->template GetSingleArgument<float>("alpha", 0.01));
18  }
19  }
20 
21  USE_OPERATOR_CONTEXT_FUNCTIONS;
22 
23  bool RunOnDevice() override;
24 
25  protected:
26  T alpha_;
27 };
28 
29 template <typename T, class Context>
30 class LeakyReluGradientOp final : public Operator<Context> {
31  public:
32  template <class... Args>
33  explicit LeakyReluGradientOp(Args&&... args)
34  : Operator<Context>(std::forward<Args>(args)...), alpha_(0.01) {
35  if (HasArgument("alpha")) {
36  alpha_ =
37  static_cast<T>(this->template GetSingleArgument<float>("alpha", 0.01));
38  }
39  }
40 
41  USE_OPERATOR_CONTEXT_FUNCTIONS;
42 
43  bool RunOnDevice() override;
44 
45  protected:
46  T alpha_;
47 };
48 
49 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70