Caffe2 - C++ API
A deep learning, cross platform ML framework
dropout_op.cc
1 
17 #include "caffe2/operators/dropout_op.h"
18 
19 namespace caffe2 {
20 
21 template <>
22 bool DropoutOp<float, CPUContext>::RunOnDevice() {
23  auto& X = Input(0);
24  auto* Y = Output(0);
25  Y->Resize(X.dims());
26  if (is_test_) {
27  if (Y != &X) {
28  context_.Copy<float, CPUContext, CPUContext>(
29  X.size(), X.data<float>(), Y->mutable_data<float>());
30  }
31  return true;
32  } else {
33  float scale = 1. / (1. - ratio_);
34  // mask=true means keep, and mask=false means not keep, so we will
35  // generate probability depending on 1-ratio.
36  std::bernoulli_distribution dist(1. - ratio_);
37  const float* Xdata = X.data<float>();
38  float* Ydata = Y->mutable_data<float>();
39  auto mask = Output(1);
40  mask->Resize(X.dims());
41  bool* mask_data = mask->mutable_data<bool>();
42  auto& gen = context_.RandGenerator();
43  for (int i = 0; i < X.size(); ++i) {
44  mask_data[i] = dist(gen);
45  Ydata[i] = Xdata[i] * scale * mask_data[i];
46  }
47  return true;
48  }
49 }
50 
51 template <>
52 bool DropoutGradientOp<float, CPUContext>::RunOnDevice() {
53  auto& dY = Input(0);
54  auto* dX = Output(0);
55  dX->Resize(dY.dims());
56  if (is_test_) {
57  if (dX != &dY) {
58  context_.Copy<float, CPUContext, CPUContext>(
59  dY.size(), dY.data<float>(), dX->mutable_data<float>());
60  }
61  return true;
62  } else {
63  auto& mask = Input(1);
64  CAFFE_ENFORCE_EQ(dY.size(), mask.size());
65  const float* dYdata = dY.data<float>();
66  const bool* mask_data = mask.data<bool>();
67  float* dXdata = dX->mutable_data<float>();
68  float scale = 1. / (1. - ratio_);
69  for (int i = 0; i < dY.size(); ++i) {
70  dXdata[i] = dYdata[i] * mask_data[i] * scale;
71  }
72  return true;
73  }
74 }
75 
76 REGISTER_CPU_OPERATOR(Dropout, DropoutOp<float, CPUContext>);
77 REGISTER_CPU_OPERATOR(DropoutGrad, DropoutGradientOp<float, CPUContext>);
78 
79 OPERATOR_SCHEMA(Dropout)
80  .NumInputs(1)
81  .NumOutputs(1, 2)
82  .AllowInplace({{0, 0}})
83  .TensorInferenceFunction([](const OperatorDef& def,
84  const vector<TensorShape>& in) {
85  CAFFE_ENFORCE_EQ(1, in.size());
86  vector<TensorShape> out;
87  ArgumentHelper argsHelper(def);
88  out.push_back(in[0]);
89  auto output_mask = !argsHelper.GetSingleArgument<bool>("is_test", 0);
90  if (output_mask) {
91  out.push_back(in[0]);
92  out[1].set_data_type(TensorProto_DataType_BOOL);
93  }
94  return out;
95  })
96  .SetDoc(R"DOC(
97 Dropout takes one input data (Tensor<float>) and produces two Tensor outputs,
98 output (Tensor<float>) and mask (Tensor<bool>). Depending on whether it is in
99 test mode or not, the output Y will either be a random dropout, or a simple
100 copy of the input. Note that our implementation of Dropout does scaling in
101 the training phase, so during testing nothing needs to be done.
102 )DOC")
103  .Arg("ratio", "(float, default 0.5) the ratio of random dropout")
104  .ArgIsTest(
105  "(int) if nonzero, run dropout in test mode where "
106  "the output is simply Y = X.")
107  .Input(0, "data", "The input data as Tensor.")
108  .Output(0, "output", "The output.")
109  .Output(
110  1,
111  "mask",
112  "The output mask. If is_test is nonzero, this output is not filled.");
114 OPERATOR_SCHEMA(DropoutGrad)
115  .NumInputs(1, 2)
116  .NumOutputs(1)
117  .AllowInplace({{0, 0}});
118 
119 class GetDropoutGradient : public GradientMakerBase {
120  using GradientMakerBase::GradientMakerBase;
121  vector<OperatorDef> GetGradientDefs() override {
122  ArgumentHelper argshelper(def_);
123  auto is_test = argshelper.GetSingleArgument<bool>("is_test", 0);
124  if (is_test) {
125  return SingleGradientDef(
126  "DropoutGrad", "", vector<string>{GO(0)}, vector<string>{GI(0)});
127  } else {
128  return SingleGradientDef(
129  "DropoutGrad",
130  "",
131  vector<string>{GO(0), O(1)},
132  vector<string>{GI(0)});
133  }
134  }
135 };
136 REGISTER_GRADIENT(Dropout, GetDropoutGradient);
137 } // namespace caffe2
A helper class to index into arguments.
Definition: proto_utils.h:198
Copyright (c) 2016-present, Facebook, Inc.
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...