1 from __future__
import absolute_import, division, print_function, unicode_literals
3 from collections
import defaultdict
9 def transpose_network(nn):
11 Convert all Convolutions operators which are in the NCHW order 12 to NHWC order and also transform their inputs and outputs so that the 13 rest of the graph is not affected. 18 outgoing = defaultdict(
lambda: [])
20 orig_nodes = [x
for x
in nn.nodes]
21 for node
in orig_nodes:
22 if node.isOperator()
and node.name ==
"Conv":
23 arg_dict = utils.ArgsToDict(node.annotation.operator_def.arg)
25 if "order" in arg_dict
and arg_dict[
"order"] !=
"NCHW":
27 inputs = [x
for x
in node.inputs]
28 assert len(inputs) >= 2,
"Conv operator should have two inputs" 29 outputs = [x
for x
in node.outputs]
30 assert len(outputs) >= 1,
"Conv operator should have an output" 32 nn.deleteEdge(inp, node)
34 nn.deleteEdge(node, outp)
38 new_inp = nn.createUniqueDataNode(inputs[idx].name)
39 transp = dfg.createNode(ng.NeuralNetOperator(
"NCHW2NHWC"))
40 nn.createEdge(inputs[idx], transp)
41 nn.createEdge(transp, new_inp)
42 outgoing[inputs[idx]].append(transp)
44 for idx
in range(len(outputs)):
45 new_outp = nn.createUniqueDataNode(outputs[idx].name)
46 transp = dfg.createNode(ng.NeuralNetOperator(
"NHWC2NCHW"))
47 nn.createEdge(transp, outputs[idx])
48 nn.createEdge(new_outp, transp)
49 incoming[outputs[idx]] = new_outp
50 outputs[idx] = new_outp
53 arg_dict[
"order"] =
"NHWC" 54 new_node = nn.createNode(core.CreateOperator(
"Conv", [], [],
57 nn.createEdge(inp, new_node)
59 nn.createEdge(new_node, outp)
74 for orig_tensor
in outgoing:
76 if orig_tensor
in incoming:
77 new_tensor = incoming[orig_tensor]
79 out_ops = outgoing[orig_tensor]
80 new_tensor = out_ops[0].outputs[0]
81 outgoing[orig_tensor] = out_ops[1:]
83 for opnode
in outgoing[orig_tensor]:
85 for out
in opnode.outputs:
86 nn.replaceAllUsesWith(out, new_tensor)