1 #include "caffe2/operators/order_switch_ops.h" 7 REGISTER_CPU_OPERATOR(
NHWC2NCHW, NHWC2NCHWOp<float, CPUContext>);
8 REGISTER_CPU_OPERATOR(
NCHW2NHWC, NCHW2NHWCOp<float, CPUContext>);
13 .TensorInferenceFunction([](
const OperatorDef& ,
14 const std::vector<TensorShape>& in) {
16 in[0].dims_size(), 3,
"Input for NHWC2NCHW must be >= 3 dimensional");
17 std::vector<TensorShape> out(1);
18 out[0].add_dims(in[0].dims(0));
19 out[0].add_dims(in[0].dims(in[0].dims_size() - 1));
20 for (
auto i = 1; i < in[0].dims_size() - 1; ++i) {
21 out[0].add_dims(in[0].dims(i));
26 The operator switches the order of data in a tensor from NHWC- sample index N, 27 height H, width H and channels C, to the NCHW order (this is for 2D images). 28 In general, this operator switches the order of data in a tensor from N H_1 ... 29 H_k C to N C H_1 ... H_k for k-dimensional features, and currently supports 32 .Input(0, "data",
"The input data (Tensor) in the NHWC order.")
33 .Output(0,
"output",
"The output tensor (Tensor) in the NCHW order.");
38 .TensorInferenceFunction([](
const OperatorDef& ,
39 const std::vector<TensorShape>& in) {
41 in[0].dims_size(), 3,
"Input for NCHW2NHWC must be >= 3 dimensional");
42 std::vector<TensorShape> out(1);
43 out[0].add_dims(in[0].dims(0));
44 for (
auto i = 2; i < in[0].dims_size(); ++i) {
45 out[0].add_dims(in[0].dims(i));
47 out[0].add_dims(in[0].dims(1));
51 The operator switches the order of data in a tensor from NCHW- sample index N, 52 channels C, height H and width W, to the NHWC order (this is for 2D images). 53 In general, this operator switches the order of data in a tensor from N C H_1 54 ... H_k to N H_1 ... H_k C for k-dimensional features, and currently supports 57 .Input(0, "data",
"The input data (Tensor) in the NCHW order.")
58 .Output(0,
"output",
"The output tensor (Tensor) in the NHWC order.");
62 class GetNHWC2NCHWGradient :
public GradientMakerBase {
63 using GradientMakerBase::GradientMakerBase;
64 std::vector<OperatorDef> GetGradientDefs()
override {
65 return SingleGradientDef(
68 std::vector<std::string>{GO(0)},
69 std::vector<std::string>{GI(0)});
73 class GetNCHW2NHWCGradient :
public GradientMakerBase {
74 using GradientMakerBase::GradientMakerBase;
75 std::vector<OperatorDef> GetGradientDefs()
override {
76 return SingleGradientDef(
79 std::vector<std::string>{GO(0)},
80 std::vector<std::string>{GI(0)});
86 REGISTER_GRADIENT(
NHWC2NCHW, GetNHWC2NCHWGradient);
87 REGISTER_GRADIENT(
NCHW2NHWC, GetNCHW2NHWCGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...