Caffe2 - C++ API
A deep learning, cross platform ML framework
locally_connected_op_impl.h
1 // locally_connected_impl.h is the templated implementation of the
2 // locally_connected.h file.
3 
4 #ifndef CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_IMPL_H_
5 #define CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_IMPL_H_
6 
7 #include <vector>
8 
9 #include "caffe2/core/context.h"
10 #include "caffe2/core/flags.h"
11 #include "caffe2/core/logging.h"
12 #include "caffe2/core/operator.h"
13 #include "caffe2/operators/conv_pool_op_base.h"
14 #include "caffe2/operators/locally_connected_op.h"
15 #include "caffe2/utils/math.h"
16 
17 namespace caffe2 {
18 
19 template <typename T, class Context>
20 bool LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNCHW() {
21  const auto& X = Input(INPUT);
22  const auto& filter = Input(FILTER);
23  auto* Y = Output(0);
24  const int image_ndim = X.dim() - 2;
25  CAFFE_ENFORCE_EQ(X.dim() + image_ndim, filter.dim());
26  lc_op_util::ShapeParams shape;
27  shape.N = X.dim32(0);
28  shape.C = X.dim32(1);
29  shape.M = filter.dim32(image_ndim);
30  CAFFE_ENFORCE(
31  shape.C == filter.dim32(image_ndim + 1) * group_,
32  "Locally Connected op: input channels does not match: "
33  "# of input channels ",
34  shape.C,
35  " is not equal to kernel channels * group:",
36  filter.dim32(image_ndim + 1),
37  "*",
38  group_);
39  CAFFE_ENFORCE_EQ(
40  shape.M % group_,
41  0,
42  "The number of output channels is not divisible by group.");
43 
44  ConvPoolOpBase<Context>::SetOutputSize(X, Y, shape.M);
45  shape.input_image_size = GetDimsSize(X);
46  shape.output_image_size = GetDimsSize(*Y);
47  const std::vector<int> output_image_dims = GetDims(*Y);
48  for (int i = 0; i < image_ndim; ++i) {
49  CAFFE_ENFORCE_EQ(output_image_dims[i], filter.dim32(i));
50  }
51 
52  int kernel_dims_size = 1;
53  for (std::size_t i = 0; i < kernel_.size(); ++i) {
54  CAFFE_ENFORCE_EQ(filter.dim32(i + image_ndim + 2), kernel_[i]);
55  kernel_dims_size *= kernel_[i];
56  }
57 
58  shape.X_dims.assign(X.sizes().cbegin() + 1, X.sizes().cend());
59  shape.kernel_size = shape.C / group_ * kernel_dims_size;
60  lc_op_util::SetColumnBufferShape(
61  shape.N,
62  shape.kernel_size,
63  shape.output_image_size,
64  output_image_dims,
65  order_,
66  &shape.column_slice_dims,
67  &shape.column_dims,
68  &shape.column_transposed_dims,
69  &shape.column_axes);
70  lc_op_util::SetYBufferShape(
71  shape.N,
72  shape.M,
73  shape.output_image_size,
74  order_,
75  &shape.Y_dims,
76  &shape.Y_transposed_dims,
77  &shape.Y_axes);
78 
79  const T* X_data = X.template data<T>();
80  const T* filter_data = filter.template data<T>();
81  const T* bias_data = nullptr;
82  if (InputSize() == 3) {
83  const auto& bias = Input(BIAS);
84  CAFFE_ENFORCE_EQ(bias.dim(), image_ndim + 1);
85  for (int i = 0; i < image_ndim; ++i) {
86  CAFFE_ENFORCE_EQ(bias.dim32(i), output_image_dims[i]);
87  }
88  CAFFE_ENFORCE_EQ(bias.dim32(image_ndim), shape.M);
89  bias_data = bias.template data<T>();
90  ConvPoolOpBase<Context>::template SetBiasMultiplier<T>(
91  shape.N, &bias_multiplier_);
92  }
93  T* Y_data = Y->template mutable_data<T>();
94 
95  RunOnDeviceWithOrderNCHWImpl(
96  shape,
97  X_data,
98  filter_data,
99  bias_data,
100  Y_data,
101  &column_buffer_,
102  &column_transposed_buffer_,
103  &Y_transposed_buffer_);
104 
105  return true;
106 }
107 
108 template <typename T, class Context>
109 bool LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNHWC() {
110  const auto& X = Input(INPUT);
111  const auto& filter = Input(FILTER);
112  auto* Y = Output(0);
113  CAFFE_ENFORCE_EQ(
114  kernel_.size(),
115  2,
116  "Only 2d locally connected op is supported for NHWC storage type.");
117  const int image_ndim = X.dim() - 2;
118  CAFFE_ENFORCE_EQ(X.dim() + image_ndim, filter.dim());
119  lc_op_util::ShapeParams shape;
120  shape.N = X.dim32(0);
121  shape.C = X.dim32(3);
122  shape.X_dims = {X.dim32(1), X.dim32(2), X.dim32(3)};
123  shape.M = filter.dim32(image_ndim);
124  CAFFE_ENFORCE_EQ(filter.dim32(image_ndim + 1), kernel_h());
125  CAFFE_ENFORCE_EQ(filter.dim32(image_ndim + 2), kernel_w());
126  CAFFE_ENFORCE_EQ(filter.dim32(image_ndim + 3), shape.C);
127  ConvPoolOpBase<Context>::SetOutputSize(X, Y, shape.M);
128 
129  shape.input_image_size = GetDimsSize(X);
130  shape.output_image_size = GetDimsSize(*Y);
131  const std::vector<int> output_image_dims = GetDims(*Y);
132  for (int i = 0; i < image_ndim; ++i) {
133  CAFFE_ENFORCE_EQ(output_image_dims[i], filter.dim32(i));
134  }
135 
136  shape.kernel_size = kernel_h() * kernel_w() * shape.C;
137  lc_op_util::SetColumnBufferShape(
138  shape.N,
139  shape.kernel_size,
140  shape.output_image_size,
141  output_image_dims,
142  order_,
143  &shape.column_slice_dims,
144  &shape.column_dims,
145  &shape.column_transposed_dims,
146  &shape.column_axes);
147  lc_op_util::SetYBufferShape(
148  shape.N,
149  shape.M,
150  shape.output_image_size,
151  order_,
152  &shape.Y_dims,
153  &shape.Y_transposed_dims,
154  &shape.Y_axes);
155 
156  const T* X_data = X.template data<T>();
157  const T* filter_data = filter.template data<T>();
158  const T* bias_data = nullptr;
159  if (InputSize() == 3) {
160  const auto& bias = Input(BIAS);
161  CAFFE_ENFORCE_EQ(bias.dim(), image_ndim + 1);
162  for (int i = 0; i < image_ndim; ++i) {
163  CAFFE_ENFORCE_EQ(bias.dim32(i), output_image_dims[i]);
164  }
165  CAFFE_ENFORCE_EQ(bias.dim32(image_ndim), shape.M);
166  bias_data = bias.template data<T>();
167  ConvPoolOpBase<Context>::template SetBiasMultiplier<T>(
168  shape.N, &bias_multiplier_);
169  }
170  T* Y_data = Y->template mutable_data<T>();
171 
172  RunOnDeviceWithOrderNHWCImpl(
173  shape,
174  X_data,
175  filter_data,
176  bias_data,
177  Y_data,
178  &column_buffer_,
179  &column_transposed_buffer_,
180  &Y_transposed_buffer_);
181 
182  return true;
183 }
184 
185 template <typename T, class Context>
186 void LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNCHWImpl(
187  const lc_op_util::ShapeParams& shape,
188  const T* X_data,
189  const T* filter_data,
190  const T* bias_data,
191  T* Y_data,
192  Tensor* column_buffer,
193  Tensor* column_transposed_buffer,
194  Tensor* Y_transposed_buffer) {
195  const int input_stride = shape.C / group_ * shape.input_image_size;
196  const int column_stride = shape.kernel_size * shape.output_image_size;
197  column_buffer->Resize(shape.column_dims);
198  column_transposed_buffer->Resize(shape.column_transposed_dims);
199  Y_transposed_buffer->Resize(shape.Y_transposed_dims);
200  T* column_buffer_data = column_buffer->template mutable_data<T>();
201  T* Y_transposed_buffer_data = Y_transposed_buffer->template mutable_data<T>();
202 
203  for (int image_id = 0; image_id < shape.N; ++image_id) {
204  for (int group_id = 0; group_id < group_; ++group_id) {
205  if (kernel_.size() == 2) {
206  math::Im2Col<T, Context, StorageOrder::NCHW>(
207  shape.C / group_,
208  shape.X_dims[1],
209  shape.X_dims[2],
210  kernel_h(),
211  kernel_w(),
212  dilation_h(),
213  dilation_w(),
214  pad_t(),
215  pad_l(),
216  pad_b(),
217  pad_r(),
218  stride_h(),
219  stride_w(),
220  X_data + group_id * input_stride,
221  column_buffer_data + group_id * column_stride,
222  &context_);
223  } else {
224  math::Im2ColNd<T, Context, StorageOrder::NCHW>(
225  kernel_.size(),
226  shape.C * shape.input_image_size,
227  column_stride,
228  shape.X_dims.data(),
229  shape.column_slice_dims.data(),
230  kernel_.data(),
231  stride_.data(),
232  dilation_.data(),
233  pads_.data(),
234  X_data + group_id * input_stride,
235  column_buffer_data + group_id * column_stride,
236  &context_);
237  }
238  }
239  X_data += input_stride * group_;
240  column_buffer_data += column_stride * group_;
241  }
242  math::Transpose(
243  shape.column_dims.size(),
244  shape.column_dims.data(),
245  shape.column_axes.data(),
246  column_buffer->template data<T>(),
247  column_transposed_buffer->template mutable_data<T>(),
248  &context_);
249  math::GemmStridedBatched(
250  CblasNoTrans,
251  CblasNoTrans,
252  shape.output_image_size * group_,
253  shape.M / group_,
254  shape.N,
255  shape.kernel_size,
256  1.0f,
257  filter_data,
258  shape.M / group_ * shape.kernel_size,
259  column_transposed_buffer->template data<T>(),
260  shape.kernel_size * shape.N,
261  0.0f,
262  Y_transposed_buffer_data,
263  shape.M / group_ * shape.N,
264  &context_);
265  if (bias_data != nullptr) {
266  math::Gemm<T, Context>(
267  CblasNoTrans,
268  CblasNoTrans,
269  shape.output_image_size * shape.M,
270  shape.N,
271  1,
272  1.0,
273  bias_data,
274  bias_multiplier_.template data<T>(),
275  1.0,
276  Y_transposed_buffer_data,
277  &context_);
278  }
279  math::Transpose(
280  shape.Y_transposed_dims.size(),
281  shape.Y_transposed_dims.data(),
282  shape.Y_axes.data(),
283  Y_transposed_buffer_data,
284  Y_data,
285  &context_);
286 }
287 
288 template <typename T, class Context>
289 void LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNHWCImpl(
290  const lc_op_util::ShapeParams& shape,
291  const T* X_data,
292  const T* filter_data,
293  const T* bias_data,
294  T* Y_data,
295  Tensor* column_buffer,
296  Tensor* column_transposed_buffer,
297  Tensor* Y_transposed_buffer) {
298  const int input_stride = shape.C * shape.input_image_size;
299  const int column_stride = shape.kernel_size * shape.output_image_size;
300  column_buffer->Resize(shape.column_dims);
301  column_transposed_buffer->Resize(shape.column_transposed_dims);
302  Y_transposed_buffer->Resize(shape.Y_transposed_dims);
303  T* column_buffer_data = column_buffer->template mutable_data<T>();
304  T* Y_transposed_buffer_data = Y_transposed_buffer->template mutable_data<T>();
305  for (int image_id = 0; image_id < shape.N; ++image_id) {
306  math::Im2Col<T, Context, StorageOrder::NHWC>(
307  shape.C,
308  shape.X_dims[0],
309  shape.X_dims[1],
310  kernel_h(),
311  kernel_w(),
312  dilation_h(),
313  dilation_w(),
314  pad_t(),
315  pad_l(),
316  pad_b(),
317  pad_r(),
318  stride_h(),
319  stride_w(),
320  X_data + image_id * input_stride,
321  column_buffer_data + image_id * column_stride,
322  &context_);
323  }
324  math::Transpose(
325  shape.column_dims.size(),
326  shape.column_dims.data(),
327  shape.column_axes.data(),
328  column_buffer->template data<T>(),
329  column_transposed_buffer->template mutable_data<T>(),
330  &context_);
331  math::GemmStridedBatched(
332  CblasNoTrans,
333  CblasTrans,
334  shape.output_image_size,
335  shape.N,
336  shape.M,
337  shape.kernel_size,
338  1.0f,
339  column_transposed_buffer->template data<T>(),
340  shape.N * shape.kernel_size,
341  filter_data,
342  shape.kernel_size * shape.M,
343  0.0f,
344  Y_transposed_buffer_data,
345  shape.N * shape.M,
346  &context_);
347  math::Transpose(
348  shape.Y_transposed_dims.size(),
349  shape.Y_transposed_dims.data(),
350  shape.Y_axes.data(),
351  Y_transposed_buffer_data,
352  Y_data,
353  &context_);
354  if (bias_data != nullptr) {
355  math::Gemm<T, Context>(
356  CblasNoTrans,
357  CblasNoTrans,
358  shape.N,
359  shape.output_image_size * shape.M,
360  1,
361  1.0f,
362  bias_multiplier_.template data<T>(),
363  bias_data,
364  1.0f,
365  Y_data,
366  &context_);
367  }
368 }
369 
370 template <typename T, class Context>
371 bool LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNCHW() {
372  const auto& X = Input(INPUT);
373  const auto& filter = Input(FILTER);
374  const auto& dY = Input(OUTPUT_GRAD);
375 
376  const int image_ndim = X.dim() - 2;
377  CAFFE_ENFORCE_EQ(X.dim() + image_ndim, filter.dim());
378 
379  lc_op_util::ShapeParams shape;
380  shape.N = X.dim32(0);
381  shape.C = X.dim32(1);
382  shape.M = filter.dim32(image_ndim);
383  CAFFE_ENFORCE_EQ(filter.dim32(image_ndim + 1) * group_, shape.C);
384  CAFFE_ENFORCE_EQ(shape.M % group_, 0);
385 
386  const std::vector<int> input_image_dims = GetDims(X);
387  shape.input_image_size = GetDimsSize(X);
388  const std::vector<int> output_image_dims = GetDims(dY);
389  shape.output_image_size = GetDimsSize(dY);
390  for (int i = 0; i < image_ndim; ++i) {
391  CAFFE_ENFORCE_EQ(output_image_dims[i], filter.dim32(i));
392  }
393  ConvPoolOpBase<Context>::ComputePads(input_image_dims);
394 
395  int kernel_dims_size = 1;
396  for (std::size_t i = 0; i < kernel_.size(); ++i) {
397  CAFFE_ENFORCE_EQ(filter.dim32(i + image_ndim + 2), kernel_[i]);
398  kernel_dims_size *= kernel_[i];
399  }
400 
401  shape.X_dims.assign(X.sizes().cbegin() + 1, X.sizes().cend());
402  shape.kernel_size = shape.C / group_ * kernel_dims_size;
403  lc_op_util::SetColumnBufferShape(
404  shape.N,
405  shape.kernel_size,
406  shape.output_image_size,
407  output_image_dims,
408  order_,
409  &shape.column_slice_dims,
410  &shape.column_dims,
411  &shape.column_transposed_dims,
412  &shape.column_axes);
413  lc_op_util::SetYBufferShape(
414  shape.N,
415  shape.M,
416  shape.output_image_size,
417  order_,
418  &shape.Y_dims,
419  &shape.Y_transposed_dims,
420  &shape.Y_axes);
421 
422  auto* dfilter = Output(FILTER_GRAD, filter.sizes(), at::dtype<T>());
423  const T* X_data = X.template data<T>();
424  const T* filter_data = filter.template data<T>();
425  const T* dY_data = dY.template data<T>();
426  T* dfilter_data = dfilter->template mutable_data<T>();
427  T* dX_data = nullptr;
428  T* dbias_data = nullptr;
429  if (OutputSize() == 3 || (no_bias_ && OutputSize() == 2)) {
430  auto* dX = Output(
431  no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD, X.sizes(), at::dtype<T>());
432  dX_data = dX->template mutable_data<T>();
433  }
434  if (!no_bias_) {
435  std::vector<int64_t> dbias_dims;
436  std::copy(
437  output_image_dims.cbegin(),
438  output_image_dims.cend(),
439  std::back_inserter(dbias_dims));
440  dbias_dims.push_back(shape.M);
441  auto* dbias = Output(BIAS_OR_INPUT_GRAD, dbias_dims, at::dtype<T>());
442  ConvPoolOpBase<Context>::template SetBiasMultiplier<T>(
443  shape.N, &bias_multiplier_);
444  dbias_data = dbias->template mutable_data<T>();
445  }
446  RunOnDeviceWithOrderNCHWImpl(
447  shape,
448  X_data,
449  filter_data,
450  dY_data,
451  dfilter_data,
452  dX_data,
453  dbias_data,
454  &column_buffer_,
455  &column_transposed_buffer_,
456  &dY_transposed_buffer_);
457 
458  return true;
459 }
460 
461 template <typename T, class Context>
462 bool LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNHWC() {
463  const auto& X = Input(INPUT);
464  const auto& filter = Input(FILTER);
465  const auto& dY = Input(OUTPUT_GRAD);
466 
467  CAFFE_ENFORCE_EQ(
468  kernel_.size(),
469  2,
470  "Only 2d locally connected op is supported for NHWC storage type.");
471  const int image_ndim = X.dim() - 2;
472  CAFFE_ENFORCE_EQ(X.dim() + image_ndim, filter.dim());
473  lc_op_util::ShapeParams shape;
474  shape.N = X.dim32(0);
475  shape.C = X.dim32(3);
476  shape.X_dims = {X.dim32(1), X.dim32(2), X.dim32(3)};
477  shape.M = filter.dim32(image_ndim);
478  CAFFE_ENFORCE_EQ(filter.dim32(image_ndim + 1), kernel_h());
479  CAFFE_ENFORCE_EQ(filter.dim32(image_ndim + 2), kernel_w());
480  CAFFE_ENFORCE_EQ(filter.dim32(image_ndim + 3), shape.C);
481  const std::vector<int> input_image_dims = {X.dim32(1), X.dim32(2)};
482  ConvPoolOpBase<Context>::ComputePads(input_image_dims);
483 
484  shape.input_image_size = GetDimsSize(X);
485  shape.output_image_size = GetDimsSize(dY);
486  const std::vector<int> output_image_dims = GetDims(dY);
487  for (int i = 0; i < image_ndim; ++i) {
488  CAFFE_ENFORCE_EQ(output_image_dims[i], filter.dim32(i));
489  }
490 
491  shape.kernel_size = kernel_h() * kernel_w() * shape.C;
492  lc_op_util::SetColumnBufferShape(
493  shape.N,
494  shape.kernel_size,
495  shape.output_image_size,
496  output_image_dims,
497  order_,
498  &shape.column_slice_dims,
499  &shape.column_dims,
500  &shape.column_transposed_dims,
501  &shape.column_axes);
502  lc_op_util::SetYBufferShape(
503  shape.N,
504  shape.M,
505  shape.output_image_size,
506  order_,
507  &shape.Y_dims,
508  &shape.Y_transposed_dims,
509  &shape.Y_axes);
510 
511  auto* dfilter = Output(FILTER_GRAD, filter.sizes(), at::dtype<T>());
512  const T* X_data = X.template data<T>();
513  const T* filter_data = filter.template data<T>();
514  const T* dY_data = dY.template data<T>();
515  T* dfilter_data = dfilter->template mutable_data<T>();
516  T* dX_data = nullptr;
517  T* dbias_data = nullptr;
518  if (OutputSize() == 3 || (no_bias_ && OutputSize() == 2)) {
519  auto* dX = Output(
520  no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD, X.sizes(), at::dtype<T>());
521  dX_data = dX->template mutable_data<T>();
522  }
523  if (!no_bias_) {
524  std::vector<int64_t> dbias_dims;
525  std::copy(
526  output_image_dims.cbegin(),
527  output_image_dims.cend(),
528  std::back_inserter(dbias_dims));
529  dbias_dims.push_back(shape.M);
530  auto* dbias = Output(BIAS_OR_INPUT_GRAD, dbias_dims, at::dtype<T>());
531  ConvPoolOpBase<Context>::template SetBiasMultiplier<T>(
532  shape.N, &bias_multiplier_);
533  dbias_data = dbias->template mutable_data<T>();
534  }
535  RunOnDeviceWithOrderNHWCImpl(
536  shape,
537  X_data,
538  filter_data,
539  dY_data,
540  dfilter_data,
541  dX_data,
542  dbias_data,
543  &column_buffer_,
544  &column_transposed_buffer_,
545  &dY_transposed_buffer_);
546 
547  return true;
548 }
549 
550 template <typename T, class Context>
551 void LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNCHWImpl(
552  const lc_op_util::ShapeParams& shape,
553  const T* X_data,
554  const T* filter_data,
555  const T* dY_data,
556  T* dfilter_data,
557  T* dX_data,
558  T* dbias_data,
559  Tensor* column_buffer,
560  Tensor* column_transposed_buffer,
561  Tensor* dY_transposed_buffer) {
562  const int input_stride = shape.C * shape.input_image_size;
563  const int column_stride = shape.kernel_size * shape.output_image_size;
564  column_buffer->Resize(shape.column_dims);
565  column_transposed_buffer->Resize(shape.column_transposed_dims);
566  dY_transposed_buffer->Resize(shape.Y_transposed_dims);
567  T* column_buffer_data = column_buffer->template mutable_data<T>();
568  T* dY_transposed_buffer_data =
569  dY_transposed_buffer->template mutable_data<T>();
570 
571  for (int image_id = 0; image_id < shape.N; ++image_id) {
572  for (int group_id = 0; group_id < group_; ++group_id) {
573  if (kernel_.size() == 2) {
574  math::Im2Col<T, Context, StorageOrder::NCHW>(
575  shape.C / group_,
576  shape.X_dims[1],
577  shape.X_dims[2],
578  kernel_h(),
579  kernel_w(),
580  dilation_h(),
581  dilation_w(),
582  pad_t(),
583  pad_l(),
584  pad_b(),
585  pad_r(),
586  stride_h(),
587  stride_w(),
588  X_data + group_id * input_stride,
589  column_buffer_data + group_id * column_stride,
590  &context_);
591  } else {
592  math::Im2ColNd<T, Context, StorageOrder::NCHW>(
593  kernel_.size(),
594  shape.C * shape.input_image_size,
595  column_stride,
596  shape.X_dims.data(),
597  shape.column_slice_dims.data(),
598  kernel_.data(),
599  stride_.data(),
600  dilation_.data(),
601  pads_.data(),
602  X_data + group_id * input_stride,
603  column_buffer_data + group_id * column_stride,
604  &context_);
605  }
606  }
607  X_data += input_stride * group_;
608  column_buffer_data += column_stride * group_;
609  }
610  math::Transpose(
611  shape.column_dims.size(),
612  shape.column_dims.data(),
613  shape.column_axes.data(),
614  column_buffer->template data<T>(),
615  column_transposed_buffer->template mutable_data<T>(),
616  &context_);
617 
618  math::Transpose(
619  shape.Y_dims.size(),
620  shape.Y_dims.data(),
621  shape.Y_axes.data(),
622  dY_data,
623  dY_transposed_buffer_data,
624  &context_);
625 
626  // Gradient respect to filter.
627  math::GemmStridedBatched(
628  CblasNoTrans,
629  CblasTrans,
630  shape.output_image_size * group_,
631  shape.M / group_,
632  shape.kernel_size,
633  shape.N,
634  1.0f,
635  dY_transposed_buffer_data,
636  shape.M / group_ * shape.N,
637  column_transposed_buffer->template data<T>(),
638  shape.N * shape.kernel_size,
639  0.0f,
640  dfilter_data,
641  shape.M / group_ * shape.kernel_size,
642  &context_);
643 
644  if (dbias_data != nullptr) {
645  // Gradient respect to bias.
646  math::Gemv<T, Context>(
647  CblasNoTrans,
648  shape.output_image_size * shape.M,
649  shape.N,
650  1.0f,
651  dY_transposed_buffer_data,
652  bias_multiplier_.template data<T>(),
653  0.0f,
654  dbias_data,
655  &context_);
656  }
657 
658  if (dX_data != nullptr) {
659  // Gradient respect to X.
660  math::GemmStridedBatched(
661  CblasTrans,
662  CblasNoTrans,
663  shape.output_image_size * group_,
664  shape.kernel_size,
665  shape.N,
666  shape.M / group_,
667  1.0f,
668  filter_data,
669  shape.kernel_size * shape.M / group_,
670  dY_transposed_buffer_data,
671  shape.M / group_ * shape.N,
672  0.0f,
673  column_transposed_buffer->template mutable_data<T>(),
674  shape.kernel_size * shape.N,
675  &context_);
676  math::Transpose(
677  shape.column_transposed_dims.size(),
678  shape.column_transposed_dims.data(),
679  shape.column_axes.data(),
680  column_transposed_buffer->template data<T>(),
681  column_buffer->template mutable_data<T>(),
682  &context_);
683  const T* const_column_buffer_data = column_buffer->template data<T>();
684  for (int image_id = 0; image_id < shape.N; ++image_id) {
685  for (int group_id = 0; group_id < group_; ++group_id) {
686  if (kernel_.size() == 2) {
687  math::Col2Im<T, Context, StorageOrder::NCHW>(
688  shape.C / group_,
689  shape.X_dims[1],
690  shape.X_dims[2],
691  kernel_h(),
692  kernel_w(),
693  dilation_h(),
694  dilation_w(),
695  pad_t(),
696  pad_l(),
697  pad_b(),
698  pad_r(),
699  stride_h(),
700  stride_w(),
701  const_column_buffer_data + group_id * column_stride,
702  dX_data + group_id * input_stride,
703  &context_);
704  } else {
705  math::Col2ImNd<T, Context, StorageOrder::NCHW>(
706  kernel_.size(),
707  shape.C * shape.input_image_size,
708  column_stride,
709  shape.X_dims.data(),
710  shape.column_slice_dims.data(),
711  kernel_.data(),
712  stride_.data(),
713  dilation_.data(),
714  pads_.data(),
715  const_column_buffer_data + group_id * column_stride,
716  dX_data + group_id * input_stride,
717  &context_);
718  }
719  }
720  dX_data += input_stride * group_;
721  const_column_buffer_data += column_stride * group_;
722  }
723  }
724 }
725 
726 template <typename T, class Context>
727 void LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNHWCImpl(
728  const lc_op_util::ShapeParams& shape,
729  const T* X_data,
730  const T* filter_data,
731  const T* dY_data,
732  T* dfilter_data,
733  T* dX_data,
734  T* dbias_data,
735  Tensor* column_buffer,
736  Tensor* column_transposed_buffer,
737  Tensor* dY_transposed_buffer) {
738  const int input_stride = shape.C * shape.input_image_size;
739  const int column_stride = shape.kernel_size * shape.output_image_size;
740  column_buffer->Resize(shape.column_dims);
741  column_transposed_buffer->Resize(shape.column_transposed_dims);
742  dY_transposed_buffer->Resize(shape.Y_transposed_dims);
743  T* column_buffer_data = column_buffer->template mutable_data<T>();
744  T* dY_transposed_buffer_data =
745  dY_transposed_buffer->template mutable_data<T>();
746  for (int image_id = 0; image_id < shape.N; ++image_id) {
747  math::Im2Col<T, Context, StorageOrder::NHWC>(
748  shape.C,
749  shape.X_dims[0],
750  shape.X_dims[1],
751  kernel_h(),
752  kernel_w(),
753  dilation_h(),
754  dilation_w(),
755  pad_t(),
756  pad_l(),
757  pad_b(),
758  pad_r(),
759  stride_h(),
760  stride_w(),
761  X_data + image_id * input_stride,
762  column_buffer_data + image_id * column_stride,
763  &context_);
764  }
765  math::Transpose(
766  shape.column_dims.size(),
767  shape.column_dims.data(),
768  shape.column_axes.data(),
769  column_buffer->template data<T>(),
770  column_transposed_buffer->template mutable_data<T>(),
771  &context_);
772  math::Transpose(
773  shape.Y_dims.size(),
774  shape.Y_dims.data(),
775  shape.Y_axes.data(),
776  dY_data,
777  dY_transposed_buffer_data,
778  &context_);
779 
780  // Gradient respect to filter.
781  math::GemmStridedBatched(
782  CblasTrans,
783  CblasNoTrans,
784  shape.output_image_size,
785  shape.M,
786  shape.kernel_size,
787  shape.N,
788  1.0f,
789  dY_transposed_buffer_data,
790  shape.M * shape.N,
791  column_transposed_buffer->template data<T>(),
792  shape.N * shape.kernel_size,
793  0.0f,
794  dfilter_data,
795  shape.M * shape.kernel_size,
796  &context_);
797 
798  if (dbias_data != nullptr) {
799  // Gradient respect to bias.
800  math::Gemv<T, Context>(
801  CblasTrans,
802  shape.N,
803  shape.output_image_size * shape.M,
804  1.0f,
805  dY_data,
806  bias_multiplier_.template data<T>(),
807  0.0f,
808  dbias_data,
809  &context_);
810  }
811 
812  if (dX_data != nullptr) {
813  // Gradient respect to X.
814  math::GemmStridedBatched(
815  CblasNoTrans,
816  CblasNoTrans,
817  shape.output_image_size,
818  shape.N,
819  shape.kernel_size,
820  shape.M,
821  1.0f,
822  dY_transposed_buffer_data,
823  shape.N * shape.M,
824  filter_data,
825  shape.M * shape.kernel_size,
826  0.0f,
827  column_transposed_buffer->template mutable_data<T>(),
828  shape.N * shape.kernel_size,
829  &context_);
830  math::Transpose(
831  shape.column_transposed_dims.size(),
832  shape.column_transposed_dims.data(),
833  shape.column_axes.data(),
834  column_transposed_buffer->template data<T>(),
835  column_buffer->template mutable_data<T>(),
836  &context_);
837  const T* const_column_buffer_data = column_buffer->template data<T>();
838  for (int image_id = 0; image_id < shape.N; ++image_id) {
839  math::Col2Im<T, Context, StorageOrder::NHWC>(
840  shape.C,
841  shape.X_dims[0],
842  shape.X_dims[1],
843  kernel_h(),
844  kernel_w(),
845  dilation_h(),
846  dilation_w(),
847  pad_t(),
848  pad_l(),
849  pad_b(),
850  pad_r(),
851  stride_h(),
852  stride_w(),
853  const_column_buffer_data,
854  dX_data,
855  &context_);
856  dX_data += input_stride;
857  const_column_buffer_data += column_stride;
858  }
859  }
860 }
861 
862 } // namespace caffe2
863 
864 #endif // CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_IMPL_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13