Caffe2 - C++ API
A deep learning, cross platform ML framework
thresholded_relu_op.h
1 
17 #ifndef CAFFE2_OPERATORS_THRESHOLDED_RELU_OP_H_
18 #define CAFFE2_OPERATORS_THRESHOLDED_RELU_OP_H_
19 
20 #include "caffe2/core/common_omp.h"
21 #include "caffe2/core/context.h"
22 #include "caffe2/core/logging.h"
23 #include "caffe2/core/operator.h"
24 
25 namespace caffe2 {
26 
27 template <typename T, class Context>
28 class ThresholdedReluOp final : public Operator<Context> {
29  public:
30  USE_OPERATOR_CONTEXT_FUNCTIONS;
31  ThresholdedReluOp(const OperatorDef& operator_def, Workspace* ws)
32  : Operator<Context>(operator_def, ws) {
33  alpha_ = OperatorBase::GetSingleArgument<T>("alpha", 1.0);
34  }
35 
36  bool RunOnDevice() override;
37 
38  protected:
39  T alpha_;
40 };
41 
42 template <typename T, class Context>
43 class ThresholdedReluGradientOp final : public Operator<Context> {
44  public:
45  USE_OPERATOR_CONTEXT_FUNCTIONS;
46  ThresholdedReluGradientOp(const OperatorDef& operator_def, Workspace* ws)
47  : Operator<Context>(operator_def, ws) {
48  alpha_ = OperatorBase::GetSingleArgument<T>("alpha", 1.0);
49  }
50 
51  bool RunOnDevice() override;
52 
53  protected:
54  T alpha_;
55 };
56 
57 } // namespace caffe2
58 
59 #endif // CAFFE2_OPERATORS_THRESHOLDED_RELU_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.