Caffe2 - Python API
A deep learning, cross platform ML framework
export_onnx_tests_filter.py
1 from __future__ import absolute_import
2 from __future__ import division
3 from __future__ import print_function
4 from __future__ import unicode_literals
5 
6 import argparse
7 import glob
8 import numpy as np
9 import onnx.backend.test
10 import caffe2.python.onnx.backend as c2
11 import os
12 import shutil
13 from onnx import numpy_helper
14 from test_caffe2_common import run_generated_test
15 import google.protobuf.text_format
16 import test_onnx_common
17 import traceback
18 
19 _fail_test_dir = os.path.join(os.path.dirname(
20  os.path.realpath(__file__)), "fail", "generated")
21 
22 
23 _expect_dir = os.path.join(os.path.dirname(
24  os.path.realpath(__file__)), "expect")
25 
26 
27 def collect_generated_testcases(root_dir=test_onnx_common.pytorch_converted_dir,
28  verbose=False, fail_dir=None, expect=True):
29  total_pass = 0
30  total_fail = 0
31  for d in os.listdir(root_dir):
32  dir_name = os.path.join(root_dir, d)
33  if os.path.isdir(dir_name):
34  failed = False
35  try:
36  model_file = os.path.join(dir_name, "model.onnx")
37  data_dir_pattern = os.path.join(dir_name, "test_data_set_*")
38  for data_dir in glob.glob(data_dir_pattern):
39  for device in ['CPU', 'CUDA']:
40  run_generated_test(model_file, data_dir)
41  if expect:
42  expect_file = os.path.join(_expect_dir,
43  "PyTorch-generated-{}.expect".format(d))
44  with open(expect_file, "w") as text_file:
45  model = onnx.load(model_file)
46  onnx.checker.check_model(model)
47  onnx.helper.strip_doc_string(model)
48  text_file.write(google.protobuf.text_format.MessageToString(model))
49  total_pass += 1
50  except Exception as e:
51  if verbose:
52  print("The test case in {} failed!".format(dir_name))
53  traceback.print_exc()
54  if fail_dir is None:
55  shutil.rmtree(dir_name)
56  else:
57  target_dir = os.path.join(fail_dir, d)
58  if os.path.exists(target_dir):
59  shutil.rmtree(target_dir)
60  shutil.move(dir_name, target_dir)
61  total_fail += 1
62  print("Successfully generated/updated {} test cases from PyTorch.".format(total_pass))
63  if expect:
64  print("Expected pbtxt files are generated in {}.".format(_expect_dir))
65  print("Failed {} testcases are moved to {}.".format(total_fail, _fail_test_dir))
66 
67 
68 if __name__ == '__main__':
69  parser = argparse.ArgumentParser(description='Check and filter the failed test cases.')
70  parser.add_argument('-v', action="store_true", default=False, help="verbose")
71  parser.add_argument('--delete', action="store_true", default=False, help="delete failed test cases")
72  parser.add_argument('--no-expect', action="store_true", default=False, help="generate expect txt files")
73  args = parser.parse_args()
74  verbose = args.v
75  delete = args.delete
76  expect = not args.no_expect
77  fail_dir = _fail_test_dir
78  if delete:
79  fail_dir = None
80  if fail_dir:
81  if not os.path.exists(fail_dir):
82  os.makedirs(fail_dir)
83 
84  collect_generated_testcases(verbose=verbose, fail_dir=fail_dir, expect=expect)
85  # We already generate the expect files for test_operators.py.
86  collect_generated_testcases(root_dir=test_onnx_common.pytorch_operator_dir,
87  verbose=verbose, fail_dir=fail_dir, expect=False)