2 from functools
import reduce
5 def maybe_view(tensor, size, check_same_size=True):
6 if check_same_size
and tensor.size() == size:
8 return tensor.contiguous().view(size)
11 def maybe_unexpand(tensor, old_size, check_same_size=True):
12 if check_same_size
and tensor.size() == old_size:
14 num_unsqueezed = tensor.dim() - len(old_size)
15 expanded_dims = [dim
for dim, (expanded, original)
16 in enumerate(zip(tensor.size()[num_unsqueezed:], old_size))
17 if expanded != original]
19 for _
in range(num_unsqueezed):
20 tensor = tensor.sum(0, keepdim=
False)
21 for dim
in expanded_dims:
22 tensor = tensor.sum(dim, keepdim=
True)
31 def prepare_onnx_paddings(dim, pad):
32 assert isinstance(dim, int)
36 assert len(pad) <= dim * 2
38 paddings = list(pad[:]) + [0] * (dim * 2 - len(pad))
40 paddings = paddings[-2::-2] + paddings[-1::-2]
41 assert len(paddings) == dim * 2
52 def check_onnx_broadcast(dims1, dims2):
57 numel1 = reduce(
lambda x, y: x * y, dims1)
58 numel2 = reduce(
lambda x, y: x * y, dims2)
65 if numel2 != 1
and dims1[len1 - len2:] != dims2:
74 raise ValueError(
"Numpy style broadcasting is not supported in ONNX. " 75 "Input dims are: {}, {}".format(dims1, dims2))