19 #include "fully_connected_fake_lowp_op.h" 23 constexpr
int nlines_log = 10000;
26 void (*Q)(
const float*, size_t,
float*),
30 template <typename T_X, typename T_W, typename T_B, typename T_Y, typename MATH>
31 bool FullyConnectedFakeLowpFPOp<Q, Context, Engine, TransposeWeight>::
33 const auto& X = Input(0);
34 const auto& W = Input(1);
35 const auto& b = Input(2);
37 CAFFE_ENFORCE(b.dim() == 1, b.dim());
39 const auto canonical_axis = X.canonical_axis_index(axis_);
40 const auto M = X.size_to_dim(canonical_axis);
41 const auto K = X.size_from_dim(canonical_axis);
42 const auto canonical_axis_w = W.canonical_axis_index(axis_w_);
43 const int N = TransposeWeight ? W.size_to_dim(canonical_axis_w)
44 : W.size_from_dim(canonical_axis_w);
46 auto dimErrorString = [&]() {
48 "Dimension mismatch: ",
66 CAFFE_ENFORCE(
M == X.size() / K, dimErrorString());
67 CAFFE_ENFORCE(K == W.size() / N, dimErrorString());
68 CAFFE_ENFORCE(N == b.dim32(0), dimErrorString());
69 CAFFE_ENFORCE(N == b.size(), dimErrorString());
71 static int log_occurences = 0;
72 if (log_occurences % nlines_log == 0) {
74 LOG(INFO) <<
"FAKE_FP16 fc running";
77 Y_shape_cache_ = X.sizes().vec();
79 DCHECK_LE(canonical_axis + 1, Y_shape_cache_.size());
80 Y_shape_cache_.resize(canonical_axis + 1);
81 Y_shape_cache_[canonical_axis] = N;
82 auto* Y = Output(0, Y_shape_cache_, at::dtype<T_Y>());
83 CAFFE_ENFORCE(
M * N == Y->size(), dimErrorString());
87 Y->template mutable_data<T_Y>();
92 TensorProto::DataType math_type = TensorProto_DataType_FLOAT;
93 if (fp16_type<MATH>()) {
94 math_type = TensorProto_DataType_FLOAT16;
99 auto type = Context::GetDeviceType();
102 Q(X.template data<T_X>(), Xh.size(), Xh.template mutable_data<T_X>());
106 Q(W.template data<T_W>(), Wh.size(), Wh.template mutable_data<T_W>());
110 Q(b.template data<T_B>(), bh.size(), bh.template mutable_data<T_B>());
113 math::Gemm<T_X, Context, Engine>(
115 TransposeWeight ? CblasTrans : CblasNoTrans,
120 Xh.template data<T_X>(),
121 Wh.template data<T_W>(),
123 Y->template mutable_data<T_Y>(),
127 if (bias_multiplier_.size() !=
M) {
132 at::dtype<T_B>().device(Context::GetDeviceType()));
133 math::Set<T_B, Context>(
135 convert::To<float, T_B>(1),
136 bias_multiplier_.template mutable_data<T_B>(),
139 math::Gemm<T_B, Context, Engine>(
146 bias_multiplier_.template data<T_B>(),
147 bh.template data<T_B>(),
149 Y->template mutable_data<T_Y>(),
157 void (*Q)(
const float*, size_t,
float*),
160 bool TransposeWeight>
170 bool FullyConnectedGradientFakeLowpFPOp<Q, Context, Engine, TransposeWeight>::
172 const auto& X = Input(0);
173 const auto& W = Input(1);
174 const auto& dY = Input(2);
176 const auto canonical_axis = X.canonical_axis_index(axis_);
177 const int M = X.size_to_dim(canonical_axis);
178 const int K = X.size_from_dim(canonical_axis);
179 const auto canonical_axis_w = W.canonical_axis_index(axis_w_);
180 const int N = TransposeWeight ? W.size_to_dim(canonical_axis_w)
181 : W.size_from_dim(canonical_axis_w);
182 CAFFE_ENFORCE(M * K == X.size());
183 CAFFE_ENFORCE(K * N == W.size());
185 auto* dW = Output(0, W.sizes(), at::dtype<T_DW>());
186 auto* db = Output(1, {N}, at::dtype<T_DB>());
190 math::Set<T_DB, Context>(
192 convert::To<float, T_DB>(0),
193 db->template mutable_data<T_DB>(),
195 math::Set<T_DW, Context>(
197 convert::To<float, T_DW>(0),
198 dW->template mutable_data<T_DW>(),
201 if (OutputSize() == 3) {
202 Output(2, X.sizes(), at::dtype<T_DX>());
209 TensorProto::DataType math_type = TensorProto_DataType_FLOAT;
210 if (fp16_type<MATH>()) {
211 math_type = TensorProto_DataType_FLOAT16;
214 auto type = Context::GetDeviceType();
218 Q(X.template data<T_X>(), Xh.size(), Xh.template mutable_data<T_X>());
222 Q(W.template data<T_W>(), Wh.size(), Wh.template mutable_data<T_W>());
226 Q(dY.template data<T_DY>(), dYh.size(), dYh.template mutable_data<T_DY>());
228 static int log_occurences = 0;
229 if (log_occurences % nlines_log == 0) {
231 LOG(INFO) <<
"FAKE_FP16 fc grad running";
235 math::Gemm<T_DY, Context, Engine>(
238 TransposeWeight ? N : K,
239 TransposeWeight ? K : N,
242 TransposeWeight ? dYh.template data<T_DY>() : Xh.template data<T_X>(),
243 TransposeWeight ? Xh.template data<T_X>() : dYh.template data<T_DY>(),
245 dW->template mutable_data<T_DW>(),
248 if (bias_multiplier_.size() != M) {
254 at::dtype<T_B>().device(Context::GetDeviceType()));
255 math::Set<T_B, Context>(
257 convert::To<float, T_B>(1),
258 bias_multiplier_.template mutable_data<T_B>(),
262 math::Gemv<T_DY, Context>(
267 dYh.template data<T_DY>(),
268 bias_multiplier_.template data<T_B>(),
270 db->template mutable_data<T_DB>(),
274 if (OutputSize() == 3) {
275 auto* dX = Output(2, X.sizes(), at::dtype<T_DX>());
276 math::Gemm<T_DX, Context, Engine>(
278 TransposeWeight ? CblasNoTrans : CblasTrans,
283 dYh.template data<T_DY>(),
284 Wh.template data<T_W>(),
286 dX->template mutable_data<T_DX>(),
295 REGISTER_CPU_OPERATOR_WITH_ENGINE(
298 FullyConnectedFakeLowpFPOp<fp32_to_fp16, CPUContext>);
299 REGISTER_CPU_OPERATOR_WITH_ENGINE(
302 FullyConnectedGradientFakeLowpFPOp<fp32_to_fp16, CPUContext>);
305 REGISTER_CPU_OPERATOR_WITH_ENGINE(
308 FullyConnectedFakeLowpFPOp<fp32_to_bfp16, CPUContext>);
309 REGISTER_CPU_OPERATOR_WITH_ENGINE(
312 FullyConnectedGradientFakeLowpFPOp<fp32_to_bfp16, CPUContext>);
315 REGISTER_CPU_OPERATOR_WITH_ENGINE(
318 FullyConnectedFakeLowpFPOp<fp32_to_bfp24, CPUContext>);
319 REGISTER_CPU_OPERATOR_WITH_ENGINE(
322 FullyConnectedGradientFakeLowpFPOp<fp32_to_bfp24, CPUContext>);
325 REGISTER_CPU_OPERATOR_WITH_ENGINE(
328 FullyConnectedFakeLowpFPOp<fp32_to_bfp14, CPUContext>);
329 REGISTER_CPU_OPERATOR_WITH_ENGINE(
332 FullyConnectedGradientFakeLowpFPOp<fp32_to_bfp14, CPUContext>);
335 REGISTER_CPU_OPERATOR_WITH_ENGINE(
338 FullyConnectedFakeLowpFPOp<fp32_to_bfp16_round, CPUContext>);
339 REGISTER_CPU_OPERATOR_WITH_ENGINE(
342 FullyConnectedGradientFakeLowpFPOp<fp32_to_bfp16_round, CPUContext>);
void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...