1 from test_pytorch_common
import TestCase, run_tests, skipIfNoLapack, flatten
6 from torch.nn import Module, functional
18 import common_utils
as common
21 '''Usage: python test/onnx/test_operators.py [--no-onnx] [--produce-onnx-test-data] 22 --no-onnx: no onnx python dependence 23 --produce-onnx-test-data: generate onnx test data 30 def export_to_pbtxt(model, inputs, *args, **kwargs):
32 model, inputs,
None, verbose=
False, google_printer=
True,
36 def export_to_pb(model, inputs, *args, **kwargs):
37 kwargs[
'operator_export_type'] = torch.onnx.OperatorExportTypes.ONNX
45 def __init__(self, f, params=None):
48 super(FuncModule, self).__init__()
50 self.
params = nn.ParameterList(list(params))
52 def forward(self, *args):
53 return self.
f(*itertools.chain(args, self.
params))
58 def assertONNX(self, f, args, params=None, **kwargs):
61 if isinstance(f, nn.Module):
66 onnx_model_pbtxt = export_to_pbtxt(m, args, **kwargs)
67 subname = kwargs.pop(
'subname',
None)
68 self.assertExpected(onnx_model_pbtxt, subname)
70 onnx_model_pb = export_to_pb(m, args, **kwargs)
73 import onnx.numpy_helper
74 import test_onnx_common
75 model_def = onnx.ModelProto.FromString(onnx_model_pb)
76 onnx.checker.check_model(model_def)
78 test_function = inspect.stack()[1][0].f_code.co_name
79 test_name = test_function[0:4] +
"_operator" + test_function[4:]
80 output_dir = os.path.join(test_onnx_common.pytorch_operator_dir, test_name)
84 assert not os.path.exists(output_dir),
"{} should not exist!".format(output_dir)
85 os.makedirs(output_dir)
86 with open(os.path.join(output_dir,
"model.onnx"),
'wb')
as file:
87 file.write(model_def.SerializeToString())
88 data_dir = os.path.join(output_dir,
"test_data_set_0")
90 if isinstance(args, Variable):
92 for index, var
in enumerate(flatten(args)):
93 tensor = onnx.numpy_helper.from_array(var.data.numpy())
94 with open(os.path.join(data_dir,
"input_{}.pb".format(index)),
'wb')
as file:
95 file.write(tensor.SerializeToString())
97 if isinstance(outputs, Variable):
99 for index, var
in enumerate(flatten(outputs)):
100 tensor = onnx.numpy_helper.from_array(var.data.numpy())
101 with open(os.path.join(data_dir,
"output_{}.pb".format(index)),
'wb')
as file:
102 file.write(tensor.SerializeToString())
104 def assertONNXRaises(self, err, f, args, params=None, **kwargs):
107 if isinstance(f, nn.Module):
111 self.assertExpectedRaises(err,
lambda: export_to_pbtxt(m, args, **kwargs))
113 def assertONNXRaisesRegex(self, err, reg, f, args, params=None, **kwargs):
116 if isinstance(f, nn.Module):
120 with self.assertRaisesRegex(err, reg):
121 export_to_pbtxt(m, args, **kwargs)
123 def test_basic(self):
126 self.
assertONNX(
lambda x, y: -torch.sigmoid(torch.tanh(x * (x + y))), (x, y))
132 def test_index(self):
136 def test_type_as(self):
140 def test_addconstant(self):
141 x = torch.randn(2, 3, requires_grad=
True).double()
144 def test_add_broadcast(self):
145 x = torch.randn(2, 3, requires_grad=
True).double()
146 y = torch.randn(3, requires_grad=
True).double()
149 def test_add_left_broadcast(self):
150 x = torch.randn(3, requires_grad=
True).double()
151 y = torch.randn(2, 3, requires_grad=
True).double()
154 def test_add_size1_broadcast(self):
155 x = torch.randn(2, 3, requires_grad=
True).double()
156 y = torch.randn(2, 1, requires_grad=
True).double()
159 def test_add_size1_right_broadcast(self):
160 x = torch.randn(2, 3, requires_grad=
True).double()
161 y = torch.randn(3, requires_grad=
True).double()
164 def test_add_size1_singleton_broadcast(self):
165 x = torch.randn(2, 3, requires_grad=
True).double()
166 y = torch.randn(1, 3, requires_grad=
True).double()
170 x = torch.randn(2, 3, requires_grad=
True).double()
173 def test_transpose(self):
174 x =
torch.tensor([[0.0, 1.0], [2.0, 3.0]], requires_grad=
True)
175 self.
assertONNX(
lambda x: x.transpose(0, 1).transpose(1, 0), x)
177 def test_chunk(self):
181 def test_split(self):
182 x =
torch.tensor([[0.0, 1.0, 1.0, 0.0, 2.0, 2.0], [2.0, 3.0, 3.0, 2.0, 1.0, 1.0]])
183 self.
assertONNX(
lambda x: torch.split(x, 2, 1), x)
185 def test_split_with_sizes(self):
186 x =
torch.tensor([[0.0, 1.0, 1.0, 0.0, 2.0, 2.0], [2.0, 3.0, 3.0, 2.0, 1.0, 1.0]])
187 self.
assertONNX(
lambda x: torch.split(x, [2, 1, 3], 1), x)
189 def test_concat2(self):
190 x = torch.randn(2, 3)
191 y = torch.randn(2, 3)
192 self.
assertONNX(
lambda inputs: torch.cat(inputs, 1), ((x, y),))
195 m1 = torch.randn(2, 3, requires_grad=
True)
196 m2 = torch.randn(3, 4, requires_grad=
True)
199 def test_addmm(self):
200 m1 = torch.randn(2, 3, requires_grad=
True)
201 m2 = torch.randn(3, 4, requires_grad=
True)
202 m3 = torch.randn(4, requires_grad=
True)
203 self.
assertONNX(
lambda x, y, z: torch.addmm(torch.addmm(z, x, y), x, y), (m1, m2, m3))
205 def test_permute2(self):
207 self.
assertONNX(
lambda x: x.permute(0, 1, 4, 2, 5, 3), x)
210 x =
torch.tensor([[[[0.0, 1.0, 1.0, 1.0], [2.0, 3.0, 7.0, 7.0]]]], requires_grad=
True)
211 self.
assertONNX(nn.ReflectionPad2d((2, 3, 0, 1)), x)
213 def test_params(self):
214 x =
torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=
True)
215 y = nn.Parameter(
torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=
True))
216 self.
assertONNX(
lambda x, y: -torch.sigmoid(torch.tanh(x * (x + y))), x, params=(y, ))
218 def test_symbolic_mismatch(self):
219 class MyFun(Function):
227 def forward(ctx, x, y):
234 with self.assertRaisesRegex(TypeError,
"occurred when translating MyFun"):
235 export_to_pbtxt(
FuncModule(MyFun().apply), (x, y))
238 def test_batchnorm(self):
239 x = torch.ones(2, 2, 2, 2, requires_grad=
True)
242 def test_batchnorm_1d(self):
243 x = torch.ones(2, 2, requires_grad=
True)
246 def test_batchnorm_training(self):
247 x = torch.ones(2, 2, 2, 2, requires_grad=
True)
248 self.
assertONNX(nn.BatchNorm2d(2), x, training=
True)
251 x = torch.ones(20, 16, 50, 40, requires_grad=
True)
252 self.
assertONNX(nn.Conv2d(16, 13, 3, bias=
False), x)
254 def test_convtranspose(self):
255 x = torch.ones(2, 3, 4, 5, requires_grad=
True)
256 self.
assertONNX(nn.ConvTranspose2d(3, 3, 3, stride=3, bias=
False,
257 padding=1, output_padding=2), x)
259 def test_maxpool(self):
260 x = torch.randn(20, 16, 50)
263 def test_avg_pool2d(self):
264 x = torch.randn(20, 16, 50, 32)
267 def test_maxpool_indices(self):
268 x = torch.randn(20, 16, 50)
269 self.
assertONNX(nn.MaxPool1d(3, stride=2, return_indices=
True), x)
271 def test_at_op(self):
272 x = torch.randn(3, 4)
274 class MyFun(Function):
278 return g.at(
"add", x, x)
284 class MyModule(Module):
285 def forward(self, x):
286 return MyFun.apply(x)
291 x = torch.randn(3, 4, requires_grad=
True)
292 self.
assertONNX(
lambda x: torch.clamp(x, min=-0.5, max=0.5), x)
294 def test_clip_min(self):
295 x = torch.randn(1, 2, 3, 4, requires_grad=
True)
296 self.
assertONNX(
lambda x: x.clamp(min=-0.1), x)
298 def test_clip_max(self):
299 x = torch.randn(1, 2, 3, 4, requires_grad=
True)
300 self.
assertONNX(
lambda x: x.clamp(max=0.1), x)
302 def test_hardtanh(self):
303 x = torch.randn(3, 4, requires_grad=
True)
304 self.
assertONNX(
lambda x: torch.nn.Hardtanh(-0.5, 0.5)(x), x)
307 x = torch.randn(3, 4, requires_grad=
True)
308 self.
assertONNX(
lambda x: torch.full(x.shape, 2), x)
310 def test_full_like(self):
311 x = torch.randn(3, 4, requires_grad=
True)
312 self.
assertONNX(
lambda x: torch.full_like(x, 2), x)
315 x = torch.randn(3, 4, requires_grad=
True)
316 y = torch.randn(3, 4, requires_grad=
True)
317 self.
assertONNX(
lambda x, y: torch.max(x, y), (x, y))
320 x = torch.randn(3, 4, requires_grad=
True)
321 y = torch.randn(3, 4, requires_grad=
True)
322 self.
assertONNX(
lambda x, y: torch.min(x, y), (x, y))
325 x = torch.randn(1, 2, 3, 4, requires_grad=
True)
328 def test_reduced_mean(self):
329 x = torch.randn(1, 2, 3, 4, requires_grad=
True)
330 self.
assertONNX(
lambda x: torch.mean(x, dim=2), x)
332 def test_reduced_mean_keepdim(self):
333 x = torch.randn(1, 2, 3, 4, requires_grad=
True)
334 self.
assertONNX(
lambda x: torch.mean(x, dim=2, keepdim=
True), x)
337 x = torch.randn(1, 2, 3, 4, requires_grad=
True)
340 def test_reduced_sum(self):
341 x = torch.randn(1, 2, 3, 4, requires_grad=
True)
342 self.
assertONNX(
lambda x: torch.sum(x, dim=2), x)
344 def test_reduced_sum_keepdim(self):
345 x = torch.randn(1, 2, 3, 4, requires_grad=
True)
346 self.
assertONNX(
lambda x: torch.sum(x, dim=2, keepdim=
True), x)
349 x = torch.randn(1, 2, 3, 4, requires_grad=
True)
352 def test_reduced_prod(self):
353 x = torch.randn(1, 2, 3, 4, requires_grad=
True)
354 self.
assertONNX(
lambda x: torch.prod(x, dim=2), x)
356 def test_reduced_prod_keepdim(self):
357 x = torch.randn(1, 2, 3, 4, requires_grad=
True)
358 self.
assertONNX(
lambda x: torch.prod(x, dim=2, keepdim=
True), x)
361 x = torch.randn(3, 4, requires_grad=
True)
364 def test_equal(self):
365 x = torch.randn(1, 2, 3, 1, requires_grad=
False).int()
366 y = torch.randn(1, 4, requires_grad=
False).int()
370 x = torch.randn(1, 2, 3, 1, requires_grad=
False).int()
371 y = torch.randn(1, 4, requires_grad=
False).int()
375 x = torch.randn(1, 2, 3, 1, requires_grad=
False).int()
376 y = torch.randn(1, 4, requires_grad=
False).int()
380 x = torch.randn(3, 4, requires_grad=
False).int()
381 y = torch.randn(3, 4, requires_grad=
False).int()
385 x = torch.randn(3, 4, requires_grad=
False).int()
386 y = torch.randn(3, 4, requires_grad=
False).int()
390 x = torch.randn(3, 4, requires_grad=
True)
394 x = torch.randn(3, 4, requires_grad=
True)
398 x = torch.randn(3, 4, requires_grad=
True)
402 x = torch.randn(3, 4, requires_grad=
True)
406 x = torch.rand(3, 4, requires_grad=
True)
410 x = torch.rand(3, 4, requires_grad=
True)
413 def test_slice(self):
414 x = torch.rand(3, 4, requires_grad=
True)
417 def test_narrow(self):
418 x = torch.randn(3, 3, requires_grad=
True)
419 self.
assertONNX(
lambda x: torch.narrow(x, 0, 0, 2), x)
422 x = torch.randn(3, 4, requires_grad=
True)
425 def test_view_flatten(self):
426 x = torch.randn(1, 2, 3, 4, requires_grad=
True)
427 self.
assertONNX(
lambda x: x.view(x.size()[0], x.numel() // x.size()[0]), x)
429 def test_flatten(self):
430 x = torch.randn(1, 2, 3, 4, requires_grad=
True)
431 self.
assertONNX(
lambda x: torch.flatten(x), x)
433 def test_flatten2D(self):
434 x = torch.randn(1, 2, 3, 4, requires_grad=
True)
435 self.
assertONNX(
lambda x: torch.flatten(x, 1), x)
437 def test_isnan(self):
441 def test_argmax(self):
442 x = torch.randn(4, 4, requires_grad=
True)
443 self.
assertONNX(
lambda x: torch.argmax(x, dim=1), x)
445 def test_logsoftmax(self):
446 x = torch.randn(1, 2, 3, 4, requires_grad=
True)
450 x = torch.randn(1, 2, 3, 4, requires_grad=
True)
451 y = torch.randn(1, 2, 3, 4, requires_grad=
True)
452 self.
assertONNX(
lambda x, y: x.pow(y), (x, y))
455 x = torch.randn(1, 2, 3, 4, requires_grad=
True)
459 x = torch.randn(1, 2, 3, 4, requires_grad=
True)
462 def test_repeat(self):
463 x = torch.randn(1, 2, 3, 4, requires_grad=
True)
464 self.
assertONNX(
lambda x: x.repeat(1, 2, 3, 4), x)
466 def test_repeat_dim_overflow(self):
467 x = torch.randn(1, 2, requires_grad=
True)
468 self.
assertONNX(
lambda x: x.repeat(1, 2, 3, 4), x)
471 x = torch.randn(1, 2, 3, 4, requires_grad=
True)
472 self.
assertONNX(
lambda x: x.norm(p=2, dim=2), (x))
474 @unittest.skip(
"Temporary - waiting for https://github.com/onnx/onnx/pull/1773.")
475 def test_upsample(self):
476 x = torch.randn(1, 2, 3, 4, requires_grad=
True)
477 self.
assertONNX(
lambda x: nn.functional.interpolate(x, scale_factor=2., mode=
'bilinear'), x)
479 def test_unsqueeze(self):
480 x = torch.randn(3, 4, requires_grad=
True)
481 self.
assertONNX(
lambda x: x.unsqueeze(len(x.shape)), x)
483 def test_batchnorm_noaffine(self):
484 x = torch.randn(128, 128, 1, 1, requires_grad=
True)
485 self.
assertONNX(nn.BatchNorm2d(128, affine=
False), x)
487 def test_embedding_bags(self):
488 emb_bag = nn.EmbeddingBag(10, 8)
493 def test_implicit_expand(self):
494 x = torch.randn(3, 4, requires_grad=
True)
497 def test_reduce_sum_negative_indices(self):
498 x = torch.randn(3, 4, requires_grad=
True)
501 def test_randn(self):
502 x = torch.randn(1, 2, 3, 4)
503 self.
assertONNX(
lambda x: torch.randn(1, 2, 3, 4) + x, x)
505 def test_rrelu(self):
506 x = torch.randn(1, 2, 3, 4)
509 def test_log_sigmoid(self):
510 x = torch.randn(1, 2, 3, 4)
513 def test_linear(self):
514 x = torch.randn(3, 4)
515 self.
assertONNX(torch.nn.Linear(4, 5, bias=
True), x)
517 def test_zeros_like(self):
518 x = torch.randn(5, 8, requires_grad=
True)
519 self.
assertONNX(
lambda x: torch.zeros_like(x), x)
521 def test_ones_like(self):
522 x = torch.randn(6, 10, requires_grad=
True)
523 self.
assertONNX(
lambda x: torch.ones_like(x), x)
525 def test_expand(self):
526 x = torch.randn(6, 1, requires_grad=
True)
527 self.
assertONNX(
lambda x: x.expand(4, 6, 2), x)
530 x = torch.randn(1, 2, 3, 1, requires_grad=
False).int()
531 y = torch.randn(1, 4, requires_grad=
False).int()
532 self.
assertONNX(
lambda x, y: torch.ne(x, y), (x, y))
534 def test_reducemax(self):
535 x = torch.randn(1, 2, 3, 4)
538 def test_reducemin(self):
539 x = torch.randn(1, 2, 3, 4)
543 x = torch.randn(1, 2, 3, 4)
546 def test_dropout(self):
547 x = torch.randn(3, 4, requires_grad=
True)
548 self.
assertONNX(
lambda x: torch.max(functional.dropout(x, training=
False)), x)
550 def test_nonzero(self):
551 x =
torch.tensor([[[2., 2.], [1., 0.]], [[0., 0.], [1., 1.]]], requires_grad=
True)
552 self.
assertONNX(
lambda x: torch.nonzero(x), x)
554 def test_master_opset(self):
555 x = torch.randn(2, 3).float()
556 y = torch.randn(2, 3).float()
557 self.
assertONNX(
lambda x, y: x + y, (x, y), opset_version=10)
559 def test_retain_param_name_disabled(self):
560 class MyModule(Module):
562 super(MyModule, self).__init__()
563 self.
fc1 = nn.Linear(4, 5, bias=
False)
564 self.fc1.weight.data.fill_(2.)
565 self.
fc2 = nn.Linear(5, 6, bias=
False)
566 self.fc2.weight.data.fill_(3.)
568 def forward(self, x):
569 return self.
fc2(self.
fc1(x))
571 x = torch.randn(3, 4).float()
572 self.
assertONNX(MyModule(), (x,), _retain_param_name=
False)
575 if __name__ ==
'__main__':
576 no_onnx_dep_flag =
'--no-onnx' 577 _onnx_dep = no_onnx_dep_flag
not in common.UNITTEST_ARGS
578 if no_onnx_dep_flag
in common.UNITTEST_ARGS:
579 common.UNITTEST_ARGS.remove(no_onnx_dep_flag)
580 onnx_test_flag =
'--produce-onnx-test-data' 581 _onnx_test = onnx_test_flag
in common.UNITTEST_ARGS
582 if onnx_test_flag
in common.UNITTEST_ARGS:
583 common.UNITTEST_ARGS.remove(onnx_test_flag)
586 import test_onnx_common
587 for d
in glob.glob(os.path.join(test_onnx_common.pytorch_operator_dir,
"test_operator_*")):
def export_to_pretty_string(args, kwargs)
def assertONNX(self, f, args, params=None, kwargs)