7 path = os.path.dirname(os.path.realpath(__file__))
8 aten_native_yaml = os.path.join(path,
'../aten/src/ATen/native/native_functions.yaml')
10 'max',
'min',
'median',
'mode',
'kthvalue',
'svd',
'symeig',
'eig',
11 'pstrf',
'qr',
'geqrf',
17 def test_field_name(self):
18 regex = re.compile(
r"^(\w*)\(")
19 file = open(aten_native_yaml,
'r') 20 for f
in yaml.load(file.read()):
22 ret = f.split(
'->')[1].strip()
23 name = regex.findall(f)[0]
24 if name
in whitelist
or name.endswith(
'_backward')
or \
25 name.endswith(
'_forward'):
27 if not ret.startswith(
'('):
29 ret = ret[1:-1].
split(
',')
32 self.assertEqual(len(r.split()), 1,
33 'only whitelisted operators are allowed to have named return type, got ' + name)
37 if __name__ ==
'__main__':
Module caffe2.python.layers.split.