1 #include <google/protobuf/util/json_util.h> 2 #include <google/protobuf/util/type_resolver_util.h> 4 #include <torch/csrc/autograd/symbolic.h> 5 #include <torch/csrc/jit/export.h> 6 #include <torch/csrc/onnx/onnx.h> 8 #include <ATen/core/functional.h> 9 #include <c10/util/Exception.h> 10 #include <torch/csrc/jit/passes/dead_code_elimination.h> 11 #include <torch/csrc/jit/passes/python_print.h> 12 #include <torch/csrc/jit/pickler.h> 14 #include <caffe2/core/types.h> 15 #include <caffe2/proto/caffe2_pb.h> 16 #include <caffe2/proto/torch_pb.h> 17 #include <caffe2/serialize/inline_container.h> 18 #include <onnx/onnx_pb.h> 20 #include <ATen/ATen.h> 21 #include <c10/util/Optional.h> 35 namespace onnx = ::ONNX_NAMESPACE;
37 class ScriptModuleSerializer;
39 std::string getNodeStackTraceString(
const Node* n) {
41 if (n->getSourceLocation()) {
42 n->getSourceLocation()->highlight(ss);
44 ss <<
"<unknown location>";
51 onnx_torch::OperatorExportTypes operator_export_type) {
52 for (
auto node : b->nodes()) {
53 for (Block* sub_block : node->blocks()) {
54 validateBlock(sub_block, operator_export_type);
57 #define FAIL_EXPORT(name) \ 58 throw std::runtime_error( \ 59 std::string("ONNX export failed: ") + name + \ 60 "\n\nGraph we tried to export:\n" + b->owningGraph()->toString()); 61 if (node->kind() == prim::PythonOp) {
62 auto py_node =
static_cast<PythonOp*
>(node);
64 "Couldn't export Python operator " + py_node->name() +
65 "\n\nDefined at:\n" + getNodeStackTraceString(node))
68 if (node->kind() == aten::expand) {
69 if (operator_export_type ==
70 onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK) {
71 WithInsertPoint guard(node);
73 b->owningGraph()->insertNode(b->owningGraph()->create(
74 Symbol(::c10::onnx::ATen),
76 node->outputs().size()));
77 for (
size_t i = 0; i < node->outputs().size(); ++i) {
78 node->output(i)->replaceAllUsesWith(new_node->output(i));
80 new_node->s_(Symbol::fromQualString(
"attr::operator"),
"expand");
83 if (node->kind() == prim::PackPadded || node->kind() == prim::PadPacked) {
85 "Cannot export individual pack_padded_sequence or pad_packed_sequence; these operations must occur in pairs.\n\nUsage of this operation occurred at:\n" +
86 getNodeStackTraceString(node));
88 bool is_aten_enabled = operator_export_type ==
89 onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK ||
90 operator_export_type == onnx_torch::OperatorExportTypes::ONNX_ATEN;
91 if (!node->kind().is_onnx() && !is_aten_enabled && !node->mustBeNone()) {
93 "Couldn't export operator " + node->kind().toDisplayString() +
94 "\n\nDefined at:\n" + getNodeStackTraceString(node));
102 const std::shared_ptr<Graph>& graph,
103 onnx_torch::OperatorExportTypes operator_export_type) {
104 validateBlock(graph->block(), operator_export_type);
105 EliminateDeadCode(graph->block());
111 onnx_torch::OperatorExportTypes operator_export_type,
114 onnx::ModelProto get_model_proto() {
120 onnx::GraphProto* graph_proto,
121 const std::shared_ptr<Graph>& graph,
122 const std::vector<at::Tensor>& initializers = {});
125 onnx::GraphProto* graph_proto,
127 const std::vector<at::Tensor>& initializers = {});
129 virtual void EncodeTensor(
130 onnx::TensorProto* tensor_proto,
134 virtual void EncodeIntermediateValueInfo(
135 onnx::GraphProto* graph_proto,
138 virtual void EncodeValueInfo(
139 onnx::GraphProto* graph_proto,
140 onnx::ValueInfoProto* v,
144 onnx::NodeProto* node_proto,
145 const jit::Node* node,
146 const jit::Symbol name);
148 onnx::ModelProto model_proto_;
150 onnx_torch::OperatorExportTypes operator_export_type_;
154 onnx::TensorProto_DataType ATenTypeToOnnxType(at::ScalarType at_type) {
157 return onnx::TensorProto_DataType_DOUBLE;
159 return onnx::TensorProto_DataType_FLOAT;
161 return onnx::TensorProto_DataType_FLOAT16;
163 return onnx::TensorProto_DataType_UINT8;
165 return onnx::TensorProto_DataType_INT8;
167 return onnx::TensorProto_DataType_INT16;
169 return onnx::TensorProto_DataType_INT32;
171 return onnx::TensorProto_DataType_INT64;
173 AT_ERROR(
"unexpected tensor scalar type");
177 EncoderBase::EncoderBase(
178 onnx_torch::OperatorExportTypes operator_export_type,
181 operator_export_type_(operator_export_type),
182 strip_doc_(strip_doc) {
183 model_proto_.set_producer_name(
"pytorch");
184 model_proto_.set_ir_version(onnx::IR_VERSION);
186 model_proto_.set_producer_version(
"1.1");
189 void EncoderBase::EncodeValueInfo(
190 onnx::GraphProto* graph_proto,
191 onnx::ValueInfoProto* v,
193 v->set_name(n->uniqueName());
194 onnx::TypeProto* t = v->mutable_type();
195 onnx::TypeProto_Tensor* tensor_type = t->mutable_tensor_type();
197 onnx::TensorShapeProto* shape = tensor_type->mutable_shape();
198 if (CompleteTensorTypePtr node_type = n->type()->cast<CompleteTensorType>()) {
199 const std::vector<std::int64_t>& sizes = node_type->sizes();
200 for (
size_t i = 0; i < sizes.size(); i++) {
202 shape->mutable_dim(i)->set_dim_value(sizes[i]);
204 tensor_type->set_elem_type(ATenTypeToOnnxType(node_type->scalarType()));
206 tensor_type->set_elem_type(onnx::TensorProto_DataType_UNDEFINED);
210 void EncoderBase::EncodeGraph(
211 onnx::GraphProto* graph_proto,
212 const std::shared_ptr<Graph>& graph,
213 const std::vector<at::Tensor>& initializers) {
214 EncodeBlock(graph_proto, graph->block(), initializers);
217 void EncoderBase::EncodeBlock(
218 onnx::GraphProto* graph_proto,
220 const std::vector<at::Tensor>& initializers) {
221 AT_ASSERT(graph_proto !=
nullptr);
222 std::string block_name =
"torch-jit-export";
224 block_name += std::to_string(num_blocks_);
227 graph_proto->set_name(block_name);
229 for (
auto input : block->inputs()) {
230 onnx::ValueInfoProto* v = graph_proto->add_input();
231 EncodeValueInfo(graph_proto, v, input);
233 for (
auto output : block->outputs()) {
234 onnx::ValueInfoProto* v = graph_proto->add_output();
235 EncodeValueInfo(graph_proto, v, output);
237 for (
auto node : block->nodes()) {
239 operator_export_type_ == onnx_torch::OperatorExportTypes::RAW;
240 if (node->mustBeNone() && !is_raw_export) {
246 auto p_n = graph_proto->add_node();
247 if (node->getSourceLocation() && !strip_doc_) {
248 std::stringstream ss;
249 node->getSourceLocation()->highlight(ss);
250 p_n->set_doc_string(ss.str());
252 for (
auto input : node->inputs()) {
253 if (input->node()->mustBeNone() && !is_raw_export) {
256 p_n->add_input(input->uniqueName());
259 for (
auto output : node->outputs()) {
260 p_n->add_output(output->uniqueName());
261 EncodeIntermediateValueInfo(graph_proto, output);
264 AT_ASSERT(!node->kind().is_onnx());
265 p_n->set_domain(node->kind().domainString());
266 }
else if (operator_export_type_ == onnx_torch::OperatorExportTypes::ONNX) {
267 AT_ASSERT(node->kind().is_onnx());
269 p_n->set_op_type(node->kind().toUnqualString());
270 for (
auto attr_name : node->attributeNames()) {
271 AddAttribute(p_n, node, attr_name);
273 if (is_raw_export && node->blocks().size() > 0) {
274 auto blocks = p_n->add_attribute();
275 blocks->set_name(
"_blocks");
276 blocks->set_type(onnx::AttributeProto_AttributeType_GRAPHS);
277 for (
auto block : node->blocks()) {
278 auto graph = blocks->add_graphs();
279 EncodeBlock(graph, block, initializers);
282 if (node->kind() == ::c10::onnx::Loop) {
283 AT_ASSERT(node->blocks().size() == 1);
285 auto body = p_n->add_attribute();
286 body->set_name(
"body");
287 body->set_type(onnx::AttributeProto_AttributeType_GRAPH);
288 auto g = body->mutable_g();
289 EncodeBlock(g, node->blocks()[0]);
291 if (node->kind() == ::c10::onnx::If) {
292 AT_ASSERT(node->blocks().size() == 2);
294 auto true_branch = p_n->add_attribute();
295 true_branch->set_name(
"then_branch");
296 true_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH);
297 auto true_g = true_branch->mutable_g();
298 EncodeBlock(true_g, node->blocks()[0]);
300 auto false_branch = p_n->add_attribute();
301 false_branch->set_name(
"else_branch");
302 false_branch->set_type(onnx::AttributeProto_AttributeType_GRAPH);
303 auto false_g = false_branch->mutable_g();
304 EncodeBlock(false_g, node->blocks()[1]);
307 auto num_initializers = initializers.size();
308 AT_ASSERT(block->inputs().size() >= num_initializers);
309 size_t inputs_count = block->inputs().size() - num_initializers;
310 for (
auto& tensor : initializers) {
313 std::string name = graph_proto->input(inputs_count++).name();
314 auto p = graph_proto->add_initializer();
316 EncodeTensor(p, tensor, name);
320 void EncoderBase::AddAttribute(
321 onnx::NodeProto* node_proto,
322 const jit::Node* node,
323 const jit::Symbol name) {
324 auto attr = node_proto->add_attribute();
325 AT_ASSERT(name.is_attr());
326 attr->set_name(name.toUnqualString());
327 switch (node->kindOf(name)) {
328 case AttributeKind::f:
329 attr->set_f(node->f(name));
330 attr->set_type(onnx::AttributeProto_AttributeType_FLOAT);
332 case AttributeKind::fs:
333 attr->set_type(onnx::AttributeProto_AttributeType_FLOATS);
334 for (
auto& v : node->fs(name))
337 case AttributeKind::i:
338 attr->set_type(onnx::AttributeProto_AttributeType_INT);
339 attr->set_i(node->i(name));
341 case AttributeKind::is:
342 attr->set_type(onnx::AttributeProto_AttributeType_INTS);
343 for (
auto& v : node->is(name))
346 case AttributeKind::s:
347 attr->set_type(onnx::AttributeProto_AttributeType_STRING);
348 attr->set_s(node->s(name));
350 case AttributeKind::ss:
351 attr->set_type(onnx::AttributeProto_AttributeType_STRINGS);
352 for (
auto& v : node->ss(name))
353 attr->add_strings(v);
355 case AttributeKind::t: {
356 attr->set_type(onnx::AttributeProto_AttributeType_TENSOR);
357 auto t = attr->mutable_t();
358 EncodeTensor(t, node->t(name));
360 case AttributeKind::ts:
361 attr->set_type(onnx::AttributeProto_AttributeType_TENSORS);
362 for (
auto& v : node->ts(name)) {
363 auto t = attr->add_tensors();
367 case AttributeKind::g: {
368 attr->set_type(onnx::AttributeProto_AttributeType_GRAPH);
369 auto g = attr->mutable_g();
370 EncodeGraph(g, node->g(name));
372 case AttributeKind::gs:
373 attr->set_type(onnx::AttributeProto_AttributeType_GRAPHS);
374 for (
auto& v : node->gs(name)) {
375 auto g = attr->add_graphs();
380 throw std::runtime_error(
"unexpected attribute kind");
384 class GraphEncoder :
public EncoderBase {
387 const std::shared_ptr<Graph>& graph,
388 int64_t onnx_opset_version,
389 onnx_torch::OperatorExportTypes operator_export_type,
390 const std::vector<at::Tensor>& initializers,
391 bool defer_weight_export,
394 RawDataExportMap get_raw_data_export_map() {
395 return raw_data_export_map_;
400 onnx::TensorProto* tensor_proto,
404 RawDataExportMap raw_data_export_map_;
405 bool defer_weight_export_;
408 GraphEncoder::GraphEncoder(
409 const std::shared_ptr<Graph>& graph,
410 int64_t onnx_opset_version,
411 onnx_torch::OperatorExportTypes operator_export_type,
412 const std::vector<at::Tensor>& initializers,
413 bool defer_weight_export,
415 : EncoderBase(operator_export_type, strip_doc),
416 defer_weight_export_(defer_weight_export) {
417 if (operator_export_type != onnx_torch::OperatorExportTypes::RAW) {
418 validateGraph(graph, operator_export_type);
421 auto* imp = model_proto_.add_opset_import();
423 imp->set_version(onnx_opset_version);
425 EncodeGraph(model_proto_.mutable_graph(), graph, initializers);
428 void GraphEncoder::EncodeTensor(
429 onnx::TensorProto* tensor_proto,
432 for (
auto d : tensor.sizes()) {
433 tensor_proto->add_dims(d);
435 tensor_proto->set_data_type(ATenTypeToOnnxType(tensor.scalar_type()));
437 auto t = tensor.contiguous().cpu();
441 if (defer_weight_export_ && external_ref) {
444 AT_ASSERT(external_ref.value() == tensor_proto->name());
445 AT_ASSERT(raw_data_export_map_.count(external_ref.value()) == 0);
446 raw_data_export_map_[external_ref.value()] = t;
447 tensor_proto->set_raw_data(
"__EXTERNAL");
449 AT_ASSERT(t.is_contiguous());
450 tensor_proto->set_raw_data(std::string(
451 static_cast<char*>(t.data_ptr()),
452 t.element_size() * t.numel()));
462 class ScriptModuleSerializer final {
464 ScriptModuleSerializer(
const std::string& filename);
466 ScriptModuleSerializer(std::ostream* ofs);
469 const script::Module& module,
470 const script::ExtraFilesMap& extra_files = script::ExtraFilesMap());
474 const script::Module& module,
475 torch::ModelDef* model_def,
476 const script::ExtraFilesMap& extra_files);
484 void convertAndWriteTensor(
487 torch::TensorDef* tensor_proto,
488 std::unordered_map<const void*, std::string>& storageMap);
494 void writeTensorTable(torch::ModelDef* model_def);
496 void writeAttributeTable();
497 void writeLibs(torch::ModelDef* model_def);
500 const script::Module& module,
501 const std::string& prefix,
502 const std::string& name,
503 torch::ModuleDef* module_def);
505 void convertParameter(
506 const script::NamedIValue& param,
507 torch::ParameterDef* param_def,
510 void convertClass(
const ClassTypePtr& type, torch::ModelDef* model_def);
516 std::vector<at::Tensor> tensor_table_;
518 std::vector<IValue> attribute_table_;
521 std::vector<ClassTypePtr> class_table_;
524 static const size_t op_version_set = 0;
528 ScriptModuleSerializer::ScriptModuleSerializer(
const std::string& filename)
529 : writer_(filename.c_str()) {
533 ScriptModuleSerializer::ScriptModuleSerializer(std::ostream* ofs)
534 : ofs_(), writer_(ofs) {}
536 void ScriptModuleSerializer::serialize(
537 const script::Module& module,
538 const script::ExtraFilesMap& extra_files) {
539 torch::ModelDef model_def;
540 convertModel(module, &model_def, extra_files);
544 std::string url_prefix =
"type.googleapis.com";
545 std::unique_ptr<::google::protobuf::util::TypeResolver> resolver(
546 ::google::protobuf::util::NewTypeResolverForDescriptorPool(
547 url_prefix, model_def.GetDescriptor()->file()->pool()));
548 ::google::protobuf::util::Status convert_result =
549 ::google::protobuf::util::BinaryToJsonString(
551 url_prefix +
"/" + model_def.GetDescriptor()->full_name(),
552 model_def.SerializeAsString(),
554 if (!convert_result.ok()) {
555 std::stringstream ss;
556 ss << convert_result;
559 writer_.writeRecord(
"model.json", output.data(), output.size());
560 writer_.writeEndOfFile();
563 void ScriptModuleSerializer::writeLibs(torch::ModelDef* model_def) {
564 auto lib_def = model_def->mutable_libs();
565 std::ostringstream lib_stream;
566 lib_stream <<
"op_version_set = " << op_version_set <<
"\n";
568 for (
const auto& class_type : class_table_) {
569 convertClass(class_type, model_def);
572 for (
const auto& c : converted_classes_) {
573 lib_stream << *c <<
"\n";
576 torch::RecordRef* lib_record = lib_def->mutable_torchscript_arena();
577 const auto filename =
"libs.py";
578 const auto& lib_str = lib_stream.str();
579 writer_.writeRecord(filename, lib_str.c_str(), lib_str.size());
580 lib_record->set_key(filename);
585 void ScriptModuleSerializer::convertClass(
586 const ClassTypePtr& class_type,
587 torch::ModelDef* model_def) {
588 if (converted_classes_.contains(class_type)) {
592 std::vector<ClassTypePtr> class_deps;
593 std::ostringstream class_stream;
601 for (
const auto& c : class_deps) {
602 if (c == class_type) {
608 convertClass(c, model_def);
612 converted_classes_.insert(class_type, class_stream.str());
615 void ScriptModuleSerializer::convertModel(
616 const script::Module& module,
617 torch::ModelDef* model_def,
618 const script::ExtraFilesMap& extra_files) {
619 model_def->set_producer_name(
"pytorch");
620 model_def->set_producer_version(
"1.0");
622 model_def->set_proto_version(torch::ProtoVersion::PROTO_VERSION_NEWEST);
625 module,
"", writer_.archiveName(), model_def->mutable_main_module());
628 writeAttributeTable();
630 writeTensorTable(model_def);
631 writeLibs(model_def);
634 for (
const auto& kv : extra_files) {
635 const std::string key =
"extra/" + kv.first;
636 writer_.writeRecord(key, kv.second.data(), kv.second.size());
640 size_t ScriptModuleSerializer::addTensor(
const at::Tensor& tensor) {
641 tensor_table_.push_back(tensor);
642 return tensor_table_.size() - 1;
645 void ScriptModuleSerializer::convertAndWriteTensor(
648 torch::TensorDef* tensor_proto,
649 std::unordered_map<const void*, std::string>& storageMap) {
650 for (
auto d : tensor.sizes()) {
651 tensor_proto->add_dims(d);
653 for (
auto s : tensor.strides()) {
654 tensor_proto->add_strides(s);
656 tensor_proto->set_data_type(caffe2::TypeMetaToDataType(
657 at::scalarTypeToTypeMeta(tensor.scalar_type())));
658 tensor_proto->set_offset(tensor.storage_offset());
660 tensor_proto->set_requires_grad(tensor.requires_grad());
662 uint64_t record_size =
663 tensor.element_size() * tensor.storage().size();
664 auto* key = tensor.storage().unsafeGetStorageImpl();
666 auto storage_it = storageMap.find(key);
667 if (storage_it == storageMap.end()) {
670 if (tensor.storage().device_type() == at::DeviceType::CUDA) {
674 storage_tensor = at::empty({0}, tensor.
options())
679 {
static_cast<int64_t
>(tensor.storage().size())},
683 storage_tensor.element_size() *
684 storage_tensor.storage().size() ==
687 std::string name =
"tensors/" + std::to_string(tensor_id);
688 writer_.writeRecord(name, storage_tensor.storage().data(), record_size);
689 storage_it = storageMap.insert({key, name}).first;
692 auto* data = tensor_proto->mutable_data();
693 data->set_key(storage_it->second);
696 std::stringstream ss;
698 tensor_proto->set_device(ss.str());
701 void ScriptModuleSerializer::writeTensorTable(torch::ModelDef* model_def) {
702 std::unordered_map<const void*, std::string> storageMap;
703 size_t tensor_id = 0;
705 auto* tensor_proto = model_def->add_tensors();
706 convertAndWriteTensor(tensor_id++, t, tensor_proto, storageMap);
710 void ScriptModuleSerializer::writeAttributeTable() {
711 Pickler pickler(&tensor_table_);
713 for (
const IValue& ivalue : attribute_table_) {
714 pickler.addIValue(ivalue);
718 "attributes.pkl", pickler.stack().data(), pickler.stack().size());
721 void ScriptModuleSerializer::convertModule(
722 const script::Module& module,
723 const std::string& prefix,
724 const std::string& name,
725 torch::ModuleDef* module_def) {
726 module_def->set_name(name);
727 module_def->set_optimize(module.is_optimized());
728 for (
const auto& elem : module.get_parameters()) {
729 torch::ParameterDef* param_def = module_def->add_parameters();
730 convertParameter(elem.value(), param_def,
false);
733 for (
const auto& item : module.get_attributes()) {
734 auto& attribute = item.value();
736 torch::AttributeDef* attribute_def = module_def->add_attributes();
737 attribute_def->set_name(attribute.name_);
738 attribute_def->set_type(attribute.type->python_str());
740 attribute_table_.push_back(*attribute.slot());
741 attribute_def->set_id(attribute_table_.size() - 1);
744 std::stringstream module_name;
746 module_name << prefix <<
"_";
749 if (module.get_methods().size() > 0) {
750 std::ostringstream methods;
751 methods <<
"op_version_set = " << op_version_set <<
"\n";
758 torch::RecordRef* record = module_def->mutable_torchscript_arena();
760 std::stringstream filename;
761 filename <<
"code/" << module_name.str() <<
".py";
762 std::string methods_str = methods.str();
764 filename.str(), methods_str.c_str(), methods_str.size());
765 record->set_key(filename.str());
768 for (
const auto& elem : module.get_modules()) {
769 torch::ModuleDef* sub_def = module_def->add_submodules();
770 convertModule(*elem->module, module_name.str(), elem.key(), sub_def);
774 void ScriptModuleSerializer::convertParameter(
775 const script::NamedIValue& param,
776 torch::ParameterDef* param_def,
778 param_def->set_name(param.name_);
779 param_def->set_is_buffer(is_parameter);
780 param_def->set_tensor_id(addTensor(param.slot()->toTensor()));
784 constexpr
char indent_char =
' ';
785 constexpr
size_t indent_multiplier = 2;
787 std::string idt(
size_t indent) {
788 return std::string(indent * indent_multiplier, indent_char);
791 std::string nlidt(
size_t indent) {
792 return std::string(
"\n") + idt(indent);
795 void dump(
const onnx::TensorProto& tensor, std::ostream& stream) {
796 stream <<
"TensorProto shape: [";
797 for (
int i = 0; i < tensor.dims_size(); ++i) {
798 stream << tensor.dims(i) << (i == tensor.dims_size() - 1 ?
"" :
" ");
803 void dump(
const onnx::TensorShapeProto& shape, std::ostream& stream) {
804 for (
int i = 0; i < shape.dim_size(); ++i) {
805 auto& dim = shape.dim(i);
806 if (dim.has_dim_value()) {
807 stream << dim.dim_value();
811 stream << (i == shape.dim_size() - 1 ?
"" :
" ");
815 void dump(
const onnx::TypeProto_Tensor& tensor_type, std::ostream& stream) {
816 stream <<
"Tensor dims: ";
817 dump(tensor_type.shape(), stream);
820 void dump(
const onnx::TypeProto& type, std::ostream& stream) {
821 dump(type.tensor_type(), stream);
824 void dump(
const onnx::ValueInfoProto& value_info, std::ostream& stream) {
825 stream <<
"{name: \"" << value_info.name() <<
"\", type:";
826 dump(value_info.type(), stream);
830 void dump(
const onnx::GraphProto& graph, std::ostream& stream,
size_t indent);
833 const onnx::AttributeProto& attr,
834 std::ostream& stream,
836 stream <<
"{ name: '" << attr.name() <<
"', type: ";
838 stream <<
"float, value: " << attr.f();
839 }
else if (attr.has_i()) {
840 stream <<
"int, value: " << attr.i();
841 }
else if (attr.has_s()) {
842 stream <<
"string, value: '" << attr.s() <<
"'";
843 }
else if (attr.has_g()) {
844 stream <<
"graph, value:\n";
845 dump(attr.g(), stream, indent + 1);
846 stream << nlidt(indent);
847 }
else if (attr.has_t()) {
848 stream <<
"tensor, value:";
849 dump(attr.t(), stream);
850 }
else if (attr.floats_size()) {
851 stream <<
"floats, values: [";
852 for (
int i = 0; i < attr.floats_size(); ++i)
853 stream << attr.floats(i) << (i == attr.floats_size() - 1 ?
"" :
" ");
855 }
else if (attr.ints_size()) {
856 stream <<
"ints, values: [";
857 for (
int i = 0; i < attr.ints_size(); ++i)
858 stream << attr.ints(i) << (i == attr.ints_size() - 1 ?
"" :
" ");
860 }
else if (attr.strings_size()) {
861 stream <<
"strings, values: [";
862 for (
int i = 0; i < attr.strings_size(); ++i)
863 stream <<
"'" << attr.strings(i) <<
"'" 864 << (i == attr.strings_size() - 1 ?
"" :
" ");
866 }
else if (attr.tensors_size()) {
867 stream <<
"tensors, values: [";
868 for (
auto& t : attr.tensors()) {
872 }
else if (attr.graphs_size()) {
873 stream <<
"graphs, values: [";
874 for (
auto& g : attr.graphs()) {
875 dump(g, stream, indent + 1);
884 void dump(
const onnx::NodeProto& node, std::ostream& stream,
size_t indent) {
885 stream <<
"Node {type: \"" << node.op_type() <<
"\", inputs: [";
886 for (
int i = 0; i < node.input_size(); ++i) {
887 stream << node.input(i) << (i == node.input_size() - 1 ?
"" :
",");
889 stream <<
"], outputs: [";
890 for (
int i = 0; i < node.output_size(); ++i) {
891 stream << node.output(i) << (i == node.output_size() - 1 ?
"" :
",");
893 stream <<
"], attributes: [";
894 for (
int i = 0; i < node.attribute_size(); ++i) {
895 dump(node.attribute(i), stream, indent + 1);
896 stream << (i == node.attribute_size() - 1 ?
"" :
",");
901 void dump(
const onnx::GraphProto& graph, std::ostream& stream,
size_t indent) {
902 stream << idt(indent) <<
"GraphProto {" << nlidt(indent + 1) <<
"name: \"" 903 << graph.name() <<
"\"" << nlidt(indent + 1) <<
"inputs: [";
904 for (
int i = 0; i < graph.input_size(); ++i) {
905 dump(graph.input(i), stream);
906 stream << (i == graph.input_size() - 1 ?
"" :
",");
908 stream <<
"]" << nlidt(indent + 1) <<
"outputs: [";
909 for (
int i = 0; i < graph.output_size(); ++i) {
910 dump(graph.output(i), stream);
911 stream << (i == graph.output_size() - 1 ?
"" :
",");
913 stream <<
"]" << nlidt(indent + 1) <<
"initializers: [";
914 for (
int i = 0; i < graph.initializer_size(); ++i) {
915 dump(graph.initializer(i), stream);
916 stream << (i == graph.initializer_size() - 1 ?
"" :
",");
918 stream <<
"]" << nlidt(indent + 1) <<
"nodes: [" << nlidt(indent + 2);
919 for (
int i = 0; i < graph.node_size(); ++i) {
920 dump(graph.node(i), stream, indent + 2);
921 if (i != graph.node_size() - 1)
922 stream <<
"," << nlidt(indent + 2);
924 stream << nlidt(indent + 1) <<
"]\n" << idt(indent) <<
"}\n";
928 const onnx::OperatorSetIdProto& operator_set_id,
929 std::ostream& stream) {
930 stream <<
"OperatorSetIdProto { domain: " << operator_set_id.domain() <<
"}";
933 void dump(
const onnx::ModelProto& model, std::ostream& stream,
size_t indent) {
934 stream << idt(indent) <<
"ModelProto {" << nlidt(indent + 1)
935 <<
"producer_name: \"" << model.producer_name() <<
"\"" 936 << nlidt(indent + 1) <<
"domain: \"" << model.domain() <<
"\"" 937 << nlidt(indent + 1) <<
"doc_string: \"" << model.doc_string() <<
"\"";
938 if (model.has_graph()) {
939 stream << nlidt(indent + 1) <<
"graph:\n";
940 dump(model.graph(), stream, indent + 2);
942 if (model.opset_import_size()) {
943 stream << idt(indent + 1) <<
"opset_import: [";
944 for (
auto& opset_imp : model.opset_import()) {
945 dump(opset_imp, stream);
949 stream << idt(indent) <<
"}\n";
952 std::string prettyPrint(
const onnx::ModelProto& model) {
953 std::stringstream ss;
960 std::string pretty_print_onnx(
961 const std::shared_ptr<Graph>& graph,
962 const std::vector<at::Tensor>& initializers,
963 int64_t onnx_opset_version,
964 bool defer_weight_export,
965 ::torch::onnx::OperatorExportTypes operator_export_type,
966 bool google_printer) {
967 auto graph_encoder = GraphEncoder(
970 operator_export_type,
974 if (google_printer) {
975 return graph_encoder.get_model_proto().DebugString();
977 return prettyPrint(graph_encoder.get_model_proto());
985 std::tuple<std::string, RawDataExportMap> export_onnx(
986 const std::shared_ptr<Graph>& graph,
987 const std::vector<at::Tensor>& initializers,
988 int64_t onnx_opset_version,
989 bool defer_weight_export,
990 ::torch::onnx::OperatorExportTypes operator_export_type) {
991 auto graph_encoder = GraphEncoder(
994 operator_export_type,
998 return std::make_tuple(
999 graph_encoder.get_model_proto().SerializeAsString(),
1000 graph_encoder.get_raw_data_export_map());
1004 const script::Module& module,
1006 const script::ExtraFilesMap& extra_files) {
1007 ScriptModuleSerializer serializer(&out);
1008 serializer.serialize(module, extra_files);
1012 const script::Module& module,
1013 const std::string& filename,
1014 const script::ExtraFilesMap& extra_files) {
1015 ScriptModuleSerializer serializer(filename);
1016 serializer.serialize(module, extra_files);
TensorOptions options() const
Returns the TensorOptions corresponding to this Tensor.
Device device() const
Returns a Tensor's device.
An ordered dictionary implementation, akin to Python's OrderedDict.