8 from .utils
import nested_dict, CodeTemplate, write
9 from .gen_autograd
import VIEW_FUNCTIONS
10 from .utils
import IDENT_REGEX
12 FUNCTION_DECLARATION = CodeTemplate(
"""\ 13 struct ${op} : public ${superclass} { 14 using ${superclass}::${superclass}; 15 variable_list apply(variable_list&& grads) override; 16 std::string name() const override { return "${op}"; } 17 void release_variables() override { 20 ${will_release_variables} 26 WILL_RELEASE_VARIABLES = CodeTemplate(
"""\ 27 bool retain_variables = true; 28 void will_release_variables() override { 29 retain_variables = false; 33 FUNCTION_DEFINITION = CodeTemplate(
"""\ 34 variable_list ${op}::apply(variable_list&& grads) { 35 IndexRangeGenerator gen; 36 ${compute_index_ranges} 37 variable_list grad_inputs(gen.size()); 43 PY_FUNCTION_DEFINITION = CodeTemplate(
"""\ 44 static PyTypeObject ${op}Class; 45 addClass<${op}>(${op}Class, "${op}"); 48 GRAD_INPUT_MASK = CodeTemplate(
"""\ 49 auto grad_input_mask = std::array<bool, ${n}>{ 54 DERIVATIVE_SINGLE = CodeTemplate(
"""\ 55 if (should_compute_output({ ${name}_ix })) { 56 auto grad_result = ${derivative}; 57 copy_range(grad_inputs, ${name}_ix, grad_result); 61 DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate(
"""\ 62 if (should_compute_output({ ${name}_ix })) { 63 copy_range(grad_inputs, ${name}_ix, std::get<${i}>(grad_result)); 67 DERIVATIVE_MULTI = CodeTemplate(
"""\ 68 if (should_compute_output({ ${idx_ranges} })) { 70 auto grad_result = ${derivative}; 81 UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
84 def gen_autograd_functions(out, autograd_functions, template_path):
85 """Functions.h and Functions.cpp body 87 These contain the auto-generated subclasses of torch::autograd::Function 88 for each every differentiable torch function. 91 FUNCTIONS_H = CodeTemplate.from_file(template_path +
'/Functions.h')
92 FUNCTIONS_CPP = CodeTemplate.from_file(template_path +
'/Functions.cpp')
93 PY_FUNCTIONS_H = CodeTemplate.from_file(template_path +
'/python_functions.h')
94 PY_FUNCTIONS_CPP = CodeTemplate.from_file(template_path +
'/python_functions.cpp')
96 function_definitions = []
97 function_declarations = []
98 py_function_initializers = []
100 for func
in autograd_functions:
101 env = process_function(func)
103 function_declarations.append(FUNCTION_DECLARATION.substitute(env))
104 function_definitions.append(FUNCTION_DEFINITION.substitute(env))
105 py_function_initializers.append(PY_FUNCTION_DEFINITION.substitute(env))
108 'autograd_function_definitions': function_definitions,
109 'autograd_function_declarations': function_declarations,
110 'py_function_initializers': py_function_initializers,
113 write(out,
'Functions.h', FUNCTIONS_H, top_env)
114 write(out,
'Functions.cpp', FUNCTIONS_CPP, top_env)
115 write(out,
'python_functions.h', PY_FUNCTIONS_H, top_env)
116 write(out,
'python_functions.cpp', PY_FUNCTIONS_CPP, top_env)
119 def process_function(func):
122 release_variables = []
123 saved_list_sizes = []
126 env[
'compute_index_ranges'] = []
127 for arg
in func[
'args_with_derivatives']:
128 if arg[
'type'] ==
'TensorList':
129 size =
'{}_size_'.format(arg[
'name'])
130 saved_list_sizes.append(
'size_t {}_size_;'.format(arg[
'name']))
133 env[
'compute_index_ranges'].append(
'auto {}_ix = gen.range({});'.format(arg[
'name'], size))
135 def save_arg(arg, is_output):
137 if arg[
'type'] ==
'Tensor' or (arg[
'type'] ==
'Scalar' and is_output):
138 saved_variables.append(
'SavedVariable {}_;'.format(name))
139 release_variables.append(
'{}_.reset_data();'.format(name))
140 release_variables.append(
'{}_.reset_grad_function();'.format(name))
141 ptr =
'shared_from_this()' if is_output
else '' 142 unpack.append(
'auto {} = {}_.unpack({});'.format(name, name, ptr))
143 elif arg[
'type'] ==
'TensorList':
144 saved_variables.append(
'std::vector<SavedVariable> {}_;'.format(name))
145 release_variables.append(
'{}_.clear();'.format(name))
146 unpack.append(
'auto {} = unpack_list({}_);'.format(name, name))
147 elif arg[
'type'] ==
'IntArrayRef':
148 saved_variables.append(
'std::vector<int64_t> {};'.format(name))
149 elif arg[
'type'] ==
'int64_t':
150 saved_variables.append(
'{} {} = 0;'.format(arg[
'type'], name))
152 saved_variables.append(
'{} {};'.format(arg[
'type'], name))
154 for arg
in func[
'saved_inputs']:
155 save_arg(arg, is_output=
False)
156 for arg
in func[
'saved_outputs']:
157 save_arg(arg, is_output=
True)
158 env[
'saved_variables'] = saved_variables
159 env[
'release_variables'] = release_variables
160 env[
'saved_list_sizes'] = saved_list_sizes
162 if uses_retain_variables(func):
163 env[
'will_release_variables'] = WILL_RELEASE_VARIABLES.substitute()
165 env[
'will_release_variables'] =
'' 169 if uses_single_grad(func):
170 body.append(
'auto& grad = grads[0];')
172 def emit_derivative(derivative):
173 formula = derivative[
'formula']
174 var_names = derivative[
'var_names']
175 if len(var_names) == 1:
176 return DERIVATIVE_SINGLE.substitute(name=var_names[0], derivative=formula)
178 if 'grad_input_mask' in formula:
179 masks = [
'should_compute_output({{ {}_ix }}),'.format(n)
for n
in var_names]
180 grad_input_mask = GRAD_INPUT_MASK.substitute(masks=masks, n=len(var_names))
183 idx_ranges =
', '.join(
"{}_ix".format(n)
for n
in var_names)
185 for i, n
in enumerate(var_names):
186 copy_ranges.append(DERIVATIVE_MULTI_COPY_RANGE.substitute(name=n, i=i))
187 return DERIVATIVE_MULTI.substitute(
188 idx_ranges=idx_ranges, copy_ranges=copy_ranges,
190 grad_input_mask=grad_input_mask)
193 for derivative
in func[
'derivatives']:
194 body.append(emit_derivative(derivative))
197 if func[
'name']
in UNTRACEABLE_FUNCTIONS:
198 env[
'superclass'] =
'Function' 200 env[
'superclass'] =
'TraceableFunction' 201 return nested_dict(env, func)
204 def uses_ident(func, ident):
207 for derivative
in func[
'derivatives']:
208 formula = derivative[
'formula']
209 if re.search(IDENT_REGEX.format(ident), formula):
214 def uses_retain_variables(func):
215 return uses_ident(func,
'retain_variables')
218 def uses_single_grad(func):
219 return uses_ident(func,
'grad')