1 #include "conv_dnnlowp_acc16_op.h" 5 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN 12 #include "dnnlowp_op.h" 13 #include "dnnlowp_partition.h" 14 #include "fbgemm_pack_op.h" 15 #include "im2col_dnnlowp.h" 17 C10_DECLARE_int32(caffe2_dnnlowp_nbits_in_non_outlier);
18 C10_DECLARE_int32(caffe2_dnnlowp_copy_to_32bit_frequency);
19 C10_DECLARE_bool(caffe2_dnnlowp_shared_int32_buffer);
24 caffe2_dnnlowp_acc16_density_threshold,
26 "If density of outlier is higher than this, fallback to 32-bit accumulation");
28 caffe2_dnnlowp_acc16_m_threshold,
30 "If m is smaller than this, fallback to 32-bit accumulation");
32 caffe2_dnnlowp_acc16_n_threshold,
34 "If n is smaller than this, fallback to 32-bit accumulation");
36 caffe2_dnnlowp_acc16_k_threshold,
38 "If k is smaller than this, fallback to 32-bit accumulation");
44 template <
bool ReluFused>
45 ConvDNNLowPAcc16Op<ReluFused>::ConvDNNLowPAcc16Op(
46 const OperatorDef& operator_def,
48 : ConvDNNLowPOp<uint8_t, ReluFused>(operator_def, ws),
49 nbits_in_non_outlier_(this->template GetSingleArgument<int>(
50 "nbits_in_non_outlier",
51 FLAGS_caffe2_dnnlowp_nbits_in_non_outlier)),
52 copy_to_32bit_frequency_(this->template GetSingleArgument<int>(
53 "copy_to_32bit_frequency",
54 FLAGS_caffe2_dnnlowp_copy_to_32bit_frequency)) {
55 if (nbits_in_non_outlier_ == 0) {
56 LOG(INFO) <<
"nbits_in_non_outlier == 0 means everything is outlier so we " 58 fallback_to_32_bit_accumulation_ =
true;
62 template <
bool ReluFused>
63 bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
64 if (fallback_to_32_bit_accumulation_) {
68 if (!BaseType::GetQuantizationParameters_()) {
72 if (!Wq_acc16_packed_ &&
73 this->
template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
77 "Pre-packed weight only works with NHWC layout");
79 const auto& packed_filter =
80 this->
template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
81 Wq_outlier_ = packed_filter.W_outlier;
82 Wq_acc16_packed_ = packed_filter.W_acc16;
84 if (nbits_in_non_outlier_ != packed_filter.nbits_in_non_outlier) {
86 <<
"nbits_in_non_outlier in packed weight " 87 << packed_filter.nbits_in_non_outlier
88 <<
" doesn't match with nbits_in_non_outlier specified in operator " 89 << nbits_in_non_outlier_;
92 first_invocation_ =
false;
96 int kernel_dim = this->KernelDim_();
97 const auto& filter = InputTensorCPU_(FILTER);
98 int num_out_channels = filter.dim32(0);
101 if (this->order_ == StorageOrder::NHWC) {
102 const Tensor& X = InputTensorCPU_(INPUT);
105 auto sizes = this->GetOutputSize(X, filter.dim32(0));
106 Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype<uint8_t>());
107 const int output_image_size = this->GetDimsSize(*Y);
109 if (N * output_image_size < FLAGS_caffe2_dnnlowp_acc16_m_threshold) {
110 LOG(INFO) <<
"M " << N * output_image_size
111 <<
" of Conv layer with weight blob " 112 << this->debug_def().input(1) <<
" is smaller than threshold " 113 << FLAGS_caffe2_dnnlowp_acc16_m_threshold
114 <<
" . Falling back to acc32";
115 fallback_to_32_bit_accumulation_ =
true;
118 if (num_out_channels / group_ < FLAGS_caffe2_dnnlowp_acc16_n_threshold) {
119 LOG(INFO) <<
"N " << num_out_channels / group_
120 <<
" of Conv layer with weight blob " 121 << this->debug_def().input(1) <<
" is smaller than threshold " 122 << FLAGS_caffe2_dnnlowp_acc16_n_threshold
123 <<
" . Falling back to acc32";
124 fallback_to_32_bit_accumulation_ =
true;
127 if (kernel_dim < FLAGS_caffe2_dnnlowp_acc16_k_threshold) {
128 LOG(INFO) <<
"K " << kernel_dim <<
" of Conv layer with weight blob " 129 << this->debug_def().input(1) <<
" is smaller than threshold " 130 << FLAGS_caffe2_dnnlowp_acc16_k_threshold
131 <<
" . Falling back to acc32";
132 fallback_to_32_bit_accumulation_ =
true;
138 if (!Wq_outlier_ && this->order_ == StorageOrder::NHWC &&
139 nbits_in_non_outlier_ < 8) {
140 CAFFE_ENFORCE(!W_quantized_.empty());
142 Wq_outlier_.reset(ExtractOutlierMatrix(
146 nbits_in_non_outlier_,
148 int outlier_cnt = Wq_outlier_->ColPtr()[num_out_channels];
150 LOG(INFO) <<
"Proportion of outlier for Conv layer with weight blob " 151 << this->debug_def().input(1) <<
" is " 152 <<
static_cast<float>(outlier_cnt) / W_quantized_.size();
153 LOG(INFO) <<
"nbits_in_non_outlier " << nbits_in_non_outlier_
154 <<
" copy_to_32bit_frequency " << copy_to_32bit_frequency_;
156 if (static_cast<float>(outlier_cnt) / W_quantized_.size() >
157 FLAGS_caffe2_dnnlowp_acc16_density_threshold) {
158 LOG(INFO) <<
"Density of outliers is higher than threshold " 159 << FLAGS_caffe2_dnnlowp_acc16_density_threshold
160 <<
" . Falling back to acc32";
161 fallback_to_32_bit_accumulation_ =
true;
167 bool packW = this->order_ == StorageOrder::NHWC && GetCpuId().avx2();
169 if (first_invocation_) {
172 if (this->order_ != StorageOrder::NHWC) {
173 reason =
"fbgemm only supports NHWC layout";
174 }
else if (!GetCpuId().avx2()) {
175 reason =
"fbgemm only supports AVX2+";
180 if (!reason.empty()) {
181 static int log_occurences = 0;
182 if (log_occurences < 32) {
184 LOG(WARNING) <<
"Conv with weight " << this->debug_def().input(FILTER)
185 <<
" falls back to slow path because " << reason;
189 if (nbits_in_non_outlier_ < 8 && this->order_ != StorageOrder::NHWC) {
190 static int log_occurences = 0;
191 if (log_occurences < 32) {
193 LOG(WARNING) <<
"Outlier-aware quantization only supports " 197 first_invocation_ =
false;
200 if (packW && !Wq_acc16_packed_) {
201 Wq_acc16_packed_.reset(
new fbgemm::PackBMatrix<int8_t, int16_t>(
202 fbgemm::matrix_op_t::Transpose,
204 num_out_channels / group_,
209 vector<int8_t>().swap(W_quantized_);
215 template <
bool ReluFused>
216 bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHW() {
217 VLOG(2) <<
"Running DNNLOWP_ACC16 Conv";
222 if (!GetQuantizationParameters_()) {
225 if (fallback_to_32_bit_accumulation_) {
226 return BaseType::RunOnDeviceWithOrderNCHW();
229 const Tensor& X = InputTensorCPU_(INPUT);
230 auto& filter = InputTensorCPU_(FILTER);
231 const int N = X.dim32(0),
C = X.dim32(1);
232 CAFFE_ENFORCE_EQ(X.ndim(), filter.ndim());
233 const int M = filter.dim32(0);
236 filter.dim32(1) * group_,
237 "Convolution op: input channels does not match: # of input channels ",
239 " is not equal to kernel channels * group:",
246 "The number of output channels is not divisible by group.");
248 auto sizes = this->GetOutputSize(X, filter.dim32(0));
249 Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype<uint8_t>());
251 const vector<int> input_dims = GetDims(X);
252 const vector<int> output_dims = GetDims(*Y);
253 const int input_image_size = this->GetDimsSize(X);
254 const int output_image_size = this->GetDimsSize(*Y);
257 const int kernel_dim = this->KernelDim_();
259 vector<int> img_shape;
260 img_shape.assign(X.sizes().begin() + 1, X.sizes().end());
262 vector<int> buffer_shape;
263 buffer_shape.push_back(kernel_dim);
265 buffer_shape.end(), output_dims.begin(), output_dims.end());
266 buffer_shape.insert(buffer_shape.begin(), dnnlowp_get_max_threads());
268 if (this->kernel_.size() != 2) {
269 SetDeviceTensor(img_shape, &(this->img_shape_device_));
270 SetDeviceTensor(buffer_shape, &(this->col_buffer_shape_device_));
273 const int col_buffer_size = kernel_dim * output_image_size;
277 const int input_offset = C / group_ * input_image_size;
281 const uint8_t* Xdata = X.template data<uint8_t>();
283 auto f = [&](
Tensor* col_buffer, vector<int32_t>* Y_int32) {
284 col_buffer->Resize(buffer_shape);
285 uint8_t* col_buffer_data = col_buffer->template mutable_data<uint8_t>();
287 Y_int32->resize(M * output_image_size * dnnlowp_get_max_threads());
288 vector<int> buffer_shape_per_thread(
289 buffer_shape.begin() + 1, buffer_shape.end());
292 uint8_t* Y_data = Y->template mutable_data<uint8_t>();
293 this->column_offsets_->resize(
294 output_image_size * dnnlowp_get_max_threads());
297 #pragma omp parallel for 299 for (
int image_id = 0; image_id < N; ++image_id) {
300 int tid = dnnlowp_get_thread_num();
301 for (
int group_id = 0; group_id < group_; ++group_id) {
302 if (this->kernel_.size() == 2) {
303 math::Im2ColNCHW<uint8_t>(
317 Xdata + (group_ * image_id + group_id) * input_offset,
318 col_buffer_data + tid * col_buffer_size,
320 in_qparams_[INPUT].zero_point);
322 math::Im2ColNdNCHW<uint8_t>(
323 this->kernel_.size(),
324 C * input_image_size,
327 buffer_shape_per_thread.data(),
328 this->kernel_.data(),
329 this->stride_.data(),
330 this->dilation_.data(),
332 Xdata + (group_ * image_id + group_id) * input_offset,
333 col_buffer_data + tid * col_buffer_size,
335 in_qparams_[INPUT].zero_point);
339 uint8_t* col_buffer_private = col_buffer_data + tid * col_buffer_size;
342 int32_t* Y_int32_temp = Y_int32->data() +
343 ((M / group_) * group_id + M * tid) * output_image_size;
344 int8_t* W_quantized_group =
345 W_quantized_.data() + (M / group_) * group_id * kernel_dim;
347 static int log_occurences = 0;
348 if (log_occurences < 32) {
351 <<
"Consider using DNNLOWP instead of DNNLOWP_ACC16 engine since " 352 "we're falling back to a slow path because of NCHW layout";
355 for (
int i = 0; i < M / group_; ++i) {
356 for (
int j = 0; j < output_image_size; ++j) {
357 int32_t int32_sum = 0;
358 int16_t int16_sum = 0;
359 for (
int k = 0; k < kernel_dim; ++k) {
360 int32_t w = W_quantized_group[i * kernel_dim + k];
361 int32_t x = col_buffer_private[k * output_image_size + j];
362 #ifdef DNNLOWP_ACC16_IN_SLOW_PATH 363 int16_sum = std::max<int32_t>(
364 numeric_limits<int16_t>::min(),
366 numeric_limits<int16_t>::max(), int16_sum + x * w));
367 if (k % copy_to_32bit_frequency_ ==
368 copy_to_32bit_frequency_ - 1) {
369 int32_sum += int16_sum;
376 Y_int32_temp[i * output_image_size + j] = int32_sum + int16_sum;
380 this->RunOnDeviceEpilogueNCHW_(
383 Y_data + (M * image_id + M / group_ * group_id) * output_image_size,
384 M / group_ * group_id,
390 this->RunWithSharedBuffer_(&col_buffer_, &(this->Y_int32_), f);
392 PropagateOutputTensorQuantizationParams(
this, 0, out_qparams_);
394 this->MeasureQuantizationError_();
399 static void conv_nhwc_acc16_ref_(
402 int output_image_size,
405 const uint8_t* col_buffer,
408 #ifdef DNNLOWP_ACC16_IN_SLOW_PATH
413 #ifdef DNNLOWP_ACC16_IN_SLOW_PATH 414 uint64_t underflow_cnt = 0, overflow_cnt = 0;
416 for (
int group_id = 0; group_id < num_groups; ++group_id) {
417 for (
int i = 0; i < N * output_image_size; ++i) {
418 for (
int j = 0; j < M / num_groups; ++j) {
419 int32_t int32_sum = 0;
420 int16_t int16_sum = 0;
421 #ifdef DNNLOWP_ACC16_IN_SLOW_PATH 422 bool overflowed =
false, underflowed =
false;
424 for (
int k = 0; k < kernel_dim; ++k) {
425 int32_t x = col_buffer[(i * num_groups + group_id) * kernel_dim + k];
426 int32_t w = W[(group_id * (M / num_groups) + j) * kernel_dim + k];
427 #ifdef DNNLOWP_ACC16_IN_SLOW_PATH 428 if (!overflowed && !underflowed) {
429 if (int16_sum + x * w > numeric_limits<int16_t>::max()) {
431 }
else if (int16_sum + x * w < numeric_limits<int16_t>::min()) {
436 int16_sum = std::max<int32_t>(
437 numeric_limits<int16_t>::min(),
439 numeric_limits<int16_t>::max(), int16_sum + x * w));
440 if (k % copy_to_32bit_frequency_ == copy_to_32bit_frequency_ - 1) {
441 int32_sum += int16_sum;
448 Y[i * M + group_id * (M / num_groups) + j] = int32_sum + int16_sum;
449 #ifdef DNNLOWP_ACC16_IN_SLOW_PATH 452 }
else if (underflowed) {
455 #ifdef DNNLOWP_DETAILED_LOG_IN_ACC16_SLOW_PATH 456 if (overflowed || underflowed) {
458 for (
int k = 0; k < kernel_dim; ++k) {
460 col_buffer[(i * num_groups + group_id) * kernel_dim + k];
461 int32_t w = W[k * M + group_id * (M / num_groups) + j];
462 LOG(INFO) << k <<
": " << sum <<
" + " << x <<
" * " << w <<
" = " 473 #ifdef DNNLOWP_ACC16_IN_SLOW_PATH 474 LOG(INFO) << op->debug_def().input(1) <<
" underflow_cnt " << underflow_cnt
475 <<
" (" << (float)underflow_cnt / (N * output_image_size * M) * 100
476 <<
") overflow_cnt " << overflow_cnt <<
" (" 477 << (float)overflow_cnt / (N * output_image_size * M) * 100 <<
")";
481 template <
bool ReluFused>
482 template <fbgemm::QuantizationGranularity Q_GRAN>
483 void ConvDNNLowPAcc16Op<ReluFused>::DispatchFBGEMM_(
484 fbgemm::PackAWithRowOffset<uint8_t, int16_t>& packA,
485 const uint8_t* col_buffer_data,
486 vector<int32_t>* Y_int32,
487 uint8_t* Y_uint8_data) {
489 auto& filter = InputTensorCPU_(FILTER);
490 const int M = filter.dim32(0);
492 assert(Wq_acc16_packed_.get());
493 int kernel_dim = this->KernelDim_();
495 int nthreads = dnnlowp_get_num_threads();
496 int tid = dnnlowp_get_thread_num();
499 DoNothing<> doNothingObj{};
500 ReQuantizeOutput<ReluFused, Q_GRAN> reqObj(
502 this->requantization_multipliers_.data(),
503 out_qparams_.zero_point,
504 in_qparams_[INPUT].zero_point,
505 this->filter_zero_points_.data(),
506 packA.getRowOffsetBuffer(),
507 this->column_offsets_->data(),
508 InputSize() == 3 ? this->b_quantized_data_ :
nullptr,
512 if (nbits_in_non_outlier_ < 8) {
514 typename ReQuantizeOutput<ReluFused>::outType,
516 ReQuantizeOutput<ReluFused, Q_GRAN>>
518 reqObj, col_buffer_data, group_ * kernel_dim, *Wq_outlier_, group_);
542 template <
bool ReluFused>
543 void ConvDNNLowPAcc16Op<ReluFused>::ConvOutlier_(
544 const uint8_t* col_buffer,
545 vector<int32_t>* Y_int32) {
546 if (nbits_in_non_outlier_ < 8) {
547 const Tensor& X = InputTensorCPU_(INPUT);
548 auto& filter = InputTensorCPU_(FILTER);
549 Tensor* Y = OutputTensorCPU_(0);
550 const int N = X.dim32(0);
551 const int M = filter.dim32(0);
553 const int kernel_dim = this->KernelDim_();
554 const int output_image_size = this->GetDimsSize(*Y);
560 int group_begin, group_end, i_begin, i_end;
561 this->PartitionGroupedNHWCConv_(
567 N * output_image_size,
568 dnnlowp_get_num_threads(),
569 dnnlowp_get_thread_num());
571 for (
int group_id = group_begin; group_id < group_end; ++group_id) {
572 CAFFE_ENFORCE_EQ(Wq_outlier_->NumOfRows(), kernel_dim);
574 fbgemm::block_type_t block = {
575 0, i_end - i_begin, group_id * (M / group_), M / group_};
578 col_buffer + (i_begin * group_ + group_id) * kernel_dim,
581 Y_int32->data() + i_begin * M + group_id * (M / group_),
588 template <
bool ReluFused>
589 bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWC() {
591 this->kernel_.size(),
593 "Only 1-3d convolution is supported for NHWC storage type");
597 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN 598 chrono::time_point<chrono::system_clock> t_very_begin, t_begin, t_end;
600 t_begin = chrono::system_clock::now();
601 t_very_begin = t_begin;
605 if (!GetQuantizationParameters_()) {
609 if (fallback_to_32_bit_accumulation_) {
610 return BaseType::RunOnDeviceWithOrderNHWC();
613 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN 614 t_end = chrono::system_clock::now();
615 double dt = chrono::duration<double>(t_end - t_begin).count();
616 LOG(INFO) <<
"this=" <<
this <<
" get_quant_params: " << dt * 1e3 <<
" ms";
619 const Tensor& X = InputTensorCPU_(INPUT);
620 auto& filter = InputTensorCPU_(FILTER);
621 const int N = X.dim32(0), C = X.dim32(X.ndim() - 1);
623 CAFFE_ENFORCE_EQ(X.ndim(), filter.ndim());
624 const int M = filter.dim32(0);
625 CAFFE_ENFORCE_EQ(filter.dim32(filter.ndim() - 1), C / group_);
627 auto sizes = this->GetOutputSize(X, filter.dim32(0));
628 Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype<uint8_t>());
630 const int kernel_dim = this->KernelDim_();
632 const int output_image_size = this->GetDimsSize(*Y);
636 auto f = [&](
Tensor* col_buffer, vector<int32_t>* Y_int32) {
637 Y_int32->resize(Y->numel());
639 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN 640 t_begin = chrono::system_clock::now();
643 bool no_im2col = this->NoIm2ColNHWC_();
646 const uint8_t* Xdata = X.template data<uint8_t>();
647 const uint8_t* col_buffer_data =
648 no_im2col ? Xdata : this->Im2ColNHWC_(col_buffer);
650 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN 651 t_end = chrono::system_clock::now();
652 dt = chrono::duration<double>(t_end - t_begin).count();
653 LOG(INFO) <<
"this=" <<
this <<
" im2col: " << dt * 1e3 <<
" ms";
654 t_begin = chrono::system_clock::now();
658 int row_offset_size_per_thread = -1;
659 int x_pack_buf_size_per_thread = -1;
660 if (Wq_acc16_packed_) {
661 row_offset_size_per_thread =
662 PackAWithRowOffset<uint8_t, int16_t>::rowOffsetBufferSize();
663 x_pack_buf_size_per_thread =
664 PackAWithRowOffset<uint8_t, int16_t>::packedBufferSize();
666 dnnlowp_get_max_threads() * row_offset_size_per_thread);
668 dnnlowp_get_max_threads() * x_pack_buf_size_per_thread);
671 uint8_t* Y_uint8_data = Y->template mutable_data<uint8_t>();
674 if (Wq_acc16_packed_)
680 int tid = dnnlowp_get_thread_num();
683 PackAWithRowOffset<uint8_t, int16_t> packA(
684 matrix_op_t::NoTranspose,
685 N * output_image_size,
689 X_pack_buf_.data() + tid * x_pack_buf_size_per_thread,
691 row_offsets_.data() + tid * row_offset_size_per_thread);
693 if (this->quantize_groupwise_) {
694 DispatchFBGEMM_<QuantizationGranularity::GROUP>(
695 packA, col_buffer_data, Y_int32, Y_uint8_data);
697 DispatchFBGEMM_<QuantizationGranularity::TENSOR>(
698 packA, col_buffer_data, Y_int32, Y_uint8_data);
702 conv_nhwc_acc16_ref_(
711 #ifdef DNNLOWP_ACC16_IN_SLOW_PATH 718 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN 719 t_end = chrono::system_clock::now();
720 dt = chrono::duration<double>(t_end - t_begin).count();
721 double ops = 2. * N * output_image_size * M * kernel_dim;
722 double gops = ops / dt / 1e9;
723 LOG(INFO) <<
"this=" <<
this <<
" GEMM: " << dt * 1e3 <<
" ms " << gops
725 t_begin = chrono::system_clock::now();
728 if (!Wq_acc16_packed_) {
729 ConvOutlier_(col_buffer_data, Y_int32);
732 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN 733 t_end = chrono::system_clock::now();
734 dt = chrono::duration<double>(t_end - t_begin).count();
735 LOG(INFO) <<
"this=" <<
this <<
" out-lier: " << dt * 1e3 <<
" ms";
736 t_begin = chrono::system_clock::now();
739 if (!Wq_acc16_packed_) {
740 this->RunOnDeviceEpilogueNHWC_(col_buffer_data, Y_int32->data());
742 PropagateOutputTensorQuantizationParams(
this, 0, out_qparams_);
746 this->RunWithSharedBuffer_(&col_buffer_, &(this->Y_int32_), f);
748 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN 749 t_end = chrono::system_clock::now();
750 dt = chrono::duration<double>(t_end - t_begin).count();
751 LOG(INFO) <<
"this=" <<
this <<
" prologue: " << dt * 1e3 <<
" ms";
752 t_begin = chrono::system_clock::now();
754 t_end = chrono::system_clock::now();
755 dt = chrono::duration<double>(t_end - t_very_begin).count();
756 double ops = 2. * N * output_image_size * M * kernel_dim;
757 double gops = ops / dt / 1e9;
758 LOG(INFO) <<
"this=" <<
this <<
" " << this->debug_def().type()
759 <<
" output=" << this->debug_def().output(0) <<
" " 760 << N * output_image_size <<
"x" << M <<
"x" << kernel_dim
761 <<
" G=" << group_ <<
" C/G=" << C / group_ <<
" K/G=" << M / group_
762 <<
" R=" << kernel_h() <<
" S=" << kernel_w() <<
" : " << dt * 1e3
763 <<
" ms " << gops <<
" gops";
766 this->MeasureQuantizationError_();
771 REGISTER_CPU_OPERATOR_WITH_ENGINE(
774 ConvDNNLowPAcc16Op<false>);
775 REGISTER_CPU_OPERATOR_WITH_ENGINE(
778 ConvDNNLowPAcc16Op<true>);
780 REGISTER_CPU_OPERATOR_WITH_ENGINE(
783 ConvDNNLowPAcc16Op<false>);
784 REGISTER_CPU_OPERATOR_WITH_ENGINE(
787 ConvDNNLowPAcc16Op<true>);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...