3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
16 def parse_kwarg(kwarg_str):
17 key, value = kwarg_str.split(
'=')
19 value = ast.literal_eval(value)
27 kwargs = {
"order":
"NCHW",
"use_cudnn":
False}
28 kwargs.update(dict(args.kwargs))
30 model = ModelHelper(name=args.benchmark_name)
32 op_type = args.operator
33 input_name = args.input_name
34 output_name = args.output_name
36 iters = int(args.instances)
37 for i
in range(iters):
38 input_blob_name = input_name + (str(i)
if i > 0
and args.chain
else '')
39 output_blob_name = output_name + str(i + 1)
40 add_op = getattr(brew, op_type)
41 add_op(model, input_blob_name, output_blob_name, **kwargs)
43 input_name, output_name = output_name, input_name
45 workspace.RunNetOnce(model.param_init_net)
47 init_net, predict_net = mobile_exporter.Export(
48 workspace, model.net, model.params
53 for op
in init_net.op:
54 print(
" ", op.type, op.input,
"-->", op.output)
56 for op
in predict_net.op:
57 print(
" ", op.type, op.input,
"-->", op.output)
59 with open(args.predict_net,
'wb')
as f:
60 f.write(predict_net.SerializeToString())
61 with open(args.init_net,
'wb')
as f:
62 f.write(init_net.SerializeToString())
65 if __name__ ==
"__main__":
66 parser = argparse.ArgumentParser(
67 description=
"Utilitity to generate Caffe2 benchmark models.")
68 parser.add_argument(
"operator", help=
"Caffe2 operator to benchmark.")
69 parser.add_argument(
"-b",
"--blob",
70 help=
"Instantiate a blob --blob name=dim1,dim2,dim3",
72 parser.add_argument(
"--context", help=
"Context to run on.", default=
"CPU")
73 parser.add_argument(
"--kwargs", help=
"kwargs to pass to operator.",
74 nargs=
"*", type=parse_kwarg, default=[])
75 parser.add_argument(
"--init_net", help=
"Output initialization net.",
76 default=
"init_net.pb")
77 parser.add_argument(
"--predict_net", help=
"Output prediction net.",
78 default=
"predict_net.pb")
79 parser.add_argument(
"--benchmark_name",
80 help=
"Name of the benchmark network",
82 parser.add_argument(
"--input_name", help=
"Name of the input blob.",
84 parser.add_argument(
"--output_name", help=
"Name of the output blob.",
86 parser.add_argument(
"--instances",
87 help=
"Number of instances to run the operator.",
89 parser.add_argument(
"-d",
"--debug", help=
"Print debug information.",
91 parser.add_argument(
"-c",
"--chain",
92 help=
"Chain ops together (create data dependencies)",
94 args = parser.parse_args()