4 from copy
import deepcopy
5 from itertools
import product
8 def parse_arguments(args):
12 if isinstance(arg, str):
13 t, _, name = arg.partition(
' ')
14 new_args.append({
'type': t,
'name': name})
15 elif isinstance(arg, dict):
17 arg[
'type'], _, arg[
'name'] = arg[
'arg'].partition(
' ')
25 def set_declaration_defaults(declaration):
26 if 'schema_string' not in declaration:
27 declaration[
'schema_string'] =
'' 28 if 'matches_jit_signature' not in declaration:
29 declaration[
'matches_jit_signature'] =
False 30 declaration.setdefault(
'arguments', [])
31 declaration.setdefault(
'return',
'void')
32 if 'cname' not in declaration:
33 declaration[
'cname'] = declaration[
'name']
34 if 'backends' not in declaration:
35 declaration[
'backends'] = [
'CPU',
'CUDA']
36 if 'api_name' not in declaration:
37 declaration[
'api_name'] = declaration[
'name']
39 if 'options' not in declaration:
40 declaration[
'options'] = [{
'arguments': declaration[
'arguments']}]
41 del declaration[
'arguments']
43 for option
in declaration[
'options']:
44 option[
'arguments'] = parse_arguments(option[
'arguments'])
46 for option
in declaration[
'options']:
47 for k, v
in declaration.items():
51 option.setdefault(k, v)
57 def filter_unique_options(options, allow_kwarg, type_to_signature, remove_self):
59 return arg.get(
'ignore_check')
or arg[
'type'] ==
'CONSTANT' 61 def exclude_arg_with_self_check(arg):
62 return exclude_arg(arg)
or (remove_self
and arg[
'name'] ==
'self')
64 def signature(option, kwarg_only_count):
65 if kwarg_only_count == 0:
66 kwarg_only_count =
None 68 kwarg_only_count = -kwarg_only_count
69 arg_signature =
'#'.join(
70 type_to_signature.get(arg[
'type'], arg[
'type'])
71 for arg
in option[
'arguments'][:kwarg_only_count]
72 if not exclude_arg_with_self_check(arg))
73 if kwarg_only_count
is None:
75 kwarg_only_signature =
'#'.join(
76 arg[
'name'] +
'#' + arg[
'type']
77 for arg
in option[
'arguments'][kwarg_only_count:]
78 if not exclude_arg(arg))
79 return arg_signature +
"#-#" + kwarg_only_signature
80 seen_signatures = set()
82 for option
in options:
84 limit = len(option[
'arguments'])
if allow_kwarg
else 0
85 for num_kwarg_only
in range(0, limit + 1):
86 sig = signature(option, num_kwarg_only)
87 if sig
not in seen_signatures:
88 if num_kwarg_only > 0:
89 for arg
in option[
'arguments'][-num_kwarg_only:]:
90 arg[
'kwarg_only'] =
True 92 seen_signatures.add(sig)
97 def enumerate_options_due_to_default(declaration,
98 allow_kwarg=
True, type_to_signature=[], remove_self=
True):
104 def is_nullable_tensor_arg(arg):
105 return arg[
'type'] ==
'THTensor*' and arg[
'default'] ==
'nullptr' 110 for option
in declaration[
'options']:
112 for i, arg
in enumerate(option[
'arguments']):
114 optional_args.append(i)
115 for permutation
in product((
True,
False), repeat=len(optional_args)):
116 option_copy = deepcopy(option)
117 option_copy[
'has_full_argument_list'] = sum(permutation) == len(optional_args)
118 for i, bit
in zip(optional_args, permutation):
119 arg = option_copy[
'arguments'][i]
121 arg[
'default'] =
'NULL' if arg[
'default']
is None else arg[
'default']
123 arg[
'declared_type'] = arg[
'type']
124 arg[
'type'] =
'CONSTANT' 125 arg[
'ignore_check'] =
True 126 new_options.append(option_copy)
127 declaration[
'options'] = filter_unique_options(new_options,
128 allow_kwarg, type_to_signature, remove_self)
131 def sort_by_number_of_options(declaration, reverse=True):
132 def num_checked_args(option):
133 return sum(map(
lambda a:
not a.get(
'ignore_check',
False), option[
'arguments']))
134 declaration[
'options'].sort(key=num_checked_args, reverse=reverse)
139 def __init__(self, name):
143 def add_argument(self, arg):
144 assert isinstance(arg, Argument)
145 self.arguments.append(arg)
148 return self.
name +
'(' +
', '.join(map(
lambda a: a.__repr__(), self.
arguments)) +
')' 153 def __init__(self, _type, name, is_optional):
162 def parse_header(path):
163 with open(path,
'r') as f: 164 lines = f.read().split('\n')
167 lines = filter(
lambda l: l
and not l.startswith(
'#'), lines)
169 lines = map(
lambda l: l.partition(
'//'), lines)
171 lines = map(
lambda l: (l[0].strip(), l[2].strip()), lines)
173 lines = map(
lambda l: (l[0].rstrip(
');').rstrip(
','), l[1]), lines)
175 lines = map(
lambda l: (l[0].
split(
','), l[1]), lines)
180 new_lines.append((split, c))
184 lines = map(
lambda l: (l[0].strip(), l[1]), lines)
186 lines = filter(
lambda l: l[0], lines)
187 generic_functions = []
189 if l.startswith(
'TH_API void THNN_'):
190 fn_name = l.lstrip(
'TH_API void THNN_')
191 if fn_name[0] ==
'(' and fn_name[-2] ==
')':
192 fn_name = fn_name[1:-2]
194 fn_name = fn_name[:-1]
195 generic_functions.append(
Function(fn_name))
196 elif l.startswith(
'THC_API void THNN_'):
197 fn_name = l.lstrip(
'THC_API void THNN_')
198 if fn_name[0] ==
'(' and fn_name[-2] ==
')':
199 fn_name = fn_name[1:-2]
201 fn_name = fn_name[:-1]
202 generic_functions.append(
Function(fn_name))
208 generic_functions[-1].add_argument(
209 Argument(t, name,
'[OPTIONAL]' in c))
210 return generic_functions
Module caffe2.python.layers.split.