1 #include <caffe2/ideep/ideep_utils.h> 7 USE_IDEEP_DEF_ALIASES();
8 USE_IDEEP_OPERATOR_FUNCTIONS();
12 ratio_(OperatorBase::GetSingleArgument<float>(
"ratio", 0.5)),
14 OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)) {
15 CAFFE_ENFORCE_GE(ratio_, 0);
16 CAFFE_ENFORCE_LT(ratio_, 1);
20 bool RunOnDevice()
override {
21 const auto& X = Input(INPUT);
22 auto* Y = Output(OUTPUT);
26 ideep::direct_copy::compute(X, *Y);
31 auto* mask = Output(MASK);
32 ideep::dropout_forward::compute(X, ratio_, *Y, *mask);
42 OUTPUT_TAGS(OUTPUT, MASK);
47 USE_IDEEP_DEF_ALIASES();
48 USE_IDEEP_OPERATOR_FUNCTIONS();
52 ratio_(OperatorBase::GetSingleArgument<float>(
"ratio", 0.5)),
54 OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)) {
55 CAFFE_ENFORCE_GE(ratio_, 0);
56 CAFFE_ENFORCE_LT(ratio_, 1);
60 bool RunOnDevice()
override {
61 const auto& dY = Input(OUTPUT_GRAD);
62 auto* dX = Output(INPUT_GRAD);
66 ideep::direct_copy::compute(dY, *dX);
71 const auto& mask = Input(MASK);
72 ideep::dropout_backward::compute(mask, dY, *dX);
81 INPUT_TAGS(OUTPUT_GRAD , MASK);
82 OUTPUT_TAGS(INPUT_GRAD);
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...