Caffe2 - C++ API
A deep learning, cross platform ML framework
instance_norm_gradient_op.cc
1 #include "caffe2/operators/instance_norm_op.h"
2 #include "caffe2/utils/eigen_utils.h"
3 
4 namespace caffe2 {
5 
6 template <typename T, typename Context>
7 bool InstanceNormGradientOp<T, Context>::RunOnDeviceWithOrderNHWC() {
8  const auto& input = Input(INPUT);
9  const auto& scale = Input(SCALE);
10  const auto& bias = Input(BIAS);
11  const auto& output_grad = Input(OUTPUT_GRAD);
12  const auto& mean = InputSize() >= 5 ? Input(MEAN) : mean_;
13  const auto& inv_stdev = InputSize() >= 6 ? Input(INV_STDEV) : inv_stdev_;
14 
15  CAFFE_ENFORCE_EQ(4, input.dim());
16  const int N = input.dim32(0);
17  const int H = input.dim32(1);
18  const int W = input.dim32(2);
19  const int C = input.dim32(3);
20  CAFFE_ENFORCE_EQ(1, scale.dim());
21  CAFFE_ENFORCE_EQ(C, scale.dim32(0));
22  CAFFE_ENFORCE_EQ(1, bias.dim());
23  CAFFE_ENFORCE_EQ(C, bias.dim32(0));
24  CAFFE_ENFORCE_EQ(4, output_grad.dim());
25  CAFFE_ENFORCE_EQ(N, output_grad.dim32(0));
26  CAFFE_ENFORCE_EQ(H, output_grad.dim32(1));
27  CAFFE_ENFORCE_EQ(W, output_grad.dim32(2));
28  CAFFE_ENFORCE_EQ(C, output_grad.dim32(3));
29  auto input_grad = Output(INPUT_GRAD, input.sizes(), at::dtype<T>());
30  auto scale_grad = Output(SCALE_GRAD, scale.sizes(), at::dtype<T>());
31  auto bias_grad = Output(BIAS_GRAD, bias.sizes(), at::dtype<T>());
32 
33  ConstEigenVectorArrayMap<T> scale_arr(scale.template data<T>(), C);
34  ConstEigenVectorArrayMap<T> bias_arr(bias.template data<T>(), C);
35  EigenVectorArrayMap<T> scale_grad_arr(
36  scale_grad->template mutable_data<T>(), C);
37  EigenVectorArrayMap<T> bias_grad_arr(
38  bias_grad->template mutable_data<T>(), C);
39 
40  // Resize before we get into the per-instance loop
41  if (InputSize() < 5) {
43  &mean_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
44  }
45  if (InputSize() < 6) {
47  &inv_stdev_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
48  }
49 
50  // looping over per-instance and using Eigen blocks to extract out
51  // a chunk of channels
52  for (int n = 0; n < N; ++n) {
53  // All Eigen mats and arrs in here are per-instance.
54  ConstEigenArrayMap<T> input_mat(
55  input.template data<T>() + n * C * H * W, C, H * W);
56  ConstEigenArrayMap<T> output_grad_mat(
57  output_grad.template data<T>() + n * C * H * W, C, H * W);
58  EigenArrayMap<T> input_grad_mat(
59  input_grad->template mutable_data<T>() + n * C * H * W, C, H * W);
60 
61  // Compute mean if it wasn't passed in
62  if (InputSize() < 5) {
63  EigenVectorArrayMap<T> mean_mutable_arr(
64  mean_.template mutable_data<T>() + n * C, C);
65  mean_mutable_arr = input_mat.rowwise().mean();
66  }
67  CAFFE_ENFORCE_EQ(2, mean.dim());
68  CAFFE_ENFORCE_EQ(N, mean.dim32(0));
69  CAFFE_ENFORCE_EQ(C, mean.dim32(1));
70  ConstEigenVectorArrayMap<T> mean_arr(mean.template data<T>() + n * C, C);
71 
72  // subtract mean
73  input_grad_mat = input_mat.colwise() - mean_arr;
74 
75  // Compute 1 / stdev if it wasn't passed in
76  if (InputSize() < 6) {
77  EigenVectorArrayMap<T> inv_stdev_mutable_arr(
78  inv_stdev_.template mutable_data<T>() + n * C, C);
79 
80  // Square the diffs along each channel and take the mean to get var
81  inv_stdev_mutable_arr = input_grad_mat.pow(2).rowwise().mean();
82  // sqrt to get stdev and take the inverse
83  inv_stdev_mutable_arr =
84  (inv_stdev_mutable_arr + epsilon_).sqrt().inverse();
85  }
86  CAFFE_ENFORCE_EQ(2, inv_stdev.dim());
87  CAFFE_ENFORCE_EQ(N, inv_stdev.dim32(0));
88  CAFFE_ENFORCE_EQ(C, inv_stdev.dim32(1));
89 
90  ConstEigenVectorArrayMap<T> inv_stdev_arr(
91  inv_stdev.template data<T>() + n * C, C);
92 
93  // for each channel
94  // dl/dbias = sum_j dl/dy_j
95  auto bias_grad_delta = output_grad_mat.rowwise().sum();
96  if (n == 0) {
97  bias_grad_arr = bias_grad_delta;
98  } else {
99  bias_grad_arr += bias_grad_delta;
100  }
101  // for each channel
102  // dl/dscale = sum_j dl/dy_j (x_j - mu) / stdev
103  auto scale_grad_delta =
104  ((input_grad_mat.colwise() * inv_stdev_arr) * output_grad_mat)
105  .rowwise()
106  .sum();
107  if (n == 0) {
108  scale_grad_arr = scale_grad_delta;
109  } else {
110  scale_grad_arr += scale_grad_delta;
111  }
112 
113  // dl/dx_j = this gross thing
114  // Derived gradient and manually massaged it to minimize extra storage
115  // and number of vectorized calls. Verified it with the autograd package
116  // in python.
117 
118  // a = -1/(HW) sum_j dl/dy_j * (x_j - mu) / stdev^3
119  const auto temp = (inv_stdev_arr.pow(3) *
120  (input_grad_mat * output_grad_mat).rowwise().mean() *
121  -1).eval();
122  // b_j = a * (x_j - mu)
123  input_grad_mat.colwise() *= temp;
124 
125  // c_j = b_j + dl/dy_j / stdev
126  input_grad_mat += output_grad_mat.colwise() * inv_stdev_arr;
127 
128  // dl/dx_j = s * (c_j - mean(c_j))
129  const auto result_mean = input_grad_mat.rowwise().mean().eval();
130  input_grad_mat.colwise() -= result_mean;
131  input_grad_mat.colwise() *= scale_arr;
132  }
133 
134  return true;
135 }
136 
137 template <typename T, typename Context>
138 bool InstanceNormGradientOp<T, Context>::RunOnDeviceWithOrderNCHW() {
139  const auto& input = Input(INPUT);
140  const auto& scale = Input(SCALE);
141  const auto& bias = Input(BIAS);
142  const auto& output_grad = Input(OUTPUT_GRAD);
143  const auto& mean = InputSize() >= 5 ? Input(MEAN) : mean_;
144  const auto& inv_stdev = InputSize() >= 6 ? Input(INV_STDEV) : inv_stdev_;
145 
146  CAFFE_ENFORCE_EQ(4, input.dim());
147  const int N = input.dim32(0);
148  const int C = input.dim32(1);
149  const int H = input.dim32(2);
150  const int W = input.dim32(3);
151  CAFFE_ENFORCE_EQ(1, scale.dim());
152  CAFFE_ENFORCE_EQ(C, scale.dim32(0));
153  CAFFE_ENFORCE_EQ(1, bias.dim());
154  CAFFE_ENFORCE_EQ(C, bias.dim32(0));
155  CAFFE_ENFORCE_EQ(4, output_grad.dim());
156  CAFFE_ENFORCE_EQ(N, output_grad.dim32(0));
157  CAFFE_ENFORCE_EQ(C, output_grad.dim32(1));
158  CAFFE_ENFORCE_EQ(H, output_grad.dim32(2));
159  CAFFE_ENFORCE_EQ(W, output_grad.dim32(3));
160  auto input_grad = Output(INPUT_GRAD, input.sizes(), at::dtype<T>());
161  auto scale_grad = Output(SCALE_GRAD, scale.sizes(), at::dtype<T>());
162  auto bias_grad = Output(BIAS_GRAD, bias.sizes(), at::dtype<T>());
163 
164  ConstEigenArrayMap<T> input_mat(input.template data<T>(), H * W, N * C);
165  ConstEigenVectorArrayMap<T> scale_arr(scale.template data<T>(), C);
166  ConstEigenVectorArrayMap<T> bias_arr(bias.template data<T>(), C);
167  ConstEigenArrayMap<T> output_grad_mat(
168  output_grad.template data<T>(), H * W, N * C);
169 
170  EigenArrayMap<T> input_grad_mat(
171  input_grad->template mutable_data<T>(), H * W, N * C);
172  EigenVectorArrayMap<T> scale_grad_arr(
173  scale_grad->template mutable_data<T>(), C);
174  EigenVectorArrayMap<T> bias_grad_arr(
175  bias_grad->template mutable_data<T>(), C);
176 
177  // Compute mean if it wasn't passed in
178  if (InputSize() < 5) {
180  &mean_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
181  EigenVectorArrayMap<T> mean_mutable_arr(
182  mean_.template mutable_data<T>(), N * C);
183  mean_mutable_arr = input_mat.colwise().mean();
184  }
185  CAFFE_ENFORCE_EQ(2, mean.dim());
186  CAFFE_ENFORCE_EQ(N, mean.dim32(0));
187  CAFFE_ENFORCE_EQ(C, mean.dim32(1));
188  ConstEigenVectorArrayMap<T> mean_arr(mean.template data<T>(), N * C);
189 
190  // subtract mean
191  input_grad_mat = input_mat.rowwise() - mean_arr.transpose();
192 
193  // compute 1 / stdev if not passed in
194  if (InputSize() < 6) {
196  &inv_stdev_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
197  EigenVectorArrayMap<T> inv_stdev_mutable_arr(
198  inv_stdev_.template mutable_data<T>(), N * C);
199 
200  // Square the diffs along each column and take mean to get var
201  inv_stdev_mutable_arr = input_grad_mat.pow(2).colwise().mean();
202  // sqrt to get stdev and then invert
203  inv_stdev_mutable_arr = (inv_stdev_mutable_arr + epsilon_).sqrt().inverse();
204  }
205  CAFFE_ENFORCE_EQ(2, inv_stdev.dim());
206  CAFFE_ENFORCE_EQ(N, inv_stdev.dim32(0));
207  CAFFE_ENFORCE_EQ(C, inv_stdev.dim32(1));
208 
209  ConstEigenVectorArrayMap<T> inv_stdev_arr(
210  inv_stdev.template data<T>(), N * C);
211 
212  // Visit comments in the NHWC version about these gradients. scale and bias
213  // grads are about the same, but the input grads no longer slice out one
214  // example at a time and instead vectorize across all N * C feature maps.
215 
216  // scale and bias gradients
217  scale_grad_arr.setZero();
218  bias_grad_arr.setZero();
219  for (int n = 0; n < N; ++n) {
220  scale_grad_arr += ((input_grad_mat.rowwise() * inv_stdev_arr.transpose()) *
221  output_grad_mat)
222  .block(0, n * C, H * W, C)
223  .colwise()
224  .sum();
225  bias_grad_arr += output_grad_mat.block(0, n * C, H * W, C).colwise().sum();
226  }
227 
228  // input gradient
229  const auto temp = ((inv_stdev_arr.pow(3).transpose() *
230  (input_grad_mat * output_grad_mat).colwise().mean()) *
231  -1).eval();
232  input_grad_mat.rowwise() *= temp;
233 
234  input_grad_mat += output_grad_mat.rowwise() * inv_stdev_arr.transpose();
235 
236  const auto result_mean = input_grad_mat.colwise().mean().eval();
237  input_grad_mat.rowwise() -= result_mean;
238 
239  for (int n = 0; n < N; ++n) {
240  input_grad_mat.block(0, n * C, H * W, C).rowwise() *= scale_arr.transpose();
241  }
242 
243  return true;
244 }
245 
247  using GradientMakerBase::GradientMakerBase;
248  vector<OperatorDef> GetGradientDefs() override {
249  vector<string> inputs{I(0), I(1), I(2), GO(0)};
250  if (def_.output_size() >= 2) {
251  inputs.push_back(O(1));
252  }
253  if (def_.output_size() >= 3) {
254  inputs.push_back(O(2));
255  }
256  return SingleGradientDef(
257  "InstanceNormGradient",
258  "",
259  inputs,
260  vector<string>{GI(0), GI(1), GI(2)});
261  }
262 };
263 
264 REGISTER_CPU_OPERATOR(
265  InstanceNormGradient,
267 
268 OPERATOR_SCHEMA(InstanceNormGradient).NumInputs(4, 6).NumOutputs(3);
269 
270 REGISTER_GRADIENT(InstanceNorm, GetInstanceNormGradient);
271 }
void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
Definition: tensor.cc:127
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...
Definition: static.cpp:64