Caffe2 - Python API
A deep learning, cross platform ML framework
export_onnx_tests_generator.py
1 from __future__ import absolute_import
2 from __future__ import division
3 from __future__ import print_function
4 from __future__ import unicode_literals
5 
6 from torch.autograd import Variable
7 from onnx import numpy_helper
8 
9 import io
10 import onnx
11 import os
12 import shutil
13 import torch
14 import traceback
15 
16 import test_pytorch_common
17 import test_onnx_common
18 from common_nn import module_tests
19 from test_nn import new_module_tests
20 
21 
22 # Take a test case (a dict) as input, return the test name.
23 def get_test_name(testcase):
24  if "fullname" in testcase:
25  return "test_" + testcase["fullname"]
26 
27  test_name = "test_" + testcase["constructor"].__name__
28  if "desc" in testcase:
29  test_name += "_" + testcase["desc"]
30  return test_name
31 
32 
33 # Take a test case (a dict) as input, return the input for the module.
34 def gen_input(testcase):
35  if "input_size" in testcase:
36  if testcase["input_size"] == () and "desc" in testcase and testcase["desc"][-6:] == "scalar":
37  testcase["input_size"] = (1,)
38  return Variable(torch.randn(*testcase["input_size"]))
39  elif "input_fn" in testcase:
40  input = testcase["input_fn"]()
41  if isinstance(input, Variable):
42  return input
43  return Variable(testcase["input_fn"]())
44 
45 
46 def gen_module(testcase):
47  if "constructor_args" in testcase:
48  args = testcase["constructor_args"]
49  module = testcase["constructor"](*args)
50  module.train(False)
51  return module
52  module = testcase["constructor"]()
53  module.train(False)
54  return module
55 
56 
57 def print_stats(FunctionalModule_nums, nn_module):
58  print("{} functional modules detected.".format(FunctionalModule_nums))
59  supported = []
60  unsupported = []
61  not_fully_supported = []
62  for key, value in nn_module.items():
63  if (value == 1):
64  supported.append(key)
65  elif (value == 2):
66  unsupported.append(key)
67  elif (value == 3):
68  not_fully_supported.append(key)
69 
70  def fun(info, l):
71  print(info)
72  for v in l:
73  print(v)
74 
75  # Fully Supported Ops: All related test cases of these ops have been exported
76  # Semi-Supported Ops: Part of related test cases of these ops have been exported
77  # Unsupported Ops: None of related test cases of these ops have been exported
78  for info, l in [["{} Fully Supported Operators:".format(len(supported)),
79  supported],
80  ["{} Semi-Supported Operators:".format(len(not_fully_supported)),
81  not_fully_supported],
82  ["{} Unsupported Operators:".format(len(unsupported)),
83  unsupported]]:
84  fun(info, l)
85 
86 
87 def convert_tests(testcases, sets=1):
88  print("Collect {} test cases from PyTorch.".format(len(testcases)))
89  failed = 0
90  FunctionalModule_nums = 0
91  nn_module = {}
92  for t in testcases:
93  test_name = get_test_name(t)
94  module = gen_module(t)
95  module_name = str(module).split("(")[0]
96  if (module_name == "FunctionalModule"):
97  FunctionalModule_nums += 1
98  else:
99  if (module_name not in nn_module):
100  nn_module[module_name] = 0
101  try:
102  input = gen_input(t)
103  f = io.BytesIO()
104  torch.onnx._export(module, input, f,
105  operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
106  onnx_model = onnx.load_from_string(f.getvalue())
107  onnx.checker.check_model(onnx_model)
108  onnx.helper.strip_doc_string(onnx_model)
109  output_dir = os.path.join(test_onnx_common.pytorch_converted_dir, test_name)
110 
111  if os.path.exists(output_dir):
112  shutil.rmtree(output_dir)
113  os.makedirs(output_dir)
114  with open(os.path.join(output_dir, "model.onnx"), "wb") as file:
115  file.write(onnx_model.SerializeToString())
116 
117  for i in range(sets):
118  output = module(input)
119  data_dir = os.path.join(output_dir, "test_data_set_{}".format(i))
120  os.makedirs(data_dir)
121 
122  for index, var in enumerate([input]):
123  tensor = numpy_helper.from_array(var.data.numpy())
124  with open(os.path.join(data_dir, "input_{}.pb".format(index)), "wb") as file:
125  file.write(tensor.SerializeToString())
126  for index, var in enumerate([output]):
127  tensor = numpy_helper.from_array(var.data.numpy())
128  with open(os.path.join(data_dir, "output_{}.pb".format(index)), "wb") as file:
129  file.write(tensor.SerializeToString())
130  input = gen_input(t)
131  if (module_name != "FunctionalModule"):
132  nn_module[module_name] |= 1
133  except: # noqa: E722
134  traceback.print_exc()
135  if (module_name != "FunctionalModule"):
136  nn_module[module_name] |= 2
137  failed += 1
138 
139  print("Collect {} test cases from PyTorch repo, failed to export {} cases.".format(
140  len(testcases), failed))
141  print("PyTorch converted cases are stored in {}.".format(test_onnx_common.pytorch_converted_dir))
142  print_stats(FunctionalModule_nums, nn_module)
143 
144 if __name__ == '__main__':
145  testcases = module_tests + new_module_tests
146  convert_tests(testcases)
Module caffe2.python.layers.split.
def _export(args, kwargs)
Definition: __init__.py:20