Caffe2 - C++ API
A deep learning, cross platform ML framework
batch_permutation_dnnlowp_op.cc
1 #include "caffe2/quantization/server/batch_permutation_dnnlowp_op.h"
2 
3 namespace caffe2 {
4 
5 template <typename T>
6 bool BatchPermutationDNNLowPOp<T>::RunOnDevice() {
7  using namespace dnnlowp;
8 
9  this->ParseDNNLowPOperatorArguments_();
10 
11  // Choose quantization params
12  in_qparams_[INPUT] =
13  GetInputTensorQuantizationParamsOf(this, INPUT, qfactory_.get());
14 
15  const auto& X = InputTensorCPU_(INPUT);
16  const auto& indices = Input(INDICES);
17  auto* Y = OutputTensorCPU_(OUTPUT);
18 
19  CAFFE_ENFORCE(indices.ndim() == 1, "indices must be 1-d");
20  CAFFE_ENFORCE(
21  X.dim32(0) == indices.dim32(0),
22  "X.dim32(0) must be equal to indices.dim32(0)",
23  "(",
24  X.dim32(0),
25  " vs. ",
26  indices.dim32(0),
27  ")");
28  CAFFE_ENFORCE_GT(X.dim32(0), 0);
29 
30  Y->ResizeLike(X);
31  const T* X_data = X.template data<T>();
32  const int* indices_data = indices.template data<int>();
33  T* Y_data = Y->template mutable_data<T>();
34 
35  int N = X.dim32(0);
36  int K = X.numel() / N;
37 
38 #ifdef _OPENMP
39 #pragma omp parallel for
40 #endif
41  for (int i = 0; i < N; ++i) {
42  int origIdx = i * K;
43  int permuteIdx = indices_data[i] * K;
44  std::memcpy(Y_data + origIdx, X_data + permuteIdx, K * sizeof(T));
45  }
46 
47  // Even if there is a pre-chosen quantization parameters for the output,
48  // it is ignored because batch permutation output quantization should be same
49  // as the input.
50  PropagateOutputTensorQuantizationParams(this, 0, in_qparams_[INPUT]);
51 
52  return true;
53 }
54 
55 REGISTER_CPU_OPERATOR_WITH_ENGINE(
56  BatchPermutation,
57  DNNLOWP,
58  BatchPermutationDNNLowPOp<uint8_t>);
59 REGISTER_CPU_OPERATOR_WITH_ENGINE(
60  Int8BatchPermutation,
61  DNNLOWP,
62  BatchPermutationDNNLowPOp<uint8_t>);
63 
64 OPERATOR_SCHEMA(Int8BatchPermutation).NumInputs(2).NumOutputs(1);
65 
66 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13