1 from __future__
import absolute_import
2 from __future__
import division
3 from __future__
import print_function
4 from __future__
import unicode_literals
7 from caffe2.proto
import caffe2_pb2
12 def rewrite_init_net_simple(net):
14 op.device_option.device_type = caffe2_pb2.IDEEP
16 def last_producer(ops, blob):
17 for (i, op)
in reversed(list(enumerate(ops))):
20 raise ValueError(
"Failed to find last producer of blob, %s", blob)
23 def fix_BoxWithNMSLimit(net):
26 if op.type ==
'BoxWithNMSLimit':
27 outputs.add(op.output[0])
28 outputs.add(op.output[1])
29 outputs.add(op.output[2])
31 if op.type ==
'CopyIDEEPToCPU':
32 if op.input[0]
in outputs:
33 print(
"Chaning CopyIDEEPToCPU to Copy for {}".format(op.input[0]))
35 op.device_option.device_type = caffe2_pb2.CPU
38 def rewrite_run_net_simple(net):
43 return "{}__MKL__".format(name)
45 input_blob = net.external_input[0]
46 if input_blob != net.op[0].input[0]:
48 "Input blob: {} is not consumed by first op: {}".format(
49 input_blob, net.op[0]))
51 from_cpu =
"CopyCPUToIDEEP" 52 to_cpu =
"CopyIDEEPToCPU" 53 copy_input_op = core.CreateOperator(
54 from_cpu, input_blob, mkl_tmp(input_blob))
55 net.op[0].input[0] = mkl_tmp(input_blob)
58 core.CreateOperator(to_cpu, mkl_tmp(output_blob), output_blob)
59 for output_blob
in net.external_output]
61 for output_blob
in net.external_output:
62 last_producer_idx = last_producer(net.op, output_blob)
63 renamed_outputs = [blob
if blob != output_blob
else mkl_tmp(blob)
64 for blob
in net.op[last_producer_idx].output]
65 net.op[last_producer_idx].output[:] = renamed_outputs
67 for op
in net.op[last_producer_idx + 1:]:
68 renamed_input = [blob
if blob != output_blob
else mkl_tmp(blob)
70 op.input[:] = renamed_input
72 ops = [copy_input_op] + net.op[:] + copy_output_ops
75 device = caffe2_pb2.IDEEP
77 op.device_option.MergeFrom(
78 core.DeviceOption(device_type=device))
84 fix_BoxWithNMSLimit(net)
87 def rewrite_run_net_simple_xrayocr_lstm(net):
97 return "{}__MKL__".format(name)
100 return "{}__CPU__".format(name)
102 input_blob = net.external_input[0]
103 if input_blob != net.op[0].input[0]:
105 "Input blob: {} is not consumed by first op: {}".format(
106 input_blob, net.op[0]))
108 from_cpu =
"CopyCPUToIDEEP" 109 to_cpu =
"CopyIDEEPToCPU" 110 copy_input_op = core.CreateOperator(
111 from_cpu, input_blob, mkl_tmp(input_blob))
112 net.op[0].input[0] = mkl_tmp(input_blob)
117 for input_blob
in net.external_input:
120 assert input_blob
not in op.output
122 external_output =
None 123 external_inputs_to_cpu = set()
124 find_first_shape_op =
False 125 cpu_op_start_idx = -1
126 for op_idx, op
in enumerate(net.op):
128 if not find_first_shape_op:
129 if op.type ==
'Shape':
130 external_output = op.input
131 find_first_shape_op =
True 132 cpu_op_start_idx = op_idx
135 for in_blob
in op.input:
136 if in_blob
in net.external_input:
137 external_inputs_to_cpu.add(in_blob)
140 assert external_output
is not None 143 copy_extra_input_ops = []
144 for in_blob
in external_inputs_to_cpu:
145 copy_extra_input_ops.append(core.CreateOperator(to_cpu, in_blob,
148 for op
in net.op[cpu_op_start_idx:]:
149 renamed_input = [blob
if blob != in_blob
else cpu_tmp(in_blob)
150 for blob
in op.input]
151 op.input[:] = renamed_input
154 core.CreateOperator(to_cpu, mkl_tmp(output_blob), output_blob)
155 for output_blob
in external_output]
157 for output_blob
in external_output:
158 last_producer_idx = last_producer(net.op, output_blob)
159 renamed_outputs = [blob
if blob != output_blob
else mkl_tmp(blob)
160 for blob
in net.op[last_producer_idx].output]
161 net.op[last_producer_idx].output[:] = renamed_outputs
164 ops = [copy_input_op] + net.op[:cpu_op_start_idx] \
165 + copy_output_ops + copy_extra_input_ops + net.op[cpu_op_start_idx:]
169 device = caffe2_pb2.IDEEP
172 if op.type ==
'Shape':
174 device = caffe2_pb2.CPU
175 op.device_option.MergeFrom(
176 core.DeviceOption(device_type=device))
180 if op.type ==
'RecurrentNetwork':
182 if arg.name ==
'step_net':
183 for nested_op
in arg.n.op:
185 nested_op.device_option.MergeFrom(
186 core.DeviceOption(device_type=device))
187 nested_op.engine =
"" 191 for blob
in nested_op.input:
192 renamed_input.append(blob
193 if blob
not in external_inputs_to_cpu
195 nested_op.input[:] = renamed_input
198 new_external_input = []
199 for blob
in arg.n.external_input:
200 new_external_input.append(blob
201 if blob
not in external_inputs_to_cpu
203 arg.n.external_input[:] = new_external_input
208 fix_BoxWithNMSLimit(net)
211 def rewrite_model_helper_simple(model):
212 model = copy.deepcopy(model)
214 rewrite_init_net_simple(model.param_init_net.Proto())
215 rewrite_run_net_simple(model.net.Proto())