Caffe2 - Python API
A deep learning, cross platform ML framework
test_namedtuple_return_api.py
1 import os
2 import re
3 import yaml
4 import unittest
5 
6 
7 path = os.path.dirname(os.path.realpath(__file__))
8 aten_native_yaml = os.path.join(path, '../aten/src/ATen/native/native_functions.yaml')
9 whitelist = [
10  'max', 'min', 'median', 'mode', 'kthvalue', 'svd', 'symeig', 'eig',
11  'pstrf', 'qr', 'geqrf',
12 ]
13 
14 
15 class TestNamedTupleAPI(unittest.TestCase):
16 
17  def test_field_name(self):
18  regex = re.compile(r"^(\w*)\(")
19  file = open(aten_native_yaml, 'r')
20  for f in yaml.load(file.read()):
21  f = f['func']
22  ret = f.split('->')[1].strip()
23  name = regex.findall(f)[0]
24  if name in whitelist or name.endswith('_backward') or \
25  name.endswith('_forward'):
26  continue
27  if not ret.startswith('('):
28  continue
29  ret = ret[1:-1].split(',')
30  for r in ret:
31  r = r.strip()
32  self.assertEqual(len(r.split()), 1,
33  'only whitelisted operators are allowed to have named return type, got ' + name)
34  file.close()
35 
36 
37 if __name__ == '__main__':
38  unittest.main()
Module caffe2.python.layers.split.