5 source_files = {
'.py',
'.cpp',
'.h'}
7 DECLARATIONS_PATH =
'torch/share/ATen/Declarations.yaml' 12 def all_generator_source():
14 for directory, _, filenames
in os.walk(
'tools'):
16 if os.path.splitext(f)[1]
in source_files:
17 full = os.path.join(directory, f)
25 'torch/share/ATen/Declarations.yaml',
26 'tools/autograd/derivatives.yaml',
27 'tools/autograd/deprecated.yaml',
31 'torch/csrc/autograd/generated/Functions.cpp',
32 'torch/csrc/autograd/generated/Functions.h',
33 'torch/csrc/autograd/generated/python_functions.cpp',
34 'torch/csrc/autograd/generated/python_functions.h',
35 'torch/csrc/autograd/generated/python_nn_functions.cpp',
36 'torch/csrc/autograd/generated/python_nn_functions.h',
37 'torch/csrc/autograd/generated/python_nn_functions_dispatch.h',
38 'torch/csrc/autograd/generated/python_variable_methods.cpp',
39 'torch/csrc/autograd/generated/python_variable_methods_dispatch.h',
40 'torch/csrc/autograd/generated/variable_factories.h',
41 'torch/csrc/autograd/generated/VariableType_0.cpp',
42 'torch/csrc/autograd/generated/VariableType_1.cpp',
43 'torch/csrc/autograd/generated/VariableType_2.cpp',
44 'torch/csrc/autograd/generated/VariableType_3.cpp',
45 'torch/csrc/autograd/generated/VariableType_4.cpp',
46 'torch/csrc/autograd/generated/VariableType.h',
47 'torch/csrc/jit/generated/register_aten_ops_0.cpp',
48 'torch/csrc/jit/generated/register_aten_ops_1.cpp',
49 'torch/csrc/jit/generated/register_aten_ops_2.cpp',
53 def generate_code(ninja_global=None,
54 declarations_path=
None,
58 root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
59 sys.path.insert(0, root)
63 from tools.nnwrap import generate_wrappers
as generate_nn_wrappers
67 generate_nn_wrappers(nn_path, install_dir,
'tools/cwrap/plugins/templates')
70 autograd_gen_dir = install_dir
or 'torch/csrc/autograd/generated' 71 jit_gen_dir = install_dir
or 'torch/csrc/jit/generated' 72 for d
in (autograd_gen_dir, jit_gen_dir):
73 if not os.path.exists(d):
75 gen_autograd(declarations_path
or DECLARATIONS_PATH, autograd_gen_dir,
'tools/autograd')
76 gen_jit_dispatch(declarations_path
or DECLARATIONS_PATH, jit_gen_dir,
'tools/jit/templates')
80 parser = argparse.ArgumentParser(description=
'Autogenerate code')
81 parser.add_argument(
'--declarations-path')
82 parser.add_argument(
'--nn-path')
83 parser.add_argument(
'--ninja-global')
84 parser.add_argument(
'--install_dir')
85 options = parser.parse_args()
86 generate_code(options.ninja_global,
87 options.declarations_path,
92 if __name__ ==
"__main__":