1 #ifndef CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_H_ 2 #define CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_H_ 6 #include "caffe2/core/context.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/operators/conv_op_shared.h" 9 #include "caffe2/operators/conv_pool_op_base.h" 10 #include "caffe2/operators/locally_connected_op_util.h" 14 template <
typename T,
class Context>
17 USE_CONV_POOL_BASE_FUNCTIONS(Context);
19 template <
class... Args>
25 group_ == 1 || order_ == StorageOrder::NCHW,
26 "Group locally connected only supports NCHW order right now.");
29 ~LocallyConnectedOp() =
default;
31 bool RunOnDeviceWithOrderNCHW()
override;
32 bool RunOnDeviceWithOrderNHWC()
override;
35 void RunOnDeviceWithOrderNCHWImpl(
42 Tensor* column_transposed_buffer,
45 void RunOnDeviceWithOrderNHWCImpl(
52 Tensor* column_transposed_buffer,
53 Tensor* Y_transposed_buffer);
55 Tensor bias_multiplier_{Context::GetDeviceType()};
58 Tensor column_buffer_{Context::GetDeviceType()};
59 Tensor column_transposed_buffer_{Context::GetDeviceType()};
60 Tensor Y_transposed_buffer_{Context::GetDeviceType()};
64 INPUT_TAGS(INPUT, FILTER, BIAS);
67 template <
typename T,
class Context>
70 USE_CONV_POOL_BASE_FUNCTIONS(Context);
72 template <
class... Args>
75 OP_SINGLE_ARG(
bool,
"no_bias", no_bias_,
false) {
77 !(no_bias_ && OutputSize() == 3),
78 "If bias is not present, you should not have 3 grad output.");
80 group_ == 1 || order_ == StorageOrder::NCHW,
81 "Group locally connected only supports NCHW order right now.");
84 ~LocallyConnectedGradientOp() =
default;
86 bool RunOnDeviceWithOrderNCHW()
override;
87 bool RunOnDeviceWithOrderNHWC()
override;
90 void RunOnDeviceWithOrderNCHWImpl(
99 Tensor* column_transposed_buffer,
100 Tensor* dY_transposed_buffer);
102 void RunOnDeviceWithOrderNHWCImpl(
105 const T* filter_data,
111 Tensor* column_transposed_buffer,
112 Tensor* dY_transposed_buffer);
116 Tensor bias_multiplier_{Context::GetDeviceType()};
119 Tensor column_buffer_{Context::GetDeviceType()};
120 Tensor column_transposed_buffer_{Context::GetDeviceType()};
121 Tensor dY_transposed_buffer_{Context::GetDeviceType()};
125 INPUT_TAGS(INPUT, FILTER, OUTPUT_GRAD);
126 OUTPUT_TAGS(FILTER_GRAD, BIAS_OR_INPUT_GRAD, INPUT_GRAD);
131 #endif // CAFFE2_OPERATORS_LOCALLY_CONNECTED_OP_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...