Caffe2 - Python API
A deep learning, cross platform ML framework
cwrap.py
1 import os
2 import yaml
3 from string import Template
4 from copy import deepcopy
5 from .plugins import ArgcountChecker, OptionalArguments, ArgumentReferences, \
6  BeforeAfterCall, ConstantArguments, ReturnArguments, GILRelease
7 from ..shared import cwrap_common
8 
9 
10 class cwrap(object):
11  BASE_INDENT_SIZE = 6
12 
13  RETURN_WRAPPERS = {
14  'void': Template('Py_RETURN_NONE;'),
15  'long': Template('return PyLong_FromLong($result);'),
16  'int64_t': Template('return PyLong_FromLong($result);'),
17  'bool': Template('return PyBool_FromLong($result);'),
18  'void*': Template('return PyLong_FromVoidPtr($result);'),
19  }
20 
21  OPTION_TEMPLATE = Template("""
22  ${els}if ($arg_check) {
23  $pre_arg_assign
24  $arg_assign
25  $code
26  """)
27 
28  ARG_ASSIGN_TEMPLATE = Template("""${type} ${name} = ${unpack};""")
29 
30  OPTION_CODE_TEMPLATE = [
31  '$call',
32  '$return_result',
33  ]
34 
35  FUNCTION_CALL_TEMPLATE = Template("$capture_result$cname($call_arg);")
36 
37  DEFAULT_PLUGIN_CLASSES = [ArgcountChecker, ConstantArguments, OptionalArguments,
38  ArgumentReferences, BeforeAfterCall, ReturnArguments, GILRelease]
39 
40  def __init__(self, source, destination=None, plugins=None, default_plugins=True, template_path=None):
41  if destination is None:
42  destination = source.replace('.cwrap', '.cpp')
43 
44  self.plugins = [] if plugins is None else plugins
45  if default_plugins:
46  defaults = [cls() for cls in self.DEFAULT_PLUGIN_CLASSES]
47  self.plugins = defaults + self.plugins
48 
49  for plugin in self.plugins:
50  plugin.initialize(self)
51 
52  self.base_path = os.path.dirname(os.path.abspath(source))
53  with open(source, 'r') as f:
54  declarations = f.read()
55 
56  # wrap all the declarations in the source .cwrap file
57  wrapper = self.wrap_declarations(declarations)
58 
59  # let each plugin do any post-processing of the wrapped file
60  for plugin in self.plugins:
61  wrapper = plugin.process_full_file(wrapper, template_path)
62 
63  # See Note [Unchanging results for ninja]
64  try:
65  with open(destination, 'r') as f:
66  old_wrapper = f.read()
67  except IOError:
68  old_wrapper = None
69 
70  if old_wrapper != wrapper:
71  with open(destination, 'w') as f:
72  print("Writing {}".format(destination))
73  f.write(wrapper)
74  else:
75  print("Skipped writing {}".format(destination))
76 
77  def wrap_declarations(self, declarations):
78  lines = declarations.split('\n')
79  declaration_lines = []
80  output = []
81  in_declaration = False
82  i = 0
83 
84  while i < len(lines):
85  line = lines[i]
86  if line == '[[':
87  declaration_lines = []
88  in_declaration = True
89  elif line == ']]':
90  in_declaration = False
91  declaration = yaml.load('\n'.join(declaration_lines))
92  cwrap_common.set_declaration_defaults(declaration)
93 
94  # Pass declaration in a list - maybe some plugins want to add
95  # multiple wrappers
96  declarations = [declaration]
97  for plugin in self.plugins:
98  declarations = plugin.process_declarations(declarations)
99  # Generate wrappers for all declarations and append them to
100  # the output
101  for declaration in declarations:
102  wrapper = self.generate_wrapper(declaration)
103  for plugin in self.plugins:
104  wrapper = plugin.process_wrapper(wrapper, declaration)
105  output.append(wrapper)
106  elif in_declaration:
107  declaration_lines.append(line)
108  elif '!!inc ' == line[:6]:
109  fname = os.path.join(self.base_path, line[6:].strip())
110  with open(fname, 'r') as f:
111  included = f.read().split('\n')
112  # insert it into lines at position i+1
113  lines[i + 1:i + 1] = included
114  else:
115  output.append(line)
116  i += 1
117 
118  return '\n'.join(output)
119 
120  def parse_arguments(self, args):
121  new_args = []
122  for arg in args:
123  # Simple arg declaration of form "<type> <name>"
124  if isinstance(arg, str):
125  t, _, name = arg.partition(' ')
126  new_args.append({'type': t, 'name': name})
127  elif isinstance(arg, dict):
128  if 'arg' in arg:
129  arg['type'], _, arg['name'] = arg['arg'].partition(' ')
130  del arg['arg']
131  new_args.append(arg)
132  else:
133  assert False
134  return new_args
135 
136  def search_plugins(self, fnname, args, fallback):
137  """Search plugins for the given function to call with args.
138 
139  If not found, call fallback with args.
140  """
141  for plugin in self.plugins:
142  wrapper = getattr(plugin, fnname)(*args)
143  if wrapper is not None:
144  return wrapper
145  return fallback(*args)
146 
147  def get_type_check(self, arg, option):
148  return self.search_plugins('get_type_check', (arg, option), lambda arg, _: None)
149 
150  def get_type_unpack(self, arg, option):
151  return self.search_plugins('get_type_unpack', (arg, option), lambda arg, _: None)
152 
153  def get_return_wrapper(self, option):
154  return self.search_plugins('get_return_wrapper', (option,), lambda _: self.RETURN_WRAPPERS[option['return']])
155 
156  def get_wrapper_template(self, declaration):
157  return self.search_plugins('get_wrapper_template', (declaration,), lambda _: None)
158 
159  def get_assign_args(self, arguments):
160  return self.search_plugins('get_assign_args', (arguments,), lambda _: arguments)
161 
162  def get_arg_accessor(self, arg, option):
163  def wrap_accessor(arg, _):
164  if arg.get('idx') is None:
165  raise RuntimeError("Missing accessor for '{} {}'".format(
166  arg['type'], arg['name']))
167  return 'PyTuple_GET_ITEM(args, {})'.format(arg['idx'])
168 
169  return self.search_plugins('get_arg_accessor', (arg, option), wrap_accessor)
170 
171  def generate_wrapper(self, declaration):
172  wrapper = ''
173  for i, option in enumerate(declaration['options']):
174  option_wrapper = self.generate_option(option, is_first=(i == 0))
175  for plugin in self.plugins:
176  option_wrapper = plugin.process_option_code(option_wrapper, option)
177  wrapper += option_wrapper
178  return self.get_wrapper_template(declaration).substitute(name=declaration['name'], options=wrapper)
179 
180  def map_selected_arguments(self, base_fn_name, plugin_fn_name, option, arguments):
181  result = []
182  for arg in arguments:
183  accessor = self.get_arg_accessor(arg, option)
184  tmpl = getattr(self, base_fn_name)(arg, option)
185  if tmpl is None:
186  fn = 'check' if base_fn_name == 'get_type_check' else 'unpack'
187  raise RuntimeError("Missing type {} for '{} {}'".format(
188  fn, arg['type'], arg['name']))
189  res = tmpl.substitute(arg=accessor, idx=arg.get('idx'))
190  for plugin in self.plugins:
191  res = getattr(plugin, plugin_fn_name)(res, arg, accessor)
192 
193  result.append(res)
194  return result
195 
196  def build_option_args(self, arguments, arg_unpack):
197  assignement = []
198  call_arg = []
199  # If types or names needs to be changed
200  arguments = self.get_assign_args(arguments)
201  for arg, unpack in zip(arguments, arg_unpack):
202  if arg['type'] == 'CONSTANT':
203  call_arg.append(unpack)
204  else:
205  var_name = "arg_" + str(arg.get('assign_name', arg['name']))
206  res = self.ARG_ASSIGN_TEMPLATE.substitute(
207  type=arg['type'],
208  name=var_name,
209  unpack=unpack)
210 
211  if var_name not in call_arg:
212  assignement.append(res)
213  call_arg.append(var_name)
214  return assignement, call_arg
215 
216  def indent_code(self, code):
217  if code == '':
218  return code
219  code_lines = map(lambda s: s.strip(), code.split('\n'))
220  code = '\n'
221  depth = self.BASE_INDENT_SIZE
222  for line in code_lines:
223  depth -= line.count('}') * 2
224  code += ' ' * depth + line + '\n'
225  depth += line.count('{') * 2
226  depth += line.count('(') * 4
227  depth -= line.count(')') * 4
228  return code[:-1]
229 
230  def generate_option(self, option, is_first):
231  checked_args = list(filter(
232  lambda arg: 'ignore_check' not in arg or not arg['ignore_check'],
233  option['arguments']))
234  option['num_checked_args'] = len(checked_args)
235  idx_args = list(filter(
236  lambda arg: not arg.get('ignore_check') and not arg.get('no_idx'),
237  option['arguments']))
238  for i, arg in enumerate(idx_args):
239  arg['idx'] = i
240 
241  # Generate checks
242  arg_checks = self.map_selected_arguments('get_type_check',
243  'process_single_check', option, checked_args)
244  arg_checks = ' &&\n '.join(arg_checks)
245  for plugin in self.plugins:
246  arg_checks = plugin.process_all_checks(arg_checks, option)
247 
248  # Generate pre_arg assign
249  pre_arg_assign = []
250  for plugin in self.plugins:
251  pre_arg_assign = plugin.process_pre_arg_assign(pre_arg_assign, option)
252 
253  # Generate arg assignment and call arguments
254  arg_unpack = self.map_selected_arguments('get_type_unpack',
255  'process_single_unpack', option, option['arguments'])
256  arg_assign, call_arg = self.build_option_args(option['arguments'], arg_unpack)
257 
258  call_arg = ', '.join(call_arg)
259  for plugin in self.plugins:
260  call_arg = plugin.process_all_call_arg(call_arg, option)
261 
262  # Generate call
263  try:
264  return_result = self.get_return_wrapper(option).substitute()
265  call = self.FUNCTION_CALL_TEMPLATE.substitute(capture_result='',
266  cname=option['cname'], call_arg=call_arg)
267  except KeyError:
268  return_result = self.get_return_wrapper(option).substitute(result='__result')
269  call = self.FUNCTION_CALL_TEMPLATE.substitute(capture_result=(option['return'] + ' __result = '),
270  cname=option['cname'], call_arg=call_arg)
271 
272  code_template = deepcopy(self.OPTION_CODE_TEMPLATE)
273  for plugin in self.plugins:
274  code_template = plugin.process_option_code_template(code_template,
275  option)
276  code_template = Template('\n'.join(code_template))
277  code = code_template.substitute(call=call, return_result=return_result)
278  code = self.indent_code(code)
279  pre_arg_assign = self.indent_code('\n'.join(pre_arg_assign))
280  arg_assign = self.indent_code('\n'.join(arg_assign))
281 
282  # Put everything together
283  return self.OPTION_TEMPLATE.substitute(
284  els=('} else ' if not is_first else ''),
285  arg_check=arg_checks,
286  pre_arg_assign=pre_arg_assign,
287  arg_assign=arg_assign,
288  code=code,
289  )
Module caffe2.python.layers.split.
def generate_wrapper(self, declaration)
Definition: cwrap.py:171
def indent_code(self, code)
Definition: cwrap.py:216
def get_wrapper_template(self, declaration)
Definition: cwrap.py:156
def generate_option(self, option, is_first)
Definition: cwrap.py:230
def map_selected_arguments(self, base_fn_name, plugin_fn_name, option, arguments)
Definition: cwrap.py:180
def build_option_args(self, arguments, arg_unpack)
Definition: cwrap.py:196
def get_assign_args(self, arguments)
Definition: cwrap.py:159
def wrap_declarations(self, declarations)
Definition: cwrap.py:77
def get_return_wrapper(self, option)
Definition: cwrap.py:153
def get_arg_accessor(self, arg, option)
Definition: cwrap.py:162
def search_plugins(self, fnname, args, fallback)
Definition: cwrap.py:136