1 #ifndef CAFFE2_OPERATORS_INT8_CONCAT_OP_H_ 2 #define CAFFE2_OPERATORS_INT8_CONCAT_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/core/tensor_int8.h" 7 #include "caffe2/operators/quantized/int8_utils.h" 15 template <
class... Args>
19 if (this->
template GetSingleArgument<string>(
"order",
"") ==
"NHWC") {
21 axis_ = this->
template GetSingleArgument<int>(
"axis", 3);
25 this->
template GetSingleArgument<string>(
"order",
"") ==
"NCHW") {
26 axis_ = this->
template GetSingleArgument<int>(
"axis", 1);
30 axis_ = this->
template GetSingleArgument<int>(
"axis", 0);
34 bool RunOnDevice()
override {
35 const auto& X0 = Inputs()[0]->template Get<Int8TensorCPU>();
36 auto* Y = Outputs()[0]->template GetMutable<Int8TensorCPU>();
38 Y->zero_point = X0.zero_point;
39 int32_t Y_offset = this->
template GetSingleArgument<int>(
"Y_zero_point", 0);
40 auto Y_scale = this->
template GetSingleArgument<float>(
"Y_scale", 1);
41 CHECK_EQ(Y_offset, X0.zero_point);
42 CHECK_EQ(Y_scale, X0.scale);
43 CHECK_GE(X0.zero_point, std::numeric_limits<uint8_t>::min());
44 CHECK_LE(X0.zero_point, std::numeric_limits<uint8_t>::max());
45 auto Y_dims = X0.t.sizes().vec();
46 if (this->
template GetSingleArgument<string>(
"order",
"") ==
"NHWC") {
47 CHECK_EQ(Y_dims.size(), 4);
49 for (
auto i = 1; i < InputSize(); ++i) {
50 const auto& Xi = Inputs()[i]->template Get<Int8TensorCPU>();
51 CHECK_EQ(Xi.t.dim(), Y_dims.size());
52 for (
auto j = 0; j < Y_dims.size(); ++j) {
54 CHECK_EQ(Xi.t.size(j), Y_dims[j]);
57 Y_dims[axis_] += Xi.t.size(axis_);
60 int before = X0.t.size_to_dim(axis_);
61 int after = X0.t.size_from_dim(axis_ + 1);
62 const auto C_total = Y_dims[axis_];
64 for (
auto i = 0; i < InputSize(); ++i) {
65 const auto& Xi = Inputs()[i]->template Get<Int8TensorCPU>();
67 const auto Ci = Xi.t.size(axis_);
68 math::CopyMatrix<CPUContext>(
72 Xi.t.template data<uint8_t>(),
74 Y->t.template mutable_data<uint8_t>() + C_offset,
78 C_offset += Ci * after * Xi.t.itemsize();
91 #endif // CAFFE2_OPERATORS_INT8_CONCAT_OP_H_ void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...