Caffe2 - Python API
A deep learning, cross platform ML framework
gen_autograd.py
1 """
2 To run this file by hand from the root of the PyTorch
3 repository, run:
4 
5 python -m tools.autograd.gen_autograd \
6  build/aten/src/ATen/Declarations.yaml \
7  $OUTPUT_DIR
8 
9 Where $OUTPUT_DIR is where you would like the files to be
10 generated. In the full build system, OUTPUT_DIR is
11 torch/csrc/autograd/generated/
12 """
13 
14 # gen_autograd.py generates C++ autograd functions and Python bindings.
15 #
16 # It delegates to the following scripts:
17 #
18 # gen_autograd_functions.py: generates subclasses of torch::autograd::Functions
19 # gen_variable_type.py: generates VariableType.h which contains all tensor methods
20 # gen_python_functions.py: generates Python bindings to THPVariable
21 #
22 
23 import argparse
24 import copy
25 import os
26 import yaml
27 import re
28 from collections import defaultdict
29 from .utils import YamlLoader, split_name_params
30 
31 # See NOTE [ Autograd View Variables ] in variable.h for details.
32 # A map: function name => two options:
33 # 1. name of the argument that all outputs are view of
34 # 2. map: output idx => name of the argument that this result is view of
35 VIEW_FUNCTIONS = {
36  'alias': 'self',
37  'as_strided': 'self',
38  'diagonal': 'self',
39  'expand': 'self',
40  'narrow': 'self',
41  'permute': 'self',
42  'select': 'self',
43  'slice': 'self',
44  'squeeze': 'self',
45  't': 'self',
46  'transpose': 'self',
47  'unfold': 'self',
48  'unsqueeze': 'self',
49  'view': 'self',
50  'unbind': 'self',
51  '_indices': 'self',
52  '_values': 'self',
53  'indices': 'self',
54  'values': 'self',
55  # sparse_coo ctor output should really be views of both indices and values,
56  # but we only supports making as view of a single varible, and indices is
57  # discrete anyways.
58  # FIXME: clone indices on construction.
59  'sparse_coo_tensor_with_dims_and_tensors': 'values',
60 }
61 
62 # note: some VIEW_FUNCTIONS are just compositions of the view functions above
63 # this list contains both the root view functions and any that are purely composed
64 # of viewing functions, and is used by the JIT to determine when an operator
65 # returns a view of its inputs
66 RETURNS_VIEWS_OF_INPUT = set(VIEW_FUNCTIONS.keys()).union({'chunk', 'split'})
67 
68 
69 def format_return_type(returns):
70  if len(returns) == 0:
71  return 'void'
72  elif len(returns) == 1:
73  return returns[0]['type']
74  else:
75  return_types = [r['type'] for r in returns]
76  return 'std::tuple<{}>'.format(','.join(return_types))
77 
78 
79 def get_simple_type(arg):
80  simple_type = arg['type']
81  simple_type = simple_type.replace(' &', '').replace('const ', '')
82  simple_type = simple_type.replace('Generator *', 'Generator')
83 
84  opt_match = re.match(r'c10::optional<(.+)>', simple_type)
85  if opt_match:
86  simple_type = '{}?'.format(opt_match.group(1))
87  return simple_type
88 
89 
90 def load_aten_declarations(path):
91  with open(path, 'r') as f:
92  declarations = yaml.load(f, Loader=YamlLoader)
93 
94  # enrich declarations with additional information
95  selected_declarations = []
96  for declaration in declarations:
97  if declaration.get('deprecated'):
98  continue
99 
100  for arg in declaration['arguments']:
101  arg['simple_type'] = get_simple_type(arg)
102  for ret in declaration['returns']:
103  ret['simple_type'] = get_simple_type(ret)
104 
105  declaration['formals'] = [arg['type'] + ' ' + arg['name']
106  for arg in declaration['arguments']]
107  declaration['args'] = [arg['name'] for arg in declaration['arguments']]
108  declaration['type_method_formals'] = [arg['type'] + ' ' + arg['name']
109  for arg in declaration['arguments']]
110  declaration['type_method_args'] = [arg['name'] for arg in declaration['arguments']]
111  declaration['api_name'] = declaration['name']
112  declaration['return_type'] = format_return_type(declaration['returns'])
113 
114  declaration['base_name'] = declaration['name']
115  selected_declarations.append(declaration)
116 
117  return selected_declarations
118 
119 
120 def load_deprecated_signatures(aten_decls, deprecated_path):
121  def group_declarations_by_signature():
122  d = defaultdict(list)
123  for declaration in aten_decls:
124  name = declaration['name']
125  base_name = name[:-1] if declaration['inplace'] else name
126  simple_types = [arg['simple_type'] for arg in declaration['arguments']]
127  signature = '{}({})'.format(base_name, ', '.join(simple_types))
128  d[signature].append(declaration)
129  return d
130 
131  with open(deprecated_path, 'r') as f:
132  deprecated_defs = yaml.load(f, Loader=YamlLoader)
133  declarations = []
134  declarations_by_signature = group_declarations_by_signature()
135 
136  def get_signature(name, params, call_args):
137  # create a mapping of parameter name to parameter type
138  types = dict([param.split(' ')[::-1] for param in params if param != '*'])
139  # if the name in the call is not in the parameter list, assume it's
140  # a literal Scalar
141  rearranged_types = [types.get(arg, 'Scalar') for arg in call_args]
142  return '{}({})'.format(name, ', '.join(rearranged_types))
143 
144  for deprecated in deprecated_defs:
145  aten_name, call_args = split_name_params(deprecated['aten'])
146  name, params = split_name_params(deprecated['name'])
147  signature = get_signature(aten_name, params, call_args)
148 
149  for declaration in declarations_by_signature[signature]:
150  declaration = copy.deepcopy(declaration)
151  declaration['deprecated'] = True
152  declaration['call_args'] = call_args
153 
154  call_arg_to_idx = {arg: i for i, arg in enumerate(call_args)}
155  original_args = declaration['arguments']
156 
157  # Create an arguments list that uses the types from the original
158  # ATen declaration, but the ordering and parameter names from
159  # the deprecated overload. Any default parameter values from the
160  # original ATen declaration are ignored.
161  arguments = []
162  kwarg_only = False
163  for param in params:
164  if param == '*':
165  kwarg_only = True
166  continue
167  _, param_name = param.split(' ')
168  original = original_args[call_arg_to_idx[param_name]]
169  arguments.append({
170  'name': param_name,
171  'kwarg_only': kwarg_only,
172  'type': original['type'],
173  'simple_type': original['simple_type'],
174  'dynamic_type': original['dynamic_type'],
175  'output': original.get('output', False),
176  })
177  declaration['arguments'] = arguments
178  declarations.append(declaration)
179  return declarations
180 
181 
182 def gen_autograd(aten_path, out, autograd_dir):
183  aten_decls = load_aten_declarations(aten_path)
184 
185  # Parse and load derivatives.yaml
186  from .load_derivatives import load_derivatives
187  autograd_functions = load_derivatives(
188  os.path.join(autograd_dir, 'derivatives.yaml'), aten_decls)
189 
190  template_path = os.path.join(autograd_dir, 'templates')
191 
192  # Generate VariableType.h/cpp
193  from .gen_variable_type import gen_variable_type
194  gen_variable_type(out, aten_decls, template_path)
195 
196  # Generate Functions.h/cpp
197  from .gen_autograd_functions import gen_autograd_functions
198  gen_autograd_functions(
199  out, autograd_functions, template_path)
200 
201  # Load deprecated signatures
202  deprecated = load_deprecated_signatures(
203  aten_decls, os.path.join(autograd_dir, 'deprecated.yaml'))
204 
205  # Generate Python bindings
206  from . import gen_python_functions
207  gen_python_functions.gen_py_variable_methods(
208  out, aten_decls + deprecated, template_path)
209  gen_python_functions.gen_py_torch_functions(
210  out, aten_decls + deprecated, template_path)
211  gen_python_functions.gen_py_nn_functions(
212  out, aten_decls, template_path)
213 
214  # Generate variable_factories.h
215  from .gen_variable_factories import gen_variable_factories
216  gen_variable_factories(out, aten_decls, template_path)
217 
218 
219 def main():
220  parser = argparse.ArgumentParser(
221  description='Generate autograd C++ files script')
222  parser.add_argument('declarations', metavar='DECL',
223  help='path to Declarations.yaml')
224  parser.add_argument('out', metavar='OUT',
225  help='path to output directory')
226  parser.add_argument('autograd', metavar='AUTOGRAD',
227  help='path to autograd directory')
228  args = parser.parse_args()
229  gen_autograd(args.declarations, args.out, args.autograd)
230 
231 
232 if __name__ == '__main__':
233  main()