1 #include "caffe2/operators/instance_norm_op.h" 2 #include "caffe2/utils/eigen_utils.h" 6 template <
typename T,
typename Context>
7 bool InstanceNormGradientOp<T, Context>::RunOnDeviceWithOrderNHWC() {
8 const auto& input = Input(INPUT);
9 const auto& scale = Input(SCALE);
10 const auto& bias = Input(BIAS);
11 const auto& output_grad = Input(OUTPUT_GRAD);
12 const auto& mean = InputSize() >= 5 ? Input(MEAN) : mean_;
13 const auto& inv_stdev = InputSize() >= 6 ? Input(INV_STDEV) : inv_stdev_;
15 CAFFE_ENFORCE_EQ(4, input.dim());
16 const int N = input.dim32(0);
17 const int H = input.dim32(1);
18 const int W = input.dim32(2);
19 const int C = input.dim32(3);
20 CAFFE_ENFORCE_EQ(1, scale.dim());
21 CAFFE_ENFORCE_EQ(C, scale.dim32(0));
22 CAFFE_ENFORCE_EQ(1, bias.dim());
23 CAFFE_ENFORCE_EQ(C, bias.dim32(0));
24 CAFFE_ENFORCE_EQ(4, output_grad.dim());
25 CAFFE_ENFORCE_EQ(N, output_grad.dim32(0));
26 CAFFE_ENFORCE_EQ(H, output_grad.dim32(1));
27 CAFFE_ENFORCE_EQ(W, output_grad.dim32(2));
28 CAFFE_ENFORCE_EQ(C, output_grad.dim32(3));
29 auto input_grad = Output(INPUT_GRAD, input.sizes(), at::dtype<T>());
30 auto scale_grad = Output(SCALE_GRAD, scale.sizes(), at::dtype<T>());
31 auto bias_grad = Output(BIAS_GRAD, bias.sizes(), at::dtype<T>());
33 ConstEigenVectorArrayMap<T> scale_arr(scale.template data<T>(), C);
34 ConstEigenVectorArrayMap<T> bias_arr(bias.template data<T>(), C);
35 EigenVectorArrayMap<T> scale_grad_arr(
36 scale_grad->template mutable_data<T>(), C);
37 EigenVectorArrayMap<T> bias_grad_arr(
38 bias_grad->template mutable_data<T>(), C);
41 if (InputSize() < 5) {
43 &mean_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
45 if (InputSize() < 6) {
47 &inv_stdev_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
52 for (
int n = 0; n < N; ++n) {
54 ConstEigenArrayMap<T> input_mat(
55 input.template data<T>() + n * C * H * W, C, H * W);
56 ConstEigenArrayMap<T> output_grad_mat(
57 output_grad.template data<T>() + n * C * H * W, C, H * W);
58 EigenArrayMap<T> input_grad_mat(
59 input_grad->template mutable_data<T>() + n * C * H * W, C, H * W);
62 if (InputSize() < 5) {
63 EigenVectorArrayMap<T> mean_mutable_arr(
64 mean_.template mutable_data<T>() + n * C, C);
65 mean_mutable_arr = input_mat.rowwise().mean();
67 CAFFE_ENFORCE_EQ(2, mean.dim());
68 CAFFE_ENFORCE_EQ(N, mean.dim32(0));
69 CAFFE_ENFORCE_EQ(C, mean.dim32(1));
70 ConstEigenVectorArrayMap<T> mean_arr(mean.template data<T>() + n * C, C);
73 input_grad_mat = input_mat.colwise() - mean_arr;
76 if (InputSize() < 6) {
77 EigenVectorArrayMap<T> inv_stdev_mutable_arr(
78 inv_stdev_.template mutable_data<T>() + n * C, C);
81 inv_stdev_mutable_arr = input_grad_mat.pow(2).rowwise().mean();
83 inv_stdev_mutable_arr =
84 (inv_stdev_mutable_arr + epsilon_).sqrt().inverse();
86 CAFFE_ENFORCE_EQ(2, inv_stdev.dim());
87 CAFFE_ENFORCE_EQ(N, inv_stdev.dim32(0));
88 CAFFE_ENFORCE_EQ(C, inv_stdev.dim32(1));
90 ConstEigenVectorArrayMap<T> inv_stdev_arr(
91 inv_stdev.template data<T>() + n * C, C);
95 auto bias_grad_delta = output_grad_mat.rowwise().sum();
97 bias_grad_arr = bias_grad_delta;
99 bias_grad_arr += bias_grad_delta;
103 auto scale_grad_delta =
104 ((input_grad_mat.colwise() * inv_stdev_arr) * output_grad_mat)
108 scale_grad_arr = scale_grad_delta;
110 scale_grad_arr += scale_grad_delta;
119 const auto temp = (inv_stdev_arr.pow(3) *
120 (input_grad_mat * output_grad_mat).rowwise().mean() *
123 input_grad_mat.colwise() *= temp;
126 input_grad_mat += output_grad_mat.colwise() * inv_stdev_arr;
129 const auto result_mean = input_grad_mat.rowwise().mean().eval();
130 input_grad_mat.colwise() -= result_mean;
131 input_grad_mat.colwise() *= scale_arr;
137 template <
typename T,
typename Context>
138 bool InstanceNormGradientOp<T, Context>::RunOnDeviceWithOrderNCHW() {
139 const auto& input = Input(INPUT);
140 const auto& scale = Input(SCALE);
141 const auto& bias = Input(BIAS);
142 const auto& output_grad = Input(OUTPUT_GRAD);
143 const auto& mean = InputSize() >= 5 ? Input(MEAN) : mean_;
144 const auto& inv_stdev = InputSize() >= 6 ? Input(INV_STDEV) : inv_stdev_;
146 CAFFE_ENFORCE_EQ(4, input.dim());
147 const int N = input.dim32(0);
148 const int C = input.dim32(1);
149 const int H = input.dim32(2);
150 const int W = input.dim32(3);
151 CAFFE_ENFORCE_EQ(1, scale.dim());
152 CAFFE_ENFORCE_EQ(C, scale.dim32(0));
153 CAFFE_ENFORCE_EQ(1, bias.dim());
154 CAFFE_ENFORCE_EQ(C, bias.dim32(0));
155 CAFFE_ENFORCE_EQ(4, output_grad.dim());
156 CAFFE_ENFORCE_EQ(N, output_grad.dim32(0));
157 CAFFE_ENFORCE_EQ(C, output_grad.dim32(1));
158 CAFFE_ENFORCE_EQ(H, output_grad.dim32(2));
159 CAFFE_ENFORCE_EQ(W, output_grad.dim32(3));
160 auto input_grad = Output(INPUT_GRAD, input.sizes(), at::dtype<T>());
161 auto scale_grad = Output(SCALE_GRAD, scale.sizes(), at::dtype<T>());
162 auto bias_grad = Output(BIAS_GRAD, bias.sizes(), at::dtype<T>());
164 ConstEigenArrayMap<T> input_mat(input.template data<T>(), H * W, N * C);
165 ConstEigenVectorArrayMap<T> scale_arr(scale.template data<T>(), C);
166 ConstEigenVectorArrayMap<T> bias_arr(bias.template data<T>(), C);
167 ConstEigenArrayMap<T> output_grad_mat(
168 output_grad.template data<T>(), H * W, N * C);
170 EigenArrayMap<T> input_grad_mat(
171 input_grad->template mutable_data<T>(), H * W, N * C);
172 EigenVectorArrayMap<T> scale_grad_arr(
173 scale_grad->template mutable_data<T>(), C);
174 EigenVectorArrayMap<T> bias_grad_arr(
175 bias_grad->template mutable_data<T>(), C);
178 if (InputSize() < 5) {
180 &mean_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
181 EigenVectorArrayMap<T> mean_mutable_arr(
182 mean_.template mutable_data<T>(), N * C);
183 mean_mutable_arr = input_mat.colwise().mean();
185 CAFFE_ENFORCE_EQ(2, mean.dim());
186 CAFFE_ENFORCE_EQ(N, mean.dim32(0));
187 CAFFE_ENFORCE_EQ(C, mean.dim32(1));
188 ConstEigenVectorArrayMap<T> mean_arr(mean.template data<T>(), N * C);
191 input_grad_mat = input_mat.rowwise() - mean_arr.transpose();
194 if (InputSize() < 6) {
196 &inv_stdev_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
197 EigenVectorArrayMap<T> inv_stdev_mutable_arr(
198 inv_stdev_.template mutable_data<T>(), N * C);
201 inv_stdev_mutable_arr = input_grad_mat.pow(2).colwise().mean();
203 inv_stdev_mutable_arr = (inv_stdev_mutable_arr + epsilon_).sqrt().inverse();
205 CAFFE_ENFORCE_EQ(2, inv_stdev.dim());
206 CAFFE_ENFORCE_EQ(N, inv_stdev.dim32(0));
207 CAFFE_ENFORCE_EQ(C, inv_stdev.dim32(1));
209 ConstEigenVectorArrayMap<T> inv_stdev_arr(
210 inv_stdev.template data<T>(), N * C);
217 scale_grad_arr.setZero();
218 bias_grad_arr.setZero();
219 for (
int n = 0; n < N; ++n) {
220 scale_grad_arr += ((input_grad_mat.rowwise() * inv_stdev_arr.transpose()) *
222 .block(0, n * C, H * W, C)
225 bias_grad_arr += output_grad_mat.block(0, n * C, H * W, C).colwise().sum();
229 const auto temp = ((inv_stdev_arr.pow(3).transpose() *
230 (input_grad_mat * output_grad_mat).colwise().mean()) *
232 input_grad_mat.rowwise() *= temp;
234 input_grad_mat += output_grad_mat.rowwise() * inv_stdev_arr.transpose();
236 const auto result_mean = input_grad_mat.colwise().mean().eval();
237 input_grad_mat.rowwise() -= result_mean;
239 for (
int n = 0; n < N; ++n) {
240 input_grad_mat.block(0, n * C, H * W, C).rowwise() *= scale_arr.transpose();
247 using GradientMakerBase::GradientMakerBase;
248 vector<OperatorDef> GetGradientDefs()
override {
249 vector<string> inputs{I(0), I(1), I(2), GO(0)};
250 if (def_.output_size() >= 2) {
251 inputs.push_back(O(1));
253 if (def_.output_size() >= 3) {
254 inputs.push_back(O(2));
257 "InstanceNormGradient",
260 vector<string>{GI(0), GI(1), GI(2)});
264 REGISTER_CPU_OPERATOR(
265 InstanceNormGradient,
268 OPERATOR_SCHEMA(InstanceNormGradient).NumInputs(4, 6).NumOutputs(3);
void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...