Caffe2 - C++ API
A deep learning, cross platform ML framework
leaky_relu_op.h
1 
17 #pragma once
18 
19 #include "caffe2/core/context.h"
20 #include "caffe2/core/logging.h"
21 #include "caffe2/core/operator.h"
22 
23 namespace caffe2 {
24 
25 template <typename T, class Context>
26 class LeakyReluOp : public Operator<Context> {
27  public:
28  LeakyReluOp(const OperatorDef& operator_def, Workspace* ws)
29  : Operator<Context>(operator_def, ws), alpha_(0.01) {
30  if (HasArgument("alpha")) {
31  alpha_ =
32  static_cast<T>(OperatorBase::GetSingleArgument<float>("alpha", 0.01));
33  }
34  }
35 
36  USE_OPERATOR_CONTEXT_FUNCTIONS;
37 
38  bool RunOnDevice() override;
39 
40  protected:
41  T alpha_;
42 };
43 
44 template <typename T, class Context>
45 class LeakyReluGradientOp final : public Operator<Context> {
46  public:
47  LeakyReluGradientOp(const OperatorDef& operator_def, Workspace* ws)
48  : Operator<Context>(operator_def, ws), alpha_(0.01) {
49  if (HasArgument("alpha")) {
50  alpha_ =
51  static_cast<T>(OperatorBase::GetSingleArgument<float>("alpha", 0.01));
52  }
53  }
54 
55  USE_OPERATOR_CONTEXT_FUNCTIONS;
56 
57  bool RunOnDevice() override;
58 
59  protected:
60  T alpha_;
61 };
62 
63 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:52