1 from __future__
import absolute_import
2 from __future__
import division
3 from __future__
import print_function
4 from __future__
import unicode_literals
9 import onnx.backend.test
13 from onnx
import numpy_helper
14 from test_caffe2_common
import run_generated_test
15 import google.protobuf.text_format
16 import test_onnx_common
19 _fail_test_dir = os.path.join(os.path.dirname(
20 os.path.realpath(__file__)),
"fail",
"generated")
23 _expect_dir = os.path.join(os.path.dirname(
24 os.path.realpath(__file__)),
"expect")
27 def collect_generated_testcases(root_dir=test_onnx_common.pytorch_converted_dir,
28 verbose=
False, fail_dir=
None, expect=
True):
31 for d
in os.listdir(root_dir):
32 dir_name = os.path.join(root_dir, d)
33 if os.path.isdir(dir_name):
36 model_file = os.path.join(dir_name,
"model.onnx")
37 data_dir_pattern = os.path.join(dir_name,
"test_data_set_*")
38 for data_dir
in glob.glob(data_dir_pattern):
39 for device
in [
'CPU',
'CUDA']:
40 run_generated_test(model_file, data_dir)
42 expect_file = os.path.join(_expect_dir,
43 "PyTorch-generated-{}.expect".format(d))
44 with open(expect_file,
"w")
as text_file:
45 model = onnx.load(model_file)
46 onnx.checker.check_model(model)
47 onnx.helper.strip_doc_string(model)
48 text_file.write(google.protobuf.text_format.MessageToString(model))
50 except Exception
as e:
52 print(
"The test case in {} failed!".format(dir_name))
55 shutil.rmtree(dir_name)
57 target_dir = os.path.join(fail_dir, d)
58 if os.path.exists(target_dir):
59 shutil.rmtree(target_dir)
60 shutil.move(dir_name, target_dir)
62 print(
"Successfully generated/updated {} test cases from PyTorch.".format(total_pass))
64 print(
"Expected pbtxt files are generated in {}.".format(_expect_dir))
65 print(
"Failed {} testcases are moved to {}.".format(total_fail, _fail_test_dir))
68 if __name__ ==
'__main__':
69 parser = argparse.ArgumentParser(description=
'Check and filter the failed test cases.')
70 parser.add_argument(
'-v', action=
"store_true", default=
False, help=
"verbose")
71 parser.add_argument(
'--delete', action=
"store_true", default=
False, help=
"delete failed test cases")
72 parser.add_argument(
'--no-expect', action=
"store_true", default=
False, help=
"generate expect txt files")
73 args = parser.parse_args()
76 expect =
not args.no_expect
77 fail_dir = _fail_test_dir
81 if not os.path.exists(fail_dir):
84 collect_generated_testcases(verbose=verbose, fail_dir=fail_dir, expect=expect)
86 collect_generated_testcases(root_dir=test_onnx_common.pytorch_operator_dir,
87 verbose=verbose, fail_dir=fail_dir, expect=
False)