3 #include <torch/serialize/archive.h> 4 #include <torch/types.h> 21 serialize::OutputArchive& archive,
22 const std::string& key,
23 const std::vector<int64_t>& steps);
27 serialize::InputArchive& archive,
28 const std::string& key,
29 std::vector<int64_t>& steps);
32 template <
typename BufferContainer>
34 serialize::OutputArchive& archive,
35 const std::string& key,
36 const BufferContainer& buffers) {
38 key +
"/size", torch::tensor(static_cast<int64_t>(buffers.size())));
39 for (
size_t index = 0; index < buffers.size(); ++index) {
41 key +
"/" + std::to_string(index), buffers[index],
true);
46 template <
typename BufferContainer>
48 serialize::InputArchive& archive,
49 const std::string& key,
50 BufferContainer& buffers) {
53 archive.read(key +
"/size", size_tensor);
54 const size_t size = size_tensor.item<int64_t>();
55 for (
size_t index = 0; index < size; ++index) {
56 buffers.emplace_back();
58 key +
"/" + std::to_string(index), buffers.back(),
true);
62 #define _TORCH_OPTIM_SERIALIZE(name) \ 63 torch::optim::serialize(archive, #name, self.name)