Caffe2 - C++ API
A deep learning, cross platform ML framework
dropout_op.h
1 
17 #ifndef CAFFE2_OPERATORS_DROPOUT_OP_H_
18 #define CAFFE2_OPERATORS_DROPOUT_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/logging.h"
22 #include "caffe2/core/operator.h"
23 #include "caffe2/utils/math.h"
24 
25 namespace caffe2 {
26 
27 template <typename T, class Context>
28 class DropoutOp final : public Operator<Context> {
29  public:
30  USE_OPERATOR_CONTEXT_FUNCTIONS;
31  DropoutOp(const OperatorDef& operator_def, Workspace* ws)
32  : Operator<Context>(operator_def, ws),
33  ratio_(OperatorBase::GetSingleArgument<float>("ratio", 0.5)),
34  is_test_(
35  OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)) {
36  CAFFE_ENFORCE_GE(ratio_, 0);
37  CAFFE_ENFORCE_LT(ratio_, 1);
38  }
39 
40  bool RunOnDevice() override;
41 
42  protected:
43  float ratio_;
44  bool is_test_;
45  // Input: X; Output: Y, mask.
46 };
47 
48 template <typename T, class Context>
49 class DropoutGradientOp final : public Operator<Context> {
50  public:
51  USE_OPERATOR_CONTEXT_FUNCTIONS;
52  DropoutGradientOp(const OperatorDef& operator_def, Workspace* ws)
53  : Operator<Context>(operator_def, ws),
54  ratio_(OperatorBase::GetSingleArgument<float>("ratio", 0.5)),
55  is_test_(
56  OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)) {
57  CAFFE_ENFORCE_GE(ratio_, 0);
58  CAFFE_ENFORCE_LT(ratio_, 1);
59  }
60 
61  bool RunOnDevice() override;
62 
63  protected:
64  float ratio_;
65  bool is_test_;
66  // Input: dY, mask; Output: dX
67 };
68 
69 } // namespace caffe2
70 
71 #endif // CAFFE2_OPERATORS_DROPOUT_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.