7 from .utils
import CodeTemplate, write
8 from .gen_variable_type
import format_trace
11 FUNCTION_TEMPLATE = CodeTemplate(
"""\ 12 inline at::Tensor ${name}(${formals}) { 14 at::Tensor tensor = at::${name}(${actuals}); 16 autograd::make_variable_consuming(std::move(tensor), /*requires_grad=*/${requires_grad}); 23 TYPE_PATTERN = re.compile(
r"(?:const\s+)?([A-Z]\w+)")
26 def fully_qualified_type(argument_type):
27 match = TYPE_PATTERN.match(argument_type)
30 index = match.start(1)
31 return "{}at::{}".format(argument_type[:index], argument_type[index:])
34 def gen_variable_factories(out, declarations, template_path):
35 function_definitions = []
36 for decl
in declarations:
37 has_tensor_options = any(a[
"simple_type"] ==
"TensorOptions" for a
in decl[
"arguments"])
38 is_namespace_fn =
'namespace' in decl[
'method_of']
39 if (has_tensor_options
or decl[
"name"].endswith(
"_like"))
and is_namespace_fn:
40 function_definitions.append(process_function(decl, has_tensor_options))
42 "variable_factories.h",
43 CodeTemplate.from_file(template_path +
"/variable_factories.h"),
44 {
"function_definitions": function_definitions})
47 def process_function(decl, has_tensor_options):
50 for argument
in decl[
"arguments"]:
51 type = fully_qualified_type(argument[
"type"])
52 default =
" = {}".format(argument[
"default"])
if "default" in argument
else "" 53 formals.append(
"{} {}{}".format(type, argument[
"name"], default))
54 actual = argument[
"name"]
55 if argument[
"simple_type"] ==
"TensorOptions":
58 actual =
"at::TensorOptions({}).is_variable(false)".format(actual)
59 actuals.append(actual)
60 requires_grad =
"options.requires_grad()" if has_tensor_options
else "false" 61 if decl[
'name'].endswith(
'_like')
and not has_tensor_options:
63 actuals.append(
'{}.options().is_variable(false)'.format(actuals[0]))
65 pre_record_trace, post_record_trace = format_trace(decl)
67 return FUNCTION_TEMPLATE.substitute(
68 name=decl[
"name"], formals=formals, actuals=actuals, requires_grad=requires_grad,
69 pre_record_trace=pre_record_trace, post_record_trace=post_record_trace