1 #include "caffe2/core/blob_serialization.h" 6 #include "caffe2/core/blob.h" 7 #include "caffe2/utils/proto_utils.h" 10 caffe2_tensor_chunk_size,
12 "Chunk size to split tensor data into");
15 caffe2_max_tensor_serializer_threads,
17 "Maximal number of threads that can be used for tensor serialization");
20 caffe2_serialize_fp16_as_bytes,
22 "Serialize FLOAT16 tensors using byte_data field");
43 SerializationAcceptor acceptor)
override {
44 CAFFE_ENFORCE(typeMeta.Match<std::string>());
47 blob_proto.set_name(name);
48 blob_proto.set_type(
"std::string");
49 blob_proto.set_content(*static_cast<const std::string*>(pointer));
50 acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
60 void Deserialize(
const BlobProto& proto,
Blob* blob)
override {
61 *blob->
GetMutable<std::string>() = proto.content();
70 BlobSerializerBase::SerializationAcceptor acceptor,
72 std::unique_ptr<BlobSerializerBase> serializer(
73 CreateSerializer(typeMeta.
id()));
74 CAFFE_ENFORCE(serializer,
"No known serializer for ", typeMeta.
name());
75 serializer->SerializeWithChunkSize(
76 pointer, typeMeta, name, acceptor, chunk_size);
82 BlobSerializerBase::SerializationAcceptor acceptor =
83 [&data](
const std::string&,
const std::string& blob_str) {
87 SerializeBlob(pointer, typeMeta, name, acceptor, kNoChunking);
95 BlobSerializerBase::SerializationAcceptor acceptor,
108 BlobSerializerBase::SerializationAcceptor acceptor) {
109 this->SerializeWithChunkSize(
110 pointer, typeMeta, name, acceptor, kDefaultChunkSize);
113 void TensorSerializer::SerializeWithChunkSize(
117 BlobSerializerBase::SerializationAcceptor acceptor,
119 CAFFE_ENFORCE(typeMeta.Match<
Tensor>());
120 const auto& tensor = *
static_cast<const Tensor*
>(pointer);
121 if (chunk_size == kNoChunking) {
122 chunk_size = tensor.numel() + 1;
123 }
else if (chunk_size == kDefaultChunkSize) {
124 chunk_size = FLAGS_caffe2_tensor_chunk_size;
127 auto processChunk = [&](int64_t chunkStart) {
128 BlobProto blob_proto;
129 blob_proto.set_name(name);
130 blob_proto.set_type(kTensorBlobType);
131 TensorProto& proto = *blob_proto.mutable_tensor();
132 proto.set_name(name);
134 tensor, name, blob_proto.mutable_tensor(), chunkStart, chunk_size);
136 c10::str(name, kChunkIdSeparator, chunkStart / chunk_size),
137 SerializeBlobProtoAsString_EnforceCheck(blob_proto));
145 while (chunkQueue.Pop(&chunkStart)) {
146 processChunk(chunkStart);
149 std::vector<std::future<void>> futures;
150 if (tensor.numel() > chunk_size) {
151 futures.reserve(FLAGS_caffe2_max_tensor_serializer_threads);
152 for (
int i = 0; i < FLAGS_caffe2_max_tensor_serializer_threads; ++i) {
153 futures.emplace_back(std::async(std::launch::async, task));
158 VLOG(1) <<
"Serializing blob " << name;
161 for (
size_t chunkBegin = 0;
162 chunkBegin < std::max(tensor.numel(),
static_cast<int64_t
>(1));
163 chunkBegin += chunk_size) {
164 VLOG(2) <<
"Starting a chunk at " << chunkBegin;
166 if (tensor.numel() > chunk_size) {
167 chunkQueue.Push(chunkBegin);
170 processChunk(chunkBegin);
174 processChunk(chunkBegin);
179 chunkQueue.NoMoreJobs();
180 for (
auto& fut : futures) {
189 TensorProto* proto_ptr,
193 chunkBegin <= input.numel(),
194 "Chunk begin is out of tensor: ",
198 if (chunkBegin + chunkSize > input.numel()) {
199 chunkSize = input.numel() - chunkBegin;
202 if (chunkSize != 0) {
205 "The input does not have data input yet. This is probably because you " 206 "created a tensor of non-zero shape but never filled its data via " 207 "mutable_data() calls. This means that it makes no sense to serialize " 208 "the tensor content.");
209 }
else if (!input.dtype_initialized()) {
210 C10_LOG_EVERY_MS(WARNING, 1000)
211 <<
"You're trying to serialize tensor with zero numel and no dtype. " 212 <<
"This is a legacy behavior and it WILL BREAK. Contact PyTorch team " 213 <<
"for details. Offending blob name: " << name;
216 TensorProto& proto = *proto_ptr;
217 proto.mutable_segment()->set_begin(chunkBegin);
218 proto.mutable_segment()->set_end(chunkBegin + chunkSize);
220 for (
int i = 0; i < input.dim(); ++i) {
221 proto.add_dims(input.size(i));
223 const TensorProto::DataType data_type = TypeMetaToDataType(input.dtype());
224 proto.set_data_type(data_type);
225 StoreDeviceDetail(input, &proto);
228 auto uniq_ptr = CreateContext(input.GetDevice());
231 case TensorProto_DataType_FLOAT:
232 detail::CopyToProtoAsIs(
234 input.template data<float>() + chunkBegin,
235 proto.mutable_float_data(),
238 case TensorProto_DataType_INT32:
239 detail::CopyToProtoAsIs(
241 input.template data<int>() + chunkBegin,
242 proto.mutable_int32_data(),
245 case TensorProto_DataType_BYTE:
246 LOG(FATAL) <<
"This should not happen. When serializing, " 247 "BYTE is deprecated and moved to UINT8.";
249 case TensorProto_DataType_STRING: {
250 proto.mutable_string_data()->Reserve(chunkSize);
251 const string* content = input.template data<string>();
252 for (
int i = chunkBegin; i < chunkBegin + chunkSize; ++i) {
253 proto.add_string_data(content[i]);
257 case TensorProto_DataType_BOOL:
258 detail::CopyToProtoWithCast(
260 input.template data<bool>() + chunkBegin,
261 proto.mutable_int32_data(),
264 case TensorProto_DataType_UINT8:
265 detail::CopyToProtoWithCast(
267 input.template data<uint8_t>() + chunkBegin,
268 proto.mutable_int32_data(),
271 case TensorProto_DataType_INT8:
272 detail::CopyToProtoWithCast(
274 input.template data<int8_t>() + chunkBegin,
275 proto.mutable_int32_data(),
278 case TensorProto_DataType_UINT16:
279 detail::CopyToProtoWithCast(
281 input.template data<uint16_t>() + chunkBegin,
282 proto.mutable_int32_data(),
285 case TensorProto_DataType_INT16:
286 detail::CopyToProtoWithCast(
288 input.template data<int16_t>() + chunkBegin,
289 proto.mutable_int32_data(),
292 case TensorProto_DataType_INT64:
293 detail::CopyToProtoAsIs(
295 input.template data<int64_t>() + chunkBegin,
296 proto.mutable_int64_data(),
299 case TensorProto_DataType_FLOAT16: {
300 if (FLAGS_caffe2_serialize_fp16_as_bytes) {
301 const int kValue = 1;
303 reinterpret_cast<const char*>(&kValue)[0],
305 "Serialization of FLOAT16 on big endian platform " 306 "is not written yet.");
307 unique_ptr<char[]> buffer(
new char[2 * chunkSize]);
308 this->context_->template CopyToCPU<char>(
310 reinterpret_cast<const char*
>(
311 input.template data<at::Half>() + chunkBegin),
313 this->context_->FinishDeviceComputation();
314 proto.set_byte_data(buffer.release(), 2 * chunkSize);
316 detail::CopyToProtoWithCast(
318 reinterpret_cast<const uint16_t*>(input.template data<at::Half>()) +
320 proto.mutable_int32_data(),
324 case TensorProto_DataType_DOUBLE:
325 detail::CopyToProtoAsIs(
327 input.template data<double>() + chunkBegin,
328 proto.mutable_double_data(),
331 case TensorProto_DataType_UNDEFINED: {
332 proto.mutable_string_data()->Reserve(chunkSize);
334 const char* raw_data =
static_cast<const char*
>(input.raw_data());
335 for (
int i = chunkBegin; i < chunkBegin + chunkSize; ++i) {
337 raw_data + i * input.itemsize(), input.dtype(),
""));
348 void TensorSerializer::StoreDeviceDetail(
350 TensorProto* proto) {
351 ExtractDeviceOption(proto->mutable_device_detail(), input.GetDevice());
354 C10_DEFINE_TYPED_REGISTRY(
355 BlobSerializerRegistry,
363 BlobProto blob_proto;
365 blob_proto.ParseFromString(content),
366 "Cannot parse content into a BlobProto.");
371 if (blob_proto.type() == kTensorBlobType) {
374 auto deserializer = CreateDeserializer(
376 DeviceTypeName(blob_proto.tensor().device_detail().device_type()));
379 CAFFE_ENFORCE(deserializer.get());
380 deserializer->Deserialize(blob_proto, result);
382 auto deserializer = CreateDeserializer(blob_proto.type());
385 "No registered deserializer for type ",
387 deserializer->Deserialize(blob_proto, result);
393 static std::vector<int64_t> DimsFromTensorProto(
const TensorProto& proto) {
394 std::vector<int64_t> dims;
395 dims.reserve(proto.dims().size());
396 for (
const int64_t d : proto.dims()) {
403 static int64_t NumelFromTensorProto(
const TensorProto& tensor_proto) {
405 for (
const int64_t d : tensor_proto.dims()) {
412 static TypeMeta GetDataType(
const TensorProto& tensor_proto) {
414 if (tensor_proto.data_type() != TensorProto_DataType_UNDEFINED) {
415 dtype = DataTypeToTypeMeta(tensor_proto.data_type());
419 dtype = temp_blob.
meta();
427 const TensorProto& tensor_proto) {
428 return at::dtype(GetDataType(tensor_proto))
429 .
device(OptionToDevice(tensor_proto.device_detail()));
432 static std::unique_ptr<BaseContext> ContextFromProto(
433 const TensorProto& tensor_proto) {
434 auto device = OptionToDevice(tensor_proto.device_detail());
435 return CreateContext(device);
440 Tensor EmptyTensorFromProto(
const TensorProto& tensor_proto) {
441 auto context = ContextFromProto(tensor_proto);
442 context->SwitchToDevice();
443 if (NumelFromTensorProto(tensor_proto) == 0 &&
444 tensor_proto.data_type() == TensorProto_DataType_UNDEFINED) {
446 return caffe2::empty(
448 at::dtype<float>().device(
449 OptionToDevice(tensor_proto.device_detail())));
451 return caffe2::empty(
452 DimsFromTensorProto(tensor_proto),
453 TensorOptionsFromProto(tensor_proto));
457 void TensorDeserializer::Deserialize(
const BlobProto& blob_proto,
Blob* blob) {
458 auto tensor_proto = blob_proto.tensor();
459 auto context = ContextFromProto(tensor_proto);
460 context->SwitchToDevice();
461 if (NumelFromTensorProto(tensor_proto) == 0 &&
462 tensor_proto.data_type() == TensorProto_DataType_UNDEFINED) {
464 VLOG(1) <<
"Deseriralizing an empty Tensor.";
465 BlobGetMutableTensor(
468 at::dtype<float>().device(
469 OptionToDevice(tensor_proto.device_detail())));
473 BlobGetMutableTensor(
475 DimsFromTensorProto(tensor_proto),
476 TensorOptionsFromProto(tensor_proto)));
480 void TensorDeserializer::DeserializeToTensor(
481 const TensorProto& tensor_proto,
484 tensor->storage_initialized() && tensor->dtype_initialized(),
485 "Tensor must be initialized before passed into Deserialize function.");
488 auto uniq_ptr = ContextFromProto(tensor_proto);
490 auto context = uniq_ptr.get();
491 context->SwitchToDevice();
493 int64_t chunkBegin = 0;
494 auto chunkEnd = tensor->numel();
495 if (tensor_proto.has_segment()) {
496 chunkBegin = tensor_proto.segment().begin();
497 chunkEnd = tensor_proto.segment().end();
500 0 <= chunkBegin && chunkBegin <= chunkEnd && chunkEnd <= tensor->numel(),
505 " with total tensor size ",
507 auto chunkSize = chunkEnd - chunkBegin;
509 switch (tensor_proto.data_type()) {
510 case TensorProto_DataType_FLOAT:
511 detail::CopyFromProtoAsIs(
513 tensor_proto.float_data(),
514 tensor->template mutable_data<float>() + chunkBegin,
517 case TensorProto_DataType_INT32:
518 detail::CopyFromProtoAsIs(
520 tensor_proto.int32_data(),
521 tensor->template mutable_data<int>() + chunkBegin,
524 case TensorProto_DataType_BYTE:
529 tensor_proto.byte_data().size(),
530 "Incorrect proto field size.");
531 context->template CopyToCPU<uint8_t>(
533 reinterpret_cast<const uint8_t*
>(tensor_proto.byte_data().data()),
534 tensor->template mutable_data<uint8_t>() + chunkBegin);
536 case TensorProto_DataType_STRING:
539 string* content = tensor->template mutable_data<string>();
540 for (
int i = 0; i < chunkSize; ++i) {
541 content[i + chunkBegin] = tensor_proto.string_data(i);
545 case TensorProto_DataType_BOOL:
546 detail::CopyFromProtoWithCast(
548 tensor_proto.int32_data(),
549 tensor->template mutable_data<bool>() + chunkBegin,
552 case TensorProto_DataType_UINT8:
553 detail::CopyFromProtoWithCast(
555 tensor_proto.int32_data(),
556 tensor->template mutable_data<uint8_t>() + chunkBegin,
559 case TensorProto_DataType_INT8:
560 detail::CopyFromProtoWithCast(
562 tensor_proto.int32_data(),
563 tensor->template mutable_data<int8_t>() + chunkBegin,
566 case TensorProto_DataType_UINT16:
567 detail::CopyFromProtoWithCast(
569 tensor_proto.int32_data(),
570 tensor->template mutable_data<uint16_t>() + chunkBegin,
573 case TensorProto_DataType_INT16:
574 detail::CopyFromProtoWithCast(
576 tensor_proto.int32_data(),
577 tensor->template mutable_data<int16_t>() + chunkBegin,
580 case TensorProto_DataType_INT64:
581 detail::CopyFromProtoAsIs(
583 tensor_proto.int64_data(),
584 tensor->template mutable_data<int64_t>() + chunkBegin,
587 case TensorProto_DataType_FLOAT16:
588 if (tensor_proto.has_byte_data()) {
589 const int kValue = 1;
591 reinterpret_cast<const char*>(&kValue)[0],
593 "Serialization of FLOAT16 on big endian platform " 594 "is not written yet.");
597 tensor_proto.byte_data().size(),
598 "Incorrect proto field size.");
599 context->template CopyToCPU<at::Half>(
601 reinterpret_cast<const at::Half*
>(tensor_proto.byte_data().data()),
602 tensor->template mutable_data<at::Half>() + chunkBegin);
605 detail::CopyFromProtoWithCast(
607 tensor_proto.int32_data(),
608 reinterpret_cast<uint16_t*
>(
609 tensor->template mutable_data<at::Half>()) +
614 case TensorProto_DataType_DOUBLE:
615 detail::CopyFromProtoAsIs(
617 tensor_proto.double_data(),
618 tensor->template mutable_data<double>() + chunkBegin,
621 case TensorProto_DataType_UNDEFINED: {
623 void* raw_ptr =
nullptr;
624 for (
int i = 0; i < chunkSize; ++i) {
627 raw_ptr = tensor->raw_mutable_data(temp_blob.
meta());
631 static_cast<char*
>(raw_ptr) +
638 context->FinishDeviceComputation();
641 Tensor TensorDeserializer::Deserialize(
const TensorProto& tensor_proto) {
642 auto tensor = EmptyTensorFromProto(tensor_proto);
643 DeserializeToTensor(tensor_proto, &tensor);
651 std::string SerializeAsString_EnforceCheck(
652 const google::protobuf::MessageLite& msg,
653 const char* error_location) {
654 std::string serialize_output;
655 bool result = msg.SerializeToString(&serialize_output);
656 if (!error_location) {
657 CAFFE_ENFORCE(result,
"protobuf::SerializeToString failed");
659 CAFFE_ENFORCE(result,
660 "protobuf::SerializeToString failed for ", error_location);
662 return serialize_output;
Blob is a general container that hosts a typed pointer.
C10_NODISCARD TensorOptions device(c10::optional< Device > device) const noexcept
Return a copy of TensorOptions with device set to the given one, or cleared if device is nullopt...
TensorSerializer is the serializer for Tensors.
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
void Serialize(const void *pointer, TypeMeta typeMeta, const string &name, SerializationAcceptor acceptor) override
Serializes a Blob.
void DeserializeBlob(const string &content, Blob *result)
Deserializes from a string containing either BlobProto or TensorProto.
StringSerializer is the serializer for String.
A type id is a unique id for a given C++ type.
void Serialize(const void *pointer, TypeMeta typeMeta, const string &name, SerializationAcceptor acceptor) override
Serializes a Blob.
TensorDeserializer is the deserializer for Tensors.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
int GetGPUIDForPointer(const void *ptr)
Gets the GPU id that the current pointer is located at.
const TypeMeta & meta() const noexcept
Returns the meta info of the blob.
T * GetMutable()
Gets a mutable pointer to the stored object.
void SerializeBlob(const Blob &blob, const string &name, BlobSerializerBase::SerializationAcceptor acceptor, int chunk_size)
Serializes the given blob, if possible.
StringDeserializer is the deserializer for Strings.
BlobSerializerBase is an abstract class that serializes a blob to a string.