1 #include "quant_decomp_zstd_op.h" 4 #include "caffe2/core/tensor.h" 5 #include "caffe2/proto/caffe2_pb.h" 11 #define REGISTER_TYPE(index, type) \ 13 index, [](TensorCPU* tensor_) -> uint8_t* { \ 14 return reinterpret_cast<uint8_t*>(tensor_->mutable_data<type>()); \ 21 uint8_t* GetMutableData(
int type_index, TensorCPU* tensor) {
23 static const std::map<int, std::function<uint8_t*(TensorCPU * tensor)>>
24 gTypeMapper = {REGISTER_TYPE(TensorProto::UINT8, uint8_t),
25 REGISTER_TYPE(TensorProto::UINT16, uint16_t),
26 REGISTER_TYPE(TensorProto::INT32, int32_t),
27 REGISTER_TYPE(TensorProto::FLOAT,
float)};
30 gTypeMapper.count(type_index),
32 "Invalid type index " + c10::to_string(type_index) +
".");
33 return gTypeMapper.at(type_index)(tensor);
36 const uint8_t* GetCompressedPtr(
const TensorCPU& compressed,
size_t* out_size) {
39 compressed.template IsType<uint8_t>() ||
41 compressed.template IsType<std::string>());
43 if (compressed.template IsType<uint8_t>()) {
44 *out_size = compressed.numel();
45 return compressed.data<uint8_t>();
49 CAFFE_ENFORCE_EQ(compressed.numel(), 1);
50 auto& str = compressed.data<std::string>()[0];
51 *out_size = str.size();
52 return reinterpret_cast<const uint8_t*
>(str.data());
57 TensorProtos GetTensorsProto(
const TensorCPU& compressed) {
59 auto* ptr = GetCompressedPtr(compressed, &sz);
61 CAFFE_ENFORCE(tensors.ParseFromArray(ptr, sz));
67 void Decompress(
const TensorProto& compressed, TensorCPU* outDecomp) {
68 vector<int64_t> shape(compressed.dims().begin(), compressed.dims().end());
71 outDecomp->Resize(shape);
72 auto* out_ptr = GetMutableData(compressed.data_type(), outDecomp);
74 auto* src =
reinterpret_cast<const uint8_t*
>(compressed.byte_data().data());
75 size_t comp_size = compressed.byte_data().size();
76 size_t decomp_size = outDecomp->nbytes();
79 size_t dc_size = ZSTD_decompress(out_ptr, decomp_size, src, comp_size);
80 CAFFE_ENFORCE(!ZSTD_isError(dc_size), ZSTD_getErrorName(dc_size));
81 CAFFE_ENFORCE_EQ(decomp_size, dc_size);
86 bool QuantDecompZstdOp::RunOnDevice() {
87 const auto& op_compressed =
Input(0);
92 op_compressed.template IsType<uint8_t>() ||
94 op_compressed.template IsType<std::string>(),
95 op_compressed.dtype().name());
98 if (op_compressed.template IsType<uint8_t>()) {
99 CAFFE_ENFORCE_EQ(op_compressed.dim(), 1, op_compressed.dim());
102 CAFFE_ENFORCE_EQ(op_compressed.numel(), 1, op_compressed.numel());
105 auto tensors = GetTensorsProto(op_compressed);
106 CAFFE_ENFORCE_EQ(tensors.protos_size(), OutputSize());
108 for (
int i = 0; i < OutputSize(); i++) {
109 Decompress(tensors.protos(i), Output(i));
115 REGISTER_CPU_OPERATOR(QuantDecompZstd, QuantDecompZstdOp);
117 OPERATOR_SCHEMA(QuantDecompZstd)
119 .NumOutputs(1, INT_MAX)
121 Decompress a set of tensors that are compressed using zstd. 122 The data can be compressed using mutils.compress_data_list(), see 123 quant_decomp_op_test.py for an example. 124 The number of outputs depended on the input. 129 "Compressed data in 1d tensor (uint8_t), " 130 "or 0d tensor with one element in string type." 131 "The data is compressed using mutils.compress_data_list().")
132 .Output(0,
"output0",
"Decompressed data 0")
133 .Output(1,
"output1",
"Decompressed data 1 if existed")
134 .Output(2,
"output2",
"Decompressed data 2 if existed")
135 .Output(3,
"outputn",
"Decompressed data n if existed");
137 SHOULD_NOT_DO_GRADIENT(QuantDecompZstd);
const Tensor & Input(int idx, DeviceType type=CPUContext::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...