Caffe2 - Python API
A deep learning, cross platform ML framework
All Classes Namespaces Functions
test_ideep_net.py
1 from __future__ import absolute_import
2 from __future__ import division
3 from __future__ import print_function
4 from __future__ import unicode_literals
5 
6 from caffe2.proto import caffe2_pb2
7 from caffe2.python import core, workspace
8 from caffe2.python.models.download import ModelDownloader
9 import numpy as np
10 import argparse
11 import time
12 import os.path
13 
14 
15 def GetArgumentParser():
16  parser = argparse.ArgumentParser(description="Caffe2 benchmark.")
17  parser.add_argument(
18  "--batch_size",
19  type=int,
20  default=128,
21  help="The batch size."
22  )
23  parser.add_argument("--model", type=str, help="The model to benchmark.")
24  parser.add_argument(
25  "--order",
26  type=str,
27  default="NCHW",
28  help="The order to evaluate."
29  )
30  parser.add_argument(
31  "--device",
32  type=str,
33  default="CPU",
34  help="device to evaluate on."
35  )
36  parser.add_argument(
37  "--cudnn_ws",
38  type=int,
39  help="The cudnn workspace size."
40  )
41  parser.add_argument(
42  "--iterations",
43  type=int,
44  default=10,
45  help="Number of iterations to run the network."
46  )
47  parser.add_argument(
48  "--warmup_iterations",
49  type=int,
50  default=10,
51  help="Number of warm-up iterations before benchmarking."
52  )
53  parser.add_argument(
54  "--forward_only",
55  action='store_true',
56  help="If set, only run the forward pass."
57  )
58  parser.add_argument(
59  "--layer_wise_benchmark",
60  action='store_true',
61  help="If True, run the layer-wise benchmark as well."
62  )
63  parser.add_argument(
64  "--engine",
65  type=str,
66  default="",
67  help="If set, blindly prefer the given engine(s) for every op.")
68  parser.add_argument(
69  "--dump_model",
70  action='store_true',
71  help="If True, dump the model prototxts to disk."
72  )
73  parser.add_argument("--net_type", type=str, default="simple")
74  parser.add_argument("--num_workers", type=int, default=2)
75  parser.add_argument("--use-nvtx", default=False, action='store_true')
76  parser.add_argument("--htrace_span_log_path", type=str)
77  return parser
78 
79 
80 def benchmark(args):
81  print('Batch size: {}'.format(args.batch_size))
82  mf = ModelDownloader()
83  init_net, pred_net, value_info = mf.get_c2_model(args.model)
84  input_shapes = {k : [args.batch_size] + v[-1][1:] for (k, v) in value_info.items()}
85  print("input info: {}".format(input_shapes))
86  external_inputs = {}
87  for k, v in input_shapes.items():
88  external_inputs[k] = np.random.randn(*v).astype(np.float32)
89 
90  if args.device == 'CPU':
91  device_option = core.DeviceOption(caffe2_pb2.CPU)
92  elif args.device == 'MKL':
93  device_option = core.DeviceOption(caffe2_pb2.MKLDNN)
94  elif args.device == 'IDEEP':
95  device_option = core.DeviceOption(caffe2_pb2.IDEEP)
96  else:
97  raise Exception("Unknown device: {}".format(args.device))
98  print("Device option: {}, {}".format(args.device, device_option))
99  pred_net.device_option.CopyFrom(device_option)
100  for op in pred_net.op:
101  op.device_option.CopyFrom(device_option)
102 
103  # Hack to initialized weights into MKL/IDEEP context
104  workspace.RunNetOnce(init_net)
105  bb = workspace.Blobs()
106  weights = {}
107  for b in bb:
108  weights[b] = workspace.FetchBlob(b)
109  for k, v in external_inputs.items():
110  weights[k] = v
111  workspace.ResetWorkspace()
112 
113  with core.DeviceScope(device_option):
114  for name, blob in weights.items():
115  #print("{}".format(name))
116  workspace.FeedBlob(name, blob, device_option)
117  workspace.CreateNet(pred_net)
118  start = time.time()
119  res = workspace.BenchmarkNet(pred_net.name,
120  args.warmup_iterations,
121  args.iterations,
122  args.layer_wise_benchmark)
123  print("FPS: {:.2f}".format(1/res[0]*1000*args.batch_size))
124 
125 if __name__ == '__main__':
126  args, extra_args = GetArgumentParser().parse_known_args()
127  if (
128  not args.batch_size or not args.model or not args.order
129  ):
130  GetArgumentParser().print_help()
131  benchmark(args)