Caffe2 - Python API
A deep learning, cross platform ML framework
symbolic.py
1 import numbers
2 
3 import torch
4 from torch._C import TensorType, ListType, OptionalType
5 from torch.nn.modules.utils import _single, _pair, _triple
6 from torch.nn.utils.rnn import PackedSequence
7 import warnings
8 
9 import torch.onnx
10 # This import monkey-patches graph manipulation methods on Graph, used for the
11 # ONNX symbolics
12 import torch.onnx.utils
13 
14 from collections import Iterable
15 from functools import partial, wraps
16 import itertools
17 
18 import numpy
19 import math
20 
21 # EDITING THIS FILE? READ THIS FIRST!
22 #
23 # - This file is ONLY for ATen operators (e.g., operators that show up in the
24 # trace as aten::blah). If you need to special case a primitive operator,
25 # look at _run_symbolic_function
26 # - Parameter ordering does NOT necessarily match what is in VariableType.cpp;
27 # tensors are always first, then non-tensor arguments.
28 # - Parameter names must *exactly* match the names in VariableType.cpp, because
29 # dispatch is done with keyword arguments.
30 # - Looking for inplace ops? They're detected by the trailing underscore, and
31 # transparently dispatched to their non inplace versions in
32 # 'run_symbolic_function'. See Note [Export inplace]
33 #
34 # ----------------------------------------------------------------------------------
35 # A note on Tensor types
36 # ----------------------------------------------------------------------------------
37 #
38 # In general, we should avoid depending on the type of Tensor Values contained
39 # within the trace graph. However, this is sometimes unavoidable (due to ONNX
40 # spec requirements, etc). If you are implementing a symbolic and need Tensor
41 # type information, note that there are several levels of Tensor types, defined
42 # in aten/src/ATen/core/jit_type.h:
43 #
44 # TensorType - This is a Tensor, but we don't know anything about its
45 # properties (e.g. scalar type, # dims, shapes).
46 # Appears as `Tensor` in graph print-outs.
47 # DimensionedTensorType <: TensorType - Denotes a Tensor for which we know the scalar
48 # type and number of dimensions, but not the concrete
49 # shapes. For example, appears as 'Float(*, *)' in
50 # graph print-outs. Useful accessor methods include
51 # dim() and scalarType()
52 # CompleteTensorType <: DimensionedTensorType - Denotes a Tensor for which we know the
53 # concrete sizes in addition to the information
54 # contained in TensorTyper. This adds a sizes()
55 # method which can be used to retrieve the
56 # concrete sizes.
57 #
58 # In general, we should prefer to rely on the least specific information possible.
59 # For example, not relying on tensor properties at all is better than relying
60 # on the number of dimensions (DimensionedTensorType) which is better than relying on
61 # concrete shapes (CompleteTensorType). Doing so will make the export symbolics
62 # more robust to different graphs.
63 
64 # ---------------------------------------------------------------------------------
65 # Helper functions
66 # ---------------------------------------------------------------------------------
67 
68 # Save some builtins as locals, because we'll shadown them below
69 _sum = sum
70 
71 
72 def _parse_arg(value, desc):
73  if desc == 'none':
74  return value
75  if desc == 'v' or not _is_value(value):
76  return value
77  if value.node().kind() != 'onnx::Constant':
78  raise RuntimeError("ONNX symbolic expected a constant value in the trace")
79  tval = value.node()['value']
80  if desc == 'i':
81  return int(tval)
82  elif desc == 'f':
83  return float(tval)
84  elif desc == 't':
85  return tval
86  elif desc == 'is':
87  return [int(v) for v in tval]
88  else:
89  raise RuntimeError("Casting constants to `{}` is not implemented".format(desc))
90 
91 
92 def _maybe_get_const(value, desc):
93  if _is_value(value) and value.node().kind() == 'onnx::Constant':
94  return _parse_arg(value, desc)
95  return value
96 
97 
98 def _maybe_get_scalar(value):
99  value_t = _maybe_get_const(value, 't')
100  if isinstance(value_t, torch.Tensor) and value_t.shape == ():
101  return value_t
102  return value
103 
104 
105 def _get_const(value, desc, arg_name):
106  if _is_value(value) and value.node().kind() != 'onnx::Constant':
107  raise RuntimeError("ONNX symbolic expected a constant value of the {} argument".format(arg_name))
108  return _parse_arg(value, desc)
109 
110 
111 def _unpack_list(list_value):
112  list_node = list_value.node()
113  assert list_node.kind() == "prim::ListConstruct"
114  return list(list_node.inputs())
115 
116 
117 def parse_args(*arg_descriptors):
118  def decorator(fn):
119  def wrapper(g, *args):
120  # some args may be optional, so the length may be smaller
121  assert len(arg_descriptors) >= len(args)
122  args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)]
123  return fn(g, *args)
124  # In Python 2 functools.wraps chokes on partially applied functions, so we need this as a workaround
125  try:
126  wrapper = wraps(fn)(wrapper)
127  except Exception:
128  pass
129  return wrapper
130  return decorator
131 
132 
133 def _scalar(x):
134  """Convert a scalar tensor into a Python value."""
135  assert x.numel() == 1
136  return x.item()
137 
138 
139 def _if_scalar_type_as(g, self, tensor):
140  """
141  Convert self into the same type of tensor, as necessary.
142 
143  We only support implicit casting for scalars, so we never
144  actually need to insert an ONNX cast operator here; just
145  fix up the scalar.
146  """
147  if isinstance(self, torch._C.Value):
148  return self
149  elif tensor.type().kind() == "DimensionedTensorType" or tensor.type().kind() == "CompleteTensorType":
150  ty = tensor.type().scalarType().lower()
151  return getattr(self, ty)()
152  else:
153  return self
154 
155 
156 def _is_value(x):
157  return isinstance(x, torch._C.Value)
158 
159 
160 def _is_tensor_list(x):
161  return x.type().isSubtypeOf(ListType.ofTensors())
162 
163 
164 def _unimplemented(op, msg):
165  warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported")
166 
167 
168 def _try_get_scalar_type(*args):
169  for arg in args:
170  try:
171  return arg.type().scalarType()
172  except RuntimeError:
173  pass
174  return None
175 
176 
177 # ---------------------------------------------------------------------
178 # ONNX operator version
179 # ---------------------------------------------------------------------
180 
181 # READ ME BEFORE EDITING _default_onnx_opset_version:
182 #
183 # The variable below controls which ONNX operator set version we are
184 # targeting. THIS VARIABLE HAS SEMANTIC EFFECT! Say a breaking
185 # change occurred in version 8. As long as this variable < 8, you can
186 # export models targeting the old behavior. However, if you bump
187 # this variable to 8 or later, the breaking change will take into effect:
188 # you MUST adjust any symbolic affected by breaking changes. The ONNX
189 # spec publishes a *comprehensive* list of BC-breaking changes for every
190 # operator revision at:
191 #
192 # https://github.com/onnx/onnx/blob/master/docs/Changelog.md
193 #
194 # Please be sure to go through and check all of our implementations here before
195 # increasing this number. This includes symbolic definitions NOT in this
196 # file, so grep for "OpName" (with quotes)
197 #
198 # Besides, opset_version can be specified in the invocation of export()
199 # and export_to_pretty_string(), and _export_onnx_opset_version will be set
200 # and the symbolic functions should check it to determine the behavior
201 # of the exporter.
202 
203 
204 _default_onnx_opset_version = 9
205 _onnx_master_opset = 10
206 _onnx_stable_opsets = [9]
207 _export_onnx_opset_version = _default_onnx_opset_version
208 
209 
210 def _set_opset_version(opset_version):
211  global _export_onnx_opset_version
212  if opset_version == _default_onnx_opset_version:
213  return
214  if opset_version in _onnx_stable_opsets + [_onnx_master_opset]:
215  _export_onnx_opset_version = opset_version
216  return
217  raise ValueError("Unsupported ONNX opset version: " + str(opset_version))
218 
219 
220 # ---------------------------------------------------------------------
221 # Symbolic definitions
222 # ---------------------------------------------------------------------
223 
224 
225 # Note [Pointwise by scalar]
226 # ~~~~~~~~~~~~~~~~~~~~~~~~~~
227 # What happens if you add a tensor with a constant (e.g., x + 2)? There are
228 # some moving parts to implementing the ONNX translation in this case:
229 #
230 # - By the time we get the scalar in a symbolic function here, it is no longer
231 # a Python long/float, but a PyTorch tensor with numel == 1 (eventually, we
232 # want it to be a zero dim tensor but this change has not happened yet.)
233 # However, the type of this scalar is *exactly* what the user wrote in
234 # Python, which may not match the tensor it is being added to. PyTorch
235 # will do implicit conversions on scalars; however, ONNX will not, so
236 # we must do the conversion ourselves. This is what _if_scalar_type_as
237 # does.
238 #
239 # - Dispatch to these functions takes advantage an outrageous coincidence
240 # between the tensor and scalar name. When we add two tensors together,
241 # you get the dispatch:
242 #
243 # add(*[self, other], **{"alpha": alpha})
244 #
245 # When you add a tensor and a scalar, you get the dispatch:
246 #
247 # add(*[self], **{"other": other, "alpha": alpha})
248 #
249 # By having the argument name line up with the name of the scalar attribute
250 # if it exists, we can write a single function for both overloads.
251 #
252 
253 # used to represent "missing" optional inputs
254 def unused(g):
255  n = g.op("prim::Constant")
256  n.setType(OptionalType.ofTensor())
257  return n
258 
259 
260 def _shape_as_tensor(g, input):
261  return g.op('Shape', input)
262 
263 
264 def _reshape_from_tensor(g, input, shape):
265  return g.op('Reshape', input, shape)
266 
267 
268 def reshape(g, self, shape):
269  return view(g, self, shape)
270 
271 
272 def reshape_as(g, self, other):
273  shape = g.op('Shape', other)
274  return reshape(g, self, shape)
275 
276 
277 def add(g, self, other, alpha=None):
278  # default alpha arg is to allow no-alpha add (aten add st overload no alpha)
279  if alpha and _scalar(_maybe_get_scalar(alpha)) != 1:
280  return _unimplemented("add", "alpha != 1")
281  # See Note [Pointwise by scalar]
282  other = _maybe_get_scalar(other)
283  return g.op("Add", self, _if_scalar_type_as(g, other, self))
284 
285 
286 def sub(g, self, other, alpha=None):
287  # default alpha arg is to allow no-alpha sub (aten sub st overload no alpha)
288  if alpha and _scalar(_maybe_get_scalar(alpha)) != 1:
289  return _unimplemented("sub", "alpha != 1")
290  # See Note [Pointwise by scalar]. Note that self or other may be scalars.
291  other = _maybe_get_scalar(other)
292  return g.op("Sub", self, _if_scalar_type_as(g, other, self))
293 
294 
295 def rsub(g, self, other, alpha=None):
296  other = _maybe_get_scalar(other)
297  other = _if_scalar_type_as(g, other, self)
298  return sub(g, other, self, alpha=alpha)
299 
300 
301 def mul(g, self, other):
302  # See Note [Pointwise by scalar]
303  other = _maybe_get_scalar(other)
304  return g.op("Mul", self, _if_scalar_type_as(g, other, self))
305 
306 
307 def div(g, self, other):
308  # See Note [Pointwise by scalar]
309  other = _maybe_get_scalar(other)
310  return g.op("Div", self, _if_scalar_type_as(g, other, self))
311 
312 
313 def reciprocal(g, self):
314  return g.op("Div", _if_scalar_type_as(g, torch.ones(1), self), self)
315 
316 
317 @parse_args('v', 'i')
318 def cat(g, tensor_list, dim):
319  tensors = _unpack_list(tensor_list)
320  return g.op("Concat", *tensors, axis_i=dim)
321 
322 
323 @parse_args('v', 'i')
324 def stack(g, tensor_list, dim):
325  unsqueezed = [g.op("Unsqueeze", t, axes_i=[dim]) for t in _unpack_list(tensor_list)]
326  return g.op("Concat", *unsqueezed, axis_i=dim)
327 
328 
329 def mm(g, self, other):
330  # Create a dummy C tensor. Only needed for API purposes, the value is
331  # since beta = 0
332  ty = _try_get_scalar_type(self, other).lower()
333  C = g.constant(0, [1], ty)
334  return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0)
335 
336 
337 def bmm(g, self, other):
338  return g.op("MatMul", self, other)
339 
340 
341 def matmul(g, self, other):
342  return g.op("MatMul", self, other)
343 
344 
345 @parse_args('v', 'v', 'v', 't', 't')
346 def addmm(g, self, mat1, mat2, beta, alpha):
347  return g.op("Gemm", mat1, mat2, self, beta_f=_scalar(beta), alpha_f=_scalar(alpha))
348 
349 
350 def neg(g, self):
351  return g.op("Neg", self)
352 
353 
354 def sqrt(g, self):
355  return g.op("Sqrt", self)
356 
357 
358 def tanh(g, self):
359  return g.op("Tanh", self)
360 
361 
362 def sin(g, self):
363  return g.op("Sin", self)
364 
365 
366 def cos(g, self):
367  return g.op("Cos", self)
368 
369 
370 def tan(g, self):
371  return g.op("Tan", self)
372 
373 
374 def asin(g, self):
375  return g.op("Asin", self)
376 
377 
378 def acos(g, self):
379  return g.op("Acos", self)
380 
381 
382 def atan(g, self):
383  return g.op("Atan", self)
384 
385 
386 def sigmoid(g, self):
387  return g.op("Sigmoid", self)
388 
389 
390 def _reduce_op_symbolic(onnx_op_name):
391  def symbolic(g, self, dim=None, keepdim=None):
392  if dim is None:
393  # all-reduce path
394  return g.op(onnx_op_name, self, keepdims_i=0)
395  else:
396  # dim-reduce path
397  dim, keepdim = _get_const(dim, 'i', 'dim'), _get_const(keepdim, 'i', 'keepdim')
398  return g.op(onnx_op_name, self, axes_i=[dim], keepdims_i=keepdim)
399  return symbolic
400 
401 mean = _reduce_op_symbolic('ReduceMean')
402 sum = _reduce_op_symbolic('ReduceSum')
403 prod = _reduce_op_symbolic('ReduceProd')
404 
405 
406 @parse_args('v', 'i')
407 def cumsum(g, input, dim):
408  return g.op("ATen", input, operator_s="cumsum", dim_i=dim)
409 
410 
411 def t(g, self):
412  return g.op("Transpose", self, perm_i=(1, 0))
413 
414 
415 def expand(g, self, size, implicit):
416  size = _maybe_get_const(size, 'is')
417  if not _is_value(size):
418  size = g.op("Constant", value_t=torch.LongTensor(size))
419  return g.op("Expand", self, size)
420 
421 
422 def expand_as(g, self, other):
423  shape = g.op("Shape", other)
424  return g.op("Expand", self, shape)
425 
426 
427 def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
428  return g.op("Gather", weight, indices)
429 
430 
431 @parse_args('v', 'v', 'v', 'i', 'i', 'i')
432 def embedding_bag(g,
433  embedding_matrix,
434  indices,
435  offsets,
436  scale_grad_by_freq,
437  mode,
438  sparse):
439  return g.op("ATen",
440  embedding_matrix,
441  indices,
442  offsets,
443  operator_s="embedding_bag",
444  outputs=4,
445  scale_grad_by_freq_i=scale_grad_by_freq,
446  mode_i=mode,
447  sparse_i=sparse)
448 
449 
450 def size(g, self, dim):
451  full_shape = g.op("Shape", self)
452  return select(g, full_shape, g.op("Constant", value_t=torch.tensor([0])), dim)
453 
454 
455 @parse_args('v', 'i', 'i')
456 def transpose(g, self, dim0, dim1):
457  if dim0 == dim1: # micro-optimization
458  return self
459 
460  # NB: Transpose in ONNX is actually a Permute
461  axes = list(range(self.type().dim()))
462  axes[dim0], axes[dim1] = axes[dim1], axes[dim0]
463  return g.op("Transpose", self, perm_i=axes)
464 
465 
466 @parse_args('v', 'is')
467 def permute(g, self, dims):
468  if dims == list(range(0, len(dims))):
469  return self
470  return g.op("Transpose", self, perm_i=dims)
471 
472 
473 def view(g, self, size):
474  size = _maybe_get_const(size, 'is')
475  if _is_value(size):
476  shape = size
477  else:
478  if self.isTensor():
479  self_sizes = self.type().sizes()
480  if self_sizes and len(size) == 2 and self_sizes[0] == size[0]:
481  return g.op("Flatten", self, axis_i=1)
482  shape = g.op("Constant", value_t=torch.LongTensor(size))
483  return g.op("Reshape", self, shape)
484 
485 
486 def prim_ConstantSplit(g, self, split_size, dim):
487  size = self.type().sizes()[dim]
488  splits = [split_size] * (size // split_size)
489  leftover = size % split_size
490  if leftover:
491  splits.append(leftover)
492  return g.op("Split", self, split_i=splits, axis_i=dim, outputs=len(splits))
493 
494 
495 # TODO: It would be better to export this as a chunk directly, as this is
496 # less sensitive to changes in input size.
497 # TODO: Once we have proper scoping, stop reimplementing chunk, delete this
498 # method, and use the desugared version
499 def prim_ConstantChunk(g, self, chunks, dim):
500  split_size = (self.type().sizes()[dim] + chunks - 1) // chunks
501  return prim_ConstantSplit(g, self, split_size, dim)
502 
503 
504 @parse_args('v', 'i', 'i')
505 def split(g, self, split_size, dim):
506  size = self.type().sizes()[dim]
507  splits = [split_size] * (size // split_size)
508  leftover = size % split_size
509  if leftover:
510  splits.append(leftover)
511  return g.op("Split", self, split_i=splits, axis_i=dim, outputs=1)
512 
513 
514 @parse_args('v', 'is', 'i')
515 def split_with_sizes(g, self, split_sizes, dim):
516  return g.op("Split", self, split_i=split_sizes, axis_i=dim, outputs=1)
517 
518 
519 @parse_args('v', 'i', 'v')
520 def select(g, self, dim, index):
521  if dim > 1:
522  # TODO: this is a temporary hack because of the implementation details
523  # of Gather in caffe2. We need to change this as soon as possible.
524  # TODO: this breaks if index == -1
525  index_val = _parse_arg(index, 'i')
526  slice_node = g.op("Slice", self, axes_i=[dim], starts_i=[index_val], ends_i=[index_val + 1])
527  return g.op("Squeeze", slice_node, axes_i=[dim])
528  else:
529  return g.op("Gather", self, index, axis_i=dim)
530 
531 
532 def squeeze(g, self, dim=None):
533  if dim is None:
534  dims = []
535  for i, size in enumerate(self.type().sizes()):
536  if size == 1:
537  dims.append(i)
538  else:
539  dims = [_get_const(dim, 'i', 'dim')]
540  return g.op("Squeeze", self, axes_i=dims)
541 
542 
543 def prelu(g, self, weight):
544  return g.op("PRelu", self, weight)
545 
546 
547 def relu(g, input):
548  return g.op("Relu", input)
549 
550 
551 @parse_args('v', 't', 't')
552 def threshold(g, self, threshold, value):
553  # See Note [Export inplace]
554  if _scalar(threshold) != 0:
555  return _unimplemented("threshold", "non-zero threshold")
556  if _scalar(value) != 0:
557  return _unimplemented("threshold", "non-zero value")
558  return g.op("Relu", self)
559 
560 
561 def leaky_relu(g, input, negative_slope, inplace=False):
562  negative_slope = _get_const(negative_slope, 't', 'negative_slope')
563  # See Note [Export inplace]
564  # TODO: Talk to ONNX about unconditional cast of scalar to float
565  return g.op("LeakyRelu", input, alpha_f=_scalar(negative_slope))
566 
567 
568 @parse_args('v', 'i')
569 def glu(g, input, dim):
570  assert input.type().sizes()[dim] % 2 == 0
571 
572  first, second = g.op('Split', input, axis_i=dim, outputs=2)
573  return g.op('Mul', first, g.op('Sigmoid', second))
574 
575 
576 @parse_args('v', 'i', 'i')
577 def softmax(g, input, dim, dtype=None):
578  # Softmax does normalization at vector level.
579  # PyTorch and ONNX use different strategies to split the input tensor into vectors.
580  # Thus dim and axis have different meanings.
581  # PyTorch slices the input tensor into vectors along the `dim`-th dimension.
582  # ONNX reshapes the input into a 2-D tensor, and `axis` indicates where the input is coerced.
583  # If input is a 2 x 3 tensor:
584  # input = [[1.0, 1.0, 1.0],
585  # [1.0, 1,0, 1,0]]
586  # with dim = 0, the result is:
587  # result = [[0.5, 0.5, 0.5],
588  # [0.5, 0.5, 0.5]]
589  # with axis = 0, the result is:
590  # result = [[0.167, 0.167, 0.167],
591  # [0.167, 0.167, 0.167]]
592  # So only when dim and axis both equal to ndim - 1 (the last dimension),
593  # their semantics are equivalent.
594  if dim < 0:
595  dim = input.type().dim() + dim
596  if input.type().dim() != dim + 1:
597  return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input.")
598  return_op = g.op('Softmax', input, axis_i=dim)
599  if dtype:
600  return_op = g.op("Cast", return_op, to_i=scalar_type_to_onnx[dtype])
601  return return_op
602 
603 
604 @parse_args('v', 't', 'v')
605 def softplus(g, self, beta, threshold):
606  if beta != 1:
607  return _unimplemented("beta", "has to be 1")
608  return g.op('Softplus', self)
609 
610 
611 def get_pool_ceil_padding(input, kernel_size, stride, padding):
612  dim = input.type().sizes()[-len(padding):]
613  ceiled_output_dim = [int(math.ceil((dim[i] + 2 * padding[i] - kernel_size[i]) / float(stride[i]))) + 1
614  for i in range(0, len(padding))]
615  # ensure last pooling starts inside
616  ceiled_output_dim = [ceiled_output_dim[i] - 1
617  if (((ceiled_output_dim[i] - 1) * stride[i]) >= (dim[i] + padding[i]))
618  else ceiled_output_dim[i]
619  for i in range(0, len(ceiled_output_dim))]
620  padding_ceil = [0
621  if (stride[i] == 1)
622  else
623  (kernel_size[i] - (dim[i] + 2 * padding[i] - ((ceiled_output_dim[i] - 1) * stride[i] + 1)))
624  for i in range(0, len(padding))]
625  # ensure padding is not > kernel_size
626  padding_ceil = [(int(padding_ceil[i]) if padding_ceil[i] < kernel_size[i] - 1 else int(kernel_size[i] - 1))
627  if ((padding_ceil[i] + 2 * padding[i]) >= (kernel_size[i]))
628  else
629  int(padding_ceil[i])
630  for i in range(0, len(padding_ceil))]
631  return padding_ceil
632 
633 
634 @parse_args('v', 'is', 'is', 'is', 'is', 'i')
635 def max_pool1d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode):
636  if ceil_mode and input.type().kind() != "CompleteTensorType":
637  return _unimplemented("max_pool1d_with_indices", "input size not accesible")
638  if set(_single(dilation)) != {1}:
639  return _unimplemented("max_pool1d_with_indices", "dilation")
640  if stride is None:
641  stride = kernel_size
642  padding = tuple(_single(padding))
643  if ceil_mode:
644  padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
645  padding = padding + tuple(numpy.add(padding_ceil, padding))
646  else:
647  padding = padding * 2
648  r, indices = g.op("MaxPool", input, outputs=2,
649  kernel_shape_i=_single(kernel_size),
650  pads_i=padding,
651  strides_i=_single(stride))
652  # easy but hacky way to get flattened indices values
653  # to be used to convert the indices values to non-flattened.
654  # In ONNX the indices are computed as a flatten 1-D tensor,
655  # so the values in indices are in [0, N x C x D1 x ... x Dn).
656  # To convert the indices to the same format used by Pytorch,
657  # we first execute a maxpool with a kernel and stride of 1 on the same input.
658  # This will result in a tensor of indices in which each index will have it's own value.
659  # Using this tensor as a reference, we extract the first index of each axis and substract
660  # it from each index of this axis in the indices to convert.
661  # This step will result in a tensor were each dimension has values of indices within
662  # the dimension it is in.
663  # For more information :
664  # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
665  _, flattened_indices = g.op("MaxPool", input, outputs=2,
666  kernel_shape_i=[1],
667  strides_i=[1])
668  # convert indices to have non-flattened indices values
669  s = g.op("Slice", flattened_indices, axes_i=[2], starts_i=[0], ends_i=[1])
670  indices = sub(g, indices, s)
671  return r, indices
672 
673 
674 @parse_args('v', 'is', 'is', 'is', 'is', 'i')
675 def max_pool2d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode):
676  if ceil_mode and input.type().kind() != "CompleteTensorType":
677  return _unimplemented("max_pool2d_with_indices", "input size not accesible")
678  if set(_pair(dilation)) != {1}:
679  return _unimplemented("max_pool2d_with_indices", "dilation")
680  if not stride:
681  stride = kernel_size
682  padding = tuple(_pair(padding))
683  if ceil_mode:
684  padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
685  padding = padding + tuple(numpy.add(padding_ceil, padding))
686  else:
687  padding = padding * 2
688  r, indices = g.op("MaxPool", input, outputs=2,
689  kernel_shape_i=_pair(kernel_size),
690  pads_i=padding,
691  strides_i=_pair(stride))
692  # easy but hacky way to get flattened indices values
693  # to be used to convert the indices values to non-flattened
694  # See comment in max_pool1d_with_indices for details.
695  _, flattened_indices = g.op("MaxPool", input, outputs=2,
696  kernel_shape_i=[1, 1],
697  strides_i=[1, 1])
698  # convert indices to have non-flattened indices values
699  s = g.op("Slice", flattened_indices, axes_i=[2, 3], starts_i=[0, 0], ends_i=[1, 1])
700  indices = sub(g, indices, s)
701  return r, indices
702 
703 
704 @parse_args('v', 'is', 'is', 'is', 'is', 'i')
705 def max_pool3d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode):
706  if ceil_mode and input.type().kind() != "CompleteTensorType":
707  return _unimplemented("max_pool3d_with_indices", "input size not accesible")
708  if set(_triple(dilation)) != {1}:
709  return _unimplemented("max_pool3d_with_indices", "dilation")
710  if not stride:
711  stride = kernel_size
712  padding = tuple(_triple(padding))
713  if ceil_mode:
714  padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
715  padding = padding + tuple(numpy.add(padding_ceil, padding))
716  else:
717  padding = padding * 2
718  r, indices = g.op("MaxPool", input, outputs=2,
719  kernel_shape_i=_triple(kernel_size),
720  pads_i=padding,
721  strides_i=_triple(stride))
722  # easy but hacky way to get flattened indices values
723  # to be used to convert the indices values to non-flattened
724  # See comment in max_pool1d_with_indices for details.
725  _, flattened_indices = g.op("MaxPool", input, outputs=2,
726  kernel_shape_i=[1, 1, 1],
727  strides_i=[1, 1, 1])
728  # convert indices to have non-flattened indices values
729  s = g.op("Slice", flattened_indices, axes_i=[2, 3, 4], starts_i=[0, 0, 0], ends_i=[1, 1, 1])
730  indices = sub(g, indices, s)
731  return r, indices
732 
733 
734 def _avg_pool(name, tuple_fn):
735  @parse_args('v', 'is', 'is', 'is', 'i', 'i')
736  def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include_pad):
737  if ceil_mode and input.type().kind() != "CompleteTensorType":
738  return _unimplemented(name, "input size not accesible")
739  if not stride:
740  stride = kernel_size
741  padding = tuple(tuple_fn(padding))
742  if ceil_mode:
743  padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
744  if count_include_pad:
745  input = g.op("Pad", input,
746  pads_i=((0,) * 2 + padding) * 2,
747  mode_s='constant',
748  value_f=0.)
749  padding = (0,) * len(padding)
750  if ceil_mode:
751  padding = padding + tuple(numpy.add(padding_ceil, padding))
752  else:
753  padding = padding * 2
754  output = g.op("AveragePool", input,
755  kernel_shape_i=tuple_fn(kernel_size),
756  strides_i=tuple_fn(stride),
757  pads_i=padding)
758  return output
759  return symbolic_fn
760 
761 
762 avg_pool1d = _avg_pool('avg_pool1d', _single)
763 avg_pool2d = _avg_pool('avg_pool2d', _pair)
764 avg_pool3d = _avg_pool('avg_pool3d', _triple)
765 
766 
767 def _adaptive_pool(name, type, tuple_fn, fn=None):
768  @parse_args('v', 'is')
769  def symbolic_fn(g, input, output_size):
770  # _adaptive_pool is supported for cases where output_size is 1 for all dimensions,
771  # by executing a GlobalPool.
772  # It is also supported for cases where the output size is a factor of the input size.
773  # For these cases the stride and kernel size are uniform along all the indices of
774  # the same dimension, which makes it possible to export it to ONNX.
775  # for MaxPool, GlobalMaxPool does not return indices,
776  # so we try using max_poolxd_with_indices, and if it is not possible
777  # (input is not CompleteTensorType or output size not factor of input size)
778  # then we call GlobalAveragePool and return None for the indices
779  if output_size == [1] * len(output_size) and type == "AveragePool":
780  return g.op("GlobalAveragePool", input)
781  if input.type().kind() != "CompleteTensorType":
782  if output_size == [1] * len(output_size):
783  return g.op("GlobalMaxPool", input), None
784  return _unimplemented(name, 'input size not accesible')
785  dim = input.type().sizes()[2:]
786  # verify if output size % input size = 0 for all dim
787  mod = [dim[i] % output_size[i] for i in range(0, len(dim))]
788  if mod != [0] * len(mod):
789  if output_size == [1] * len(output_size):
790  return g.op("GlobalMaxPool", input), None
791  return _unimplemented(name, 'output size that are not factor of input size')
792  k = [int(dim[i] / output_size[i]) for i in range(0, len(dim))]
793  # call max_poolxd_with_indices to get indices in the output
794  if type == "MaxPool":
795  return fn(g, input, k, k, (0,) * len(dim), (1,) * len(dim), False)
796  output = g.op(type, input,
797  kernel_shape_i=tuple_fn(k),
798  strides_i=tuple_fn(k))
799  return output
800  return symbolic_fn
801 
802 
803 adaptive_avg_pool1d = _adaptive_pool('adaptive_avg_pool1d', "AveragePool", _single)
804 adaptive_avg_pool2d = _adaptive_pool('adaptive_avg_pool2d', "AveragePool", _pair)
805 adaptive_avg_pool3d = _adaptive_pool('adaptive_avg_pool3d', "AveragePool", _triple)
806 
807 adaptive_max_pool1d = _adaptive_pool('adaptive_max_pool1d', "MaxPool", _single, max_pool1d_with_indices)
808 adaptive_max_pool2d = _adaptive_pool('adaptive_max_pool2d', "MaxPool", _pair, max_pool2d_with_indices)
809 adaptive_max_pool3d = _adaptive_pool('adaptive_max_pool3d', "MaxPool", _triple, max_pool3d_with_indices)
810 
811 
812 @parse_args('v', 'is', 'f')
813 def constant_pad_nd(g, input, padding, value):
814  from torch.autograd._functions.utils import prepare_onnx_paddings
815  mode = "constant"
816  paddings = prepare_onnx_paddings(input.type().dim(), padding)
817  return g.op("Pad", input, pads_i=paddings, mode_s=mode, value_f=value)
818 
819 
820 @parse_args('v', 'is')
821 def reflection_pad(g, input, padding):
822  from torch.autograd._functions.utils import prepare_onnx_paddings
823  mode = "reflect"
824  paddings = prepare_onnx_paddings(input.type().dim(), padding)
825  return g.op("Pad", input, pads_i=paddings, mode_s=mode)
826 
827 
828 @parse_args('v', 'is')
829 def replication_pad(g, input, padding):
830  from torch.autograd._functions.utils import prepare_onnx_paddings
831  mode = "edge"
832  paddings = prepare_onnx_paddings(input.type().dim(), padding)
833  return g.op("Pad", input, pads_i=paddings, mode_s=mode)
834 
835 
836 reflection_pad1d = reflection_pad
837 reflection_pad2d = reflection_pad
838 reflection_pad3d = reflection_pad
839 replication_pad1d = replication_pad
840 replication_pad2d = replication_pad
841 replication_pad3d = replication_pad
842 
843 
844 @parse_args('v', 'is')
845 def upsample_nearest2d(g, input, output_size):
846  height_scale = float(output_size[-2]) / input.type().sizes()[-2]
847  width_scale = float(output_size[-1]) / input.type().sizes()[-1]
848  scales = g.op("Constant", value_t=torch.tensor([1., 1., height_scale,
849  width_scale]))
850 
851  return g.op("Upsample", input, scales,
852  mode_s="nearest")
853 
854 
855 @parse_args('v', 'is', 'i')
856 def upsample_bilinear2d(g, input, output_size, align_corners):
857  if align_corners:
858  return _unimplemented("upsample_bilinear2d", "align_corners == True")
859  height_scale = float(output_size[-2]) / input.type().sizes()[-2]
860  width_scale = float(output_size[-1]) / input.type().sizes()[-1]
861  scales = g.op("Constant", value_t=torch.tensor([1., 1., height_scale,
862  width_scale]))
863  return g.op("Upsample", input, scales,
864  mode_s="linear")
865 
866 
867 def wrap_logical_op_with_cast_to_uint8(func):
868  def wrap_with_cast(g, input, other):
869  return g.op("Cast", func(g, input, other), to_i=cast_pytorch_to_onnx['Byte'])
870  return wrap_with_cast
871 
872 
873 def wrap_logical_op_with_negation(func):
874  def wrap_with_not(g, input, other):
875  return g.op("Not", func(g, input, other))
876  return wrap_with_not
877 
878 
879 @wrap_logical_op_with_cast_to_uint8
880 def eq(g, self, other):
881  return g.op("Equal", self, other)
882 
883 
884 @wrap_logical_op_with_cast_to_uint8
885 @wrap_logical_op_with_negation
886 def ne(g, self, other):
887  return g.op("Equal", self, other)
888 
889 
890 @wrap_logical_op_with_cast_to_uint8
891 def gt(g, input, other):
892  return gt_impl(g, input, other)
893 
894 
895 def gt_impl(g, input, other):
896  other = _maybe_get_scalar(other)
897  return g.op("Greater", input, _if_scalar_type_as(g, other, input))
898 
899 
900 @wrap_logical_op_with_cast_to_uint8
901 def lt(g, input, other):
902  return lt_impl(g, input, other)
903 
904 
905 def lt_impl(g, input, other):
906  other = _maybe_get_scalar(other)
907  return g.op("Less", input, _if_scalar_type_as(g, other, input))
908 
909 
910 @wrap_logical_op_with_cast_to_uint8
911 @wrap_logical_op_with_negation
912 def ge(g, input, other):
913  other = _maybe_get_scalar(other)
914  return lt_impl(g, input, _if_scalar_type_as(g, other, input))
915 
916 
917 @wrap_logical_op_with_cast_to_uint8
918 @wrap_logical_op_with_negation
919 def le(g, input, other):
920  other = _maybe_get_scalar(other)
921  return gt_impl(g, input, _if_scalar_type_as(g, other, input))
922 
923 
924 def where(g, condition, self, other):
925  return g.op("ATen", condition, self, other, operator_s="where")
926 
927 
928 @parse_args('v', 'i', 'i')
929 def log_softmax(g, input, dim=None, dtype=None):
930  # PyTorch dim and ONNX axis have different meanings.
931  # See Softmax comment for details.
932  if dim < 0:
933  dim = input.type().dim() + dim
934  if input.type().dim() != dim + 1:
935  return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input.")
936  return_op = g.op("LogSoftmax", input, axis_i=dim)
937  if dtype:
938  return_op = g.op("Cast", return_op, to_i=scalar_type_to_onnx[dtype])
939  return return_op
940 
941 
942 @parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i', 'is', 'i', 'i', 'i', 'i')
943 def _convolution(g, input, weight, bias, stride, padding, dilation,
944  transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled):
945  weight_size = weight.type().sizes()
946 
947  args = [input, weight]
948  # ONNX only supports 1D bias
949  if not bias.node().mustBeNone() and bias.type().dim() == 1:
950  args.append(bias)
951 
952  kwargs = {"kernel_shape_i": weight_size[2:],
953  "strides_i": stride,
954  # NB: ONNX supports asymmetric padding, whereas PyTorch supports only
955  # symmetric padding
956  "pads_i": padding + padding,
957  "dilations_i": dilation,
958  "group_i": groups}
959 
960  if any(o != 0 for o in output_padding):
961  # ONNX supports both output_shape and output_padding. they are equivalent expressive.
962  # output_padding is more straightforward, so we use it here.
963  # output_shape = stride * (input_shape - 1) + output_padding + kernel_shape - padding * 2
964  assert transposed
965  assert len(stride) == len(output_padding)
966  kwargs["output_padding_i"] = output_padding
967 
968  n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs)
969 
970  if not bias.node().mustBeNone() and bias.type().dim() != 1:
971  return g.op("Add", n, bias)
972  else:
973  return n
974 
975 
976 @parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i')
977 def batch_norm(g, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled):
978  input_sizes = input.type().sizes()
979  if len(input_sizes) == 2:
980  # batchnorm1d accepts 2d and 3d array, but ONNX only accepts 3d
981  input = g.op("Unsqueeze", input, axes_i=[2])
982 
983  if weight is None or weight.node().mustBeNone():
984  assert len(input_sizes) > 1
985  weight_value = torch.tensor([1.] * input_sizes[1]).type(
986  'torch.' + input.type().scalarType() + 'Tensor')
987  weight = g.op("Constant", value_t=weight_value)
988  if bias is None or bias.node().mustBeNone():
989  assert len(input_sizes) > 1
990  bias_value = torch.tensor([0.] * input_sizes[1]).type(
991  'torch.' + input.type().scalarType() + 'Tensor')
992  bias = g.op("Constant", value_t=bias_value)
993  out = g.op("BatchNormalization", input, weight, bias, running_mean, running_var,
994  epsilon_f=eps,
995  momentum_f=1 - momentum,
996  outputs=1 if not training else 5)
997  if not training:
998  if len(input_sizes) == 2:
999  out = g.op("Squeeze", out, axes_i=[2])
1000  return out
1001  else:
1002  res, new_running_mean, new_running_var, saved_mean, saved_var = out
1003  new_running_mean.setType(running_mean.type())
1004  new_running_var.setType(running_var.type())
1005  saved_mean.setUniqueName("batch_norm_dead_output-" + saved_mean.uniqueName())
1006  saved_var.setUniqueName("batch_norm_dead_output-" + saved_var.uniqueName())
1007  if len(input_sizes) == 2:
1008  res = g.op("Squeeze", res, axes_i=[2])
1009  return res
1010 
1011 
1012 @parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i')
1013 def instance_norm(g, input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled):
1014  input_sizes = input.type().sizes()
1015  if weight is None or weight.node().mustBeNone():
1016  assert len(input_sizes) > 1
1017  weight_value = torch.tensor([1.] * input_sizes[1]).type(
1018  'torch.' + input.type().scalarType() + 'Tensor')
1019  weight = g.op("Constant", value_t=weight_value)
1020  if bias is None or bias.node().mustBeNone():
1021  assert len(input_sizes) > 1
1022  bias_value = torch.tensor([0.] * input_sizes[1]).type(
1023  'torch.' + input.type().scalarType() + 'Tensor')
1024  bias = g.op("Constant", value_t=bias_value)
1025  return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps)
1026 
1027 
1028 @parse_args('v', 'i', 'i', 'i')
1029 def unfold(g, input, dimension, size, step):
1030  return g.op("ATen", input, operator_s="unfold", dimension_i=dimension, size_i=size, step_i=step)
1031 
1032 
1033 @parse_args('v', 'v', 'i')
1034 def _weight_norm(graph, v, g, dim):
1035  return graph.op("ATen", v, g, dim_i=dim, operator_s="_weight_norm")
1036 
1037 
1038 @parse_args('v', 't', 't', 't')
1039 def elu(g, input, alpha, scale, input_scale):
1040  if scale and scale != 1.:
1041  return _unimplemented("scale", "does not support scale in Elu")
1042  if input_scale and input_scale != 1.:
1043  return _unimplemented("input_scale", "does not support input_scale in Elu")
1044  # See Note [Export inplace]
1045  return g.op("Elu", input, alpha_f=_scalar(alpha))
1046 
1047 
1048 def selu(g, input):
1049  return g.op("Selu", input)
1050 
1051 
1052 @parse_args('v', 'i', 'v')
1053 def index_select(g, self, dim, index):
1054  return g.op("Gather", self, index, axis_i=dim)
1055 
1056 
1057 def index_put(g, self, indices_list_value, values, accumulate):
1058  indices_list = _unpack_list(indices_list_value)
1059  args = [self] + indices_list + [values, accumulate]
1060  return g.op("ATen", *args, operator_s='index_put')
1061 
1062 
1063 def type_as(g, self, other):
1064  if self.isTensor() and other.isTensor() and self.type().scalarType() == other.type().scalarType():
1065  return self
1066 
1067  if other.isTensor():
1068  other_type_name = other.type().scalarType()
1069  return g.op("Cast", self, to_i=cast_pytorch_to_onnx[other_type_name])
1070  else:
1071  # We don't know the type of other, bail by emitting ATen
1072  return g.op("ATen", self, other, operator_s="type_as")
1073 
1074 
1075 @parse_args('v', 'is', 'v', 'v', 'f', 'i')
1076 def layer_norm(g, self, normalized_shape, weight, bias, eps, cudnn_enable):
1077  return g.op("ATen", self, weight, bias, normalized_shape_i=normalized_shape,
1078  eps_f=eps, cudnn_enable_i=cudnn_enable, operator_s="layer_norm")
1079 
1080 
1081 # ignore clone operators that are inserted by PyTorch autograd
1082 def clone(g, input):
1083  return input
1084 
1085 
1086 def abs(g, self):
1087  return g.op("Abs", self)
1088 
1089 
1090 def log(g, self):
1091  return g.op("Log", self)
1092 
1093 
1094 def pow(g, self, exponent):
1095  exponent = _maybe_get_scalar(exponent)
1096  return g.op("Pow", self, _if_scalar_type_as(g, exponent, self))
1097 
1098 
1099 def clamp(g, self, min, max):
1100  # min or max may be None that we need to dispatch to
1101  # Clip separately, as ONNX does not have None syntax
1102  if min.node().mustBeNone():
1103  return clamp_max(g, self, max)
1104  elif max.node().mustBeNone():
1105  return clamp_min(g, self, min)
1106  else:
1107  min = _parse_arg(min, 'f')
1108  max = _parse_arg(max, 'f')
1109  return g.op("Clip", self, min_f=min, max_f=max)
1110 
1111 
1112 @parse_args('v', 'f')
1113 def clamp_min(g, self, min):
1114  return g.op("Clip", self, min_f=min)
1115 
1116 
1117 @parse_args('v', 'f')
1118 def clamp_max(g, self, max):
1119  return g.op("Clip", self, max_f=max)
1120 
1121 
1122 # torch.max (same for torch.min) actually has two interfaces smashed together:
1123 # torch.max(x, dim, keepdim) and torch.max(x, y)
1124 def max(g, self, dim_or_y=None, keepdim=None):
1125  if dim_or_y is None and keepdim is None:
1126  return g.op("ReduceMax", self, keepdims_i=0)
1127  if keepdim is None:
1128  return g.op("Max", self, dim_or_y)
1129  else:
1130  dim = _get_const(dim_or_y, 'i', 'dim')
1131  keepdim = _get_const(keepdim, 'i', 'keepdim')
1132  # TODO: export it as ReduceMax
1133  return g.op("ATen",
1134  self,
1135  operator_s="max",
1136  dim_i=dim,
1137  keepdim_i=keepdim,
1138  outputs=2)
1139 
1140 
1141 def min(g, self, dim_or_y=None, keepdim=None):
1142  if dim_or_y is None and keepdim is None:
1143  return g.op("ReduceMin", self, keepdims_i=0)
1144  if keepdim is None:
1145  return g.op("Min", self, dim_or_y)
1146  else:
1147  dim = _get_const(dim_or_y, 'i', 'dim')
1148  keepdim = _get_const(keepdim, 'i', 'keepdim')
1149  # TODO: export it as ReduceMax
1150  return g.op("ATen",
1151  self,
1152  operator_s="min",
1153  dim_i=dim,
1154  keepdim_i=keepdim,
1155  outputs=2)
1156 
1157 
1158 def exp(g, self):
1159  return g.op("Exp", self)
1160 
1161 
1162 @parse_args('v', 'f', 'i')
1163 def dropout(g, input, p, train):
1164  if not train: # in eval mode, dropout is non-op
1165  return input
1166  r, _ = g.op("Dropout", input, ratio_f=p, outputs=2)
1167  return r
1168 
1169 
1170 def _unsupported_dropout(name):
1171  @parse_args('v', 'f', 'i')
1172  def feature_dropout(g, input, p, train):
1173  # NB: In inference mode, FeatureDropout is exported as an identity op.
1174  from torch.onnx.symbolic import _unimplemented
1175  if train:
1176  return _unimplemented(name, "training mode")
1177  return input
1178  return feature_dropout
1179 
1180 
1181 feature_dropout = _unsupported_dropout("feature_dropout")
1182 alpha_dropout = _unsupported_dropout("alpha_dropout")
1183 feature_alpha_dropout = _unsupported_dropout("feature_alpha_dropout")
1184 
1185 # See Note [Export inplace]
1186 dropout_ = dropout
1187 feature_dropout_ = feature_dropout
1188 alpha_dropout_ = alpha_dropout
1189 feature_alpha_dropout_ = feature_alpha_dropout
1190 
1191 
1192 @parse_args('v', 't', 'i', 'i')
1193 def norm(g, self, p, dim, keepdim):
1194  if p == 1:
1195  f = _reduce_op_symbolic("ReduceL1")
1196  elif p == 2:
1197  f = _reduce_op_symbolic("ReduceL2")
1198  else:
1199  raise RuntimeError("ONNX export only p-norms with p of 1 or 2")
1200  return f(g, self, dim=dim, keepdim=keepdim)
1201 
1202 
1203 @parse_args('v', 'v', 'v', 'i')
1204 def conv_tbc(g, input, weight, bias, pad):
1205  return g.op("ATen", input, weight, bias, operator_s="conv_tbc", pad_i=pad)
1206 
1207 
1208 @parse_args('v', 'i', 'i')
1209 def _unique(g, input, sorted, return_inverse):
1210  return g.op("ATen", input, operator_s="_unique", sorted_i=sorted,
1211  return_inverse_i=return_inverse, outputs=2)
1212 
1213 
1214 # Metaprogram symbolics for each ATen native specialized cast operator.
1215 # For e.g. we specify a function named `_cast_uint8_t` that instantiates an
1216 # ONNX cast node with `to` attribute 'UINT8'
1217 #
1218 # TODO: remove these once we support Type's in the JIT IR and we can once again
1219 # use the unified toType operator
1220 cast_pytorch_to_onnx = {
1221  'Byte': torch.onnx.TensorProtoDataType.UINT8,
1222  'Char': torch.onnx.TensorProtoDataType.INT8,
1223  'Double': torch.onnx.TensorProtoDataType.DOUBLE,
1224  'Float': torch.onnx.TensorProtoDataType.FLOAT,
1225  'Half': torch.onnx.TensorProtoDataType.FLOAT16,
1226  'Int': torch.onnx.TensorProtoDataType.INT32,
1227  'Long': torch.onnx.TensorProtoDataType.INT64,
1228  'Short': torch.onnx.TensorProtoDataType.INT16,
1229 }
1230 
1231 scalar_name_to_pytorch = {
1232  'uint8_t': 'Byte',
1233  'int8_t': 'Char',
1234  'double': 'Double',
1235  'float': 'Float',
1236  'half': 'Half',
1237  'int': 'Int',
1238  'int64_t': 'Long',
1239  'int16_t': 'Short',
1240 }
1241 
1242 
1243 # This indicates each scalar type's corresponding
1244 # torch type. Related source:
1245 # https://github.com/pytorch/pytorch/blob/da7468853ae322252270bbb58032668bd21b7457/c10/core/ScalarType.h
1246 scalar_type_to_pytorch_type = [
1247  torch.uint8, # 0
1248  torch.int8, # 1
1249  torch.short, # 2
1250  torch.int, # 3
1251  torch.int64, # 4
1252  torch.half, # 5
1253  torch.float, # 6
1254  torch.double, # 7
1255 ]
1256 
1257 
1258 def _cast_func_template(to_i, g, input, non_blocking):
1259  return g.op("Cast", input, to_i=to_i)
1260 
1261 
1262 for k, v in cast_pytorch_to_onnx.items():
1263  name = '_cast_{}'.format(k)
1264  globals()[name] = parse_args('v', 'i')(partial(_cast_func_template, v))
1265 
1266 
1267 scalar_type_to_onnx = [
1268  cast_pytorch_to_onnx["Byte"],
1269  cast_pytorch_to_onnx["Char"],
1270  cast_pytorch_to_onnx["Short"],
1271  cast_pytorch_to_onnx["Int"],
1272  cast_pytorch_to_onnx["Long"],
1273  cast_pytorch_to_onnx["Half"],
1274  cast_pytorch_to_onnx["Float"],
1275  cast_pytorch_to_onnx["Double"],
1276 ]
1277 
1278 
1279 @parse_args('v', 'i', 'v', 'v')
1280 def zeros(g, sizes, dtype, layout, device):
1281  # NOTE: no way to set device and layout in ONNX, so we ignore it
1282  return g.op("ConstantOfShape", sizes,
1283  value_t=torch.tensor(0, dtype=scalar_type_to_pytorch_type[dtype]))
1284 
1285 
1286 @parse_args('v', 'i', 'v', 'v')
1287 def zeros_like(g, input, dtype, layout, device):
1288  shape = g.op("Shape", input)
1289  return g.op("ConstantOfShape", shape,
1290  value_t=torch.tensor(0, dtype=scalar_type_to_pytorch_type[dtype]))
1291 
1292 
1293 @parse_args('v', 'i', 'v', 'v')
1294 def ones(g, sizes, dtype, layout, device):
1295  return g.op("ConstantOfShape", sizes,
1296  value_t=torch.tensor(1, dtype=scalar_type_to_pytorch_type[dtype]))
1297 
1298 
1299 @parse_args('v', 'i', 'v', 'v')
1300 def ones_like(g, input, dtype, layout, device):
1301  shape = g.op("Shape", input)
1302  return g.op("ConstantOfShape", shape,
1303  value_t=torch.tensor(1, dtype=scalar_type_to_pytorch_type[dtype]))
1304 
1305 
1306 def full(g, sizes, value, dtype, layout, device):
1307  const_value = _maybe_get_const(value, 't')
1308  if _is_value(const_value):
1309  tmp = zeros(sizes, dtype, layout, device)
1310  return add(tmp, value, g.op("Constant", value_t=torch.tensor(1)))
1311  else:
1312  dtype = _get_const(dtype, 'i', 'dtype')
1313  return g.op("ConstantOfShape", sizes,
1314  value_t=torch.tensor(const_value, dtype=scalar_type_to_pytorch_type[dtype]))
1315 
1316 
1317 @parse_args('v', 'f', 'i', 'v', 'v')
1318 def full_like(g, input, fill_value, dtype, layout, device):
1319  shape = g.op("Shape", input)
1320  return g.op("ConstantOfShape", shape,
1321  value_t=torch.tensor(fill_value, dtype=scalar_type_to_pytorch_type[dtype]))
1322 
1323 
1324 @parse_args('v', 'v', 'v', 'v', 'i')
1325 def slice(g, self, dim, start, end, step):
1326  if step != 1:
1327  _unimplemented("slice", "step!=1 is currently not supported")
1328  if start.node().kind() != 'onnx::Constant' or \
1329  end.node().kind() != 'onnx::Constant' or dim.node().kind() != 'onnx::Constant':
1330  start_unsqueezed = g.op("Unsqueeze", start, axes_i=[0])
1331  end_unsqueezed = g.op("Unsqueeze", end, axes_i=[0])
1332  dim_unsqueezed = g.op("Unsqueeze", dim, axes_i=[0])
1333  return g.op("DynamicSlice", self, start_unsqueezed, end_unsqueezed, dim_unsqueezed)
1334  else:
1335  start = _parse_arg(start, 'i')
1336  end = _parse_arg(end, 'i')
1337  dim = _parse_arg(dim, 'i')
1338  return g.op("Slice", self, axes_i=[dim], starts_i=[start], ends_i=[end])
1339 
1340 
1341 @parse_args('v', 'f', 'f')
1342 def hardtanh(g, self, min_val, max_val):
1343  return g.op("Clip", self, min_f=min_val, max_f=max_val)
1344 
1345 
1346 def alias(g, self):
1347  return self
1348 
1349 
1350 @parse_args('v', 'i')
1351 def unsqueeze(g, self, dim):
1352  return g.op("Unsqueeze", self, axes_i=[dim])
1353 
1354 
1355 @parse_args('v', 'i', 'i', 'i', 'i')
1356 def topk(g, self, k, dim, largest, sorted, out=None):
1357  if out is not None:
1358  _unimplemented("TopK", "Out parameter is not supported for topk")
1359  if not largest:
1360  _unimplemented("TopK", "Ascending TopK is not supported")
1361 
1362  return g.op("TopK", self, k_i=k, axis_i=dim, outputs=2)
1363 
1364 
1365 def to(g, self, *args):
1366  # ONNX doesn't have a concept of a device, so we ignore device casts
1367  if len(args) == 3:
1368  if args[0].type().isSubtypeOf(ListType.ofInts()):
1369  # aten::to(Tensor, Device, bool, bool)
1370  return self
1371  else:
1372  # aten::to(Tensor, ScalarType, bool, bool)
1373  dtype = _get_const(args[0], 'i', 'dtype')
1374  return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
1375  elif len(args) == 4:
1376  # aten::to(Tensor, Device, ScalarType, bool, bool)
1377  dtype = _get_const(args[1], 'i', 'dtype')
1378  return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
1379  elif len(args) == 5:
1380  # aten::to(Tensor, ScalarType, Layout, Device, bool, bool) -> Tensor
1381  dtype = _get_const(args[0], 'i', 'dtype')
1382  # Layout and device are ignored
1383  return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
1384  else:
1385  raise NotImplementedError("Unknown aten::to signature")
1386 
1387 
1388 def repeat(g, self, repeats):
1389  if not _is_value(repeats):
1390  repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
1391  const_repeats = _maybe_get_const(repeats, 'is')
1392 
1393  if self.isTensor() and not _is_value(const_repeats):
1394  sizes = self.type().sizes()
1395  diff_dims = len(const_repeats) - len(sizes)
1396  if diff_dims > 0:
1397  self = view(g, self, [1] * diff_dims + sizes)
1398  return g.op("Tile", self, repeats)
1399 
1400 
1401 @parse_args('v', 'i')
1402 def pixel_shuffle(g, self, upscale_factor):
1403  dims = self.type().sizes()
1404  if len(dims) != 4:
1405  return _unimplemented("pixel_shuffle", "only support 4d input")
1406  output_channel = dims[1] // upscale_factor // upscale_factor
1407  after_view = view(g, self, [-1, upscale_factor, upscale_factor,
1408  output_channel, dims[2], dims[3]])
1409  after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3])
1410  return view(g, after_transpose,
1411  [-1, output_channel, dims[2] * upscale_factor, dims[3] *
1412  upscale_factor])
1413 
1414 
1415 @parse_args('v', 'i', 'v', 'v', 'f', 'i')
1416 def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):
1417  return g.op("ATen", input, weight, bias, num_groups_i=num_groups,
1418  eps_f=eps, cudnn_enabled_i=cudnn_enabled, operator_s="group_norm")
1419 
1420 
1421 def _generic_rnn(g, variant, input, initial_states, all_weights, has_biases,
1422  num_layers, dropout, train, bidirectional, batch_first=None, batch_sizes=None):
1423  weights_per_layer = 4 if has_biases else 2
1424  assert len(all_weights) == num_layers * weights_per_layer * (1 + bidirectional)
1425  layer_weights = [all_weights[i:i + weights_per_layer] for i in range(0, len(all_weights), weights_per_layer)]
1426  if batch_first:
1427  return _unimplemented("RNN/GRU/LSTM", "batch_first")
1428  if dropout and train:
1429  return _unimplemented("RNN/GRU/LSTM", "dropout in training mode")
1430 
1431  if variant.startswith('RNN'):
1432  nonlinearity = variant[4:].lower()
1433  variant = 'RNN'
1434 
1435  w_hh = all_weights[1]
1436  hidden_size = w_hh.type().sizes()[1]
1437 
1438  unidirectional = not bidirectional
1439 
1440  prev_output = input
1441 
1442  h_outs = []
1443  if variant == 'RNN' or variant == 'GRU':
1444  h0 = initial_states
1445  elif variant == 'LSTM':
1446  h0, c0 = initial_states
1447  c_outs = []
1448 
1449  sequence_lens = unused(g) if batch_sizes is None else batch_sizes
1450 
1451  if variant == 'GRU':
1452  # pytorch is reset, input, hidden
1453  # onnx is input, reset, hidden
1454  reform_permutation = [(1, 2), (0, 1), (2, 3)]
1455  elif variant == 'LSTM':
1456  # pytorch is input, forget, cell, output.
1457  # onnx is input, output, forget, cell.
1458  reform_permutation = [(0, 1), (3, 4), (1, 3)]
1459 
1460  def reform_weights(g, w, n, intervals):
1461  slices = [g.op('Slice', w, axes_i=[0], starts_i=[x * n], ends_i=[y * n]) for x, y in intervals]
1462  return g.op('Concat', *slices, axis_i=0)
1463 
1464  def transform_weights(layer_index):
1465  if variant == 'RNN':
1466  weight_ih, weight_hh, bias_ih, bias_hh = layer_weights[layer_index]
1467  elif variant == 'GRU' or variant == 'LSTM':
1468  weight_ih, weight_hh, bias_ih, bias_hh = \
1469  [reform_weights(g, w, hidden_size, reform_permutation) for w in layer_weights[layer_index]]
1470  bias_concat = g.op('Concat', bias_ih, bias_hh, axis_i=0)
1471 
1472  return tuple(g.op('Unsqueeze', x, axes_i=[0]) for x in (weight_ih, weight_hh, bias_concat))
1473 
1474  def retrieve_state(x, start, end):
1475  return x if num_layers == 1 else g.op('Slice', x, axes_i=[0], starts_i=[start], ends_i=[end])
1476 
1477  for i in range(num_layers):
1478  if unidirectional:
1479  weight_ih, weight_hh, bias_concat = transform_weights(i)
1480  state_indices = i, i + 1
1481  else:
1482  weight_ih_f, weight_hh_f, bias_f = transform_weights(2 * i)
1483  weight_ih_b, weight_hh_b, bias_b = transform_weights(2 * i + 1)
1484 
1485  weight_ih = g.op('Concat', weight_ih_f, weight_ih_b, axis_i=0)
1486  weight_hh = g.op('Concat', weight_hh_f, weight_hh_b, axis_i=0)
1487  bias_concat = g.op('Concat', bias_f, bias_b, axis_i=0)
1488 
1489  state_indices = 2 * i, 2 * i + 2
1490 
1491  inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens]
1492 
1493  inputs.append(retrieve_state(h0, *state_indices))
1494  if variant == 'LSTM':
1495  inputs.append(retrieve_state(c0, *state_indices))
1496 
1497  extra_kwargs = {} if unidirectional else {'direction_s': 'bidirectional'}
1498  if variant == 'RNN':
1499  prev_output, h_out = g.op('RNN', *inputs, outputs=2,
1500  hidden_size_i=hidden_size,
1501  activations_s=[nonlinearity],
1502  **extra_kwargs)
1503  elif variant == 'GRU':
1504  prev_output, h_out = g.op('GRU', *inputs, outputs=2,
1505  hidden_size_i=hidden_size,
1506  linear_before_reset_i=1,
1507  **extra_kwargs)
1508  elif variant == 'LSTM':
1509  prev_output, h_out, c_out = g.op('LSTM', *inputs, outputs=3,
1510  hidden_size_i=hidden_size,
1511  **extra_kwargs)
1512 
1513  if bidirectional:
1514  # The ONNX RNN/GRU/LSTM produce an output of dimensions
1515  # seq_len, num_directions, batch, hidden_size
1516  # We have to convert to match pytorch's expected
1517  # seq_len, batch, num_directions * hidden_size
1518  # by first moving num_directions before hidden_size with
1519  # Transpose, and then combining it with hidden_size
1520  # with Reshape.
1521  prev_output = g.op('Transpose', prev_output, perm_i=[0, 2, 1, 3])
1522  prev_output = g.op('Reshape', prev_output, g.op('Constant', value_t=torch.LongTensor([0, 0, -1])))
1523  else:
1524  prev_output = g.op('Squeeze', prev_output, axes_i=[1])
1525 
1526  h_outs.append(h_out)
1527  if variant == 'LSTM':
1528  c_outs.append(c_out)
1529  h_outs = h_out if num_layers == 1 else g.op('Concat', *h_outs, axis_i=0)
1530  if variant == 'RNN' or variant == 'GRU':
1531  return prev_output, h_outs
1532  elif variant == 'LSTM':
1533  c_outs = c_out if num_layers == 1 else g.op('Concat', *c_outs, axis_i=0)
1534  return prev_output, h_outs, c_outs
1535 
1536 
1537 @parse_args('v', 'v', 'v', 'i', 'i', 'f', 'i', 'i', 'i')
1538 def _lstm_full(g, input, hidden_v, weight_v, has_biases, num_layers, dropout, train, bidirectional, batch_first):
1539  hidden, weight = _unpack_list(hidden_v), _unpack_list(weight_v)
1540  return _generic_rnn(g, 'LSTM', input, hidden, weight, has_biases, num_layers,
1541  dropout, train, bidirectional, batch_first)
1542 
1543 
1544 @parse_args('v', 'v', 'v', 'v', 'i', 'i', 'f', 'i', 'i')
1545 def _lstm_packed(g, input, batch_sizes, hidden_v, weight_v, has_biases, num_layers, dropout, train, bidirectional):
1546  hidden, weight = _unpack_list(hidden_v), _unpack_list(weight_v)
1547  return _generic_rnn(g, 'LSTM', input, hidden, weight, has_biases, num_layers,
1548  dropout, train, bidirectional, batch_sizes=batch_sizes)
1549 
1550 
1551 def lstm(g, *args):
1552  if _is_tensor_list(args[3]):
1553  return _lstm_packed(g, *args)
1554  else:
1555  return _lstm_full(g, *args)
1556 
1557 
1558 def _one_hidden_rnn(kind):
1559  @parse_args('v', 'v', 'v', 'i', 'i', 'f', 'i', 'i', 'i')
1560  def _rnn_full(g, input, hidden, weight_v, has_biases, num_layers, dropout, train, bidirectional, batch_first):
1561  weight = _unpack_list(weight_v)
1562  return _generic_rnn(g, kind, input, hidden, weight, has_biases, num_layers,
1563  dropout, train, bidirectional, batch_first)
1564 
1565  @parse_args('v', 'v', 'v', 'v', 'i', 'i', 'f', 'i', 'i')
1566  def _rnn_packed(g, input, batch_sizes, hidden, weight_v, has_biases, num_layers, dropout, train, bidirectional):
1567  weight = _unpack_list(weight_v)
1568  return _generic_rnn(g, kind, input, hidden, weight, has_biases, num_layers,
1569  dropout, train, bidirectional, batch_sizes=batch_sizes)
1570 
1571  def symbolic(g, *args):
1572  if _is_tensor_list(args[3]):
1573  return _rnn_packed(g, *args)
1574  else:
1575  return _rnn_full(g, *args)
1576 
1577  return symbolic
1578 
1579 
1580 gru = _one_hidden_rnn('GRU')
1581 rnn_tanh = _one_hidden_rnn('RNN_TANH')
1582 rnn_relu = _one_hidden_rnn('RNN_RELU')
1583 
1584 
1585 @parse_args('v', 'i')
1586 def _dim_arange(g, like, dim):
1587  return g.op('ATen', like, dim_i=dim, operator_s='_dim_arange')
1588 
1589 
1590 def detach(g, input):
1591  # Erase aten::detach nodes because ONNX is inference only
1592  return input
1593 
1594 
1595 def contiguous(g, input):
1596  return input
1597 
1598 
1599 @parse_args('v', 'v', 'i')
1600 def _pack_padded_sequence(g, input, lengths, batch_first):
1601  # There currently is no PackPadded operator in ONNX. We rely on an
1602  # optimization pass to remove this later. It is an error if all
1603  # PackPadded operators cannot be optimized out.
1604  if batch_first:
1605  input = g.op('Transpose', input, perm_i=[1, 0, 2])
1606  if not lengths.type().isSubtypeOf(torch._C.TensorType.get()):
1607  raise RuntimeError("Lengths must be a Tensor for ONNX export")
1608  # We know it's a TensorType so this check is now safe.
1609  # It's really only necessary because those operators expand to something that
1610  # only works with int32 types in Caffe2...
1611  if lengths.type().scalarType() != 'Int':
1612  lengths = _cast_Int(g, lengths, False)
1613  return g.op("prim::PackPadded", input, lengths, outputs=2)
1614 
1615 
1616 @parse_args('v', 'v', 'i', 't', 'v')
1617 def _pad_packed_sequence(g, data, batch_sizes, batch_first, padding_value, total_length):
1618  # Ignore total_length as it is not supported in _symbolic_pad_packed_sequence
1619  # It is only useful/used when training using data_parallel model, so
1620  # It shouldn't be relevant for ONNX anyway
1621  data, lengths = g.op("prim::PadPacked", data, batch_sizes, outputs=2)
1622  if batch_first:
1623  data = g.op('Transpose', data, perm_i=[1, 0, 2])
1624  return data, lengths
1625 
1626 
1627 def randn(g, *shapes):
1628  shapes_list = list(shapes)
1629  shape = _maybe_get_const(shapes_list[0], "is")
1630  return g.op('RandomNormal', shape_i=shape)
1631 
1632 
1633 @parse_args('v', 'f', 'f', 'i', 'none')
1634 def rrelu(g, input, lower, upper, training, generator):
1635  p = g.op('RandomUniformLike', input, high_f=upper, low_f=lower)
1636  return g.op('PRelu', input, p)
1637 
1638 
1639 @parse_args('v')
1640 def log_sigmoid(g, input):
1641  p = g.op('Sigmoid', input)
1642  return g.op('Log', p)
1643 
1644 
1645 @parse_args('v')
1646 def erf(g, input):
1647  return g.op('Erf', input)
1648 
1649 
1650 @parse_args('v', 'i', 'i')
1651 def flatten(g, input, start_dim, end_dim):
1652  dim = input.type().dim()
1653  if end_dim < 0 :
1654  end_dim = dim + end_dim
1655  # use ONNX's Flatten operator for cases where the output shape is 2D
1656  if start_dim == 1 and end_dim == dim - 1 :
1657  return g.op("Flatten", input, axis_i=start_dim)
1658  if start_dim == 0 and end_dim == dim - 2 :
1659  return g.op("Flatten", input, axis_i=end_dim + 1)
1660  # use Reshape for cases where the output shape is not 2D
1661  if input.type().kind() != "CompleteTensorType":
1662  return _unimplemented("flatten", "input size not accesible")
1663  input_dims = input.type().sizes()
1664  output_dims = []
1665  for i in range(0, dim):
1666  if start_dim < i and end_dim >= i:
1667  output_dims[start_dim] = output_dims[start_dim] * input_dims[i]
1668  else:
1669  output_dims.append(input_dims[i])
1670  shape = g.op("Constant", value_t=torch.LongTensor(output_dims))
1671  p = _reshape_from_tensor(g, input, shape)
1672  return p
1673 
1674 
1675 @parse_args('v')
1676 def nonzero(g, input):
1677  return t(g, g.op('NonZero', input))
1678 
1679 
1680 @parse_args('v')
1681 def isnan(g, input):
1682  output = g.op('IsNaN', input)
1683  output = _cast_func_template(cast_pytorch_to_onnx['Byte'], g, output, None)
1684  return output
1685 
1686 
1687 @parse_args('v', 'i', 'i', 'i')
1688 def narrow(g, input, dim, start, length):
1689  return g.op("Slice", input, axes_i=[dim], starts_i=[start], ends_i=[start + length])
1690 
1691 
1692 @parse_args('v', 'i', 'i')
1693 def argmax(g, input, dim, keepdim):
1694  return g.op('ArgMax', input, axis_i=dim, keepdims_i=keepdim)
1695 
1696 
1697 @parse_args('v', 'i', 'i')
1698 def argmin(g, input, dim, keepdim):
1699  return g.op('ArgMin', input, axis_i=dim, keepdims_i=keepdim)
Module caffe2.python.layers.split.
Module caffe2.python.helpers.dropout.