1 from __future__
import absolute_import
2 from __future__
import division
3 from __future__
import print_function
4 from __future__
import unicode_literals
7 from onnx
import numpy_helper
16 import test_pytorch_common
17 import test_onnx_common
18 from common_nn
import module_tests
19 from test_nn
import new_module_tests
23 def get_test_name(testcase):
24 if "fullname" in testcase:
25 return "test_" + testcase[
"fullname"]
27 test_name =
"test_" + testcase[
"constructor"].__name__
28 if "desc" in testcase:
29 test_name +=
"_" + testcase[
"desc"]
34 def gen_input(testcase):
35 if "input_size" in testcase:
36 if testcase[
"input_size"] == ()
and "desc" in testcase
and testcase[
"desc"][-6:] ==
"scalar":
37 testcase[
"input_size"] = (1,)
38 return Variable(torch.randn(*testcase[
"input_size"]))
39 elif "input_fn" in testcase:
40 input = testcase[
"input_fn"]()
41 if isinstance(input, Variable):
43 return Variable(testcase[
"input_fn"]())
46 def gen_module(testcase):
47 if "constructor_args" in testcase:
48 args = testcase[
"constructor_args"]
49 module = testcase[
"constructor"](*args)
52 module = testcase[
"constructor"]()
57 def print_stats(FunctionalModule_nums, nn_module):
58 print(
"{} functional modules detected.".format(FunctionalModule_nums))
61 not_fully_supported = []
62 for key, value
in nn_module.items():
66 unsupported.append(key)
68 not_fully_supported.append(key)
78 for info, l
in [[
"{} Fully Supported Operators:".format(len(supported)),
80 [
"{} Semi-Supported Operators:".format(len(not_fully_supported)),
82 [
"{} Unsupported Operators:".format(len(unsupported)),
87 def convert_tests(testcases, sets=1):
88 print(
"Collect {} test cases from PyTorch.".format(len(testcases)))
90 FunctionalModule_nums = 0
93 test_name = get_test_name(t)
94 module = gen_module(t)
95 module_name = str(module).
split(
"(")[0]
96 if (module_name ==
"FunctionalModule"):
97 FunctionalModule_nums += 1
99 if (module_name
not in nn_module):
100 nn_module[module_name] = 0
105 operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
106 onnx_model = onnx.load_from_string(f.getvalue())
107 onnx.checker.check_model(onnx_model)
108 onnx.helper.strip_doc_string(onnx_model)
109 output_dir = os.path.join(test_onnx_common.pytorch_converted_dir, test_name)
111 if os.path.exists(output_dir):
112 shutil.rmtree(output_dir)
113 os.makedirs(output_dir)
114 with open(os.path.join(output_dir,
"model.onnx"),
"wb")
as file:
115 file.write(onnx_model.SerializeToString())
117 for i
in range(sets):
118 output = module(input)
119 data_dir = os.path.join(output_dir,
"test_data_set_{}".format(i))
120 os.makedirs(data_dir)
122 for index, var
in enumerate([input]):
123 tensor = numpy_helper.from_array(var.data.numpy())
124 with open(os.path.join(data_dir,
"input_{}.pb".format(index)),
"wb")
as file:
125 file.write(tensor.SerializeToString())
126 for index, var
in enumerate([output]):
127 tensor = numpy_helper.from_array(var.data.numpy())
128 with open(os.path.join(data_dir,
"output_{}.pb".format(index)),
"wb")
as file:
129 file.write(tensor.SerializeToString())
131 if (module_name !=
"FunctionalModule"):
132 nn_module[module_name] |= 1
134 traceback.print_exc()
135 if (module_name !=
"FunctionalModule"):
136 nn_module[module_name] |= 2
139 print(
"Collect {} test cases from PyTorch repo, failed to export {} cases.".format(
140 len(testcases), failed))
141 print(
"PyTorch converted cases are stored in {}.".format(test_onnx_common.pytorch_converted_dir))
142 print_stats(FunctionalModule_nums, nn_module)
144 if __name__ ==
'__main__':
145 testcases = module_tests + new_module_tests
146 convert_tests(testcases)
Module caffe2.python.layers.split.
def _export(args, kwargs)