1 #include <google/protobuf/util/json_util.h> 2 #include <google/protobuf/util/type_resolver_util.h> 4 #include <ATen/core/functional.h> 5 #include <c10/util/Exception.h> 6 #include <torch/csrc/jit/import.h> 7 #include <torch/csrc/jit/import_source.h> 8 #include <torch/csrc/jit/ir.h> 9 #include <torch/csrc/jit/operator.h> 10 #include <torch/csrc/jit/pickler.h> 11 #include <torch/csrc/jit/script/script_type_parser.h> 13 #include "caffe2/core/common.h" 14 #include "caffe2/core/types.h" 15 #include "caffe2/proto/caffe2_pb.h" 16 #include "caffe2/proto/torch_pb.h" 17 #include "caffe2/serialize/file_adapter.h" 18 #include "caffe2/serialize/inline_container.h" 19 #include "caffe2/serialize/istream_adapter.h" 21 #include <ATen/ATen.h> 25 #include <unordered_map> 43 class ScriptModuleDeserializer final {
45 ScriptModuleDeserializer(
const std::string& filename);
46 ScriptModuleDeserializer(std::istream* is);
47 explicit ScriptModuleDeserializer(std::unique_ptr<ReadAdapterInterface> rai);
49 script::ModuleLookup module_lookup,
51 script::ExtraFilesMap& extra_files);
55 const torch::TensorDef& tensor_proto,
56 std::unordered_map<std::string, at::Storage>& storageMap);
58 void convertModule(
const torch::ModuleDef& module_def);
60 void loadTensorTable(torch::ModelDef* model_def);
61 void loadAttributeTable();
62 void loadLibs(torch::ModelDef* model_def);
67 script::ModuleLookup moduleLookup_;
69 std::vector<std::string> moduleStack_;
71 std::vector<at::Tensor> tensor_table_;
72 std::vector<IValue> attribute_table_;
75 ScriptModuleDeserializer::ScriptModuleDeserializer(
const std::string& filename)
76 : reader_(filename.c_str()) {
80 ScriptModuleDeserializer::ScriptModuleDeserializer(std::istream* is)
83 ScriptModuleDeserializer::ScriptModuleDeserializer(
84 std::unique_ptr<ReadAdapterInterface> rai)
85 : reader_(
std::move(rai)) {}
87 void ScriptModuleDeserializer::deserialize(
88 script::ModuleLookup module_lookup,
90 script::ExtraFilesMap& extra_files) {
91 torch::ModelDef model_def;
94 std::tie(data_ptr, data_size) = reader_.getRecord(
"model.json");
97 std::string url_prefix =
"type.googleapis.com";
98 std::unique_ptr<::google::protobuf::util::TypeResolver> resolver(
99 ::google::protobuf::util::NewTypeResolverForDescriptorPool(
100 url_prefix, model_def.GetDescriptor()->file()->pool()));
101 std::string json_string = std::string(
102 static_cast<char*>(data_ptr.get()),
103 static_cast<char*>(data_ptr.get()) + data_size);
104 std::string binary_string;
105 auto convert_result = ::google::protobuf::util::JsonToBinaryString(
107 url_prefix +
"/" + model_def.GetDescriptor()->full_name(),
110 if (!convert_result.ok()) {
111 std::stringstream ss;
112 ss << convert_result;
116 model_def.ParseFromString(binary_string),
117 "JSON transcoder produced invalid protobuf output.");
118 moduleLookup_ = module_lookup;
121 const auto& module_def = model_def.main_module();
124 for (
const auto& kv : extra_files) {
125 const std::string& key =
"extra/" + kv.first;
128 std::tie(meta_ptr, meta_size) = reader_.getRecord(key);
129 extra_files[kv.first] =
130 std::string(static_cast<char*>(meta_ptr.get()), meta_size);
133 loadTensorTable(&model_def);
134 if (model_def.proto_version() >= 2) {
135 loadAttributeTable();
136 loadLibs(&model_def);
141 convertModule(module_def);
144 void ScriptModuleDeserializer::loadTensorTable(torch::ModelDef* model_def) {
145 std::unordered_map<std::string, at::Storage> storageMap;
146 for (
const torch::TensorDef& tensor : model_def->tensors()) {
147 tensor_table_.emplace_back(loadTensor(tensor, storageMap));
151 void ScriptModuleDeserializer::loadAttributeTable() {
153 size_t attributes_size;
154 std::tie(attributes_ptr, attributes_size) =
155 reader_.getRecord(
"attributes.pkl");
156 Unpickler unpickler(attributes_ptr.get(), attributes_size, &tensor_table_);
157 attribute_table_ = unpickler.parse_ivalue_list();
160 void ScriptModuleDeserializer::loadLibs(torch::ModelDef* model_def) {
161 const auto lib_def = model_def->libs();
162 if (lib_def.has_torchscript_arena()) {
165 std::tie(data, size) = reader_.getRecord(lib_def.torchscript_arena().key());
166 std::string data_str(static_cast<const char*>(data.get()), size);
167 script::import_libs(data_str, tensor_table_);
171 at::Tensor ScriptModuleDeserializer::loadTensor(
172 const torch::TensorDef& tensor_proto,
173 std::unordered_map<std::string, at::Storage>& storageMap) {
174 std::vector<int64_t> dims(
175 tensor_proto.dims().begin(), tensor_proto.dims().end());
176 std::vector<int64_t> strides(
177 tensor_proto.strides().begin(), tensor_proto.strides().end());
178 auto type = at::typeMetaToScalarType(
179 caffe2::DataTypeToTypeMeta(tensor_proto.data_type()));
180 const std::string& record_key = tensor_proto.data().key();
181 AT_ASSERT(tensor_proto.has_device() && !tensor_proto.device().empty());
183 if (device_.has_value()) {
185 device = device_.value();
188 auto storage_it = storageMap.find(record_key);
189 if (storage_it == storageMap.end()) {
191 uint64_t record_size;
192 std::tie(storage_ptr, record_size) = reader_.getRecord(record_key);
194 at::CPU(type).typeMeta(),
195 record_size / at::CPU(type).typeMeta().itemsize(),
196 std::move(storage_ptr),
199 if (device.type() == at::DeviceType::CPU) {
201 storageMap.insert(std::make_pair(record_key, cpu_storage)).first;
202 }
else if (device.type() == at::DeviceType::CUDA) {
204 at::empty({0}, at::CPU(type).options()).set_(cpu_storage);
206 cpu_tensor.to(device, cpu_tensor.scalar_type()).storage();
208 storageMap.insert(std::make_pair(record_key, cuda_storage)).first;
211 "supported devices include CPU and CUDA, however got ",
212 at::DeviceTypeName(device.type(),
false));
215 if (storage_it->second.device().type() != device.type() ||
216 (device.has_index() &&
217 storage_it->second.device().
index() != device.index())) {
218 std::stringstream oss;
219 oss <<
"storage previously was specified with device " 220 << storage_it->second.device() <<
"but now is specified with device " 221 << device << std::endl;
226 if (device.type() == at::DeviceType::CPU) {
228 at::empty({0}, at::CPU(type).options())
229 .set_(storage_it->second, tensor_proto.offset(), dims, strides);
230 }
else if (device.type() == at::DeviceType::CUDA) {
232 at::empty({0}, at::CUDA(type).options())
233 .set_(storage_it->second, tensor_proto.offset(), dims, strides);
235 AT_ASSERT(result.defined());
237 result = autograd::make_variable(result, tensor_proto.requires_grad());
242 void ScriptModuleDeserializer::convertModule(
243 const torch::ModuleDef& module_def) {
244 std::shared_ptr<script::Module> module = moduleLookup_(moduleStack_);
245 module->set_optimized(module_def.optimize());
246 for (
int i = 0; i < module_def.submodules_size(); ++i) {
247 const torch::ModuleDef& sub_def = module_def.submodules(i);
248 moduleStack_.emplace_back(sub_def.name());
249 convertModule(sub_def);
250 moduleStack_.pop_back();
252 for (
int i = 0; i < module_def.parameters_size(); ++i) {
253 const torch::ParameterDef& param_def = module_def.parameters(i);
254 at::Tensor tensor = tensor_table_.at(param_def.tensor_id());
255 if (param_def.is_buffer()) {
256 module->register_buffer(param_def.name(), tensor);
258 module->register_parameter(param_def.name(), tensor,
false);
261 for (
int i = 0; i < module_def.attributes_size(); ++i) {
262 const torch::AttributeDef& attr_def = module_def.attributes(i);
263 if (module->find_buffer(attr_def.name())) {
268 module->register_attribute(
270 script::parseType(attr_def.type()),
271 attribute_table_.at(attr_def.id())
274 if (module_def.has_torchscript_arena()) {
277 std::tie(data, size) =
278 reader_.getRecord(module_def.torchscript_arena().key());
279 std::string data_str(static_cast<const char*>(data.get()), size);
280 script::import_methods(module, data_str, tensor_table_);
286 void import_ir_module(
287 script::ModuleLookup module_lookup,
290 script::ExtraFilesMap& extra_files) {
291 ScriptModuleDeserializer deserializer(&in);
292 deserializer.deserialize(module_lookup, device, extra_files);
295 void import_ir_module(
296 script::ModuleLookup module_lookup,
297 const std::string& filename,
299 script::ExtraFilesMap& extra_files) {
300 ScriptModuleDeserializer deserializer(filename);
301 deserializer.deserialize(module_lookup, device, extra_files);
304 void import_ir_module(
305 script::ModuleLookup module_lookup,
306 std::unique_ptr<ReadAdapterInterface> rai,
308 script::ExtraFilesMap& extra_files) {
309 ScriptModuleDeserializer deserializer(std::move(rai));
310 deserializer.deserialize(module_lookup, device, extra_files);
313 std::shared_ptr<script::Module> load(
316 script::ExtraFilesMap& extra_files) {
317 std::unique_ptr<IStreamAdapter> rai =
318 caffe2::make_unique<IStreamAdapter>(&in);
319 auto module = load(std::move(rai), device, extra_files);
323 std::shared_ptr<script::Module> load(
324 const std::string& filename,
326 script::ExtraFilesMap& extra_files) {
327 std::unique_ptr<FileAdapter> rai = caffe2::make_unique<FileAdapter>(filename);
328 auto module = load(std::move(rai), device, extra_files);
332 std::shared_ptr<script::Module> load(
333 std::unique_ptr<ReadAdapterInterface> rai,
335 script::ExtraFilesMap& extra_files) {
336 auto module = std::make_shared<script::Module>();
338 auto module_lookup = [&](
const std::vector<std::string>& qualified_name) {
339 std::shared_ptr<script::Module> curr = module;
340 for (
const auto& name : qualified_name) {
341 if (curr->find_module(name) ==
nullptr) {
342 curr->register_module(name, std::make_shared<script::Module>());
344 curr = curr->get_module(name);
349 ScriptModuleDeserializer deserializer(std::move(rai));
350 deserializer.deserialize(module_lookup, device, extra_files);
Represents a a compute device on which a tensor is located.
DeviceIndex index() const noexcept
Returns the optional index.