1 #ifndef CAFFE2_OPERATORS_INSTANCE_NORM_OP_H_ 2 #define CAFFE2_OPERATORS_INSTANCE_NORM_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/utils/math.h" 10 template <
typename T,
class Context>
13 USE_OPERATOR_CONTEXT_FUNCTIONS;
14 template <
class... Args>
17 epsilon_(this->
template GetSingleArgument<T>(
"epsilon", 1e-5f)),
18 order_(StringToStorageOrder(
19 this->
template GetSingleArgument<string>(
"order",
"NCHW"))) {
20 CAFFE_ENFORCE(epsilon_ >= 0,
"Must pass a nonnegative epsilon.");
26 case StorageOrder::NHWC:
27 return RunOnDeviceWithOrderNHWC();
28 case StorageOrder::NCHW:
29 return RunOnDeviceWithOrderNCHW();
31 CAFFE_THROW(
"Unknown storage order: ", order_);
35 bool RunOnDeviceWithOrderNHWC();
36 bool RunOnDeviceWithOrderNCHW();
44 Tensor mean_{Context::GetDeviceType()};
45 Tensor inv_stdev_{Context::GetDeviceType()};
47 INPUT_TAGS(INPUT, SCALE, BIAS);
48 OUTPUT_TAGS(OUTPUT, MEAN, INV_STDEV);
51 template <
typename T,
class Context>
54 USE_OPERATOR_CONTEXT_FUNCTIONS;
55 template <
class... Args>
58 epsilon_(this->
template GetSingleArgument<T>(
"epsilon", 1e-5f)),
59 order_(StringToStorageOrder(
60 this->
template GetSingleArgument<string>(
"order",
"NCHW"))) {
61 CAFFE_ENFORCE(epsilon_ >= 0,
"Must pass a nonnegative epsilon.");
63 ~InstanceNormGradientOp() {}
67 case StorageOrder::NHWC:
68 return RunOnDeviceWithOrderNHWC();
69 case StorageOrder::NCHW:
70 return RunOnDeviceWithOrderNCHW();
72 CAFFE_THROW(
"Unknown storage order: ", order_);
76 bool RunOnDeviceWithOrderNHWC();
77 bool RunOnDeviceWithOrderNCHW();
89 INPUT_TAGS(INPUT, SCALE, BIAS, OUTPUT_GRAD, MEAN, INV_STDEV);
90 OUTPUT_TAGS(INPUT_GRAD, SCALE_GRAD, BIAS_GRAD);
95 #endif // CAFFE2_OPERATORS_INSTANCE_NORM_OP_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...