1 from __future__ 
import absolute_import
     2 from __future__ 
import division
     3 from __future__ 
import print_function
     4 from __future__ 
import unicode_literals
    13 from caffe2.proto 
import caffe2_pb2
    19 def pairwise(iterable):
    20     from itertools 
import tee
    26 def last_producer(ops, blob):
    27     for (i, op) 
in reversed(list(enumerate(ops))):
    30     raise ValueError(
"Failed to find last producer of blob, %s", blob)
    33 def blob_uses(net, blob):
    35     for i, op 
in enumerate(net.op):
    36         if blob 
in op.input 
or blob 
in op.control_input:
    41 def GetArgumentParser():
    42     parser = argparse.ArgumentParser(description=
"Caffe2 optimization")
    43     parser.add_argument(
"--init_net",
    44                         type=argparse.FileType(
'rb'),
    46     parser.add_argument(
"--pred_net",
    47                         type=argparse.FileType(
'rb'),
    49     parser.add_argument(
"--verify_input",
    50                         type=argparse.FileType(
'r'),    51                         help="input dims for verification")
    52     parser.add_argument(
"--fuse_bn", default=
False, action=
'store_true')
    53     parser.add_argument(
"--fuse_mul_add", default=
False, action=
'store_true')
    54     parser.add_argument(
"--fuse_conv_relu", default=
False, action=
'store_true')
    58 def fuse_first_bn(net, params, removed_tensors):
    59     net = copy.deepcopy(net)
    60     params = copy.deepcopy(params)
    62     for ((i, current), (j, next_)) 
in pairwise(enumerate(net.op)):
    63         if next_.input[0] != current.output[0]:
    66         if current.type 
not in (
"Conv", 
"ConvTranspose") \
    67            or next_.type != 
"SpatialBN":
    69         if len(blob_uses(net, current.output[0])) != 1:
    76         fused_conv = copy.deepcopy(conv)
    77         fused_conv.output[0] = bn.output[0]
    80         if len(fused_conv.input) != 3:
    81             bias_name = 
"{}_bias".format(conv.input[1])
    82             net.external_input.extend([bias_name])
    83             fused_conv.input.extend([bias_name])
    84             for arg 
in fused_conv.arg:
    85                 if arg.name == 
"no_bias":
    88         conv_weight = params[conv.input[1]]
    89         conv_bias = params[conv.input[2]] 
if len(conv.input) == 3 \
    90             else np.zeros(shape=(conv_weight.shape[0])).astype(np.float32)
    92         bn_scale = params[bn.input[1]]
    93         bn_bias = params[bn.input[2]]
    94         bn_running_mean = params[bn.input[3]]
    95         bn_running_var = params[bn.input[4]]
   111             if arg.name == 
"epsilon":
   113         A = bn_scale * 1.0 / (np.sqrt(bn_running_var + eps))
   114         B = bn_bias - bn_running_mean * A
   138         A_ = A.reshape(-1, 1, 1, 1) 
if conv.type == 
"Conv" else \
   139             A.reshape(1, -1, 1, 1)
   141         C = conv_bias * A + B
   144         params[fused_conv.input[1]] = Q
   145         params[fused_conv.input[2]] = C
   146         new_ops = net.op[:i] + [fused_conv] + net.op[j + 1:]
   148         removed_tensors.append(bn.input[1])
   149         removed_tensors.append(bn.input[2])
   150         removed_tensors.append(bn.input[3])
   151         removed_tensors.append(bn.input[4])
   152         del params[bn.input[1]]
   153         del params[bn.input[2]]
   154         del params[bn.input[3]]
   155         del params[bn.input[4]]
   156         net.op.extend(new_ops)
   158     return net, params, removed_tensors
   161 def fuse_bn(net, params, ignore_failure):
   165         (next_net, next_params, removed_tensors) = \
   166             fuse_first_bn(net, params, removed_tensors)
   167         if len(next_net.op) == len(net.op):
   169                 any(op.type == 
"SpatialBN" for op 
in next_net.op) 
and   173                     "Model contains SpatialBN op after fusion: %s", next_net)
   174             return (next_net, next_params, removed_tensors)
   175         net, params, removed_tensors = (next_net, next_params, removed_tensors)
   178 def fuse_first_mul_add(net, params, removed_tensors):
   179     net = copy.deepcopy(net)
   180     params = copy.deepcopy(params)
   182     for ((i, current), (j, next_)) 
in pairwise(enumerate(net.op)):
   183         if current.type != 
"Mul" or next_.type != 
"Add":
   186         if next_.input[0] != current.output[0]:
   187             raise Exception(
"Failure to fuse")
   189         if len(blob_uses(net, current.output[0])) != 1:
   190             raise Exception(
"Failure to fuse")
   192         log.info(
"Fusing at index %s", i)
   195         batch_norm = copy.deepcopy(mul_)
   196         batch_norm.type = 
"SpatialBN"   197         batch_norm.arg.extend([utils.MakeArgument(
"is_test", 1)])
   198         batch_norm.arg.extend([utils.MakeArgument(
"epsilon", float(1e-9))])
   201             return "{}{}".format(add_.output[0], x)
   202         fake_mean = s(
"_mean")
   205         del batch_norm.input[:]
   206         batch_norm.input.extend([mul_.input[0],
   211         params[fake_mean] = np.zeros_like(params[mul_.input[1]])
   212         params[fake_var] = np.ones_like(params[mul_.input[1]])
   213         net.external_input.extend([fake_mean, fake_var])
   215         batch_norm.output[0] = add_.output[0]
   216         new_ops = net.op[:i] + [batch_norm] + net.op[j + 1:]
   218         net.op.extend(new_ops)
   220     return net, params, removed_tensors
   223 def fuse_mul_add(net, params):
   227         (next_net, next_params, removed_tensors) = \
   228             fuse_first_mul_add(net, params, removed_tensors)
   229         if len(next_net.op) == len(net.op):
   230             return (next_net, next_params, removed_tensors)
   231         net, params, removed_tensors = (next_net, next_params, removed_tensors)
   234 def add_tensor(net, name, blob):
   235     ''' Create an operator to store the tensor 'blob',   236         run the operator to put the blob to workspace.   237         uint8 is stored as an array of string with one element.   240         np.dtype(
'float32'): 
"GivenTensorFill",
   241         np.dtype(
'int32'): 
"GivenTensorIntFill",
   242         np.dtype(
'int64'): 
"GivenTensorInt64Fill",
   243         np.dtype(
'uint8'): 
"GivenTensorStringFill",
   250     if blob.dtype == np.dtype(
'uint8'):
   252         values = [str(blob.data)]
   254     op = core.CreateOperator(
   255         kTypeNameMapper[blob.dtype],
   258             utils.MakeArgument(
"shape", shape),
   259             utils.MakeArgument(
"values", values),
   265 def gen_init_net_from_blobs(blobs):
   266     ''' Generate an initialization net based on a blob dict '''   267     ret = caffe2_pb2.NetDef()
   268     for name, blob 
in blobs.items():
   269         add_tensor(ret, name, blob)
   273 def fuse_conv_relu(net):
   274     net = copy.deepcopy(net)
   275     device_option = core.DeviceOption(caffe2_pb2.IDEEP)
   277         op.device_option.CopyFrom(device_option)
   279     new_net = caffe2_pb2.NetDef()
   280     new_net.ParseFromString(C.transform_optimizeForIDEEP(net.SerializeToString()))
   285     init_net = caffe2_pb2.NetDef()
   286     predict_net = caffe2_pb2.NetDef()
   287     init_net.ParseFromString(args.init_net.read())
   288     predict_net.ParseFromString(args.pred_net.read())
   290     workspace.ResetWorkspace()
   291     workspace.RunNetOnce(init_net)
   292     param_dict = {p: workspace.FetchBlob(p) 
for p 
in workspace.Blobs()}
   295     external_outputs = {}
   296     if args.verify_input:
   297         value_info = json.load(args.verify_input)
   298         input_shapes = {k : v[-1] 
for (k, v) 
in value_info.items()}
   299         print(
"input info: {}".format(input_shapes))
   300         for k, v 
in input_shapes.items():
   301             external_inputs[k] = np.random.randn(*v).astype(np.float32)
   302             workspace.FeedBlob(k, external_inputs[k])
   303         workspace.RunNetOnce(predict_net)
   304         for o 
in predict_net.external_output:
   305             external_outputs[o] = workspace.FetchBlob(o)
   307     if args.fuse_mul_add:
   308         predict_net, param_dict, _ = fuse_mul_add(predict_net, param_dict)
   310         predict_net, param_dict, _ = fuse_bn(predict_net, param_dict, 
False)
   311     if args.fuse_conv_relu:
   312         predict_net = fuse_conv_relu(predict_net)
   314     external_outputs_opt = {}
   315     if args.verify_input:
   316         workspace.ResetWorkspace()
   317         device_option = core.DeviceOption(caffe2_pb2.IDEEP) 
if args.fuse_conv_relu 
else core.DeviceOption(caffe2_pb2.CPU)
   318         with core.DeviceScope(device_option):
   319             for k, v 
in param_dict.items():
   320                 workspace.FeedBlob(k, v, device_option)
   321             for k, v 
in external_inputs.items():
   322                 workspace.FeedBlob(k, v, device_option)
   323             workspace.RunNetOnce(predict_net)
   324             for o 
in predict_net.external_output:
   325                 external_outputs_opt[o] = workspace.FetchBlob(o)
   326                 assert np.allclose(external_outputs[o],
   327                                    external_outputs_opt[o],
   331     for i, o 
in enumerate(predict_net.op):
   332         print(
"op[{}]: {}".format(i, o.type))
   333     init_net = gen_init_net_from_blobs(param_dict)
   334     with open(
'init_net.pb', 
'wb') 
as f:
   335         f.write(init_net.SerializeToString())
   336     with open(
'predict_net.pb', 
'wb') 
as f:
   337         f.write(predict_net.SerializeToString())
   339 if __name__ == 
'__main__':
   340     args = GetArgumentParser().parse_args()