1 #include "caffe2/quantization/server/batch_permutation_dnnlowp_op.h" 6 bool BatchPermutationDNNLowPOp<T>::RunOnDevice() {
9 this->ParseDNNLowPOperatorArguments_();
13 GetInputTensorQuantizationParamsOf(
this, INPUT, qfactory_.get());
15 const auto& X = InputTensorCPU_(INPUT);
16 const auto& indices = Input(INDICES);
17 auto* Y = OutputTensorCPU_(OUTPUT);
19 CAFFE_ENFORCE(indices.ndim() == 1,
"indices must be 1-d");
21 X.dim32(0) == indices.dim32(0),
22 "X.dim32(0) must be equal to indices.dim32(0)",
28 CAFFE_ENFORCE_GT(X.dim32(0), 0);
31 const T* X_data = X.template data<T>();
32 const int* indices_data = indices.template data<int>();
33 T* Y_data = Y->template mutable_data<T>();
36 int K = X.numel() / N;
39 #pragma omp parallel for 41 for (
int i = 0; i < N; ++i) {
43 int permuteIdx = indices_data[i] * K;
44 std::memcpy(Y_data + origIdx, X_data + permuteIdx, K *
sizeof(
T));
50 PropagateOutputTensorQuantizationParams(
this, 0, in_qparams_[INPUT]);
55 REGISTER_CPU_OPERATOR_WITH_ENGINE(
58 BatchPermutationDNNLowPOp<uint8_t>);
59 REGISTER_CPU_OPERATOR_WITH_ENGINE(
62 BatchPermutationDNNLowPOp<uint8_t>);
64 OPERATOR_SCHEMA(Int8BatchPermutation).NumInputs(2).NumOutputs(1);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...