Caffe2 - C++ API
A deep learning, cross platform ML framework
qtensor_serialization.h
1 
17 #ifndef CAFFE2_CORE_QTENSOR_SERIALIZATION_H_
18 #define CAFFE2_CORE_QTENSOR_SERIALIZATION_H_
19 
20 #include "caffe2/core/blob_serialization.h"
21 #include "caffe2/core/qtensor.h"
22 
23 namespace caffe2 {
24 
25 constexpr auto kQTensorBlobQType = "QTensor";
26 
27 template <class Context>
29  public:
30  QTensorSerializer() : context_() {}
31  ~QTensorSerializer() {}
35  void Serialize(
36  const Blob& blob,
37  const string& name,
38  SerializationAcceptor acceptor) override;
39 
40  private:
41  Context context_;
42 };
43 
44 template <class Context>
46  public:
47  void Deserialize(const BlobProto& proto, Blob* blob) override;
48  void Deserialize(const QTensorProto& proto, QTensor<Context>* tensor);
49 };
50 
51 template <class Context>
53  const Blob& blob,
54  const string& name,
55  BlobSerializerBase::SerializationAcceptor acceptor) {
56  const auto& qtensor = blob.template Get<QTensor<Context>>();
57  BlobProto blob_proto;
58  blob_proto.set_name(name);
59  blob_proto.set_type(kQTensorBlobQType);
60  QTensorProto& proto = *blob_proto.mutable_qtensor();
61  proto.set_name(name);
62  for (int i = 0; i < qtensor.ndim(); ++i) {
63  proto.add_dims(qtensor.dim32(i));
64  }
65  proto.set_precision(qtensor.precision());
66  proto.set_scale(qtensor.scale());
67  proto.set_bias(qtensor.bias());
68  proto.set_is_signed(qtensor.is_signed());
69  detail::CopyToProtoWithCast(
70  qtensor.nbytes(), qtensor.data(), proto.mutable_data(), &this->context_);
71  acceptor(name, blob_proto.SerializeAsString());
72 }
73 
74 template <class Context>
76  const BlobProto& blob_proto,
77  Blob* blob) {
78  Deserialize(blob_proto.qtensor(), blob->GetMutable<QTensor<Context>>());
79 }
80 
81 template <class Context>
83  const QTensorProto& proto,
84  QTensor<Context>* qtensor) {
85  Context context{};
86  vector<int> dims;
87  for (const int d : proto.dims()) {
88  dims.push_back(d);
89  }
90  qtensor->Resize(dims);
91  qtensor->SetPrecision(proto.precision());
92  qtensor->SetScale(proto.scale());
93  qtensor->SetBias(proto.bias());
94  qtensor->SetSigned(proto.is_signed());
95 
96  detail::CopyFromProtoWithCast(
97  qtensor->nbytes(), proto.data(), qtensor->mutable_data(), &context);
98 }
99 
100 } // namespace caffe2
101 
102 #endif // CAFFE2_CORE_QTENSOR_SERIALIZATION_H_
Blob is a general container that hosts a typed pointer.
Definition: blob.h:41
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
Copyright (c) 2016-present, Facebook, Inc.
T * GetMutable(bool *is_new_object=nullptr)
Gets a mutable pointer to the stored object.
Definition: blob.h:117
void Serialize(const Blob &blob, const string &name, SerializationAcceptor acceptor) override
Serializes a Blob.
BlobSerializerBase is an abstract class that serializes a blob to a string.