3 import common_with_cwrap
5 from collections
import OrderedDict, defaultdict
9 from yaml
import CLoader
as Loader
11 from yaml
import Loader
15 NAME_PARAM_REGEX =
r'(\w+)\((.*)\)' 18 def argument_to_declaration(param, func=None):
20 arg[
'type'], name = param.split(
' ')
21 if (arg[
'type'].endswith(
'?')):
22 arg[
'is_nullable'] =
True 23 arg[
'type'] = arg[
'type'].rstrip(
'?')
24 if arg[
'type'] ==
'Tensor':
25 arg[
'type'] =
'THTensor*' 26 elif arg[
'type'] ==
'LongTensor':
27 arg[
'type'] =
'THIndexTensor*' 28 elif arg[
'type'] ==
'Scalar':
29 arg[
'type'] =
'accreal' 30 elif arg[
'type'] ==
'Generator*':
31 arg[
'type'] =
'THGenerator*' 33 match = re.match(
r'IntArrayRef\[(\d+)\]', arg[
'type'])
35 arg[
'type'] =
'IntArrayRef' 36 arg[
'size'] = int(match.group(1))
39 name, default = name.split(
'=')
40 arg[
'optional'] =
True 41 arg[
'default'] = default
45 default_inits = func.get(
'default_init', {})
46 wrap_dims = func.get(
'wrap_dim', {})
47 if name
in default_inits:
49 arg[
'default_init'] = default_inits[name]
51 arg[
'wrap_dim'] = wrap_dims[name]
56 def output_arguments(thnn_function):
57 cname = thnn_function.name
64 def map_to_th_type(t):
65 if t.startswith(
'THC'):
66 t = t.replace(
'THC',
'TH')
69 def is_output_arg(arg_name, func_name):
70 if arg_name ==
'output' and 'updateOutput' in cname:
72 if name
in {
'gradInput',
'gradWeight',
'gradBias',
'gradGrid'}:
74 if arg_name ==
'indices' and 'updateOutput' in cname
and 'Unpool' not in cname:
79 for arg
in thnn_function.arguments:
81 if is_output_arg(name, cname):
83 'type': map_to_th_type(arg.type),
84 'name': camel_to_snake(name),
87 if name.startswith(
'grad_'):
88 desc[
'is_nullable'] =
True 89 output_args.append(desc)
94 indices = [str(idx)
for idx, arg
in enumerate(args)
if arg.get(
'output')]
95 return 'argument {}'.format(
','.join(indices))
104 'osize':
'output_size',
105 'output':
'output_size',
106 'isize':
'input_size',
107 'dilation':
'dilation',
108 'adj':
'output_padding',
109 'a':
'output_padding',
134 'negval':
'negative_slope',
138 def camel_to_snake(name):
140 s1 = re.sub(
'(.)([A-Z][a-z]+)',
r'\1_\2', name)
141 return re.sub(
'([a-z0-9])([A-Z])',
r'\1_\2', s1).lower()
144 def get_thnn_args(thnn_function, params, inplace):
145 params_by_name = {p[
'name']: p
for p
in params}
147 def arg_expr(prefix, suffix):
149 name = ARGUMENT_MAPPINGS[prefix]
150 if name
not in params_by_name:
151 raise RuntimeError(
'missing arg "{}" in {}'.format(name, thnn_function.name))
152 param = params_by_name[name]
153 if param[
'type'] ==
'IntArrayRef' and 'size' in param:
162 index = DIMENSION_OFFSET[suffix]
164 index += param[
'size']
165 expr =
'{}[{}]'.format(name, index)
166 return {
'type':
'EXPRESSION',
'name': expr}
169 for arg
in thnn_function.arguments:
173 if inplace
and name ==
'output':
175 aten_name = camel_to_snake(SUBSTITUTIONS.get(name, name))
176 parts = aten_name.split(
'_')
177 if aten_name
in params_by_name:
178 param = params_by_name[aten_name]
180 param[
'is_nullable'] =
True 181 thnn_args.append(copy.deepcopy(param))
182 elif len(parts) == 2
and parts[0]
in ARGUMENT_MAPPINGS
and parts[1]
in DIMENSION_OFFSET:
184 thnn_args.append(arg_expr(parts[0], parts[1]))
185 elif name[-1]
in DIMENSION_OFFSET
and name[:-1]
in ARGUMENT_MAPPINGS:
187 thnn_args.append(arg_expr(name[:-1], name[-1]))
188 elif name ==
'owidth' or name ==
'oheight':
189 thnn_args.append(arg_expr(name[0], name[1:]))
190 elif name ==
'scale':
191 thnn_args.append({
'type':
'EXPRESSION',
'name':
'1'})
192 elif name ==
'inplace':
193 thnn_args.append({
'type':
'EXPRESSION',
'name': str(inplace).lower()})
195 raise RuntimeError(
"{}: can't find binding for '{}'" 196 .format(thnn_function.name, name))
200 def remove_unused_args(args, thnn_args):
201 """Returns the subset of args whose name appears in thnn_args""" 202 def clean_name(name):
203 name = name[:name.index(
'[')]
if '[' in name
else name
204 if name.endswith(
'_'):
207 uses = set([clean_name(arg[
'name'])
for arg
in thnn_args])
208 uses.add(
'output_mask')
209 args = [arg
for arg
in args
if arg[
'name']
in uses]
216 def unique_args(argslist):
219 for args
in argslist:
221 if arg[
'name']
in seen:
223 seen.add(arg[
'name'])
228 def function_info(name, arguments, cimpls, buffers, backends, inplace, scalar_check):
230 cimpls contains information use to call into THNN: 231 cname: THNN function name 232 arguments: arguments to functional call 233 condition: [optional] guard around call 238 'types': [
'Float',
'Double',
'Half'],
239 'arguments': arguments,
240 'return':
'argument 0' if inplace
else get_return(arguments),
242 'backends': backends,
244 'scalar_check': scalar_check,
245 'variants': [
'function'],
249 def base_declaration(func, thnn_function, backends, inplace=False):
250 """Creates the NN function without any buffers in it's signature""" 251 name, params = re.match(NAME_PARAM_REGEX, func[
'name']).groups()
254 params = params.split(
', ')
255 arguments = [argument_to_declaration(a, func)
for a
in params]
257 arguments += output_arguments(thnn_function)
258 buffers = [argument_to_declaration(
'Tensor ' + buf)
259 for buf
in func.get(
'buffers', [])]
261 return function_info(name, arguments,
None, buffers, backends, inplace, func.get(
'scalar_check'))
264 def forward_declaration(base, thnn_function, inplace=False):
265 name =
'{}_forward'.format(base[
'name'])
269 arguments = [copy.deepcopy(arg)
for arg
in base[
'arguments']
270 if not arg.get(
'output')]
272 arguments += output_arguments(thnn_function)
273 for buffer
in base[
'buffers']:
274 buffer = copy.deepcopy(buffer)
275 buffer[
'output'] =
True 276 arguments.append(buffer)
278 thnn_args = get_thnn_args(thnn_function, arguments, inplace)
279 arguments = remove_unused_args(arguments, thnn_args)
280 cimpl = {
'cname': thnn_function.name,
'arguments': thnn_args}
282 scalar_check = base[
'scalar_check']
283 if scalar_check
is not None:
284 output_arg_names = [arg[
'name']
for arg
in arguments
if arg.get(
'output',
False)]
285 scalar_check = {k: v
for (k, v)
in scalar_check.items()
if k
in output_arg_names}
287 return function_info(name, arguments, [cimpl], [], base[
'backends'], inplace, scalar_check)
290 def backward_declaration(base, thnn_functions):
291 name =
'{}_backward'.format(base[
'name'])
294 arguments.append({
'type':
'THTensor*',
'name':
'grad_output'})
295 arguments += [copy.deepcopy(arg)
for arg
in base[
'arguments']
296 if arg[
'name'] !=
'inplace']
297 arguments += base[
'buffers']
299 if 'upsample' in base[
'name']:
302 size = 2 + int(re.search(
r'(\d+)d', base[
'name']).group(1))
303 input_size_arg = {
'type':
'IntArrayRef',
'name':
'input_size',
'size': size}
304 for output_size_idx, arg
in enumerate(arguments):
305 if arg[
'name'] ==
'output_size':
307 arguments.insert(output_size_idx + 1, input_size_arg)
309 if 'im2col' in base[
'name']:
311 input_size_arg = {
'type':
'IntArrayRef',
'name':
'input_size',
'size': 2}
312 arguments.insert(2, input_size_arg)
315 for arg
in arguments:
319 arguments += unique_args([output_arguments(f)
for f
in thnn_functions])
321 def initialize_output_arg(arg):
324 arg[
'is_nullable'] =
True 327 if arg[
'name'] ==
'grad_weight':
328 arg[
'resize'] =
'weight' 330 if arg[
'name'] ==
'grad_bias':
331 dim = 1
if 'transpose' in name
else 0
332 arg[
'resize'] = [(
'weight', dim)]
335 is_batch_norm_backward =
'_backward' in thnn_functions[0].name
337 if len(thnn_functions) > 1
or is_batch_norm_backward:
338 for arg
in arguments:
339 if arg.get(
'output',
False):
340 initialize_output_arg(arg)
341 if 'Tensor' in arg[
'type']
and arg[
'name'].startswith(
'grad_')
and \
342 'input' not in arg[
'name']
and 'output' not in arg[
'name']:
343 grad_params.append(arg[
'name'])
345 thnn_args = [get_thnn_args(f, arguments,
False)
for f
in thnn_functions]
346 arguments = remove_unused_args(arguments, unique_args(thnn_args))
349 def get_condition(func):
351 if '_updateGradInput' in func.name:
353 if '_accGradParameters' in func.name:
354 return ' || '.join(p +
'_' for p
in grad_params)
357 for func, args
in zip(thnn_functions, thnn_args):
358 cimpl = {
'cname': func.name,
'arguments': args}
359 if len(thnn_functions) > 1:
360 cimpl[
'condition'] = get_condition(func)
363 output_args = [arg
for arg
in arguments
if arg.get(
'output',
False)]
364 scalar_check_arg = base[
'scalar_check']
if base[
'scalar_check']
is not None else dict()
365 scalar_check = {k: v
for (k, v)
in scalar_check_arg.items()
if k
in [a[
'name']
for a
in output_args]}
366 for arg
in output_args:
368 if scalar_check.get(arg[
'name'])
is not None or arg.get(
'resize',
False):
371 base_name = arg[
'name'][len(
'grad_'):]
if arg[
'name'] !=
'grad_input' else 'self' 372 if base_name
in [a[
'name']
for a
in arguments]:
373 scalar_check[arg[
'name']] = base_name +
'_->dim() == 0' 375 raise ValueError((
"Could not infer scalar_check for {} argument of func {} because {} " 376 "does not exist. Please explicitly specify scalar_check." 377 .format(arg[
'name'], name, base_name)))
379 return function_info(name, arguments, cimpls, [], base[
'backends'],
False, scalar_check)
382 def parse_nn_yaml(filename):
383 with open(filename,
'r') as f: 384 return yaml.load(f, Loader=Loader)
387 include_only =
'(updateOutput|updateGradInput|accGradParameters|backward)$' 388 exclude =
'LookupTable' 392 function_backends = defaultdict(list)
393 header_functions = OrderedDict()
395 headers = [p
for p
in paths
if p.endswith(
'.h')]
396 yamls = [p
for p
in paths
if p.endswith(
'.yaml')]
399 backend =
'CUDA' if re.search(
'THCU', path)
else 'CPU' 401 if re.search(include_only, func.name)
is None or re.search(exclude, func.name)
is not None:
403 function_backends[func.name].append(backend)
404 if func.name
not in header_functions:
405 header_functions[func.name] = func
407 bwd_suffixes = [
'_updateGradInput',
'_accGradParameters',
'_backward']
411 for func
in parse_nn_yaml(path):
412 cname = func[
'cname']
413 backends = function_backends[cname +
'_updateOutput']
415 fwd_function = header_functions[cname +
'_updateOutput']
417 for suffix
in bwd_suffixes:
418 if cname + suffix
in header_functions:
419 bwd_functions.append(header_functions[cname + suffix])
421 base = base_declaration(func, fwd_function, backends)
422 declarations.append(forward_declaration(base, fwd_function))
423 declarations.append(backward_declaration(base, bwd_functions))
425 if func.get(
'has_inplace',
False):
426 declarations.append(base_declaration(func, fwd_function, backends,
True))
427 declarations.append(forward_declaration(base, fwd_function,
True))