3 from __future__ 
import absolute_import
     4 from __future__ 
import division
     5 from __future__ 
import print_function
     6 from __future__ 
import unicode_literals
    18 def parse_kwarg(kwarg_str):
    19     key, value = map(string.strip, kwarg_str.split(
"=", 1))
    32     kwargs = {
"order": 
"NCHW"}
    33     kwargs.update(dict(args.kwargs))
    35     model = ModelHelper(name=args.benchmark_name)
    37     op_type = args.operator  
    38     input_name = args.input_name
    39     output_name = args.output_name
    41     iters = int(args.iters)
    42     for i 
in range(iters):
    43         input_blob_name = input_name + (str(i) 
if i > 0 
and args.chain 
else '')
    44         output_blob_name = output_name + str(i + 1)
    45         add_op = getattr(brew, op_type)
    46         add_op(model, input_blob_name, output_blob_name, **kwargs)
    48             input_name, output_name = output_name, input_name
    50     workspace.RunNetOnce(model.param_init_net)
    51     extra_init_net_ops = []
    53     def make_blob_on_context(blob_name, blob_data, context):
    54         if context.upper() != 
"CPU":
    55             blob_name_modified = 
"{}_CPU".format(blob_name)
    57             blob_name_modified = blob_name
    59         fill_op = core.CreateOperator(
    60             "GivenTensorFill", [], [blob_name_modified],
    62                 utils.MakeArgument(
"shape", blob_data.shape),
    63                 utils.MakeArgument(
"values", blob_data)
    66         extra_init_net_ops.append(fill_op)
    70         if context.upper() == 
"OPENGL":
    71             copy_op = core.CreateOperator(
"CopyToOpenGL", [blob_name_modified],
    73             extra_init_net_ops.append(copy_op)
    75     for unparsed_blob 
in args.blob:
    76         name, unparsed_dims = unparsed_blob.split(
'=')
    77         dims = [int(d) 
for d 
in unparsed_dims.split(
',')]
    78         np_input = np.random.rand(*dims).astype(np.float32)
    79         make_blob_on_context(name, np_input, args.context)
    81     init_net, predict_net = mobile_exporter.Export(
    82         workspace, model.net, model.params
    84     init_net.op.extend(extra_init_net_ops)
    87     if args.context.upper() == 
"OPENGL":
    88         old_ops = [op 
for op 
in predict_net.op]
    91             op.type = 
'OpenGL{}'.format(op.type)
    92         predict_net.op.extend(old_ops)
    96         for op 
in init_net.op:
    97             print(
" ", op.type, op.input, 
"-->", op.output)
    99         for op 
in predict_net.op:
   100             print(
" ", op.type, op.input, 
"-->", op.output)
   102     with open(args.predict_net, 
'wb') 
as f:
   103         f.write(predict_net.SerializeToString())
   104     with open(args.init_net, 
'wb') 
as f:
   105         f.write(init_net.SerializeToString())
   107 if __name__ == 
"__main__":
   108     parser = argparse.ArgumentParser(
   109         description=
"Utilitity to generate Caffe2 benchmark models.")
   110     parser.add_argument(
"operator", help=
"Caffe2 operator to benchmark.")
   111     parser.add_argument(
"-b", 
"--blob",
   112                         help=
"Instantiate a blob --blob name=dim1,dim2,dim3",
   114     parser.add_argument(
"--context", help=
"Context to run on.", default=
"CPU")
   115     parser.add_argument(
"--kwargs", help=
"kwargs to pass to operator.",
   116                         nargs=
"*", type=parse_kwarg, default=[])
   117     parser.add_argument(
"--init_net", help=
"Output initialization net.",
   118                         default=
"init_net.pb")
   119     parser.add_argument(
"--predict_net", help=
"Output prediction net.",
   120                         default=
"predict_net.pb")
   121     parser.add_argument(
"--benchmark_name",
   122                         help=
"Name of the benchmark network",
   124     parser.add_argument(
"--input_name", help=
"Name of the input blob.",
   126     parser.add_argument(
"--output_name", help=
"Name of the output blob.",
   128     parser.add_argument(
"--iters",
   129                         help=
"Number of iterations to run the operator.",
   131     parser.add_argument(
"-d", 
"--debug", help=
"Print debug information.",
   133     parser.add_argument(
"-c", 
"--chain",
   134                         help=
"Chain ops together (create data dependencies)",
   136     args = parser.parse_args()