4 from torch._C import TensorType, ListType, OptionalType
14 from collections
import Iterable
15 from functools
import partial, wraps
72 def _parse_arg(value, desc):
75 if desc ==
'v' or not _is_value(value):
77 if value.node().kind() !=
'onnx::Constant':
78 raise RuntimeError(
"ONNX symbolic expected a constant value in the trace")
79 tval = value.node()[
'value']
87 return [int(v)
for v
in tval]
89 raise RuntimeError(
"Casting constants to `{}` is not implemented".format(desc))
92 def _maybe_get_const(value, desc):
93 if _is_value(value)
and value.node().kind() ==
'onnx::Constant':
94 return _parse_arg(value, desc)
98 def _maybe_get_scalar(value):
99 value_t = _maybe_get_const(value,
't')
100 if isinstance(value_t, torch.Tensor)
and value_t.shape == ():
105 def _get_const(value, desc, arg_name):
106 if _is_value(value)
and value.node().kind() !=
'onnx::Constant':
107 raise RuntimeError(
"ONNX symbolic expected a constant value of the {} argument".format(arg_name))
108 return _parse_arg(value, desc)
111 def _unpack_list(list_value):
112 list_node = list_value.node()
113 assert list_node.kind() ==
"prim::ListConstruct" 114 return list(list_node.inputs())
117 def parse_args(*arg_descriptors):
119 def wrapper(g, *args):
121 assert len(arg_descriptors) >= len(args)
122 args = [_parse_arg(arg, arg_desc)
for arg, arg_desc
in zip(args, arg_descriptors)]
126 wrapper = wraps(fn)(wrapper)
134 """Convert a scalar tensor into a Python value.""" 135 assert x.numel() == 1
139 def _if_scalar_type_as(g, self, tensor):
141 Convert self into the same type of tensor, as necessary. 143 We only support implicit casting for scalars, so we never 144 actually need to insert an ONNX cast operator here; just 147 if isinstance(self, torch._C.Value):
149 elif tensor.type().kind() ==
"DimensionedTensorType" or tensor.type().kind() ==
"CompleteTensorType":
150 ty = tensor.type().scalarType().lower()
151 return getattr(self, ty)()
157 return isinstance(x, torch._C.Value)
160 def _is_tensor_list(x):
161 return x.type().isSubtypeOf(ListType.ofTensors())
164 def _unimplemented(op, msg):
165 warnings.warn(
"ONNX export failed on " + op +
" because " + msg +
" not supported")
168 def _try_get_scalar_type(*args):
171 return arg.type().scalarType()
204 _default_onnx_opset_version = 9
205 _onnx_master_opset = 10
206 _onnx_stable_opsets = [9]
207 _export_onnx_opset_version = _default_onnx_opset_version
210 def _set_opset_version(opset_version):
211 global _export_onnx_opset_version
212 if opset_version == _default_onnx_opset_version:
214 if opset_version
in _onnx_stable_opsets + [_onnx_master_opset]:
215 _export_onnx_opset_version = opset_version
217 raise ValueError(
"Unsupported ONNX opset version: " + str(opset_version))
255 n = g.op(
"prim::Constant")
256 n.setType(OptionalType.ofTensor())
260 def _shape_as_tensor(g, input):
261 return g.op(
'Shape', input)
264 def _reshape_from_tensor(g, input, shape):
265 return g.op(
'Reshape', input, shape)
268 def reshape(g, self, shape):
269 return view(g, self, shape)
272 def reshape_as(g, self, other):
273 shape = g.op(
'Shape', other)
274 return reshape(g, self, shape)
277 def add(g, self, other, alpha=None):
279 if alpha
and _scalar(_maybe_get_scalar(alpha)) != 1:
280 return _unimplemented(
"add",
"alpha != 1")
282 other = _maybe_get_scalar(other)
283 return g.op(
"Add", self, _if_scalar_type_as(g, other, self))
286 def sub(g, self, other, alpha=None):
288 if alpha
and _scalar(_maybe_get_scalar(alpha)) != 1:
289 return _unimplemented(
"sub",
"alpha != 1")
291 other = _maybe_get_scalar(other)
292 return g.op(
"Sub", self, _if_scalar_type_as(g, other, self))
295 def rsub(g, self, other, alpha=None):
296 other = _maybe_get_scalar(other)
297 other = _if_scalar_type_as(g, other, self)
298 return sub(g, other, self, alpha=alpha)
301 def mul(g, self, other):
303 other = _maybe_get_scalar(other)
304 return g.op(
"Mul", self, _if_scalar_type_as(g, other, self))
307 def div(g, self, other):
309 other = _maybe_get_scalar(other)
310 return g.op(
"Div", self, _if_scalar_type_as(g, other, self))
313 def reciprocal(g, self):
314 return g.op(
"Div", _if_scalar_type_as(g, torch.ones(1), self), self)
317 @parse_args(
'v',
'i')
318 def cat(g, tensor_list, dim):
319 tensors = _unpack_list(tensor_list)
320 return g.op(
"Concat", *tensors, axis_i=dim)
323 @parse_args(
'v',
'i')
324 def stack(g, tensor_list, dim):
325 unsqueezed = [g.op(
"Unsqueeze", t, axes_i=[dim])
for t
in _unpack_list(tensor_list)]
326 return g.op(
"Concat", *unsqueezed, axis_i=dim)
329 def mm(g, self, other):
332 ty = _try_get_scalar_type(self, other).lower()
333 C = g.constant(0, [1], ty)
334 return g.op(
"Gemm", self, other, C, beta_f=0.0, alpha_f=1.0)
337 def bmm(g, self, other):
338 return g.op(
"MatMul", self, other)
341 def matmul(g, self, other):
342 return g.op(
"MatMul", self, other)
345 @parse_args(
'v',
'v',
'v',
't',
't')
346 def addmm(g, self, mat1, mat2, beta, alpha):
347 return g.op(
"Gemm", mat1, mat2, self, beta_f=_scalar(beta), alpha_f=_scalar(alpha))
351 return g.op(
"Neg", self)
355 return g.op(
"Sqrt", self)
359 return g.op(
"Tanh", self)
363 return g.op(
"Sin", self)
367 return g.op(
"Cos", self)
371 return g.op(
"Tan", self)
375 return g.op(
"Asin", self)
379 return g.op(
"Acos", self)
383 return g.op(
"Atan", self)
386 def sigmoid(g, self):
387 return g.op(
"Sigmoid", self)
390 def _reduce_op_symbolic(onnx_op_name):
391 def symbolic(g, self, dim=None, keepdim=None):
394 return g.op(onnx_op_name, self, keepdims_i=0)
397 dim, keepdim = _get_const(dim,
'i',
'dim'), _get_const(keepdim,
'i',
'keepdim')
398 return g.op(onnx_op_name, self, axes_i=[dim], keepdims_i=keepdim)
401 mean = _reduce_op_symbolic(
'ReduceMean')
402 sum = _reduce_op_symbolic(
'ReduceSum')
403 prod = _reduce_op_symbolic(
'ReduceProd')
406 @parse_args(
'v',
'i')
407 def cumsum(g, input, dim):
408 return g.op(
"ATen", input, operator_s=
"cumsum", dim_i=dim)
412 return g.op(
"Transpose", self, perm_i=(1, 0))
415 def expand(g, self, size, implicit):
416 size = _maybe_get_const(size,
'is')
417 if not _is_value(size):
418 size = g.op(
"Constant", value_t=torch.LongTensor(size))
419 return g.op(
"Expand", self, size)
422 def expand_as(g, self, other):
423 shape = g.op(
"Shape", other)
424 return g.op(
"Expand", self, shape)
427 def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
428 return g.op(
"Gather", weight, indices)
431 @parse_args(
'v',
'v',
'v',
'i',
'i',
'i')
443 operator_s=
"embedding_bag",
445 scale_grad_by_freq_i=scale_grad_by_freq,
450 def size(g, self, dim):
451 full_shape = g.op(
"Shape", self)
452 return select(g, full_shape, g.op(
"Constant", value_t=
torch.tensor([0])), dim)
455 @parse_args(
'v',
'i',
'i')
456 def transpose(g, self, dim0, dim1):
461 axes = list(range(self.type().dim()))
462 axes[dim0], axes[dim1] = axes[dim1], axes[dim0]
463 return g.op(
"Transpose", self, perm_i=axes)
466 @parse_args(
'v',
'is')
467 def permute(g, self, dims):
468 if dims == list(range(0, len(dims))):
470 return g.op(
"Transpose", self, perm_i=dims)
473 def view(g, self, size):
474 size = _maybe_get_const(size,
'is')
479 self_sizes = self.type().sizes()
480 if self_sizes
and len(size) == 2
and self_sizes[0] == size[0]:
481 return g.op(
"Flatten", self, axis_i=1)
482 shape = g.op(
"Constant", value_t=torch.LongTensor(size))
483 return g.op(
"Reshape", self, shape)
486 def prim_ConstantSplit(g, self, split_size, dim):
487 size = self.type().sizes()[dim]
488 splits = [split_size] * (size // split_size)
489 leftover = size % split_size
491 splits.append(leftover)
492 return g.op(
"Split", self, split_i=splits, axis_i=dim, outputs=len(splits))
499 def prim_ConstantChunk(g, self, chunks, dim):
500 split_size = (self.type().sizes()[dim] + chunks - 1) // chunks
501 return prim_ConstantSplit(g, self, split_size, dim)
504 @parse_args(
'v',
'i',
'i')
505 def split(g, self, split_size, dim):
506 size = self.type().sizes()[dim]
507 splits = [split_size] * (size // split_size)
508 leftover = size % split_size
510 splits.append(leftover)
511 return g.op(
"Split", self, split_i=splits, axis_i=dim, outputs=1)
514 @parse_args(
'v',
'is',
'i')
515 def split_with_sizes(g, self, split_sizes, dim):
516 return g.op(
"Split", self, split_i=split_sizes, axis_i=dim, outputs=1)
519 @parse_args(
'v',
'i',
'v')
520 def select(g, self, dim, index):
525 index_val = _parse_arg(index,
'i')
526 slice_node = g.op(
"Slice", self, axes_i=[dim], starts_i=[index_val], ends_i=[index_val + 1])
527 return g.op(
"Squeeze", slice_node, axes_i=[dim])
529 return g.op(
"Gather", self, index, axis_i=dim)
532 def squeeze(g, self, dim=None):
535 for i, size
in enumerate(self.type().sizes()):
539 dims = [_get_const(dim,
'i',
'dim')]
540 return g.op(
"Squeeze", self, axes_i=dims)
543 def prelu(g, self, weight):
544 return g.op(
"PRelu", self, weight)
548 return g.op(
"Relu", input)
551 @parse_args(
'v',
't',
't')
552 def threshold(g, self, threshold, value):
554 if _scalar(threshold) != 0:
555 return _unimplemented(
"threshold",
"non-zero threshold")
556 if _scalar(value) != 0:
557 return _unimplemented(
"threshold",
"non-zero value")
558 return g.op(
"Relu", self)
561 def leaky_relu(g, input, negative_slope, inplace=False):
562 negative_slope = _get_const(negative_slope,
't',
'negative_slope')
565 return g.op(
"LeakyRelu", input, alpha_f=_scalar(negative_slope))
568 @parse_args(
'v',
'i')
569 def glu(g, input, dim):
570 assert input.type().sizes()[dim] % 2 == 0
572 first, second = g.op(
'Split', input, axis_i=dim, outputs=2)
573 return g.op(
'Mul', first, g.op(
'Sigmoid', second))
576 @parse_args(
'v',
'i',
'i')
577 def softmax(g, input, dim, dtype=None):
595 dim = input.type().dim() + dim
596 if input.type().dim() != dim + 1:
597 return _unimplemented(
"dim",
"ONNX and PyTorch use different strategies to split the input.")
598 return_op = g.op(
'Softmax', input, axis_i=dim)
600 return_op = g.op(
"Cast", return_op, to_i=scalar_type_to_onnx[dtype])
604 @parse_args(
'v',
't',
'v')
605 def softplus(g, self, beta, threshold):
607 return _unimplemented(
"beta",
"has to be 1")
608 return g.op(
'Softplus', self)
611 def get_pool_ceil_padding(input, kernel_size, stride, padding):
612 dim = input.type().sizes()[-len(padding):]
613 ceiled_output_dim = [int(math.ceil((dim[i] + 2 * padding[i] - kernel_size[i]) / float(stride[i]))) + 1
614 for i
in range(0, len(padding))]
616 ceiled_output_dim = [ceiled_output_dim[i] - 1
617 if (((ceiled_output_dim[i] - 1) * stride[i]) >= (dim[i] + padding[i]))
618 else ceiled_output_dim[i]
619 for i
in range(0, len(ceiled_output_dim))]
623 (kernel_size[i] - (dim[i] + 2 * padding[i] - ((ceiled_output_dim[i] - 1) * stride[i] + 1)))
624 for i
in range(0, len(padding))]
626 padding_ceil = [(int(padding_ceil[i])
if padding_ceil[i] < kernel_size[i] - 1
else int(kernel_size[i] - 1))
627 if ((padding_ceil[i] + 2 * padding[i]) >= (kernel_size[i]))
630 for i
in range(0, len(padding_ceil))]
634 @parse_args(
'v',
'is',
'is',
'is',
'is',
'i')
635 def max_pool1d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode):
636 if ceil_mode
and input.type().kind() !=
"CompleteTensorType":
637 return _unimplemented(
"max_pool1d_with_indices",
"input size not accesible")
638 if set(_single(dilation)) != {1}:
639 return _unimplemented(
"max_pool1d_with_indices",
"dilation")
642 padding = tuple(_single(padding))
644 padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
645 padding = padding + tuple(numpy.add(padding_ceil, padding))
647 padding = padding * 2
648 r, indices = g.op(
"MaxPool", input, outputs=2,
649 kernel_shape_i=_single(kernel_size),
651 strides_i=_single(stride))
665 _, flattened_indices = g.op(
"MaxPool", input, outputs=2,
669 s = g.op(
"Slice", flattened_indices, axes_i=[2], starts_i=[0], ends_i=[1])
670 indices = sub(g, indices, s)
674 @parse_args(
'v',
'is',
'is',
'is',
'is',
'i')
675 def max_pool2d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode):
676 if ceil_mode
and input.type().kind() !=
"CompleteTensorType":
677 return _unimplemented(
"max_pool2d_with_indices",
"input size not accesible")
678 if set(_pair(dilation)) != {1}:
679 return _unimplemented(
"max_pool2d_with_indices",
"dilation")
682 padding = tuple(_pair(padding))
684 padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
685 padding = padding + tuple(numpy.add(padding_ceil, padding))
687 padding = padding * 2
688 r, indices = g.op(
"MaxPool", input, outputs=2,
689 kernel_shape_i=_pair(kernel_size),
691 strides_i=_pair(stride))
695 _, flattened_indices = g.op(
"MaxPool", input, outputs=2,
696 kernel_shape_i=[1, 1],
699 s = g.op(
"Slice", flattened_indices, axes_i=[2, 3], starts_i=[0, 0], ends_i=[1, 1])
700 indices = sub(g, indices, s)
704 @parse_args(
'v',
'is',
'is',
'is',
'is',
'i')
705 def max_pool3d_with_indices(g, input, kernel_size, stride, padding, dilation, ceil_mode):
706 if ceil_mode
and input.type().kind() !=
"CompleteTensorType":
707 return _unimplemented(
"max_pool3d_with_indices",
"input size not accesible")
708 if set(_triple(dilation)) != {1}:
709 return _unimplemented(
"max_pool3d_with_indices",
"dilation")
712 padding = tuple(_triple(padding))
714 padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
715 padding = padding + tuple(numpy.add(padding_ceil, padding))
717 padding = padding * 2
718 r, indices = g.op(
"MaxPool", input, outputs=2,
719 kernel_shape_i=_triple(kernel_size),
721 strides_i=_triple(stride))
725 _, flattened_indices = g.op(
"MaxPool", input, outputs=2,
726 kernel_shape_i=[1, 1, 1],
729 s = g.op(
"Slice", flattened_indices, axes_i=[2, 3, 4], starts_i=[0, 0, 0], ends_i=[1, 1, 1])
730 indices = sub(g, indices, s)
734 def _avg_pool(name, tuple_fn):
735 @parse_args(
'v',
'is',
'is',
'is',
'i',
'i')
736 def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include_pad):
737 if ceil_mode
and input.type().kind() !=
"CompleteTensorType":
738 return _unimplemented(name,
"input size not accesible")
741 padding = tuple(tuple_fn(padding))
743 padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
744 if count_include_pad:
745 input = g.op(
"Pad", input,
746 pads_i=((0,) * 2 + padding) * 2,
749 padding = (0,) * len(padding)
751 padding = padding + tuple(numpy.add(padding_ceil, padding))
753 padding = padding * 2
754 output = g.op(
"AveragePool", input,
755 kernel_shape_i=tuple_fn(kernel_size),
756 strides_i=tuple_fn(stride),
762 avg_pool1d = _avg_pool(
'avg_pool1d', _single)
763 avg_pool2d = _avg_pool(
'avg_pool2d', _pair)
764 avg_pool3d = _avg_pool(
'avg_pool3d', _triple)
767 def _adaptive_pool(name, type, tuple_fn, fn=None):
768 @parse_args(
'v',
'is')
769 def symbolic_fn(g, input, output_size):
779 if output_size == [1] * len(output_size)
and type ==
"AveragePool":
780 return g.op(
"GlobalAveragePool", input)
781 if input.type().kind() !=
"CompleteTensorType":
782 if output_size == [1] * len(output_size):
783 return g.op(
"GlobalMaxPool", input),
None 784 return _unimplemented(name,
'input size not accesible')
785 dim = input.type().sizes()[2:]
787 mod = [dim[i] % output_size[i]
for i
in range(0, len(dim))]
788 if mod != [0] * len(mod):
789 if output_size == [1] * len(output_size):
790 return g.op(
"GlobalMaxPool", input),
None 791 return _unimplemented(name,
'output size that are not factor of input size')
792 k = [int(dim[i] / output_size[i])
for i
in range(0, len(dim))]
794 if type ==
"MaxPool":
795 return fn(g, input, k, k, (0,) * len(dim), (1,) * len(dim),
False)
796 output = g.op(type, input,
797 kernel_shape_i=tuple_fn(k),
798 strides_i=tuple_fn(k))
803 adaptive_avg_pool1d = _adaptive_pool(
'adaptive_avg_pool1d',
"AveragePool", _single)
804 adaptive_avg_pool2d = _adaptive_pool(
'adaptive_avg_pool2d',
"AveragePool", _pair)
805 adaptive_avg_pool3d = _adaptive_pool(
'adaptive_avg_pool3d',
"AveragePool", _triple)
807 adaptive_max_pool1d = _adaptive_pool(
'adaptive_max_pool1d',
"MaxPool", _single, max_pool1d_with_indices)
808 adaptive_max_pool2d = _adaptive_pool(
'adaptive_max_pool2d',
"MaxPool", _pair, max_pool2d_with_indices)
809 adaptive_max_pool3d = _adaptive_pool(
'adaptive_max_pool3d',
"MaxPool", _triple, max_pool3d_with_indices)
812 @parse_args(
'v',
'is',
'f')
813 def constant_pad_nd(g, input, padding, value):
816 paddings = prepare_onnx_paddings(input.type().dim(), padding)
817 return g.op(
"Pad", input, pads_i=paddings, mode_s=mode, value_f=value)
820 @parse_args(
'v',
'is')
821 def reflection_pad(g, input, padding):
824 paddings = prepare_onnx_paddings(input.type().dim(), padding)
825 return g.op(
"Pad", input, pads_i=paddings, mode_s=mode)
828 @parse_args(
'v',
'is')
829 def replication_pad(g, input, padding):
832 paddings = prepare_onnx_paddings(input.type().dim(), padding)
833 return g.op(
"Pad", input, pads_i=paddings, mode_s=mode)
836 reflection_pad1d = reflection_pad
837 reflection_pad2d = reflection_pad
838 reflection_pad3d = reflection_pad
839 replication_pad1d = replication_pad
840 replication_pad2d = replication_pad
841 replication_pad3d = replication_pad
844 @parse_args(
'v',
'is')
845 def upsample_nearest2d(g, input, output_size):
846 height_scale = float(output_size[-2]) / input.type().sizes()[-2]
847 width_scale = float(output_size[-1]) / input.type().sizes()[-1]
848 scales = g.op(
"Constant", value_t=
torch.tensor([1., 1., height_scale,
851 return g.op(
"Upsample", input, scales,
855 @parse_args(
'v',
'is',
'i')
856 def upsample_bilinear2d(g, input, output_size, align_corners):
858 return _unimplemented(
"upsample_bilinear2d",
"align_corners == True")
859 height_scale = float(output_size[-2]) / input.type().sizes()[-2]
860 width_scale = float(output_size[-1]) / input.type().sizes()[-1]
861 scales = g.op(
"Constant", value_t=
torch.tensor([1., 1., height_scale,
863 return g.op(
"Upsample", input, scales,
867 def wrap_logical_op_with_cast_to_uint8(func):
868 def wrap_with_cast(g, input, other):
869 return g.op(
"Cast", func(g, input, other), to_i=cast_pytorch_to_onnx[
'Byte'])
870 return wrap_with_cast
873 def wrap_logical_op_with_negation(func):
874 def wrap_with_not(g, input, other):
875 return g.op(
"Not", func(g, input, other))
879 @wrap_logical_op_with_cast_to_uint8
880 def eq(g, self, other):
881 return g.op(
"Equal", self, other)
884 @wrap_logical_op_with_cast_to_uint8
885 @wrap_logical_op_with_negation
886 def ne(g, self, other):
887 return g.op(
"Equal", self, other)
890 @wrap_logical_op_with_cast_to_uint8
891 def gt(g, input, other):
892 return gt_impl(g, input, other)
895 def gt_impl(g, input, other):
896 other = _maybe_get_scalar(other)
897 return g.op(
"Greater", input, _if_scalar_type_as(g, other, input))
900 @wrap_logical_op_with_cast_to_uint8
901 def lt(g, input, other):
902 return lt_impl(g, input, other)
905 def lt_impl(g, input, other):
906 other = _maybe_get_scalar(other)
907 return g.op(
"Less", input, _if_scalar_type_as(g, other, input))
910 @wrap_logical_op_with_cast_to_uint8
911 @wrap_logical_op_with_negation
912 def ge(g, input, other):
913 other = _maybe_get_scalar(other)
914 return lt_impl(g, input, _if_scalar_type_as(g, other, input))
917 @wrap_logical_op_with_cast_to_uint8
918 @wrap_logical_op_with_negation
919 def le(g, input, other):
920 other = _maybe_get_scalar(other)
921 return gt_impl(g, input, _if_scalar_type_as(g, other, input))
924 def where(g, condition, self, other):
925 return g.op(
"ATen", condition, self, other, operator_s=
"where")
928 @parse_args(
'v',
'i',
'i')
929 def log_softmax(g, input, dim=None, dtype=None):
933 dim = input.type().dim() + dim
934 if input.type().dim() != dim + 1:
935 return _unimplemented(
"dim",
"ONNX and PyTorch use different strategies to split the input.")
936 return_op = g.op(
"LogSoftmax", input, axis_i=dim)
938 return_op = g.op(
"Cast", return_op, to_i=scalar_type_to_onnx[dtype])
942 @parse_args(
'v',
'v',
'v',
'is',
'is',
'is',
'i',
'is',
'i',
'i',
'i',
'i')
943 def _convolution(g, input, weight, bias, stride, padding, dilation,
944 transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled):
945 weight_size = weight.type().sizes()
947 args = [input, weight]
949 if not bias.node().mustBeNone()
and bias.type().dim() == 1:
952 kwargs = {
"kernel_shape_i": weight_size[2:],
956 "pads_i": padding + padding,
957 "dilations_i": dilation,
960 if any(o != 0
for o
in output_padding):
965 assert len(stride) == len(output_padding)
966 kwargs[
"output_padding_i"] = output_padding
968 n = g.op(
"ConvTranspose" if transposed
else "Conv", *args, **kwargs)
970 if not bias.node().mustBeNone()
and bias.type().dim() != 1:
971 return g.op(
"Add", n, bias)
976 @parse_args(
'v',
'v',
'v',
'v',
'v',
'i',
'f',
'f',
'i')
977 def batch_norm(g, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled):
978 input_sizes = input.type().sizes()
979 if len(input_sizes) == 2:
981 input = g.op(
"Unsqueeze", input, axes_i=[2])
983 if weight
is None or weight.node().mustBeNone():
984 assert len(input_sizes) > 1
985 weight_value =
torch.tensor([1.] * input_sizes[1]).type(
986 'torch.' + input.type().scalarType() +
'Tensor')
987 weight = g.op(
"Constant", value_t=weight_value)
988 if bias
is None or bias.node().mustBeNone():
989 assert len(input_sizes) > 1
991 'torch.' + input.type().scalarType() +
'Tensor')
992 bias = g.op(
"Constant", value_t=bias_value)
993 out = g.op(
"BatchNormalization", input, weight, bias, running_mean, running_var,
995 momentum_f=1 - momentum,
996 outputs=1
if not training
else 5)
998 if len(input_sizes) == 2:
999 out = g.op(
"Squeeze", out, axes_i=[2])
1002 res, new_running_mean, new_running_var, saved_mean, saved_var = out
1003 new_running_mean.setType(running_mean.type())
1004 new_running_var.setType(running_var.type())
1005 saved_mean.setUniqueName(
"batch_norm_dead_output-" + saved_mean.uniqueName())
1006 saved_var.setUniqueName(
"batch_norm_dead_output-" + saved_var.uniqueName())
1007 if len(input_sizes) == 2:
1008 res = g.op(
"Squeeze", res, axes_i=[2])
1012 @parse_args(
'v',
'v',
'v',
'v',
'v',
'i',
'f',
'f',
'i')
1013 def instance_norm(g, input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled):
1014 input_sizes = input.type().sizes()
1015 if weight
is None or weight.node().mustBeNone():
1016 assert len(input_sizes) > 1
1017 weight_value =
torch.tensor([1.] * input_sizes[1]).type(
1018 'torch.' + input.type().scalarType() +
'Tensor')
1019 weight = g.op(
"Constant", value_t=weight_value)
1020 if bias
is None or bias.node().mustBeNone():
1021 assert len(input_sizes) > 1
1023 'torch.' + input.type().scalarType() +
'Tensor')
1024 bias = g.op(
"Constant", value_t=bias_value)
1025 return g.op(
"InstanceNormalization", input, weight, bias, epsilon_f=eps)
1028 @parse_args(
'v',
'i',
'i',
'i')
1029 def unfold(g, input, dimension, size, step):
1030 return g.op(
"ATen", input, operator_s=
"unfold", dimension_i=dimension, size_i=size, step_i=step)
1033 @parse_args(
'v',
'v',
'i')
1034 def _weight_norm(graph, v, g, dim):
1035 return graph.op(
"ATen", v, g, dim_i=dim, operator_s=
"_weight_norm")
1038 @parse_args(
'v',
't',
't',
't')
1039 def elu(g, input, alpha, scale, input_scale):
1040 if scale
and scale != 1.:
1041 return _unimplemented(
"scale",
"does not support scale in Elu")
1042 if input_scale
and input_scale != 1.:
1043 return _unimplemented(
"input_scale",
"does not support input_scale in Elu")
1045 return g.op(
"Elu", input, alpha_f=_scalar(alpha))
1049 return g.op(
"Selu", input)
1052 @parse_args(
'v',
'i',
'v')
1053 def index_select(g, self, dim, index):
1054 return g.op(
"Gather", self, index, axis_i=dim)
1057 def index_put(g, self, indices_list_value, values, accumulate):
1058 indices_list = _unpack_list(indices_list_value)
1059 args = [self] + indices_list + [values, accumulate]
1060 return g.op(
"ATen", *args, operator_s=
'index_put')
1063 def type_as(g, self, other):
1064 if self.isTensor()
and other.isTensor()
and self.type().scalarType() == other.type().scalarType():
1067 if other.isTensor():
1068 other_type_name = other.type().scalarType()
1069 return g.op(
"Cast", self, to_i=cast_pytorch_to_onnx[other_type_name])
1072 return g.op(
"ATen", self, other, operator_s=
"type_as")
1075 @parse_args(
'v',
'is',
'v',
'v',
'f',
'i')
1076 def layer_norm(g, self, normalized_shape, weight, bias, eps, cudnn_enable):
1077 return g.op(
"ATen", self, weight, bias, normalized_shape_i=normalized_shape,
1078 eps_f=eps, cudnn_enable_i=cudnn_enable, operator_s=
"layer_norm")
1082 def clone(g, input):
1087 return g.op(
"Abs", self)
1091 return g.op(
"Log", self)
1094 def pow(g, self, exponent):
1095 exponent = _maybe_get_scalar(exponent)
1096 return g.op(
"Pow", self, _if_scalar_type_as(g, exponent, self))
1099 def clamp(g, self, min, max):
1102 if min.node().mustBeNone():
1103 return clamp_max(g, self, max)
1104 elif max.node().mustBeNone():
1105 return clamp_min(g, self, min)
1107 min = _parse_arg(min,
'f')
1108 max = _parse_arg(max,
'f')
1109 return g.op(
"Clip", self, min_f=min, max_f=max)
1112 @parse_args(
'v',
'f')
1113 def clamp_min(g, self, min):
1114 return g.op(
"Clip", self, min_f=min)
1117 @parse_args(
'v',
'f')
1118 def clamp_max(g, self, max):
1119 return g.op(
"Clip", self, max_f=max)
1124 def max(g, self, dim_or_y=None, keepdim=None):
1125 if dim_or_y
is None and keepdim
is None:
1126 return g.op(
"ReduceMax", self, keepdims_i=0)
1128 return g.op(
"Max", self, dim_or_y)
1130 dim = _get_const(dim_or_y,
'i',
'dim')
1131 keepdim = _get_const(keepdim,
'i',
'keepdim')
1141 def min(g, self, dim_or_y=None, keepdim=None):
1142 if dim_or_y
is None and keepdim
is None:
1143 return g.op(
"ReduceMin", self, keepdims_i=0)
1145 return g.op(
"Min", self, dim_or_y)
1147 dim = _get_const(dim_or_y,
'i',
'dim')
1148 keepdim = _get_const(keepdim,
'i',
'keepdim')
1159 return g.op(
"Exp", self)
1162 @parse_args(
'v',
'f',
'i')
1163 def dropout(g, input, p, train):
1166 r, _ = g.op(
"Dropout", input, ratio_f=p, outputs=2)
1170 def _unsupported_dropout(name):
1171 @parse_args(
'v',
'f',
'i')
1172 def feature_dropout(g, input, p, train):
1176 return _unimplemented(name,
"training mode")
1178 return feature_dropout
1181 feature_dropout = _unsupported_dropout(
"feature_dropout")
1182 alpha_dropout = _unsupported_dropout(
"alpha_dropout")
1183 feature_alpha_dropout = _unsupported_dropout(
"feature_alpha_dropout")
1187 feature_dropout_ = feature_dropout
1188 alpha_dropout_ = alpha_dropout
1189 feature_alpha_dropout_ = feature_alpha_dropout
1192 @parse_args(
'v',
't',
'i',
'i')
1193 def norm(g, self, p, dim, keepdim):
1195 f = _reduce_op_symbolic(
"ReduceL1")
1197 f = _reduce_op_symbolic(
"ReduceL2")
1199 raise RuntimeError(
"ONNX export only p-norms with p of 1 or 2")
1200 return f(g, self, dim=dim, keepdim=keepdim)
1203 @parse_args(
'v',
'v',
'v',
'i')
1204 def conv_tbc(g, input, weight, bias, pad):
1205 return g.op(
"ATen", input, weight, bias, operator_s=
"conv_tbc", pad_i=pad)
1208 @parse_args(
'v',
'i',
'i')
1209 def _unique(g, input, sorted, return_inverse):
1210 return g.op(
"ATen", input, operator_s=
"_unique", sorted_i=sorted,
1211 return_inverse_i=return_inverse, outputs=2)
1220 cast_pytorch_to_onnx = {
1221 'Byte': torch.onnx.TensorProtoDataType.UINT8,
1222 'Char': torch.onnx.TensorProtoDataType.INT8,
1223 'Double': torch.onnx.TensorProtoDataType.DOUBLE,
1224 'Float': torch.onnx.TensorProtoDataType.FLOAT,
1225 'Half': torch.onnx.TensorProtoDataType.FLOAT16,
1226 'Int': torch.onnx.TensorProtoDataType.INT32,
1227 'Long': torch.onnx.TensorProtoDataType.INT64,
1228 'Short': torch.onnx.TensorProtoDataType.INT16,
1231 scalar_name_to_pytorch = {
1246 scalar_type_to_pytorch_type = [
1258 def _cast_func_template(to_i, g, input, non_blocking):
1259 return g.op(
"Cast", input, to_i=to_i)
1262 for k, v
in cast_pytorch_to_onnx.items():
1263 name =
'_cast_{}'.format(k)
1264 globals()[name] = parse_args(
'v',
'i')(partial(_cast_func_template, v))
1267 scalar_type_to_onnx = [
1268 cast_pytorch_to_onnx[
"Byte"],
1269 cast_pytorch_to_onnx[
"Char"],
1270 cast_pytorch_to_onnx[
"Short"],
1271 cast_pytorch_to_onnx[
"Int"],
1272 cast_pytorch_to_onnx[
"Long"],
1273 cast_pytorch_to_onnx[
"Half"],
1274 cast_pytorch_to_onnx[
"Float"],
1275 cast_pytorch_to_onnx[
"Double"],
1279 @parse_args(
'v',
'i',
'v',
'v')
1280 def zeros(g, sizes, dtype, layout, device):
1282 return g.op(
"ConstantOfShape", sizes,
1283 value_t=
torch.tensor(0, dtype=scalar_type_to_pytorch_type[dtype]))
1286 @parse_args(
'v',
'i',
'v',
'v')
1287 def zeros_like(g, input, dtype, layout, device):
1288 shape = g.op(
"Shape", input)
1289 return g.op(
"ConstantOfShape", shape,
1290 value_t=
torch.tensor(0, dtype=scalar_type_to_pytorch_type[dtype]))
1293 @parse_args(
'v',
'i',
'v',
'v')
1294 def ones(g, sizes, dtype, layout, device):
1295 return g.op(
"ConstantOfShape", sizes,
1296 value_t=
torch.tensor(1, dtype=scalar_type_to_pytorch_type[dtype]))
1299 @parse_args(
'v',
'i',
'v',
'v')
1300 def ones_like(g, input, dtype, layout, device):
1301 shape = g.op(
"Shape", input)
1302 return g.op(
"ConstantOfShape", shape,
1303 value_t=
torch.tensor(1, dtype=scalar_type_to_pytorch_type[dtype]))
1306 def full(g, sizes, value, dtype, layout, device):
1307 const_value = _maybe_get_const(value,
't')
1308 if _is_value(const_value):
1309 tmp = zeros(sizes, dtype, layout, device)
1310 return add(tmp, value, g.op(
"Constant", value_t=
torch.tensor(1)))
1312 dtype = _get_const(dtype,
'i',
'dtype')
1313 return g.op(
"ConstantOfShape", sizes,
1314 value_t=
torch.tensor(const_value, dtype=scalar_type_to_pytorch_type[dtype]))
1317 @parse_args(
'v',
'f',
'i',
'v',
'v')
1318 def full_like(g, input, fill_value, dtype, layout, device):
1319 shape = g.op(
"Shape", input)
1320 return g.op(
"ConstantOfShape", shape,
1321 value_t=
torch.tensor(fill_value, dtype=scalar_type_to_pytorch_type[dtype]))
1324 @parse_args(
'v',
'v',
'v',
'v',
'i')
1325 def slice(g, self, dim, start, end, step):
1327 _unimplemented(
"slice",
"step!=1 is currently not supported")
1328 if start.node().kind() !=
'onnx::Constant' or \
1329 end.node().kind() !=
'onnx::Constant' or dim.node().kind() !=
'onnx::Constant':
1330 start_unsqueezed = g.op(
"Unsqueeze", start, axes_i=[0])
1331 end_unsqueezed = g.op(
"Unsqueeze", end, axes_i=[0])
1332 dim_unsqueezed = g.op(
"Unsqueeze", dim, axes_i=[0])
1333 return g.op(
"DynamicSlice", self, start_unsqueezed, end_unsqueezed, dim_unsqueezed)
1335 start = _parse_arg(start,
'i')
1336 end = _parse_arg(end,
'i')
1337 dim = _parse_arg(dim,
'i')
1338 return g.op(
"Slice", self, axes_i=[dim], starts_i=[start], ends_i=[end])
1341 @parse_args(
'v',
'f',
'f')
1342 def hardtanh(g, self, min_val, max_val):
1343 return g.op(
"Clip", self, min_f=min_val, max_f=max_val)
1350 @parse_args(
'v',
'i')
1351 def unsqueeze(g, self, dim):
1352 return g.op(
"Unsqueeze", self, axes_i=[dim])
1355 @parse_args(
'v',
'i',
'i',
'i',
'i')
1356 def topk(g, self, k, dim, largest, sorted, out=None):
1358 _unimplemented(
"TopK",
"Out parameter is not supported for topk")
1360 _unimplemented(
"TopK",
"Ascending TopK is not supported")
1362 return g.op(
"TopK", self, k_i=k, axis_i=dim, outputs=2)
1365 def to(g, self, *args):
1368 if args[0].type().isSubtypeOf(ListType.ofInts()):
1373 dtype = _get_const(args[0],
'i',
'dtype')
1374 return g.op(
"Cast", self, to_i=scalar_type_to_onnx[dtype])
1375 elif len(args) == 4:
1377 dtype = _get_const(args[1],
'i',
'dtype')
1378 return g.op(
"Cast", self, to_i=scalar_type_to_onnx[dtype])
1379 elif len(args) == 5:
1381 dtype = _get_const(args[0],
'i',
'dtype')
1383 return g.op(
"Cast", self, to_i=scalar_type_to_onnx[dtype])
1385 raise NotImplementedError(
"Unknown aten::to signature")
1388 def repeat(g, self, repeats):
1389 if not _is_value(repeats):
1390 repeats = g.op(
"Constant", value_t=torch.LongTensor(repeats))
1391 const_repeats = _maybe_get_const(repeats,
'is')
1393 if self.isTensor()
and not _is_value(const_repeats):
1394 sizes = self.type().sizes()
1395 diff_dims = len(const_repeats) - len(sizes)
1397 self = view(g, self, [1] * diff_dims + sizes)
1398 return g.op(
"Tile", self, repeats)
1401 @parse_args(
'v',
'i')
1402 def pixel_shuffle(g, self, upscale_factor):
1403 dims = self.type().sizes()
1405 return _unimplemented(
"pixel_shuffle",
"only support 4d input")
1406 output_channel = dims[1] // upscale_factor // upscale_factor
1407 after_view = view(g, self, [-1, upscale_factor, upscale_factor,
1408 output_channel, dims[2], dims[3]])
1409 after_transpose = g.op(
"Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3])
1410 return view(g, after_transpose,
1411 [-1, output_channel, dims[2] * upscale_factor, dims[3] *
1415 @parse_args(
'v',
'i',
'v',
'v',
'f',
'i')
1416 def group_norm(g, input, num_groups, weight, bias, eps, cudnn_enabled):
1417 return g.op(
"ATen", input, weight, bias, num_groups_i=num_groups,
1418 eps_f=eps, cudnn_enabled_i=cudnn_enabled, operator_s=
"group_norm")
1421 def _generic_rnn(g, variant, input, initial_states, all_weights, has_biases,
1422 num_layers, dropout, train, bidirectional, batch_first=
None, batch_sizes=
None):
1423 weights_per_layer = 4
if has_biases
else 2
1424 assert len(all_weights) == num_layers * weights_per_layer * (1 + bidirectional)
1425 layer_weights = [all_weights[i:i + weights_per_layer]
for i
in range(0, len(all_weights), weights_per_layer)]
1427 return _unimplemented(
"RNN/GRU/LSTM",
"batch_first")
1428 if dropout
and train:
1429 return _unimplemented(
"RNN/GRU/LSTM",
"dropout in training mode")
1431 if variant.startswith(
'RNN'):
1432 nonlinearity = variant[4:].lower()
1435 w_hh = all_weights[1]
1436 hidden_size = w_hh.type().sizes()[1]
1438 unidirectional =
not bidirectional
1443 if variant ==
'RNN' or variant ==
'GRU':
1445 elif variant ==
'LSTM':
1446 h0, c0 = initial_states
1449 sequence_lens = unused(g)
if batch_sizes
is None else batch_sizes
1451 if variant ==
'GRU':
1454 reform_permutation = [(1, 2), (0, 1), (2, 3)]
1455 elif variant ==
'LSTM':
1458 reform_permutation = [(0, 1), (3, 4), (1, 3)]
1460 def reform_weights(g, w, n, intervals):
1461 slices = [g.op(
'Slice', w, axes_i=[0], starts_i=[x * n], ends_i=[y * n])
for x, y
in intervals]
1462 return g.op(
'Concat', *slices, axis_i=0)
1464 def transform_weights(layer_index):
1465 if variant ==
'RNN':
1466 weight_ih, weight_hh, bias_ih, bias_hh = layer_weights[layer_index]
1467 elif variant ==
'GRU' or variant ==
'LSTM':
1468 weight_ih, weight_hh, bias_ih, bias_hh = \
1469 [reform_weights(g, w, hidden_size, reform_permutation)
for w
in layer_weights[layer_index]]
1470 bias_concat = g.op(
'Concat', bias_ih, bias_hh, axis_i=0)
1472 return tuple(g.op(
'Unsqueeze', x, axes_i=[0])
for x
in (weight_ih, weight_hh, bias_concat))
1474 def retrieve_state(x, start, end):
1475 return x
if num_layers == 1
else g.op(
'Slice', x, axes_i=[0], starts_i=[start], ends_i=[end])
1477 for i
in range(num_layers):
1479 weight_ih, weight_hh, bias_concat = transform_weights(i)
1480 state_indices = i, i + 1
1482 weight_ih_f, weight_hh_f, bias_f = transform_weights(2 * i)
1483 weight_ih_b, weight_hh_b, bias_b = transform_weights(2 * i + 1)
1485 weight_ih = g.op(
'Concat', weight_ih_f, weight_ih_b, axis_i=0)
1486 weight_hh = g.op(
'Concat', weight_hh_f, weight_hh_b, axis_i=0)
1487 bias_concat = g.op(
'Concat', bias_f, bias_b, axis_i=0)
1489 state_indices = 2 * i, 2 * i + 2
1491 inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens]
1493 inputs.append(retrieve_state(h0, *state_indices))
1494 if variant ==
'LSTM':
1495 inputs.append(retrieve_state(c0, *state_indices))
1497 extra_kwargs = {}
if unidirectional
else {
'direction_s':
'bidirectional'}
1498 if variant ==
'RNN':
1499 prev_output, h_out = g.op(
'RNN', *inputs, outputs=2,
1500 hidden_size_i=hidden_size,
1501 activations_s=[nonlinearity],
1503 elif variant ==
'GRU':
1504 prev_output, h_out = g.op(
'GRU', *inputs, outputs=2,
1505 hidden_size_i=hidden_size,
1506 linear_before_reset_i=1,
1508 elif variant ==
'LSTM':
1509 prev_output, h_out, c_out = g.op(
'LSTM', *inputs, outputs=3,
1510 hidden_size_i=hidden_size,
1521 prev_output = g.op(
'Transpose', prev_output, perm_i=[0, 2, 1, 3])
1522 prev_output = g.op(
'Reshape', prev_output, g.op(
'Constant', value_t=torch.LongTensor([0, 0, -1])))
1524 prev_output = g.op(
'Squeeze', prev_output, axes_i=[1])
1526 h_outs.append(h_out)
1527 if variant ==
'LSTM':
1528 c_outs.append(c_out)
1529 h_outs = h_out
if num_layers == 1
else g.op(
'Concat', *h_outs, axis_i=0)
1530 if variant ==
'RNN' or variant ==
'GRU':
1531 return prev_output, h_outs
1532 elif variant ==
'LSTM':
1533 c_outs = c_out
if num_layers == 1
else g.op(
'Concat', *c_outs, axis_i=0)
1534 return prev_output, h_outs, c_outs
1537 @parse_args(
'v',
'v',
'v',
'i',
'i',
'f',
'i',
'i',
'i')
1538 def _lstm_full(g, input, hidden_v, weight_v, has_biases, num_layers, dropout, train, bidirectional, batch_first):
1539 hidden, weight = _unpack_list(hidden_v), _unpack_list(weight_v)
1540 return _generic_rnn(g,
'LSTM', input, hidden, weight, has_biases, num_layers,
1541 dropout, train, bidirectional, batch_first)
1544 @parse_args(
'v',
'v',
'v',
'v',
'i',
'i',
'f',
'i',
'i')
1545 def _lstm_packed(g, input, batch_sizes, hidden_v, weight_v, has_biases, num_layers, dropout, train, bidirectional):
1546 hidden, weight = _unpack_list(hidden_v), _unpack_list(weight_v)
1547 return _generic_rnn(g,
'LSTM', input, hidden, weight, has_biases, num_layers,
1548 dropout, train, bidirectional, batch_sizes=batch_sizes)
1552 if _is_tensor_list(args[3]):
1553 return _lstm_packed(g, *args)
1555 return _lstm_full(g, *args)
1558 def _one_hidden_rnn(kind):
1559 @parse_args(
'v',
'v',
'v',
'i',
'i',
'f',
'i',
'i',
'i')
1560 def _rnn_full(g, input, hidden, weight_v, has_biases, num_layers, dropout, train, bidirectional, batch_first):
1561 weight = _unpack_list(weight_v)
1562 return _generic_rnn(g, kind, input, hidden, weight, has_biases, num_layers,
1563 dropout, train, bidirectional, batch_first)
1565 @parse_args(
'v',
'v',
'v',
'v',
'i',
'i',
'f',
'i',
'i')
1566 def _rnn_packed(g, input, batch_sizes, hidden, weight_v, has_biases, num_layers, dropout, train, bidirectional):
1567 weight = _unpack_list(weight_v)
1568 return _generic_rnn(g, kind, input, hidden, weight, has_biases, num_layers,
1569 dropout, train, bidirectional, batch_sizes=batch_sizes)
1571 def symbolic(g, *args):
1572 if _is_tensor_list(args[3]):
1573 return _rnn_packed(g, *args)
1575 return _rnn_full(g, *args)
1580 gru = _one_hidden_rnn(
'GRU')
1581 rnn_tanh = _one_hidden_rnn(
'RNN_TANH')
1582 rnn_relu = _one_hidden_rnn(
'RNN_RELU')
1585 @parse_args(
'v',
'i')
1586 def _dim_arange(g, like, dim):
1587 return g.op(
'ATen', like, dim_i=dim, operator_s=
'_dim_arange')
1590 def detach(g, input):
1595 def contiguous(g, input):
1599 @parse_args(
'v',
'v',
'i')
1600 def _pack_padded_sequence(g, input, lengths, batch_first):
1605 input = g.op(
'Transpose', input, perm_i=[1, 0, 2])
1606 if not lengths.type().isSubtypeOf(torch._C.TensorType.get()):
1607 raise RuntimeError(
"Lengths must be a Tensor for ONNX export")
1611 if lengths.type().scalarType() !=
'Int':
1612 lengths = _cast_Int(g, lengths,
False)
1613 return g.op(
"prim::PackPadded", input, lengths, outputs=2)
1616 @parse_args(
'v',
'v',
'i',
't',
'v')
1617 def _pad_packed_sequence(g, data, batch_sizes, batch_first, padding_value, total_length):
1621 data, lengths = g.op(
"prim::PadPacked", data, batch_sizes, outputs=2)
1623 data = g.op(
'Transpose', data, perm_i=[1, 0, 2])
1624 return data, lengths
1627 def randn(g, *shapes):
1628 shapes_list = list(shapes)
1629 shape = _maybe_get_const(shapes_list[0],
"is")
1630 return g.op(
'RandomNormal', shape_i=shape)
1633 @parse_args(
'v',
'f',
'f',
'i',
'none')
1634 def rrelu(g, input, lower, upper, training, generator):
1635 p = g.op(
'RandomUniformLike', input, high_f=upper, low_f=lower)
1636 return g.op(
'PRelu', input, p)
1640 def log_sigmoid(g, input):
1641 p = g.op(
'Sigmoid', input)
1642 return g.op(
'Log', p)
1647 return g.op(
'Erf', input)
1650 @parse_args(
'v',
'i',
'i')
1651 def flatten(g, input, start_dim, end_dim):
1652 dim = input.type().dim()
1654 end_dim = dim + end_dim
1656 if start_dim == 1
and end_dim == dim - 1 :
1657 return g.op(
"Flatten", input, axis_i=start_dim)
1658 if start_dim == 0
and end_dim == dim - 2 :
1659 return g.op(
"Flatten", input, axis_i=end_dim + 1)
1661 if input.type().kind() !=
"CompleteTensorType":
1662 return _unimplemented(
"flatten",
"input size not accesible")
1663 input_dims = input.type().sizes()
1665 for i
in range(0, dim):
1666 if start_dim < i
and end_dim >= i:
1667 output_dims[start_dim] = output_dims[start_dim] * input_dims[i]
1669 output_dims.append(input_dims[i])
1670 shape = g.op(
"Constant", value_t=torch.LongTensor(output_dims))
1671 p = _reshape_from_tensor(g, input, shape)
1676 def nonzero(g, input):
1677 return t(g, g.op(
'NonZero', input))
1681 def isnan(g, input):
1682 output = g.op(
'IsNaN', input)
1683 output = _cast_func_template(cast_pytorch_to_onnx[
'Byte'], g, output,
None)
1687 @parse_args(
'v',
'i',
'i',
'i')
1688 def narrow(g, input, dim, start, length):
1689 return g.op(
"Slice", input, axes_i=[dim], starts_i=[start], ends_i=[start + length])
1692 @parse_args(
'v',
'i',
'i')
1693 def argmax(g, input, dim, keepdim):
1694 return g.op(
'ArgMax', input, axis_i=dim, keepdims_i=keepdim)
1697 @parse_args(
'v',
'i',
'i')
1698 def argmin(g, input, dim, keepdim):
1699 return g.op(
'ArgMin', input, axis_i=dim, keepdims_i=keepdim)
Module caffe2.python.layers.split.
Module caffe2.python.helpers.dropout.