1 from __future__
import print_function
10 from yaml
import CLoader
as Loader
12 from yaml
import Loader
22 def type_argument_translations(arg):
23 type_and_name = [a.strip()
for a
in arg.rsplit(
' ', 1)]
25 if len(type_and_name) > 1:
26 name = type_and_name[1]
28 name = name.split(
'=')
36 match = re.match(
r'(Tensor.*)\((.+)\)(.*)', t)
39 t = match.group(1) + match.group(3)
40 annotation = match.group(2)
46 nullable = (t !=
'Generator?' and '?' in t)
50 if t ==
'Generator?' and default ==
'None':
54 elif t ==
"Generator?":
57 elif t ==
'Tensor[]' or t ==
'Tensor?[]':
68 raise RuntimeError(
"Please use int and not int64_t. " 69 "See [temp translations] for details.")
71 raise RuntimeError(
"Please use int? and not int64_t?. " 72 "See [temp translations] for details.")
80 raise RuntimeError(
"Please use float and not double. " 81 "See [temp translations] for details.")
83 elif re.match(
r'int\[(\d+)\]', t):
84 match = re.match(
r'int\[(\d+)\]', t)
86 size = int(match.group(1))
88 elif re.match(
r'bool\[(\d+)\]', t):
89 match = re.match(
r'bool\[(\d+)\]', t)
90 t =
'std::array<bool,{}>'.format(match.group(1))
91 elif re.match(
r'std::array', t):
92 raise RuntimeError(
"Please use array notation, e.g. bool[3] and not std::array." 93 "See [temp translations] for details.")
103 elif t.startswith(
'Tensor?')
and default ==
'None':
105 elif default ==
'True':
107 elif default ==
'False':
109 elif default ==
'true':
110 raise RuntimeError(
"Please use True and not true. " 111 "See [temp translations] for details.")
112 elif default ==
'false':
113 raise RuntimeError(
"Please use False and not false. " 114 "See [temp translations] for details.")
117 elif default ==
'[]':
121 elif re.match(
r'\[.*\]', default):
122 default =
"{" + default[1:-1] +
"}" 123 elif default ==
'None':
124 default =
'c10::nullopt' 128 elif default ==
'Mean':
129 default =
'Reduction::Mean' 132 default = int(default)
135 default = float(default)
139 return t, name, default, nullable, size, annotation
142 def parse_arguments(args, func_variants, declaration, func_return):
146 if len(args.strip()) == 0:
151 for arg_idx, arg
in enumerate(args.split(
', ')):
152 type_and_name = [a.strip()
for a
in arg.rsplit(
' ', 1)]
153 if type_and_name == [
'*']:
154 assert not kwarg_only
158 t, name, default, nullable, size, annotation = type_argument_translations(arg)
160 argument_dict = {
'type': t.rstrip(
'?'),
'name': name,
'is_nullable': nullable,
'annotation': annotation}
162 argument_dict[
'size'] = size
163 if default
is not None:
164 argument_dict[
'default'] = default
166 argument_dict[
'kwarg_only'] =
True 167 arguments.append(argument_dict)
172 for argument
in arguments:
173 if argument[
'type'] ==
"Tensor" and \
174 argument[
'annotation']
and \
175 re.match(
r'^(.*!)$', argument[
'annotation'])
and \
176 argument.get(
'kwarg_only'):
177 argument[
'output'] =
True 178 argument[
'kwarg_only'] =
False 179 arguments_out.append(argument)
182 arguments_other.append(argument)
184 arguments = arguments_out + arguments_other
186 name = declaration[
'name']
188 declaration[
'name'] +=
"_out" 200 supported_topt_arguments = [
202 {
'name':
'dtype',
'type':
'ScalarType',
'is_nullable':
False,
'annotation':
None},
203 {
'name':
'layout',
'type':
'Layout',
'is_nullable':
False,
'annotation':
None},
204 {
'name':
'device',
'type':
'Device',
'is_nullable':
False,
'annotation':
None},
207 supported_topt_arguments.append(copy.deepcopy(supported_topt_arguments[0]))
208 supported_topt_arguments[1][0][
'kwarg_only'] =
True 209 supported_topt_arguments[1][1][
'kwarg_only'] =
True 210 supported_topt_arguments[1][2][
'kwarg_only'] =
True 211 supported_topt_arguments.append(copy.deepcopy(supported_topt_arguments[1]))
212 supported_topt_arguments[2][0][
'default'] =
'c10::nullopt' 213 supported_topt_arguments[2][1][
'default'] =
'c10::nullopt' 214 supported_topt_arguments[2][2][
'default'] =
'c10::nullopt' 215 supported_topt_arguments[2][0][
'is_nullable'] =
True 216 supported_topt_arguments[2][1][
'is_nullable'] =
True 217 supported_topt_arguments[2][2][
'is_nullable'] =
True 219 corresponding_topts = [
220 {
'type':
'TensorOptions',
'name':
'options',
'is_nullable':
False,
'annotation':
None},
222 corresponding_topts.append(corresponding_topts[0].copy())
223 corresponding_topts[1][
'kwarg_only'] =
True 224 corresponding_topts.append(corresponding_topts[1].copy())
225 corresponding_topts[2][
'default'] =
'{}' 227 def check_topt_representation(topt_representation):
228 for idx, supported_topt
in enumerate(supported_topt_arguments):
230 matches = matches
and topt_representation[0] == supported_topt[0]
231 matches = matches
and topt_representation[1] == supported_topt[1]
232 matches = matches
and topt_representation[2] == supported_topt[2]
234 return corresponding_topts[idx]
237 def is_tensor_option(argument):
238 return argument[
'name']
in [
'dtype',
'layout',
'device']
242 while idx < len(arguments):
243 argument = arguments[idx]
244 if is_tensor_option(argument)
and len(arguments) - idx >= 3:
245 topt_representation = []
247 argument = arguments[idx]
248 if not is_tensor_option(argument):
250 topt_representation.append(argument)
252 if len(topt_representation) == 3:
253 merged_argument = check_topt_representation(topt_representation)
254 assert merged_argument, \
255 "Unsupported combination of TensorOptions {}, the only currently supported combinations are {}"\
256 .format(str(topt_representation), str(supported_topt_arguments))
257 new_arguments.append(merged_argument)
259 new_arguments += topt_representation
261 new_arguments.append(argument)
264 arguments = new_arguments
270 for arg_idx, argument
in enumerate(arguments_out):
271 assert argument[
'annotation'] == func_return[arg_idx][
'annotation'], \
272 "For func {} writeable keyword Tensor arguments need to have a matching return Tensor. Further, " \
273 "the ith-argument needs to correspond to the i-th return.".format(name)
275 assert len(arguments_out) <= len(func_return),
"func {} must return at least as many Tensors " \
276 "as can be passed as output.".format(name)
278 if name.endswith(
'_out'):
279 raise RuntimeError(
"Native function {} may not be suffixed with _out as we transition to a unified schema. " 280 "Otherwise you will cause confusion amongst consumers of native functions.".format(name))
282 if is_out_fn
and func_variants
not in [[],
'function', [
'function']]:
283 raise RuntimeError(
"Native functions with output MUST be declared with only the function variant; " 284 "e.g., variants: function; otherwise you will tickle a Python argument binding bug " 285 "(which usually manifests itself as the result variable being undefined.) " 286 "The culprit was: {}".format(name))
288 assert len(arguments_out) == 0,
"func {} is not marked as output yet contains output " \
289 "keyword arguments".format(name)
293 if declaration[
'inplace']
and len(func_return) > 0
and func_return[0][
'type'] !=
"void":
295 for arg_idx, argument
in enumerate(arguments):
296 if argument[
'name'] ==
"self":
297 assert argument[
'annotation']
and argument[
'annotation'].endswith(
"!"), \
298 "Inplace function \"{}\" needs to annotate Tensor argument named self " \
299 "as mutable.".format(name)
301 assert argument[
'annotation'] == func_return[arg_idx][
'annotation'], \
302 "Inplace function annotations of function {} need to match between " \
303 "input and correponding output.".format(name)
304 assert argument[
'name'] == func_return[arg_idx][
'name']
or \
305 argument[
'name'] == func_return[arg_idx][
'name'] +
"_return" 306 assert argument[
'type'] == func_return[arg_idx][
'type']
307 assert found_self,
"Inplace function \"{}\" needs Tensor argument named self.".format(name)
312 def parse_return_arguments(return_decl, inplace, func_decl):
316 if return_decl[0] ==
'(' and return_decl[-1] ==
')':
317 return_decl = return_decl[1:-1]
318 multiple_args = len(return_decl.split(
', ')) > 1
320 for arg_idx, arg
in enumerate(return_decl.split(
', ')):
321 t, name, default, nullable, size, annotation = type_argument_translations(arg)
325 if name
in func_decl[
'func'].
split(
'->')[0]:
326 return_name = name +
"_return" 327 argument_dict = {
'type': t,
'name': return_name,
'annotation': annotation}
330 argument_dict[
'field_name'] = name
332 if t ==
"Tensor" and inplace:
333 assert annotation
and annotation.endswith(
"!"), \
334 "Return Tensor of function \"{}\" flagged as inplace needs to be " \
335 "annotated as mutable".format(func_decl[
'func'])
336 argument_dict[
'name'] =
'self' 338 argument_dict[
'name'] =
'result' if not multiple_args
else 'result' + str(arg_idx)
339 argument_dict[
'output'] =
True 340 arguments.append(argument_dict)
344 def has_sparse_dispatches(dispatches):
345 for dispatch
in dispatches:
346 if 'Sparse' in dispatch:
351 def parse_native_yaml(path):
352 with open(path,
'r') as f: 353 return yaml.load(f, Loader=Loader)
356 def propagate_field_names(output_arguments, return_arguments):
358 for i, r
in enumerate(return_arguments):
359 if 'field_name' in r:
360 output_arguments[i][
'field_name'] = r[
'field_name']
366 for func
in parse_native_yaml(path):
367 declaration = {
'mode':
'native'}
369 declaration[
'schema_string'] =
"aten::" + func[
'func']
370 if '->' in func[
'func']:
371 func_decl, return_decl = [x.strip()
for x
in func[
'func'].
split(
'->')]
373 raise Exception(
'Expected return declaration')
374 fn_name, arguments = func_decl.split(
'(', 1)
375 assert arguments[-1] ==
")",
"Expecting closing ) for {}".format(func[
'func'])
376 arguments = arguments[:-1]
377 declaration[
'name'] = func.get(
'name', fn_name)
378 declaration[
'inplace'] = re.search(
'(^__i|[^_]_$)', fn_name)
is not None 379 return_arguments = parse_return_arguments(return_decl, declaration[
'inplace'], func)
380 arguments = parse_arguments(arguments, func.get(
'variants', []), declaration, return_arguments)
381 output_arguments = [x
for x
in arguments
if x.get(
'output')]
382 propagate_field_names(output_arguments, return_arguments)
383 declaration[
'return'] = return_arguments
if len(output_arguments) == 0
else output_arguments
384 declaration[
'variants'] = func.get(
'variants', [
'function'])
385 declaration[
'requires_tensor'] = func.get(
'requires_tensor',
False)
386 declaration[
'matches_jit_signature'] = func.get(
'matches_jit_signature',
False)
387 declaration[
'cpu_half'] = func.get(
'cpu_half',
False)
388 declaration[
'cpu_bool'] = func.get(
'cpu_bool',
False)
389 declaration[
'deprecated'] = func.get(
'deprecated',
False)
390 declaration[
'device_guard'] = func.get(
'device_guard',
True)
391 declaration[
'arguments'] = func.get(
'arguments', arguments)
392 declaration[
'type_method_definition_dispatch'] = func.get(
'dispatch', declaration[
'name'])
393 declaration[
'python_module'] = func.get(
'python_module',
'')
394 declarations.append(declaration)
395 except Exception
as e:
396 msg =
'''Exception raised in processing function: 398 Generated partial declaration: 399 {decl}'''.format(func=pprint.pformat(func), decl=pprint.pformat(declaration))
400 print(msg, file=sys.stderr)
Module caffe2.python.layers.split.