Caffe2 - Python API
A deep learning, cross platform ML framework
benchmark_generator.py
1 #!/usr/bin/env python
2 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 import string
8 
9 import argparse
10 
11 import numpy as np
12 
13 from caffe2.python.model_helper import ModelHelper
14 from caffe2.python.predictor import mobile_exporter
15 from caffe2.python import core, workspace, brew, utils
16 
17 
18 def parse_kwarg(kwarg_str):
19  key, value = map(string.strip, kwarg_str.split("=", 1))
20  try:
21  value = int(value)
22  except ValueError:
23  try:
24  value = float(value)
25  except ValueError:
26  pass
27  return key, value
28 
29 
30 def main(args):
31  # User defined keyword arguments
32  kwargs = {"order": "NCHW"}
33  kwargs.update(dict(args.kwargs))
34 
35  model = ModelHelper(name=args.benchmark_name)
36 
37  op_type = args.operator # assumes a brew type op name
38  input_name = args.input_name
39  output_name = args.output_name
40 
41  iters = int(args.iters)
42  for i in range(iters):
43  input_blob_name = input_name + (str(i) if i > 0 and args.chain else '')
44  output_blob_name = output_name + str(i + 1)
45  add_op = getattr(brew, op_type)
46  add_op(model, input_blob_name, output_blob_name, **kwargs)
47  if args.chain:
48  input_name, output_name = output_name, input_name
49 
50  workspace.RunNetOnce(model.param_init_net)
51  extra_init_net_ops = []
52 
53  def make_blob_on_context(blob_name, blob_data, context):
54  if context.upper() != "CPU":
55  blob_name_modified = "{}_CPU".format(blob_name)
56  else: # CPU case is simple
57  blob_name_modified = blob_name
58 
59  fill_op = core.CreateOperator(
60  "GivenTensorFill", [], [blob_name_modified],
61  arg=[
62  utils.MakeArgument("shape", blob_data.shape),
63  utils.MakeArgument("values", blob_data)
64  ]
65  )
66  extra_init_net_ops.append(fill_op)
67 
68  # We need to create CPU blobs and add some copy operations in
69  # the init_net
70  if context.upper() == "OPENGL":
71  copy_op = core.CreateOperator("CopyToOpenGL", [blob_name_modified],
72  [blob_name])
73  extra_init_net_ops.append(copy_op)
74 
75  for unparsed_blob in args.blob:
76  name, unparsed_dims = unparsed_blob.split('=')
77  dims = [int(d) for d in unparsed_dims.split(',')]
78  np_input = np.random.rand(*dims).astype(np.float32)
79  make_blob_on_context(name, np_input, args.context)
80 
81  init_net, predict_net = mobile_exporter.Export(
82  workspace, model.net, model.params
83  )
84  init_net.op.extend(extra_init_net_ops)
85 
86  # Handle manual rewrite
87  if args.context.upper() == "OPENGL":
88  old_ops = [op for op in predict_net.op]
89  del predict_net.op[:]
90  for op in old_ops:
91  op.type = 'OpenGL{}'.format(op.type)
92  predict_net.op.extend(old_ops)
93 
94  if args.debug:
95  print("init_net:")
96  for op in init_net.op:
97  print(" ", op.type, op.input, "-->", op.output)
98  print("predict_net:")
99  for op in predict_net.op:
100  print(" ", op.type, op.input, "-->", op.output)
101 
102  with open(args.predict_net, 'wb') as f:
103  f.write(predict_net.SerializeToString())
104  with open(args.init_net, 'wb') as f:
105  f.write(init_net.SerializeToString())
106 
107 if __name__ == "__main__":
108  parser = argparse.ArgumentParser(
109  description="Utilitity to generate Caffe2 benchmark models.")
110  parser.add_argument("operator", help="Caffe2 operator to benchmark.")
111  parser.add_argument("-b", "--blob",
112  help="Instantiate a blob --blob name=dim1,dim2,dim3",
113  action='append')
114  parser.add_argument("--context", help="Context to run on.", default="CPU")
115  parser.add_argument("--kwargs", help="kwargs to pass to operator.",
116  nargs="*", type=parse_kwarg, default=[])
117  parser.add_argument("--init_net", help="Output initialization net.",
118  default="init_net.pb")
119  parser.add_argument("--predict_net", help="Output prediction net.",
120  default="predict_net.pb")
121  parser.add_argument("--benchmark_name",
122  help="Name of the benchmark network",
123  default="benchmark")
124  parser.add_argument("--input_name", help="Name of the input blob.",
125  default="data")
126  parser.add_argument("--output_name", help="Name of the output blob.",
127  default="output")
128  parser.add_argument("--iters",
129  help="Number of iterations to run the operator.",
130  default="1")
131  parser.add_argument("-d", "--debug", help="Print debug information.",
132  action='store_true')
133  parser.add_argument("-c", "--chain",
134  help="Chain ops together (create data dependencies)",
135  action='store_true')
136  args = parser.parse_args()
137  main(args)