Caffe2 - C++ API
A deep learning, cross platform ML framework
dropout_op.h
1 #ifndef CAFFE2_OPERATORS_DROPOUT_OP_H_
2 #define CAFFE2_OPERATORS_DROPOUT_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 template <typename T, class Context>
12 class DropoutOp final : public Operator<Context> {
13  public:
14  USE_OPERATOR_CONTEXT_FUNCTIONS;
15  template <class... Args>
16  explicit DropoutOp(Args&&... args)
17  : Operator<Context>(std::forward<Args>(args)...),
18  ratio_(this->template GetSingleArgument<float>("ratio", 0.5)),
19  is_test_(
20  this->template GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)) {
21  CAFFE_ENFORCE_GE(ratio_, 0);
22  CAFFE_ENFORCE_LT(ratio_, 1);
23  }
24 
25  bool RunOnDevice() override;
26 
27  protected:
28  float ratio_;
29  bool is_test_;
30  // Input: X; Output: Y, mask.
31 };
32 
33 template <typename T, class Context>
34 class DropoutGradientOp final : public Operator<Context> {
35  public:
36  USE_OPERATOR_CONTEXT_FUNCTIONS;
37  template <class... Args>
38  explicit DropoutGradientOp(Args&&... args)
39  : Operator<Context>(std::forward<Args>(args)...),
40  ratio_(this->template GetSingleArgument<float>("ratio", 0.5)),
41  is_test_(
42  this->template GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)) {
43  CAFFE_ENFORCE_GE(ratio_, 0);
44  CAFFE_ENFORCE_LT(ratio_, 1);
45  }
46 
47  bool RunOnDevice() override;
48 
49  protected:
50  float ratio_;
51  bool is_test_;
52  // Input: dY, mask; Output: dX
53 };
54 
55 } // namespace caffe2
56 
57 #endif // CAFFE2_OPERATORS_DROPOUT_OP_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13