Caffe2 - Python API
A deep learning, cross platform ML framework
nn_parse.py
1 import copy
2 import re
3 import common_with_cwrap
4 import yaml
5 from collections import OrderedDict, defaultdict
6 
7 try:
8  # use faster C loader if available
9  from yaml import CLoader as Loader
10 except ImportError:
11  from yaml import Loader
12 
13 
14 # matches `name`, `params` in `name(params)`
15 NAME_PARAM_REGEX = r'(\w+)\((.*)\)'
16 
17 
18 def argument_to_declaration(param, func=None):
19  arg = {}
20  arg['type'], name = param.split(' ')
21  if (arg['type'].endswith('?')):
22  arg['is_nullable'] = True
23  arg['type'] = arg['type'].rstrip('?')
24  if arg['type'] == 'Tensor':
25  arg['type'] = 'THTensor*'
26  elif arg['type'] == 'LongTensor':
27  arg['type'] = 'THIndexTensor*'
28  elif arg['type'] == 'Scalar':
29  arg['type'] = 'accreal'
30  elif arg['type'] == 'Generator*':
31  arg['type'] = 'THGenerator*'
32 
33  match = re.match(r'IntArrayRef\[(\d+)\]', arg['type'])
34  if match:
35  arg['type'] = 'IntArrayRef'
36  arg['size'] = int(match.group(1))
37 
38  if '=' in name:
39  name, default = name.split('=')
40  arg['optional'] = True
41  arg['default'] = default
42  arg['name'] = name
43 
44  if func is not None:
45  default_inits = func.get('default_init', {})
46  wrap_dims = func.get('wrap_dim', {})
47  if name in default_inits:
48  # non constexpr defaults
49  arg['default_init'] = default_inits[name]
50  if name in wrap_dims:
51  arg['wrap_dim'] = wrap_dims[name]
52 
53  return arg
54 
55 
56 def output_arguments(thnn_function):
57  cname = thnn_function.name
58  output_args = []
59 
60  # function_wrapper expects everything in a declaration to be in
61  # the base type (i.e. THTensor*), but if we pull a THCUNN only
62  # implementation, it will have THCTensor* as the arg type. So we
63  # strip the THC here before returning
64  def map_to_th_type(t):
65  if t.startswith('THC'):
66  t = t.replace('THC', 'TH')
67  return t
68 
69  def is_output_arg(arg_name, func_name):
70  if arg_name == 'output' and 'updateOutput' in cname:
71  return True
72  if name in {'gradInput', 'gradWeight', 'gradBias', 'gradGrid'}:
73  return True
74  if arg_name == 'indices' and 'updateOutput' in cname and 'Unpool' not in cname:
75  # indices is an output argument in pooling and an input in unpooling
76  return True
77  return False
78 
79  for arg in thnn_function.arguments:
80  name = arg.name
81  if is_output_arg(name, cname):
82  desc = {
83  'type': map_to_th_type(arg.type),
84  'name': camel_to_snake(name),
85  'output': True,
86  }
87  if name.startswith('grad_'):
88  desc['is_nullable'] = True
89  output_args.append(desc)
90  return output_args
91 
92 
93 def get_return(args):
94  indices = [str(idx) for idx, arg in enumerate(args) if arg.get('output')]
95  return 'argument {}'.format(','.join(indices))
96 
97 
98 ARGUMENT_MAPPINGS = {
99  'k': 'kernel_size',
100  'd': 'stride',
101  'pad': 'padding',
102  'p': 'padding',
103  'o': 'output_size',
104  'osize': 'output_size',
105  'output': 'output_size', # as a prefix e.g. outputW
106  'isize': 'input_size',
107  'dilation': 'dilation',
108  'adj': 'output_padding',
109  'a': 'output_padding',
110 }
111 
112 DIMENSION_OFFSET = {
113  'width': -1,
114  'height': -2,
115  'B': 0,
116  'C': 1,
117  'W': -1,
118  'H': -2,
119  'T': -3,
120  'left': 0,
121  'right': 1,
122  'top': 2,
123  'bottom': 3,
124  'front': 4,
125  'back': 5,
126 }
127 
128 SUBSTITUTIONS = {
129  'input': 'self',
130  'weights': 'weight',
131  'train': 'training',
132  'val': 'value',
133  'lambda': 'lambd',
134  'negval': 'negative_slope',
135 }
136 
137 
138 def camel_to_snake(name):
139  # from https://stackoverflow.com/questions/1175208/elegant-python-function-to-convert-camelcase-to-snake-case
140  s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
141  return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
142 
143 
144 def get_thnn_args(thnn_function, params, inplace):
145  params_by_name = {p['name']: p for p in params}
146 
147  def arg_expr(prefix, suffix):
148  # e.g kW, kH
149  name = ARGUMENT_MAPPINGS[prefix]
150  if name not in params_by_name:
151  raise RuntimeError('missing arg "{}" in {}'.format(name, thnn_function.name))
152  param = params_by_name[name]
153  if param['type'] == 'IntArrayRef' and 'size' in param:
154  name = name + '_'
155  # NB: We calculate the dimension based on the name of
156  # the argument, not its positional order. This means
157  # that we may reorder arguments to get them in
158  # the right place; e.g., if a THNN implementation
159  # has arguments in the order kernelW, kernelH, we
160  # will generate a caller that is kernel[1], kernel[0]
161  # to order them in the correct way.
162  index = DIMENSION_OFFSET[suffix]
163  if index < 0:
164  index += param['size']
165  expr = '{}[{}]'.format(name, index)
166  return {'type': 'EXPRESSION', 'name': expr}
167 
168  thnn_args = []
169  for arg in thnn_function.arguments:
170  name = arg.name
171  if name == 'state':
172  continue
173  if inplace and name == 'output':
174  name = 'self'
175  aten_name = camel_to_snake(SUBSTITUTIONS.get(name, name))
176  parts = aten_name.split('_')
177  if aten_name in params_by_name:
178  param = params_by_name[aten_name]
179  if arg.is_optional:
180  param['is_nullable'] = True
181  thnn_args.append(copy.deepcopy(param))
182  elif len(parts) == 2 and parts[0] in ARGUMENT_MAPPINGS and parts[1] in DIMENSION_OFFSET:
183  # e.g. pad_left
184  thnn_args.append(arg_expr(parts[0], parts[1]))
185  elif name[-1] in DIMENSION_OFFSET and name[:-1] in ARGUMENT_MAPPINGS:
186  # e.g kW, kH
187  thnn_args.append(arg_expr(name[:-1], name[-1]))
188  elif name == 'owidth' or name == 'oheight':
189  thnn_args.append(arg_expr(name[0], name[1:]))
190  elif name == 'scale':
191  thnn_args.append({'type': 'EXPRESSION', 'name': '1'})
192  elif name == 'inplace':
193  thnn_args.append({'type': 'EXPRESSION', 'name': str(inplace).lower()})
194  else:
195  raise RuntimeError("{}: can't find binding for '{}'"
196  .format(thnn_function.name, name))
197  return thnn_args
198 
199 
200 def remove_unused_args(args, thnn_args):
201  """Returns the subset of args whose name appears in thnn_args"""
202  def clean_name(name):
203  name = name[:name.index('[')] if '[' in name else name
204  if name.endswith('_'):
205  name = name[:-1]
206  return name
207  uses = set([clean_name(arg['name']) for arg in thnn_args])
208  uses.add('output_mask')
209  args = [arg for arg in args if arg['name'] in uses]
210  for arg in args:
211  if 'default' in arg:
212  del arg['default']
213  return args
214 
215 
216 def unique_args(argslist):
217  result = []
218  seen = set()
219  for args in argslist:
220  for arg in args:
221  if arg['name'] in seen:
222  continue
223  seen.add(arg['name'])
224  result.append(arg)
225  return result
226 
227 
228 def function_info(name, arguments, cimpls, buffers, backends, inplace, scalar_check):
229  """
230  cimpls contains information use to call into THNN:
231  cname: THNN function name
232  arguments: arguments to functional call
233  condition: [optional] guard around call
234  """
235  return {
236  'mode': 'NN',
237  'name': name,
238  'types': ['Float', 'Double', 'Half'], # Half will be stripped for CPU backend
239  'arguments': arguments,
240  'return': 'argument 0' if inplace else get_return(arguments),
241  'buffers': buffers,
242  'backends': backends,
243  'cimpls': cimpls,
244  'scalar_check': scalar_check,
245  'variants': ['function'],
246  }
247 
248 
249 def base_declaration(func, thnn_function, backends, inplace=False):
250  """Creates the NN function without any buffers in it's signature"""
251  name, params = re.match(NAME_PARAM_REGEX, func['name']).groups()
252  if inplace:
253  name += '_'
254  params = params.split(', ')
255  arguments = [argument_to_declaration(a, func) for a in params]
256  if not inplace:
257  arguments += output_arguments(thnn_function)
258  buffers = [argument_to_declaration('Tensor ' + buf)
259  for buf in func.get('buffers', [])]
260 
261  return function_info(name, arguments, None, buffers, backends, inplace, func.get('scalar_check'))
262 
263 
264 def forward_declaration(base, thnn_function, inplace=False):
265  name = '{}_forward'.format(base['name'])
266  if inplace:
267  name += '_'
268 
269  arguments = [copy.deepcopy(arg) for arg in base['arguments']
270  if not arg.get('output')]
271 
272  arguments += output_arguments(thnn_function)
273  for buffer in base['buffers']:
274  buffer = copy.deepcopy(buffer)
275  buffer['output'] = True
276  arguments.append(buffer)
277 
278  thnn_args = get_thnn_args(thnn_function, arguments, inplace)
279  arguments = remove_unused_args(arguments, thnn_args)
280  cimpl = {'cname': thnn_function.name, 'arguments': thnn_args}
281 
282  scalar_check = base['scalar_check']
283  if scalar_check is not None:
284  output_arg_names = [arg['name'] for arg in arguments if arg.get('output', False)]
285  scalar_check = {k: v for (k, v) in scalar_check.items() if k in output_arg_names}
286 
287  return function_info(name, arguments, [cimpl], [], base['backends'], inplace, scalar_check)
288 
289 
290 def backward_declaration(base, thnn_functions):
291  name = '{}_backward'.format(base['name'])
292 
293  arguments = []
294  arguments.append({'type': 'THTensor*', 'name': 'grad_output'})
295  arguments += [copy.deepcopy(arg) for arg in base['arguments']
296  if arg['name'] != 'inplace']
297  arguments += base['buffers']
298 
299  if 'upsample' in base['name']:
300  # Add input_size as parameter to upsample backwards functions
301  # Note that input_size is 4-dim for upsample_xxx2d
302  size = 2 + int(re.search(r'(\d+)d', base['name']).group(1))
303  input_size_arg = {'type': 'IntArrayRef', 'name': 'input_size', 'size': size}
304  for output_size_idx, arg in enumerate(arguments):
305  if arg['name'] == 'output_size':
306  break
307  arguments.insert(output_size_idx + 1, input_size_arg)
308 
309  if 'im2col' in base['name']:
310  # Add input_size as parameter to im2col backwards function
311  input_size_arg = {'type': 'IntArrayRef', 'name': 'input_size', 'size': 2}
312  arguments.insert(2, input_size_arg)
313 
314  # outputs from the forward may be inputs to the backwards
315  for arg in arguments:
316  if 'output' in arg:
317  del arg['output']
318 
319  arguments += unique_args([output_arguments(f) for f in thnn_functions])
320 
321  def initialize_output_arg(arg):
322  # the mask array<bool, N> specifies which return values to compute
323  arg['mask'] = True
324  arg['is_nullable'] = True
325 
326  # grad_weight and grad_bias need to be resized and zeroed
327  if arg['name'] == 'grad_weight':
328  arg['resize'] = 'weight'
329  arg['zero'] = True
330  if arg['name'] == 'grad_bias':
331  dim = 1 if 'transpose' in name else 0
332  arg['resize'] = [('weight', dim)]
333  arg['zero'] = True
334 
335  is_batch_norm_backward = '_backward' in thnn_functions[0].name
336  grad_params = []
337  if len(thnn_functions) > 1 or is_batch_norm_backward:
338  for arg in arguments:
339  if arg.get('output', False):
340  initialize_output_arg(arg)
341  if 'Tensor' in arg['type'] and arg['name'].startswith('grad_') and \
342  'input' not in arg['name'] and 'output' not in arg['name']:
343  grad_params.append(arg['name'])
344 
345  thnn_args = [get_thnn_args(f, arguments, False) for f in thnn_functions]
346  arguments = remove_unused_args(arguments, unique_args(thnn_args))
347  cimpls = []
348 
349  def get_condition(func):
350  # only call into the THNN functions if the output args are not null
351  if '_updateGradInput' in func.name:
352  return 'grad_input_'
353  if '_accGradParameters' in func.name:
354  return ' || '.join(p + '_' for p in grad_params)
355  return None
356 
357  for func, args in zip(thnn_functions, thnn_args):
358  cimpl = {'cname': func.name, 'arguments': args}
359  if len(thnn_functions) > 1:
360  cimpl['condition'] = get_condition(func)
361  cimpls.append(cimpl)
362 
363  output_args = [arg for arg in arguments if arg.get('output', False)]
364  scalar_check_arg = base['scalar_check'] if base['scalar_check'] is not None else dict()
365  scalar_check = {k: v for (k, v) in scalar_check_arg.items() if k in [a['name'] for a in output_args]}
366  for arg in output_args:
367  # resize automatically sets scalar_check
368  if scalar_check.get(arg['name']) is not None or arg.get('resize', False):
369  pass
370  else:
371  base_name = arg['name'][len('grad_'):] if arg['name'] != 'grad_input' else 'self'
372  if base_name in [a['name'] for a in arguments]:
373  scalar_check[arg['name']] = base_name + '_->dim() == 0'
374  else:
375  raise ValueError(("Could not infer scalar_check for {} argument of func {} because {} "
376  "does not exist. Please explicitly specify scalar_check."
377  .format(arg['name'], name, base_name)))
378 
379  return function_info(name, arguments, cimpls, [], base['backends'], False, scalar_check)
380 
381 
382 def parse_nn_yaml(filename):
383  with open(filename, 'r') as f:
384  return yaml.load(f, Loader=Loader)
385 
386 
387 include_only = '(updateOutput|updateGradInput|accGradParameters|backward)$'
388 exclude = 'LookupTable'
389 
390 
391 def run(paths):
392  function_backends = defaultdict(list)
393  header_functions = OrderedDict()
394 
395  headers = [p for p in paths if p.endswith('.h')]
396  yamls = [p for p in paths if p.endswith('.yaml')]
397 
398  for path in headers:
399  backend = 'CUDA' if re.search('THCU', path) else 'CPU'
400  for func in common_with_cwrap.parse_header(path):
401  if re.search(include_only, func.name) is None or re.search(exclude, func.name) is not None:
402  continue
403  function_backends[func.name].append(backend)
404  if func.name not in header_functions:
405  header_functions[func.name] = func
406 
407  bwd_suffixes = ['_updateGradInput', '_accGradParameters', '_backward']
408 
409  declarations = []
410  for path in yamls:
411  for func in parse_nn_yaml(path):
412  cname = func['cname']
413  backends = function_backends[cname + '_updateOutput']
414 
415  fwd_function = header_functions[cname + '_updateOutput']
416  bwd_functions = []
417  for suffix in bwd_suffixes:
418  if cname + suffix in header_functions:
419  bwd_functions.append(header_functions[cname + suffix])
420 
421  base = base_declaration(func, fwd_function, backends)
422  declarations.append(forward_declaration(base, fwd_function))
423  declarations.append(backward_declaration(base, bwd_functions))
424 
425  if func.get('has_inplace', False):
426  declarations.append(base_declaration(func, fwd_function, backends, True))
427  declarations.append(forward_declaration(base, fwd_function, True))
428 
429  return declarations