1 from __future__
import absolute_import
2 from __future__
import division
3 from __future__
import print_function
4 from __future__
import unicode_literals
13 from caffe2.proto
import caffe2_pb2
19 def pairwise(iterable):
20 from itertools
import tee
26 def last_producer(ops, blob):
27 for (i, op)
in reversed(list(enumerate(ops))):
30 raise ValueError(
"Failed to find last producer of blob, %s", blob)
33 def blob_uses(net, blob):
35 for i, op
in enumerate(net.op):
36 if blob
in op.input
or blob
in op.control_input:
41 def GetArgumentParser():
42 parser = argparse.ArgumentParser(description=
"Caffe2 optimization")
43 parser.add_argument(
"--init_net",
44 type=argparse.FileType(
'rb'),
46 parser.add_argument(
"--pred_net",
47 type=argparse.FileType(
'rb'),
49 parser.add_argument(
"--verify_input",
50 type=argparse.FileType(
'r'), 51 help="input dims for verification")
52 parser.add_argument(
"--fuse_bn", default=
False, action=
'store_true')
53 parser.add_argument(
"--fuse_mul_add", default=
False, action=
'store_true')
54 parser.add_argument(
"--fuse_conv_relu", default=
False, action=
'store_true')
58 def fuse_first_bn(net, params, removed_tensors):
59 net = copy.deepcopy(net)
60 params = copy.deepcopy(params)
62 for ((i, current), (j, next_))
in pairwise(enumerate(net.op)):
63 if next_.input[0] != current.output[0]:
66 if current.type
not in (
"Conv",
"ConvTranspose") \
67 or next_.type !=
"SpatialBN":
69 if len(blob_uses(net, current.output[0])) != 1:
76 fused_conv = copy.deepcopy(conv)
77 fused_conv.output[0] = bn.output[0]
80 if len(fused_conv.input) != 3:
81 bias_name =
"{}_bias".format(conv.input[1])
82 net.external_input.extend([bias_name])
83 fused_conv.input.extend([bias_name])
84 for arg
in fused_conv.arg:
85 if arg.name ==
"no_bias":
88 conv_weight = params[conv.input[1]]
89 conv_bias = params[conv.input[2]]
if len(conv.input) == 3 \
90 else np.zeros(shape=(conv_weight.shape[0])).astype(np.float32)
92 bn_scale = params[bn.input[1]]
93 bn_bias = params[bn.input[2]]
94 bn_running_mean = params[bn.input[3]]
95 bn_running_var = params[bn.input[4]]
111 if arg.name ==
"epsilon":
113 A = bn_scale * 1.0 / (np.sqrt(bn_running_var + eps))
114 B = bn_bias - bn_running_mean * A
138 A_ = A.reshape(-1, 1, 1, 1)
if conv.type ==
"Conv" else \
139 A.reshape(1, -1, 1, 1)
141 C = conv_bias * A + B
144 params[fused_conv.input[1]] = Q
145 params[fused_conv.input[2]] = C
146 new_ops = net.op[:i] + [fused_conv] + net.op[j + 1:]
148 removed_tensors.append(bn.input[1])
149 removed_tensors.append(bn.input[2])
150 removed_tensors.append(bn.input[3])
151 removed_tensors.append(bn.input[4])
152 del params[bn.input[1]]
153 del params[bn.input[2]]
154 del params[bn.input[3]]
155 del params[bn.input[4]]
156 net.op.extend(new_ops)
158 return net, params, removed_tensors
161 def fuse_bn(net, params, ignore_failure):
165 (next_net, next_params, removed_tensors) = \
166 fuse_first_bn(net, params, removed_tensors)
167 if len(next_net.op) == len(net.op):
169 any(op.type ==
"SpatialBN" for op
in next_net.op)
and 173 "Model contains SpatialBN op after fusion: %s", next_net)
174 return (next_net, next_params, removed_tensors)
175 net, params, removed_tensors = (next_net, next_params, removed_tensors)
178 def fuse_first_mul_add(net, params, removed_tensors):
179 net = copy.deepcopy(net)
180 params = copy.deepcopy(params)
182 for ((i, current), (j, next_))
in pairwise(enumerate(net.op)):
183 if current.type !=
"Mul" or next_.type !=
"Add":
186 if next_.input[0] != current.output[0]:
187 raise Exception(
"Failure to fuse")
189 if len(blob_uses(net, current.output[0])) != 1:
190 raise Exception(
"Failure to fuse")
192 log.info(
"Fusing at index %s", i)
195 batch_norm = copy.deepcopy(mul_)
196 batch_norm.type =
"SpatialBN" 197 batch_norm.arg.extend([utils.MakeArgument(
"is_test", 1)])
198 batch_norm.arg.extend([utils.MakeArgument(
"epsilon", float(1e-9))])
201 return "{}{}".format(add_.output[0], x)
202 fake_mean = s(
"_mean")
205 del batch_norm.input[:]
206 batch_norm.input.extend([mul_.input[0],
211 params[fake_mean] = np.zeros_like(params[mul_.input[1]])
212 params[fake_var] = np.ones_like(params[mul_.input[1]])
213 net.external_input.extend([fake_mean, fake_var])
215 batch_norm.output[0] = add_.output[0]
216 new_ops = net.op[:i] + [batch_norm] + net.op[j + 1:]
218 net.op.extend(new_ops)
220 return net, params, removed_tensors
223 def fuse_mul_add(net, params):
227 (next_net, next_params, removed_tensors) = \
228 fuse_first_mul_add(net, params, removed_tensors)
229 if len(next_net.op) == len(net.op):
230 return (next_net, next_params, removed_tensors)
231 net, params, removed_tensors = (next_net, next_params, removed_tensors)
234 def add_tensor(net, name, blob):
235 ''' Create an operator to store the tensor 'blob', 236 run the operator to put the blob to workspace. 237 uint8 is stored as an array of string with one element. 240 np.dtype(
'float32'):
"GivenTensorFill",
241 np.dtype(
'int32'):
"GivenTensorIntFill",
242 np.dtype(
'int64'):
"GivenTensorInt64Fill",
243 np.dtype(
'uint8'):
"GivenTensorStringFill",
250 if blob.dtype == np.dtype(
'uint8'):
252 values = [str(blob.data)]
254 op = core.CreateOperator(
255 kTypeNameMapper[blob.dtype],
258 utils.MakeArgument(
"shape", shape),
259 utils.MakeArgument(
"values", values),
265 def gen_init_net_from_blobs(blobs):
266 ''' Generate an initialization net based on a blob dict ''' 267 ret = caffe2_pb2.NetDef()
268 for name, blob
in blobs.items():
269 add_tensor(ret, name, blob)
273 def fuse_conv_relu(net):
274 net = copy.deepcopy(net)
275 device_option = core.DeviceOption(caffe2_pb2.IDEEP)
277 op.device_option.CopyFrom(device_option)
279 new_net = caffe2_pb2.NetDef()
280 new_net.ParseFromString(C.transform_optimizeForIDEEP(net.SerializeToString()))
285 init_net = caffe2_pb2.NetDef()
286 predict_net = caffe2_pb2.NetDef()
287 init_net.ParseFromString(args.init_net.read())
288 predict_net.ParseFromString(args.pred_net.read())
290 workspace.ResetWorkspace()
291 workspace.RunNetOnce(init_net)
292 param_dict = {p: workspace.FetchBlob(p)
for p
in workspace.Blobs()}
295 external_outputs = {}
296 if args.verify_input:
297 value_info = json.load(args.verify_input)
298 input_shapes = {k : v[-1]
for (k, v)
in value_info.items()}
299 print(
"input info: {}".format(input_shapes))
300 for k, v
in input_shapes.items():
301 external_inputs[k] = np.random.randn(*v).astype(np.float32)
302 workspace.FeedBlob(k, external_inputs[k])
303 workspace.RunNetOnce(predict_net)
304 for o
in predict_net.external_output:
305 external_outputs[o] = workspace.FetchBlob(o)
307 if args.fuse_mul_add:
308 predict_net, param_dict, _ = fuse_mul_add(predict_net, param_dict)
310 predict_net, param_dict, _ = fuse_bn(predict_net, param_dict,
False)
311 if args.fuse_conv_relu:
312 predict_net = fuse_conv_relu(predict_net)
314 external_outputs_opt = {}
315 if args.verify_input:
316 workspace.ResetWorkspace()
317 device_option = core.DeviceOption(caffe2_pb2.IDEEP)
if args.fuse_conv_relu
else core.DeviceOption(caffe2_pb2.CPU)
318 with core.DeviceScope(device_option):
319 for k, v
in param_dict.items():
320 workspace.FeedBlob(k, v, device_option)
321 for k, v
in external_inputs.items():
322 workspace.FeedBlob(k, v, device_option)
323 workspace.RunNetOnce(predict_net)
324 for o
in predict_net.external_output:
325 external_outputs_opt[o] = workspace.FetchBlob(o)
326 assert np.allclose(external_outputs[o],
327 external_outputs_opt[o],
331 for i, o
in enumerate(predict_net.op):
332 print(
"op[{}]: {}".format(i, o.type))
333 init_net = gen_init_net_from_blobs(param_dict)
334 with open(
'init_net.pb',
'wb')
as f:
335 f.write(init_net.SerializeToString())
336 with open(
'predict_net.pb',
'wb')
as f:
337 f.write(predict_net.SerializeToString())
339 if __name__ ==
'__main__':
340 args = GetArgumentParser().parse_args()