1 #include "channel_shuffle_op.h" 9 #endif // CAFFE2_USE_MKL 11 #include "caffe2/utils/math.h" 18 void RunChannelShuffleNCHW(
25 CPUContext* context) {
26 const int stride = G * K * HxW;
27 for (
int i = 0; i < N; ++i) {
29 for (
int j = 0; j < G; ++j) {
30 math::CopyMatrix<T, CPUContext>(
31 K, HxW, X + j * K * HxW, HxW, Y + j * HxW, G * HxW, context);
34 for (
int j = 0; j < K; ++j) {
35 math::CopyMatrix<T, CPUContext>(
36 G, HxW, X + j * HxW, K * HxW, Y + j * G * HxW, HxW, context);
45 void RunChannelShuffleNHWC(
52 CPUContext* context) {
53 const std::array<std::int64_t, 2> dims = {G, K};
54 const std::array<std::int32_t, 2> axes = {1, 0};
55 const int M = N * HxW;
57 for (
int i = 0; i < M; ++i) {
58 math::Transpose<std::int64_t, T, CPUContext>(
59 2, dims.data(), axes.data(), X, Y, context);
68 bool ChannelShuffleOp<float, CPUContext>::RunOnDeviceWithOrderNCHW() {
69 const auto& X = Input(0);
71 auto* Y = Output(0, X.sizes(), at::dtype<float>());
72 const int N = X.dim32(0);
73 const int C = X.dim32(1);
75 CAFFE_ENFORCE_EQ(C % G, 0);
77 const int HxW = X.numel() / (N * C);
78 const float* X_data = X.data<
float>();
79 float* Y_data = Y->mutable_data<
float>();
80 RunChannelShuffleNCHW<float>(N, G, K, HxW, X_data, Y_data, &context_);
85 bool ChannelShuffleOp<float, CPUContext>::RunOnDeviceWithOrderNHWC() {
86 const auto& X = Input(0);
88 auto* Y = Output(0, X.sizes(), at::dtype<float>());
89 const int ndim = X.dim();
90 const int N = X.dim32(0);
91 const int C = X.dim32(ndim - 1);
93 CAFFE_ENFORCE_EQ(C % G, 0);
95 const int HxW = X.numel() / (N * C);
96 const float* X_data = X.data<
float>();
97 float* Y_data = Y->mutable_data<
float>();
98 RunChannelShuffleNHWC<float>(N, G, K, HxW, X_data, Y_data, &context_);
103 bool ChannelShuffleGradientOp<float, CPUContext>::RunOnDeviceWithOrderNCHW() {
104 const auto& dY = Input(0);
106 auto* dX = Output(0, dY.sizes(), at::dtype<float>());
107 const int N = dY.dim32(0);
108 const int C = dY.dim32(1);
109 const int G = group_;
110 CAFFE_ENFORCE_EQ(C % G, 0);
112 const int HxW = dY.numel() / (N * C);
113 const float* dY_data = dY.data<
float>();
114 float* dX_data = dX->mutable_data<
float>();
115 RunChannelShuffleNCHW<float>(N, K, G, HxW, dY_data, dX_data, &context_);
120 bool ChannelShuffleGradientOp<float, CPUContext>::RunOnDeviceWithOrderNHWC() {
121 const auto& dY = Input(0);
123 auto* dX = Output(0, dY.sizes(), at::dtype<float>());
124 const int ndim = dY.dim();
125 const int N = dY.dim32(0);
126 const int C = dY.dim32(ndim - 1);
127 const int G = group_;
128 CAFFE_ENFORCE_EQ(C % G, 0);
130 const int HxW = dY.numel() / (N * C);
131 const float* dY_data = dY.data<
float>();
132 float* dX_data = dX->mutable_data<
float>();
133 RunChannelShuffleNHWC<float>(N, K, G, HxW, dY_data, dX_data, &context_);
137 REGISTER_CPU_OPERATOR(
ChannelShuffle, ChannelShuffleOp<float, CPUContext>);
138 REGISTER_CPU_GRADIENT_OPERATOR(
139 ChannelShuffleGradient,
140 ChannelShuffleGradientOp<float, CPUContext>);
143 .IdenticalTypeAndShape()
146 .InheritOnnxSchema();
147 GRADIENT_OPERATOR_SCHEMA(ChannelShuffleGradient)
148 .IdenticalTypeAndShape()
154 class GetChannelShuffleGradient :
public GradientMakerBase {
155 using GradientMakerBase::GradientMakerBase;
156 std::vector<OperatorDef> GetGradientDefs()
override {
157 return SingleGradientDef(
158 "ChannelShuffleGradient",
160 std::vector<std::string>{GO(0)},
161 std::vector<std::string>{GI(0)});
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...