Caffe2 - Python API
A deep learning, cross platform ML framework
bench_gen.py
1 #!/usr/bin/env python
2 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 import argparse
9 import ast
10 
11 from caffe2.python.model_helper import ModelHelper
12 from caffe2.python.predictor import mobile_exporter
13 from caffe2.python import workspace, brew
14 
15 
16 def parse_kwarg(kwarg_str):
17  key, value = kwarg_str.split('=')
18  try:
19  value = ast.literal_eval(value)
20  except ValueError:
21  pass
22  return key, value
23 
24 
25 def main(args):
26  # User defined keyword arguments
27  kwargs = {"order": "NCHW", "use_cudnn": False}
28  kwargs.update(dict(args.kwargs))
29 
30  model = ModelHelper(name=args.benchmark_name)
31 
32  op_type = args.operator # assumes a brew type op name
33  input_name = args.input_name
34  output_name = args.output_name
35 
36  iters = int(args.instances)
37  for i in range(iters):
38  input_blob_name = input_name + (str(i) if i > 0 and args.chain else '')
39  output_blob_name = output_name + str(i + 1)
40  add_op = getattr(brew, op_type)
41  add_op(model, input_blob_name, output_blob_name, **kwargs)
42  if args.chain:
43  input_name, output_name = output_name, input_name
44 
45  workspace.RunNetOnce(model.param_init_net)
46 
47  init_net, predict_net = mobile_exporter.Export(
48  workspace, model.net, model.params
49  )
50 
51  if args.debug:
52  print("init_net:")
53  for op in init_net.op:
54  print(" ", op.type, op.input, "-->", op.output)
55  print("predict_net:")
56  for op in predict_net.op:
57  print(" ", op.type, op.input, "-->", op.output)
58 
59  with open(args.predict_net, 'wb') as f:
60  f.write(predict_net.SerializeToString())
61  with open(args.init_net, 'wb') as f:
62  f.write(init_net.SerializeToString())
63 
64 
65 if __name__ == "__main__":
66  parser = argparse.ArgumentParser(
67  description="Utilitity to generate Caffe2 benchmark models.")
68  parser.add_argument("operator", help="Caffe2 operator to benchmark.")
69  parser.add_argument("-b", "--blob",
70  help="Instantiate a blob --blob name=dim1,dim2,dim3",
71  action='append')
72  parser.add_argument("--context", help="Context to run on.", default="CPU")
73  parser.add_argument("--kwargs", help="kwargs to pass to operator.",
74  nargs="*", type=parse_kwarg, default=[])
75  parser.add_argument("--init_net", help="Output initialization net.",
76  default="init_net.pb")
77  parser.add_argument("--predict_net", help="Output prediction net.",
78  default="predict_net.pb")
79  parser.add_argument("--benchmark_name",
80  help="Name of the benchmark network",
81  default="benchmark")
82  parser.add_argument("--input_name", help="Name of the input blob.",
83  default="data")
84  parser.add_argument("--output_name", help="Name of the output blob.",
85  default="output")
86  parser.add_argument("--instances",
87  help="Number of instances to run the operator.",
88  default="1")
89  parser.add_argument("-d", "--debug", help="Print debug information.",
90  action='store_true')
91  parser.add_argument("-c", "--chain",
92  help="Chain ops together (create data dependencies)",
93  action='store_true')
94  args = parser.parse_args()
95  main(args)