Caffe2 - C++ API
A deep learning, cross platform ML framework
serialize.h
1 #pragma once
2 
3 #include <torch/serialize/archive.h>
4 #include <torch/types.h>
5 
6 #include <cstddef>
7 #include <cstdint>
8 #include <deque>
9 #include <string>
10 #include <vector>
11 
12 namespace torch {
13 namespace optim {
14 
15 // Note: These functions are all called `serialize()` so they can be called
16 // inside a template where the archive type is a template type and can thus be
17 // passed such that the appropriate overload is selected.
18 
20 void serialize(
21  serialize::OutputArchive& archive,
22  const std::string& key,
23  const std::vector<int64_t>& steps);
24 
26 void serialize(
27  serialize::InputArchive& archive,
28  const std::string& key,
29  std::vector<int64_t>& steps);
30 
32 template <typename BufferContainer>
33 void serialize(
34  serialize::OutputArchive& archive,
35  const std::string& key,
36  const BufferContainer& buffers) {
37  archive.write(
38  key + "/size", torch::tensor(static_cast<int64_t>(buffers.size())));
39  for (size_t index = 0; index < buffers.size(); ++index) {
40  archive.write(
41  key + "/" + std::to_string(index), buffers[index], /*is_buffer=*/true);
42  }
43 }
44 
46 template <typename BufferContainer>
47 void serialize(
48  serialize::InputArchive& archive,
49  const std::string& key,
50  BufferContainer& buffers) {
51  buffers.clear();
52  torch::Tensor size_tensor;
53  archive.read(key + "/size", size_tensor);
54  const size_t size = size_tensor.item<int64_t>();
55  for (size_t index = 0; index < size; ++index) {
56  buffers.emplace_back();
57  archive.read(
58  key + "/" + std::to_string(index), buffers.back(), /*is_buffer=*/true);
59  }
60 }
61 
62 #define _TORCH_OPTIM_SERIALIZE(name) \
63  torch::optim::serialize(archive, #name, self.name)
64 
65 } // namespace optim
66 } // namespace torch
Definition: jit_type.h:17