2 To run this file by hand from the root of the PyTorch 5 python -m tools.autograd.gen_autograd \ 6 build/aten/src/ATen/Declarations.yaml \ 9 Where $OUTPUT_DIR is where you would like the files to be 10 generated. In the full build system, OUTPUT_DIR is 11 torch/csrc/autograd/generated/ 28 from collections
import defaultdict
29 from .utils
import YamlLoader, split_name_params
59 'sparse_coo_tensor_with_dims_and_tensors':
'values',
66 RETURNS_VIEWS_OF_INPUT = set(VIEW_FUNCTIONS.keys()).union({
'chunk',
'split'})
69 def format_return_type(returns):
72 elif len(returns) == 1:
73 return returns[0][
'type']
75 return_types = [r[
'type']
for r
in returns]
76 return 'std::tuple<{}>'.format(
','.join(return_types))
79 def get_simple_type(arg):
80 simple_type = arg[
'type']
81 simple_type = simple_type.replace(
' &',
'').replace(
'const ',
'')
82 simple_type = simple_type.replace(
'Generator *',
'Generator')
84 opt_match = re.match(
r'c10::optional<(.+)>', simple_type)
86 simple_type =
'{}?'.format(opt_match.group(1))
90 def load_aten_declarations(path):
91 with open(path,
'r') as f: 92 declarations = yaml.load(f, Loader=YamlLoader) 95 selected_declarations = []
96 for declaration
in declarations:
97 if declaration.get(
'deprecated'):
100 for arg
in declaration[
'arguments']:
101 arg[
'simple_type'] = get_simple_type(arg)
102 for ret
in declaration[
'returns']:
103 ret[
'simple_type'] = get_simple_type(ret)
105 declaration[
'formals'] = [arg[
'type'] +
' ' + arg[
'name']
106 for arg
in declaration[
'arguments']]
107 declaration[
'args'] = [arg[
'name']
for arg
in declaration[
'arguments']]
108 declaration[
'type_method_formals'] = [arg[
'type'] +
' ' + arg[
'name']
109 for arg
in declaration[
'arguments']]
110 declaration[
'type_method_args'] = [arg[
'name']
for arg
in declaration[
'arguments']]
111 declaration[
'api_name'] = declaration[
'name']
112 declaration[
'return_type'] = format_return_type(declaration[
'returns'])
114 declaration[
'base_name'] = declaration[
'name']
115 selected_declarations.append(declaration)
117 return selected_declarations
120 def load_deprecated_signatures(aten_decls, deprecated_path):
121 def group_declarations_by_signature():
122 d = defaultdict(list)
123 for declaration
in aten_decls:
124 name = declaration[
'name']
125 base_name = name[:-1]
if declaration[
'inplace']
else name
126 simple_types = [arg[
'simple_type']
for arg
in declaration[
'arguments']]
127 signature =
'{}({})'.format(base_name,
', '.join(simple_types))
128 d[signature].append(declaration)
131 with open(deprecated_path,
'r') as f: 132 deprecated_defs = yaml.load(f, Loader=YamlLoader) 134 declarations_by_signature = group_declarations_by_signature() 136 def get_signature(name, params, call_args):
138 types = dict([param.split(
' ')[::-1]
for param
in params
if param !=
'*'])
141 rearranged_types = [types.get(arg,
'Scalar')
for arg
in call_args]
142 return '{}({})'.format(name,
', '.join(rearranged_types))
144 for deprecated
in deprecated_defs:
145 aten_name, call_args = split_name_params(deprecated[
'aten'])
146 name, params = split_name_params(deprecated[
'name'])
147 signature = get_signature(aten_name, params, call_args)
149 for declaration
in declarations_by_signature[signature]:
150 declaration = copy.deepcopy(declaration)
151 declaration[
'deprecated'] =
True 152 declaration[
'call_args'] = call_args
154 call_arg_to_idx = {arg: i
for i, arg
in enumerate(call_args)}
155 original_args = declaration[
'arguments']
167 _, param_name = param.split(
' ')
168 original = original_args[call_arg_to_idx[param_name]]
171 'kwarg_only': kwarg_only,
172 'type': original[
'type'],
173 'simple_type': original[
'simple_type'],
174 'dynamic_type': original[
'dynamic_type'],
175 'output': original.get(
'output',
False),
177 declaration[
'arguments'] = arguments
178 declarations.append(declaration)
182 def gen_autograd(aten_path, out, autograd_dir):
183 aten_decls = load_aten_declarations(aten_path)
186 from .load_derivatives
import load_derivatives
187 autograd_functions = load_derivatives(
188 os.path.join(autograd_dir,
'derivatives.yaml'), aten_decls)
190 template_path = os.path.join(autograd_dir,
'templates')
193 from .gen_variable_type
import gen_variable_type
194 gen_variable_type(out, aten_decls, template_path)
197 from .gen_autograd_functions
import gen_autograd_functions
198 gen_autograd_functions(
199 out, autograd_functions, template_path)
202 deprecated = load_deprecated_signatures(
203 aten_decls, os.path.join(autograd_dir,
'deprecated.yaml'))
206 from .
import gen_python_functions
207 gen_python_functions.gen_py_variable_methods(
208 out, aten_decls + deprecated, template_path)
209 gen_python_functions.gen_py_torch_functions(
210 out, aten_decls + deprecated, template_path)
211 gen_python_functions.gen_py_nn_functions(
212 out, aten_decls, template_path)
215 from .gen_variable_factories
import gen_variable_factories
216 gen_variable_factories(out, aten_decls, template_path)
220 parser = argparse.ArgumentParser(
221 description=
'Generate autograd C++ files script')
222 parser.add_argument(
'declarations', metavar=
'DECL',
223 help=
'path to Declarations.yaml')
224 parser.add_argument(
'out', metavar=
'OUT',
225 help=
'path to output directory')
226 parser.add_argument(
'autograd', metavar=
'AUTOGRAD',
227 help=
'path to autograd directory')
228 args = parser.parse_args()
229 gen_autograd(args.declarations, args.out, args.autograd)
232 if __name__ ==
'__main__':