Caffe2 - C++ API
A deep learning, cross platform ML framework
quant_decomp_zstd_op.cc
1 #include "quant_decomp_zstd_op.h"
2 #include <stdint.h>
3 #include <zstd.h>
4 #include "caffe2/core/tensor.h"
5 #include "caffe2/proto/caffe2_pb.h"
6 
7 namespace caffe2 {
8 
9 namespace {
10 
11 #define REGISTER_TYPE(index, type) \
12  { \
13  index, [](TensorCPU* tensor_) -> uint8_t* { \
14  return reinterpret_cast<uint8_t*>(tensor_->mutable_data<type>()); \
15  } \
16  }
17 
18 // return a mutable pointer to the tensor in uint8_t format, the memory is
19 // allocated based on the type 'type_index'
20 // supported type is defined in 'gTypeMapper'
21 uint8_t* GetMutableData(int type_index, TensorCPU* tensor) {
22  // see COMP_DATA_TYPE_MAPPER in mutils.py for the mapping
23  static const std::map<int, std::function<uint8_t*(TensorCPU * tensor)>>
24  gTypeMapper = {REGISTER_TYPE(TensorProto::UINT8, uint8_t),
25  REGISTER_TYPE(TensorProto::UINT16, uint16_t),
26  REGISTER_TYPE(TensorProto::INT32, int32_t),
27  REGISTER_TYPE(TensorProto::FLOAT, float)};
28 
29  CAFFE_ENFORCE_EQ(
30  gTypeMapper.count(type_index),
31  1,
32  "Invalid type index " + c10::to_string(type_index) + ".");
33  return gTypeMapper.at(type_index)(tensor);
34 }
35 
36 const uint8_t* GetCompressedPtr(const TensorCPU& compressed, size_t* out_size) {
37  CAFFE_ENFORCE(
38  // array of uint8_t
39  compressed.template IsType<uint8_t>() ||
40  // array with one string
41  compressed.template IsType<std::string>());
42 
43  if (compressed.template IsType<uint8_t>()) {
44  *out_size = compressed.numel();
45  return compressed.data<uint8_t>();
46  }
47 
48  // string type
49  CAFFE_ENFORCE_EQ(compressed.numel(), 1);
50  auto& str = compressed.data<std::string>()[0];
51  *out_size = str.size();
52  return reinterpret_cast<const uint8_t*>(str.data());
53 }
54 
55 // Deserialize the string to get TensorProtos, storing tensors in compressed
56 // format
57 TensorProtos GetTensorsProto(const TensorCPU& compressed) {
58  size_t sz;
59  auto* ptr = GetCompressedPtr(compressed, &sz);
60  TensorProtos tensors;
61  CAFFE_ENFORCE(tensors.ParseFromArray(ptr, sz));
62  return tensors;
63 }
64 
65 // Decompress tensor stored in compressed format
66 // It is compressed using mutils.compress_data_list()
67 void Decompress(const TensorProto& compressed, TensorCPU* outDecomp) {
68  vector<int64_t> shape(compressed.dims().begin(), compressed.dims().end());
69  // shape stores the dimensions of data before compression,
70  // see _compress_data_single() in mutils.py
71  outDecomp->Resize(shape);
72  auto* out_ptr = GetMutableData(compressed.data_type(), outDecomp);
73 
74  auto* src = reinterpret_cast<const uint8_t*>(compressed.byte_data().data());
75  size_t comp_size = compressed.byte_data().size();
76  size_t decomp_size = outDecomp->nbytes();
77 
78  // call zstd
79  size_t dc_size = ZSTD_decompress(out_ptr, decomp_size, src, comp_size);
80  CAFFE_ENFORCE(!ZSTD_isError(dc_size), ZSTD_getErrorName(dc_size));
81  CAFFE_ENFORCE_EQ(decomp_size, dc_size);
82 }
83 
84 } // namespace
85 
86 bool QuantDecompZstdOp::RunOnDevice() {
87  const auto& op_compressed = Input(0);
88 
89  // Data could be an array of uint_t, or a string
90  CAFFE_ENFORCE(
91  // array of uint8_t
92  op_compressed.template IsType<uint8_t>() ||
93  // array with one string
94  op_compressed.template IsType<std::string>(),
95  op_compressed.dtype().name());
96 
97  // op_compressed: compressed data, 1d
98  if (op_compressed.template IsType<uint8_t>()) {
99  CAFFE_ENFORCE_EQ(op_compressed.dim(), 1, op_compressed.dim());
100  } else {
101  // string type has 0 dimension
102  CAFFE_ENFORCE_EQ(op_compressed.numel(), 1, op_compressed.numel());
103  }
104 
105  auto tensors = GetTensorsProto(op_compressed);
106  CAFFE_ENFORCE_EQ(tensors.protos_size(), OutputSize());
107 
108  for (int i = 0; i < OutputSize(); i++) {
109  Decompress(tensors.protos(i), Output(i));
110  }
111 
112  return true;
113 }
114 
115 REGISTER_CPU_OPERATOR(QuantDecompZstd, QuantDecompZstdOp);
116 
117 OPERATOR_SCHEMA(QuantDecompZstd)
118  .NumInputs(1)
119  .NumOutputs(1, INT_MAX)
120  .SetDoc(R"DOC(
121  Decompress a set of tensors that are compressed using zstd.
122  The data can be compressed using mutils.compress_data_list(), see
123  quant_decomp_op_test.py for an example.
124  The number of outputs depended on the input.
125  )DOC")
126  .Input(
127  0,
128  "compressed",
129  "Compressed data in 1d tensor (uint8_t), "
130  "or 0d tensor with one element in string type."
131  "The data is compressed using mutils.compress_data_list().")
132  .Output(0, "output0", "Decompressed data 0")
133  .Output(1, "output1", "Decompressed data 1 if existed")
134  .Output(2, "output2", "Decompressed data 2 if existed")
135  .Output(3, "outputn", "Decompressed data n if existed");
136 
137 SHOULD_NOT_DO_GRADIENT(QuantDecompZstd);
138 
139 } // namespace caffe2
const Tensor & Input(int idx, DeviceType type=CPUContext::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13