1 #include <caffe2/ideep/operators/conv_pool_base_op.h> 7 USE_IDEEP_DEF_ALIASES();
8 USE_IDEEP_CONV_POOL_BASE_FUNCTIONS();
13 OperatorBase::GetSingleArgument<int>(
"training_mode", 1)) {
15 (dilation_h() == 1) && (dilation_w() == 1),
16 "Pooling op does not support dilation right now.");
17 if (!global_pooling_) {
19 pad_t() < kernel_h() && pad_b() < kernel_h() &&
20 pad_l() < kernel_w() && pad_r() < kernel_w(),
21 "Pad should be smaller than kernel.");
24 if (operator_def.type().substr(0, 7) ==
"MaxPool") {
25 algo_ = ialgo::pooling_max;
26 }
else if (operator_def.type().substr(0, 11) ==
"AveragePool") {
27 algo_ = ialgo::pooling_avg;
29 LOG(FATAL) <<
"Unsupported pooling method: " << operator_def.type();
34 bool RunOnDeviceWithOrderNCHW()
override {
35 auto& X = Input(INPUT);
36 auto* Y = Output(OUTPUT);
37 auto Y_dims = CalcOutputDims(X, X.get_dim(1));
38 mkldnn::prop_kind pk = training_mode_ ?
39 mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_inference;
41 ideep::pooling_forward::compute(X, Y_dims, *Y,
42 stride_, kernel_, pad_tl(), pad_br(), algo_, pk);
57 USE_IDEEP_DEF_ALIASES();
58 USE_IDEEP_CONV_POOL_BASE_FUNCTIONS();
63 (dilation_h() == 1) && (dilation_w() == 1),
64 "Pooling op does not support dilation right now.");
65 if (!global_pooling_) {
67 pad_t() < kernel_h() && pad_b() < kernel_h() &&
68 pad_l() < kernel_w() && pad_r() < kernel_w(),
69 "Pad should be smaller than kernel.");
72 if (operator_def.type().substr(0, 15) ==
"MaxPoolGradient") {
73 algo_ = ialgo::pooling_max;
74 }
else if (operator_def.type().substr(0, 19) ==
"AveragePoolGradient") {
75 algo_ = ialgo::pooling_avg;
77 LOG(FATAL) <<
"Unsupported pooling method: " << operator_def.type();
82 bool RunOnDeviceWithOrderNCHW()
override {
83 const auto& X = Input(INPUT);
84 const auto& Y = Input(OUTPUT);
85 const auto& dY = Input(OUTPUT_GRAD);
86 auto* dX = Output(INPUT_GRAD);
88 ideep::pooling_backward::compute(dY, Y, X, *dX,
89 stride_, kernel_, pad_tl(), pad_br(), algo_);
97 INPUT_TAGS(INPUT, OUTPUT, OUTPUT_GRAD);
98 OUTPUT_TAGS(INPUT_GRAD);
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...