Caffe2 - C++ API
A deep learning, cross platform ML framework
dropout_op.cc
1 #include <caffe2/ideep/ideep_utils.h>
2 
3 namespace caffe2 {
4 
5 class IDEEPDropoutOp final : public IDEEPOperator {
6  public:
7  USE_IDEEP_DEF_ALIASES();
8  USE_IDEEP_OPERATOR_FUNCTIONS();
9 
10  IDEEPDropoutOp(const OperatorDef& operator_def, Workspace* ws)
11  : IDEEPOperator(operator_def, ws),
12  ratio_(OperatorBase::GetSingleArgument<float>("ratio", 0.5)),
13  is_test_(
14  OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)) {
15  CAFFE_ENFORCE_GE(ratio_, 0);
16  CAFFE_ENFORCE_LT(ratio_, 1);
17  }
18  ~IDEEPDropoutOp() override {}
19 
20  bool RunOnDevice() override {
21  const auto& X = Input(INPUT);
22  auto* Y = Output(OUTPUT);
23 
24  if (is_test_) {
25  if (Y != &X) {
26  ideep::direct_copy::compute(X, *Y);
27  }
28  return true;
29  }
30 
31  auto* mask = Output(MASK);
32  ideep::dropout_forward::compute(X, ratio_, *Y, *mask);
33 
34  return true;
35  }
36 
37  private:
38  float ratio_;
39  bool is_test_;
40 
41  INPUT_TAGS(INPUT);
42  OUTPUT_TAGS(OUTPUT, MASK);
43 };
44 
45 class IDEEPDropoutGradientOp final : public IDEEPOperator {
46  public:
47  USE_IDEEP_DEF_ALIASES();
48  USE_IDEEP_OPERATOR_FUNCTIONS();
49 
50  IDEEPDropoutGradientOp(const OperatorDef& operator_def, Workspace* ws)
51  : IDEEPOperator(operator_def, ws),
52  ratio_(OperatorBase::GetSingleArgument<float>("ratio", 0.5)),
53  is_test_(
54  OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)) {
55  CAFFE_ENFORCE_GE(ratio_, 0);
56  CAFFE_ENFORCE_LT(ratio_, 1);
57  }
58  ~IDEEPDropoutGradientOp() override {}
59 
60  bool RunOnDevice() override {
61  const auto& dY = Input(OUTPUT_GRAD);
62  auto* dX = Output(INPUT_GRAD);
63 
64  if (is_test_) {
65  if (dX != &dY) {
66  ideep::direct_copy::compute(dY, *dX);
67  }
68  return true;
69  }
70 
71  const auto& mask = Input(MASK);
72  ideep::dropout_backward::compute(mask, dY, *dX);
73 
74  return true;
75  }
76 
77  protected:
78  float ratio_;
79  bool is_test_;
80 
81  INPUT_TAGS(OUTPUT_GRAD , MASK);
82  OUTPUT_TAGS(INPUT_GRAD);
83 };
84 
85 REGISTER_IDEEP_OPERATOR(Dropout, IDEEPDropoutOp);
86 REGISTER_IDEEP_OPERATOR(DropoutGrad, IDEEPDropoutGradientOp);
87 
88 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13