1 #include <torch/serialize/input-archive.h> 3 #include <torch/types.h> 4 #include <torch/utils.h> 6 #include <torch/csrc/jit/import.h> 7 #include <torch/csrc/jit/script/module.h> 9 #include <c10/util/Exception.h> 20 : module_(
std::make_shared<jit::script::Module>()) {}
23 const std::string& key,
26 auto param = module_->find_parameter(key);
27 auto buffer = module_->find_buffer(key);
29 param !=
nullptr || buffer !=
nullptr,
30 "No such serialized tensor '",
34 auto read_param = is_buffer ? buffer : param;
35 auto read_tensor = read_param->slot()->toTensor();
37 bool(buffer) == is_buffer,
38 "Expected deserialized tensor for key '", key,
39 "' to ", is_buffer ?
"not " :
"",
"be a buffer, but it was not");
41 if (tensor.defined()) {
43 if (tensor.
device() != read_tensor.device()) {
46 tensor.set_(read_tensor);
49 tensor = std::move(read_tensor);
54 if (
auto* named_module = module_->find_module(key)) {
55 AT_ASSERT(named_module->module !=
nullptr);
56 archive.module_ = std::move(named_module->module);
58 AT_ERROR(
"No such serialized submodule: '", key,
"'");
64 module_ = torch::jit::load(filename, std::move(device));
69 module_ = torch::jit::load(stream, std::move(device));
Device device() const
Returns a Tensor's device.
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...