Caffe2 - C++ API
A deep learning, cross platform ML framework
locally_connected_op_impl.h
1 
17 // locally_connected_impl.h is the templated implementation of the
18 // locally_connected.h file.
19 
20 #ifndef CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_IMPL_H_
21 #define CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_IMPL_H_
22 
23 #include <vector>
24 
25 #include "caffe2/core/context.h"
26 #include "caffe2/core/flags.h"
27 #include "caffe2/core/logging.h"
28 #include "caffe2/core/operator.h"
29 #include "caffe2/operators/conv_pool_op_base.h"
30 #include "caffe2/operators/locally_connected_op.h"
31 #include "caffe2/utils/math.h"
32 
33 namespace caffe2 {
34 
35 namespace {
36 
37 void SetColumnBufferShapeImpl(
38  const int N,
39  const int C,
40  const int kernel_dim,
41  const StorageOrder order,
42  const std::vector<int>& output_image_dims,
43  std::vector<int>* column_dims,
44  std::vector<int>* column_transposed_dims,
45  std::vector<int>* column_axes,
46  std::vector<int>* column_transposed_axes) {
47  const int n_column_dims = output_image_dims.size() + 2;
48  column_dims->resize(n_column_dims);
49  column_transposed_dims->resize(n_column_dims);
50  column_axes->resize(n_column_dims);
51  if (order == StorageOrder::NCHW) {
52  for (int i = 0; i < n_column_dims - 2; ++i) {
53  column_dims->at(i + 2) = output_image_dims[i];
54  column_transposed_dims->at(i) = output_image_dims[i];
55  column_axes->at(i) = i + 2;
56  }
57  column_dims->at(0) = N;
58  column_dims->at(1) = kernel_dim;
59  column_transposed_dims->at(n_column_dims - 2) = kernel_dim;
60  column_transposed_dims->at(n_column_dims - 1) = N;
61  column_axes->at(n_column_dims - 1) = 0;
62  column_axes->at(n_column_dims - 2) = 1;
63  } else {
64  for (int i = 0; i < n_column_dims - 2; ++i) {
65  column_dims->at(i + 1) = output_image_dims[i];
66  column_transposed_dims->at(i) = output_image_dims[i];
67  column_axes->at(i) = i + 1;
68  }
69  column_dims->at(0) = N;
70  column_dims->at(n_column_dims - 1) = kernel_dim;
71  column_transposed_dims->at(n_column_dims - 2) = N;
72  column_transposed_dims->at(n_column_dims - 1) = kernel_dim;
73  column_axes->at(n_column_dims - 2) = 0;
74  column_axes->at(n_column_dims - 1) = n_column_dims - 1;
75  }
76  if (column_transposed_axes != nullptr) {
77  column_transposed_axes->resize(n_column_dims);
78  for (int i = 0; i < n_column_dims; ++i) {
79  column_transposed_axes->at(column_axes->at(i)) = i;
80  }
81  }
82 }
83 
84 } // namespace
85 
86 template <typename T, class Context>
87 bool LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNCHW() {
88  const auto& X = Input(INPUT);
89  const auto& filter = Input(FILTER);
90  auto* Y = Output(0);
91  const int image_ndim = X.ndim() - 2;
92  CAFFE_ENFORCE_EQ(X.ndim() + image_ndim, filter.ndim());
93  ShapeParams shape;
94  shape.N = X.dim32(0);
95  shape.C = X.dim32(1);
96  shape.M = filter.dim32(image_ndim);
97  CAFFE_ENFORCE(
98  shape.C == filter.dim32(image_ndim + 1) * group_,
99  "Locally Connected op: input channels does not match: "
100  "# of input channels ",
101  shape.C,
102  " is not equal to kernel channels * group:",
103  filter.dim32(image_ndim + 1),
104  "*",
105  group_);
106  CAFFE_ENFORCE(
107  shape.M % group_ == 0,
108  "The number of output channels is not divisible by group.");
109 
110  ConvPoolOpBase<Context>::SetOutputSize(X, Y, shape.M);
111  shape.input_image_size = GetDimsSize(X);
112  shape.output_image_size = GetDimsSize(*Y);
113  const std::vector<int> output_image_dims = GetDims(*Y);
114  for (int i = 0; i < image_ndim; ++i) {
115  CAFFE_ENFORCE(output_image_dims[i] == filter.dim32(i));
116  }
117 
118  int kernel_dims_size = 1;
119  for (std::size_t i = 0; i < kernel_.size(); ++i) {
120  CAFFE_ENFORCE_EQ(filter.dim32(i + image_ndim + 2), kernel_[i]);
121  kernel_dims_size *= kernel_[i];
122  }
123 
124  shape.input_image_dims = GetDims(X);
125  const std::vector<int> input_dims(X.dims().cbegin() + 1, X.dims().cend());
126  SetDeviceTensor(input_dims, &input_dims_device_);
127  shape.kernel_dim = shape.C / group_ * kernel_dims_size;
128 
129  const std::vector<int> Y_dims(Y->dims().cbegin(), Y->dims().cend());
130  SetColumnBufferShape(
131  shape.N,
132  shape.C,
133  shape.kernel_dim,
134  output_image_dims,
135  &shape.column_dims,
136  &shape.column_transposed_dims);
137  SetYTranposedBufferShape(Y_dims, &shape.Y_transposed_dims);
138 
139  const T* X_data = X.template data<T>();
140  const T* filter_data = filter.template data<T>();
141  const T* bias_data = nullptr;
142  if (InputSize() == 3) {
143  const auto& bias = Input(BIAS);
144  CAFFE_ENFORCE(bias.ndim() == image_ndim + 1);
145  for (int i = 0; i < image_ndim; ++i) {
146  CAFFE_ENFORCE(bias.dim32(i) == output_image_dims[i]);
147  }
148  CAFFE_ENFORCE(bias.dim32(image_ndim) == shape.M);
149  bias_data = bias.template data<T>();
150  ConvPoolOpBase<Context>::template SetBiasMultiplier<T>(
151  shape.N, &bias_multiplier_);
152  }
153  T* Y_data = Y->template mutable_data<T>();
154 
155  RunOnDeviceWithOrderNCHWImpl(
156  shape,
157  X_data,
158  filter_data,
159  bias_data,
160  Y_data,
161  &column_buffer_,
162  &column_transposed_buffer_,
163  &Y_transposed_buffer_);
164 
165  return true;
166 }
167 
168 template <typename T, class Context>
169 bool LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNHWC() {
170  const auto& X = Input(INPUT);
171  const auto& filter = Input(FILTER);
172  auto* Y = Output(0);
173  CAFFE_ENFORCE_EQ(
174  kernel_.size(),
175  2,
176  "Only 2d locally connected op is supported for NHWC storage type.");
177  const int image_ndim = X.ndim() - 2;
178  CAFFE_ENFORCE_EQ(X.ndim() + image_ndim, filter.ndim());
179  ShapeParams shape;
180  shape.N = X.dim32(0);
181  shape.C = X.dim32(3);
182  shape.input_image_dims = {X.dim32(1), X.dim32(2)};
183  shape.M = filter.dim32(image_ndim);
184  CAFFE_ENFORCE(filter.dim32(image_ndim + 1) == kernel_h());
185  CAFFE_ENFORCE(filter.dim32(image_ndim + 2) == kernel_w());
186  CAFFE_ENFORCE(filter.dim32(image_ndim + 3) == shape.C);
187  ConvPoolOpBase<Context>::SetOutputSize(X, Y, shape.M);
188 
189  shape.input_image_size = GetDimsSize(X);
190  shape.output_image_size = GetDimsSize(*Y);
191  const std::vector<int> output_image_dims = GetDims(*Y);
192  for (int i = 0; i < image_ndim; ++i) {
193  CAFFE_ENFORCE(output_image_dims[i] == filter.dim32(i));
194  }
195 
196  shape.kernel_dim = kernel_h() * kernel_w() * shape.C;
197  const std::vector<int> Y_dims(Y->dims().cbegin(), Y->dims().cend());
198  SetColumnBufferShape(
199  shape.N,
200  shape.C,
201  shape.kernel_dim,
202  output_image_dims,
203  &shape.column_dims,
204  &shape.column_transposed_dims);
205  SetYTranposedBufferShape(Y_dims, &shape.Y_transposed_dims);
206 
207  const T* X_data = X.template data<T>();
208  const T* filter_data = filter.template data<T>();
209  const T* bias_data = nullptr;
210  if (InputSize() == 3) {
211  const auto& bias = Input(BIAS);
212  CAFFE_ENFORCE(bias.ndim() == image_ndim + 1);
213  for (int i = 0; i < image_ndim; ++i) {
214  CAFFE_ENFORCE(bias.dim32(i) == output_image_dims[i]);
215  }
216  CAFFE_ENFORCE(bias.dim32(image_ndim) == shape.M);
217  bias_data = bias.template data<T>();
218  ConvPoolOpBase<Context>::template SetBiasMultiplier<T>(
219  shape.N, &bias_multiplier_);
220  }
221  T* Y_data = Y->template mutable_data<T>();
222 
223  RunOnDeviceWithOrderNHWCImpl(
224  shape,
225  X_data,
226  filter_data,
227  bias_data,
228  Y_data,
229  &column_buffer_,
230  &column_transposed_buffer_,
231  &Y_transposed_buffer_);
232 
233  return true;
234 }
235 
236 template <typename T, class Context>
237 void LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNCHWImpl(
238  const ShapeParams& shape,
239  const T* X_data,
240  const T* filter_data,
241  const T* bias_data,
242  T* Y_data,
243  Tensor<Context>* column_buffer,
244  Tensor<Context>* column_transposed_buffer,
245  Tensor<Context>* Y_transposed_buffer) {
246  const int input_stride = shape.C / group_ * shape.input_image_size;
247  const int column_stride = shape.kernel_dim * shape.output_image_size;
248  column_buffer->Resize(shape.column_dims);
249  column_transposed_buffer->Resize(shape.column_transposed_dims);
250  Y_transposed_buffer->Resize(shape.Y_transposed_dims);
251  T* column_buffer_data = column_buffer->template mutable_data<T>();
252  T* Y_transposed_buffer_data = Y_transposed_buffer->template mutable_data<T>();
253 
254  for (int image_id = 0; image_id < shape.N; ++image_id) {
255  for (int group_id = 0; group_id < group_; ++group_id) {
256  if (kernel_.size() == 2) {
257  math::Im2col<T, Context, StorageOrder::NCHW>(
258  X_data + group_id * input_stride,
259  shape.C / group_,
260  shape.input_image_dims[0],
261  shape.input_image_dims[1],
262  kernel_h(),
263  kernel_w(),
264  dilation_h(),
265  dilation_w(),
266  pad_t(),
267  pad_l(),
268  pad_b(),
269  pad_r(),
270  stride_h(),
271  stride_w(),
272  column_buffer_data + group_id * column_stride,
273  &context_);
274  } else {
275  math::Im2colNd<T, Context, StorageOrder::NCHW>(
276  X_data + group_id * input_stride,
277  input_dims_device_.template data<int>(),
278  column_dims_device_.template data<int>() + 1,
279  shape.C * shape.input_image_size,
280  column_stride,
281  kernel_device_.template data<int>(),
282  stride_device_.template data<int>(),
283  dilation_device_.template data<int>(),
284  pads_device_.template data<int>(),
285  kernel_.size(),
286  column_buffer_data + group_id * column_stride,
287  &context_);
288  }
289  }
290  X_data += input_stride * group_;
291  column_buffer_data += column_stride * group_;
292  }
293  math::Transpose(
294  shape.column_dims.size(),
295  column_dims_device_.template data<int>(),
296  column_transposed_dims_device_.template data<int>(),
297  column_axes_device_.template data<int>(),
298  column_buffer->size(),
299  column_buffer->template data<T>(),
300  column_transposed_buffer->template mutable_data<T>(),
301  &context_);
302  math::GemmBatched(
303  CblasNoTrans,
304  CblasNoTrans,
305  shape.output_image_size * group_,
306  shape.M / group_,
307  shape.N,
308  shape.kernel_dim,
309  1.0f,
310  filter_data,
311  column_transposed_buffer->template data<T>(),
312  0.0f,
313  Y_transposed_buffer_data,
314  &context_);
315  if (bias_data != nullptr) {
316  math::Gemm<T, Context>(
317  CblasNoTrans,
318  CblasNoTrans,
319  shape.output_image_size * shape.M,
320  shape.N,
321  1,
322  1.0,
323  bias_data,
324  bias_multiplier_.template data<T>(),
325  1.0,
326  Y_transposed_buffer_data,
327  &context_);
328  }
329  math::Transpose(
330  shape.Y_transposed_dims.size(),
331  Y_transposed_dims_device_.template data<int>(),
332  Y_dims_device_.template data<int>(),
333  Y_transposed_axes_device_.template data<int>(),
334  Y_transposed_buffer->size(),
335  Y_transposed_buffer_data,
336  Y_data,
337  &context_);
338 }
339 
340 template <typename T, class Context>
341 void LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNHWCImpl(
342  const ShapeParams& shape,
343  const T* X_data,
344  const T* filter_data,
345  const T* bias_data,
346  T* Y_data,
347  Tensor<Context>* column_buffer,
348  Tensor<Context>* column_transposed_buffer,
349  Tensor<Context>* Y_transposed_buffer) {
350  const int input_stride = shape.C * shape.input_image_size;
351  const int column_stride = shape.kernel_dim * shape.output_image_size;
352  column_buffer->Resize(shape.column_dims);
353  column_transposed_buffer->Resize(shape.column_transposed_dims);
354  Y_transposed_buffer->Resize(shape.Y_transposed_dims);
355  T* column_buffer_data = column_buffer->template mutable_data<T>();
356  T* Y_transposed_buffer_data = Y_transposed_buffer->template mutable_data<T>();
357  for (int image_id = 0; image_id < shape.N; ++image_id) {
358  math::Im2col<T, Context, StorageOrder::NHWC>(
359  X_data + image_id * input_stride,
360  shape.C,
361  shape.input_image_dims[0],
362  shape.input_image_dims[1],
363  kernel_h(),
364  kernel_w(),
365  dilation_h(),
366  dilation_w(),
367  pad_t(),
368  pad_l(),
369  pad_b(),
370  pad_r(),
371  stride_h(),
372  stride_w(),
373  column_buffer_data + image_id * column_stride,
374  &context_);
375  }
376  math::Transpose(
377  shape.column_dims.size(),
378  column_dims_device_.template data<int>(),
379  column_transposed_dims_device_.template data<int>(),
380  column_axes_device_.template data<int>(),
381  column_buffer->size(),
382  column_buffer->template data<T>(),
383  column_transposed_buffer->template mutable_data<T>(),
384  &context_);
385  math::GemmBatched(
386  CblasNoTrans,
387  CblasTrans,
388  shape.output_image_size,
389  shape.N,
390  shape.M,
391  shape.kernel_dim,
392  1.0f,
393  column_transposed_buffer->template data<T>(),
394  filter_data,
395  0.0f,
396  Y_transposed_buffer_data,
397  &context_);
398  math::Transpose(
399  shape.Y_transposed_dims.size(),
400  Y_transposed_dims_device_.template data<int>(),
401  Y_dims_device_.template data<int>(),
402  Y_transposed_axes_device_.template data<int>(),
403  Y_transposed_buffer->size(),
404  Y_transposed_buffer_data,
405  Y_data,
406  &context_);
407  if (bias_data != nullptr) {
408  math::Gemm<T, Context>(
409  CblasNoTrans,
410  CblasNoTrans,
411  shape.N,
412  shape.output_image_size * shape.M,
413  1,
414  1.0f,
415  bias_multiplier_.template data<T>(),
416  bias_data,
417  1.0f,
418  Y_data,
419  &context_);
420  }
421 }
422 
423 template <typename T, class Context>
424 void LocallyConnectedOp<T, Context>::SetColumnBufferShape(
425  const int N,
426  const int C,
427  const int kernel_dim,
428  const std::vector<int>& output_image_dims,
429  std::vector<int>* column_dims,
430  std::vector<int>* column_transposed_dims) {
431  std::vector<int> column_axes;
432  SetColumnBufferShapeImpl(
433  N,
434  C,
435  kernel_dim,
436  order_,
437  output_image_dims,
438  column_dims,
439  column_transposed_dims,
440  &column_axes,
441  nullptr);
442  SetDeviceTensor(*column_dims, &column_dims_device_);
443  SetDeviceTensor(*column_transposed_dims, &column_transposed_dims_device_);
444  SetDeviceTensor(column_axes, &column_axes_device_);
445 }
446 
447 template <typename T, class Context>
448 void LocallyConnectedOp<T, Context>::SetYTranposedBufferShape(
449  const std::vector<int>& Y_dims,
450  std::vector<int>* Y_transposed_dims) {
451  const int n_Y_dims = Y_dims.size();
452  Y_transposed_dims->resize(n_Y_dims);
453  std::vector<int> Y_transposed_axes(n_Y_dims);
454  if (order_ == StorageOrder::NCHW) {
455  for (int i = 0; i < n_Y_dims - 2; ++i) {
456  Y_transposed_dims->at(i) = Y_dims[i + 2];
457  Y_transposed_axes[i + 2] = i;
458  }
459  Y_transposed_dims->at(n_Y_dims - 2) = Y_dims[1];
460  Y_transposed_dims->at(n_Y_dims - 1) = Y_dims[0];
461  Y_transposed_axes[1] = n_Y_dims - 2;
462  Y_transposed_axes[0] = n_Y_dims - 1;
463  } else {
464  for (int i = 0; i < n_Y_dims - 2; ++i) {
465  Y_transposed_dims->at(i) = Y_dims[i + 1];
466  Y_transposed_axes[i + 1] = i;
467  }
468  Y_transposed_dims->at(n_Y_dims - 2) = Y_dims[0];
469  Y_transposed_dims->at(n_Y_dims - 1) = Y_dims[n_Y_dims - 1];
470  Y_transposed_axes[0] = n_Y_dims - 2;
471  Y_transposed_axes[n_Y_dims - 1] = n_Y_dims - 1;
472  }
473  SetDeviceTensor(Y_dims, &Y_dims_device_);
474  SetDeviceTensor(*Y_transposed_dims, &Y_transposed_dims_device_);
475  SetDeviceTensor(Y_transposed_axes, &Y_transposed_axes_device_);
476 }
477 
478 template <typename T, class Context>
479 bool LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNCHW() {
480  const auto& X = Input(INPUT);
481  const auto& filter = Input(FILTER);
482  const auto& dY = Input(OUTPUT_GRAD);
483  auto* dfilter = Output(FILTER_GRAD);
484  const int image_ndim = X.ndim() - 2;
485  CAFFE_ENFORCE_EQ(X.ndim() + image_ndim, filter.ndim());
486 
487  ShapeParams shape;
488  shape.N = X.dim32(0);
489  shape.C = X.dim32(1);
490  shape.M = filter.dim32(image_ndim);
491  CAFFE_ENFORCE(filter.dim32(image_ndim + 1) * group_ == shape.C);
492  CAFFE_ENFORCE(shape.M % group_ == 0);
493 
494  shape.input_image_dims = GetDims(X);
495  shape.input_image_size = GetDimsSize(X);
496  const std::vector<int> output_image_dims = GetDims(dY);
497  shape.output_image_size = GetDimsSize(dY);
498  for (int i = 0; i < image_ndim; ++i) {
499  CAFFE_ENFORCE(output_image_dims[i] == filter.dim32(i));
500  }
501  ConvPoolOpBase<Context>::ComputePads(shape.input_image_dims);
502 
503  int kernel_dims_size = 1;
504  for (std::size_t i = 0; i < kernel_.size(); ++i) {
505  CAFFE_ENFORCE_EQ(filter.dim32(i + image_ndim + 2), kernel_[i]);
506  kernel_dims_size *= kernel_[i];
507  }
508 
509  const std::vector<int> input_dims(X.dims().cbegin() + 1, X.dims().cend());
510  SetDeviceTensor(input_dims, &input_dims_device_);
511  shape.kernel_dim = shape.C / group_ * kernel_dims_size;
512 
513  const std::vector<int> dY_dims(dY.dims().cbegin(), dY.dims().cend());
514  SetColumnBufferShape(
515  shape.N,
516  shape.C,
517  shape.kernel_dim,
518  output_image_dims,
519  &shape.column_dims,
520  &shape.column_transposed_dims);
521  SetDYTranposedBufferShape(dY_dims, &shape.dY_transposed_dims);
522 
523  dfilter->ResizeLike(filter);
524  const T* X_data = X.template data<T>();
525  const T* filter_data = filter.template data<T>();
526  const T* dY_data = dY.template data<T>();
527  T* dfilter_data = dfilter->template mutable_data<T>();
528  T* dX_data = nullptr;
529  T* dbias_data = nullptr;
530  if (OutputSize() == 3 || (no_bias_ && OutputSize() == 2)) {
531  auto* dX = Output(no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD);
532  dX->ResizeLike(X);
533  dX_data = dX->template mutable_data<T>();
534  }
535  if (!no_bias_) {
536  auto* dbias = Output(BIAS_OR_INPUT_GRAD);
537  std::vector<int> dbias_dims = output_image_dims;
538  dbias_dims.push_back(shape.M);
539  dbias->Resize(dbias_dims);
540  ConvPoolOpBase<Context>::template SetBiasMultiplier<T>(
541  shape.N, &bias_multiplier_);
542  dbias_data = dbias->template mutable_data<T>();
543  }
544  RunOnDeviceWithOrderNCHWImpl(
545  shape,
546  X_data,
547  filter_data,
548  dY_data,
549  dfilter_data,
550  dX_data,
551  dbias_data,
552  &column_buffer_,
553  &column_transposed_buffer_,
554  &dY_transposed_buffer_);
555 
556  return true;
557 }
558 
559 template <typename T, class Context>
560 bool LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNHWC() {
561  const auto& X = Input(INPUT);
562  const auto& filter = Input(FILTER);
563  const auto& dY = Input(OUTPUT_GRAD);
564  auto* dfilter = Output(FILTER_GRAD);
565  CAFFE_ENFORCE_EQ(
566  kernel_.size(),
567  2,
568  "Only 2d locally connected op is supported for NHWC storage type.");
569  const int image_ndim = X.ndim() - 2;
570  CAFFE_ENFORCE_EQ(X.ndim() + image_ndim, filter.ndim());
571  ShapeParams shape;
572  shape.N = X.dim32(0);
573  shape.C = X.dim32(3);
574  shape.input_image_dims = {X.dim32(1), X.dim32(2)};
575  shape.M = filter.dim32(image_ndim);
576  CAFFE_ENFORCE(filter.dim32(image_ndim + 1) == kernel_h());
577  CAFFE_ENFORCE(filter.dim32(image_ndim + 2) == kernel_w());
578  CAFFE_ENFORCE(filter.dim32(image_ndim + 3) == shape.C);
579  ConvPoolOpBase<Context>::ComputePads(shape.input_image_dims);
580 
581  shape.input_image_size = GetDimsSize(X);
582  shape.output_image_size = GetDimsSize(dY);
583  const std::vector<int> output_image_dims = GetDims(dY);
584  for (int i = 0; i < image_ndim; ++i) {
585  CAFFE_ENFORCE(output_image_dims[i] == filter.dim32(i));
586  }
587 
588  shape.kernel_dim = kernel_h() * kernel_w() * shape.C;
589  const std::vector<int> dY_dims(dY.dims().cbegin(), dY.dims().cend());
590  SetColumnBufferShape(
591  shape.N,
592  shape.C,
593  shape.kernel_dim,
594  output_image_dims,
595  &shape.column_dims,
596  &shape.column_transposed_dims);
597  SetDYTranposedBufferShape(dY_dims, &shape.dY_transposed_dims);
598 
599  dfilter->ResizeLike(filter);
600  const T* X_data = X.template data<T>();
601  const T* filter_data = filter.template data<T>();
602  const T* dY_data = dY.template data<T>();
603  T* dfilter_data = dfilter->template mutable_data<T>();
604  T* dX_data = nullptr;
605  T* dbias_data = nullptr;
606  if (OutputSize() == 3 || (no_bias_ && OutputSize() == 2)) {
607  auto* dX = Output(no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD);
608  dX->ResizeLike(X);
609  dX_data = dX->template mutable_data<T>();
610  }
611  if (!no_bias_) {
612  auto* dbias = Output(BIAS_OR_INPUT_GRAD);
613  std::vector<int> dbias_dims = output_image_dims;
614  dbias_dims.push_back(shape.M);
615  dbias->Resize(dbias_dims);
616  ConvPoolOpBase<Context>::template SetBiasMultiplier<T>(
617  shape.N, &bias_multiplier_);
618  dbias_data = dbias->template mutable_data<T>();
619  }
620  RunOnDeviceWithOrderNHWCImpl(
621  shape,
622  X_data,
623  filter_data,
624  dY_data,
625  dfilter_data,
626  dX_data,
627  dbias_data,
628  &column_buffer_,
629  &column_transposed_buffer_,
630  &dY_transposed_buffer_);
631 
632  return true;
633 }
634 
635 template <typename T, class Context>
636 void LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNCHWImpl(
637  const ShapeParams& shape,
638  const T* X_data,
639  const T* filter_data,
640  const T* dY_data,
641  T* dfilter_data,
642  T* dX_data,
643  T* dbias_data,
644  Tensor<Context>* column_buffer,
645  Tensor<Context>* column_transposed_buffer,
646  Tensor<Context>* dY_transposed_buffer) {
647  const int input_stride = shape.C * shape.input_image_size;
648  const int column_stride = shape.kernel_dim * shape.output_image_size;
649  column_buffer->Resize(shape.column_dims);
650  column_transposed_buffer->Resize(shape.column_transposed_dims);
651  dY_transposed_buffer->Resize(shape.dY_transposed_dims);
652  T* column_buffer_data = column_buffer->template mutable_data<T>();
653  T* dY_transposed_buffer_data =
654  dY_transposed_buffer->template mutable_data<T>();
655 
656  for (int image_id = 0; image_id < shape.N; ++image_id) {
657  for (int group_id = 0; group_id < group_; ++group_id) {
658  if (kernel_.size() == 2) {
659  math::Im2col<T, Context, StorageOrder::NCHW>(
660  X_data + group_id * input_stride,
661  shape.C / group_,
662  shape.input_image_dims[0],
663  shape.input_image_dims[1],
664  kernel_h(),
665  kernel_w(),
666  dilation_h(),
667  dilation_w(),
668  pad_t(),
669  pad_l(),
670  pad_b(),
671  pad_r(),
672  stride_h(),
673  stride_w(),
674  column_buffer_data + group_id * column_stride,
675  &context_);
676  } else {
677  math::Im2colNd<T, Context, StorageOrder::NCHW>(
678  X_data + group_id * input_stride,
679  input_dims_device_.template data<int>(),
680  column_dims_device_.template data<int>() + 1,
681  shape.C * shape.input_image_size,
682  column_stride,
683  kernel_device_.template data<int>(),
684  stride_device_.template data<int>(),
685  dilation_device_.template data<int>(),
686  pads_device_.template data<int>(),
687  kernel_.size(),
688  column_buffer_data + group_id * column_stride,
689  &context_);
690  }
691  }
692  X_data += input_stride * group_;
693  column_buffer_data += column_stride * group_;
694  }
695  math::Transpose(
696  shape.column_dims.size(),
697  column_dims_device_.template data<int>(),
698  column_transposed_dims_device_.template data<int>(),
699  column_axes_device_.template data<int>(),
700  column_buffer->size(),
701  column_buffer->template data<T>(),
702  column_transposed_buffer->template mutable_data<T>(),
703  &context_);
704 
705  math::Transpose(
706  shape.dY_transposed_dims.size(),
707  dY_dims_device_.template data<int>(),
708  dY_transposed_dims_device_.template data<int>(),
709  dY_axes_device_.template data<int>(),
710  dY_transposed_buffer->size(),
711  dY_data,
712  dY_transposed_buffer_data,
713  &context_);
714 
715  // Gradient respect to filter.
716  math::GemmBatched(
717  CblasNoTrans,
718  CblasTrans,
719  shape.output_image_size * group_,
720  shape.M / group_,
721  shape.kernel_dim,
722  shape.N,
723  1.0f,
724  dY_transposed_buffer_data,
725  column_transposed_buffer->template data<T>(),
726  0.0f,
727  dfilter_data,
728  &context_);
729 
730  if (dbias_data != nullptr) {
731  // Gradient respect to bias.
732  math::Gemv<T, Context>(
733  CblasNoTrans,
734  shape.output_image_size * shape.M,
735  shape.N,
736  1.0f,
737  dY_transposed_buffer_data,
738  bias_multiplier_.template data<T>(),
739  0.0f,
740  dbias_data,
741  &context_);
742  }
743 
744  if (dX_data != nullptr) {
745  // Gradient respect to X.
746  math::GemmBatched(
747  CblasTrans,
748  CblasNoTrans,
749  shape.output_image_size * group_,
750  shape.kernel_dim,
751  shape.N,
752  shape.M / group_,
753  1.0f,
754  filter_data,
755  dY_transposed_buffer_data,
756  0.0f,
757  column_transposed_buffer->template mutable_data<T>(),
758  &context_);
759  math::Transpose(
760  shape.column_dims.size(),
761  column_transposed_dims_device_.template data<int>(),
762  column_dims_device_.template data<int>(),
763  column_transposed_axes_device_.template data<int>(),
764  column_transposed_buffer->size(),
765  column_transposed_buffer->template data<T>(),
766  column_buffer->template mutable_data<T>(),
767  &context_);
768  const T* const_column_buffer_data = column_buffer->template data<T>();
769  for (int image_id = 0; image_id < shape.N; ++image_id) {
770  for (int group_id = 0; group_id < group_; ++group_id) {
771  if (kernel_.size() == 2) {
772  math::Col2im<T, Context, StorageOrder::NCHW>(
773  const_column_buffer_data + group_id * column_stride,
774  shape.C / group_,
775  shape.input_image_dims[0],
776  shape.input_image_dims[1],
777  kernel_h(),
778  kernel_w(),
779  dilation_h(),
780  dilation_w(),
781  pad_t(),
782  pad_l(),
783  pad_b(),
784  pad_r(),
785  stride_h(),
786  stride_w(),
787  dX_data + group_id * input_stride,
788  &context_);
789  } else {
790  math::Col2imNd<T, Context, StorageOrder::NCHW>(
791  const_column_buffer_data + group_id * column_stride,
792  input_dims_device_.template data<int>(),
793  column_dims_device_.template data<int>() + 1,
794  shape.C * shape.input_image_size,
795  column_stride,
796  kernel_device_.template data<int>(),
797  stride_device_.template data<int>(),
798  dilation_device_.template data<int>(),
799  pads_device_.template data<int>(),
800  kernel_.size(),
801  dX_data + group_id * input_stride,
802  &context_);
803  }
804  }
805  dX_data += input_stride * group_;
806  const_column_buffer_data += column_stride * group_;
807  }
808  }
809 }
810 
811 template <typename T, class Context>
812 void LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNHWCImpl(
813  const ShapeParams& shape,
814  const T* X_data,
815  const T* filter_data,
816  const T* dY_data,
817  T* dfilter_data,
818  T* dX_data,
819  T* dbias_data,
820  Tensor<Context>* column_buffer,
821  Tensor<Context>* column_transposed_buffer,
822  Tensor<Context>* dY_transposed_buffer) {
823  const int input_stride = shape.C * shape.input_image_size;
824  const int column_stride = shape.kernel_dim * shape.output_image_size;
825  column_buffer->Resize(shape.column_dims);
826  column_transposed_buffer->Resize(shape.column_transposed_dims);
827  dY_transposed_buffer->Resize(shape.dY_transposed_dims);
828  T* column_buffer_data = column_buffer->template mutable_data<T>();
829  T* dY_transposed_buffer_data =
830  dY_transposed_buffer->template mutable_data<T>();
831  for (int image_id = 0; image_id < shape.N; ++image_id) {
832  math::Im2col<T, Context, StorageOrder::NHWC>(
833  X_data + image_id * input_stride,
834  shape.C,
835  shape.input_image_dims[0],
836  shape.input_image_dims[1],
837  kernel_h(),
838  kernel_w(),
839  dilation_h(),
840  dilation_w(),
841  pad_t(),
842  pad_l(),
843  pad_b(),
844  pad_r(),
845  stride_h(),
846  stride_w(),
847  column_buffer_data + image_id * column_stride,
848  &context_);
849  }
850  math::Transpose(
851  shape.column_dims.size(),
852  column_dims_device_.template data<int>(),
853  column_transposed_dims_device_.template data<int>(),
854  column_axes_device_.template data<int>(),
855  column_buffer->size(),
856  column_buffer->template data<T>(),
857  column_transposed_buffer->template mutable_data<T>(),
858  &context_);
859 
860  math::Transpose(
861  shape.dY_transposed_dims.size(),
862  dY_dims_device_.template data<int>(),
863  dY_transposed_dims_device_.template data<int>(),
864  dY_axes_device_.template data<int>(),
865  dY_transposed_buffer->size(),
866  dY_data,
867  dY_transposed_buffer_data,
868  &context_);
869 
870  // Gradient respect to filter.
871  math::GemmBatched(
872  CblasTrans,
873  CblasNoTrans,
874  shape.output_image_size,
875  shape.M,
876  shape.kernel_dim,
877  shape.N,
878  1.0f,
879  dY_transposed_buffer_data,
880  column_transposed_buffer->template data<T>(),
881  0.0f,
882  dfilter_data,
883  &context_);
884 
885  if (dbias_data != nullptr) {
886  // Gradient respect to bias.
887  math::Gemv<T, Context>(
888  CblasTrans,
889  shape.N,
890  shape.output_image_size * shape.M,
891  1.0f,
892  dY_data,
893  bias_multiplier_.template data<T>(),
894  0.0f,
895  dbias_data,
896  &context_);
897  }
898 
899  if (dX_data != nullptr) {
900  // Gradient respect to X.
901  math::GemmBatched(
902  CblasNoTrans,
903  CblasNoTrans,
904  shape.output_image_size,
905  shape.N,
906  shape.kernel_dim,
907  shape.M,
908  1.0f,
909  dY_transposed_buffer_data,
910  filter_data,
911  0.0f,
912  column_transposed_buffer->template mutable_data<T>(),
913  &context_);
914  math::Transpose(
915  shape.column_dims.size(),
916  column_transposed_dims_device_.template data<int>(),
917  column_dims_device_.template data<int>(),
918  column_transposed_axes_device_.template data<int>(),
919  column_transposed_buffer->size(),
920  column_transposed_buffer->template data<T>(),
921  column_buffer->template mutable_data<T>(),
922  &context_);
923  const T* const_column_buffer_data = column_buffer->template data<T>();
924  for (int image_id = 0; image_id < shape.N; ++image_id) {
925  math::Col2im<T, Context, StorageOrder::NHWC>(
926  const_column_buffer_data,
927  shape.C,
928  shape.input_image_dims[0],
929  shape.input_image_dims[1],
930  kernel_h(),
931  kernel_w(),
932  dilation_h(),
933  dilation_w(),
934  pad_t(),
935  pad_l(),
936  pad_b(),
937  pad_r(),
938  stride_h(),
939  stride_w(),
940  dX_data,
941  &context_);
942  dX_data += input_stride;
943  const_column_buffer_data += column_stride;
944  }
945  }
946 }
947 
948 template <typename T, class Context>
949 void LocallyConnectedGradientOp<T, Context>::SetColumnBufferShape(
950  const int N,
951  const int C,
952  const int kernel_dim,
953  const std::vector<int>& output_image_dims,
954  std::vector<int>* column_dims,
955  std::vector<int>* column_transposed_dims) {
956  std::vector<int> column_axes;
957  std::vector<int> column_transposed_axes;
958  SetColumnBufferShapeImpl(
959  N,
960  C,
961  kernel_dim,
962  order_,
963  output_image_dims,
964  column_dims,
965  column_transposed_dims,
966  &column_axes,
967  &column_transposed_axes);
968  SetDeviceTensor(*column_dims, &column_dims_device_);
969  SetDeviceTensor(*column_transposed_dims, &column_transposed_dims_device_);
970  SetDeviceTensor(column_axes, &column_axes_device_);
971  SetDeviceTensor(column_transposed_axes, &column_transposed_axes_device_);
972 }
973 
974 template <typename T, class Context>
975 void LocallyConnectedGradientOp<T, Context>::SetDYTranposedBufferShape(
976  const std::vector<int>& dY_dims,
977  std::vector<int>* dY_transposed_dims) {
978  const int n_dY_dims = dY_dims.size();
979  dY_transposed_dims->resize(n_dY_dims);
980  std::vector<int> dY_axes(n_dY_dims);
981  if (order_ == StorageOrder::NCHW) {
982  for (int i = 0; i < n_dY_dims - 2; ++i) {
983  dY_transposed_dims->at(i) = dY_dims[i + 2];
984  dY_axes[i] = i + 2;
985  }
986  dY_transposed_dims->at(n_dY_dims - 2) = dY_dims[1];
987  dY_transposed_dims->at(n_dY_dims - 1) = dY_dims[0];
988  dY_axes[n_dY_dims - 2] = 1;
989  dY_axes[n_dY_dims - 1] = 0;
990  } else {
991  for (int i = 0; i < n_dY_dims - 2; ++i) {
992  dY_transposed_dims->at(i) = dY_dims[i + 1];
993  dY_axes[i] = i + 1;
994  }
995  dY_transposed_dims->at(n_dY_dims - 2) = dY_dims[0];
996  dY_transposed_dims->at(n_dY_dims - 1) = dY_dims[n_dY_dims - 1];
997  dY_axes[n_dY_dims - 2] = 0;
998  dY_axes[n_dY_dims - 1] = n_dY_dims - 1;
999  }
1000  SetDeviceTensor(dY_dims, &dY_dims_device_);
1001  SetDeviceTensor(*dY_transposed_dims, &dY_transposed_dims_device_);
1002  SetDeviceTensor(dY_axes, &dY_axes_device_);
1003 }
1004 
1005 } // namespace caffe2
1006 
1007 #endif // CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_IMPL_H_
Copyright (c) 2016-present, Facebook, Inc.
Copyright (c) 2016-present, Facebook, Inc.
Copyright (c) 2016-present, Facebook, Inc.