Caffe2 - Python API
A deep learning, cross platform ML framework
test_operators.py
1 from test_pytorch_common import TestCase, run_tests, skipIfNoLapack, flatten
2 
3 import torch
4 import torch.onnx
5 from torch.autograd import Variable, Function
6 from torch.nn import Module, functional
7 import torch.nn as nn
8 
9 import itertools
10 import io
11 import unittest
12 import inspect
13 import argparse
14 import glob
15 import os
16 import shutil
17 import sys
18 import common_utils as common
19 
20 
21 '''Usage: python test/onnx/test_operators.py [--no-onnx] [--produce-onnx-test-data]
22  --no-onnx: no onnx python dependence
23  --produce-onnx-test-data: generate onnx test data
24 '''
25 
26 _onnx_test = False # flag to produce onnx test cases.
27 _onnx_dep = True # flag to import onnx package.
28 
29 
30 def export_to_pbtxt(model, inputs, *args, **kwargs):
32  model, inputs, None, verbose=False, google_printer=True,
33  *args, **kwargs)
34 
35 
36 def export_to_pb(model, inputs, *args, **kwargs):
37  kwargs['operator_export_type'] = torch.onnx.OperatorExportTypes.ONNX
38  f = io.BytesIO()
39  with torch.no_grad():
40  torch.onnx.export(model, inputs, f, *args, **kwargs)
41  return f.getvalue()
42 
43 
44 class FuncModule(Module):
45  def __init__(self, f, params=None):
46  if params is None:
47  params = ()
48  super(FuncModule, self).__init__()
49  self.f = f
50  self.params = nn.ParameterList(list(params))
51 
52  def forward(self, *args):
53  return self.f(*itertools.chain(args, self.params))
54 
55 
56 class TestOperators(TestCase):
57 
58  def assertONNX(self, f, args, params=None, **kwargs):
59  if params is None:
60  params = ()
61  if isinstance(f, nn.Module):
62  m = f
63  else:
64  m = FuncModule(f, params)
65  m.eval()
66  onnx_model_pbtxt = export_to_pbtxt(m, args, **kwargs)
67  subname = kwargs.pop('subname', None)
68  self.assertExpected(onnx_model_pbtxt, subname)
69  if _onnx_dep:
70  onnx_model_pb = export_to_pb(m, args, **kwargs)
71  import onnx
72  import onnx.checker
73  import onnx.numpy_helper
74  import test_onnx_common
75  model_def = onnx.ModelProto.FromString(onnx_model_pb)
76  onnx.checker.check_model(model_def)
77  if _onnx_test:
78  test_function = inspect.stack()[1][0].f_code.co_name
79  test_name = test_function[0:4] + "_operator" + test_function[4:]
80  output_dir = os.path.join(test_onnx_common.pytorch_operator_dir, test_name)
81  # Assume:
82  # 1) the old test should be delete before the test.
83  # 2) only one assertONNX in each test, otherwise will override the data.
84  assert not os.path.exists(output_dir), "{} should not exist!".format(output_dir)
85  os.makedirs(output_dir)
86  with open(os.path.join(output_dir, "model.onnx"), 'wb') as file:
87  file.write(model_def.SerializeToString())
88  data_dir = os.path.join(output_dir, "test_data_set_0")
89  os.makedirs(data_dir)
90  if isinstance(args, Variable):
91  args = (args,)
92  for index, var in enumerate(flatten(args)):
93  tensor = onnx.numpy_helper.from_array(var.data.numpy())
94  with open(os.path.join(data_dir, "input_{}.pb".format(index)), 'wb') as file:
95  file.write(tensor.SerializeToString())
96  outputs = m(*args)
97  if isinstance(outputs, Variable):
98  outputs = (outputs,)
99  for index, var in enumerate(flatten(outputs)):
100  tensor = onnx.numpy_helper.from_array(var.data.numpy())
101  with open(os.path.join(data_dir, "output_{}.pb".format(index)), 'wb') as file:
102  file.write(tensor.SerializeToString())
103 
104  def assertONNXRaises(self, err, f, args, params=None, **kwargs):
105  if params is None:
106  params = ()
107  if isinstance(f, nn.Module):
108  m = f
109  else:
110  m = FuncModule(f, params)
111  self.assertExpectedRaises(err, lambda: export_to_pbtxt(m, args, **kwargs))
112 
113  def assertONNXRaisesRegex(self, err, reg, f, args, params=None, **kwargs):
114  if params is None:
115  params = ()
116  if isinstance(f, nn.Module):
117  m = f
118  else:
119  m = FuncModule(f, params)
120  with self.assertRaisesRegex(err, reg):
121  export_to_pbtxt(m, args, **kwargs)
122 
123  def test_basic(self):
124  x = torch.tensor([0.4], requires_grad=True)
125  y = torch.tensor([0.7], requires_grad=True)
126  self.assertONNX(lambda x, y: -torch.sigmoid(torch.tanh(x * (x + y))), (x, y))
127 
128  def test_view(self):
129  x = torch.tensor([0.0], requires_grad=True)
130  self.assertONNX(lambda x: x.view(1, 1), x)
131 
132  def test_index(self):
133  x = torch.tensor([[0.0]], requires_grad=True)
134  self.assertONNX(lambda x: x[0], x)
135 
136  def test_type_as(self):
137  x = torch.tensor([0.0], requires_grad=True)
138  self.assertONNX(lambda x: x.type_as(x), x)
139 
140  def test_addconstant(self):
141  x = torch.randn(2, 3, requires_grad=True).double()
142  self.assertONNX(lambda x: x + 1, x)
143 
144  def test_add_broadcast(self):
145  x = torch.randn(2, 3, requires_grad=True).double()
146  y = torch.randn(3, requires_grad=True).double()
147  self.assertONNX(lambda x, y: x + y, (x, y))
148 
149  def test_add_left_broadcast(self):
150  x = torch.randn(3, requires_grad=True).double()
151  y = torch.randn(2, 3, requires_grad=True).double()
152  self.assertONNX(lambda x, y: x + y, (x, y))
153 
154  def test_add_size1_broadcast(self):
155  x = torch.randn(2, 3, requires_grad=True).double()
156  y = torch.randn(2, 1, requires_grad=True).double()
157  self.assertONNX(lambda x, y: x + y, (x, y))
158 
159  def test_add_size1_right_broadcast(self):
160  x = torch.randn(2, 3, requires_grad=True).double()
161  y = torch.randn(3, requires_grad=True).double()
162  self.assertONNX(lambda x, y: x + y, (x, y))
163 
164  def test_add_size1_singleton_broadcast(self):
165  x = torch.randn(2, 3, requires_grad=True).double()
166  y = torch.randn(1, 3, requires_grad=True).double()
167  self.assertONNX(lambda x, y: x + y, (x, y))
168 
169  def test_rsub(self):
170  x = torch.randn(2, 3, requires_grad=True).double()
171  self.assertONNX(lambda x: 1 - x, (x,))
172 
173  def test_transpose(self):
174  x = torch.tensor([[0.0, 1.0], [2.0, 3.0]], requires_grad=True)
175  self.assertONNX(lambda x: x.transpose(0, 1).transpose(1, 0), x)
176 
177  def test_chunk(self):
178  x = torch.tensor([0.0, 1.0, 2.0], requires_grad=True)
179  self.assertONNX(lambda x: x.chunk(2), x)
180 
181  def test_split(self):
182  x = torch.tensor([[0.0, 1.0, 1.0, 0.0, 2.0, 2.0], [2.0, 3.0, 3.0, 2.0, 1.0, 1.0]])
183  self.assertONNX(lambda x: torch.split(x, 2, 1), x)
184 
185  def test_split_with_sizes(self):
186  x = torch.tensor([[0.0, 1.0, 1.0, 0.0, 2.0, 2.0], [2.0, 3.0, 3.0, 2.0, 1.0, 1.0]])
187  self.assertONNX(lambda x: torch.split(x, [2, 1, 3], 1), x)
188 
189  def test_concat2(self):
190  x = torch.randn(2, 3)
191  y = torch.randn(2, 3)
192  self.assertONNX(lambda inputs: torch.cat(inputs, 1), ((x, y),))
193 
194  def test_mm(self):
195  m1 = torch.randn(2, 3, requires_grad=True)
196  m2 = torch.randn(3, 4, requires_grad=True)
197  self.assertONNX(torch.mm, (m1, m2))
198 
199  def test_addmm(self):
200  m1 = torch.randn(2, 3, requires_grad=True)
201  m2 = torch.randn(3, 4, requires_grad=True)
202  m3 = torch.randn(4, requires_grad=True)
203  self.assertONNX(lambda x, y, z: torch.addmm(torch.addmm(z, x, y), x, y), (m1, m2, m3))
204 
205  def test_permute2(self):
206  x = torch.tensor([[[[[[0.0]]]]]], requires_grad=True)
207  self.assertONNX(lambda x: x.permute(0, 1, 4, 2, 5, 3), x)
208 
209  def test_pad(self):
210  x = torch.tensor([[[[0.0, 1.0, 1.0, 1.0], [2.0, 3.0, 7.0, 7.0]]]], requires_grad=True)
211  self.assertONNX(nn.ReflectionPad2d((2, 3, 0, 1)), x)
212 
213  def test_params(self):
214  x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
215  y = nn.Parameter(torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True))
216  self.assertONNX(lambda x, y: -torch.sigmoid(torch.tanh(x * (x + y))), x, params=(y, ))
217 
218  def test_symbolic_mismatch(self):
219  class MyFun(Function):
220  @staticmethod
221  def symbolic(g, x):
222  # The inside of this function should never be invoked, because
223  # we will fail due to an argument mismatch first.
224  assert False
225 
226  @staticmethod
227  def forward(ctx, x, y):
228  return x + y
229 
230  x = torch.ones(2, 2)
231  y = torch.ones(2, 2)
232  # NB: Don't use expect test here, the type error wobbles depending
233  # on Python version
234  with self.assertRaisesRegex(TypeError, "occurred when translating MyFun"):
235  export_to_pbtxt(FuncModule(MyFun().apply), (x, y))
236 
237  # TODO: Do an nn style test for these
238  def test_batchnorm(self):
239  x = torch.ones(2, 2, 2, 2, requires_grad=True)
240  self.assertONNX(nn.BatchNorm2d(2), x)
241 
242  def test_batchnorm_1d(self):
243  x = torch.ones(2, 2, requires_grad=True)
244  self.assertONNX(nn.BatchNorm1d(2), x)
245 
246  def test_batchnorm_training(self):
247  x = torch.ones(2, 2, 2, 2, requires_grad=True)
248  self.assertONNX(nn.BatchNorm2d(2), x, training=True)
249 
250  def test_conv(self):
251  x = torch.ones(20, 16, 50, 40, requires_grad=True)
252  self.assertONNX(nn.Conv2d(16, 13, 3, bias=False), x)
253 
254  def test_convtranspose(self):
255  x = torch.ones(2, 3, 4, 5, requires_grad=True)
256  self.assertONNX(nn.ConvTranspose2d(3, 3, 3, stride=3, bias=False,
257  padding=1, output_padding=2), x)
258 
259  def test_maxpool(self):
260  x = torch.randn(20, 16, 50)
261  self.assertONNX(nn.MaxPool1d(3, stride=2), x)
262 
263  def test_avg_pool2d(self):
264  x = torch.randn(20, 16, 50, 32)
265  self.assertONNX(nn.AvgPool2d(3, stride=2), x)
266 
267  def test_maxpool_indices(self):
268  x = torch.randn(20, 16, 50)
269  self.assertONNX(nn.MaxPool1d(3, stride=2, return_indices=True), x)
270 
271  def test_at_op(self):
272  x = torch.randn(3, 4)
273 
274  class MyFun(Function):
275 
276  @staticmethod
277  def symbolic(g, x):
278  return g.at("add", x, x)
279 
280  @staticmethod
281  def forward(ctx, x):
282  return x + x
283 
284  class MyModule(Module):
285  def forward(self, x):
286  return MyFun.apply(x)
287 
288  self.assertONNX(MyModule(), x)
289 
290  def test_clip(self):
291  x = torch.randn(3, 4, requires_grad=True)
292  self.assertONNX(lambda x: torch.clamp(x, min=-0.5, max=0.5), x)
293 
294  def test_clip_min(self):
295  x = torch.randn(1, 2, 3, 4, requires_grad=True)
296  self.assertONNX(lambda x: x.clamp(min=-0.1), x)
297 
298  def test_clip_max(self):
299  x = torch.randn(1, 2, 3, 4, requires_grad=True)
300  self.assertONNX(lambda x: x.clamp(max=0.1), x)
301 
302  def test_hardtanh(self):
303  x = torch.randn(3, 4, requires_grad=True)
304  self.assertONNX(lambda x: torch.nn.Hardtanh(-0.5, 0.5)(x), x)
305 
306  def test_full(self):
307  x = torch.randn(3, 4, requires_grad=True)
308  self.assertONNX(lambda x: torch.full(x.shape, 2), x)
309 
310  def test_full_like(self):
311  x = torch.randn(3, 4, requires_grad=True)
312  self.assertONNX(lambda x: torch.full_like(x, 2), x)
313 
314  def test_max(self):
315  x = torch.randn(3, 4, requires_grad=True)
316  y = torch.randn(3, 4, requires_grad=True)
317  self.assertONNX(lambda x, y: torch.max(x, y), (x, y))
318 
319  def test_min(self):
320  x = torch.randn(3, 4, requires_grad=True)
321  y = torch.randn(3, 4, requires_grad=True)
322  self.assertONNX(lambda x, y: torch.min(x, y), (x, y))
323 
324  def test_mean(self):
325  x = torch.randn(1, 2, 3, 4, requires_grad=True)
326  self.assertONNX(lambda x: torch.mean(x), x)
327 
328  def test_reduced_mean(self):
329  x = torch.randn(1, 2, 3, 4, requires_grad=True)
330  self.assertONNX(lambda x: torch.mean(x, dim=2), x)
331 
332  def test_reduced_mean_keepdim(self):
333  x = torch.randn(1, 2, 3, 4, requires_grad=True)
334  self.assertONNX(lambda x: torch.mean(x, dim=2, keepdim=True), x)
335 
336  def test_sum(self):
337  x = torch.randn(1, 2, 3, 4, requires_grad=True)
338  self.assertONNX(lambda x: torch.sum(x), x)
339 
340  def test_reduced_sum(self):
341  x = torch.randn(1, 2, 3, 4, requires_grad=True)
342  self.assertONNX(lambda x: torch.sum(x, dim=2), x)
343 
344  def test_reduced_sum_keepdim(self):
345  x = torch.randn(1, 2, 3, 4, requires_grad=True)
346  self.assertONNX(lambda x: torch.sum(x, dim=2, keepdim=True), x)
347 
348  def test_prod(self):
349  x = torch.randn(1, 2, 3, 4, requires_grad=True)
350  self.assertONNX(lambda x: torch.prod(x), x)
351 
352  def test_reduced_prod(self):
353  x = torch.randn(1, 2, 3, 4, requires_grad=True)
354  self.assertONNX(lambda x: torch.prod(x, dim=2), x)
355 
356  def test_reduced_prod_keepdim(self):
357  x = torch.randn(1, 2, 3, 4, requires_grad=True)
358  self.assertONNX(lambda x: torch.prod(x, dim=2, keepdim=True), x)
359 
360  def test_sqrt(self):
361  x = torch.randn(3, 4, requires_grad=True)
362  self.assertONNX(lambda x: torch.sqrt(x), x)
363 
364  def test_equal(self):
365  x = torch.randn(1, 2, 3, 1, requires_grad=False).int()
366  y = torch.randn(1, 4, requires_grad=False).int()
367  self.assertONNX(lambda x, y: x == y, (x, y))
368 
369  def test_lt(self):
370  x = torch.randn(1, 2, 3, 1, requires_grad=False).int()
371  y = torch.randn(1, 4, requires_grad=False).int()
372  self.assertONNX(lambda x, y: x < y, (x, y))
373 
374  def test_gt(self):
375  x = torch.randn(1, 2, 3, 1, requires_grad=False).int()
376  y = torch.randn(1, 4, requires_grad=False).int()
377  self.assertONNX(lambda x, y: x > y, (x, y))
378 
379  def test_le(self):
380  x = torch.randn(3, 4, requires_grad=False).int()
381  y = torch.randn(3, 4, requires_grad=False).int()
382  self.assertONNX(lambda x, y: x <= y, (x, y))
383 
384  def test_ge(self):
385  x = torch.randn(3, 4, requires_grad=False).int()
386  y = torch.randn(3, 4, requires_grad=False).int()
387  self.assertONNX(lambda x, y: x >= y, (x, y))
388 
389  def test_exp(self):
390  x = torch.randn(3, 4, requires_grad=True)
391  self.assertONNX(lambda x: x.exp(), x)
392 
393  def test_sin(self):
394  x = torch.randn(3, 4, requires_grad=True)
395  self.assertONNX(lambda x: x.sin(), x)
396 
397  def test_cos(self):
398  x = torch.randn(3, 4, requires_grad=True)
399  self.assertONNX(lambda x: x.cos(), x)
400 
401  def test_tan(self):
402  x = torch.randn(3, 4, requires_grad=True)
403  self.assertONNX(lambda x: x.tan(), x)
404 
405  def test_asin(self):
406  x = torch.rand(3, 4, requires_grad=True)
407  self.assertONNX(lambda x: x.asin(), x)
408 
409  def test_acos(self):
410  x = torch.rand(3, 4, requires_grad=True)
411  self.assertONNX(lambda x: x.acos(), x)
412 
413  def test_slice(self):
414  x = torch.rand(3, 4, requires_grad=True)
415  self.assertONNX(lambda x: x[:, 1:2], x)
416 
417  def test_narrow(self):
418  x = torch.randn(3, 3, requires_grad=True)
419  self.assertONNX(lambda x: torch.narrow(x, 0, 0, 2), x)
420 
421  def test_atan(self):
422  x = torch.randn(3, 4, requires_grad=True)
423  self.assertONNX(lambda x: x.atan(), x)
424 
425  def test_view_flatten(self):
426  x = torch.randn(1, 2, 3, 4, requires_grad=True)
427  self.assertONNX(lambda x: x.view(x.size()[0], x.numel() // x.size()[0]), x)
428 
429  def test_flatten(self):
430  x = torch.randn(1, 2, 3, 4, requires_grad=True)
431  self.assertONNX(lambda x: torch.flatten(x), x)
432 
433  def test_flatten2D(self):
434  x = torch.randn(1, 2, 3, 4, requires_grad=True)
435  self.assertONNX(lambda x: torch.flatten(x, 1), x)
436 
437  def test_isnan(self):
438  x = torch.tensor([1, float('nan'), 2])
439  self.assertONNX(lambda x: torch.isnan(x), x)
440 
441  def test_argmax(self):
442  x = torch.randn(4, 4, requires_grad=True)
443  self.assertONNX(lambda x: torch.argmax(x, dim=1), x)
444 
445  def test_logsoftmax(self):
446  x = torch.randn(1, 2, 3, 4, requires_grad=True)
447  self.assertONNX(nn.LogSoftmax(dim=3), x)
448 
449  def test_pow(self):
450  x = torch.randn(1, 2, 3, 4, requires_grad=True)
451  y = torch.randn(1, 2, 3, 4, requires_grad=True)
452  self.assertONNX(lambda x, y: x.pow(y), (x, y))
453 
454  def test_elu(self):
455  x = torch.randn(1, 2, 3, 4, requires_grad=True)
456  self.assertONNX(nn.ELU(), x)
457 
458  def test_selu(self):
459  x = torch.randn(1, 2, 3, 4, requires_grad=True)
460  self.assertONNX(nn.SELU(), x)
461 
462  def test_repeat(self):
463  x = torch.randn(1, 2, 3, 4, requires_grad=True)
464  self.assertONNX(lambda x: x.repeat(1, 2, 3, 4), x)
465 
466  def test_repeat_dim_overflow(self):
467  x = torch.randn(1, 2, requires_grad=True)
468  self.assertONNX(lambda x: x.repeat(1, 2, 3, 4), x)
469 
470  def test_norm(self):
471  x = torch.randn(1, 2, 3, 4, requires_grad=True)
472  self.assertONNX(lambda x: x.norm(p=2, dim=2), (x))
473 
474  @unittest.skip("Temporary - waiting for https://github.com/onnx/onnx/pull/1773.")
475  def test_upsample(self):
476  x = torch.randn(1, 2, 3, 4, requires_grad=True)
477  self.assertONNX(lambda x: nn.functional.interpolate(x, scale_factor=2., mode='bilinear'), x)
478 
479  def test_unsqueeze(self):
480  x = torch.randn(3, 4, requires_grad=True)
481  self.assertONNX(lambda x: x.unsqueeze(len(x.shape)), x)
482 
483  def test_batchnorm_noaffine(self):
484  x = torch.randn(128, 128, 1, 1, requires_grad=True)
485  self.assertONNX(nn.BatchNorm2d(128, affine=False), x)
486 
487  def test_embedding_bags(self):
488  emb_bag = nn.EmbeddingBag(10, 8)
489  input = torch.tensor([1, 2, 3, 4]).long()
490  offset = torch.tensor([0]).long()
491  self.assertONNX(emb_bag, (input, offset))
492 
493  def test_implicit_expand(self):
494  x = torch.randn(3, 4, requires_grad=True)
495  self.assertONNX(lambda x: x + 1, x)
496 
497  def test_reduce_sum_negative_indices(self):
498  x = torch.randn(3, 4, requires_grad=True)
499  self.assertONNX(lambda x: x.sum(-1), x)
500 
501  def test_randn(self):
502  x = torch.randn(1, 2, 3, 4)
503  self.assertONNX(lambda x: torch.randn(1, 2, 3, 4) + x, x)
504 
505  def test_rrelu(self):
506  x = torch.randn(1, 2, 3, 4)
507  self.assertONNX(torch.nn.RReLU(), x)
508 
509  def test_log_sigmoid(self):
510  x = torch.randn(1, 2, 3, 4)
511  self.assertONNX(torch.nn.LogSigmoid(), x)
512 
513  def test_linear(self):
514  x = torch.randn(3, 4)
515  self.assertONNX(torch.nn.Linear(4, 5, bias=True), x)
516 
517  def test_zeros_like(self):
518  x = torch.randn(5, 8, requires_grad=True)
519  self.assertONNX(lambda x: torch.zeros_like(x), x)
520 
521  def test_ones_like(self):
522  x = torch.randn(6, 10, requires_grad=True)
523  self.assertONNX(lambda x: torch.ones_like(x), x)
524 
525  def test_expand(self):
526  x = torch.randn(6, 1, requires_grad=True)
527  self.assertONNX(lambda x: x.expand(4, 6, 2), x)
528 
529  def test_ne(self):
530  x = torch.randn(1, 2, 3, 1, requires_grad=False).int()
531  y = torch.randn(1, 4, requires_grad=False).int()
532  self.assertONNX(lambda x, y: torch.ne(x, y), (x, y))
533 
534  def test_reducemax(self):
535  x = torch.randn(1, 2, 3, 4)
536  self.assertONNX(lambda x: torch.max(x), x)
537 
538  def test_reducemin(self):
539  x = torch.randn(1, 2, 3, 4)
540  self.assertONNX(lambda x: torch.min(x), x)
541 
542  def test_erf(self):
543  x = torch.randn(1, 2, 3, 4)
544  self.assertONNX(lambda x: x.erf(), x)
545 
546  def test_dropout(self):
547  x = torch.randn(3, 4, requires_grad=True)
548  self.assertONNX(lambda x: torch.max(functional.dropout(x, training=False)), x)
549 
550  def test_nonzero(self):
551  x = torch.tensor([[[2., 2.], [1., 0.]], [[0., 0.], [1., 1.]]], requires_grad=True)
552  self.assertONNX(lambda x: torch.nonzero(x), x)
553 
554  def test_master_opset(self):
555  x = torch.randn(2, 3).float()
556  y = torch.randn(2, 3).float()
557  self.assertONNX(lambda x, y: x + y, (x, y), opset_version=10)
558 
559  def test_retain_param_name_disabled(self):
560  class MyModule(Module):
561  def __init__(self):
562  super(MyModule, self).__init__()
563  self.fc1 = nn.Linear(4, 5, bias=False)
564  self.fc1.weight.data.fill_(2.)
565  self.fc2 = nn.Linear(5, 6, bias=False)
566  self.fc2.weight.data.fill_(3.)
567 
568  def forward(self, x):
569  return self.fc2(self.fc1(x))
570 
571  x = torch.randn(3, 4).float()
572  self.assertONNX(MyModule(), (x,), _retain_param_name=False)
573 
574 
575 if __name__ == '__main__':
576  no_onnx_dep_flag = '--no-onnx'
577  _onnx_dep = no_onnx_dep_flag not in common.UNITTEST_ARGS
578  if no_onnx_dep_flag in common.UNITTEST_ARGS:
579  common.UNITTEST_ARGS.remove(no_onnx_dep_flag)
580  onnx_test_flag = '--produce-onnx-test-data'
581  _onnx_test = onnx_test_flag in common.UNITTEST_ARGS
582  if onnx_test_flag in common.UNITTEST_ARGS:
583  common.UNITTEST_ARGS.remove(onnx_test_flag)
584  if _onnx_test:
585  _onnx_dep = True
586  import test_onnx_common
587  for d in glob.glob(os.path.join(test_onnx_common.pytorch_operator_dir, "test_operator_*")):
588  shutil.rmtree(d)
589  run_tests()
def export_to_pretty_string(args, kwargs)
Definition: __init__.py:30
def assertONNX(self, f, args, params=None, kwargs)
def export(args, kwargs)
Definition: __init__.py:25