10 from .._utils_internal
import get_file_path
12 THNN_H_PATH = get_file_path(
'torch',
'include',
'THNN',
'generic',
'THNN.h')
13 THCUNN_H_PATH = get_file_path(
'torch',
'include',
'THCUNN',
'generic',
'THCUNN.h')
14 except Exception
as e:
18 def _unpickle_backend(backend_name):
20 return torch._thnn.type2backend[backend_name]
28 def __getattr__(self, name):
29 method = self.methods.get(name,
None)
31 raise NotImplementedError
34 def register_method(self, name, ctypes_fn):
38 def library_state(self):
42 return (_unpickle_backend, (type(self).__name__,))
47 def __init__(self, name):
51 def add_argument(self, arg):
52 assert isinstance(arg, Argument)
53 self.arguments.append(arg)
56 return self.
name +
'(' +
', '.join(map(
lambda a: a.__repr__(), self.
arguments)) +
')' 61 def __init__(self, _type, name, is_optional):
70 def parse_header(path):
71 with open(path,
'r') as f: 72 lines = f.read().split('\n')
75 lines = filter(
lambda l: l
and not l.startswith(
'#'), lines)
77 lines = map(
lambda l: l.partition(
'//'), lines)
79 lines = map(
lambda l: (l[0].strip(), l[2].strip()), lines)
81 lines = map(
lambda l: (l[0].rstrip(
');').rstrip(
','), l[1]), lines)
83 lines = map(
lambda l: (l[0].
split(
','), l[1]), lines)
88 new_lines.append((split, c))
92 lines = map(
lambda l: (l[0].strip(), l[1]), lines)
94 lines = filter(
lambda l: l[0], lines)
95 generic_functions = []
97 if l.startswith(
'TH_API void THNN_'):
98 fn_name = l.lstrip(
'TH_API void THNN_')
99 if fn_name[0] ==
'(' and fn_name[-2] ==
')':
100 fn_name = fn_name[1:-2]
102 fn_name = fn_name[:-1]
103 generic_functions.append(
Function(fn_name))
104 elif l.startswith(
'THC_API void THNN_'):
105 fn_name = l.lstrip(
'THC_API void THNN_')
106 if fn_name[0] ==
'(' and fn_name[-2] ==
')':
107 fn_name = fn_name[1:-2]
109 fn_name = fn_name[:-1]
110 generic_functions.append(
Function(fn_name))
116 generic_functions[-1].add_argument(
Argument(t, name,
'[OPTIONAL]' in c))
117 return generic_functions
120 def load_backend(t, lib, generic_functions, mixins=tuple()):
121 backend_name =
'THNN{}Backend'.format(t)
122 backend = type(backend_name, mixins + (THNNBackendBase,), {})()
123 for function
in generic_functions:
124 full_fn_name =
'{}{}'.format(t, function.name)
125 fn = getattr(lib, full_fn_name)
126 backend.register_method(function.name, fn)
Module caffe2.python.layers.split.