6 from collections
import defaultdict
10 from .utils
import YamlLoader
11 from .utils
import IDENT_REGEX, split_name_params
14 def load_derivatives(path, declarations):
15 with open(path,
'r') as f: 16 definitions = yaml.load(f, Loader=YamlLoader) 18 declarations_by_signature = defaultdict(list) 19 for declaration
in declarations:
20 declarations_by_signature[get_signature(declaration)].append(declaration)
22 autograd_functions = [
23 process_definition(defn, declarations_by_signature)
24 for defn
in definitions]
25 ensure_unique_names(autograd_functions)
26 match_declarations_with_autograd_functions(declarations, autograd_functions)
28 return autograd_functions
32 def create_autograd_function(name, derivatives, args_with_derivatives, not_differentiable_args_names,
33 signature, declaration, output_differentiability):
34 op = to_camel_case(name) +
'Backward' 35 op = op.replace(
'ForwardBackward',
'Backward')
39 'declaration': declaration,
40 'args_with_derivatives': args_with_derivatives,
41 'not_differentiable_args_names': not_differentiable_args_names,
42 'signature': signature,
43 'derivatives': derivatives,
44 'saved_inputs': all_saved_variables(derivatives,
'saved_inputs'),
45 'saved_outputs': all_saved_variables(derivatives,
'saved_outputs'),
46 'output_differentiability': output_differentiability,
50 def create_derivative(arguments, returns, name, formula, var_names):
51 def transform_return(r):
54 if r[
'name'] ==
'self':
59 returns = [transform_return(r)
for r
in returns]
60 formula, saved_inputs = saved_variables(formula, arguments)
61 formula, saved_outputs = saved_variables(formula, returns)
64 for i
in used_gradient_indices(formula):
67 "Out of bounds grads access: derivative formula for {} " 68 "used grads[{}], but the forward only returns {} outputs." 69 .format(name, i, len(returns)))
73 'saved_inputs': saved_inputs,
74 'saved_outputs': saved_outputs,
75 'var_names': var_names,
79 def process_definition(defn, declarations_by_signature):
80 """Processes a single entry `defn` in derivatives.yaml""" 82 def canonical_declaration(declarations, name):
83 for declaration
in declarations:
84 if declaration[
'name'] == name:
87 assert name +
'_' == declarations[0][
'name']
88 return declarations[0]
90 def split_names(raw_names):
91 """Given "foo, bar", return ["foo", "bar"].""" 92 return [x.strip()
for x
in raw_names.split(
',')]
94 def lookup_pred(pred, xs):
95 """Return the index of the first element of xs matching pred.""" 96 return next((i, x)
for i, x
in enumerate(xs)
if pred(x))
98 def check_grad_usage(defn_name, declaration, derivatives):
100 Check for some subtle mistakes one might make when writing derivatives. 101 These mistakes will compile, but will be latent until a function is 102 used with double backwards. 107 fully_implemented =
True 108 used_grads_indices = []
109 for d
in derivatives:
110 formula = d[
'formula']
111 used_grad += len(re.findall(IDENT_REGEX.format(
'grad'), formula))
112 used_grads += len(re.findall(IDENT_REGEX.format(
'grads'), formula))
113 fully_implemented = \
114 fully_implemented
and \
115 not re.search(IDENT_REGEX.format(
'not_implemented'), formula)
116 used_grads_indices.extend(used_gradient_indices(formula))
117 assert used_grads >= len(used_grads_indices)
118 only_used_grads_indices = used_grads == len(used_grads_indices)
120 if used_grad
and used_grads:
121 raise RuntimeError(
"Derivative definition of {} in derivatives.yaml illegally " 122 "mixes use of 'grad' and 'grads'. Consider replacing " 123 "occurrences of 'grad' with 'grads[0]'".format(defn_name))
125 if only_used_grads_indices
and set(used_grads_indices) == {0}:
126 raise RuntimeError(
"Derivative definition of {} in derivatives.yaml solely " 127 "refers to 'grads[0]'. If the first output is indeed the " 128 "only differentiable output, replace 'grads[0]' with 'grad'; " 129 "otherwise, there is a likely error in your derivatives " 130 "declaration.".format(defn_name))
132 def set_up_derivatives(defn_name, defn, declaration):
134 args_with_derivatives_set = set()
135 for raw_names
in defn:
136 args_with_derivatives_set |= set(split_names(raw_names))
139 args_with_derivatives = []
140 for arg
in declaration[
'arguments']:
141 if arg[
'name']
not in args_with_derivatives_set:
143 args_with_derivatives.append(arg)
147 not_differentiable_args_names = []
148 for raw_names
in sorted(defn.keys()):
149 formula = defn[raw_names]
150 names = split_names(raw_names)
151 derivative = create_derivative(declaration[
'arguments'], declaration[
'returns'],
152 declaration[
'name'], formula, names)
153 if formula.lower().strip() ==
'not_differentiable':
154 assert not sum([type(var_name) == list
155 for var_name
in derivative[
'var_names']]), \
156 "Variable names associated to a formula should be a flat list" 157 not_differentiable_args_names += derivative[
'var_names']
159 derivatives.append(derivative)
160 args_with_derivatives = list(filter(
lambda x: x[
'name']
not in not_differentiable_args_names,
161 args_with_derivatives))
164 check_grad_usage(defn_name, declaration, derivatives)
166 return derivatives, args_with_derivatives, not_differentiable_args_names
172 defn_name, params = split_name_params(defn.pop(
'name'))
175 output_differentiability = defn.pop(
'output_differentiability',
None)
176 param_types, param_names = unzip([p.split(
' ')
for p
in params
if p !=
'*'])
178 if 'grad_input_mask' in param_names:
179 raise RuntimeError(
"Signature for {} has an argument named grad_input_mask, " 180 "but this name would be shadowed by our codegen. " 181 "Please use a different name in Declarations.cwrap." 183 signature =
'{}({})'.format(defn_name,
', '.join(param_types))
185 declarations = declarations_by_signature[signature]
186 if len(declarations) == 0:
187 avail = [k
for k, v
in declarations_by_signature.items()
188 if k.startswith(defn_name +
'(')
and len(v) > 0]
189 raise RuntimeError(
'no ATen declaration found for: {}. ' 190 'Available signatures: {}'.format(signature,
', '.join(avail)))
191 canonical = canonical_declaration(declarations, defn_name)
194 if len(param_names) != len(canonical[
'args']):
195 raise RuntimeError(
'Signature for {} has {} arguments ({}), but ' 196 'Declarations.yaml records {} arguments ({})' 199 ', '.join(param_names),
200 len(canonical[
'args']),
201 ', '.join(canonical[
'args'])))
202 for i, (x, y)
in enumerate(zip(param_names, canonical[
'args'])):
204 raise RuntimeError(
'Argument {} of {} has different names in ' 205 'derivatives.yaml ({}) and ' 206 'Declarations.yaml ({})' 207 .format(i, defn_name, x, y))
209 derivatives, args_with_derivatives, not_differentiable_args_names = set_up_derivatives(defn_name, defn, canonical)
210 return create_autograd_function(defn_name, derivatives, args_with_derivatives, not_differentiable_args_names,
211 signature, canonical, output_differentiability)
214 def ensure_unique_names(autograd_functions):
220 functions_by_name = defaultdict(list)
221 for func
in autograd_functions:
222 functions_by_name[func[
'op']].append(func)
223 for op
in functions_by_name.keys():
224 overloads = functions_by_name[op]
225 if len(overloads) > 1:
226 for i, func
in enumerate(overloads):
230 def get_signature(declaration, use_base_variant=False):
231 name = declaration[
'name']
232 arguments = declaration[
'arguments']
234 if declaration[
'inplace']:
235 assert name.endswith(
'_')
237 elif name.endswith(
'_out'):
239 arguments = [arg
for arg
in arguments
if not arg.get(
'output',
False)]
240 simple_types = [arg[
'simple_type']
for arg
in arguments]
241 return '{}({})'.format(name,
', '.join(simple_types))
244 GRAD_INDEX_REGEX =
r'(?:^|\W)grads\[(\d+)\]' 247 def used_gradient_indices(formula):
248 """Determine a list of gradient indices (the i in grads[i]) that 249 are used by the formula. 251 >>> used_gradient_indices("foo(grads[0], grads[1])") 254 return [int(i)
for i
in re.findall(GRAD_INDEX_REGEX, formula)]
257 def saved_variables(formula, args):
265 'type':
'IntArrayRef',
268 (
r'zeros_like\({}\)', {
270 'type':
'TypeAndSize',
271 'expr':
lambda name: name,
272 'res':
lambda name: name +
'_info.zeros()',
275 (
r'{}.size\((\w+)\)', {
276 'suffix':
lambda m:
'_argsize_{}'.format(*m.groups()),
285 (
r'to_args_sizes\({}\)', {
286 'suffix':
'_args_sizes',
287 'type':
'std::vector<std::vector<int64_t>>',
290 (
r'TensorGeometry\({}\)', {
291 'suffix':
'_geometry',
292 'type':
'TensorGeometry',
297 if 'name' not in arg:
305 for regex, info
in REPLACEMENTS:
307 suffix = info[
'suffix']
308 suffix = suffix(m)
if callable(suffix)
else suffix
309 expr = info[
'expr'](name)
if 'expr' in info
else m.group(0)
311 'name': name + suffix,
312 'type': info[
'type'],
316 return info[
'res'](name)
319 formula = re.sub(regex.format(name), repl, formula)
322 if re.search(IDENT_REGEX.format(name), formula):
323 arg = copy.deepcopy(arg)
324 arg[
'type'] = arg[
'type'].replace(
'const ',
'').replace(
' &',
'')
327 return formula, saved
330 def all_saved_variables(derivatives, key):
333 for d
in derivatives:
334 for saved_arg
in d[key]:
335 if saved_arg[
'name']
in seen:
337 seen.add(saved_arg[
'name'])
338 saved.append(saved_arg)
342 def to_camel_case(name):
343 return ''.join([p.title()
for p
in name.split(
'_')])
346 def match_declarations_with_autograd_functions(declarations, autograd_functions):
347 """Sets the "derivative" key on declarations to matching autograd functions 349 In-place functions will use the out-of-place derivative definition if there 350 is no in-place specific derivative. 353 functions_by_signature = {f[
'signature']: f
for f
in autograd_functions}
355 def find_function(declaration):
356 signature = get_signature(declaration)
357 if signature
in functions_by_signature:
358 return functions_by_signature[signature]
362 signature = get_signature(declaration, use_base_variant=
True)
363 return functions_by_signature.get(signature)
365 for declaration
in declarations:
366 declaration[
'derivative'] = find_function(declaration)