Caffe2 - Python API
A deep learning, cross platform ML framework
gen_variable_factories.py
1 # Generates C++ functions that wrap ATen tensor factory methods to turn them into Variables.
2 #
3 # This writes one file: variable_factories.h
4 
5 import re
6 
7 from .utils import CodeTemplate, write
8 from .gen_variable_type import format_trace
9 
10 
11 FUNCTION_TEMPLATE = CodeTemplate("""\
12 inline at::Tensor ${name}(${formals}) {
13  ${pre_record_trace}
14  at::Tensor tensor = at::${name}(${actuals});
15  at::Tensor result =
16  autograd::make_variable_consuming(std::move(tensor), /*requires_grad=*/${requires_grad});
17  ${post_record_trace}
18  return result;
19 }
20 """)
21 
22 
23 TYPE_PATTERN = re.compile(r"(?:const\s+)?([A-Z]\w+)")
24 
25 
26 def fully_qualified_type(argument_type):
27  match = TYPE_PATTERN.match(argument_type)
28  if match is None:
29  return argument_type
30  index = match.start(1)
31  return "{}at::{}".format(argument_type[:index], argument_type[index:])
32 
33 
34 def gen_variable_factories(out, declarations, template_path):
35  function_definitions = []
36  for decl in declarations:
37  has_tensor_options = any(a["simple_type"] == "TensorOptions" for a in decl["arguments"])
38  is_namespace_fn = 'namespace' in decl['method_of']
39  if (has_tensor_options or decl["name"].endswith("_like")) and is_namespace_fn:
40  function_definitions.append(process_function(decl, has_tensor_options))
41  write(out,
42  "variable_factories.h",
43  CodeTemplate.from_file(template_path + "/variable_factories.h"),
44  {"function_definitions": function_definitions})
45 
46 
47 def process_function(decl, has_tensor_options):
48  formals = []
49  actuals = []
50  for argument in decl["arguments"]:
51  type = fully_qualified_type(argument["type"])
52  default = " = {}".format(argument["default"]) if "default" in argument else ""
53  formals.append("{} {}{}".format(type, argument["name"], default))
54  actual = argument["name"]
55  if argument["simple_type"] == "TensorOptions":
56  # We want to make `at::{name}` always return a
57  # tensor and not a variable, since we create a variable right after.
58  actual = "at::TensorOptions({}).is_variable(false)".format(actual)
59  actuals.append(actual)
60  requires_grad = "options.requires_grad()" if has_tensor_options else "false"
61  if decl['name'].endswith('_like') and not has_tensor_options:
62  # it's a tensor
63  actuals.append('{}.options().is_variable(false)'.format(actuals[0]))
64 
65  pre_record_trace, post_record_trace = format_trace(decl)
66 
67  return FUNCTION_TEMPLATE.substitute(
68  name=decl["name"], formals=formals, actuals=actuals, requires_grad=requires_grad,
69  pre_record_trace=pre_record_trace, post_record_trace=post_record_trace
70  )