Caffe2 - C++ API
A deep learning, cross platform ML framework
import.cpp
1 #include <google/protobuf/util/json_util.h>
2 #include <google/protobuf/util/type_resolver_util.h>
3 
4 #include <ATen/core/functional.h>
5 #include <c10/util/Exception.h>
6 #include <torch/csrc/jit/import.h>
7 #include <torch/csrc/jit/import_source.h>
8 #include <torch/csrc/jit/ir.h>
9 #include <torch/csrc/jit/operator.h>
10 #include <torch/csrc/jit/pickler.h>
11 #include <torch/csrc/jit/script/script_type_parser.h>
12 
13 #include "caffe2/core/common.h"
14 #include "caffe2/core/types.h"
15 #include "caffe2/proto/caffe2_pb.h"
16 #include "caffe2/proto/torch_pb.h"
17 #include "caffe2/serialize/file_adapter.h"
18 #include "caffe2/serialize/inline_container.h"
19 #include "caffe2/serialize/istream_adapter.h"
20 
21 #include <ATen/ATen.h>
22 
23 #include <fstream>
24 #include <string>
25 #include <unordered_map>
26 #include <vector>
27 
28 namespace torch {
29 namespace jit {
30 
34 
35 namespace {
36 
37 // this is a deserializer class which loads script modules from pt files. the
38 // content of the file is written using PyTorchStreamWriter, for details please
39 // check caffe2/serialize/inline_container.h. all the records except the last
40 // one are tensor data, and the last record is a serialized ModelProto, defined
41 // in caffe2/proto/torch.proto. ModelProto contains all the metadata of the
42 // model, and it is serialized as json.
43 class ScriptModuleDeserializer final {
44  public:
45  ScriptModuleDeserializer(const std::string& filename);
46  ScriptModuleDeserializer(std::istream* is);
47  explicit ScriptModuleDeserializer(std::unique_ptr<ReadAdapterInterface> rai);
48  void deserialize(
49  script::ModuleLookup module_lookup,
51  script::ExtraFilesMap& extra_files);
52 
53  private:
54  at::Tensor loadTensor(
55  const torch::TensorDef& tensor_proto,
56  std::unordered_map<std::string, at::Storage>& storageMap);
57 
58  void convertModule(const torch::ModuleDef& module_def);
59 
60  void loadTensorTable(torch::ModelDef* model_def);
61  void loadAttributeTable();
62  void loadLibs(torch::ModelDef* model_def);
63 
65  // this is a hack to make sure the script module created in C++ is the
66  // same as created in Python
67  script::ModuleLookup moduleLookup_;
69  std::vector<std::string> moduleStack_;
70 
71  std::vector<at::Tensor> tensor_table_;
72  std::vector<IValue> attribute_table_;
73 };
74 
75 ScriptModuleDeserializer::ScriptModuleDeserializer(const std::string& filename)
76  : reader_(filename.c_str()) {
77  // TODO appropriate support for mmap, right now still use stream reader
78 }
79 
80 ScriptModuleDeserializer::ScriptModuleDeserializer(std::istream* is)
81  : reader_(is) {}
82 
83 ScriptModuleDeserializer::ScriptModuleDeserializer(
84  std::unique_ptr<ReadAdapterInterface> rai)
85  : reader_(std::move(rai)) {}
86 
87 void ScriptModuleDeserializer::deserialize(
88  script::ModuleLookup module_lookup,
90  script::ExtraFilesMap& extra_files) {
91  torch::ModelDef model_def;
92  at::DataPtr data_ptr;
93  size_t data_size;
94  std::tie(data_ptr, data_size) = reader_.getRecord("model.json");
95  // NB: cannot use JsonStringToMessage, since fbcode's protobuf is too old
96  // be consistent with JsonStringToMessage
97  std::string url_prefix = "type.googleapis.com";
98  std::unique_ptr<::google::protobuf::util::TypeResolver> resolver(
99  ::google::protobuf::util::NewTypeResolverForDescriptorPool(
100  url_prefix, model_def.GetDescriptor()->file()->pool()));
101  std::string json_string = std::string(
102  static_cast<char*>(data_ptr.get()),
103  static_cast<char*>(data_ptr.get()) + data_size);
104  std::string binary_string;
105  auto convert_result = ::google::protobuf::util::JsonToBinaryString(
106  resolver.get(),
107  url_prefix + "/" + model_def.GetDescriptor()->full_name(),
108  json_string,
109  &binary_string);
110  if (!convert_result.ok()) {
111  std::stringstream ss;
112  ss << convert_result;
113  AT_ERROR(ss.str());
114  }
115  AT_ASSERTM(
116  model_def.ParseFromString(binary_string),
117  "JSON transcoder produced invalid protobuf output.");
118  moduleLookup_ = module_lookup;
119  device_ = device;
120 
121  const auto& module_def = model_def.main_module();
122 
123  // Load extra files.
124  for (const auto& kv : extra_files) {
125  const std::string& key = "extra/" + kv.first;
126  at::DataPtr meta_ptr;
127  size_t meta_size;
128  std::tie(meta_ptr, meta_size) = reader_.getRecord(key);
129  extra_files[kv.first] =
130  std::string(static_cast<char*>(meta_ptr.get()), meta_size);
131  }
132 
133  loadTensorTable(&model_def);
134  if (model_def.proto_version() >= 2) {
135  loadAttributeTable();
136  loadLibs(&model_def);
137  }
138 
139  // TODO: this can be simplified when C++/Python interop lands,
140  // and the submodules would be created as the same in either C++ or Python
141  convertModule(module_def);
142 }
143 
144 void ScriptModuleDeserializer::loadTensorTable(torch::ModelDef* model_def) {
145  std::unordered_map<std::string, at::Storage> storageMap;
146  for (const torch::TensorDef& tensor : model_def->tensors()) {
147  tensor_table_.emplace_back(loadTensor(tensor, storageMap));
148  }
149 }
150 
151 void ScriptModuleDeserializer::loadAttributeTable() {
152  at::DataPtr attributes_ptr;
153  size_t attributes_size;
154  std::tie(attributes_ptr, attributes_size) =
155  reader_.getRecord("attributes.pkl");
156  Unpickler unpickler(attributes_ptr.get(), attributes_size, &tensor_table_);
157  attribute_table_ = unpickler.parse_ivalue_list();
158 }
159 
160 void ScriptModuleDeserializer::loadLibs(torch::ModelDef* model_def) {
161  const auto lib_def = model_def->libs();
162  if (lib_def.has_torchscript_arena()) {
163  at::DataPtr data;
164  size_t size;
165  std::tie(data, size) = reader_.getRecord(lib_def.torchscript_arena().key());
166  std::string data_str(static_cast<const char*>(data.get()), size);
167  script::import_libs(data_str, tensor_table_);
168  }
169 }
170 
171 at::Tensor ScriptModuleDeserializer::loadTensor(
172  const torch::TensorDef& tensor_proto,
173  std::unordered_map<std::string, at::Storage>& storageMap) {
174  std::vector<int64_t> dims(
175  tensor_proto.dims().begin(), tensor_proto.dims().end());
176  std::vector<int64_t> strides(
177  tensor_proto.strides().begin(), tensor_proto.strides().end());
178  auto type = at::typeMetaToScalarType(
179  caffe2::DataTypeToTypeMeta(tensor_proto.data_type()));
180  const std::string& record_key = tensor_proto.data().key();
181  AT_ASSERT(tensor_proto.has_device() && !tensor_proto.device().empty());
182  at::Device device(tensor_proto.device());
183  if (device_.has_value()) {
184  // override the device, if user provides map_location
185  device = device_.value();
186  }
187 
188  auto storage_it = storageMap.find(record_key);
189  if (storage_it == storageMap.end()) {
190  at::DataPtr storage_ptr;
191  uint64_t record_size;
192  std::tie(storage_ptr, record_size) = reader_.getRecord(record_key);
193  auto cpu_storage = at::Storage(
194  at::CPU(type).typeMeta(),
195  record_size / at::CPU(type).typeMeta().itemsize(),
196  std::move(storage_ptr),
197  /*allocator=*/nullptr,
198  /*resizable=*/false); // NB: we didn't set any allocator for the tensor
199  if (device.type() == at::DeviceType::CPU) {
200  storage_it =
201  storageMap.insert(std::make_pair(record_key, cpu_storage)).first;
202  } else if (device.type() == at::DeviceType::CUDA) {
203  at::Tensor cpu_tensor =
204  at::empty({0}, at::CPU(type).options()).set_(cpu_storage);
205  at::Storage cuda_storage =
206  cpu_tensor.to(device, cpu_tensor.scalar_type()).storage();
207  storage_it =
208  storageMap.insert(std::make_pair(record_key, cuda_storage)).first;
209  } else {
210  AT_ERROR(
211  "supported devices include CPU and CUDA, however got ",
212  at::DeviceTypeName(device.type(), false));
213  }
214  }
215  if (storage_it->second.device().type() != device.type() ||
216  (device.has_index() &&
217  storage_it->second.device().index() != device.index())) {
218  std::stringstream oss;
219  oss << "storage previously was specified with device "
220  << storage_it->second.device() << "but now is specified with device "
221  << device << std::endl;
222  AT_ERROR(oss.str());
223  }
224 
225  at::Tensor result;
226  if (device.type() == at::DeviceType::CPU) {
227  result =
228  at::empty({0}, at::CPU(type).options())
229  .set_(storage_it->second, tensor_proto.offset(), dims, strides);
230  } else if (device.type() == at::DeviceType::CUDA) {
231  result =
232  at::empty({0}, at::CUDA(type).options())
233  .set_(storage_it->second, tensor_proto.offset(), dims, strides);
234  }
235  AT_ASSERT(result.defined());
236 
237  result = autograd::make_variable(result, tensor_proto.requires_grad());
238 
239  return result;
240 }
241 
242 void ScriptModuleDeserializer::convertModule(
243  const torch::ModuleDef& module_def) {
244  std::shared_ptr<script::Module> module = moduleLookup_(moduleStack_);
245  module->set_optimized(module_def.optimize());
246  for (int i = 0; i < module_def.submodules_size(); ++i) {
247  const torch::ModuleDef& sub_def = module_def.submodules(i);
248  moduleStack_.emplace_back(sub_def.name());
249  convertModule(sub_def);
250  moduleStack_.pop_back();
251  }
252  for (int i = 0; i < module_def.parameters_size(); ++i) {
253  const torch::ParameterDef& param_def = module_def.parameters(i);
254  at::Tensor tensor = tensor_table_.at(param_def.tensor_id());
255  if (param_def.is_buffer()) {
256  module->register_buffer(param_def.name(), tensor);
257  } else {
258  module->register_parameter(param_def.name(), tensor, /*is_buffer=*/false);
259  }
260  }
261  for (int i = 0; i < module_def.attributes_size(); ++i) {
262  const torch::AttributeDef& attr_def = module_def.attributes(i);
263  if (module->find_buffer(attr_def.name())) {
264  // TODO: handle this above so this can be removed
265  continue;
266  }
267 
268  module->register_attribute(
269  attr_def.name(),
270  script::parseType(attr_def.type()),
271  attribute_table_.at(attr_def.id())
272  );
273  }
274  if (module_def.has_torchscript_arena()) {
275  at::DataPtr data;
276  size_t size;
277  std::tie(data, size) =
278  reader_.getRecord(module_def.torchscript_arena().key());
279  std::string data_str(static_cast<const char*>(data.get()), size);
280  script::import_methods(module, data_str, tensor_table_);
281  }
282 }
283 
284 } // namespace
285 
286 void import_ir_module(
287  script::ModuleLookup module_lookup,
288  std::istream& in,
290  script::ExtraFilesMap& extra_files) {
291  ScriptModuleDeserializer deserializer(&in);
292  deserializer.deserialize(module_lookup, device, extra_files);
293 }
294 
295 void import_ir_module(
296  script::ModuleLookup module_lookup,
297  const std::string& filename,
299  script::ExtraFilesMap& extra_files) {
300  ScriptModuleDeserializer deserializer(filename);
301  deserializer.deserialize(module_lookup, device, extra_files);
302 }
303 
304 void import_ir_module(
305  script::ModuleLookup module_lookup,
306  std::unique_ptr<ReadAdapterInterface> rai,
308  script::ExtraFilesMap& extra_files) {
309  ScriptModuleDeserializer deserializer(std::move(rai));
310  deserializer.deserialize(module_lookup, device, extra_files);
311 }
312 
313 std::shared_ptr<script::Module> load(
314  std::istream& in,
316  script::ExtraFilesMap& extra_files) {
317  std::unique_ptr<IStreamAdapter> rai =
318  caffe2::make_unique<IStreamAdapter>(&in);
319  auto module = load(std::move(rai), device, extra_files);
320  return module;
321 }
322 
323 std::shared_ptr<script::Module> load(
324  const std::string& filename,
326  script::ExtraFilesMap& extra_files) {
327  std::unique_ptr<FileAdapter> rai = caffe2::make_unique<FileAdapter>(filename);
328  auto module = load(std::move(rai), device, extra_files);
329  return module;
330 }
331 
332 std::shared_ptr<script::Module> load(
333  std::unique_ptr<ReadAdapterInterface> rai,
335  script::ExtraFilesMap& extra_files) {
336  auto module = std::make_shared<script::Module>();
337 
338  auto module_lookup = [&](const std::vector<std::string>& qualified_name) {
339  std::shared_ptr<script::Module> curr = module;
340  for (const auto& name : qualified_name) {
341  if (curr->find_module(name) == nullptr) {
342  curr->register_module(name, std::make_shared<script::Module>());
343  }
344  curr = curr->get_module(name);
345  }
346  return curr;
347  };
348 
349  ScriptModuleDeserializer deserializer(std::move(rai));
350  deserializer.deserialize(module_lookup, device, extra_files);
351 
352  return module;
353 }
354 
355 } // namespace jit
356 } // namespace torch
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
Definition: jit_type.h:17
DeviceIndex index() const noexcept
Returns the optional index.
Definition: Device.h:70