Caffe2 - C++ API
A deep learning, cross platform ML framework
input-archive.cpp
1 #include <torch/serialize/input-archive.h>
2 
3 #include <torch/types.h>
4 #include <torch/utils.h>
5 
6 #include <torch/csrc/jit/import.h>
7 #include <torch/csrc/jit/script/module.h>
8 
9 #include <c10/util/Exception.h>
10 
11 #include <istream>
12 #include <memory>
13 #include <string>
14 #include <utility>
15 
16 namespace torch {
17 namespace serialize {
18 
20  : module_(std::make_shared<jit::script::Module>()) {}
21 
23  const std::string& key,
24  Tensor& tensor,
25  bool is_buffer) {
26  auto param = module_->find_parameter(key);
27  auto buffer = module_->find_buffer(key);
28  AT_CHECK(
29  param != nullptr || buffer != nullptr,
30  "No such serialized tensor '",
31  key,
32  "'");
33  // clang-format off
34  auto read_param = is_buffer ? buffer : param;
35  auto read_tensor = read_param->slot()->toTensor();
36  AT_CHECK(
37  bool(buffer) == is_buffer,
38  "Expected deserialized tensor for key '", key,
39  "' to ", is_buffer ? "not " : "", "be a buffer, but it was not");
40  // clang-format on
41  if (tensor.defined()) {
42  torch::NoGradGuard guard;
43  if (tensor.device() != read_tensor.device()) {
44  tensor.set_data(autograd::Variable(read_tensor).data());
45  } else {
46  tensor.set_(read_tensor);
47  }
48  } else {
49  tensor = std::move(read_tensor);
50  }
51 }
52 
53 void InputArchive::read(const std::string& key, InputArchive& archive) {
54  if (auto* named_module = module_->find_module(key)) {
55  AT_ASSERT(named_module->module != nullptr);
56  archive.module_ = std::move(named_module->module);
57  } else {
58  AT_ERROR("No such serialized submodule: '", key, "'");
59  }
60 }
61 
62 void InputArchive::load_from(const std::string& filename,
63  c10::optional<torch::Device> device /*= c10::nullopt*/) {
64  module_ = torch::jit::load(filename, std::move(device));
65 }
66 
67 void InputArchive::load_from(std::istream& stream,
68  c10::optional<torch::Device> device /*= c10::nullopt*/) {
69  module_ = torch::jit::load(stream, std::move(device));
70 }
71 } // namespace serialize
72 } // namespace torch
void read(const std::string &key, Tensor &tensor, bool is_buffer=false)
Reads a tensor associated with a given key.
Device device() const
Returns a Tensor&#39;s device.
void load_from(const std::string &filename, c10::optional< torch::Device > device=c10::nullopt)
Loads the InputArchive from a serialized representation stored in the file at filename.
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
Definition: variable.h:85
Definition: jit_type.h:17
InputArchive()
Default-constructs the InputArchive.
A recursive representation of tensors that can be deserialized from a file or stream.
Definition: input-archive.h:32