Caffe2 - C++ API
A deep learning, cross platform ML framework
qtensor_serialization.h
1 #ifndef CAFFE2_CORE_QTENSOR_SERIALIZATION_H_
2 #define CAFFE2_CORE_QTENSOR_SERIALIZATION_H_
3 
4 #include "caffe2/core/blob_serialization.h"
5 #include "caffe2/core/qtensor.h"
6 
7 namespace caffe2 {
8 
9 constexpr auto kQTensorBlobQType = "QTensor";
10 
11 template <class Context>
13  public:
14  QTensorSerializer() : context_() {}
15  ~QTensorSerializer() {}
19  void Serialize(
20  const void* pointer,
21  TypeMeta typeMeta,
22  const string& name,
23  SerializationAcceptor acceptor) override;
24 
25  private:
26  Context context_;
27 };
28 
29 template <class Context>
31  public:
32  void Deserialize(const BlobProto& proto, Blob* blob) override;
33  void Deserialize(const QTensorProto& proto, QTensor<Context>* tensor);
34 };
35 
36 template <class Context>
38  const void* pointer,
39  TypeMeta typeMeta,
40  const string& name,
41  BlobSerializerBase::SerializationAcceptor acceptor) {
42  CAFFE_ENFORCE(typeMeta.Match<QTensor<Context>>());
43  const auto& qtensor = *static_cast<const QTensor<Context>*>(pointer);
44  BlobProto blob_proto;
45  blob_proto.set_name(name);
46  blob_proto.set_type(kQTensorBlobQType);
47  QTensorProto& proto = *blob_proto.mutable_qtensor();
48  proto.set_name(name);
49  for (int i = 0; i < qtensor.ndim(); ++i) {
50  proto.add_dims(qtensor.dim32(i));
51  }
52  proto.set_precision(qtensor.precision());
53  proto.set_scale(qtensor.scale());
54  proto.set_bias(qtensor.bias());
55  proto.set_is_signed(qtensor.is_signed());
56  detail::CopyToProtoWithCast(
57  qtensor.nbytes(), qtensor.data(), proto.mutable_data(), &this->context_);
58  acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
59 }
60 
61 template <class Context>
63  const BlobProto& blob_proto,
64  Blob* blob) {
65  Deserialize(blob_proto.qtensor(), blob->GetMutable<QTensor<Context>>());
66 }
67 
68 template <class Context>
70  const QTensorProto& proto,
71  QTensor<Context>* qtensor) {
72  Context context{};
73  vector<int> dims;
74  for (const int d : proto.dims()) {
75  dims.push_back(d);
76  }
77  qtensor->Resize(dims);
78  qtensor->SetPrecision(proto.precision());
79  qtensor->SetScale(proto.scale());
80  qtensor->SetBias(proto.bias());
81  qtensor->SetSigned(proto.is_signed());
82 
83  detail::CopyFromProtoWithCast(
84  qtensor->nbytes(), proto.data(), qtensor->mutable_data(), &context);
85 }
86 
87 } // namespace caffe2
88 
89 #endif // CAFFE2_CORE_QTENSOR_SERIALIZATION_H_
Blob is a general container that hosts a typed pointer.
Definition: blob.h:24
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
void Serialize(const void *pointer, TypeMeta typeMeta, const string &name, SerializationAcceptor acceptor) override
Serializes a Blob.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
T * GetMutable()
Gets a mutable pointer to the stored object.
Definition: blob.h:100
TypeMeta is a thin class that allows us to store the type of a container such as a blob...
Definition: typeid.h:324
BlobSerializerBase is an abstract class that serializes a blob to a string.