Caffe2 - Python API
A deep learning, cross platform ML framework
auto.py
1 from itertools import repeat
2 from collections import defaultdict
3 
4 import torch
5 from torch._thnn.utils import parse_header, THNN_H_PATH
6 from torch.autograd.function import Function, InplaceFunction, once_differentiable
7 from torch._thnn import type2backend
8 from .auto_double_backwards import double_backwards_fns
9 from .auto_symbolic import symbolic_fns
10 
11 from . import _all_functions
12 
13 
14 def _make_function_class_criterion(class_name, update_output, update_grad_input, acc_grad_parameters,
15  double_backwards_fn, symbolic_fn):
16  weight_arg_idx = -1
17  for i, arg in enumerate(update_output.arguments):
18  if arg.name.startswith('weight'):
19  weight_arg_idx = i
20  break
21 
22  reduce_arg_idx = -1
23  for i, arg in enumerate(update_output.arguments):
24  if arg.name == 'reduce':
25  reduce_arg_idx = i
26  break
27 
28  buffers_idx = []
29  additional_arg_idx = 0
30  for arg in update_output.arguments[4:]:
31  if not arg.name.startswith('weight') and arg.type == 'THTensor*':
32  buffers_idx.append(additional_arg_idx)
33  additional_arg_idx += 1
34 
35  @staticmethod
36  def symbolic(*args, **kwargs):
37  a = symbolic_fn(*args, **kwargs)
38  return a
39 
40  @staticmethod
41  def forward(ctx, input, target, *args):
42  ctx._backend = type2backend[input.type()]
43  ctx.save_for_backward(input, target)
44  if weight_arg_idx >= 0:
45  ctx.weight = args[0]
46  args = args[1:]
47  ctx.additional_args = list(args)
48  insert_idx = weight_arg_idx - 4 # state, input, target, output
49  ctx.additional_args.insert(insert_idx, ctx.weight)
50  else:
51  ctx.additional_args = list(args)
52 
53  ctx.forward_args_count = len(ctx.additional_args)
54  for idx in buffers_idx:
55  ctx.additional_args.insert(idx, input.new(1))
56  output = input.new(1)
57  getattr(ctx._backend, update_output.name)(ctx._backend.library_state, input, target,
58  output, *ctx.additional_args)
59  return output
60 
61  @staticmethod
62  def backward(ctx, grad_output):
63  input, target = ctx.saved_tensors
64  # apply returns grad_input, so we need to return Nones for target (1) + 1 for each extra arg passed to forward.
65  return ((backward_cls.apply(input, target, grad_output, ctx.additional_args, ctx._backend),) +
66  (None,) * (ctx.forward_args_count + 1))
67 
68  @staticmethod
69  def backward_cls_forward(ctx, input, target, grad_output, additional_args_ctx, backend_ctx):
70  ctx.additional_args = additional_args_ctx
71  ctx._backend = backend_ctx
72  ctx.save_for_backward(input, target, grad_output)
73  grad_input = grad_output.new().resize_as_(input).zero_()
74 
75  if reduce_arg_idx >= 0:
76  getattr(ctx._backend, update_grad_input.name)(ctx._backend.library_state, input, target,
77  grad_output, grad_input, *ctx.additional_args)
78  return grad_input
79 
80  getattr(ctx._backend, update_grad_input.name)(ctx._backend.library_state, input, target,
81  grad_input, *ctx.additional_args)
82  grad_output_expanded = grad_output.view(*repeat(1, grad_input.dim()))
83  grad_input.mul_(grad_output_expanded.expand_as(grad_input))
84  return grad_input
85 
86  @staticmethod
87  def backward_cls_backward(ctx, *grad_params):
88  return double_backwards_fn(ctx, *grad_params)
89 
90  backward_cls = type(class_name + "Backward", (Function,),
91  dict(forward=backward_cls_forward, backward=backward_cls_backward))
92  return type(class_name, (Function,), dict(forward=forward, backward=backward, symbolic=symbolic)), backward_cls
93 
94 
95 def _find_buffers(args, ignored_args):
96  additional_arg_idx = 0
97  buffers = []
98  for arg in args:
99  if arg.name in ignored_args:
100  continue
101  if arg.type == 'THTensor*':
102  buffers.append((additional_arg_idx, arg.name))
103  additional_arg_idx += 1
104  return buffers
105 
106 
107 def _make_function_class(class_name, update_output, update_grad_input, acc_grad_parameters,
108  double_backwards_fn, symbolic_fn):
109  def has_argument(fn, name):
110  for arg in fn.arguments:
111  if arg.name == name:
112  return True
113  return False
114  save_output = has_argument(update_grad_input, 'output')
115 
116  param_args = {'weight', 'bias'}
117  ignored_args = {'weight', 'bias', 'gradWeight', 'gradBias', 'output'}
118  expected_params = [arg for arg in update_output.arguments[3:]
119  if arg.name in param_args]
120  buffers = {}
121  buffers['update_output'] = _find_buffers(update_output.arguments[3:],
122  ignored_args)
123  buffers['update_grad_input'] = _find_buffers(
124  update_grad_input.arguments[4:], ignored_args)
125  if acc_grad_parameters is not None:
126  buffers['acc_grad_parameters'] = _find_buffers(
127  acc_grad_parameters.arguments[3:], ignored_args)
128 
129  # This assumes that only the last argument can be
130  # an inplace flag
131  is_inplace = update_output.arguments[-1].name == 'inplace'
132 
133  def _initialize_buffers(ctx, fn_name):
134  additional_args = ctx.additional_args
135  for idx, name in buffers[fn_name]:
136  # TODO: some buffers are necessary only for update output and can be
137  # freed right afterwards
138  buffer = ctx.buffers[name]
139  additional_args = additional_args[:idx] + [buffer] + additional_args[idx:]
140  return tuple(additional_args)
141 
142  @staticmethod
143  def symbolic(*args, **kwargs):
144  return symbolic_fn(*args, **kwargs)
145 
146  @staticmethod
147  def forward(ctx, input, *params):
148  ctx._backend = type2backend[input.type()]
149 
150  ctx.additional_args = []
151  tensor_param_list = []
152  for param in params:
153  if isinstance(param, torch.Tensor):
154  if type(param) != type(input):
155  raise RuntimeError("input type ({}) doesn't match the type of "
156  "a parameter tensor ({})".format(torch.typename(input),
157  torch.typename(param)))
158  tensor_param_list.append(param)
159  else:
160  ctx.additional_args.append(param)
161 
162  tensor_params = tuple(tensor_param_list)
163  if is_inplace:
164  ctx.inplace = params[-1]
165  # Allocate temporary buffers and insert them into additional_args
166  ctx.buffers = defaultdict(type(input))
167  additional_args = _initialize_buffers(ctx, 'update_output')
168 
169  # Fill in optional params with None
170  args = tensor_params
171  for i in range(len(params), len(expected_params)):
172  param = expected_params[i]
173  if param.is_optional:
174  args += (None,)
175  else:
176  raise ValueError("missing required argument '%s'" % param.name)
177 
178  args += tuple(additional_args)
179 
180  # If the module is working in-place its output will be set to the
181  # same storage as input, but its tensor won't be dirty.
182  if is_inplace and ctx.inplace:
183  ctx.mark_dirty(input)
184  output = input
185  else:
186  output = input.new()
187 
188  if save_output:
189  ctx.save_for_backward(input, output, *tensor_params)
190  else:
191  ctx.save_for_backward(input, *tensor_params)
192 
193  if not ctx.requires_grad:
194  del ctx.buffers
195 
196  getattr(ctx._backend, update_output.name)(ctx._backend.library_state, input, output, *args)
197  return output
198 
199  @staticmethod
200  def backward(ctx, grad_output):
201  t = ctx.saved_tensors
202  input, tensor_params = t[0], t[1:]
203  # Some notes on this function call:
204  # 1) We need to pass params as *params so they are unwrapped correctly in backward_cls_forward.
205  # 2) apply returns the grad_input / grad_tensor_params, so we need to append Nones equal to the number
206  # of non tensor_params, i.e. the additional_args
207  # 3) it may be simpler to recalculate some of these parameters (e.g. ctx._backend) in backward_cls_forward?
208 
209  return (backward_cls.apply(input, grad_output, ctx.additional_args, ctx._backend, ctx.buffers, *tensor_params) +
210  (None,) * len(ctx.additional_args))
211 
212  @staticmethod
213  def backward_cls_forward(ctx, input, grad_output, additional_args_ctx, backend_ctx, buffers_ctx, *params):
214  ctx.additional_args = additional_args_ctx
215  ctx.buffers = buffers_ctx
216  ctx._backend = backend_ctx
217  ctx.save_for_backward(input, grad_output, *params)
218  if save_output:
219  output = params[0]
220  params = params[1:]
221 
222  grad_params = tuple(None for p in params)
223  grad_input_tuple = (None,)
224  if is_inplace:
225  ctx.inplace = additional_args_ctx[-1]
226 
227  if ctx.needs_input_grad[0]:
228  additional_args = _initialize_buffers(ctx, 'update_grad_input')
229  if save_output:
230  additional_args = (output,) + additional_args
231 
232  if is_inplace and ctx.inplace:
233  assert additional_args[-1] is True
234  tmp_args = list(additional_args)
235  tmp_args[-1] = False
236  additional_args = tuple(tmp_args)
237  grad_input = input.new(input.size())
238  params_without_bias = params if len(params) < 2 else params[:1]
239  update_grad_input_fn = getattr(ctx._backend, update_grad_input.name)
240  gi_args = params_without_bias + additional_args
241  update_grad_input_fn(ctx._backend.library_state, input, grad_output, grad_input, *gi_args)
242  grad_input_tuple = (grad_input,)
243 
244  if acc_grad_parameters and any(ctx.needs_input_grad[1:]):
245  additional_args = _initialize_buffers(ctx, 'acc_grad_parameters')
246  grad_params = tuple(p.new(p.size()).zero_() for p in params)
247  appended_grads = len(expected_params) - len(grad_params)
248  grad_params += (None,) * appended_grads
249  acc_grad_parameters_fn = getattr(ctx._backend, acc_grad_parameters.name)
250  param_args = grad_params + additional_args + (1,)
251  acc_grad_parameters_fn(ctx._backend.library_state, input, grad_output, *param_args)
252  if appended_grads:
253  grad_params = grad_params[:-appended_grads]
254 
255  return grad_input_tuple + grad_params
256 
257  @staticmethod
258  def backward_cls_backward(ctx, *grad_params):
259  return double_backwards_fn(ctx, *grad_params)
260 
261  base_class = Function if not is_inplace else InplaceFunction
262  backward_cls = type(class_name + "Backward", (base_class,), dict(forward=backward_cls_forward,
263  backward=backward_cls_backward))
264 
265  return type(class_name, (base_class,), dict(forward=forward, backward=backward, symbolic=symbolic)), backward_cls
266 
267 
268 def _generate_function_classes(scope_dict):
269  global function_list, function_by_name
270  function_list = parse_header(THNN_H_PATH)
271  function_by_name = {fn.name: fn for fn in function_list}
272  classes_to_generate = {fn.name.partition('_')[0] for fn in function_list}
273  exceptions = {
274  'Linear',
275  'IndexLinear',
276  'SpatialFullConvolution',
277  'SpatialConvolutionMM',
278  'TemporalConvolution',
279  'SpatialAveragePooling',
280  'SpatialMaxPooling',
281  'SpatialDilatedMaxPooling',
282  'SpatialMaxUnpooling',
283  'SpatialAdaptiveMaxPooling',
284  'VolumetricAveragePooling',
285  'VolumetricMaxPooling',
286  'VolumetricMaxUnpooling',
287  'VolumetricAdaptiveAveragePooling',
288  'VolumetricAdaptiveMaxPooling',
289  'VolumetricConvolution',
290  'VolumetricFullConvolution',
291  'VolumetricConvolutionMM',
292  'TemporalMaxPooling',
293  'BatchNormalization',
294  'LookupTable',
295  'LookupTableBag',
296  'PReLU',
297  'RReLU',
298  'SoftMax',
299  'LogSoftMax',
300  'GRUFused',
301  'LSTMFused',
302  'unfolded',
303  }
304  name_remap = {
305  'TemporalConvolution': 'Conv1d',
306  'SpatialDilatedConvolution': 'DilatedConv2d',
307  'SpatialMaxUnpooling': 'MaxUnpool2d',
308  'VolumetricMaxUnpooling': 'MaxUnpool3d',
309  'HardTanh': 'Hardtanh',
310  'HardShrink': 'Hardshrink',
311  'SoftPlus': 'Softplus',
312  'SoftShrink': 'Softshrink',
313  'MSECriterion': 'MSELoss',
314  'AbsCriterion': 'L1Loss',
315  'BCECriterion': 'BCELoss',
316  'ClassNLLCriterion': 'NLLLoss',
317  'DistKLDivCriterion': 'KLDivLoss',
318  'SpatialClassNLLCriterion': 'NLLLoss2d',
319  'MultiLabelMarginCriterion': 'MultiLabelMarginLoss',
320  'MultiMarginCriterion': 'MultiMarginLoss',
321  'SmoothL1Criterion': 'SmoothL1Loss',
322  'SoftMarginCriterion': 'SoftMarginLoss',
323  }
324 
325  classes_to_generate -= exceptions
326  for fn in classes_to_generate:
327  update_output = function_by_name[fn + '_updateOutput']
328  update_grad_input = function_by_name[fn + '_updateGradInput']
329  acc_grad_parameters = function_by_name.get(fn + '_accGradParameters')
330  class_name = name_remap.get(fn, fn)
331  double_backwards_fn = double_backwards_fns.get(class_name)
332  if double_backwards_fn is None:
333  def make_default_double_backwards_fn(class_name):
334  def default_double_backwards_fn(ctx, *grad_params):
335  raise ValueError(class_name + " can only be differentiated once.")
336  return default_double_backwards_fn
337  double_backwards_fn = make_default_double_backwards_fn(class_name)
338  symbolic_fn = symbolic_fns.get(class_name)
339  # This has to call a function to retain correct references to functions
340  is_criterion_fn = 'Criterion' in fn
341  if is_criterion_fn:
342  cls, backward_cls = _make_function_class_criterion(class_name, update_output,
343  update_grad_input, acc_grad_parameters,
344  double_backwards_fn, symbolic_fn)
345  else:
346  cls, backward_cls = _make_function_class(class_name, update_output,
347  update_grad_input, acc_grad_parameters,
348  double_backwards_fn, symbolic_fn)
349  scope_dict[class_name] = cls
350  scope_dict[backward_cls.__name__] = backward_cls
351  if not class_name.startswith('_'):
352  _all_functions.append(cls)
353  _all_functions.append(backward_cls)
354 
355 
356 _generate_function_classes(locals())
def typename(o)
Define basic utilities.
Definition: __init__.py:94