Caffe2 - C++ API
A deep learning, cross platform ML framework
mklmemory_serialization.cc
1 
17 #include "caffe2/core/blob.h"
18 #include "caffe2/core/blob_serialization.h"
19 #include "caffe2/mkl/mkl_utils.h"
20 
21 #ifdef CAFFE2_HAS_MKL_DNN
22 
23 namespace caffe2 {
24 namespace mkl {
31 class MKLMemorySerializer : public BlobSerializerBase {
32  public:
33  MKLMemorySerializer() {}
34  ~MKLMemorySerializer() {}
35 
36  void Serialize(
37  const Blob& blob,
38  const string& name,
39  SerializationAcceptor acceptor) override {
40  BlobProto blob_proto;
41  blob_proto.set_name(name);
42  blob_proto.set_type(kTensorBlobType);
43  TensorProto* proto = blob_proto.mutable_tensor();
44  auto* device_detail = proto->mutable_device_detail();
45  device_detail->set_device_type(MKLDNN);
46  proto->set_name(name);
47  if (blob.IsType<MKLMemory<float>>()) {
48  const MKLMemory<float>& src = blob.Get<MKLMemory<float>>();
49  CAFFE_ENFORCE(
50  src.buffer(), "Cannot serialize an empty MKLMemory object.");
51  size_t total = 1;
52  for (int i = 0; i < src.dims().size(); ++i) {
53  proto->add_dims(src.dims()[i]);
54  total *= src.dims()[i];
55  }
56  proto->mutable_float_data()->Reserve(total);
57  while (total--) {
58  proto->add_float_data(0);
59  }
60  src.CopyTo(proto->mutable_float_data()->mutable_data());
61  } else if (blob.IsType<MKLMemory<double>>()) {
62  const MKLMemory<double>& src = blob.Get<MKLMemory<double>>();
63  CAFFE_ENFORCE(
64  src.buffer(), "Cannot serialize an empty MKLMemory object.");
65  size_t total = 1;
66  for (int i = 0; i < src.dims().size(); ++i) {
67  proto->add_dims(src.dims()[i]);
68  total *= src.dims()[i];
69  }
70  proto->mutable_double_data()->Reserve(total);
71  while (total--) {
72  proto->add_double_data(0);
73  }
74  src.CopyTo(proto->mutable_double_data()->mutable_data());
75  } else {
76  CAFFE_THROW(
77  "MKLMemory could only be either float or double. "
78  "Encountered unsupported type.");
79  }
80  acceptor(name, blob_proto.SerializeAsString());
81  }
82 };
83 
93 class MKLMemoryDeserializer : public BlobDeserializerBase {
94  public:
95  void Deserialize(const BlobProto& blob_proto, Blob* blob) override {
96  const TensorProto& proto = blob_proto.tensor();
97  CAFFE_ENFORCE(
98  proto.data_type() == TensorProto_DataType_FLOAT ||
99  proto.data_type() == TensorProto_DataType_DOUBLE,
100  "MKLMemory only supports either float or double formats.");
101  CAFFE_ENFORCE(
102  !proto.has_segment(), "MKLMemory does not support segment right now.");
103  vector<TIndex> dims;
104  for (const TIndex d : proto.dims()) {
105  dims.push_back(d);
106  }
107  // TODO: right now, every time we do a deserializer we create a new MKL
108  // Memory object. Optionally, we can change that.
109  switch (proto.data_type()) {
110  case TensorProto_DataType_FLOAT: {
111  auto dst = make_unique<MKLMemory<float>>(dims);
112  dst->CopyFrom(proto.float_data().data());
113  blob->Reset(dst.release());
114  break;
115  }
116  case TensorProto_DataType_DOUBLE: {
117  auto dst = make_unique<MKLMemory<double>>(dims);
118  dst->CopyFrom(proto.double_data().data());
119  blob->Reset(dst.release());
120  break;
121  }
122  default:
123  CAFFE_THROW("This should not happen, we guarded things above already.");
124  }
125  }
126 };
127 
128 } // namespace mkl
129 
130 REGISTER_BLOB_SERIALIZER(
131  (TypeMeta::Id<mkl::MKLMemory<float>>()),
132  mkl::MKLMemorySerializer);
133 REGISTER_BLOB_SERIALIZER(
134  (TypeMeta::Id<mkl::MKLMemory<double>>()),
135  mkl::MKLMemorySerializer);
136 REGISTER_BLOB_DESERIALIZER(TensorMKLDNN, mkl::MKLMemoryDeserializer);
137 } // namespace caffe2
138 
139 #endif // CAFFE2_HAS_MKL_DNN
static CAFFE2_API CaffeTypeId Id()
Returns the unique id for the given type T.
Copyright (c) 2016-present, Facebook, Inc.