1 #include <caffe2/ideep/ideep_utils.h> 7 USE_IDEEP_DEF_ALIASES();
8 USE_IDEEP_OPERATOR_FUNCTIONS();
12 is_test_(OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)),
13 epsilon_(OperatorBase::GetSingleArgument<float>(
"epsilon", 1e-5)),
14 momentum_(OperatorBase::GetSingleArgument<float>(
"momentum", 0.9)) {
16 (is_test_ && OutputSize() > OUTPUT)
17 || (!is_test_ && OutputSize() > SAVED_VAR));
18 CAFFE_ENFORCE_GT(epsilon_, 0);
19 CAFFE_ENFORCE_GE(momentum_, 0);
20 CAFFE_ENFORCE_LE(momentum_, 1);
24 bool RunOnDevice()
override {
25 const auto& X = Input(INPUT);
26 const auto& scale = Input(SCALE);
27 const auto& bias = Input(BIAS);
28 auto* Y = Output(OUTPUT);
30 DCHECK_EQ(scale.ndims(), 1);
31 DCHECK_EQ(bias.ndims(), 1);
32 DCHECK_EQ(scale.get_dim(0), X.get_dim(1));
33 DCHECK_EQ(bias.get_dim(0), X.get_dim(1));
36 const auto& est_mean = Input(EST_MEAN);
37 const auto& est_var = Input(EST_VAR);
38 ideep::batch_normalization_forward_inference::compute(
39 X, est_mean, est_var, scale, bias, *Y, epsilon_);
41 auto* saved_mean = Output(SAVED_MEAN);
42 auto* saved_var = Output(SAVED_VAR);
43 auto* running_mean = Output(RUNNING_MEAN);
44 auto* running_var = Output(RUNNING_VAR);
45 ideep::batch_normalization_forward_training::compute(
46 X, scale, bias, *Y, *saved_mean, *saved_var,
47 *running_mean, *running_var, momentum_, epsilon_);
58 INPUT_TAGS(INPUT, SCALE, BIAS, EST_MEAN, EST_VAR);
59 OUTPUT_TAGS(OUTPUT, RUNNING_MEAN, RUNNING_VAR, SAVED_MEAN, SAVED_VAR);
64 USE_IDEEP_DEF_ALIASES();
65 USE_IDEEP_OPERATOR_FUNCTIONS();
69 epsilon_(OperatorBase::GetSingleArgument<float>(
"epsilon", 1e-5)) {
70 CAFFE_ENFORCE(InputSize() > SAVED_VAR);
71 CAFFE_ENFORCE(OutputSize() > BIAS_GRAD);
75 bool RunOnDevice()
override {
76 const auto& X = Input(INPUT);
77 const auto& scale = Input(SCALE);
78 const auto& dY = Input(OUTPUT_GRAD);
79 const auto& saved_mean = Input(SAVED_MEAN);
80 const auto& saved_var = Input(SAVED_VAR);
81 auto* dX = Output(INPUT_GRAD);
82 auto* dscale = Output(SCALE_GRAD);
83 auto* dbias = Output(BIAS_GRAD);
85 ideep::batch_normalization_backward::compute(
86 X, saved_mean, saved_var, dY, scale,
87 *dX, *dscale, *dbias, epsilon_);
95 INPUT_TAGS(INPUT, SCALE, OUTPUT_GRAD, SAVED_MEAN, SAVED_VAR);
96 OUTPUT_TAGS(INPUT_GRAD, SCALE_GRAD, BIAS_GRAD);
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...