Caffe2 - C++ API
A deep learning, cross platform ML framework
export.cpp
1 #include <google/protobuf/util/json_util.h>
2 #include <google/protobuf/util/type_resolver_util.h>
3 
4 #include <torch/csrc/autograd/symbolic.h>
5 #include <torch/csrc/jit/export.h>
6 #include <torch/csrc/onnx/onnx.h>
7 
8 #include <ATen/core/functional.h>
9 #include <c10/util/Exception.h>
10 #include <torch/csrc/jit/passes/dead_code_elimination.h>
11 #include <torch/csrc/jit/passes/python_print.h>
12 #include <torch/csrc/jit/pickler.h>
13 
14 #include <caffe2/core/types.h>
15 #include <caffe2/proto/caffe2_pb.h>
16 #include <caffe2/proto/torch_pb.h>
17 #include <caffe2/serialize/inline_container.h>
18 #include <onnx/onnx_pb.h>
19 
20 #include <ATen/ATen.h>
21 #include <c10/util/Optional.h>
22 
23 #include <fstream>
24 #include <memory>
25 #include <sstream>
26 #include <stack>
27 #include <string>
28 #include <vector>
29 
30 namespace torch {
31 namespace jit {
32 
33 namespace {
34 namespace onnx_torch = ::torch::onnx;
35 namespace onnx = ::ONNX_NAMESPACE;
36 
37 class ScriptModuleSerializer;
38 
39 std::string getNodeStackTraceString(const Node* n) {
40  std::stringstream ss;
41  if (n->getSourceLocation()) {
42  n->getSourceLocation()->highlight(ss);
43  } else {
44  ss << "<unknown location>";
45  }
46  return ss.str();
47 }
48 
49 void validateBlock(
50  Block* b,
51  onnx_torch::OperatorExportTypes operator_export_type) {
52  for (auto node : b->nodes()) {
53  for (Block* sub_block : node->blocks()) {
54  validateBlock(sub_block, operator_export_type);
55  }
56  // Macro'ed so we get a marginally better line number on failed export
57 #define FAIL_EXPORT(name) \
58  throw std::runtime_error( \
59  std::string("ONNX export failed: ") + name + \
60  "\n\nGraph we tried to export:\n" + b->owningGraph()->toString());
61  if (node->kind() == prim::PythonOp) {
62  auto py_node = static_cast<PythonOp*>(node);
63  FAIL_EXPORT(
64  "Couldn't export Python operator " + py_node->name() +
65  "\n\nDefined at:\n" + getNodeStackTraceString(node))
66  } else {
67  // Special error messages for certain types of operators
68  if (node->kind() == aten::expand) {
69  if (operator_export_type ==
70  onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK) {
71  WithInsertPoint guard(node);
72  auto* new_node =
73  b->owningGraph()->insertNode(b->owningGraph()->create(
74  Symbol(::c10::onnx::ATen),
75  node->inputs(),
76  node->outputs().size()));
77  for (size_t i = 0; i < node->outputs().size(); ++i) {
78  node->output(i)->replaceAllUsesWith(new_node->output(i));
79  }
80  new_node->s_(Symbol::fromQualString("attr::operator"), "expand");
81  }
82  }
83  if (node->kind() == prim::PackPadded || node->kind() == prim::PadPacked) {
84  FAIL_EXPORT(
85  "Cannot export individual pack_padded_sequence or pad_packed_sequence; these operations must occur in pairs.\n\nUsage of this operation occurred at:\n" +
86  getNodeStackTraceString(node));
87  }
88  bool is_aten_enabled = operator_export_type ==
89  onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK ||
90  operator_export_type == onnx_torch::OperatorExportTypes::ONNX_ATEN;
91  if (!node->kind().is_onnx() && !is_aten_enabled && !node->mustBeNone()) {
92  FAIL_EXPORT(
93  "Couldn't export operator " + node->kind().toDisplayString() +
94  "\n\nDefined at:\n" + getNodeStackTraceString(node));
95  }
96  }
97 #undef FAIL_EXPORT
98  }
99 }
100 
101 void validateGraph(
102  const std::shared_ptr<Graph>& graph,
103  onnx_torch::OperatorExportTypes operator_export_type) {
104  validateBlock(graph->block(), operator_export_type);
105  EliminateDeadCode(graph->block());
106 }
107 
108 class EncoderBase {
109  public:
110  EncoderBase(
111  onnx_torch::OperatorExportTypes operator_export_type,
112  bool strip_doc);
113 
114  onnx::ModelProto get_model_proto() {
115  return model_proto_;
116  }
117 
118  protected:
119  void EncodeGraph(
120  onnx::GraphProto* graph_proto,
121  const std::shared_ptr<Graph>& graph,
122  const std::vector<at::Tensor>& initializers = {});
123 
124  void EncodeBlock(
125  onnx::GraphProto* graph_proto,
126  const Block* block,
127  const std::vector<at::Tensor>& initializers = {});
128 
129  virtual void EncodeTensor(
130  onnx::TensorProto* tensor_proto,
131  const at::Tensor& tensor,
132  const c10::optional<std::string> external_ref = {}) = 0;
133 
134  virtual void EncodeIntermediateValueInfo(
135  onnx::GraphProto* graph_proto,
136  const Value* n){};
137 
138  virtual void EncodeValueInfo(
139  onnx::GraphProto* graph_proto,
140  onnx::ValueInfoProto* v,
141  const Value* n);
142 
143  void AddAttribute(
144  onnx::NodeProto* node_proto,
145  const jit::Node* node,
146  const jit::Symbol name);
147 
148  onnx::ModelProto model_proto_;
149  size_t num_blocks_;
150  onnx_torch::OperatorExportTypes operator_export_type_;
151  bool strip_doc_;
152 };
153 
154 onnx::TensorProto_DataType ATenTypeToOnnxType(at::ScalarType at_type) {
155  switch (at_type) {
156  case at::kDouble:
157  return onnx::TensorProto_DataType_DOUBLE;
158  case at::kFloat:
159  return onnx::TensorProto_DataType_FLOAT;
160  case at::kHalf:
161  return onnx::TensorProto_DataType_FLOAT16;
162  case at::kByte:
163  return onnx::TensorProto_DataType_UINT8;
164  case at::kChar:
165  return onnx::TensorProto_DataType_INT8;
166  case at::kShort:
167  return onnx::TensorProto_DataType_INT16;
168  case at::kInt:
169  return onnx::TensorProto_DataType_INT32;
170  case at::kLong:
171  return onnx::TensorProto_DataType_INT64;
172  default:
173  AT_ERROR("unexpected tensor scalar type");
174  }
175 }
176 
177 EncoderBase::EncoderBase(
178  onnx_torch::OperatorExportTypes operator_export_type,
179  bool strip_doc)
180  : num_blocks_(0),
181  operator_export_type_(operator_export_type),
182  strip_doc_(strip_doc) {
183  model_proto_.set_producer_name("pytorch");
184  model_proto_.set_ir_version(onnx::IR_VERSION);
185  // TODO: set the producer version using appropriate function call
186  model_proto_.set_producer_version("1.1");
187 }
188 
189 void EncoderBase::EncodeValueInfo(
190  onnx::GraphProto* graph_proto,
191  onnx::ValueInfoProto* v,
192  const Value* n) {
193  v->set_name(n->uniqueName());
194  onnx::TypeProto* t = v->mutable_type();
195  onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type();
196 
197  onnx::TensorShapeProto* shape = tensor_type->mutable_shape();
198  if (CompleteTensorTypePtr node_type = n->type()->cast<CompleteTensorType>()) {
199  const std::vector<std::int64_t>& sizes = node_type->sizes();
200  for (size_t i = 0; i < sizes.size(); i++) {
201  shape->add_dim();
202  shape->mutable_dim(i)->set_dim_value(sizes[i]);
203  }
204  tensor_type->set_elem_type(ATenTypeToOnnxType(node_type->scalarType()));
205  } else {
206  tensor_type->set_elem_type(onnx::TensorProto_DataType_UNDEFINED);
207  }
208 }
209 
210 void EncoderBase::EncodeGraph(
211  onnx::GraphProto* graph_proto,
212  const std::shared_ptr<Graph>& graph,
213  const std::vector<at::Tensor>& initializers) {
214  EncodeBlock(graph_proto, graph->block(), initializers);
215 }
216 
217 void EncoderBase::EncodeBlock(
218  onnx::GraphProto* graph_proto,
219  const Block* block,
220  const std::vector<at::Tensor>& initializers) {
221  AT_ASSERT(graph_proto != nullptr);
222  std::string block_name = "torch-jit-export";
223  if (num_blocks_) {
224  block_name += std::to_string(num_blocks_);
225  }
226  num_blocks_++;
227  graph_proto->set_name(block_name);
228 
229  for (auto input : block->inputs()) {
230  onnx::ValueInfoProto* v = graph_proto->add_input();
231  EncodeValueInfo(graph_proto, v, input);
232  }
233  for (auto output : block->outputs()) {
234  onnx::ValueInfoProto* v = graph_proto->add_output();
235  EncodeValueInfo(graph_proto, v, output);
236  }
237  for (auto node : block->nodes()) {
238  bool is_raw_export =
239  operator_export_type_ == onnx_torch::OperatorExportTypes::RAW;
240  if (node->mustBeNone() && !is_raw_export) {
241  // None nodes are used to implement optional inputs. One
242  // way to "not provide" an optional input is to create an
243  // Undefined node, and pass its output as that input.
244  continue;
245  }
246  auto p_n = graph_proto->add_node();
247  if (node->getSourceLocation() && !strip_doc_) {
248  std::stringstream ss;
249  node->getSourceLocation()->highlight(ss);
250  p_n->set_doc_string(ss.str());
251  }
252  for (auto input : node->inputs()) {
253  if (input->node()->mustBeNone() && !is_raw_export) {
254  p_n->add_input("");
255  } else {
256  p_n->add_input(input->uniqueName());
257  }
258  }
259  for (auto output : node->outputs()) {
260  p_n->add_output(output->uniqueName());
261  EncodeIntermediateValueInfo(graph_proto, output);
262  }
263  if (is_raw_export) {
264  AT_ASSERT(!node->kind().is_onnx());
265  p_n->set_domain(node->kind().domainString());
266  } else if (operator_export_type_ == onnx_torch::OperatorExportTypes::ONNX) {
267  AT_ASSERT(node->kind().is_onnx());
268  }
269  p_n->set_op_type(node->kind().toUnqualString());
270  for (auto attr_name : node->attributeNames()) {
271  AddAttribute(p_n, node, attr_name);
272  }
273  if (is_raw_export && node->blocks().size() > 0) {
274  auto blocks = p_n->add_attribute();
275  blocks->set_name("_blocks");
276  blocks->set_type(onnx::AttributeProto_AttributeType_GRAPHS);
277  for (auto block : node->blocks()) {
278  auto graph = blocks->add_graphs();
279  EncodeBlock(graph, block, initializers);
280  }
281  }
282  if (node->kind() == ::c10::onnx::Loop) {
283  AT_ASSERT(node->blocks().size() == 1);
284 
285  auto body = p_n->add_attribute();
286  body->set_name("body");
287  body->set_type(onnx::AttributeProto_AttributeType_GRAPH);
288  auto g = body->mutable_g();
289  EncodeBlock(g, node->blocks()[0]);
290  }
291  if (node->kind() == ::c10::onnx::If) {
292  AT_ASSERT(node->blocks().size() == 2);
293 
294  auto true_branch = p_n->add_attribute();
295  true_branch->set_name("then_branch");
296  true_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH);
297  auto true_g = true_branch->mutable_g();
298  EncodeBlock(true_g, node->blocks()[0]);
299 
300  auto false_branch = p_n->add_attribute();
301  false_branch->set_name("else_branch");
302  false_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH);
303  auto false_g = false_branch->mutable_g();
304  EncodeBlock(false_g, node->blocks()[1]);
305  }
306  }
307  auto num_initializers = initializers.size();
308  AT_ASSERT(block->inputs().size() >= num_initializers);
309  size_t inputs_count = block->inputs().size() - num_initializers;
310  for (auto& tensor : initializers) {
311  // TODO: stop using positions to determine which initializers
312  // match to which inputs
313  std::string name = graph_proto->input(inputs_count++).name();
314  auto p = graph_proto->add_initializer();
315  p->set_name(name);
316  EncodeTensor(p, tensor, name);
317  }
318 }
319 
320 void EncoderBase::AddAttribute(
321  onnx::NodeProto* node_proto,
322  const jit::Node* node,
323  const jit::Symbol name) {
324  auto attr = node_proto->add_attribute();
325  AT_ASSERT(name.is_attr());
326  attr->set_name(name.toUnqualString());
327  switch (node->kindOf(name)) {
328  case AttributeKind::f:
329  attr->set_f(node->f(name));
330  attr->set_type(onnx::AttributeProto_AttributeType_FLOAT);
331  break;
332  case AttributeKind::fs:
333  attr->set_type(onnx::AttributeProto_AttributeType_FLOATS);
334  for (auto& v : node->fs(name))
335  attr->add_floats(v);
336  break;
337  case AttributeKind::i:
338  attr->set_type(onnx::AttributeProto_AttributeType_INT);
339  attr->set_i(node->i(name));
340  break;
341  case AttributeKind::is:
342  attr->set_type(onnx::AttributeProto_AttributeType_INTS);
343  for (auto& v : node->is(name))
344  attr->add_ints(v);
345  break;
346  case AttributeKind::s:
347  attr->set_type(onnx::AttributeProto_AttributeType_STRING);
348  attr->set_s(node->s(name));
349  break;
350  case AttributeKind::ss:
351  attr->set_type(onnx::AttributeProto_AttributeType_STRINGS);
352  for (auto& v : node->ss(name))
353  attr->add_strings(v);
354  break;
355  case AttributeKind::t: {
356  attr->set_type(onnx::AttributeProto_AttributeType_TENSOR);
357  auto t = attr->mutable_t();
358  EncodeTensor(t, node->t(name));
359  } break;
360  case AttributeKind::ts:
361  attr->set_type(onnx::AttributeProto_AttributeType_TENSORS);
362  for (auto& v : node->ts(name)) {
363  auto t = attr->add_tensors();
364  EncodeTensor(t, v);
365  }
366  break;
367  case AttributeKind::g: {
368  attr->set_type(onnx::AttributeProto_AttributeType_GRAPH);
369  auto g = attr->mutable_g();
370  EncodeGraph(g, node->g(name));
371  } break;
372  case AttributeKind::gs:
373  attr->set_type(onnx::AttributeProto_AttributeType_GRAPHS);
374  for (auto& v : node->gs(name)) {
375  auto g = attr->add_graphs();
376  EncodeGraph(g, v);
377  }
378  break;
379  default:
380  throw std::runtime_error("unexpected attribute kind");
381  }
382 }
383 
384 class GraphEncoder : public EncoderBase {
385  public:
386  GraphEncoder(
387  const std::shared_ptr<Graph>& graph,
388  int64_t onnx_opset_version,
389  onnx_torch::OperatorExportTypes operator_export_type,
390  const std::vector<at::Tensor>& initializers,
391  bool defer_weight_export,
392  bool strip_doc);
393 
394  RawDataExportMap get_raw_data_export_map() {
395  return raw_data_export_map_;
396  }
397 
398  private:
399  void EncodeTensor(
400  onnx::TensorProto* tensor_proto,
401  const at::Tensor& tensor,
402  const c10::optional<std::string> external_ref = {}) override;
403 
404  RawDataExportMap raw_data_export_map_;
405  bool defer_weight_export_;
406 };
407 
408 GraphEncoder::GraphEncoder(
409  const std::shared_ptr<Graph>& graph,
410  int64_t onnx_opset_version,
411  onnx_torch::OperatorExportTypes operator_export_type,
412  const std::vector<at::Tensor>& initializers,
413  bool defer_weight_export,
414  bool strip_doc)
415  : EncoderBase(operator_export_type, strip_doc),
416  defer_weight_export_(defer_weight_export) {
417  if (operator_export_type != onnx_torch::OperatorExportTypes::RAW) {
418  validateGraph(graph, operator_export_type);
419  }
420 
421  auto* imp = model_proto_.add_opset_import();
422  // This is the version of ONNX operator set we are targeting
423  imp->set_version(onnx_opset_version);
424 
425  EncodeGraph(model_proto_.mutable_graph(), graph, initializers);
426 }
427 
428 void GraphEncoder::EncodeTensor(
429  onnx::TensorProto* tensor_proto,
430  const at::Tensor& tensor,
431  const c10::optional<std::string> external_ref) {
432  for (auto d : tensor.sizes()) {
433  tensor_proto->add_dims(d);
434  }
435  tensor_proto->set_data_type(ATenTypeToOnnxType(tensor.scalar_type()));
436  // CPU's HalfTensor doesn't have contiguous(), so first calling contiguous()
437  auto t = tensor.contiguous().cpu();
438  // Add a buffer to the raw_data_export_map for the caller to dump into an
439  // external data store. If external_ref is not specified, we instead dump
440  // the contiguous data into the protobuf itself
441  if (defer_weight_export_ && external_ref) {
442  // For now, we use the name of the tensor as the external lookup name to
443  // avoid ONNX protobuf changes.
444  AT_ASSERT(external_ref.value() == tensor_proto->name());
445  AT_ASSERT(raw_data_export_map_.count(external_ref.value()) == 0);
446  raw_data_export_map_[external_ref.value()] = t;
447  tensor_proto->set_raw_data("__EXTERNAL");
448  } else {
449  AT_ASSERT(t.is_contiguous());
450  tensor_proto->set_raw_data(std::string(
451  static_cast<char*>(t.data_ptr()),
452  t.element_size() * t.numel()));
453  }
454 }
455 
456 // this is a serializer class which saves script modules to pt files. the
457 // content of the file is written using PyTorchStreamWriter, for details please
458 // check caffe2/serialize/inline_container.h. all the records except the last
459 // one are tensor data, and the last record is a serialized ModelProto, defined
460 // in caffe2/proto/torch.proto. ModelProto contains all the metadata of the
461 // model, and it is serialized as json.
462 class ScriptModuleSerializer final {
463  public:
464  ScriptModuleSerializer(const std::string& filename);
465 
466  ScriptModuleSerializer(std::ostream* ofs);
467 
468  void serialize(
469  const script::Module& module,
470  const script::ExtraFilesMap& extra_files = script::ExtraFilesMap());
471 
472  private:
473  void convertModel(
474  const script::Module& module,
475  torch::ModelDef* model_def,
476  const script::ExtraFilesMap& extra_files);
477 
478  // add a tensor to the tensorTable
479  // returns the offset into the tensor table
480  size_t addTensor(const at::Tensor& tensor);
481 
482  // write the content of the tensor to the file/stream, and save the
483  // offset in the storageMap_
484  void convertAndWriteTensor(
485  size_t tensor_id,
486  const at::Tensor& tensor,
487  torch::TensorDef* tensor_proto,
488  std::unordered_map<const void*, std::string>& storageMap);
489 
490  // dump all the tensors in the tensorTable_ to a ModelDef (metadata) and
491  // the file/stream (the content), assuming all the information of the
492  // tensors has been collected. the method calls convertAndWriteTensor
493  // to dump the content of a tensor
494  void writeTensorTable(torch::ModelDef* model_def);
495 
496  void writeAttributeTable();
497  void writeLibs(torch::ModelDef* model_def);
498 
499  void convertModule(
500  const script::Module& module,
501  const std::string& prefix,
502  const std::string& name,
503  torch::ModuleDef* module_def);
504 
505  void convertParameter(
506  const script::NamedIValue& param,
507  torch::ParameterDef* param_def,
508  bool is_parameter);
509 
510  void convertClass(const ClassTypePtr& type, torch::ModelDef* model_def);
511 
512  std::ofstream ofs_;
514 
515  // all tensors that will be stored
516  std::vector<at::Tensor> tensor_table_;
517 
518  std::vector<IValue> attribute_table_;
519 
520  // all classes used by this module hierarchy
521  std::vector<ClassTypePtr> class_table_;
522  OrderedDict<ClassTypePtr, std::string> converted_classes_;
523 
524  static const size_t op_version_set = 0;
525 };
526 
527 // ScriptModuleSerializer's methods
528 ScriptModuleSerializer::ScriptModuleSerializer(const std::string& filename)
529  : writer_(filename.c_str()) {
530  // TODO appropriate support for mmap, right now we still use stream writer
531 }
532 
533 ScriptModuleSerializer::ScriptModuleSerializer(std::ostream* ofs)
534  : ofs_(), writer_(ofs) {}
535 
536 void ScriptModuleSerializer::serialize(
537  const script::Module& module,
538  const script::ExtraFilesMap& extra_files) {
539  torch::ModelDef model_def;
540  convertModel(module, &model_def, extra_files);
541  std::string output;
542  // NB: cannot use MessageToJsonString, since fbcode's protobuf is too old
543  // be consistent with MessageToJsonString
544  std::string url_prefix = "type.googleapis.com";
545  std::unique_ptr<::google::protobuf::util::TypeResolver> resolver(
546  ::google::protobuf::util::NewTypeResolverForDescriptorPool(
547  url_prefix, model_def.GetDescriptor()->file()->pool()));
548  ::google::protobuf::util::Status convert_result =
549  ::google::protobuf::util::BinaryToJsonString(
550  resolver.get(),
551  url_prefix + "/" + model_def.GetDescriptor()->full_name(),
552  model_def.SerializeAsString(),
553  &output);
554  if (!convert_result.ok()) {
555  std::stringstream ss;
556  ss << convert_result;
557  AT_ERROR(ss.str());
558  }
559  writer_.writeRecord("model.json", output.data(), output.size());
560  writer_.writeEndOfFile();
561 }
562 
563 void ScriptModuleSerializer::writeLibs(torch::ModelDef* model_def) {
564  auto lib_def = model_def->mutable_libs();
565  std::ostringstream lib_stream;
566  lib_stream << "op_version_set = " << op_version_set << "\n";
567  // Convert all the classes that
568  for (const auto& class_type : class_table_) {
569  convertClass(class_type, model_def);
570  }
571 
572  for (const auto& c : converted_classes_) {
573  lib_stream << *c << "\n";
574  }
575 
576  torch::RecordRef* lib_record = lib_def->mutable_torchscript_arena();
577  const auto filename = "libs.py";
578  const auto& lib_str = lib_stream.str();
579  writer_.writeRecord(filename, lib_str.c_str(), lib_str.size());
580  lib_record->set_key(filename);
581 }
582 
583 // python print the class and add to the converted_classes_. Recursively
584 // python print all classes that this class depends on.
585 void ScriptModuleSerializer::convertClass(
586  const ClassTypePtr& class_type,
587  torch::ModelDef* model_def) {
588  if (converted_classes_.contains(class_type)) {
589  return;
590  }
591 
592  std::vector<ClassTypePtr> class_deps;
593  std::ostringstream class_stream;
594  PythonPrint(
595  class_stream,
596  class_type,
597  tensor_table_,
598  class_deps,
599  /*enforce_importable=*/true);
600 
601  for (const auto& c : class_deps) {
602  if (c == class_type) {
603  // Don't re-process this class and enter an infinite loop. We need this
604  // because we insert to converted_classes_ post-traversal, so the current
605  // class isn't in there yet.
606  continue;
607  }
608  convertClass(c, model_def);
609  }
610  // Insert *after* we've traversed the dependencies. This ensures that any
611  // given class will appear after its dependencies in the order.
612  converted_classes_.insert(class_type, class_stream.str());
613 }
614 
615 void ScriptModuleSerializer::convertModel(
616  const script::Module& module,
617  torch::ModelDef* model_def,
618  const script::ExtraFilesMap& extra_files) {
619  model_def->set_producer_name("pytorch");
620  model_def->set_producer_version("1.0"); // TODO: set the producer version
621  // using appropriate function call
622  model_def->set_proto_version(torch::ProtoVersion::PROTO_VERSION_NEWEST);
623 
624  convertModule(
625  module, "", writer_.archiveName(), model_def->mutable_main_module());
626 
627  // This may write some attributes to the tensor_table_
628  writeAttributeTable();
629 
630  writeTensorTable(model_def);
631  writeLibs(model_def);
632 
633  // Write out extra files.
634  for (const auto& kv : extra_files) {
635  const std::string key = "extra/" + kv.first;
636  writer_.writeRecord(key, kv.second.data(), kv.second.size());
637  }
638 }
639 
640 size_t ScriptModuleSerializer::addTensor(const at::Tensor& tensor) {
641  tensor_table_.push_back(tensor);
642  return tensor_table_.size() - 1;
643 }
644 
645 void ScriptModuleSerializer::convertAndWriteTensor(
646  size_t tensor_id,
647  const at::Tensor& tensor,
648  torch::TensorDef* tensor_proto,
649  std::unordered_map<const void*, std::string>& storageMap) {
650  for (auto d : tensor.sizes()) {
651  tensor_proto->add_dims(d);
652  }
653  for (auto s : tensor.strides()) {
654  tensor_proto->add_strides(s);
655  }
656  tensor_proto->set_data_type(caffe2::TypeMetaToDataType(
657  at::scalarTypeToTypeMeta(tensor.scalar_type())));
658  tensor_proto->set_offset(tensor.storage_offset());
659 
660  tensor_proto->set_requires_grad(tensor.requires_grad());
661 
662  uint64_t record_size =
663  tensor.element_size() * tensor.storage().size();
664  auto* key = tensor.storage().unsafeGetStorageImpl();
665 
666  auto storage_it = storageMap.find(key);
667  if (storage_it == storageMap.end()) {
668  at::Tensor storage_tensor = tensor;
669  // TODO HIP support
670  if (tensor.storage().device_type() == at::DeviceType::CUDA) {
671  // NB: This new tensor is created to support cuda tensors.
672  // Storages can be mutated when converting tensors from cuda to cpu,
673  // and we need a cpu tensor to copy data from.
674  storage_tensor = at::empty({0}, tensor.options())
675  .set_(
676  tensor.storage(),
677  /* storageOffset = */ 0,
678  /* size = */
679  {static_cast<int64_t>(tensor.storage().size())},
680  /* stride = */ {1})
681  .cpu();
682  AT_ASSERT(
683  storage_tensor.element_size() *
684  storage_tensor.storage().size() ==
685  record_size);
686  }
687  std::string name = "tensors/" + std::to_string(tensor_id);
688  writer_.writeRecord(name, storage_tensor.storage().data(), record_size);
689  storage_it = storageMap.insert({key, name}).first;
690  }
691 
692  auto* data = tensor_proto->mutable_data();
693  data->set_key(storage_it->second);
694 
695  // handle device case, set the device_detail and load to CUDA device
696  std::stringstream ss;
697  ss << tensor.device();
698  tensor_proto->set_device(ss.str());
699 }
700 
701 void ScriptModuleSerializer::writeTensorTable(torch::ModelDef* model_def) {
702  std::unordered_map<const void*, std::string> storageMap;
703  size_t tensor_id = 0;
704  for (const at::Tensor& t : tensor_table_) {
705  auto* tensor_proto = model_def->add_tensors();
706  convertAndWriteTensor(tensor_id++, t, tensor_proto, storageMap);
707  }
708 }
709 
710 void ScriptModuleSerializer::writeAttributeTable() {
711  Pickler pickler(&tensor_table_);
712  pickler.start();
713  for (const IValue& ivalue : attribute_table_) {
714  pickler.addIValue(ivalue);
715  }
716  pickler.finish();
717  writer_.writeRecord(
718  "attributes.pkl", pickler.stack().data(), pickler.stack().size());
719 }
720 
721 void ScriptModuleSerializer::convertModule(
722  const script::Module& module,
723  const std::string& prefix,
724  const std::string& name,
725  torch::ModuleDef* module_def) {
726  module_def->set_name(name);
727  module_def->set_optimize(module.is_optimized());
728  for (const auto& elem : module.get_parameters()) {
729  torch::ParameterDef* param_def = module_def->add_parameters();
730  convertParameter(elem.value(), param_def, /*is_buffer=*/false);
731  }
732 
733  for (const auto& item : module.get_attributes()) {
734  auto& attribute = item.value();
735  // Add attribute to ModuleDef
736  torch::AttributeDef* attribute_def = module_def->add_attributes();
737  attribute_def->set_name(attribute.name_);
738  attribute_def->set_type(attribute.type->python_str());
739 
740  attribute_table_.push_back(*attribute.slot());
741  attribute_def->set_id(attribute_table_.size() - 1);
742  }
743 
744  std::stringstream module_name;
745  if (prefix != "")
746  module_name << prefix << "_";
747  module_name << name;
748 
749  if (module.get_methods().size() > 0) {
750  std::ostringstream methods;
751  methods << "op_version_set = " << op_version_set << "\n";
752  PythonPrint(
753  methods,
754  module,
755  tensor_table_,
756  class_table_,
757  /*enforce_importable=*/true);
758  torch::RecordRef* record = module_def->mutable_torchscript_arena();
759 
760  std::stringstream filename;
761  filename << "code/" << module_name.str() << ".py";
762  std::string methods_str = methods.str();
763  writer_.writeRecord(
764  filename.str(), methods_str.c_str(), methods_str.size());
765  record->set_key(filename.str());
766  }
767 
768  for (const auto& elem : module.get_modules()) {
769  torch::ModuleDef* sub_def = module_def->add_submodules();
770  convertModule(*elem->module, module_name.str(), elem.key(), sub_def);
771  }
772 }
773 
774 void ScriptModuleSerializer::convertParameter(
775  const script::NamedIValue& param,
776  torch::ParameterDef* param_def,
777  bool is_parameter) {
778  param_def->set_name(param.name_);
779  param_def->set_is_buffer(is_parameter);
780  param_def->set_tensor_id(addTensor(param.slot()->toTensor()));
781 }
782 
783 // Pretty printing for ONNX
784 constexpr char indent_char = ' ';
785 constexpr size_t indent_multiplier = 2;
786 
787 std::string idt(size_t indent) {
788  return std::string(indent * indent_multiplier, indent_char);
789 }
790 
791 std::string nlidt(size_t indent) {
792  return std::string("\n") + idt(indent);
793 }
794 
795 void dump(const onnx::TensorProto& tensor, std::ostream& stream) {
796  stream << "TensorProto shape: [";
797  for (int i = 0; i < tensor.dims_size(); ++i) {
798  stream << tensor.dims(i) << (i == tensor.dims_size() - 1 ? "" : " ");
799  }
800  stream << "]";
801 }
802 
803 void dump(const onnx::TensorShapeProto& shape, std::ostream& stream) {
804  for (int i = 0; i < shape.dim_size(); ++i) {
805  auto& dim = shape.dim(i);
806  if (dim.has_dim_value()) {
807  stream << dim.dim_value();
808  } else {
809  stream << "?";
810  }
811  stream << (i == shape.dim_size() - 1 ? "" : " ");
812  }
813 }
814 
815 void dump(const onnx::TypeProto_Tensor& tensor_type, std::ostream& stream) {
816  stream << "Tensor dims: ";
817  dump(tensor_type.shape(), stream);
818 }
819 
820 void dump(const onnx::TypeProto& type, std::ostream& stream) {
821  dump(type.tensor_type(), stream);
822 }
823 
824 void dump(const onnx::ValueInfoProto& value_info, std::ostream& stream) {
825  stream << "{name: \"" << value_info.name() << "\", type:";
826  dump(value_info.type(), stream);
827  stream << "}";
828 }
829 
830 void dump(const onnx::GraphProto& graph, std::ostream& stream, size_t indent);
831 
832 void dump(
833  const onnx::AttributeProto& attr,
834  std::ostream& stream,
835  size_t indent) {
836  stream << "{ name: '" << attr.name() << "', type: ";
837  if (attr.has_f()) {
838  stream << "float, value: " << attr.f();
839  } else if (attr.has_i()) {
840  stream << "int, value: " << attr.i();
841  } else if (attr.has_s()) {
842  stream << "string, value: '" << attr.s() << "'";
843  } else if (attr.has_g()) {
844  stream << "graph, value:\n";
845  dump(attr.g(), stream, indent + 1);
846  stream << nlidt(indent);
847  } else if (attr.has_t()) {
848  stream << "tensor, value:";
849  dump(attr.t(), stream);
850  } else if (attr.floats_size()) {
851  stream << "floats, values: [";
852  for (int i = 0; i < attr.floats_size(); ++i)
853  stream << attr.floats(i) << (i == attr.floats_size() - 1 ? "" : " ");
854  stream << "]";
855  } else if (attr.ints_size()) {
856  stream << "ints, values: [";
857  for (int i = 0; i < attr.ints_size(); ++i)
858  stream << attr.ints(i) << (i == attr.ints_size() - 1 ? "" : " ");
859  stream << "]";
860  } else if (attr.strings_size()) {
861  stream << "strings, values: [";
862  for (int i = 0; i < attr.strings_size(); ++i)
863  stream << "'" << attr.strings(i) << "'"
864  << (i == attr.strings_size() - 1 ? "" : " ");
865  stream << "]";
866  } else if (attr.tensors_size()) {
867  stream << "tensors, values: [";
868  for (auto& t : attr.tensors()) {
869  dump(t, stream);
870  }
871  stream << "]";
872  } else if (attr.graphs_size()) {
873  stream << "graphs, values: [";
874  for (auto& g : attr.graphs()) {
875  dump(g, stream, indent + 1);
876  }
877  stream << "]";
878  } else {
879  stream << "UNKNOWN";
880  }
881  stream << "}";
882 }
883 
884 void dump(const onnx::NodeProto& node, std::ostream& stream, size_t indent) {
885  stream << "Node {type: \"" << node.op_type() << "\", inputs: [";
886  for (int i = 0; i < node.input_size(); ++i) {
887  stream << node.input(i) << (i == node.input_size() - 1 ? "" : ",");
888  }
889  stream << "], outputs: [";
890  for (int i = 0; i < node.output_size(); ++i) {
891  stream << node.output(i) << (i == node.output_size() - 1 ? "" : ",");
892  }
893  stream << "], attributes: [";
894  for (int i = 0; i < node.attribute_size(); ++i) {
895  dump(node.attribute(i), stream, indent + 1);
896  stream << (i == node.attribute_size() - 1 ? "" : ",");
897  }
898  stream << "]}";
899 }
900 
901 void dump(const onnx::GraphProto& graph, std::ostream& stream, size_t indent) {
902  stream << idt(indent) << "GraphProto {" << nlidt(indent + 1) << "name: \""
903  << graph.name() << "\"" << nlidt(indent + 1) << "inputs: [";
904  for (int i = 0; i < graph.input_size(); ++i) {
905  dump(graph.input(i), stream);
906  stream << (i == graph.input_size() - 1 ? "" : ",");
907  }
908  stream << "]" << nlidt(indent + 1) << "outputs: [";
909  for (int i = 0; i < graph.output_size(); ++i) {
910  dump(graph.output(i), stream);
911  stream << (i == graph.output_size() - 1 ? "" : ",");
912  }
913  stream << "]" << nlidt(indent + 1) << "initializers: [";
914  for (int i = 0; i < graph.initializer_size(); ++i) {
915  dump(graph.initializer(i), stream);
916  stream << (i == graph.initializer_size() - 1 ? "" : ",");
917  }
918  stream << "]" << nlidt(indent + 1) << "nodes: [" << nlidt(indent + 2);
919  for (int i = 0; i < graph.node_size(); ++i) {
920  dump(graph.node(i), stream, indent + 2);
921  if (i != graph.node_size() - 1)
922  stream << "," << nlidt(indent + 2);
923  }
924  stream << nlidt(indent + 1) << "]\n" << idt(indent) << "}\n";
925 }
926 
927 void dump(
928  const onnx::OperatorSetIdProto& operator_set_id,
929  std::ostream& stream) {
930  stream << "OperatorSetIdProto { domain: " << operator_set_id.domain() << "}";
931 }
932 
933 void dump(const onnx::ModelProto& model, std::ostream& stream, size_t indent) {
934  stream << idt(indent) << "ModelProto {" << nlidt(indent + 1)
935  << "producer_name: \"" << model.producer_name() << "\""
936  << nlidt(indent + 1) << "domain: \"" << model.domain() << "\""
937  << nlidt(indent + 1) << "doc_string: \"" << model.doc_string() << "\"";
938  if (model.has_graph()) {
939  stream << nlidt(indent + 1) << "graph:\n";
940  dump(model.graph(), stream, indent + 2);
941  }
942  if (model.opset_import_size()) {
943  stream << idt(indent + 1) << "opset_import: [";
944  for (auto& opset_imp : model.opset_import()) {
945  dump(opset_imp, stream);
946  }
947  stream << "],\n";
948  }
949  stream << idt(indent) << "}\n";
950 }
951 
952 std::string prettyPrint(const onnx::ModelProto& model) {
953  std::stringstream ss;
954  dump(model, ss, 0);
955  return ss.str();
956 }
957 
958 } // namespace
959 
960 std::string pretty_print_onnx(
961  const std::shared_ptr<Graph>& graph,
962  const std::vector<at::Tensor>& initializers,
963  int64_t onnx_opset_version,
964  bool defer_weight_export,
965  ::torch::onnx::OperatorExportTypes operator_export_type,
966  bool google_printer) {
967  auto graph_encoder = GraphEncoder(
968  graph,
969  onnx_opset_version,
970  operator_export_type,
971  initializers,
972  defer_weight_export,
973  true);
974  if (google_printer) {
975  return graph_encoder.get_model_proto().DebugString();
976  }
977  return prettyPrint(graph_encoder.get_model_proto());
978 }
979 
980 // export_raw_ir will export IR ops without turning them into ONNX ops.
981 // The output will use the ONNX protobuf format, but the ops will not
982 // conform to the ONNX op specification. Thus, the output will not
983 // be interpretable by a ONNX-compatible framework. However, PyTorch or
984 // libtorch will be able to import the IR and play it back.
985 std::tuple<std::string, RawDataExportMap> export_onnx(
986  const std::shared_ptr<Graph>& graph,
987  const std::vector<at::Tensor>& initializers,
988  int64_t onnx_opset_version,
989  bool defer_weight_export,
990  ::torch::onnx::OperatorExportTypes operator_export_type) {
991  auto graph_encoder = GraphEncoder(
992  graph,
993  onnx_opset_version,
994  operator_export_type,
995  initializers,
996  defer_weight_export,
997  false);
998  return std::make_tuple(
999  graph_encoder.get_model_proto().SerializeAsString(),
1000  graph_encoder.get_raw_data_export_map());
1001 }
1002 
1003 void ExportModule(
1004  const script::Module& module,
1005  std::ostream& out,
1006  const script::ExtraFilesMap& extra_files) {
1007  ScriptModuleSerializer serializer(&out);
1008  serializer.serialize(module, extra_files);
1009 }
1010 
1011 void ExportModule(
1012  const script::Module& module,
1013  const std::string& filename,
1014  const script::ExtraFilesMap& extra_files) {
1015  ScriptModuleSerializer serializer(filename);
1016  serializer.serialize(module, extra_files);
1017 }
1018 
1019 } // namespace jit
1020 } // namespace torch
TensorOptions options() const
Returns the TensorOptions corresponding to this Tensor.
Definition: TensorMethods.h:42
Device device() const
Returns a Tensor&#39;s device.
Definition: jit_type.h:17
An ordered dictionary implementation, akin to Python&#39;s OrderedDict.
Definition: ordered_dict.h:16