1 from itertools
import repeat
2 from collections
import defaultdict
8 from .auto_double_backwards
import double_backwards_fns
9 from .auto_symbolic
import symbolic_fns
11 from .
import _all_functions
14 def _make_function_class_criterion(class_name, update_output, update_grad_input, acc_grad_parameters,
15 double_backwards_fn, symbolic_fn):
17 for i, arg
in enumerate(update_output.arguments):
18 if arg.name.startswith(
'weight'):
23 for i, arg
in enumerate(update_output.arguments):
24 if arg.name ==
'reduce':
29 additional_arg_idx = 0
30 for arg
in update_output.arguments[4:]:
31 if not arg.name.startswith(
'weight')
and arg.type ==
'THTensor*':
32 buffers_idx.append(additional_arg_idx)
33 additional_arg_idx += 1
36 def symbolic(*args, **kwargs):
37 a = symbolic_fn(*args, **kwargs)
41 def forward(ctx, input, target, *args):
42 ctx._backend = type2backend[input.type()]
43 ctx.save_for_backward(input, target)
44 if weight_arg_idx >= 0:
47 ctx.additional_args = list(args)
48 insert_idx = weight_arg_idx - 4
49 ctx.additional_args.insert(insert_idx, ctx.weight)
51 ctx.additional_args = list(args)
53 ctx.forward_args_count = len(ctx.additional_args)
54 for idx
in buffers_idx:
55 ctx.additional_args.insert(idx, input.new(1))
57 getattr(ctx._backend, update_output.name)(ctx._backend.library_state, input, target,
58 output, *ctx.additional_args)
62 def backward(ctx, grad_output):
63 input, target = ctx.saved_tensors
65 return ((backward_cls.apply(input, target, grad_output, ctx.additional_args, ctx._backend),) +
66 (
None,) * (ctx.forward_args_count + 1))
69 def backward_cls_forward(ctx, input, target, grad_output, additional_args_ctx, backend_ctx):
70 ctx.additional_args = additional_args_ctx
71 ctx._backend = backend_ctx
72 ctx.save_for_backward(input, target, grad_output)
73 grad_input = grad_output.new().resize_as_(input).zero_()
75 if reduce_arg_idx >= 0:
76 getattr(ctx._backend, update_grad_input.name)(ctx._backend.library_state, input, target,
77 grad_output, grad_input, *ctx.additional_args)
80 getattr(ctx._backend, update_grad_input.name)(ctx._backend.library_state, input, target,
81 grad_input, *ctx.additional_args)
82 grad_output_expanded = grad_output.view(*repeat(1, grad_input.dim()))
83 grad_input.mul_(grad_output_expanded.expand_as(grad_input))
87 def backward_cls_backward(ctx, *grad_params):
88 return double_backwards_fn(ctx, *grad_params)
90 backward_cls = type(class_name +
"Backward", (Function,),
91 dict(forward=backward_cls_forward, backward=backward_cls_backward))
92 return type(class_name, (Function,), dict(forward=forward, backward=backward, symbolic=symbolic)), backward_cls
95 def _find_buffers(args, ignored_args):
96 additional_arg_idx = 0
99 if arg.name
in ignored_args:
101 if arg.type ==
'THTensor*':
102 buffers.append((additional_arg_idx, arg.name))
103 additional_arg_idx += 1
107 def _make_function_class(class_name, update_output, update_grad_input, acc_grad_parameters,
108 double_backwards_fn, symbolic_fn):
109 def has_argument(fn, name):
110 for arg
in fn.arguments:
114 save_output = has_argument(update_grad_input,
'output')
116 param_args = {
'weight',
'bias'}
117 ignored_args = {
'weight',
'bias',
'gradWeight',
'gradBias',
'output'}
118 expected_params = [arg
for arg
in update_output.arguments[3:]
119 if arg.name
in param_args]
121 buffers[
'update_output'] = _find_buffers(update_output.arguments[3:],
123 buffers[
'update_grad_input'] = _find_buffers(
124 update_grad_input.arguments[4:], ignored_args)
125 if acc_grad_parameters
is not None:
126 buffers[
'acc_grad_parameters'] = _find_buffers(
127 acc_grad_parameters.arguments[3:], ignored_args)
131 is_inplace = update_output.arguments[-1].name ==
'inplace' 133 def _initialize_buffers(ctx, fn_name):
134 additional_args = ctx.additional_args
135 for idx, name
in buffers[fn_name]:
138 buffer = ctx.buffers[name]
139 additional_args = additional_args[:idx] + [buffer] + additional_args[idx:]
140 return tuple(additional_args)
143 def symbolic(*args, **kwargs):
144 return symbolic_fn(*args, **kwargs)
147 def forward(ctx, input, *params):
148 ctx._backend = type2backend[input.type()]
150 ctx.additional_args = []
151 tensor_param_list = []
153 if isinstance(param, torch.Tensor):
154 if type(param) != type(input):
155 raise RuntimeError(
"input type ({}) doesn't match the type of " 158 tensor_param_list.append(param)
160 ctx.additional_args.append(param)
162 tensor_params = tuple(tensor_param_list)
164 ctx.inplace = params[-1]
166 ctx.buffers = defaultdict(type(input))
167 additional_args = _initialize_buffers(ctx,
'update_output')
171 for i
in range(len(params), len(expected_params)):
172 param = expected_params[i]
173 if param.is_optional:
176 raise ValueError(
"missing required argument '%s'" % param.name)
178 args += tuple(additional_args)
182 if is_inplace
and ctx.inplace:
183 ctx.mark_dirty(input)
189 ctx.save_for_backward(input, output, *tensor_params)
191 ctx.save_for_backward(input, *tensor_params)
193 if not ctx.requires_grad:
196 getattr(ctx._backend, update_output.name)(ctx._backend.library_state, input, output, *args)
200 def backward(ctx, grad_output):
201 t = ctx.saved_tensors
202 input, tensor_params = t[0], t[1:]
209 return (backward_cls.apply(input, grad_output, ctx.additional_args, ctx._backend, ctx.buffers, *tensor_params) +
210 (
None,) * len(ctx.additional_args))
213 def backward_cls_forward(ctx, input, grad_output, additional_args_ctx, backend_ctx, buffers_ctx, *params):
214 ctx.additional_args = additional_args_ctx
215 ctx.buffers = buffers_ctx
216 ctx._backend = backend_ctx
217 ctx.save_for_backward(input, grad_output, *params)
222 grad_params = tuple(
None for p
in params)
223 grad_input_tuple = (
None,)
225 ctx.inplace = additional_args_ctx[-1]
227 if ctx.needs_input_grad[0]:
228 additional_args = _initialize_buffers(ctx,
'update_grad_input')
230 additional_args = (output,) + additional_args
232 if is_inplace
and ctx.inplace:
233 assert additional_args[-1]
is True 234 tmp_args = list(additional_args)
236 additional_args = tuple(tmp_args)
237 grad_input = input.new(input.size())
238 params_without_bias = params
if len(params) < 2
else params[:1]
239 update_grad_input_fn = getattr(ctx._backend, update_grad_input.name)
240 gi_args = params_without_bias + additional_args
241 update_grad_input_fn(ctx._backend.library_state, input, grad_output, grad_input, *gi_args)
242 grad_input_tuple = (grad_input,)
244 if acc_grad_parameters
and any(ctx.needs_input_grad[1:]):
245 additional_args = _initialize_buffers(ctx,
'acc_grad_parameters')
246 grad_params = tuple(p.new(p.size()).zero_()
for p
in params)
247 appended_grads = len(expected_params) - len(grad_params)
248 grad_params += (
None,) * appended_grads
249 acc_grad_parameters_fn = getattr(ctx._backend, acc_grad_parameters.name)
250 param_args = grad_params + additional_args + (1,)
251 acc_grad_parameters_fn(ctx._backend.library_state, input, grad_output, *param_args)
253 grad_params = grad_params[:-appended_grads]
255 return grad_input_tuple + grad_params
258 def backward_cls_backward(ctx, *grad_params):
259 return double_backwards_fn(ctx, *grad_params)
261 base_class = Function
if not is_inplace
else InplaceFunction
262 backward_cls = type(class_name +
"Backward", (base_class,), dict(forward=backward_cls_forward,
263 backward=backward_cls_backward))
265 return type(class_name, (base_class,), dict(forward=forward, backward=backward, symbolic=symbolic)), backward_cls
268 def _generate_function_classes(scope_dict):
269 global function_list, function_by_name
270 function_list = parse_header(THNN_H_PATH)
271 function_by_name = {fn.name: fn
for fn
in function_list}
272 classes_to_generate = {fn.name.partition(
'_')[0]
for fn
in function_list}
276 'SpatialFullConvolution',
277 'SpatialConvolutionMM',
278 'TemporalConvolution',
279 'SpatialAveragePooling',
281 'SpatialDilatedMaxPooling',
282 'SpatialMaxUnpooling',
283 'SpatialAdaptiveMaxPooling',
284 'VolumetricAveragePooling',
285 'VolumetricMaxPooling',
286 'VolumetricMaxUnpooling',
287 'VolumetricAdaptiveAveragePooling',
288 'VolumetricAdaptiveMaxPooling',
289 'VolumetricConvolution',
290 'VolumetricFullConvolution',
291 'VolumetricConvolutionMM',
292 'TemporalMaxPooling',
293 'BatchNormalization',
305 'TemporalConvolution':
'Conv1d',
306 'SpatialDilatedConvolution':
'DilatedConv2d',
307 'SpatialMaxUnpooling':
'MaxUnpool2d',
308 'VolumetricMaxUnpooling':
'MaxUnpool3d',
309 'HardTanh':
'Hardtanh',
310 'HardShrink':
'Hardshrink',
311 'SoftPlus':
'Softplus',
312 'SoftShrink':
'Softshrink',
313 'MSECriterion':
'MSELoss',
314 'AbsCriterion':
'L1Loss',
315 'BCECriterion':
'BCELoss',
316 'ClassNLLCriterion':
'NLLLoss',
317 'DistKLDivCriterion':
'KLDivLoss',
318 'SpatialClassNLLCriterion':
'NLLLoss2d',
319 'MultiLabelMarginCriterion':
'MultiLabelMarginLoss',
320 'MultiMarginCriterion':
'MultiMarginLoss',
321 'SmoothL1Criterion':
'SmoothL1Loss',
322 'SoftMarginCriterion':
'SoftMarginLoss',
325 classes_to_generate -= exceptions
326 for fn
in classes_to_generate:
327 update_output = function_by_name[fn +
'_updateOutput']
328 update_grad_input = function_by_name[fn +
'_updateGradInput']
329 acc_grad_parameters = function_by_name.get(fn +
'_accGradParameters')
330 class_name = name_remap.get(fn, fn)
331 double_backwards_fn = double_backwards_fns.get(class_name)
332 if double_backwards_fn
is None:
333 def make_default_double_backwards_fn(class_name):
334 def default_double_backwards_fn(ctx, *grad_params):
335 raise ValueError(class_name +
" can only be differentiated once.")
336 return default_double_backwards_fn
337 double_backwards_fn = make_default_double_backwards_fn(class_name)
338 symbolic_fn = symbolic_fns.get(class_name)
340 is_criterion_fn =
'Criterion' in fn
342 cls, backward_cls = _make_function_class_criterion(class_name, update_output,
343 update_grad_input, acc_grad_parameters,
344 double_backwards_fn, symbolic_fn)
346 cls, backward_cls = _make_function_class(class_name, update_output,
347 update_grad_input, acc_grad_parameters,
348 double_backwards_fn, symbolic_fn)
349 scope_dict[class_name] = cls
350 scope_dict[backward_cls.__name__] = backward_cls
351 if not class_name.startswith(
'_'):
352 _all_functions.append(cls)
353 _all_functions.append(backward_cls)
356 _generate_function_classes(locals())
def typename(o)
Define basic utilities.