Caffe2 - C++ API
A deep learning, cross platform ML framework
int8_serialization.cc
1 #include "caffe2/core/blob_serialization.h"
2 #include "caffe2/core/common.h"
3 #include "caffe2/core/context.h"
4 #include "caffe2/core/tensor_int8.h"
5 #include <c10/util/typeid.h>
6 #include "caffe2/core/types.h"
7 
8 namespace caffe2 {
9 namespace int8 {
10 
12  public:
13  void Serialize(
14  const void* pointer,
15  TypeMeta typeMeta,
16  const string& name,
17  SerializationAcceptor acceptor) override {
18  CAFFE_ENFORCE(typeMeta.Match<Int8TensorCPU>());
19  const auto& tensor = *static_cast<const Int8TensorCPU*>(pointer);
20  BlobProto blob_proto;
21  blob_proto.set_name(name);
22  blob_proto.set_type("Int8TensorCPU");
23  QTensorProto& proto = *blob_proto.mutable_qtensor();
24  proto.set_name(name);
25  for (int i = 0; i < tensor.t.dim(); ++i) {
26  proto.add_dims(tensor.t.dim32(i));
27  }
28  proto.set_precision(8);
29  proto.set_scale(tensor.scale);
30  proto.set_bias(tensor.zero_point);
31  proto.set_is_signed(false);
32 
33  const TensorProto::DataType data_type =
34  TypeMetaToDataType(tensor.t.dtype());
35  proto.set_data_type(data_type);
36  switch (data_type) {
37  case TensorProto_DataType_INT32:
38  detail::CopyToProtoAsIs(
39  tensor.t.numel(),
40  tensor.t.template data<int32_t>(),
41  proto.mutable_data(),
42  &this->context_);
43  break;
44  case TensorProto_DataType_UINT8:
45  detail::CopyToProtoWithCast(
46  tensor.t.numel(),
47  tensor.t.template data<uint8_t>(),
48  proto.mutable_data(),
49  &this->context_);
50  break;
51  default:
52  CAFFE_ENFORCE(false, "Unsupported data type in Int8TensorCPU");
53  }
54 
55  acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
56  }
57 
58  private:
59  CPUContext context_;
60 };
61 
63  public:
64  void Deserialize(const BlobProto& blob_proto, Blob* blob) override {
65  const QTensorProto& proto = blob_proto.qtensor();
66  Int8TensorCPU* tensor = blob->template GetMutable<Int8TensorCPU>();
67  tensor->scale = proto.scale();
68  tensor->zero_point = proto.bias();
69  vector<int> dims;
70  for (const int d : proto.dims()) {
71  dims.push_back(d);
72  }
73  tensor->t.Resize(dims);
74  switch (proto.data_type()) {
75  case TensorProto_DataType_INT32:
76  detail::CopyFromProtoAsIs(
77  tensor->t.numel(),
78  proto.data(),
79  tensor->t.template mutable_data<int32_t>(),
80  &this->context_);
81  break;
82  case TensorProto_DataType_UINT8:
83  detail::CopyFromProtoWithCast(
84  tensor->t.numel(),
85  proto.data(),
86  tensor->t.template mutable_data<uint8_t>(),
87  &this->context_);
88  break;
89  default:
90  CAFFE_ENFORCE(false, "Unsupported data type in Int8TensorCPU");
91  }
92  }
93 
94  private:
95  CPUContext context_;
96 };
97 
98 } // namespace int8
99 
100 namespace {
101 REGISTER_BLOB_SERIALIZER(
102  (TypeMeta::Id<int8::Int8TensorCPU>()),
104 REGISTER_BLOB_DESERIALIZER(Int8TensorCPU, int8::Int8TensorCPUDeserializer);
105 } // namespace
106 
107 } // namespace caffe2
Blob is a general container that hosts a typed pointer.
Definition: blob.h:24
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:40
TensorDeserializer is the deserializer for Tensors.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
TypeMeta is a thin class that allows us to store the type of a container such as a blob...
Definition: typeid.h:324
BlobSerializerBase is an abstract class that serializes a blob to a string.