3 from string
import Template
4 from copy
import deepcopy
5 from .plugins
import ArgcountChecker, OptionalArguments, ArgumentReferences, \
6 BeforeAfterCall, ConstantArguments, ReturnArguments, GILRelease
7 from ..shared
import cwrap_common
14 'void': Template(
'Py_RETURN_NONE;'),
15 'long': Template(
'return PyLong_FromLong($result);'),
16 'int64_t': Template(
'return PyLong_FromLong($result);'),
17 'bool': Template(
'return PyBool_FromLong($result);'),
18 'void*': Template(
'return PyLong_FromVoidPtr($result);'),
21 OPTION_TEMPLATE = Template(
""" 22 ${els}if ($arg_check) { 28 ARG_ASSIGN_TEMPLATE = Template(
"""${type} ${name} = ${unpack};""")
30 OPTION_CODE_TEMPLATE = [
35 FUNCTION_CALL_TEMPLATE = Template(
"$capture_result$cname($call_arg);")
37 DEFAULT_PLUGIN_CLASSES = [ArgcountChecker, ConstantArguments, OptionalArguments,
38 ArgumentReferences, BeforeAfterCall, ReturnArguments, GILRelease]
40 def __init__(self, source, destination=None, plugins=None, default_plugins=True, template_path=None):
41 if destination
is None:
42 destination = source.replace(
'.cwrap',
'.cpp')
44 self.
plugins = []
if plugins
is None else plugins
50 plugin.initialize(self)
52 self.
base_path = os.path.dirname(os.path.abspath(source))
53 with open(source,
'r') as f: 54 declarations = f.read() 61 wrapper = plugin.process_full_file(wrapper, template_path)
65 with open(destination,
'r') as f: 66 old_wrapper = f.read() 70 if old_wrapper != wrapper:
71 with open(destination,
'w')
as f:
72 print(
"Writing {}".format(destination))
75 print(
"Skipped writing {}".format(destination))
77 def wrap_declarations(self, declarations):
78 lines = declarations.split(
'\n')
79 declaration_lines = []
81 in_declaration =
False 87 declaration_lines = []
90 in_declaration =
False 91 declaration = yaml.load(
'\n'.join(declaration_lines))
92 cwrap_common.set_declaration_defaults(declaration)
96 declarations = [declaration]
98 declarations = plugin.process_declarations(declarations)
101 for declaration
in declarations:
104 wrapper = plugin.process_wrapper(wrapper, declaration)
105 output.append(wrapper)
107 declaration_lines.append(line)
108 elif '!!inc ' == line[:6]:
109 fname = os.path.join(self.
base_path, line[6:].strip())
110 with open(fname,
'r') as f: 111 included = f.read().split('\n')
113 lines[i + 1:i + 1] = included
118 return '\n'.join(output)
120 def parse_arguments(self, args):
124 if isinstance(arg, str):
125 t, _, name = arg.partition(
' ')
126 new_args.append({
'type': t,
'name': name})
127 elif isinstance(arg, dict):
129 arg[
'type'], _, arg[
'name'] = arg[
'arg'].partition(
' ')
137 """Search plugins for the given function to call with args. 139 If not found, call fallback with args. 142 wrapper = getattr(plugin, fnname)(*args)
143 if wrapper
is not None:
145 return fallback(*args)
147 def get_type_check(self, arg, option):
148 return self.
search_plugins(
'get_type_check', (arg, option),
lambda arg, _:
None)
150 def get_type_unpack(self, arg, option):
151 return self.
search_plugins(
'get_type_unpack', (arg, option),
lambda arg, _:
None)
153 def get_return_wrapper(self, option):
156 def get_wrapper_template(self, declaration):
157 return self.
search_plugins(
'get_wrapper_template', (declaration,),
lambda _:
None)
159 def get_assign_args(self, arguments):
160 return self.
search_plugins(
'get_assign_args', (arguments,),
lambda _: arguments)
162 def get_arg_accessor(self, arg, option):
163 def wrap_accessor(arg, _):
164 if arg.get(
'idx')
is None:
165 raise RuntimeError(
"Missing accessor for '{} {}'".format(
166 arg[
'type'], arg[
'name']))
167 return 'PyTuple_GET_ITEM(args, {})'.format(arg[
'idx'])
169 return self.
search_plugins(
'get_arg_accessor', (arg, option), wrap_accessor)
171 def generate_wrapper(self, declaration):
173 for i, option
in enumerate(declaration[
'options']):
176 option_wrapper = plugin.process_option_code(option_wrapper, option)
177 wrapper += option_wrapper
178 return self.
get_wrapper_template(declaration).substitute(name=declaration[
'name'], options=wrapper)
180 def map_selected_arguments(self, base_fn_name, plugin_fn_name, option, arguments):
182 for arg
in arguments:
184 tmpl = getattr(self, base_fn_name)(arg, option)
186 fn =
'check' if base_fn_name ==
'get_type_check' else 'unpack' 187 raise RuntimeError(
"Missing type {} for '{} {}'".format(
188 fn, arg[
'type'], arg[
'name']))
189 res = tmpl.substitute(arg=accessor, idx=arg.get(
'idx'))
191 res = getattr(plugin, plugin_fn_name)(res, arg, accessor)
196 def build_option_args(self, arguments, arg_unpack):
201 for arg, unpack
in zip(arguments, arg_unpack):
202 if arg[
'type'] ==
'CONSTANT':
203 call_arg.append(unpack)
205 var_name =
"arg_" + str(arg.get(
'assign_name', arg[
'name']))
206 res = self.ARG_ASSIGN_TEMPLATE.substitute(
211 if var_name
not in call_arg:
212 assignement.append(res)
213 call_arg.append(var_name)
214 return assignement, call_arg
216 def indent_code(self, code):
219 code_lines = map(
lambda s: s.strip(), code.split(
'\n'))
222 for line
in code_lines:
223 depth -= line.count(
'}') * 2
224 code +=
' ' * depth + line +
'\n' 225 depth += line.count(
'{') * 2
226 depth += line.count(
'(') * 4
227 depth -= line.count(
')') * 4
230 def generate_option(self, option, is_first):
231 checked_args = list(filter(
232 lambda arg:
'ignore_check' not in arg
or not arg[
'ignore_check'],
233 option[
'arguments']))
234 option[
'num_checked_args'] = len(checked_args)
235 idx_args = list(filter(
236 lambda arg:
not arg.get(
'ignore_check')
and not arg.get(
'no_idx'),
237 option[
'arguments']))
238 for i, arg
in enumerate(idx_args):
243 'process_single_check', option, checked_args)
244 arg_checks =
' &&\n '.join(arg_checks)
246 arg_checks = plugin.process_all_checks(arg_checks, option)
251 pre_arg_assign = plugin.process_pre_arg_assign(pre_arg_assign, option)
255 'process_single_unpack', option, option[
'arguments'])
258 call_arg =
', '.join(call_arg)
260 call_arg = plugin.process_all_call_arg(call_arg, option)
265 call = self.FUNCTION_CALL_TEMPLATE.substitute(capture_result=
'',
266 cname=option[
'cname'], call_arg=call_arg)
269 call = self.FUNCTION_CALL_TEMPLATE.substitute(capture_result=(option[
'return'] +
' __result = '),
270 cname=option[
'cname'], call_arg=call_arg)
274 code_template = plugin.process_option_code_template(code_template,
276 code_template = Template(
'\n'.join(code_template))
277 code = code_template.substitute(call=call, return_result=return_result)
279 pre_arg_assign = self.
indent_code(
'\n'.join(pre_arg_assign))
280 arg_assign = self.
indent_code(
'\n'.join(arg_assign))
283 return self.OPTION_TEMPLATE.substitute(
284 els=(
'} else ' if not is_first
else ''),
285 arg_check=arg_checks,
286 pre_arg_assign=pre_arg_assign,
287 arg_assign=arg_assign,
Module caffe2.python.layers.split.