Caffe2 - Python API
A deep learning, cross platform ML framework
transform_ideep_net.py
1 from __future__ import absolute_import
2 from __future__ import division
3 from __future__ import print_function
4 from __future__ import unicode_literals
5 
6 import argparse
7 import copy
8 import json
9 import os.path
10 
11 import numpy as np
12 
13 from caffe2.proto import caffe2_pb2
14 from caffe2.python import core, workspace, utils
16 
17 
18 
19 def pairwise(iterable):
20  from itertools import tee
21  a, b = tee(iterable)
22  next(b, None)
23  return zip(a, b)
24 
25 
26 def last_producer(ops, blob):
27  for (i, op) in reversed(list(enumerate(ops))):
28  if blob in op.output:
29  return i
30  raise ValueError("Failed to find last producer of blob, %s", blob)
31 
32 
33 def blob_uses(net, blob):
34  u = []
35  for i, op in enumerate(net.op):
36  if blob in op.input or blob in op.control_input:
37  u.append(i)
38  return u
39 
40 
41 def GetArgumentParser():
42  parser = argparse.ArgumentParser(description="Caffe2 optimization")
43  parser.add_argument("--init_net",
44  type=argparse.FileType('rb'),
45  help="init net")
46  parser.add_argument("--pred_net",
47  type=argparse.FileType('rb'),
48  help="predict net")
49  parser.add_argument("--verify_input",
50  type=argparse.FileType('r'),
51  help="input dims for verification")
52  parser.add_argument("--fuse_bn", default=False, action='store_true')
53  parser.add_argument("--fuse_mul_add", default=False, action='store_true')
54  parser.add_argument("--fuse_conv_relu", default=False, action='store_true')
55  return parser
56 
57 
58 def fuse_first_bn(net, params, removed_tensors):
59  net = copy.deepcopy(net)
60  params = copy.deepcopy(params)
61 
62  for ((i, current), (j, next_)) in pairwise(enumerate(net.op)):
63  if next_.input[0] != current.output[0]:
64  continue
65 
66  if current.type not in ("Conv", "ConvTranspose") \
67  or next_.type != "SpatialBN":
68  continue
69  if len(blob_uses(net, current.output[0])) != 1:
70  # Can't fuse if more than one user
71  continue
72 
73  # else, can fuse
74  conv = current
75  bn = next_
76  fused_conv = copy.deepcopy(conv)
77  fused_conv.output[0] = bn.output[0]
78 
79  # Fix fused_conv to ensure we have a bias passed.
80  if len(fused_conv.input) != 3:
81  bias_name = "{}_bias".format(conv.input[1])
82  net.external_input.extend([bias_name])
83  fused_conv.input.extend([bias_name])
84  for arg in fused_conv.arg:
85  if arg.name == "no_bias":
86  arg.i = 0
87 
88  conv_weight = params[conv.input[1]]
89  conv_bias = params[conv.input[2]] if len(conv.input) == 3 \
90  else np.zeros(shape=(conv_weight.shape[0])).astype(np.float32)
91 
92  bn_scale = params[bn.input[1]]
93  bn_bias = params[bn.input[2]]
94  bn_running_mean = params[bn.input[3]]
95  bn_running_var = params[bn.input[4]]
96 
97  # First, BN computation can be phrased as follows:
98  # (X - running_mean) * (1.0 / sqrt(running_var + eps)) *
99  # bn_scale + bias
100  # Thus, we can rewrite bn_scale as:
101  # X * bn_scale * 1.0 / (sqrt(running_var + eps)) + (bias -
102  # running_mean * (1.0 / sqrt(running_var + eps)) * bn_scale)
103  # Thus, can just have the affine transform
104  # X * A + B
105  # where
106  # A = bn_scale * 1.0 / (sqrt(running_var + eps))
107  # B = (bias - running_mean * (1.0 / sqrt(running_var + eps))
108  # * bn_scale)
109  eps = 1.0e-5
110  for arg in bn.arg:
111  if arg.name == "epsilon":
112  eps = arg.f
113  A = bn_scale * 1.0 / (np.sqrt(bn_running_var + eps))
114  B = bn_bias - bn_running_mean * A
115 
116  # This identify should hold if we have correctly fused
117  # np.testing.assert_array_equal(
118  # params[conv.output[0]] * A + B,
119  # params[bn.output[0]])
120 
121  # Now, we have that the computation made is the following:
122  # ((X `conv` W) + b) * A + B
123  # Then, we can simply fuse this as follows:
124  # (X `conv` (W * A)) + b * A + B
125  # which is simply
126  # (X `conv` Q) + C
127  # where
128 
129  # Q = W * A
130  # C = b * A + B
131 
132  # For ConvTranspose, from the view of convolutions as a
133  # Toepeliz multiplication, we have W_ = W^T, so the weights
134  # are laid out as (R, S, K, K) (vs (S, R, K, K) for a Conv),
135  # so the weights broadcast slightly differently. Remember, our
136  # BN scale 'B' is of size (S,)
137 
138  A_ = A.reshape(-1, 1, 1, 1) if conv.type == "Conv" else \
139  A.reshape(1, -1, 1, 1)
140 
141  C = conv_bias * A + B
142  Q = conv_weight * A_
143 
144  params[fused_conv.input[1]] = Q
145  params[fused_conv.input[2]] = C
146  new_ops = net.op[:i] + [fused_conv] + net.op[j + 1:]
147  del net.op[:]
148  removed_tensors.append(bn.input[1])
149  removed_tensors.append(bn.input[2])
150  removed_tensors.append(bn.input[3])
151  removed_tensors.append(bn.input[4])
152  del params[bn.input[1]]
153  del params[bn.input[2]]
154  del params[bn.input[3]]
155  del params[bn.input[4]]
156  net.op.extend(new_ops)
157  break
158  return net, params, removed_tensors
159 
160 
161 def fuse_bn(net, params, ignore_failure):
162  # Run until we hit a fixed point
163  removed_tensors = []
164  while True:
165  (next_net, next_params, removed_tensors) = \
166  fuse_first_bn(net, params, removed_tensors)
167  if len(next_net.op) == len(net.op):
168  if (
169  any(op.type == "SpatialBN" for op in next_net.op) and
170  not ignore_failure
171  ):
172  raise Exception(
173  "Model contains SpatialBN op after fusion: %s", next_net)
174  return (next_net, next_params, removed_tensors)
175  net, params, removed_tensors = (next_net, next_params, removed_tensors)
176 
177 
178 def fuse_first_mul_add(net, params, removed_tensors):
179  net = copy.deepcopy(net)
180  params = copy.deepcopy(params)
181 
182  for ((i, current), (j, next_)) in pairwise(enumerate(net.op)):
183  if current.type != "Mul" or next_.type != "Add":
184  continue
185 
186  if next_.input[0] != current.output[0]:
187  raise Exception("Failure to fuse")
188 
189  if len(blob_uses(net, current.output[0])) != 1:
190  raise Exception("Failure to fuse")
191 
192  log.info("Fusing at index %s", i)
193  mul_ = current
194  add_ = next_
195  batch_norm = copy.deepcopy(mul_)
196  batch_norm.type = "SpatialBN"
197  batch_norm.arg.extend([utils.MakeArgument("is_test", 1)])
198  batch_norm.arg.extend([utils.MakeArgument("epsilon", float(1e-9))])
199 
200  def s(x):
201  return "{}{}".format(add_.output[0], x)
202  fake_mean = s("_mean")
203  fake_var = s("_var")
204 
205  del batch_norm.input[:]
206  batch_norm.input.extend([mul_.input[0],
207  mul_.input[1],
208  add_.input[1],
209  fake_mean,
210  fake_var])
211  params[fake_mean] = np.zeros_like(params[mul_.input[1]])
212  params[fake_var] = np.ones_like(params[mul_.input[1]])
213  net.external_input.extend([fake_mean, fake_var])
214 
215  batch_norm.output[0] = add_.output[0]
216  new_ops = net.op[:i] + [batch_norm] + net.op[j + 1:]
217  del net.op[:]
218  net.op.extend(new_ops)
219  break
220  return net, params, removed_tensors
221 
222 
223 def fuse_mul_add(net, params):
224  # Run until we hit a fixed point
225  removed_tensors = []
226  while True:
227  (next_net, next_params, removed_tensors) = \
228  fuse_first_mul_add(net, params, removed_tensors)
229  if len(next_net.op) == len(net.op):
230  return (next_net, next_params, removed_tensors)
231  net, params, removed_tensors = (next_net, next_params, removed_tensors)
232 
233 
234 def add_tensor(net, name, blob):
235  ''' Create an operator to store the tensor 'blob',
236  run the operator to put the blob to workspace.
237  uint8 is stored as an array of string with one element.
238  '''
239  kTypeNameMapper = {
240  np.dtype('float32'): "GivenTensorFill",
241  np.dtype('int32'): "GivenTensorIntFill",
242  np.dtype('int64'): "GivenTensorInt64Fill",
243  np.dtype('uint8'): "GivenTensorStringFill",
244  }
245 
246  shape = blob.shape
247  values = blob
248  # pass array of uint8 as a string to save storage
249  # storing uint8_t has a large overhead for now
250  if blob.dtype == np.dtype('uint8'):
251  shape = [1]
252  values = [str(blob.data)]
253 
254  op = core.CreateOperator(
255  kTypeNameMapper[blob.dtype],
256  [], [name],
257  arg=[
258  utils.MakeArgument("shape", shape),
259  utils.MakeArgument("values", values),
260  ]
261  )
262  net.op.extend([op])
263 
264 
265 def gen_init_net_from_blobs(blobs):
266  ''' Generate an initialization net based on a blob dict '''
267  ret = caffe2_pb2.NetDef()
268  for name, blob in blobs.items():
269  add_tensor(ret, name, blob)
270  return ret
271 
272 
273 def fuse_conv_relu(net):
274  net = copy.deepcopy(net)
275  device_option = core.DeviceOption(caffe2_pb2.IDEEP)
276  for op in net.op:
277  op.device_option.CopyFrom(device_option)
278 
279  new_net = caffe2_pb2.NetDef()
280  new_net.ParseFromString(C.transform_optimizeForIDEEP(net.SerializeToString()))
281  return new_net
282 
283 
284 def Optimize(args):
285  init_net = caffe2_pb2.NetDef()
286  predict_net = caffe2_pb2.NetDef()
287  init_net.ParseFromString(args.init_net.read())
288  predict_net.ParseFromString(args.pred_net.read())
289 
290  workspace.ResetWorkspace()
291  workspace.RunNetOnce(init_net)
292  param_dict = {p: workspace.FetchBlob(p) for p in workspace.Blobs()}
293 
294  external_inputs = {}
295  external_outputs = {}
296  if args.verify_input:
297  value_info = json.load(args.verify_input)
298  input_shapes = {k : v[-1] for (k, v) in value_info.items()}
299  print("input info: {}".format(input_shapes))
300  for k, v in input_shapes.items():
301  external_inputs[k] = np.random.randn(*v).astype(np.float32)
302  workspace.FeedBlob(k, external_inputs[k])
303  workspace.RunNetOnce(predict_net)
304  for o in predict_net.external_output:
305  external_outputs[o] = workspace.FetchBlob(o)
306 
307  if args.fuse_mul_add:
308  predict_net, param_dict, _ = fuse_mul_add(predict_net, param_dict)
309  if args.fuse_bn:
310  predict_net, param_dict, _ = fuse_bn(predict_net, param_dict, False)
311  if args.fuse_conv_relu:
312  predict_net = fuse_conv_relu(predict_net)
313 
314  external_outputs_opt = {}
315  if args.verify_input:
316  workspace.ResetWorkspace()
317  device_option = core.DeviceOption(caffe2_pb2.IDEEP) if args.fuse_conv_relu else core.DeviceOption(caffe2_pb2.CPU)
318  with core.DeviceScope(device_option):
319  for k, v in param_dict.items():
320  workspace.FeedBlob(k, v, device_option)
321  for k, v in external_inputs.items():
322  workspace.FeedBlob(k, v, device_option)
323  workspace.RunNetOnce(predict_net)
324  for o in predict_net.external_output:
325  external_outputs_opt[o] = workspace.FetchBlob(o)
326  assert np.allclose(external_outputs[o],
327  external_outputs_opt[o],
328  atol=1e-3,
329  rtol=1e-3)
330 
331  for i, o in enumerate(predict_net.op):
332  print("op[{}]: {}".format(i, o.type))
333  init_net = gen_init_net_from_blobs(param_dict)
334  with open('init_net.pb', 'wb') as f:
335  f.write(init_net.SerializeToString())
336  with open('predict_net.pb', 'wb') as f:
337  f.write(predict_net.SerializeToString())
338 
339 if __name__ == '__main__':
340  args = GetArgumentParser().parse_args()
341  Optimize(args)