Caffe2 - Python API
A deep learning, cross platform ML framework
utils.py
1 r"""
2 The torch.onnx module contains functions to export models into the ONNX
3 IR format. These models can be loaded with the ONNX library and then
4 converted to models which run on other deep learning frameworks.
5 """
6 
7 import torch
8 import torch.jit
9 import torch.autograd
11 import re
12 from torch._six import container_abcs
13 import contextlib
14 import numbers
15 import warnings
16 import functools
17 import types
18 from torch._six import string_classes
19 from torch.autograd import Function, function
20 from torch.jit import _unique_state_dict
21 from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes
22 from torch._C import ListType
23 
24 
25 @contextlib.contextmanager
26 def set_training(model, mode):
27  r"""
28  A context manager to temporarily set the training mode of 'model'
29  to 'mode', resetting it when we exit the with-block. A no-op if
30  mode is None.
31  """
32  if mode is None:
33  yield
34  return
35  old_mode = model.training
36  if old_mode != mode:
37  model.train(mode)
38  try:
39  yield
40  finally:
41  if old_mode != mode:
42  model.train(old_mode)
43 
44 
45 def export(model, args, f, export_params=True, verbose=False, training=False,
46  input_names=None, output_names=None, aten=False, export_raw_ir=False,
47  operator_export_type=None, opset_version=None, _retain_param_name=True):
48  r"""
49  Export a model into ONNX format. This exporter runs your model
50  once in order to get a trace of its execution to be exported;
51  at the moment, it supports a limited set of dynamic models (e.g., RNNs.)
52 
53  See also: :ref:`onnx-export`
54 
55  Arguments:
56  model (torch.nn.Module): the model to be exported.
57  args (tuple of arguments): the inputs to
58  the model, e.g., such that ``model(*args)`` is a valid
59  invocation of the model. Any non-Tensor arguments will
60  be hard-coded into the exported model; any Tensor arguments
61  will become inputs of the exported model, in the order they
62  occur in args. If args is a Tensor, this is equivalent
63  to having called it with a 1-ary tuple of that Tensor.
64  (Note: passing keyword arguments to the model is not currently
65  supported. Give us a shout if you need it.)
66  f: a file-like object (has to implement fileno that returns a file descriptor)
67  or a string containing a file name. A binary Protobuf will be written
68  to this file.
69  export_params (bool, default True): if specified, all parameters will
70  be exported. Set this to False if you want to export an untrained model.
71  In this case, the exported model will first take all of its parameters
72  as arguments, the ordering as specified by ``model.state_dict().values()``
73  verbose (bool, default False): if specified, we will print out a debug
74  description of the trace being exported.
75  training (bool, default False): export the model in training mode. At
76  the moment, ONNX is oriented towards exporting models for inference
77  only, so you will generally not need to set this to True.
78  input_names(list of strings, default empty list): names to assign to the
79  input nodes of the graph, in order
80  output_names(list of strings, default empty list): names to assign to the
81  output nodes of the graph, in order
82  aten (bool, default False): [DEPRECATED. use operator_export_type] export the
83  model in aten mode. If using aten mode, all the ops original exported
84  by the functions in symbolic.py are exported as ATen ops.
85  export_raw_ir (bool, default False): [DEPRECATED. use operator_export_type]
86  export the internal IR directly instead of converting it to ONNX ops.
87  operator_export_type (enum, default OperatorExportTypes.ONNX):
88  OperatorExportTypes.ONNX: all ops are exported as regular ONNX ops.
89  OperatorExportTypes.ONNX_ATEN: all ops are exported as ATen ops.
90  OperatorExportTypes.ONNX_ATEN_FALLBACK: if symbolic is missing,
91  fall back on ATen op.
92  OperatorExportTypes.RAW: export raw ir.
93  opset_version (int, default is 9): by default we export the model to the
94  opset version of the onnx submodule. Since ONNX's latest opset may
95  evolve before next stable release, by default we export to one stable
96  opset version. Right now, supported stable opset version is 9.
97  The opset_version must be _onnx_master_opset or in _onnx_stable_opsets
98  which are defined in torch/onnx/symbolic.py
99  """
100  if aten or export_raw_ir:
101  assert operator_export_type is None
102  assert aten ^ export_raw_ir
103  operator_export_type = OperatorExportTypes.ATEN if aten else OperatorExportTypes.RAW
104  elif operator_export_type is None:
105  if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE:
106  operator_export_type = OperatorExportTypes.ONNX_ATEN_FALLBACK
107  else:
108  operator_export_type = OperatorExportTypes.ONNX
109  _export(model, args, f, export_params, verbose, training, input_names, output_names,
110  operator_export_type=operator_export_type, opset_version=opset_version,
111  _retain_param_name=_retain_param_name)
112 
113 
114 # ONNX can't handle constants that are lists of tensors, which can
115 # get generated in constant prop. So we split them back into prim::ListConstructs
116 def _split_tensor_list_constants(g, block):
117  for node in block.nodes():
118  for subblock in node.blocks():
119  _split_tensor_list_constants(g, subblock)
120  if node.kind() == "prim::Constant":
121  output_type = node.output().type()
122  if output_type.isSubtypeOf(ListType.ofTensors()):
123  inputs = [g.create("prim::Constant").t_('value', t)
124  .insertBefore(node).output()
125  for t in node['value']]
126  lc = (g.create("prim::ListConstruct", inputs)
127  .insertBefore(node)
128  .output()
129  .setType(ListType.ofTensors()))
130  node.output().replaceAllUsesWith(lc)
131 
132 
133 def _optimize_graph(graph, operator_export_type):
134  # Remove fork/wait nodes
135  torch._C._jit_pass_inline_fork_wait(graph)
136  torch._C._jit_pass_dce(graph)
137  torch._C._jit_pass_lint(graph)
138 
139  torch._C._jit_pass_remove_inplace_ops(graph)
140  # we record now record some ops like ones/zeros
141  # into a trace where we previously recorded constants
142  # use constant prop to maintain our current level of onnx support
143  # without implementing symbolics for all of them
144  torch._C._jit_pass_constant_propagation(graph)
145  _split_tensor_list_constants(graph, graph)
146  # run dce to eliminate dead parts of the graph that might have been
147  # left behind by things like symbolic_override
148  torch._C._jit_pass_dce(graph)
149  torch._C._jit_pass_lint(graph)
150 
151  torch._C._jit_pass_canonicalize_ops(graph)
152  torch._C._jit_pass_lint(graph)
153 
154  torch._C._jit_pass_peephole(graph, True)
155  torch._C._jit_pass_lint(graph)
156 
157  # onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0
158  torch._C._jit_pass_prepare_division_for_onnx(graph)
159  # onnx only supports tensors, so we turn all out number types into tensors
160  torch._C._jit_pass_erase_number_types(graph)
161  # onnx does not support tuples, so try to remove them
162  torch._C._jit_pass_lower_all_tuples(graph)
163  torch._C._jit_pass_peephole(graph, True)
164  torch._C._jit_pass_lint(graph)
165 
166  if operator_export_type != OperatorExportTypes.RAW:
167  graph = torch._C._jit_pass_onnx(graph, operator_export_type)
168  torch._C._jit_pass_lint(graph)
169  torch._C._jit_pass_onnx_peephole(graph)
170  torch._C._jit_pass_lint(graph)
171  torch._C._jit_pass_dce(graph)
172  torch._C._jit_pass_lint(graph)
173  torch._C._jit_pass_fixup_onnx_loops(graph)
174  torch._C._jit_pass_lint(graph)
175  graph = torch._C._jit_pass_canonicalize(graph)
176  torch._C._jit_pass_lint(graph)
177  return graph
178 
179 
180 def _trace(func, args, operator_export_type, return_outs=False):
181  # Special case for common case of passing a single Tensor
182  if isinstance(args, torch.Tensor):
183  args = (args, )
184 
185  trace, torch_out = torch.jit.get_trace_graph(func, args, _force_outplace=True)
186  trace.set_graph(_optimize_graph(trace.graph(), operator_export_type))
187  if return_outs:
188  return trace, torch_out
189  return trace
190 
191 
192 def _trace_and_get_graph_from_model(model, args, training):
193 
194  # A basic sanity check: make sure the state_dict keys are the same
195  # before and after running the model. Fail fast!
196  orig_state_dict_keys = _unique_state_dict(model).keys()
197 
198  # By default, training=False, which is good because running a model in
199  # training mode could result in internal buffers getting updated, dropout
200  # getting applied, etc. If you really know what you're doing, you
201  # can turn training=True (or None, to preserve whatever the original
202  # training mode was.)
203  with set_training(model, training):
204  trace, torch_out = torch.jit.get_trace_graph(model, args, _force_outplace=True)
205 
206  if orig_state_dict_keys != _unique_state_dict(model).keys():
207  raise RuntimeError("state_dict changed after running the tracer; "
208  "something weird is happening in your model!")
209 
210  return trace.graph(), torch_out
211 
212 
213 def _model_to_graph(model, args, f, verbose=False, training=False,
214  input_names=None, output_names=None,
215  operator_export_type=OperatorExportTypes.ONNX,
216  example_outputs=None, propagate=False,
217  _retain_param_name=False):
218  # Special case for common case of passing a single Tensor
219  if isinstance(args, torch.Tensor):
220  args = (args, )
221 
222  if isinstance(model, torch.jit.ScriptModule):
223  torch_out = None
224  assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule"
225  if isinstance(example_outputs, torch.Tensor):
226  example_outputs = [example_outputs]
227  try:
228  method = model.__getattr__('forward')
229  graph = method.propagate_and_assign_input_and_output_shapes(
230  args, example_outputs, False, propagate)
231  # Erase number types to bring the graph to a pre-NumberType state
232  params = method.initial_ivalues()
233  except AttributeError:
234  # TODO: just trace it
235  raise RuntimeError('\'forward\' method must be a script method')
236  else:
237  graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
238  state_dict = _unique_state_dict(model)
239  params = list(state_dict.values())
240  if _retain_param_name:
241  graph_inputs = list(graph.inputs())
242  user_input_num = len(graph_inputs) - len(state_dict)
243  param_names = list(state_dict.keys())
244  for i, inp in enumerate(graph_inputs):
245  if i >= user_input_num:
246  inp.setUniqueName(param_names[i - user_input_num])
247 
248  graph = _optimize_graph(graph, operator_export_type)
249 
250  # NB: ONNX requires complete information about output types, which might be
251  # erased by some optimizations, so we need to set it explicitly again.
252  if torch_out is not None:
253  output_tensors, _ = torch._C._jit_flatten(torch_out)
254  for output, tensor in zip(graph.outputs(), output_tensors):
255  output.inferTypeFrom(tensor)
256 
257  _set_input_and_output_names(graph, input_names, output_names)
258  if verbose:
259  print(graph)
260 
261  return graph, params, torch_out
262 
263 
264 def export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=False,
265  input_names=None, output_names=None, aten=False, export_raw_ir=False,
266  operator_export_type=None, export_type=ExportTypes.PROTOBUF_FILE,
267  example_outputs=None, propagate=False, google_printer=False,
268  opset_version=None, _retain_param_name=True):
269  if aten or export_raw_ir:
270  assert operator_export_type is None
271  assert aten ^ export_raw_ir
272  operator_export_type = OperatorExportTypes.ATEN if aten else OperatorExportTypes.RAW
273  elif operator_export_type is None:
274  operator_export_type = OperatorExportTypes.ONNX
275  return _export_to_pretty_string(model, args, f, export_params, verbose, training,
276  input_names, output_names, operator_export_type,
277  export_type, example_outputs, propagate, google_printer,
278  opset_version, _retain_param_name)
279 
280 
281 def _export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=False,
282  input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX,
283  export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, propagate=False,
284  google_printer=False, opset_version=None, _retain_param_name=False):
285  from torch.onnx.symbolic import _default_onnx_opset_version, _set_opset_version
286  if opset_version is None:
287  opset_version = _default_onnx_opset_version
288  _set_opset_version(opset_version)
289  graph, params, torch_out = _model_to_graph(model, args, f, verbose,
290  training, input_names,
291  output_names, operator_export_type,
292  example_outputs, propagate, _retain_param_name)
293 
294  return graph._pretty_print_onnx(params, opset_version, False, operator_export_type, google_printer)
295 
296 
297 # NOTE: the output `torch_out` will contain the output tensors resulting from
298 # the trace of a Module. In the case that a torch.nn.ScriptModule is passed in,
299 # this output will be None, since we are not doing any tracing but rather
300 # directly extracting the graph.
301 def _export(model, args, f, export_params=True, verbose=False, training=False,
302  input_names=None, output_names=None, operator_export_type=OperatorExportTypes.ONNX,
303  export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, propagate=False,
304  opset_version=None, _retain_param_name=False):
305  from torch.onnx.symbolic import _default_onnx_opset_version, _set_opset_version
306  if opset_version is None:
307  opset_version = _default_onnx_opset_version
308  _set_opset_version(opset_version)
309  graph, params, torch_out = _model_to_graph(model, args, f, verbose,
310  training, input_names,
311  output_names, operator_export_type,
312  example_outputs, propagate,
313  _retain_param_name)
314 
315  # TODO: Don't allocate a in-memory string for the protobuf
316  defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
317  if export_params:
318  proto, export_map = graph._export_onnx(params, opset_version, defer_weight_export, operator_export_type)
319  else:
320  proto, export_map = graph._export_onnx([], opset_version, False, operator_export_type)
321 
322  if export_type == ExportTypes.PROTOBUF_FILE:
323  assert(len(export_map) == 0)
324  torch.serialization._with_file_like(f, "wb", lambda f: f.write(proto))
325  elif export_type in [ExportTypes.ZIP_ARCHIVE, ExportTypes.COMPRESSED_ZIP_ARCHIVE]:
326  import zipfile
327  compression = zipfile.ZIP_DEFLATED \
328  if export_type == ExportTypes.COMPRESSED_ZIP_ARCHIVE \
329  else zipfile.ZIP_STORED
330  with zipfile.ZipFile(f, 'w', compression=compression) as z:
331  z.writestr(ONNX_ARCHIVE_MODEL_PROTO_NAME, proto)
332  for k, v in export_map.items():
333  z.writestr(k, v)
334  elif export_type == ExportTypes.DIRECTORY:
335  import os
336  if os.path.exists(f):
337  assert(os.path.isdir(f))
338  else:
339  os.makedirs(f)
340 
341  model_proto_file = os.path.join(f, ONNX_ARCHIVE_MODEL_PROTO_NAME)
343  model_proto_file, "wb", lambda f: f.write(proto))
344 
345  for k, v in export_map.items():
346  weight_proto_file = os.path.join(f, k)
348  weight_proto_file, "wb", lambda f: f.write(v))
349  else:
350  raise RuntimeError('Unknown export type')
351  return torch_out
352 
353 
354 def _set_input_and_output_names(graph, input_names, output_names):
355  def set_names(node_list, name_list, descriptor):
356  if name_list is None:
357  return
358  if len(name_list) > len(node_list):
359  raise RuntimeError(
360  "number of %s names provided (%d) exceeded number of %ss (%d)"
361  % (descriptor, len(name_list), descriptor, len(node_list)))
362  for name, node in zip(name_list, node_list):
363  if node.uniqueName() != name:
364  node.setUniqueName(name)
365  set_names(list(graph.inputs()), input_names, 'input')
366  set_names(list(graph.outputs()), output_names, 'output')
367 
368 attr_pattern = re.compile("^(.+)_([ifstgz])$")
369 
370 
371 def _run_symbolic_method(op_name, symbolic_fn, args):
372  r"""
373  This trampoline function gets invoked for every symbolic method
374  call from C++.
375  """
376  try:
377  return symbolic_fn(*args)
378  except TypeError as e:
379  # Handle the specific case where we didn't successfully dispatch
380  # to symbolic_fn. Otherwise, the backtrace will have the clues
381  # you need.
382  e.args = ("{} (occurred when translating {})".format(e.args[0], op_name), )
383  raise
384 
385 
386 def _is_onnx_list(value):
387  if not isinstance(value, string_classes) and \
388  not isinstance(value, torch.Tensor) and \
389  isinstance(value, container_abcs.Iterable):
390  return True
391  return False
392 
393 
394 def _add_attribute(node, key, value, aten):
395  r""" initializes the right attribute based on type of value """
396  m = attr_pattern.match(key)
397  if m is None:
398  raise IndexError((
399  "Invalid attribute specifier '{}' names " +
400  " must be suffixed with type, e.g. 'dim_i' or 'dims_i'").format(key))
401  name, kind = m.group(1), m.group(2)
402  if _is_onnx_list(value):
403  kind += "s"
404  if aten:
405  if isinstance(value, torch.Tensor):
406  # Caffe2 proto does not support tensor attribute.
407  if value.numel() > 1:
408  raise ValueError("Should not pass tensor attribute")
409  value = _scalar(value)
410  if isinstance(value, float):
411  kind = "f"
412  else:
413  kind = "i"
414  return getattr(node, kind + "_")(name, value)
415 
416 
417 def _scalar(x):
418  """Convert a scalar tensor into a Python value."""
419  assert x.numel() == 1
420  return x[0]
421 
422 
423 def _newNode(g, opname, outputs, *args, **kwargs):
424  if "::" in opname:
425  aten = False
426  ns_opname = opname
427  else:
428  aten = kwargs.pop("aten", False)
429  ns = "aten" if aten else "onnx"
430  ns_opname = ns + "::" + opname
431  n = g.create(ns_opname, args, outputs)
432  for k, v in sorted(kwargs.items()):
433  # TODO: enable inplace in aten exporting mode.
434  if k == "inplace":
435  continue
436  _add_attribute(n, k, v, aten=aten)
437  return n
438 
439 
440 def _graph_op(g, opname, *raw_args, **kwargs):
441  r"""
442  Create an ONNX operator 'opname', taking 'args' as inputs and attributes
443  'kwargs'; returning the node representing the single output of this operator
444  (see the `outputs` keyword argument for multi-return nodes).
445 
446  The set of operators and the inputs/attributes they take
447  is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md
448 
449  This function is monkey-patched onto Graph.
450 
451  Arguments:
452  opname (string): The ONNX operator name, e.g., `Abs` or `Add`.
453  args (Node...): The inputs to the operator; usually provided
454  as arguments to the `symbolic` definition.
455  kwargs: The attributes of the ONNX operator, with keys named
456  according to the following convention: `alpha_f` indicates
457  the `alpha` attribute with type `f`. The valid type specifiers are
458  `f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute
459  specified with type float accepts either a single float, or a
460  list of floats (e.g., you would say `dims_i` for a `dims` attribute
461  that takes a list of integers).
462  outputs (int, optional): The number of outputs this operator returns;
463  by default an operator is assumed to return a single output.
464  If `outputs` is greater than one, this functions returns a tuple
465  of output `Node`, representing each output of the ONNX operator
466  in positional.
467  """
468  outputs = kwargs.pop('outputs', 1)
469 
470  # Filter out None attributes, this can be convenient client side because
471  # now they can pass through None attributes, and have them not show up
472  kwargs = dict((k, v) for k, v in kwargs.items() if v is not None)
473 
474  def const_if_tensor(arg):
475  if arg is None:
476  return arg
477  elif isinstance(arg, torch._C.Value):
478  return arg
479  else:
480  return g.op("Constant", value_z=arg)
481 
482  args = list(const_if_tensor(arg) for arg in raw_args)
483  n = g.insertNode(_newNode(g, opname, outputs, *args, **kwargs))
484  if outputs == 1:
485  return n.output()
486  return tuple(o for o in n.outputs())
487 
488 
489 # Note [Export inplace]
490 # ~~~~~~~~~~~~~~~~~~~~~
491 # In abstract, it would be better for us to export inplace annotations,
492 # than to not export them, since it is useful information that can
493 # help the target of an ONNX export export more efficiently. However,
494 # ONNX doesn't currently formalize inplace. Fortunately, it's sound to drop
495 # inplace annotations, but we are losing information this way.
496 
497 
498 def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExportTypes.ONNX):
499  # NB: Returning None means the node gets cloned as is into
500  # the new graph
501  try:
502  import torch.onnx.symbolic
503 
504  # See Note [Export inplace]
505  # TODO: I think this is not necessary anymore
506  if n.kind().endswith('_'):
507  ns_op_name = n.kind()[:-1]
508  else:
509  ns_op_name = n.kind()
510  ns, op_name = ns_op_name.split("::")
511 
512  if ns == "onnx":
513  # Use the original node directly
514  return None
515 
516  elif ns == "aten":
517  is_exportable_aten_op = hasattr(torch.onnx.symbolic, op_name)
518  is_onnx_aten_export = operator_export_type == OperatorExportTypes.ONNX_ATEN
519  is_aten_fallback_export = operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK
520  if is_onnx_aten_export or (not is_exportable_aten_op and is_aten_fallback_export):
521  # Direct ATen export requested
522  attrs = {k + "_" + n.kindOf(k)[0]: n[k] for k in n.attributeNames()}
523  outputs = n.outputsSize()
524  attrs["outputs"] = outputs
525  return _graph_at(g, op_name, *inputs, aten=True, **attrs)
526 
527  else:
528  # Export it regularly
529  attrs = {k: n[k] for k in n.attributeNames()}
530  if not is_exportable_aten_op:
531  warnings.warn("ONNX export failed on ATen operator {} because torch.onnx.symbolic.{} does not exist"
532  .format(op_name, op_name))
533  return None
534  fn = getattr(torch.onnx.symbolic, op_name)
535  return fn(g, *inputs, **attrs)
536 
537  elif ns == "prim":
538  if op_name == "Constant" and not n.mustBeNone():
539  if n.kindOf("value") == "t":
540  return g.op("Constant", value_t=n["value"])
541  elif n.kindOf("value") == "is":
542  value = torch.stack([torch.tensor(v) for v in n["value"]]) if n["value"] else []
543  return g.op("Constant", value_t=value)
544  elif n.output().type().kind() == "DeviceObjType":
545  return None
546  else:
547  raise RuntimeError("Unsupported prim::Constant kind: `{}`. Send a bug report.".format(
548  n.kindOf("value")))
549  elif n.mustBeNone() or op_name == "ListConstruct" or op_name == "ListUnpack":
550  # None is not an ONNX operator; keep it as None
551  # let the exporter handle finally eliminating these
552 
553  # For ListConstruct/ListUnpack, it will be erased in the ONNX peephole pass
554  return None
555  elif op_name == 'Loop' or op_name == 'If':
556  new_op_outputs = g.op(op_name, *inputs, outputs=n.outputsSize())
557  new_node = new_op_outputs[0].node() if n.outputsSize() > 1 else new_op_outputs.node()
558  for b in n.blocks():
559  new_block = new_node.addBlock()
560  torch._C._jit_pass_onnx_block(b, new_block, operator_export_type, env)
561  return new_op_outputs
562  else:
563  symbolic_name = 'prim_' + op_name
564  symbolic_fn = getattr(torch.onnx.symbolic, symbolic_name, None)
565  if symbolic_fn is None:
566  warnings.warn("ONNX export failed on primitive operator {}; please report a bug".format(op_name))
567  return None
568  attrs = {k: n[k] for k in n.attributeNames()}
569  return symbolic_fn(g, *inputs, **attrs)
570 
571  else:
572  warnings.warn("ONNX export failed on an operator with unrecognized namespace {}::{}; "
573  "please report a bug".format(ns, op_name))
574  return None
575 
576  except TypeError as e:
577  # Handle the specific case where we didn't successfully dispatch.
578  # Otherwise, the backtrace will have the clues you need.
579  e.args = ("{} (occurred when translating {})".format(e.args[0], op_name), )
580  raise
581 
582 
583 # Generate an ONNX ATen op node.
584 def _graph_at(g, opname, *args, **kwargs):
585  return g.op("ATen", *args, operator_s=opname, **kwargs)
586 
587 
588 # This helper function can create either constant tensor or constant scalar.
589 # If dims is None or 0 or [0], generate a 0-d tensor (scalar).
590 #
591 # TODO: We might not need this anymore, since most scalars now show up
592 # as tensors
593 def _graph_constant(g, value, dims, type, *args, **kwargs):
594  assert isinstance(value, numbers.Number)
595  assert type is not None
596  isscalar = False
597  if dims is None or dims == 0 or set(dims) == set([0]):
598  dims = [1]
599  isscalar = True
600  type = type.lower()
601  if type == "char":
602  tensor = torch.CharTensor(*dims)
603  elif type == "short":
604  tensor = torch.ShortTensor(*dims)
605  elif type == "int":
606  tensor = torch.IntTensor(*dims)
607  elif type == "long":
608  tensor = torch.LongTensor(*dims)
609  elif type == "half":
610  tensor = torch.HalfTensor(*dims)
611  elif type == "float":
612  tensor = torch.FloatTensor(*dims)
613  elif type == "double":
614  tensor = torch.DoubleTensor(*dims)
615  else:
616  raise ValueError("Unknown type, type should be one of the following strings: "
617  "char, short, int, long, half, float, double")
618  tensor.fill_(value)
619  if isscalar:
620  return g.op("Constant", *args, value_z=tensor, **kwargs)
621  return g.op("Constant", *args, value_t=tensor, **kwargs)
622 
623 
624 def _node_getitem(self, k):
625  r"""
626  Accessor for attributes of a node which is polymorphic over
627  return type.
628 
629  NB: This is monkey-patched onto Node.
630  """
631  sel = self.kindOf(k)
632  return getattr(self, sel)(k)
633 
634 
635 torch._C.Graph.op = _graph_op
636 torch._C.Graph.at = _graph_at
637 torch._C.Graph.constant = _graph_constant
638 torch._C.Node.__getitem__ = _node_getitem
def get_trace_graph(f, args=(), kwargs=None, _force_outplace=False, return_inputs=False)
Definition: __init__.py:192
def _with_file_like(f, mode, body)