Caffe2 - C++ API
A deep learning, cross platform ML framework
selu_op.h
1 #ifndef CAFFE2_OPERATORS_SELU_OP_H_
2 #define CAFFE2_OPERATORS_SELU_OP_H_
3 
4 #include "caffe2/core/common_omp.h"
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/logging.h"
7 #include "caffe2/core/operator.h"
8 
9 namespace caffe2 {
10 
11 template <typename T, class Context>
12 class SeluOp final : public Operator<Context> {
13  public:
14  USE_OPERATOR_CONTEXT_FUNCTIONS;
15 
16  SeluOp(const OperatorDef& operator_def, Workspace* ws)
17  : Operator<Context>(operator_def, ws) {
18  alpha_ = OperatorBase::GetSingleArgument<T>(
19  "alpha", 1.6732632423543772848170429916717f);
20  lambda_ = OperatorBase::GetSingleArgument<T>(
21  "scale", 1.0507009873554804934193349852946f);
22  // In the paper "scale" is named "lambda", but "lambda" is a reserved
23  // keyword in python
24  CAFFE_ENFORCE_GT(lambda_, 1.0);
25  }
26 
27  bool RunOnDevice() override;
28 
29  protected:
30  T alpha_;
31  T lambda_;
32 };
33 
34 template <typename T, class Context>
35 class SeluGradientOp final : public Operator<Context> {
36  public:
37  USE_OPERATOR_CONTEXT_FUNCTIONS;
38  SeluGradientOp(const OperatorDef& operator_def, Workspace* ws)
39  : Operator<Context>(operator_def, ws) {
40  alpha_ = OperatorBase::GetSingleArgument<T>(
41  "alpha", 1.6732632423543772848170429916717f);
42  lambda_ = OperatorBase::GetSingleArgument<T>(
43  "scale", 1.0507009873554804934193349852946f);
44  CAFFE_ENFORCE_GT(lambda_, 1.0);
45  }
46 
47  bool RunOnDevice() override;
48 
49  protected:
50  T alpha_;
51  T lambda_;
52 };
53 
54 } // namespace caffe2
55 
56 #endif // CAFFE2_OPERATORS_SELU_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.