3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from caffe2.proto
import caffe2_pb2
11 def gen_do_gradient(op, g_output):
13 Generates gradient Do operator, given forward Do op and a list 14 of gradient blobs corresponding to forward op's outputs 15 Returns a gradient op and a list of blobs corresponding to input gradients 18 subnet, outer_to_inner_map, inner_to_outer_map, workspace_blob_name = \
19 _do_op_sanity_check_and_process(op)
21 assert len(g_output) == len(op.output), \
22 "Different number of gradient blobs and Do op outputs" 24 grad_ops, deduped_g_output = dedupe_g_output(op, g_output)
25 g_output = deduped_g_output
38 op_output = [str(o)
for o
in op.output]
39 op_output = op_output[:-1]
40 op_input = [str(i)
for i
in op.input]
41 op_input = op_input[:-1]
43 ordered_inner_output_blob_names = [outer_to_inner_map[o]
for o
in op_output]
45 backward_pass_initial_grad_map = {}
47 for inner_output_name, outer_grad_output_name
in \
48 zip(ordered_inner_output_blob_names, g_output):
51 if outer_grad_output_name:
52 inner_grad_output_name = inner_output_name +
"/_DO_OPERATOR_INNER_GRAD_" 53 backward_pass_initial_grad_map[BlobReference(inner_output_name)] = \
54 BlobReference(inner_grad_output_name)
55 initial_grad_map[inner_grad_output_name] = str(outer_grad_output_name)
56 assert len(initial_grad_map) > 0,
"Empty initial gradient map for Do op" 58 inner_grad_ops, inner_grad_names_map = _gen_subgradient_pass(
59 subnet, backward_pass_initial_grad_map)
61 if len(inner_grad_ops) == 0:
67 new_blob_bindings = {}
68 for outer_input_name
in op_input:
69 inner_input_name = outer_to_inner_map[outer_input_name]
70 if inner_input_name
in inner_grad_names_map:
71 inner_grad_input_name = inner_grad_names_map[inner_input_name]
72 outer_grad_input_name = outer_input_name +
"_grad" 113 new_inner_grad_input_name = \
114 inner_input_name +
"/_DO_OPERATOR_INNER_GRAD_COPY_" 115 grad_copy_ops.append(_prepare_blob_copy_op(
116 inner_grad_input_name, new_inner_grad_input_name))
118 new_blob_bindings[new_inner_grad_input_name] = outer_grad_input_name
119 new_op_outputs.append(outer_grad_input_name)
120 g_input.append(outer_grad_input_name)
125 overwritten_names = set()
126 saved_local_blob_names = set()
127 for grad_op
in inner_grad_ops:
128 grad_op_input = [str(i)
for i
in grad_op.input]
129 grad_op_output = [str(o)
for o
in grad_op.output]
130 for grad_op_input_name
in grad_op_input:
131 if grad_op_input_name
in overwritten_names:
134 outer_name = inner_to_outer_map.get(grad_op_input_name,
None)
137 outer_name = initial_grad_map.get(grad_op_input_name,
None)
139 outer_name = str(outer_name)
140 if outer_name
not in new_op_inputs:
141 new_op_inputs.append(outer_name)
143 new_blob_bindings[grad_op_input_name] = outer_name
147 saved_local_blob_names.add(grad_op_input_name)
148 overwritten_names.update(grad_op_output)
151 inner_grad_ops += grad_copy_ops
153 gradient_do_def = _prepare_gradient_do_op(
156 grad_ops=inner_grad_ops,
157 inputs=new_op_inputs,
158 outputs=new_op_outputs,
159 blob_bindings=new_blob_bindings,
160 saved_fwd_blobs=saved_local_blob_names,
161 workspace_blob_name=workspace_blob_name)
162 grad_ops.append(gradient_do_def)
164 _do_op_sanity_check_and_process(gradient_do_def)
166 return grad_ops, g_input
169 def dedupe_g_output(op, g_output):
175 deduped_g_output = []
177 for output_name, grad_name
in zip(op.output, g_output):
179 deduped_g_output.append(grad_name)
182 if output_name
in init_grad_map:
183 deduped_g_output.append(init_grad_map[output_name])
185 if grad_name
not in init_grad_map.values():
186 init_grad_map[output_name] = grad_name
187 deduped_g_output.append(grad_name)
189 deduped_grad_name = output_name +
"_" + grad_name +
"_DEDUP" 190 assert deduped_grad_name
not in init_grad_map.values()
191 grad_copy_op = caffe2_pb2.OperatorDef()
192 grad_copy_op.type =
"Copy" 193 grad_copy_op.input.extend([grad_name])
194 grad_copy_op.output.extend([deduped_grad_name])
195 grad_ops.append(grad_copy_op)
196 deduped_g_output.append(deduped_grad_name)
197 init_grad_map[output_name] = deduped_grad_name
198 return grad_ops, deduped_g_output
201 def gen_while_gradient(op, g_output):
203 Generates gradient While operator 206 assert op.type ==
"While",
"Expected While op" 207 assert len(op.input) > 0,
"Expected at least one input in While op" 209 assert len(op.output) == len(g_output), \
210 "Different number of gradient blobs and While op outputs" 212 grad_ops, deduped_g_output = dedupe_g_output(op, g_output)
213 g_output = deduped_g_output
216 op_output = [str(o)
for o
in op.output]
217 for output_name, grad_output_name
in zip(op_output, g_output):
219 init_grad_map[BlobReference(output_name)] = \
220 BlobReference(grad_output_name)
221 assert len(init_grad_map) > 0,
"Empty initial gradient map for While op" 223 loop_net = _get_net_argument(op,
"loop_net")
224 assert loop_net,
"Expected loop subnet in While op" 225 assert len(loop_net.op) == 1
and loop_net.op[0].type ==
"Do", \
226 "Gradient While op requires single Do op as a loop body" 227 do_op = loop_net.op[0]
228 do_args = _get_do_arguments(do_op)
229 assert "reuse_workspace" not in do_args
or not do_args[
"reuse_workspace"], \
230 "Gradient While op requires Do loop body op without reuse_workspace set" 232 assert len(do_op.output) > 0,
"Expected Do op with at least one output" 233 workspace_blob = do_op.output[-1]
235 loop_grad_net, loop_grad_map, loop_input_names, loop_output_names = \
236 _gen_subnet_gradient(loop_net, init_grad_map)
237 assert loop_grad_net,
"Failed to get gradient net for loop body in While op" 239 grad_ops += _prepare_gradient_while_ops(
241 input_names=loop_input_names,
242 output_names=loop_output_names,
243 loop_grad_net=loop_grad_net,
244 workspace_blob=workspace_blob,
245 init_grad_map=init_grad_map,
246 loop_grad_map=loop_grad_map)
248 op_input = [str(i)
for i
in op.input]
249 g_input = [loop_grad_map.get(i,
None)
for i
in op_input]
250 return grad_ops, g_input
261 def _prepare_gradient_while_ops(
262 fwd_op, input_names, output_names, loop_grad_net, workspace_blob,
263 init_grad_map, loop_grad_map):
264 gradient_while_def = caffe2_pb2.OperatorDef()
265 gradient_while_def.CopyFrom(fwd_op)
266 if gradient_while_def.name:
267 gradient_while_def.name +=
"_grad" 269 loop_net_arg = caffe2_pb2.Argument()
270 loop_net_arg.name =
"loop_net" 271 loop_net_arg.n.CopyFrom(loop_grad_net)
273 cond_net_arg = caffe2_pb2.Argument()
274 cond_net_arg.name =
"cond_net" 278 cond_net = Net(
'gradient_loop_cond_net')
279 cond_init_net = Net(
'gradient_loop_cond_net_init')
280 cond_blob = cond_net.NextScopedBlob(cond_net.Name() +
'/cond')
281 cond_init_net.HasScope(workspace_blob, cond_blob)
282 cond_net.HasScope(workspace_blob, cond_blob)
283 for blob, init_grad_blob
in init_grad_map.items():
284 blob_name = str(blob)
285 init_grad_blob_name = str(init_grad_blob)
286 if blob_name
in loop_grad_map
and \
287 loop_grad_map[blob_name] != init_grad_blob_name:
289 BlobReference(loop_grad_map[blob_name]), init_grad_blob)
291 init_grad_blob, BlobReference(loop_grad_map[blob_name]))
292 cond_net_arg.n.CopyFrom(cond_net.Proto())
294 del gradient_while_def.arg[:]
295 gradient_while_def.arg.extend([loop_net_arg, cond_net_arg])
297 del gradient_while_def.control_input[:]
298 del gradient_while_def.input[:]
299 gradient_while_def.input.extend(
300 [str(cond_blob).encode(
'utf-8')] + list(input_names))
301 del gradient_while_def.output[:]
302 gradient_while_def.output.extend(output_names)
303 gradient_while_def.is_gradient_op =
True 304 return [o
for o
in cond_init_net.Proto().op] + [gradient_while_def]
307 def _get_do_arguments(do_op):
308 assert do_op.type ==
"Do",
"Expected Do op" 310 for arg
in do_op.arg:
313 if arg.name ==
"net":
314 assert arg.n,
"Expected non empty net argument" 316 elif arg.name ==
"reuse_workspace":
317 assert arg.i,
"Expected non empty reuse_workspace argument" 318 args[
"reuse_workspace"] = bool(arg.i)
319 elif arg.name ==
"inner_blobs":
320 assert arg.strings,
"Expected non empty inner_blobs argument" 321 args[
"inner_blobs"] = arg.strings
322 elif arg.name ==
"outer_blobs_idx":
323 assert arg.ints,
"Expected non empty outer_blobs_idx argument" 324 args[
"outer_blobs_idx"] = arg.ints
328 def gen_if_gradient(op, g_output):
330 Generates gradient If operator, given forward If op and a list 331 of gradient blobs corresponding to forward op's outputs 332 Returns a gradient op and a list of blobs corresponding to input gradients 335 assert op.type ==
"If",
"Expected If op" 337 assert len(op.input) > 0,
"Expected at least one input in If op" 339 assert len(op.output) == len(g_output), \
340 "Different number of gradient blobs and If op outputs" 342 grad_ops, deduped_g_output = dedupe_g_output(op, g_output)
343 g_output = deduped_g_output
346 op_input = [str(i)
for i
in op.input]
347 op_output = [str(o)
for o
in op.output]
348 for output_name, grad_output_name
in zip(op_output, g_output):
350 init_grad_map[BlobReference(output_name)] = \
351 BlobReference(grad_output_name)
353 assert len(init_grad_map) > 0,
"Empty initial gradient map for If op" 356 then_net = _get_net_argument(op,
"then_net")
357 assert then_net,
"Expected then subnet in If op" 358 then_grad_net, then_grad_map, then_input_names, then_output_names = \
359 _gen_subnet_gradient(then_net, init_grad_map)
360 assert then_grad_net,
"Failed to get gradient net for then in If op" 361 grad_map.update(then_grad_map)
363 else_input_names = set()
364 else_output_names = set()
367 else_net = _get_net_argument(op,
"else_net")
369 else_grad_net, else_grad_map, else_input_names, else_output_names = \
370 _gen_subnet_gradient(else_net, init_grad_map)
371 assert else_grad_net,
"Failed to get gradient net for else in If op" 374 for else_blob, else_grad_blob
in else_grad_map.items():
375 if else_blob
in then_grad_map:
376 then_grad_blob = then_grad_map[else_blob]
382 if then_grad_blob != else_grad_blob:
383 init_grad_name = init_grad_map[else_blob] \
384 if else_blob
in init_grad_map
else None 386 if then_grad_blob == init_grad_name:
387 grad_map[else_blob] = else_grad_blob
388 elif else_grad_blob == init_grad_name:
389 grad_map[else_blob] = then_grad_blob
391 raise "Unexpected grad blob name " + else_blob +
", " + \
392 else_grad_blob +
", " + then_grad_blob
394 grad_map[else_blob] = else_grad_blob
398 then_other_output_names = \
399 then_output_names - (then_output_names & else_output_names)
400 then_other_grad_output_names = set(
401 [o
for o
in then_other_output_names
if o
in then_grad_map.values()])
402 zero_then = _gen_grad_zero_init_ops(
403 init_grad_map, then_grad_map, then_other_grad_output_names)
405 else_grad_net.op.extend(zero_then)
406 elif len(zero_then) > 0:
407 else_grad_net = caffe2_pb2.NetDef()
408 else_grad_net.CopyFrom(then_grad_net)
409 if else_grad_net.name:
410 else_grad_net.name +=
"_auto_else_zero_blobs_" 411 del else_grad_net.op[:]
412 else_grad_net.op.extend(zero_then)
413 del else_grad_net.external_input[:]
414 del else_grad_net.external_output[:]
416 else_other_output_names = \
417 else_output_names - (then_output_names & else_output_names)
418 else_other_grad_output_names = set(
419 [o
for o
in else_other_output_names
if o
in else_grad_map.values()])
420 zero_else = _gen_grad_zero_init_ops(
421 init_grad_map, else_grad_map, else_other_grad_output_names)
422 then_grad_net.op.extend(zero_else)
424 output_names = list(then_output_names | else_output_names)
425 input_names = then_input_names | else_input_names
427 input_names = [op_input[0]] + list(input_names - set(op_input[0]))
428 gradient_if_def = _prepare_gradient_if_op(
430 input_names=input_names,
431 output_names=output_names,
432 then_grad_net=then_grad_net,
433 else_grad_net=else_grad_net)
434 g_input = [grad_map.get(i,
None)
for i
in op_input]
435 return grad_ops + [gradient_if_def], g_input
438 def _gen_subnet_gradient(subnet, init_grad):
439 grad_ops, grad_names_map = _gen_subgradient_pass(
444 for grad_op
in grad_ops:
445 for grad_op_input
in grad_op.input:
446 if str(grad_op_input)
not in output_names:
447 input_names.add(str(grad_op_input))
448 for grad_op_output
in grad_op.output:
449 output_names.add(str(grad_op_output))
451 gradient_net_def = caffe2_pb2.NetDef()
452 gradient_net_def.CopyFrom(subnet)
453 if gradient_net_def.name:
454 gradient_net_def.name +=
"_grad" 455 del gradient_net_def.op[:]
456 gradient_net_def.op.extend(grad_ops)
457 del gradient_net_def.external_input[:]
458 del gradient_net_def.external_output[:]
460 return gradient_net_def, grad_names_map, input_names, output_names
463 def _get_net_argument(op, net_name):
465 if arg.name
and arg.name == net_name:
466 assert arg.n,
"Expected non empty net argument " + net_name
471 def getNetArgument(op, net_name):
472 """A wrapper for external call""" 473 return _get_net_argument(op, net_name)
476 def _gen_subgradient_pass(subnet, init_grad):
478 subnet_ir = IR(subnet.op)
479 grad_ops, grad_blob_map = \
480 subnet_ir.GetBackwardPass(init_grad)
482 for b, g
in grad_blob_map.items():
483 grad_names_map[str(b)] = str(g)
484 return grad_ops, grad_names_map
487 def _do_op_sanity_check_and_process(op):
488 assert op.type ==
"Do",
"Expected Do op" 490 subnet = _get_net_argument(op,
"net")
491 assert subnet,
"No net argument found in Do op" 494 outer_blobs_idx =
None 496 if arg.name
and arg.name ==
"inner_blobs":
497 assert not inner_blobs,
"inner_blobs redefinition" 498 assert arg.strings
and len(arg.strings) > 0, \
499 "Empty inner_blobs argument in Do op" 500 inner_blobs = [s.decode(
'utf-8')
for s
in arg.strings]
501 if arg.name
and arg.name ==
"outer_blobs_idx":
502 assert not outer_blobs_idx,
"outer_blobs_idx redefinition" 503 assert arg.ints
and len(arg.ints) > 0, \
504 "Empty outer_blobs_idx argument in Do op" 505 outer_blobs_idx = arg.ints
506 if inner_blobs
and outer_blobs_idx:
509 assert inner_blobs,
"No inner_blobs argument found in Do op" 510 assert outer_blobs_idx,
"No outer_blobs_idx argument found in Do op" 512 assert len(inner_blobs) == len(outer_blobs_idx), \
513 "Arguments inner_blobs and outer_blobs_idx of different length in Do op" 515 all_inner_blobs = set(inner_blobs)
516 assert len(all_inner_blobs) == len(inner_blobs), \
517 "Found duplicates in inner_blobs in Do op" 519 op_input = [str(i)
for i
in op.input]
520 assert len(op_input) > 0,
"Expected at least one input blob" 522 input_workspace_blob_name = op_input[-1]
523 op_input = op_input[:-1]
525 op_output = [str(o)
for o
in op.output]
526 assert len(op_output) > 0,
"Expected at least one output blob" 528 workspace_blob_name = op_output[-1]
529 assert input_workspace_blob_name == workspace_blob_name, \
530 "Expected same input/output workspace blob" 531 op_output = op_output[:-1]
533 all_op_input_blob_names = set(op_input)
534 assert len(all_op_input_blob_names) == len(op_input), \
535 "Found duplicates in Do op inputs" 536 all_op_output_blob_names = set(op_output)
537 assert len(all_op_output_blob_names) == len(op_output), \
538 "Found duplicates in Do op outputs" 540 ordered_outer_blob_names = op_input + op_output
541 all_outer_blob_names = set(ordered_outer_blob_names)
542 used_outer_blob_names = set()
543 outer_to_inner_map = {}
544 inner_to_outer_map = {}
545 for inner_name, outer_blob_idx
in zip(inner_blobs, outer_blobs_idx):
546 assert outer_blob_idx >= 0
and \
547 outer_blob_idx < len(ordered_outer_blob_names), \
548 "Outer blob index is out of bounds in Do op" 549 outer_name = ordered_outer_blob_names[outer_blob_idx]
550 assert outer_name
not in used_outer_blob_names, \
551 "Reusage of outer blob name " + outer_name +
" in Do op" 552 used_outer_blob_names.add(outer_name)
553 outer_to_inner_map[outer_name] = inner_name
554 inner_to_outer_map[inner_name] = outer_name
556 assert len(used_outer_blob_names) == len(all_outer_blob_names), \
557 "Not all outer blob names are used in blob bindings in Do op" 559 return subnet, outer_to_inner_map, inner_to_outer_map, workspace_blob_name
562 def _prepare_blob_copy_op(from_name, to_name):
563 copy_op_def = caffe2_pb2.OperatorDef()
564 copy_op_def.type =
"Copy" 565 copy_op_def.input.extend([from_name])
566 copy_op_def.output.extend([to_name])
570 def _prepare_gradient_do_op(
571 fwd_op, fwd_net, grad_ops, inputs, outputs, blob_bindings, saved_fwd_blobs,
572 workspace_blob_name):
573 gradient_net_def = caffe2_pb2.NetDef()
574 gradient_net_def.CopyFrom(fwd_net)
575 if gradient_net_def.name:
576 gradient_net_def.name +=
"_grad" 577 del gradient_net_def.op[:]
578 gradient_net_def.op.extend(grad_ops)
579 del gradient_net_def.external_input[:]
580 del gradient_net_def.external_output[:]
582 gradient_do_def = caffe2_pb2.OperatorDef()
583 gradient_do_def.CopyFrom(fwd_op)
584 if gradient_do_def.name
and len(gradient_do_def.name) > 0:
585 gradient_do_def.name +=
"_grad" 587 del gradient_do_def.input[:]
588 gradient_do_def.input.extend(inputs)
590 gradient_do_def.input.append(workspace_blob_name)
591 del gradient_do_def.output[:]
592 gradient_do_def.output.extend(outputs)
594 gradient_do_def.output.append(workspace_blob_name)
596 net_arg = caffe2_pb2.Argument()
598 net_arg.n.CopyFrom(gradient_net_def)
600 ordered_new_outer_names = inputs + outputs
601 inner_blobs = blob_bindings.keys()
602 new_outer_blobs_idx = [ordered_new_outer_names.index(blob_bindings[b])
603 for b
in inner_blobs]
605 inner_blobs_arg = caffe2_pb2.Argument()
606 inner_blobs_arg.name =
"inner_blobs" 607 inner_blobs_arg.strings.extend([b.encode(
'utf-8')
for b
in inner_blobs])
609 outer_blobs_idx_arg = caffe2_pb2.Argument()
610 outer_blobs_idx_arg.name =
"outer_blobs_idx" 611 outer_blobs_idx_arg.ints.extend(new_outer_blobs_idx)
613 saved_blobs_arg = caffe2_pb2.Argument()
614 saved_blobs_arg.name =
"saved_fwd_blobs" 615 saved_blobs_arg.strings.extend(
616 [b.encode(
'utf-8')
for b
in saved_fwd_blobs])
618 del gradient_do_def.arg[:]
619 gradient_do_def.arg.extend([
620 net_arg, inner_blobs_arg, outer_blobs_idx_arg, saved_blobs_arg])
621 del gradient_do_def.control_input[:]
623 gradient_do_def.is_gradient_op =
True 625 return gradient_do_def
628 def _gen_grad_zero_init_ops(init_grad_map, grad_map, grad_output_names):
630 for grad_output
in grad_output_names:
634 for o, g
in grad_map.items():
638 assert output_name,
"Unknown gradient output " + grad_output
642 if output_name
in init_grad_map:
643 init_grad_name = init_grad_map[output_name]
645 if init_grad_name != grad_output:
646 grad_init_op = caffe2_pb2.OperatorDef()
647 grad_init_op.type =
"Copy" 648 grad_init_op.input.extend([str(init_grad_name)])
649 grad_init_op.output.extend([str(grad_output)])
651 grad_init_op = caffe2_pb2.OperatorDef()
652 grad_init_op.type =
"ConstantFill" 653 grad_init_op.input.extend([output_name])
654 grad_init_op.output.extend([grad_output])
655 value_arg = caffe2_pb2.Argument()
656 value_arg.name =
"value" 658 grad_init_op.arg.extend([value_arg])
661 grad_init_ops.append(grad_init_op)
665 def _prepare_gradient_if_op(
666 fwd_op, input_names, output_names, then_grad_net, else_grad_net):
667 gradient_if_def = caffe2_pb2.OperatorDef()
668 gradient_if_def.CopyFrom(fwd_op)
669 del gradient_if_def.input[:]
670 gradient_if_def.input.extend(input_names)
671 del gradient_if_def.output[:]
672 gradient_if_def.output.extend(output_names)
674 then_net_arg = caffe2_pb2.Argument()
675 then_net_arg.name =
"then_net" 676 then_net_arg.n.CopyFrom(then_grad_net)
677 gradient_args = [then_net_arg]
679 else_net_arg = caffe2_pb2.Argument()
680 else_net_arg.name =
"else_net" 681 else_net_arg.n.CopyFrom(else_grad_net)
682 gradient_args.append(else_net_arg)
684 del gradient_if_def.arg[:]
685 gradient_if_def.arg.extend(gradient_args)
686 if gradient_if_def.name:
687 gradient_if_def.name +=
"_grad" 688 del gradient_if_def.control_input[:]
689 gradient_if_def.is_gradient_op =
True 690 return gradient_if_def
693 def disambiguate_grad_if_op_output(grad_op, idx, new_grad_output):
694 then_net = _get_net_argument(grad_op,
"then_net")
695 old_grad_out_match = grad_op.output[idx]
696 for op
in then_net.op:
697 for i, out
in enumerate(op.output):
698 if out == old_grad_out_match:
699 op.output[i] = new_grad_output
700 else_net = _get_net_argument(grad_op,
"else_net")
702 for op
in else_net.op:
703 for i, out
in enumerate(op.output):
704 if out == old_grad_out_match:
705 op.output[i] = new_grad_output
706 grad_op.output[idx] = new_grad_output