4 #ifndef CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_IMPL_H_ 5 #define CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_IMPL_H_ 9 #include "caffe2/core/context.h" 10 #include "caffe2/core/flags.h" 11 #include "caffe2/core/logging.h" 12 #include "caffe2/core/operator.h" 13 #include "caffe2/operators/conv_pool_op_base.h" 14 #include "caffe2/operators/locally_connected_op.h" 15 #include "caffe2/utils/math.h" 19 template <
typename T,
class Context>
20 bool LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNCHW() {
21 const auto& X = Input(INPUT);
22 const auto& filter = Input(FILTER);
24 const int image_ndim = X.dim() - 2;
25 CAFFE_ENFORCE_EQ(X.dim() + image_ndim, filter.dim());
26 lc_op_util::ShapeParams shape;
29 shape.M = filter.dim32(image_ndim);
31 shape.C == filter.dim32(image_ndim + 1) * group_,
32 "Locally Connected op: input channels does not match: " 33 "# of input channels ",
35 " is not equal to kernel channels * group:",
36 filter.dim32(image_ndim + 1),
42 "The number of output channels is not divisible by group.");
44 ConvPoolOpBase<Context>::SetOutputSize(X, Y, shape.M);
45 shape.input_image_size = GetDimsSize(X);
46 shape.output_image_size = GetDimsSize(*Y);
47 const std::vector<int> output_image_dims = GetDims(*Y);
48 for (
int i = 0; i < image_ndim; ++i) {
49 CAFFE_ENFORCE_EQ(output_image_dims[i], filter.dim32(i));
52 int kernel_dims_size = 1;
53 for (std::size_t i = 0; i < kernel_.size(); ++i) {
54 CAFFE_ENFORCE_EQ(filter.dim32(i + image_ndim + 2), kernel_[i]);
55 kernel_dims_size *= kernel_[i];
58 shape.X_dims.assign(X.sizes().cbegin() + 1, X.sizes().cend());
59 shape.kernel_size = shape.C / group_ * kernel_dims_size;
60 lc_op_util::SetColumnBufferShape(
63 shape.output_image_size,
66 &shape.column_slice_dims,
68 &shape.column_transposed_dims,
70 lc_op_util::SetYBufferShape(
73 shape.output_image_size,
76 &shape.Y_transposed_dims,
79 const T* X_data = X.template data<T>();
80 const T* filter_data = filter.template data<T>();
81 const T* bias_data =
nullptr;
82 if (InputSize() == 3) {
83 const auto& bias = Input(BIAS);
84 CAFFE_ENFORCE_EQ(bias.dim(), image_ndim + 1);
85 for (
int i = 0; i < image_ndim; ++i) {
86 CAFFE_ENFORCE_EQ(bias.dim32(i), output_image_dims[i]);
88 CAFFE_ENFORCE_EQ(bias.dim32(image_ndim), shape.M);
89 bias_data = bias.template data<T>();
90 ConvPoolOpBase<Context>::template SetBiasMultiplier<T>(
91 shape.N, &bias_multiplier_);
93 T* Y_data = Y->template mutable_data<T>();
95 RunOnDeviceWithOrderNCHWImpl(
102 &column_transposed_buffer_,
103 &Y_transposed_buffer_);
108 template <
typename T,
class Context>
109 bool LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNHWC() {
110 const auto& X = Input(INPUT);
111 const auto& filter = Input(FILTER);
116 "Only 2d locally connected op is supported for NHWC storage type.");
117 const int image_ndim = X.dim() - 2;
118 CAFFE_ENFORCE_EQ(X.dim() + image_ndim, filter.dim());
119 lc_op_util::ShapeParams shape;
120 shape.N = X.dim32(0);
121 shape.C = X.dim32(3);
122 shape.X_dims = {X.dim32(1), X.dim32(2), X.dim32(3)};
123 shape.M = filter.dim32(image_ndim);
124 CAFFE_ENFORCE_EQ(filter.dim32(image_ndim + 1), kernel_h());
125 CAFFE_ENFORCE_EQ(filter.dim32(image_ndim + 2), kernel_w());
126 CAFFE_ENFORCE_EQ(filter.dim32(image_ndim + 3), shape.C);
127 ConvPoolOpBase<Context>::SetOutputSize(X, Y, shape.M);
129 shape.input_image_size = GetDimsSize(X);
130 shape.output_image_size = GetDimsSize(*Y);
131 const std::vector<int> output_image_dims = GetDims(*Y);
132 for (
int i = 0; i < image_ndim; ++i) {
133 CAFFE_ENFORCE_EQ(output_image_dims[i], filter.dim32(i));
136 shape.kernel_size = kernel_h() * kernel_w() * shape.C;
137 lc_op_util::SetColumnBufferShape(
140 shape.output_image_size,
143 &shape.column_slice_dims,
145 &shape.column_transposed_dims,
147 lc_op_util::SetYBufferShape(
150 shape.output_image_size,
153 &shape.Y_transposed_dims,
156 const T* X_data = X.template data<T>();
157 const T* filter_data = filter.template data<T>();
158 const T* bias_data =
nullptr;
159 if (InputSize() == 3) {
160 const auto& bias = Input(BIAS);
161 CAFFE_ENFORCE_EQ(bias.dim(), image_ndim + 1);
162 for (
int i = 0; i < image_ndim; ++i) {
163 CAFFE_ENFORCE_EQ(bias.dim32(i), output_image_dims[i]);
165 CAFFE_ENFORCE_EQ(bias.dim32(image_ndim), shape.M);
166 bias_data = bias.template data<T>();
167 ConvPoolOpBase<Context>::template SetBiasMultiplier<T>(
168 shape.N, &bias_multiplier_);
170 T* Y_data = Y->template mutable_data<T>();
172 RunOnDeviceWithOrderNHWCImpl(
179 &column_transposed_buffer_,
180 &Y_transposed_buffer_);
185 template <
typename T,
class Context>
186 void LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNCHWImpl(
187 const lc_op_util::ShapeParams& shape,
189 const T* filter_data,
193 Tensor* column_transposed_buffer,
194 Tensor* Y_transposed_buffer) {
195 const int input_stride = shape.C / group_ * shape.input_image_size;
196 const int column_stride = shape.kernel_size * shape.output_image_size;
197 column_buffer->Resize(shape.column_dims);
198 column_transposed_buffer->Resize(shape.column_transposed_dims);
199 Y_transposed_buffer->Resize(shape.Y_transposed_dims);
200 T* column_buffer_data = column_buffer->template mutable_data<T>();
201 T* Y_transposed_buffer_data = Y_transposed_buffer->template mutable_data<T>();
203 for (
int image_id = 0; image_id < shape.N; ++image_id) {
204 for (
int group_id = 0; group_id < group_; ++group_id) {
205 if (kernel_.size() == 2) {
206 math::Im2Col<T, Context, StorageOrder::NCHW>(
220 X_data + group_id * input_stride,
221 column_buffer_data + group_id * column_stride,
224 math::Im2ColNd<T, Context, StorageOrder::NCHW>(
226 shape.C * shape.input_image_size,
229 shape.column_slice_dims.data(),
234 X_data + group_id * input_stride,
235 column_buffer_data + group_id * column_stride,
239 X_data += input_stride * group_;
240 column_buffer_data += column_stride * group_;
243 shape.column_dims.size(),
244 shape.column_dims.data(),
245 shape.column_axes.data(),
246 column_buffer->template data<T>(),
247 column_transposed_buffer->template mutable_data<T>(),
249 math::GemmStridedBatched(
252 shape.output_image_size * group_,
258 shape.M / group_ * shape.kernel_size,
259 column_transposed_buffer->template data<T>(),
260 shape.kernel_size * shape.N,
262 Y_transposed_buffer_data,
263 shape.M / group_ * shape.N,
265 if (bias_data !=
nullptr) {
266 math::Gemm<T, Context>(
269 shape.output_image_size * shape.M,
274 bias_multiplier_.template data<T>(),
276 Y_transposed_buffer_data,
280 shape.Y_transposed_dims.size(),
281 shape.Y_transposed_dims.data(),
283 Y_transposed_buffer_data,
288 template <
typename T,
class Context>
289 void LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNHWCImpl(
290 const lc_op_util::ShapeParams& shape,
292 const T* filter_data,
296 Tensor* column_transposed_buffer,
297 Tensor* Y_transposed_buffer) {
298 const int input_stride = shape.C * shape.input_image_size;
299 const int column_stride = shape.kernel_size * shape.output_image_size;
300 column_buffer->Resize(shape.column_dims);
301 column_transposed_buffer->Resize(shape.column_transposed_dims);
302 Y_transposed_buffer->Resize(shape.Y_transposed_dims);
303 T* column_buffer_data = column_buffer->template mutable_data<T>();
304 T* Y_transposed_buffer_data = Y_transposed_buffer->template mutable_data<T>();
305 for (
int image_id = 0; image_id < shape.N; ++image_id) {
306 math::Im2Col<T, Context, StorageOrder::NHWC>(
320 X_data + image_id * input_stride,
321 column_buffer_data + image_id * column_stride,
325 shape.column_dims.size(),
326 shape.column_dims.data(),
327 shape.column_axes.data(),
328 column_buffer->template data<T>(),
329 column_transposed_buffer->template mutable_data<T>(),
331 math::GemmStridedBatched(
334 shape.output_image_size,
339 column_transposed_buffer->template data<T>(),
340 shape.N * shape.kernel_size,
342 shape.kernel_size * shape.M,
344 Y_transposed_buffer_data,
348 shape.Y_transposed_dims.size(),
349 shape.Y_transposed_dims.data(),
351 Y_transposed_buffer_data,
354 if (bias_data !=
nullptr) {
355 math::Gemm<T, Context>(
359 shape.output_image_size * shape.M,
362 bias_multiplier_.template data<T>(),
370 template <
typename T,
class Context>
371 bool LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNCHW() {
372 const auto& X = Input(INPUT);
373 const auto& filter = Input(FILTER);
374 const auto& dY = Input(OUTPUT_GRAD);
376 const int image_ndim = X.dim() - 2;
377 CAFFE_ENFORCE_EQ(X.dim() + image_ndim, filter.dim());
379 lc_op_util::ShapeParams shape;
380 shape.N = X.dim32(0);
381 shape.C = X.dim32(1);
382 shape.M = filter.dim32(image_ndim);
383 CAFFE_ENFORCE_EQ(filter.dim32(image_ndim + 1) * group_, shape.C);
384 CAFFE_ENFORCE_EQ(shape.M % group_, 0);
386 const std::vector<int> input_image_dims = GetDims(X);
387 shape.input_image_size = GetDimsSize(X);
388 const std::vector<int> output_image_dims = GetDims(dY);
389 shape.output_image_size = GetDimsSize(dY);
390 for (
int i = 0; i < image_ndim; ++i) {
391 CAFFE_ENFORCE_EQ(output_image_dims[i], filter.dim32(i));
393 ConvPoolOpBase<Context>::ComputePads(input_image_dims);
395 int kernel_dims_size = 1;
396 for (std::size_t i = 0; i < kernel_.size(); ++i) {
397 CAFFE_ENFORCE_EQ(filter.dim32(i + image_ndim + 2), kernel_[i]);
398 kernel_dims_size *= kernel_[i];
401 shape.X_dims.assign(X.sizes().cbegin() + 1, X.sizes().cend());
402 shape.kernel_size = shape.C / group_ * kernel_dims_size;
403 lc_op_util::SetColumnBufferShape(
406 shape.output_image_size,
409 &shape.column_slice_dims,
411 &shape.column_transposed_dims,
413 lc_op_util::SetYBufferShape(
416 shape.output_image_size,
419 &shape.Y_transposed_dims,
422 auto* dfilter = Output(FILTER_GRAD, filter.sizes(), at::dtype<T>());
423 const T* X_data = X.template data<T>();
424 const T* filter_data = filter.template data<T>();
425 const T* dY_data = dY.template data<T>();
426 T* dfilter_data = dfilter->template mutable_data<T>();
427 T* dX_data =
nullptr;
428 T* dbias_data =
nullptr;
429 if (OutputSize() == 3 || (no_bias_ && OutputSize() == 2)) {
431 no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD, X.sizes(), at::dtype<T>());
432 dX_data = dX->template mutable_data<T>();
435 std::vector<int64_t> dbias_dims;
437 output_image_dims.cbegin(),
438 output_image_dims.cend(),
439 std::back_inserter(dbias_dims));
440 dbias_dims.push_back(shape.M);
441 auto* dbias = Output(BIAS_OR_INPUT_GRAD, dbias_dims, at::dtype<T>());
442 ConvPoolOpBase<Context>::template SetBiasMultiplier<T>(
443 shape.N, &bias_multiplier_);
444 dbias_data = dbias->template mutable_data<T>();
446 RunOnDeviceWithOrderNCHWImpl(
455 &column_transposed_buffer_,
456 &dY_transposed_buffer_);
461 template <
typename T,
class Context>
462 bool LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNHWC() {
463 const auto& X = Input(INPUT);
464 const auto& filter = Input(FILTER);
465 const auto& dY = Input(OUTPUT_GRAD);
470 "Only 2d locally connected op is supported for NHWC storage type.");
471 const int image_ndim = X.dim() - 2;
472 CAFFE_ENFORCE_EQ(X.dim() + image_ndim, filter.dim());
473 lc_op_util::ShapeParams shape;
474 shape.N = X.dim32(0);
475 shape.C = X.dim32(3);
476 shape.X_dims = {X.dim32(1), X.dim32(2), X.dim32(3)};
477 shape.M = filter.dim32(image_ndim);
478 CAFFE_ENFORCE_EQ(filter.dim32(image_ndim + 1), kernel_h());
479 CAFFE_ENFORCE_EQ(filter.dim32(image_ndim + 2), kernel_w());
480 CAFFE_ENFORCE_EQ(filter.dim32(image_ndim + 3), shape.C);
481 const std::vector<int> input_image_dims = {X.dim32(1), X.dim32(2)};
482 ConvPoolOpBase<Context>::ComputePads(input_image_dims);
484 shape.input_image_size = GetDimsSize(X);
485 shape.output_image_size = GetDimsSize(dY);
486 const std::vector<int> output_image_dims = GetDims(dY);
487 for (
int i = 0; i < image_ndim; ++i) {
488 CAFFE_ENFORCE_EQ(output_image_dims[i], filter.dim32(i));
491 shape.kernel_size = kernel_h() * kernel_w() * shape.C;
492 lc_op_util::SetColumnBufferShape(
495 shape.output_image_size,
498 &shape.column_slice_dims,
500 &shape.column_transposed_dims,
502 lc_op_util::SetYBufferShape(
505 shape.output_image_size,
508 &shape.Y_transposed_dims,
511 auto* dfilter = Output(FILTER_GRAD, filter.sizes(), at::dtype<T>());
512 const T* X_data = X.template data<T>();
513 const T* filter_data = filter.template data<T>();
514 const T* dY_data = dY.template data<T>();
515 T* dfilter_data = dfilter->template mutable_data<T>();
516 T* dX_data =
nullptr;
517 T* dbias_data =
nullptr;
518 if (OutputSize() == 3 || (no_bias_ && OutputSize() == 2)) {
520 no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD, X.sizes(), at::dtype<T>());
521 dX_data = dX->template mutable_data<T>();
524 std::vector<int64_t> dbias_dims;
526 output_image_dims.cbegin(),
527 output_image_dims.cend(),
528 std::back_inserter(dbias_dims));
529 dbias_dims.push_back(shape.M);
530 auto* dbias = Output(BIAS_OR_INPUT_GRAD, dbias_dims, at::dtype<T>());
531 ConvPoolOpBase<Context>::template SetBiasMultiplier<T>(
532 shape.N, &bias_multiplier_);
533 dbias_data = dbias->template mutable_data<T>();
535 RunOnDeviceWithOrderNHWCImpl(
544 &column_transposed_buffer_,
545 &dY_transposed_buffer_);
550 template <
typename T,
class Context>
551 void LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNCHWImpl(
552 const lc_op_util::ShapeParams& shape,
554 const T* filter_data,
560 Tensor* column_transposed_buffer,
561 Tensor* dY_transposed_buffer) {
562 const int input_stride = shape.C * shape.input_image_size;
563 const int column_stride = shape.kernel_size * shape.output_image_size;
564 column_buffer->Resize(shape.column_dims);
565 column_transposed_buffer->Resize(shape.column_transposed_dims);
566 dY_transposed_buffer->Resize(shape.Y_transposed_dims);
567 T* column_buffer_data = column_buffer->template mutable_data<T>();
568 T* dY_transposed_buffer_data =
569 dY_transposed_buffer->template mutable_data<T>();
571 for (
int image_id = 0; image_id < shape.N; ++image_id) {
572 for (
int group_id = 0; group_id < group_; ++group_id) {
573 if (kernel_.size() == 2) {
574 math::Im2Col<T, Context, StorageOrder::NCHW>(
588 X_data + group_id * input_stride,
589 column_buffer_data + group_id * column_stride,
592 math::Im2ColNd<T, Context, StorageOrder::NCHW>(
594 shape.C * shape.input_image_size,
597 shape.column_slice_dims.data(),
602 X_data + group_id * input_stride,
603 column_buffer_data + group_id * column_stride,
607 X_data += input_stride * group_;
608 column_buffer_data += column_stride * group_;
611 shape.column_dims.size(),
612 shape.column_dims.data(),
613 shape.column_axes.data(),
614 column_buffer->template data<T>(),
615 column_transposed_buffer->template mutable_data<T>(),
623 dY_transposed_buffer_data,
627 math::GemmStridedBatched(
630 shape.output_image_size * group_,
635 dY_transposed_buffer_data,
636 shape.M / group_ * shape.N,
637 column_transposed_buffer->template data<T>(),
638 shape.N * shape.kernel_size,
641 shape.M / group_ * shape.kernel_size,
644 if (dbias_data !=
nullptr) {
646 math::Gemv<T, Context>(
648 shape.output_image_size * shape.M,
651 dY_transposed_buffer_data,
652 bias_multiplier_.template data<T>(),
658 if (dX_data !=
nullptr) {
660 math::GemmStridedBatched(
663 shape.output_image_size * group_,
669 shape.kernel_size * shape.M / group_,
670 dY_transposed_buffer_data,
671 shape.M / group_ * shape.N,
673 column_transposed_buffer->template mutable_data<T>(),
674 shape.kernel_size * shape.N,
677 shape.column_transposed_dims.size(),
678 shape.column_transposed_dims.data(),
679 shape.column_axes.data(),
680 column_transposed_buffer->template data<T>(),
681 column_buffer->template mutable_data<T>(),
683 const T* const_column_buffer_data = column_buffer->template data<T>();
684 for (
int image_id = 0; image_id < shape.N; ++image_id) {
685 for (
int group_id = 0; group_id < group_; ++group_id) {
686 if (kernel_.size() == 2) {
687 math::Col2Im<T, Context, StorageOrder::NCHW>(
701 const_column_buffer_data + group_id * column_stride,
702 dX_data + group_id * input_stride,
705 math::Col2ImNd<T, Context, StorageOrder::NCHW>(
707 shape.C * shape.input_image_size,
710 shape.column_slice_dims.data(),
715 const_column_buffer_data + group_id * column_stride,
716 dX_data + group_id * input_stride,
720 dX_data += input_stride * group_;
721 const_column_buffer_data += column_stride * group_;
726 template <
typename T,
class Context>
727 void LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNHWCImpl(
728 const lc_op_util::ShapeParams& shape,
730 const T* filter_data,
736 Tensor* column_transposed_buffer,
737 Tensor* dY_transposed_buffer) {
738 const int input_stride = shape.C * shape.input_image_size;
739 const int column_stride = shape.kernel_size * shape.output_image_size;
740 column_buffer->Resize(shape.column_dims);
741 column_transposed_buffer->Resize(shape.column_transposed_dims);
742 dY_transposed_buffer->Resize(shape.Y_transposed_dims);
743 T* column_buffer_data = column_buffer->template mutable_data<T>();
744 T* dY_transposed_buffer_data =
745 dY_transposed_buffer->template mutable_data<T>();
746 for (
int image_id = 0; image_id < shape.N; ++image_id) {
747 math::Im2Col<T, Context, StorageOrder::NHWC>(
761 X_data + image_id * input_stride,
762 column_buffer_data + image_id * column_stride,
766 shape.column_dims.size(),
767 shape.column_dims.data(),
768 shape.column_axes.data(),
769 column_buffer->template data<T>(),
770 column_transposed_buffer->template mutable_data<T>(),
777 dY_transposed_buffer_data,
781 math::GemmStridedBatched(
784 shape.output_image_size,
789 dY_transposed_buffer_data,
791 column_transposed_buffer->template data<T>(),
792 shape.N * shape.kernel_size,
795 shape.M * shape.kernel_size,
798 if (dbias_data !=
nullptr) {
800 math::Gemv<T, Context>(
803 shape.output_image_size * shape.M,
806 bias_multiplier_.template data<T>(),
812 if (dX_data !=
nullptr) {
814 math::GemmStridedBatched(
817 shape.output_image_size,
822 dY_transposed_buffer_data,
825 shape.M * shape.kernel_size,
827 column_transposed_buffer->template mutable_data<T>(),
828 shape.N * shape.kernel_size,
831 shape.column_transposed_dims.size(),
832 shape.column_transposed_dims.data(),
833 shape.column_axes.data(),
834 column_transposed_buffer->template data<T>(),
835 column_buffer->template mutable_data<T>(),
837 const T* const_column_buffer_data = column_buffer->template data<T>();
838 for (
int image_id = 0; image_id < shape.N; ++image_id) {
839 math::Col2Im<T, Context, StorageOrder::NHWC>(
853 const_column_buffer_data,
856 dX_data += input_stride;
857 const_column_buffer_data += column_stride;
864 #endif // CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_IMPL_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...