1 #ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_ 2 #define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_ 4 #include <c10/util/Optional.h> 5 #include "caffe2/core/context.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/utils/conversions.h" 8 #include "caffe2/utils/math.h" 10 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN 19 class Engine = DefaultEngine,
20 bool TransposeWeight =
true>
23 USE_OPERATOR_CONTEXT_FUNCTIONS;
24 template <
class... Args>
27 axis_(this->
template GetSingleArgument<int32_t>(
"axis", 1)),
28 axis_w_(this->
template GetSingleArgument<int32_t>(
"axis_w", 1)),
30 this->
template GetSingleArgument<bool>(
"float16_compute",
false)) {}
31 ~FullyConnectedOp() {}
39 bool DoRunWithType() {
40 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN 41 std::chrono::time_point<std::chrono::system_clock> t_very_begin, t_begin,
45 t_begin = std::chrono::system_clock::now();
46 t_very_begin = t_begin;
50 const auto& X =
Input(0);
51 const auto& W =
Input(1);
52 const auto& b =
Input(2);
54 CAFFE_ENFORCE(b.dim() == 1, b.dim());
56 const auto canonical_axis = X.canonical_axis_index(axis_);
57 const auto M = X.size_to_dim(canonical_axis);
58 const auto K = X.size_from_dim(canonical_axis);
59 const auto canonical_axis_w = W.canonical_axis_index(axis_w_);
60 const int N = TransposeWeight ? W.size_to_dim(canonical_axis_w)
61 : W.size_from_dim(canonical_axis_w);
63 auto dimErrorString = [&]() {
65 "Dimension mismatch: ",
83 CAFFE_ENFORCE(
M == X.numel() / K, dimErrorString());
84 CAFFE_ENFORCE(K == W.numel() / N, dimErrorString());
85 CAFFE_ENFORCE(N == b.dim32(0), dimErrorString());
86 CAFFE_ENFORCE(N == b.numel(), dimErrorString());
88 Y_shape_cache_ = X.sizes().vec();
90 DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size());
91 Y_shape_cache_.resize(canonical_axis + 1);
92 Y_shape_cache_[canonical_axis] = N;
93 auto* Y = Output(0, Y_shape_cache_, at::dtype<T_Y>());
94 CAFFE_ENFORCE(
M * N == Y->numel(), dimErrorString());
98 Y->template mutable_data<T_Y>();
103 TensorProto::DataType math_type = TensorProto_DataType_FLOAT;
104 if (fp16_type<MATH>()) {
105 math_type = TensorProto_DataType_FLOAT16;
108 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN 111 t_end = std::chrono::system_clock::now();
112 double dt = std::chrono::duration<double>(t_end - t_begin).count();
113 LOG(INFO) <<
"@PERF this=" <<
this <<
" before_gemm: " << dt * 1e3
115 t_begin = std::chrono::system_clock::now();
119 math::Gemm<T_X, Context, Engine>(
121 TransposeWeight ? CblasTrans : CblasNoTrans,
126 X.template data<T_X>(),
127 W.template data<T_W>(),
129 Y->template mutable_data<T_Y>(),
133 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN 136 t_end = std::chrono::system_clock::now();
137 double dt = std::chrono::duration<double>(t_end - t_begin).count();
138 LOG(INFO) <<
"@PERF this=" <<
this <<
" gemm: " << dt * 1e3 <<
" ms";
139 t_begin = std::chrono::system_clock::now();
143 if (!bias_multiplier_.has_value()) {
145 caffe2::empty({M}, at::dtype<T_B>().device(Context::GetDeviceType()));
146 math::Set<T_B, Context>(
148 convert::To<float, T_B>(1),
149 bias_multiplier_->template mutable_data<T_B>(),
151 }
else if (bias_multiplier_->numel() != M) {
152 bias_multiplier_->Resize(M);
153 math::Set<T_B, Context>(
155 convert::To<float, T_B>(1),
156 bias_multiplier_->template mutable_data<T_B>(),
160 math::Gemm<T_B, Context, Engine>(
167 bias_multiplier_->template data<T_B>(),
168 b.template data<T_B>(),
170 Y->template mutable_data<T_Y>(),
174 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN 177 t_end = std::chrono::system_clock::now();
178 double dt = std::chrono::duration<double>(t_end - t_begin).count();
179 LOG(INFO) <<
"@PERF this=" <<
this <<
" add_bias : " << dt * 1e3 <<
" ms";
180 t_begin = std::chrono::system_clock::now();
186 bool RunOnDevice()
override {
187 return DoRunWithType<
200 vector<int64_t> Y_shape_cache_;
203 bool float16_compute_;
209 bool TransposeWeight =
true>
212 USE_OPERATOR_CONTEXT_FUNCTIONS;
213 template <
class... Args>
216 axis_(this->
template GetSingleArgument<int32_t>(
"axis", 1)),
217 axis_w_(this->
template GetSingleArgument<int32_t>(
"axis_w", 1)),
219 this->
template GetSingleArgument<bool>(
"float16_compute",
false)) {}
220 ~FullyConnectedGradientOp() {}
231 bool DoRunWithType() {
232 const auto& X =
Input(0);
233 const auto& W =
Input(1);
234 const auto& dY =
Input(2);
236 const auto canonical_axis = X.canonical_axis_index(axis_);
237 const int M = X.size_to_dim(canonical_axis);
238 const int K = X.size_from_dim(canonical_axis);
239 const auto canonical_axis_w = W.canonical_axis_index(axis_w_);
240 const int N = TransposeWeight ? W.size_to_dim(canonical_axis_w)
241 : W.size_from_dim(canonical_axis_w);
243 auto dimErrorString = [&]() {
245 "Dimension mismatch: ",
262 CAFFE_ENFORCE(M * K == X.numel(), dimErrorString());
263 CAFFE_ENFORCE(K * N == W.numel(), dimErrorString());
265 auto* dW = Output(0, W.sizes(), at::dtype<T_DW>());
266 auto* db = Output(1, {N}, at::dtype<T_DB>());
268 if (X.numel() == 0) {
270 math::Set<T_DB, Context>(
272 convert::To<float, T_DB>(0),
273 db->template mutable_data<T_DB>(),
275 math::Set<T_DW, Context>(
277 convert::To<float, T_DW>(0),
278 dW->template mutable_data<T_DW>(),
281 if (OutputSize() == 3) {
282 Output(2, X.sizes(), at::dtype<T_DX>());
289 TensorProto::DataType math_type = TensorProto_DataType_FLOAT;
290 if (fp16_type<MATH>()) {
291 math_type = TensorProto_DataType_FLOAT16;
295 math::Gemm<T_DY, Context, Engine>(
298 TransposeWeight ? N : K,
299 TransposeWeight ? K : N,
302 TransposeWeight ? dY.template data<T_DY>() : X.template data<T_X>(),
303 TransposeWeight ? X.template data<T_X>() : dY.template data<T_DY>(),
305 dW->template mutable_data<T_DW>(),
308 if (!bias_multiplier_.has_value()) {
309 bias_multiplier_ = caffe2::empty({M}, at::dtype<T_B>().device(Context::GetDeviceType()));
310 math::Set<T_B, Context>(
312 convert::To<float, T_B>(1),
313 bias_multiplier_->template mutable_data<T_B>(),
315 }
else if (bias_multiplier_->numel() != M) {
316 bias_multiplier_->Resize(M);
317 math::Set<T_B, Context>(
319 convert::To<float, T_B>(1),
320 bias_multiplier_->template mutable_data<T_B>(),
324 math::Gemv<T_DY, Context>(
329 dY.template data<T_DY>(),
330 bias_multiplier_->template data<T_B>(),
332 db->template mutable_data<T_DB>(),
336 if (OutputSize() == 3) {
337 auto* dX = Output(2, X.sizes(), at::dtype<T_DX>());
338 math::Gemm<T_DX, Context, Engine>(
340 TransposeWeight ? CblasNoTrans : CblasTrans,
345 dY.template data<T_DY>(),
346 W.template data<T_W>(),
348 dX->template mutable_data<T_DX>(),
355 bool RunOnDevice()
override {
356 return DoRunWithType<
371 bool float16_compute_;
376 #endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...