Caffe2 - C++ API
A deep learning, cross platform ML framework
channel_shuffle_op.cc
1 #include "channel_shuffle_op.h"
2 
3 #include <array>
4 #include <string>
5 #include <vector>
6 
7 #ifdef CAFFE2_USE_MKL
8 #include <mkl.h>
9 #endif // CAFFE2_USE_MKL
10 
11 #include "caffe2/utils/math.h"
12 
13 namespace caffe2 {
14 
15 namespace {
16 
17 template <typename T>
18 void RunChannelShuffleNCHW(
19  const int N,
20  const int G,
21  const int K,
22  const int HxW,
23  const T* X,
24  T* Y,
25  CPUContext* context) {
26  const int stride = G * K * HxW;
27  for (int i = 0; i < N; ++i) {
28  if (G < K) {
29  for (int j = 0; j < G; ++j) {
30  math::CopyMatrix<T, CPUContext>(
31  K, HxW, X + j * K * HxW, HxW, Y + j * HxW, G * HxW, context);
32  }
33  } else {
34  for (int j = 0; j < K; ++j) {
35  math::CopyMatrix<T, CPUContext>(
36  G, HxW, X + j * HxW, K * HxW, Y + j * G * HxW, HxW, context);
37  }
38  }
39  X += stride;
40  Y += stride;
41  }
42 }
43 
44 template <typename T>
45 void RunChannelShuffleNHWC(
46  const int N,
47  const int G,
48  const int K,
49  const int HxW,
50  const T* X,
51  T* Y,
52  CPUContext* context) {
53  const std::array<std::int64_t, 2> dims = {G, K};
54  const std::array<std::int32_t, 2> axes = {1, 0};
55  const int M = N * HxW;
56  const int C = G * K;
57  for (int i = 0; i < M; ++i) {
58  math::Transpose<std::int64_t, T, CPUContext>(
59  2, dims.data(), axes.data(), X, Y, context);
60  X += C;
61  Y += C;
62  }
63 }
64 
65 } // namespace
66 
67 template <>
68 bool ChannelShuffleOp<float, CPUContext>::RunOnDeviceWithOrderNCHW() {
69  const auto& X = Input(0);
70 
71  auto* Y = Output(0, X.sizes(), at::dtype<float>());
72  const int N = X.dim32(0);
73  const int C = X.dim32(1);
74  const int G = group_;
75  CAFFE_ENFORCE_EQ(C % G, 0);
76  const int K = C / G;
77  const int HxW = X.numel() / (N * C);
78  const float* X_data = X.data<float>();
79  float* Y_data = Y->mutable_data<float>();
80  RunChannelShuffleNCHW<float>(N, G, K, HxW, X_data, Y_data, &context_);
81  return true;
82 } // namespace caffe2
83 
84 template <>
85 bool ChannelShuffleOp<float, CPUContext>::RunOnDeviceWithOrderNHWC() {
86  const auto& X = Input(0);
87 
88  auto* Y = Output(0, X.sizes(), at::dtype<float>());
89  const int ndim = X.dim();
90  const int N = X.dim32(0);
91  const int C = X.dim32(ndim - 1);
92  const int G = group_;
93  CAFFE_ENFORCE_EQ(C % G, 0);
94  const int K = C / G;
95  const int HxW = X.numel() / (N * C);
96  const float* X_data = X.data<float>();
97  float* Y_data = Y->mutable_data<float>();
98  RunChannelShuffleNHWC<float>(N, G, K, HxW, X_data, Y_data, &context_);
99  return true;
100 }
101 
102 template <>
103 bool ChannelShuffleGradientOp<float, CPUContext>::RunOnDeviceWithOrderNCHW() {
104  const auto& dY = Input(0);
105 
106  auto* dX = Output(0, dY.sizes(), at::dtype<float>());
107  const int N = dY.dim32(0);
108  const int C = dY.dim32(1);
109  const int G = group_;
110  CAFFE_ENFORCE_EQ(C % G, 0);
111  const int K = C / G;
112  const int HxW = dY.numel() / (N * C);
113  const float* dY_data = dY.data<float>();
114  float* dX_data = dX->mutable_data<float>();
115  RunChannelShuffleNCHW<float>(N, K, G, HxW, dY_data, dX_data, &context_);
116  return true;
117 }
118 
119 template <>
120 bool ChannelShuffleGradientOp<float, CPUContext>::RunOnDeviceWithOrderNHWC() {
121  const auto& dY = Input(0);
122 
123  auto* dX = Output(0, dY.sizes(), at::dtype<float>());
124  const int ndim = dY.dim();
125  const int N = dY.dim32(0);
126  const int C = dY.dim32(ndim - 1);
127  const int G = group_;
128  CAFFE_ENFORCE_EQ(C % G, 0);
129  const int K = C / G;
130  const int HxW = dY.numel() / (N * C);
131  const float* dY_data = dY.data<float>();
132  float* dX_data = dX->mutable_data<float>();
133  RunChannelShuffleNHWC<float>(N, K, G, HxW, dY_data, dX_data, &context_);
134  return true;
135 }
136 
137 REGISTER_CPU_OPERATOR(ChannelShuffle, ChannelShuffleOp<float, CPUContext>);
138 REGISTER_CPU_GRADIENT_OPERATOR(
139  ChannelShuffleGradient,
140  ChannelShuffleGradientOp<float, CPUContext>);
141 
142 OPERATOR_SCHEMA(ChannelShuffle)
143  .IdenticalTypeAndShape()
144  .NumInputs(1)
145  .NumOutputs(1)
146  .InheritOnnxSchema();
147 GRADIENT_OPERATOR_SCHEMA(ChannelShuffleGradient)
148  .IdenticalTypeAndShape()
149  .NumInputs(1)
150  .NumOutputs(1);
151 
152 namespace {
153 
154 class GetChannelShuffleGradient : public GradientMakerBase {
155  using GradientMakerBase::GradientMakerBase;
156  std::vector<OperatorDef> GetGradientDefs() override {
157  return SingleGradientDef(
158  "ChannelShuffleGradient",
159  "",
160  std::vector<std::string>{GO(0)},
161  std::vector<std::string>{GI(0)});
162  }
163 };
164 
165 } // namespace
166 
167 REGISTER_GRADIENT(ChannelShuffle, GetChannelShuffleGradient);
168 
169 } // namespace caffe2
Definition: any.cpp:108
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:64