2 The torch.onnx module contains functions to export models into the ONNX 3 IR format. These models can be loaded with the ONNX library and then 4 converted to models which run on other deep learning frameworks. 21 from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes
25 @contextlib.contextmanager
26 def set_training(model, mode):
28 A context manager to temporarily set the training mode of 'model' 29 to 'mode', resetting it when we exit the with-block. A no-op if 35 old_mode = model.training
45 def export(model, args, f, export_params=True, verbose=False, training=False,
46 input_names=
None, output_names=
None, aten=
False, export_raw_ir=
False,
47 operator_export_type=
None, opset_version=
None, _retain_param_name=
True):
49 Export a model into ONNX format. This exporter runs your model 50 once in order to get a trace of its execution to be exported; 51 at the moment, it supports a limited set of dynamic models (e.g., RNNs.) 53 See also: :ref:`onnx-export` 56 model (torch.nn.Module): the model to be exported. 57 args (tuple of arguments): the inputs to 58 the model, e.g., such that ``model(*args)`` is a valid 59 invocation of the model. Any non-Tensor arguments will 60 be hard-coded into the exported model; any Tensor arguments 61 will become inputs of the exported model, in the order they 62 occur in args. If args is a Tensor, this is equivalent 63 to having called it with a 1-ary tuple of that Tensor. 64 (Note: passing keyword arguments to the model is not currently 65 supported. Give us a shout if you need it.) 66 f: a file-like object (has to implement fileno that returns a file descriptor) 67 or a string containing a file name. A binary Protobuf will be written 69 export_params (bool, default True): if specified, all parameters will 70 be exported. Set this to False if you want to export an untrained model. 71 In this case, the exported model will first take all of its parameters 72 as arguments, the ordering as specified by ``model.state_dict().values()`` 73 verbose (bool, default False): if specified, we will print out a debug 74 description of the trace being exported. 75 training (bool, default False): export the model in training mode. At 76 the moment, ONNX is oriented towards exporting models for inference 77 only, so you will generally not need to set this to True. 78 input_names(list of strings, default empty list): names to assign to the 79 input nodes of the graph, in order 80 output_names(list of strings, default empty list): names to assign to the 81 output nodes of the graph, in order 82 aten (bool, default False): [DEPRECATED. use operator_export_type] export the 83 model in aten mode. If using aten mode, all the ops original exported 84 by the functions in symbolic.py are exported as ATen ops. 85 export_raw_ir (bool, default False): [DEPRECATED. use operator_export_type] 86 export the internal IR directly instead of converting it to ONNX ops. 87 operator_export_type (enum, default OperatorExportTypes.ONNX): 88 OperatorExportTypes.ONNX: all ops are exported as regular ONNX ops. 89 OperatorExportTypes.ONNX_ATEN: all ops are exported as ATen ops. 90 OperatorExportTypes.ONNX_ATEN_FALLBACK: if symbolic is missing, 92 OperatorExportTypes.RAW: export raw ir. 93 opset_version (int, default is 9): by default we export the model to the 94 opset version of the onnx submodule. Since ONNX's latest opset may 95 evolve before next stable release, by default we export to one stable 96 opset version. Right now, supported stable opset version is 9. 97 The opset_version must be _onnx_master_opset or in _onnx_stable_opsets 98 which are defined in torch/onnx/symbolic.py 100 if aten
or export_raw_ir:
101 assert operator_export_type
is None 102 assert aten ^ export_raw_ir
103 operator_export_type = OperatorExportTypes.ATEN
if aten
else OperatorExportTypes.RAW
104 elif operator_export_type
is None:
105 if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE:
106 operator_export_type = OperatorExportTypes.ONNX_ATEN_FALLBACK
108 operator_export_type = OperatorExportTypes.ONNX
109 _export(model, args, f, export_params, verbose, training, input_names, output_names,
110 operator_export_type=operator_export_type, opset_version=opset_version,
111 _retain_param_name=_retain_param_name)
116 def _split_tensor_list_constants(g, block):
117 for node
in block.nodes():
118 for subblock
in node.blocks():
119 _split_tensor_list_constants(g, subblock)
120 if node.kind() ==
"prim::Constant":
121 output_type = node.output().type()
122 if output_type.isSubtypeOf(ListType.ofTensors()):
123 inputs = [g.create(
"prim::Constant").t_(
'value', t)
124 .insertBefore(node).output()
125 for t
in node[
'value']]
126 lc = (g.create(
"prim::ListConstruct", inputs)
129 .setType(ListType.ofTensors()))
130 node.output().replaceAllUsesWith(lc)
133 def _optimize_graph(graph, operator_export_type):
135 torch._C._jit_pass_inline_fork_wait(graph)
136 torch._C._jit_pass_dce(graph)
137 torch._C._jit_pass_lint(graph)
139 torch._C._jit_pass_remove_inplace_ops(graph)
144 torch._C._jit_pass_constant_propagation(graph)
145 _split_tensor_list_constants(graph, graph)
148 torch._C._jit_pass_dce(graph)
149 torch._C._jit_pass_lint(graph)
151 torch._C._jit_pass_canonicalize_ops(graph)
152 torch._C._jit_pass_lint(graph)
154 torch._C._jit_pass_peephole(graph,
True)
155 torch._C._jit_pass_lint(graph)
158 torch._C._jit_pass_prepare_division_for_onnx(graph)
160 torch._C._jit_pass_erase_number_types(graph)
162 torch._C._jit_pass_lower_all_tuples(graph)
163 torch._C._jit_pass_peephole(graph,
True)
164 torch._C._jit_pass_lint(graph)
166 if operator_export_type != OperatorExportTypes.RAW:
167 graph = torch._C._jit_pass_onnx(graph, operator_export_type)
168 torch._C._jit_pass_lint(graph)
169 torch._C._jit_pass_onnx_peephole(graph)
170 torch._C._jit_pass_lint(graph)
171 torch._C._jit_pass_dce(graph)
172 torch._C._jit_pass_lint(graph)
173 torch._C._jit_pass_fixup_onnx_loops(graph)
174 torch._C._jit_pass_lint(graph)
175 graph = torch._C._jit_pass_canonicalize(graph)
176 torch._C._jit_pass_lint(graph)
180 def _trace(func, args, operator_export_type, return_outs=False):
182 if isinstance(args, torch.Tensor):
186 trace.set_graph(_optimize_graph(trace.graph(), operator_export_type))
188 return trace, torch_out
192 def _trace_and_get_graph_from_model(model, args, training):
196 orig_state_dict_keys = _unique_state_dict(model).keys()
203 with set_training(model, training):
206 if orig_state_dict_keys != _unique_state_dict(model).keys():
207 raise RuntimeError(
"state_dict changed after running the tracer; " 208 "something weird is happening in your model!")
210 return trace.graph(), torch_out
213 def _model_to_graph(model, args, f, verbose=False, training=False,
214 input_names=
None, output_names=
None,
215 operator_export_type=OperatorExportTypes.ONNX,
216 example_outputs=
None, propagate=
False,
217 _retain_param_name=
False):
219 if isinstance(args, torch.Tensor):
224 assert example_outputs
is not None,
"example_outputs must be provided when exporting a ScriptModule" 225 if isinstance(example_outputs, torch.Tensor):
226 example_outputs = [example_outputs]
228 method = model.__getattr__(
'forward')
229 graph = method.propagate_and_assign_input_and_output_shapes(
230 args, example_outputs,
False, propagate)
232 params = method.initial_ivalues()
233 except AttributeError:
235 raise RuntimeError(
'\'forward\' method must be a script method')
237 graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
238 state_dict = _unique_state_dict(model)
239 params = list(state_dict.values())
240 if _retain_param_name:
241 graph_inputs = list(graph.inputs())
242 user_input_num = len(graph_inputs) - len(state_dict)
243 param_names = list(state_dict.keys())
244 for i, inp
in enumerate(graph_inputs):
245 if i >= user_input_num:
246 inp.setUniqueName(param_names[i - user_input_num])
248 graph = _optimize_graph(graph, operator_export_type)
252 if torch_out
is not None:
253 output_tensors, _ = torch._C._jit_flatten(torch_out)
254 for output, tensor
in zip(graph.outputs(), output_tensors):
255 output.inferTypeFrom(tensor)
257 _set_input_and_output_names(graph, input_names, output_names)
261 return graph, params, torch_out
264 def export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=False,
265 input_names=
None, output_names=
None, aten=
False, export_raw_ir=
False,
266 operator_export_type=
None, export_type=ExportTypes.PROTOBUF_FILE,
267 example_outputs=
None, propagate=
False, google_printer=
False,
268 opset_version=
None, _retain_param_name=
True):
269 if aten
or export_raw_ir:
270 assert operator_export_type
is None 271 assert aten ^ export_raw_ir
272 operator_export_type = OperatorExportTypes.ATEN
if aten
else OperatorExportTypes.RAW
273 elif operator_export_type
is None:
274 operator_export_type = OperatorExportTypes.ONNX
275 return _export_to_pretty_string(model, args, f, export_params, verbose, training,
276 input_names, output_names, operator_export_type,
277 export_type, example_outputs, propagate, google_printer,
278 opset_version, _retain_param_name)
281 def _export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=False,
282 input_names=
None, output_names=
None, operator_export_type=OperatorExportTypes.ONNX,
283 export_type=ExportTypes.PROTOBUF_FILE, example_outputs=
None, propagate=
False,
284 google_printer=
False, opset_version=
None, _retain_param_name=
False):
286 if opset_version
is None:
287 opset_version = _default_onnx_opset_version
288 _set_opset_version(opset_version)
289 graph, params, torch_out = _model_to_graph(model, args, f, verbose,
290 training, input_names,
291 output_names, operator_export_type,
292 example_outputs, propagate, _retain_param_name)
294 return graph._pretty_print_onnx(params, opset_version,
False, operator_export_type, google_printer)
301 def _export(model, args, f, export_params=True, verbose=False, training=False,
302 input_names=
None, output_names=
None, operator_export_type=OperatorExportTypes.ONNX,
303 export_type=ExportTypes.PROTOBUF_FILE, example_outputs=
None, propagate=
False,
304 opset_version=
None, _retain_param_name=
False):
306 if opset_version
is None:
307 opset_version = _default_onnx_opset_version
308 _set_opset_version(opset_version)
309 graph, params, torch_out = _model_to_graph(model, args, f, verbose,
310 training, input_names,
311 output_names, operator_export_type,
312 example_outputs, propagate,
316 defer_weight_export = export_type
is not ExportTypes.PROTOBUF_FILE
318 proto, export_map = graph._export_onnx(params, opset_version, defer_weight_export, operator_export_type)
320 proto, export_map = graph._export_onnx([], opset_version,
False, operator_export_type)
322 if export_type == ExportTypes.PROTOBUF_FILE:
323 assert(len(export_map) == 0)
325 elif export_type
in [ExportTypes.ZIP_ARCHIVE, ExportTypes.COMPRESSED_ZIP_ARCHIVE]:
327 compression = zipfile.ZIP_DEFLATED \
328 if export_type == ExportTypes.COMPRESSED_ZIP_ARCHIVE \
329 else zipfile.ZIP_STORED
330 with zipfile.ZipFile(f,
'w', compression=compression)
as z:
331 z.writestr(ONNX_ARCHIVE_MODEL_PROTO_NAME, proto)
332 for k, v
in export_map.items():
334 elif export_type == ExportTypes.DIRECTORY:
336 if os.path.exists(f):
337 assert(os.path.isdir(f))
341 model_proto_file = os.path.join(f, ONNX_ARCHIVE_MODEL_PROTO_NAME)
343 model_proto_file,
"wb",
lambda f: f.write(proto))
345 for k, v
in export_map.items():
346 weight_proto_file = os.path.join(f, k)
348 weight_proto_file,
"wb",
lambda f: f.write(v))
350 raise RuntimeError(
'Unknown export type')
354 def _set_input_and_output_names(graph, input_names, output_names):
355 def set_names(node_list, name_list, descriptor):
356 if name_list
is None:
358 if len(name_list) > len(node_list):
360 "number of %s names provided (%d) exceeded number of %ss (%d)" 361 % (descriptor, len(name_list), descriptor, len(node_list)))
362 for name, node
in zip(name_list, node_list):
363 if node.uniqueName() != name:
364 node.setUniqueName(name)
365 set_names(list(graph.inputs()), input_names,
'input')
366 set_names(list(graph.outputs()), output_names,
'output')
368 attr_pattern = re.compile(
"^(.+)_([ifstgz])$")
371 def _run_symbolic_method(op_name, symbolic_fn, args):
373 This trampoline function gets invoked for every symbolic method 377 return symbolic_fn(*args)
378 except TypeError
as e:
382 e.args = (
"{} (occurred when translating {})".format(e.args[0], op_name), )
386 def _is_onnx_list(value):
387 if not isinstance(value, string_classes)
and \
388 not isinstance(value, torch.Tensor)
and \
389 isinstance(value, container_abcs.Iterable):
394 def _add_attribute(node, key, value, aten):
395 r""" initializes the right attribute based on type of value """ 396 m = attr_pattern.match(key)
399 "Invalid attribute specifier '{}' names " +
400 " must be suffixed with type, e.g. 'dim_i' or 'dims_i'").format(key))
401 name, kind = m.group(1), m.group(2)
402 if _is_onnx_list(value):
405 if isinstance(value, torch.Tensor):
407 if value.numel() > 1:
408 raise ValueError(
"Should not pass tensor attribute")
409 value = _scalar(value)
410 if isinstance(value, float):
414 return getattr(node, kind +
"_")(name, value)
418 """Convert a scalar tensor into a Python value.""" 419 assert x.numel() == 1
423 def _newNode(g, opname, outputs, *args, **kwargs):
428 aten = kwargs.pop(
"aten",
False)
429 ns =
"aten" if aten
else "onnx" 430 ns_opname = ns +
"::" + opname
431 n = g.create(ns_opname, args, outputs)
432 for k, v
in sorted(kwargs.items()):
436 _add_attribute(n, k, v, aten=aten)
440 def _graph_op(g, opname, *raw_args, **kwargs):
442 Create an ONNX operator 'opname', taking 'args' as inputs and attributes 443 'kwargs'; returning the node representing the single output of this operator 444 (see the `outputs` keyword argument for multi-return nodes). 446 The set of operators and the inputs/attributes they take 447 is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md 449 This function is monkey-patched onto Graph. 452 opname (string): The ONNX operator name, e.g., `Abs` or `Add`. 453 args (Node...): The inputs to the operator; usually provided 454 as arguments to the `symbolic` definition. 455 kwargs: The attributes of the ONNX operator, with keys named 456 according to the following convention: `alpha_f` indicates 457 the `alpha` attribute with type `f`. The valid type specifiers are 458 `f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute 459 specified with type float accepts either a single float, or a 460 list of floats (e.g., you would say `dims_i` for a `dims` attribute 461 that takes a list of integers). 462 outputs (int, optional): The number of outputs this operator returns; 463 by default an operator is assumed to return a single output. 464 If `outputs` is greater than one, this functions returns a tuple 465 of output `Node`, representing each output of the ONNX operator 468 outputs = kwargs.pop(
'outputs', 1)
472 kwargs = dict((k, v)
for k, v
in kwargs.items()
if v
is not None)
474 def const_if_tensor(arg):
477 elif isinstance(arg, torch._C.Value):
480 return g.op(
"Constant", value_z=arg)
482 args = list(const_if_tensor(arg)
for arg
in raw_args)
483 n = g.insertNode(_newNode(g, opname, outputs, *args, **kwargs))
486 return tuple(o
for o
in n.outputs())
498 def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExportTypes.ONNX):
506 if n.kind().endswith(
'_'):
507 ns_op_name = n.kind()[:-1]
509 ns_op_name = n.kind()
510 ns, op_name = ns_op_name.split(
"::")
518 is_onnx_aten_export = operator_export_type == OperatorExportTypes.ONNX_ATEN
519 is_aten_fallback_export = operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK
520 if is_onnx_aten_export
or (
not is_exportable_aten_op
and is_aten_fallback_export):
522 attrs = {k +
"_" + n.kindOf(k)[0]: n[k]
for k
in n.attributeNames()}
523 outputs = n.outputsSize()
524 attrs[
"outputs"] = outputs
525 return _graph_at(g, op_name, *inputs, aten=
True, **attrs)
529 attrs = {k: n[k]
for k
in n.attributeNames()}
530 if not is_exportable_aten_op:
531 warnings.warn(
"ONNX export failed on ATen operator {} because torch.onnx.symbolic.{} does not exist" 532 .format(op_name, op_name))
535 return fn(g, *inputs, **attrs)
538 if op_name ==
"Constant" and not n.mustBeNone():
539 if n.kindOf(
"value") ==
"t":
540 return g.op(
"Constant", value_t=n[
"value"])
541 elif n.kindOf(
"value") ==
"is":
542 value = torch.stack([
torch.tensor(v)
for v
in n[
"value"]])
if n[
"value"]
else []
543 return g.op(
"Constant", value_t=value)
544 elif n.output().type().kind() ==
"DeviceObjType":
547 raise RuntimeError(
"Unsupported prim::Constant kind: `{}`. Send a bug report.".format(
549 elif n.mustBeNone()
or op_name ==
"ListConstruct" or op_name ==
"ListUnpack":
555 elif op_name ==
'Loop' or op_name ==
'If':
556 new_op_outputs = g.op(op_name, *inputs, outputs=n.outputsSize())
557 new_node = new_op_outputs[0].node()
if n.outputsSize() > 1
else new_op_outputs.node()
559 new_block = new_node.addBlock()
560 torch._C._jit_pass_onnx_block(b, new_block, operator_export_type, env)
561 return new_op_outputs
563 symbolic_name =
'prim_' + op_name
565 if symbolic_fn
is None:
566 warnings.warn(
"ONNX export failed on primitive operator {}; please report a bug".format(op_name))
568 attrs = {k: n[k]
for k
in n.attributeNames()}
569 return symbolic_fn(g, *inputs, **attrs)
572 warnings.warn(
"ONNX export failed on an operator with unrecognized namespace {}::{}; " 573 "please report a bug".format(ns, op_name))
576 except TypeError
as e:
579 e.args = (
"{} (occurred when translating {})".format(e.args[0], op_name), )
584 def _graph_at(g, opname, *args, **kwargs):
585 return g.op(
"ATen", *args, operator_s=opname, **kwargs)
593 def _graph_constant(g, value, dims, type, *args, **kwargs):
594 assert isinstance(value, numbers.Number)
595 assert type
is not None 597 if dims
is None or dims == 0
or set(dims) == set([0]):
602 tensor = torch.CharTensor(*dims)
603 elif type ==
"short":
604 tensor = torch.ShortTensor(*dims)
606 tensor = torch.IntTensor(*dims)
608 tensor = torch.LongTensor(*dims)
610 tensor = torch.HalfTensor(*dims)
611 elif type ==
"float":
612 tensor = torch.FloatTensor(*dims)
613 elif type ==
"double":
614 tensor = torch.DoubleTensor(*dims)
616 raise ValueError(
"Unknown type, type should be one of the following strings: " 617 "char, short, int, long, half, float, double")
620 return g.op(
"Constant", *args, value_z=tensor, **kwargs)
621 return g.op(
"Constant", *args, value_t=tensor, **kwargs)
624 def _node_getitem(self, k):
626 Accessor for attributes of a node which is polymorphic over 629 NB: This is monkey-patched onto Node. 632 return getattr(self, sel)(k)
635 torch._C.Graph.op = _graph_op
636 torch._C.Graph.at = _graph_at
637 torch._C.Graph.constant = _graph_constant
638 torch._C.Node.__getitem__ = _node_getitem
def get_trace_graph(f, args=(), kwargs=None, _force_outplace=False, return_inputs=False)
def _with_file_like(f, mode, body)