Caffe2 - Python API
A deep learning, cross platform ML framework
load_derivatives.py
1 # Parses derivatives.yaml into autograd functions
2 #
3 # Each autograd function is represented by dictionary containing a list of
4 # derivatives (also a dictionary). See `create_autograd_function` and
5 # `create_derivative` for the keys.
6 from collections import defaultdict
7 import copy
8 import re
9 import yaml
10 from .utils import YamlLoader
11 from .utils import IDENT_REGEX, split_name_params
12 
13 
14 def load_derivatives(path, declarations):
15  with open(path, 'r') as f:
16  definitions = yaml.load(f, Loader=YamlLoader)
17 
18  declarations_by_signature = defaultdict(list)
19  for declaration in declarations:
20  declarations_by_signature[get_signature(declaration)].append(declaration)
21 
22  autograd_functions = [
23  process_definition(defn, declarations_by_signature)
24  for defn in definitions]
25  ensure_unique_names(autograd_functions)
26  match_declarations_with_autograd_functions(declarations, autograd_functions)
27 
28  return autograd_functions
29 
30 
31 # How do you feel about pasting declaration inside autograd function...
32 def create_autograd_function(name, derivatives, args_with_derivatives, not_differentiable_args_names,
33  signature, declaration, output_differentiability):
34  op = to_camel_case(name) + 'Backward'
35  op = op.replace('ForwardBackward', 'Backward')
36  return {
37  'name': name,
38  'op': op,
39  'declaration': declaration,
40  'args_with_derivatives': args_with_derivatives,
41  'not_differentiable_args_names': not_differentiable_args_names,
42  'signature': signature,
43  'derivatives': derivatives,
44  'saved_inputs': all_saved_variables(derivatives, 'saved_inputs'),
45  'saved_outputs': all_saved_variables(derivatives, 'saved_outputs'),
46  'output_differentiability': output_differentiability,
47  }
48 
49 
50 def create_derivative(arguments, returns, name, formula, var_names):
51  def transform_return(r):
52  # In-place functions take in and return self. Call the modified version
53  # "output" so that it can be referred to in derivative definitions.
54  if r['name'] == 'self':
55  r = copy.deepcopy(r)
56  r['name'] = 'result'
57  return r
58 
59  returns = [transform_return(r) for r in returns]
60  formula, saved_inputs = saved_variables(formula, arguments)
61  formula, saved_outputs = saved_variables(formula, returns)
62 
63  # Check that the referenced derivatives in the formula are in bounds
64  for i in used_gradient_indices(formula):
65  if i >= len(returns):
66  raise RuntimeError(
67  "Out of bounds grads access: derivative formula for {} "
68  "used grads[{}], but the forward only returns {} outputs."
69  .format(name, i, len(returns)))
70 
71  return {
72  'formula': formula,
73  'saved_inputs': saved_inputs,
74  'saved_outputs': saved_outputs,
75  'var_names': var_names,
76  }
77 
78 
79 def process_definition(defn, declarations_by_signature):
80  """Processes a single entry `defn` in derivatives.yaml"""
81 
82  def canonical_declaration(declarations, name):
83  for declaration in declarations:
84  if declaration['name'] == name:
85  return declaration
86  # some functions only have in-place variants
87  assert name + '_' == declarations[0]['name']
88  return declarations[0]
89 
90  def split_names(raw_names):
91  """Given "foo, bar", return ["foo", "bar"]."""
92  return [x.strip() for x in raw_names.split(',')]
93 
94  def lookup_pred(pred, xs):
95  """Return the index of the first element of xs matching pred."""
96  return next((i, x) for i, x in enumerate(xs) if pred(x))
97 
98  def check_grad_usage(defn_name, declaration, derivatives):
99  """
100  Check for some subtle mistakes one might make when writing derivatives.
101  These mistakes will compile, but will be latent until a function is
102  used with double backwards.
103  """
104 
105  used_grad = 0
106  used_grads = 0
107  fully_implemented = True
108  used_grads_indices = []
109  for d in derivatives:
110  formula = d['formula']
111  used_grad += len(re.findall(IDENT_REGEX.format('grad'), formula))
112  used_grads += len(re.findall(IDENT_REGEX.format('grads'), formula))
113  fully_implemented = \
114  fully_implemented and \
115  not re.search(IDENT_REGEX.format('not_implemented'), formula)
116  used_grads_indices.extend(used_gradient_indices(formula))
117  assert used_grads >= len(used_grads_indices)
118  only_used_grads_indices = used_grads == len(used_grads_indices)
119 
120  if used_grad and used_grads:
121  raise RuntimeError("Derivative definition of {} in derivatives.yaml illegally "
122  "mixes use of 'grad' and 'grads'. Consider replacing "
123  "occurrences of 'grad' with 'grads[0]'".format(defn_name))
124 
125  if only_used_grads_indices and set(used_grads_indices) == {0}:
126  raise RuntimeError("Derivative definition of {} in derivatives.yaml solely "
127  "refers to 'grads[0]'. If the first output is indeed the "
128  "only differentiable output, replace 'grads[0]' with 'grad'; "
129  "otherwise, there is a likely error in your derivatives "
130  "declaration.".format(defn_name))
131 
132  def set_up_derivatives(defn_name, defn, declaration):
133  # Determine the set of inputs which have derivatives
134  args_with_derivatives_set = set()
135  for raw_names in defn:
136  args_with_derivatives_set |= set(split_names(raw_names))
137 
138  # Next, let us determine the list of inputs in order.
139  args_with_derivatives = []
140  for arg in declaration['arguments']:
141  if arg['name'] not in args_with_derivatives_set:
142  continue
143  args_with_derivatives.append(arg)
144 
145  # Set up the derivative information
146  derivatives = []
147  not_differentiable_args_names = []
148  for raw_names in sorted(defn.keys()):
149  formula = defn[raw_names]
150  names = split_names(raw_names)
151  derivative = create_derivative(declaration['arguments'], declaration['returns'],
152  declaration['name'], formula, names)
153  if formula.lower().strip() == 'not_differentiable':
154  assert not sum([type(var_name) == list
155  for var_name in derivative['var_names']]), \
156  "Variable names associated to a formula should be a flat list"
157  not_differentiable_args_names += derivative['var_names']
158  else:
159  derivatives.append(derivative)
160  args_with_derivatives = list(filter(lambda x: x['name'] not in not_differentiable_args_names,
161  args_with_derivatives))
162 
163  # Test to see if the use of 'grads' makes sense.
164  check_grad_usage(defn_name, declaration, derivatives)
165 
166  return derivatives, args_with_derivatives, not_differentiable_args_names
167 
168  def unzip(xs):
169  return zip(*xs)
170 
171  # NB: Removes 'name' from defn dictionary
172  defn_name, params = split_name_params(defn.pop('name'))
173  # NB: Removes 'output_differentiability' from defn dictionary
174  # `None` means all differentiable.
175  output_differentiability = defn.pop('output_differentiability', None)
176  param_types, param_names = unzip([p.split(' ') for p in params if p != '*'])
177 
178  if 'grad_input_mask' in param_names:
179  raise RuntimeError("Signature for {} has an argument named grad_input_mask, "
180  "but this name would be shadowed by our codegen. "
181  "Please use a different name in Declarations.cwrap."
182  .format(defn_name))
183  signature = '{}({})'.format(defn_name, ', '.join(param_types))
184 
185  declarations = declarations_by_signature[signature]
186  if len(declarations) == 0:
187  avail = [k for k, v in declarations_by_signature.items()
188  if k.startswith(defn_name + '(') and len(v) > 0]
189  raise RuntimeError('no ATen declaration found for: {}. '
190  'Available signatures: {}'.format(signature, ', '.join(avail)))
191  canonical = canonical_declaration(declarations, defn_name)
192 
193  # TODO: Check the types line up
194  if len(param_names) != len(canonical['args']):
195  raise RuntimeError('Signature for {} has {} arguments ({}), but '
196  'Declarations.yaml records {} arguments ({})'
197  .format(defn_name,
198  len(param_names),
199  ', '.join(param_names),
200  len(canonical['args']),
201  ', '.join(canonical['args'])))
202  for i, (x, y) in enumerate(zip(param_names, canonical['args'])):
203  if x != y:
204  raise RuntimeError('Argument {} of {} has different names in '
205  'derivatives.yaml ({}) and '
206  'Declarations.yaml ({})'
207  .format(i, defn_name, x, y))
208 
209  derivatives, args_with_derivatives, not_differentiable_args_names = set_up_derivatives(defn_name, defn, canonical)
210  return create_autograd_function(defn_name, derivatives, args_with_derivatives, not_differentiable_args_names,
211  signature, canonical, output_differentiability)
212 
213 
214 def ensure_unique_names(autograd_functions):
215  # de-duplicate operation names
216  # you end up with something like:
217  # AddBackward0
218  # AddBackward1
219  # one for each overload
220  functions_by_name = defaultdict(list)
221  for func in autograd_functions:
222  functions_by_name[func['op']].append(func)
223  for op in functions_by_name.keys():
224  overloads = functions_by_name[op]
225  if len(overloads) > 1:
226  for i, func in enumerate(overloads):
227  func['op'] += str(i)
228 
229 
230 def get_signature(declaration, use_base_variant=False):
231  name = declaration['name']
232  arguments = declaration['arguments']
233  if use_base_variant:
234  if declaration['inplace']:
235  assert name.endswith('_')
236  name = name[:-1]
237  elif name.endswith('_out'):
238  name = name[:-4]
239  arguments = [arg for arg in arguments if not arg.get('output', False)]
240  simple_types = [arg['simple_type'] for arg in arguments]
241  return '{}({})'.format(name, ', '.join(simple_types))
242 
243 
244 GRAD_INDEX_REGEX = r'(?:^|\W)grads\[(\d+)\]'
245 
246 
247 def used_gradient_indices(formula):
248  """Determine a list of gradient indices (the i in grads[i]) that
249  are used by the formula.
250 
251  >>> used_gradient_indices("foo(grads[0], grads[1])")
252  [0, 1]
253  """
254  return [int(i) for i in re.findall(GRAD_INDEX_REGEX, formula)]
255 
256 
257 def saved_variables(formula, args):
258  # find which arguments need to be saved
259  saved = []
260 
261  REPLACEMENTS = [
262  # replace self.sizes() with self_sizes
263  (r'{}.sizes\(\)', {
264  'suffix': '_sizes',
265  'type': 'IntArrayRef',
266  }),
267  # replace zeros_like(self) with self_info
268  (r'zeros_like\({}\)', {
269  'suffix': '_info',
270  'type': 'TypeAndSize',
271  'expr': lambda name: name, # at save-time
272  'res': lambda name: name + '_info.zeros()', # at eval-time
273  }),
274  # replace self.size(2) with self_size_2
275  (r'{}.size\((\w+)\)', {
276  'suffix': lambda m: '_argsize_{}'.format(*m.groups()),
277  'type': 'int64_t',
278  }),
279  # replace self.numel() with self_numel
280  (r'{}.numel\(\)', {
281  'suffix': '_numel',
282  'type': 'int64_t',
283  }),
284  # replace to_args_sizes(self) with self_args_sizes
285  (r'to_args_sizes\({}\)', {
286  'suffix': '_args_sizes',
287  'type': 'std::vector<std::vector<int64_t>>',
288  }),
289  # replace TensorGeometry(self) with self_geometry
290  (r'TensorGeometry\({}\)', {
291  'suffix': '_geometry',
292  'type': 'TensorGeometry',
293  }),
294  ]
295 
296  for arg in args:
297  if 'name' not in arg:
298  # some returned arguments do not have names
299  continue
300 
301  name = arg['name']
302 
303  # First search the formula for expressions which can be evaluated
304  # when the autograd Function is created to avoid saving variables
305  for regex, info in REPLACEMENTS:
306  def repl(m):
307  suffix = info['suffix']
308  suffix = suffix(m) if callable(suffix) else suffix
309  expr = info['expr'](name) if 'expr' in info else m.group(0)
310  saved.append({
311  'name': name + suffix,
312  'type': info['type'],
313  'expr': expr,
314  })
315  if 'res' in info:
316  return info['res'](name)
317  return name + suffix
318 
319  formula = re.sub(regex.format(name), repl, formula)
320 
321  # Find any variables which remain in the formula and save them
322  if re.search(IDENT_REGEX.format(name), formula):
323  arg = copy.deepcopy(arg)
324  arg['type'] = arg['type'].replace('const ', '').replace(' &', '')
325  saved.append(arg)
326 
327  return formula, saved
328 
329 
330 def all_saved_variables(derivatives, key):
331  seen = set()
332  saved = []
333  for d in derivatives:
334  for saved_arg in d[key]:
335  if saved_arg['name'] in seen:
336  continue
337  seen.add(saved_arg['name'])
338  saved.append(saved_arg)
339  return saved
340 
341 
342 def to_camel_case(name):
343  return ''.join([p.title() for p in name.split('_')])
344 
345 
346 def match_declarations_with_autograd_functions(declarations, autograd_functions):
347  """Sets the "derivative" key on declarations to matching autograd functions
348 
349  In-place functions will use the out-of-place derivative definition if there
350  is no in-place specific derivative.
351  """
352 
353  functions_by_signature = {f['signature']: f for f in autograd_functions}
354 
355  def find_function(declaration):
356  signature = get_signature(declaration)
357  if signature in functions_by_signature:
358  return functions_by_signature[signature]
359 
360  # if there is no exact match look for the out-of-place signature.
361  # i.e mul() for mul_() or mul_out()
362  signature = get_signature(declaration, use_base_variant=True)
363  return functions_by_signature.get(signature)
364 
365  for declaration in declarations:
366  declaration['derivative'] = find_function(declaration)