Caffe2 - C++ API
A deep learning, cross platform ML framework
quant_decomp_zstd_op.cc
1 
17 #include "quant_decomp_zstd_op.h"
18 #include <stdint.h>
19 #include <zstd.h>
20 #include "caffe2/core/tensor.h"
21 #include "caffe2/proto/caffe2.pb.h"
22 
23 namespace caffe2 {
24 
25 namespace {
26 
27 #define REGISTER_TYPE(index, type) \
28  { \
29  index, [](TensorCPU* tensor_) -> uint8_t* { \
30  return reinterpret_cast<uint8_t*>(tensor_->mutable_data<type>()); \
31  } \
32  }
33 
34 // return a mutable pointer to the tensor in uint8_t format, the memory is
35 // allocated based on the type 'type_index'
36 // supported type is defined in 'gTypeMapper'
37 uint8_t* GetMutableData(int type_index, TensorCPU* tensor) {
38  // see COMP_DATA_TYPE_MAPPER in mutils.py for the mapping
39  static const std::map<int, std::function<uint8_t*(TensorCPU * tensor)>>
40  gTypeMapper = {REGISTER_TYPE(TensorProto::UINT8, uint8_t),
41  REGISTER_TYPE(TensorProto::UINT16, uint16_t),
42  REGISTER_TYPE(TensorProto::INT32, int32_t),
43  REGISTER_TYPE(TensorProto::FLOAT, float)};
44 
45  CAFFE_ENFORCE_EQ(
46  gTypeMapper.count(type_index),
47  1,
48  "Invalid type index " + caffe2::to_string(type_index) + ".");
49  return gTypeMapper.at(type_index)(tensor);
50 }
51 
52 const uint8_t* GetCompressedPtr(const TensorCPU& compressed, size_t* out_size) {
53  CAFFE_ENFORCE(
54  // array of uint8_t
55  compressed.template IsType<uint8_t>() ||
56  // array with one string
57  compressed.template IsType<std::string>());
58 
59  if (compressed.template IsType<uint8_t>()) {
60  *out_size = compressed.size();
61  return compressed.data<uint8_t>();
62  }
63 
64  // string type
65  CAFFE_ENFORCE_EQ(compressed.size(), 1);
66  auto& str = compressed.data<std::string>()[0];
67  *out_size = str.size();
68  return reinterpret_cast<const uint8_t*>(str.data());
69 }
70 
71 // Deserialize the string to get TensorProtos, storing tensors in compressed
72 // format
73 TensorProtos GetTensorsProto(const TensorCPU& compressed) {
74  size_t sz;
75  auto* ptr = GetCompressedPtr(compressed, &sz);
76  TensorProtos tensors;
77  CAFFE_ENFORCE(tensors.ParseFromArray(ptr, sz));
78  return tensors;
79 }
80 
81 // Decompress tensor stored in compressed format
82 // It is compressed using mutils.compress_data_list()
83 void Decompress(const TensorProto& compressed, TensorCPU* outDecomp) {
84  vector<TIndex> shape(compressed.dims().begin(), compressed.dims().end());
85  // shape stores the dimensions of data before compression,
86  // see _compress_data_single() in mutils.py
87  outDecomp->Resize(shape);
88  auto* out_ptr = GetMutableData(compressed.data_type(), outDecomp);
89 
90  auto* src = reinterpret_cast<const uint8_t*>(compressed.byte_data().data());
91  size_t comp_size = compressed.byte_data().size();
92  size_t decomp_size = outDecomp->nbytes();
93 
94  // call zstd
95  size_t dc_size = ZSTD_decompress(out_ptr, decomp_size, src, comp_size);
96  CAFFE_ENFORCE(!ZSTD_isError(dc_size), ZSTD_getErrorName(dc_size));
97  CAFFE_ENFORCE_EQ(decomp_size, dc_size);
98 }
99 
100 } // namespace
101 
102 bool QuantDecompZstdOp::RunOnDevice() {
103  const auto& op_compressed = Input(0);
104 
105  // Data could be an array of uint_t, or a string
106  CAFFE_ENFORCE(
107  // array of uint8_t
108  op_compressed.template IsType<uint8_t>() ||
109  // array with one string
110  op_compressed.template IsType<std::string>(),
111  op_compressed.meta().name());
112 
113  // op_compressed: compressed data, 1d
114  if (op_compressed.template IsType<uint8_t>()) {
115  CAFFE_ENFORCE_EQ(op_compressed.ndim(), 1, op_compressed.ndim());
116  } else {
117  // string type has 0 dimension
118  CAFFE_ENFORCE_EQ(op_compressed.size(), 1, op_compressed.size());
119  }
120 
121  auto tensors = GetTensorsProto(op_compressed);
122  CAFFE_ENFORCE_EQ(tensors.protos_size(), OutputSize());
123 
124  for (int i = 0; i < OutputSize(); i++) {
125  Decompress(tensors.protos(i), Output(i));
126  }
127 
128  return true;
129 }
130 
131 REGISTER_CPU_OPERATOR(QuantDecompZstd, QuantDecompZstdOp);
132 
133 OPERATOR_SCHEMA(QuantDecompZstd)
134  .NumInputs(1)
135  .NumOutputs(1, INT_MAX)
136  .SetDoc(R"DOC(
137  Decompress a set of tensors that are compressed using zstd.
138  The data can be compressed using mutils.compress_data_list(), see
139  quant_decomp_op_test.py for an example.
140  The number of outputs depended on the input.
141  )DOC")
142  .Input(
143  0,
144  "compressed",
145  "Compressed data in 1d tensor (uint8_t), "
146  "or 0d tensor with one element in string type."
147  "The data is compressed using mutils.compress_data_list().")
148  .Output(0, "output0", "Decompressed data 0")
149  .Output(1, "output1", "Decompressed data 1 if existed")
150  .Output(2, "output2", "Decompressed data 2 if existed")
151  .Output(3, "outputn", "Decompressed data n if existed");
152 
153 SHOULD_NOT_DO_GRADIENT(QuantDecompZstd);
154 
155 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.