Caffe2 - C++ API
A deep learning, cross platform ML framework
selu_op.h
1 #ifndef CAFFE2_OPERATORS_SELU_OP_H_
2 #define CAFFE2_OPERATORS_SELU_OP_H_
3 
4 #include "caffe2/core/common_omp.h"
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/logging.h"
7 #include "caffe2/core/operator.h"
8 
9 namespace caffe2 {
10 
11 template <typename T, class Context>
12 class SeluOp final : public Operator<Context> {
13  public:
14  USE_OPERATOR_CONTEXT_FUNCTIONS;
15 
16  template <class... Args>
17  explicit SeluOp(Args&&... args)
18  : Operator<Context>(std::forward<Args>(args)...) {
19  alpha_ = this->template GetSingleArgument<T>(
20  "alpha", 1.6732632423543772848170429916717f);
21  lambda_ = this->template GetSingleArgument<T>(
22  "scale", 1.0507009873554804934193349852946f);
23  // In the paper "scale" is named "lambda", but "lambda" is a reserved
24  // keyword in python
25  CAFFE_ENFORCE_GT(lambda_, 1.0);
26  }
27 
28  bool RunOnDevice() override;
29 
30  protected:
31  T alpha_;
32  T lambda_;
33 };
34 
35 template <typename T, class Context>
36 class SeluGradientOp final : public Operator<Context> {
37  public:
38  USE_OPERATOR_CONTEXT_FUNCTIONS;
39  template <class... Args>
40  explicit SeluGradientOp(Args&&... args)
41  : Operator<Context>(std::forward<Args>(args)...) {
42  alpha_ = this->template GetSingleArgument<T>(
43  "alpha", 1.6732632423543772848170429916717f);
44  lambda_ = this->template GetSingleArgument<T>(
45  "scale", 1.0507009873554804934193349852946f);
46  CAFFE_ENFORCE_GT(lambda_, 1.0);
47  }
48 
49  bool RunOnDevice() override;
50 
51  protected:
52  T alpha_;
53  T lambda_;
54 };
55 
56 } // namespace caffe2
57 
58 #endif // CAFFE2_OPERATORS_SELU_OP_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13