Caffe2 - Python API
A deep learning, cross platform ML framework
native_parse.py
1 from __future__ import print_function
2 import re
3 import yaml
4 import pprint
5 import sys
6 import copy
7 
8 try:
9  # use faster C loader if available
10  from yaml import CLoader as Loader
11 except ImportError:
12  from yaml import Loader
13 
14 
15 # [temp translations]
16 # We're currently incrementally moving from the custom func schema to the
17 # JIT signature schema incrementally. This will reduce overall complexity
18 # and increase compliance between these components. So for now we do simple
19 # type translations to continue to emit the legacy func schema for further
20 # processing by downstream tools. This will helps us avoid having to prematurely
21 # change all downstream tools to detect these new types.
22 def type_argument_translations(arg):
23  type_and_name = [a.strip() for a in arg.rsplit(' ', 1)]
24  name = ''
25  if len(type_and_name) > 1:
26  name = type_and_name[1]
27  t = type_and_name[0]
28  name = name.split('=')
29  default = None
30  nullable = False
31  size = None # Only applies to int[\d+] and Tensor[\d+] arguments
32  if len(name) > 1:
33  default = name[1]
34  name = name[0]
35 
36  match = re.match(r'(Tensor.*)\((.+)\)(.*)', t)
37  annotation = None
38  if match:
39  t = match.group(1) + match.group(3)
40  annotation = match.group(2)
41 
42  # XXX: is_nullable flag can only annotate entire type as optional type,
43  # need to special case Generator? logic to make ? only available in jit
44  # TODO: deprecate is_nullable global flag, and parse the type
45  # to support annotating complicated types with optional annotation
46  nullable = (t != 'Generator?' and '?' in t)
47 
48  # This enables "Generator? x = None and translates to legacy
49  # "Generator* x = nullptr". See [temp translations].
50  if t == 'Generator?' and default == 'None':
51  t = 'Generator*'
52  default = 'nullptr'
53  # Enables Generator? by translating to legacy Generator*.
54  elif t == "Generator?":
55  t = 'Generator*'
56  # Enables Tensor[] by translating to legacy TensorList.
57  elif t == 'Tensor[]' or t == 'Tensor?[]':
58  t = 'TensorList'
59  # Enables int[] by translating to legacy IntArrayRef.
60  elif t == 'int[]':
61  t = 'IntArrayRef'
62  # Enables int by translating to legacy int64_t.
63  elif t == 'int':
64  t = 'int64_t'
65  elif t == 'int?':
66  t = 'int64_t?'
67  elif t == 'int64_t':
68  raise RuntimeError("Please use int and not int64_t. "
69  "See [temp translations] for details.")
70  elif t == 'int64_t?':
71  raise RuntimeError("Please use int? and not int64_t?. "
72  "See [temp translations] for details.")
73  # Enables float by translating to legacy double.
74  elif t == 'float':
75  t = 'double'
76  # Enables str by translating to legacy std::string.
77  elif t == 'str':
78  t = 'std::string'
79  elif t == 'double':
80  raise RuntimeError("Please use float and not double. "
81  "See [temp translations] for details.")
82  # Enables int[x] by translating to legacy IntArrayRef[x]. See [temp translations]
83  elif re.match(r'int\[(\d+)\]', t):
84  match = re.match(r'int\[(\d+)\]', t)
85  t = 'IntArrayRef'
86  size = int(match.group(1))
87  # Enables bool[x] by translating to legacy std::array<bool,x>. See [temp translations]
88  elif re.match(r'bool\[(\d+)\]', t):
89  match = re.match(r'bool\[(\d+)\]', t)
90  t = 'std::array<bool,{}>'.format(match.group(1))
91  elif re.match(r'std::array', t):
92  raise RuntimeError("Please use array notation, e.g. bool[3] and not std::array."
93  "See [temp translations] for details.")
94 
95  # Legacy type sanitization. TODO: Do we really need this?
96  if t == 'Generator*':
97  t = 'Generator *'
98 
99  if not default:
100  pass
101  # This enables Tensor? x=None and translates to legacy
102  # "Tensor? x={}". See [temp translations].
103  elif t.startswith('Tensor?') and default == 'None':
104  default = "{}"
105  elif default == 'True':
106  default = True
107  elif default == 'False':
108  default = False
109  elif default == 'true':
110  raise RuntimeError("Please use True and not true. "
111  "See [temp translations] for details.")
112  elif default == 'false':
113  raise RuntimeError("Please use False and not false. "
114  "See [temp translations] for details.")
115  # Enables default argument [] by translating to legacy {}.
116  # See [temp translations]
117  elif default == '[]':
118  default = '{}'
119  # Enables lists by translating to legacy {.*}.
120  # See [temp translations]
121  elif re.match(r'\[.*\]', default):
122  default = "{" + default[1:-1] + "}"
123  elif default == 'None':
124  default = 'c10::nullopt'
125  # The JIT signature schema uses Mean, but in particular C++ needs
126  # the legacy Reduction::Mean. So we'll continue emiting that until
127  # we change this at either a JIT schema or C++ level.
128  elif default == 'Mean':
129  default = 'Reduction::Mean'
130  else:
131  try:
132  default = int(default)
133  except ValueError:
134  try:
135  default = float(default)
136  except ValueError:
137  pass
138 
139  return t, name, default, nullable, size, annotation
140 
141 
142 def parse_arguments(args, func_variants, declaration, func_return):
143  arguments = []
144  kwarg_only = False
145 
146  if len(args.strip()) == 0:
147  return arguments
148 
149  # TODO: Use a real parser here; this will get bamboozled
150  # by signatures that contain things like std::array<bool, 2> (note the space)
151  for arg_idx, arg in enumerate(args.split(', ')):
152  type_and_name = [a.strip() for a in arg.rsplit(' ', 1)]
153  if type_and_name == ['*']:
154  assert not kwarg_only
155  kwarg_only = True
156  continue
157 
158  t, name, default, nullable, size, annotation = type_argument_translations(arg)
159 
160  argument_dict = {'type': t.rstrip('?'), 'name': name, 'is_nullable': nullable, 'annotation': annotation}
161  if size:
162  argument_dict['size'] = size
163  if default is not None:
164  argument_dict['default'] = default
165  if kwarg_only:
166  argument_dict['kwarg_only'] = True
167  arguments.append(argument_dict)
168 
169  is_out_fn = False
170  arguments_out = []
171  arguments_other = []
172  for argument in arguments:
173  if argument['type'] == "Tensor" and \
174  argument['annotation'] and \
175  re.match(r'^(.*!)$', argument['annotation']) and \
176  argument.get('kwarg_only'):
177  argument['output'] = True
178  argument['kwarg_only'] = False
179  arguments_out.append(argument)
180  is_out_fn = True
181  else:
182  arguments_other.append(argument)
183 
184  arguments = arguments_out + arguments_other
185 
186  name = declaration['name']
187  if is_out_fn:
188  declaration['name'] += "_out"
189 
190  # Reverse splat of TensorOptions
191  # As we move towards the JIT function schema for native_functions.yaml we need to support
192  # the expanded version of TensorOptions. For now we discover whether there are three
193  # types and names of keyword arguments: "ScalarType dtype", "Layout layout" and "Device device"
194  # Each, if set, must have default arguments set to long or float, strided and "cpu" respectively.
195  # They must appear in this order and in this order only in order for us to be able to process them.
196  # In the future we will get rid of this specific processing as downstream consumers start relying
197  # less on the content of Declarations.yaml. If you want to support more than this you'll
198  # potentially have to extend the JIT.
199 
200  supported_topt_arguments = [
201  [
202  {'name': 'dtype', 'type': 'ScalarType', 'is_nullable': False, 'annotation': None},
203  {'name': 'layout', 'type': 'Layout', 'is_nullable': False, 'annotation': None},
204  {'name': 'device', 'type': 'Device', 'is_nullable': False, 'annotation': None},
205  ]
206  ]
207  supported_topt_arguments.append(copy.deepcopy(supported_topt_arguments[0]))
208  supported_topt_arguments[1][0]['kwarg_only'] = True
209  supported_topt_arguments[1][1]['kwarg_only'] = True
210  supported_topt_arguments[1][2]['kwarg_only'] = True
211  supported_topt_arguments.append(copy.deepcopy(supported_topt_arguments[1]))
212  supported_topt_arguments[2][0]['default'] = 'c10::nullopt'
213  supported_topt_arguments[2][1]['default'] = 'c10::nullopt'
214  supported_topt_arguments[2][2]['default'] = 'c10::nullopt'
215  supported_topt_arguments[2][0]['is_nullable'] = True
216  supported_topt_arguments[2][1]['is_nullable'] = True
217  supported_topt_arguments[2][2]['is_nullable'] = True
218 
219  corresponding_topts = [
220  {'type': 'TensorOptions', 'name': 'options', 'is_nullable': False, 'annotation': None},
221  ]
222  corresponding_topts.append(corresponding_topts[0].copy())
223  corresponding_topts[1]['kwarg_only'] = True
224  corresponding_topts.append(corresponding_topts[1].copy())
225  corresponding_topts[2]['default'] = '{}'
226 
227  def check_topt_representation(topt_representation):
228  for idx, supported_topt in enumerate(supported_topt_arguments):
229  matches = True
230  matches = matches and topt_representation[0] == supported_topt[0]
231  matches = matches and topt_representation[1] == supported_topt[1]
232  matches = matches and topt_representation[2] == supported_topt[2]
233  if matches:
234  return corresponding_topts[idx]
235  return None
236 
237  def is_tensor_option(argument):
238  return argument['name'] in ['dtype', 'layout', 'device']
239 
240  new_arguments = []
241  idx = 0
242  while idx < len(arguments):
243  argument = arguments[idx]
244  if is_tensor_option(argument) and len(arguments) - idx >= 3:
245  topt_representation = []
246  for i in range(3):
247  argument = arguments[idx]
248  if not is_tensor_option(argument):
249  break
250  topt_representation.append(argument)
251  idx += 1
252  if len(topt_representation) == 3:
253  merged_argument = check_topt_representation(topt_representation)
254  assert merged_argument, \
255  "Unsupported combination of TensorOptions {}, the only currently supported combinations are {}"\
256  .format(str(topt_representation), str(supported_topt_arguments))
257  new_arguments.append(merged_argument)
258  else:
259  new_arguments += topt_representation
260  else:
261  new_arguments.append(argument)
262  idx += 1
263 
264  arguments = new_arguments
265 
266  # Sanity checks
267 
268  # TODO: convention is that the ith-argument correspond to the i-th return, but it would
269  # be better if we just named everything and matched by name.
270  for arg_idx, argument in enumerate(arguments_out):
271  assert argument['annotation'] == func_return[arg_idx]['annotation'], \
272  "For func {} writeable keyword Tensor arguments need to have a matching return Tensor. Further, " \
273  "the ith-argument needs to correspond to the i-th return.".format(name)
274 
275  assert len(arguments_out) <= len(func_return), "func {} must return at least as many Tensors " \
276  "as can be passed as output.".format(name)
277 
278  if name.endswith('_out'):
279  raise RuntimeError("Native function {} may not be suffixed with _out as we transition to a unified schema. "
280  "Otherwise you will cause confusion amongst consumers of native functions.".format(name))
281 
282  if is_out_fn and func_variants not in [[], 'function', ['function']]:
283  raise RuntimeError("Native functions with output MUST be declared with only the function variant; "
284  "e.g., variants: function; otherwise you will tickle a Python argument binding bug "
285  "(which usually manifests itself as the result variable being undefined.) "
286  "The culprit was: {}".format(name))
287  if not is_out_fn:
288  assert len(arguments_out) == 0, "func {} is not marked as output yet contains output " \
289  "keyword arguments".format(name)
290 
291  # TODO: Explicit checking for void is a hack and should disappear after a more
292  # functionally complete implementation of Tensor aliases.
293  if declaration['inplace'] and len(func_return) > 0 and func_return[0]['type'] != "void":
294  found_self = False
295  for arg_idx, argument in enumerate(arguments):
296  if argument['name'] == "self":
297  assert argument['annotation'] and argument['annotation'].endswith("!"), \
298  "Inplace function \"{}\" needs to annotate Tensor argument named self " \
299  "as mutable.".format(name)
300  found_self = True
301  assert argument['annotation'] == func_return[arg_idx]['annotation'], \
302  "Inplace function annotations of function {} need to match between " \
303  "input and correponding output.".format(name)
304  assert argument['name'] == func_return[arg_idx]['name'] or \
305  argument['name'] == func_return[arg_idx]['name'] + "_return"
306  assert argument['type'] == func_return[arg_idx]['type']
307  assert found_self, "Inplace function \"{}\" needs Tensor argument named self.".format(name)
308 
309  return arguments
310 
311 
312 def parse_return_arguments(return_decl, inplace, func_decl):
313  arguments = []
314  # TODO: Use a real parser here; this will get bamboozled
315  # by signatures that contain things like std::array<bool, 2> (note the space)
316  if return_decl[0] == '(' and return_decl[-1] == ')':
317  return_decl = return_decl[1:-1]
318  multiple_args = len(return_decl.split(', ')) > 1
319 
320  for arg_idx, arg in enumerate(return_decl.split(', ')):
321  t, name, default, nullable, size, annotation = type_argument_translations(arg)
322  # name of arguments and name of return sometimes have collision
323  # in this case, we rename the return name to <name>_return.
324  return_name = name
325  if name in func_decl['func'].split('->')[0]:
326  return_name = name + "_return"
327  argument_dict = {'type': t, 'name': return_name, 'annotation': annotation}
328  if name:
329  # See Note [field_name versus name]
330  argument_dict['field_name'] = name
331  else:
332  if t == "Tensor" and inplace:
333  assert annotation and annotation.endswith("!"), \
334  "Return Tensor of function \"{}\" flagged as inplace needs to be " \
335  "annotated as mutable".format(func_decl['func'])
336  argument_dict['name'] = 'self'
337  else:
338  argument_dict['name'] = 'result' if not multiple_args else 'result' + str(arg_idx)
339  argument_dict['output'] = True
340  arguments.append(argument_dict)
341  return arguments
342 
343 
344 def has_sparse_dispatches(dispatches):
345  for dispatch in dispatches:
346  if 'Sparse' in dispatch:
347  return True
348  return False
349 
350 
351 def parse_native_yaml(path):
352  with open(path, 'r') as f:
353  return yaml.load(f, Loader=Loader)
354 
355 
356 def propagate_field_names(output_arguments, return_arguments):
357  if output_arguments:
358  for i, r in enumerate(return_arguments):
359  if 'field_name' in r:
360  output_arguments[i]['field_name'] = r['field_name']
361 
362 
363 def run(paths):
364  declarations = []
365  for path in paths:
366  for func in parse_native_yaml(path):
367  declaration = {'mode': 'native'}
368  try:
369  declaration['schema_string'] = "aten::" + func['func']
370  if '->' in func['func']:
371  func_decl, return_decl = [x.strip() for x in func['func'].split('->')]
372  else:
373  raise Exception('Expected return declaration')
374  fn_name, arguments = func_decl.split('(', 1)
375  assert arguments[-1] == ")", "Expecting closing ) for {}".format(func['func'])
376  arguments = arguments[:-1] # Expect closing )
377  declaration['name'] = func.get('name', fn_name)
378  declaration['inplace'] = re.search('(^__i|[^_]_$)', fn_name) is not None
379  return_arguments = parse_return_arguments(return_decl, declaration['inplace'], func)
380  arguments = parse_arguments(arguments, func.get('variants', []), declaration, return_arguments)
381  output_arguments = [x for x in arguments if x.get('output')]
382  propagate_field_names(output_arguments, return_arguments)
383  declaration['return'] = return_arguments if len(output_arguments) == 0 else output_arguments
384  declaration['variants'] = func.get('variants', ['function'])
385  declaration['requires_tensor'] = func.get('requires_tensor', False)
386  declaration['matches_jit_signature'] = func.get('matches_jit_signature', False)
387  declaration['cpu_half'] = func.get('cpu_half', False)
388  declaration['cpu_bool'] = func.get('cpu_bool', False)
389  declaration['deprecated'] = func.get('deprecated', False)
390  declaration['device_guard'] = func.get('device_guard', True)
391  declaration['arguments'] = func.get('arguments', arguments)
392  declaration['type_method_definition_dispatch'] = func.get('dispatch', declaration['name'])
393  declaration['python_module'] = func.get('python_module', '')
394  declarations.append(declaration)
395  except Exception as e:
396  msg = '''Exception raised in processing function:
397 {func}
398 Generated partial declaration:
399 {decl}'''.format(func=pprint.pformat(func), decl=pprint.pformat(declaration))
400  print(msg, file=sys.stderr)
401  raise e
402 
403  return declarations
Module caffe2.python.layers.split.