Caffe2 - Python API
A deep learning, cross platform ML framework
convnet_benchmarks.py
1 ## @package convnet_benchmarks
2 # Module caffe2.python.convnet_benchmarks
3 """
4 Benchmark for common convnets.
5 
6 Speed on Titan X, with 10 warmup steps and 10 main steps and with different
7 versions of cudnn, are as follows (time reported below is per-batch time,
8 forward / forward+backward):
9 
10  CuDNN V3 CuDNN v4
11 AlexNet 32.5 / 108.0 27.4 / 90.1
12 OverFeat 113.0 / 342.3 91.7 / 276.5
13 Inception 134.5 / 485.8 125.7 / 450.6
14 VGG (batch 64) 200.8 / 650.0 164.1 / 551.7
15 
16 Speed on Inception with varied batch sizes and CuDNN v4 is as follows:
17 
18 Batch Size Speed per batch Speed per image
19  16 22.8 / 72.7 1.43 / 4.54
20  32 38.0 / 127.5 1.19 / 3.98
21  64 67.2 / 233.6 1.05 / 3.65
22 128 125.7 / 450.6 0.98 / 3.52
23 
24 Speed on Tesla M40, which 10 warmup steps and 10 main steps and with cudnn
25 v4, is as follows:
26 
27 AlexNet 68.4 / 218.1
28 OverFeat 210.5 / 630.3
29 Inception 300.2 / 1122.2
30 VGG (batch 64) 405.8 / 1327.7
31 
32 (Note that these numbers involve a "full" backprop, i.e. the gradient
33 with respect to the input image is also computed.)
34 
35 To get the numbers, simply run:
36 
37 for MODEL in AlexNet OverFeat Inception; do
38  PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py \
39  --batch_size 128 --model $MODEL --forward_only True
40 done
41 for MODEL in AlexNet OverFeat Inception; do
42  PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py \
43  --batch_size 128 --model $MODEL
44 done
45 PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py \
46  --batch_size 64 --model VGGA --forward_only True
47 PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py \
48  --batch_size 64 --model VGGA
49 
50 for BS in 16 32 64 128; do
51  PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py \
52  --batch_size $BS --model Inception --forward_only True
53  PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py \
54  --batch_size $BS --model Inception
55 done
56 
57 Note that VGG needs to be run at batch 64 due to memory limit on the backward
58 pass.
59 """
60 
61 import argparse
62 
63 from caffe2.python import workspace, brew, model_helper
64 
65 
66 def MLP(order, cudnn_ws):
67  model = model_helper.ModelHelper(name="MLP")
68  d = 256
69  depth = 20
70  width = 3
71  for i in range(depth):
72  for j in range(width):
73  current = "fc_{}_{}".format(i, j) if i > 0 else "data"
74  next_ = "fc_{}_{}".format(i + 1, j)
75  brew.fc(
76  model,
77  current,
78  next_,
79  dim_in=d,
80  dim_out=d,
81  weight_init=('XavierFill', {}),
82  bias_init=('XavierFill', {}),
83  )
84  brew.sum(
85  model, ["fc_{}_{}".format(depth, j) for j in range(width)], ["sum"]
86  )
87  brew.fc(
88  model,
89  "sum",
90  "last",
91  dim_in=d,
92  dim_out=1000,
93  weight_init=('XavierFill', {}),
94  bias_init=('XavierFill', {}),
95  )
96  xent = model.net.LabelCrossEntropy(["last", "label"], "xent")
97  model.net.AveragedLoss(xent, "loss")
98  return model, d
99 
100 
101 def AlexNet(order, cudnn_ws):
102  my_arg_scope = {
103  'order': order,
104  'use_cudnn': True,
105  'cudnn_exhaustive_search': True,
106  }
107  if cudnn_ws:
108  my_arg_scope['ws_nbytes_limit'] = cudnn_ws
109  model = model_helper.ModelHelper(
110  name="alexnet",
111  arg_scope=my_arg_scope,
112  )
113  conv1 = brew.conv(
114  model,
115  "data",
116  "conv1",
117  3,
118  64,
119  11, ('XavierFill', {}), ('ConstantFill', {}),
120  stride=4,
121  pad=2
122  )
123  relu1 = brew.relu(model, conv1, "conv1")
124  pool1 = brew.max_pool(model, relu1, "pool1", kernel=3, stride=2)
125  conv2 = brew.conv(
126  model,
127  pool1,
128  "conv2",
129  64,
130  192,
131  5,
132  ('XavierFill', {}),
133  ('ConstantFill', {}),
134  pad=2
135  )
136  relu2 = brew.relu(model, conv2, "conv2")
137  pool2 = brew.max_pool(model, relu2, "pool2", kernel=3, stride=2)
138  conv3 = brew.conv(
139  model,
140  pool2,
141  "conv3",
142  192,
143  384,
144  3,
145  ('XavierFill', {}),
146  ('ConstantFill', {}),
147  pad=1
148  )
149  relu3 = brew.relu(model, conv3, "conv3")
150  conv4 = brew.conv(
151  model,
152  relu3,
153  "conv4",
154  384,
155  256,
156  3,
157  ('XavierFill', {}),
158  ('ConstantFill', {}),
159  pad=1
160  )
161  relu4 = brew.relu(model, conv4, "conv4")
162  conv5 = brew.conv(
163  model,
164  relu4,
165  "conv5",
166  256,
167  256,
168  3,
169  ('XavierFill', {}),
170  ('ConstantFill', {}),
171  pad=1
172  )
173  relu5 = brew.relu(model, conv5, "conv5")
174  pool5 = brew.max_pool(model, relu5, "pool5", kernel=3, stride=2)
175  fc6 = brew.fc(
176  model,
177  pool5, "fc6", 256 * 6 * 6, 4096, ('XavierFill', {}),
178  ('ConstantFill', {})
179  )
180  relu6 = brew.relu(model, fc6, "fc6")
181  fc7 = brew.fc(
182  model, relu6, "fc7", 4096, 4096, ('XavierFill', {}), ('ConstantFill', {})
183  )
184  relu7 = brew.relu(model, fc7, "fc7")
185  fc8 = brew.fc(
186  model, relu7, "fc8", 4096, 1000, ('XavierFill', {}), ('ConstantFill', {})
187  )
188  pred = brew.softmax(model, fc8, "pred")
189  xent = model.net.LabelCrossEntropy([pred, "label"], "xent")
190  model.net.AveragedLoss(xent, "loss")
191  return model, 224
192 
193 
194 def OverFeat(order, cudnn_ws):
195  my_arg_scope = {
196  'order': order,
197  'use_cudnn': True,
198  'cudnn_exhaustive_search': True,
199  }
200  if cudnn_ws:
201  my_arg_scope['ws_nbytes_limit'] = cudnn_ws
202  model = model_helper.ModelHelper(
203  name="overfeat",
204  arg_scope=my_arg_scope,
205  )
206  conv1 = brew.conv(
207  model,
208  "data",
209  "conv1",
210  3,
211  96,
212  11,
213  ('XavierFill', {}),
214  ('ConstantFill', {}),
215  stride=4,
216  )
217  relu1 = brew.relu(model, conv1, "conv1")
218  pool1 = brew.max_pool(model, relu1, "pool1", kernel=2, stride=2)
219  conv2 = brew.conv(
220  model, pool1, "conv2", 96, 256, 5, ('XavierFill', {}),
221  ('ConstantFill', {})
222  )
223  relu2 = brew.relu(model, conv2, "conv2")
224  pool2 = brew.max_pool(model, relu2, "pool2", kernel=2, stride=2)
225  conv3 = brew.conv(
226  model,
227  pool2,
228  "conv3",
229  256,
230  512,
231  3,
232  ('XavierFill', {}),
233  ('ConstantFill', {}),
234  pad=1,
235  )
236  relu3 = brew.relu(model, conv3, "conv3")
237  conv4 = brew.conv(
238  model,
239  relu3,
240  "conv4",
241  512,
242  1024,
243  3,
244  ('XavierFill', {}),
245  ('ConstantFill', {}),
246  pad=1,
247  )
248  relu4 = brew.relu(model, conv4, "conv4")
249  conv5 = brew.conv(
250  model,
251  relu4,
252  "conv5",
253  1024,
254  1024,
255  3,
256  ('XavierFill', {}),
257  ('ConstantFill', {}),
258  pad=1,
259  )
260  relu5 = brew.relu(model, conv5, "conv5")
261  pool5 = brew.max_pool(model, relu5, "pool5", kernel=2, stride=2)
262  fc6 = brew.fc(
263  model, pool5, "fc6", 1024 * 6 * 6, 3072, ('XavierFill', {}),
264  ('ConstantFill', {})
265  )
266  relu6 = brew.relu(model, fc6, "fc6")
267  fc7 = brew.fc(
268  model, relu6, "fc7", 3072, 4096, ('XavierFill', {}), ('ConstantFill', {})
269  )
270  relu7 = brew.relu(model, fc7, "fc7")
271  fc8 = brew.fc(
272  model, relu7, "fc8", 4096, 1000, ('XavierFill', {}), ('ConstantFill', {})
273  )
274  pred = brew.softmax(model, fc8, "pred")
275  xent = model.net.LabelCrossEntropy([pred, "label"], "xent")
276  model.net.AveragedLoss(xent, "loss")
277  return model, 231
278 
279 
280 def VGGA(order, cudnn_ws):
281  my_arg_scope = {
282  'order': order,
283  'use_cudnn': True,
284  'cudnn_exhaustive_search': True,
285  }
286  if cudnn_ws:
287  my_arg_scope['ws_nbytes_limit'] = cudnn_ws
288  model = model_helper.ModelHelper(
289  name="vgga",
290  arg_scope=my_arg_scope,
291  )
292  conv1 = brew.conv(
293  model,
294  "data",
295  "conv1",
296  3,
297  64,
298  3,
299  ('XavierFill', {}),
300  ('ConstantFill', {}),
301  pad=1,
302  )
303  relu1 = brew.relu(model, conv1, "conv1")
304  pool1 = brew.max_pool(model, relu1, "pool1", kernel=2, stride=2)
305  conv2 = brew.conv(
306  model,
307  pool1,
308  "conv2",
309  64,
310  128,
311  3,
312  ('XavierFill', {}),
313  ('ConstantFill', {}),
314  pad=1,
315  )
316  relu2 = brew.relu(model, conv2, "conv2")
317  pool2 = brew.max_pool(model, relu2, "pool2", kernel=2, stride=2)
318  conv3 = brew.conv(
319  model,
320  pool2,
321  "conv3",
322  128,
323  256,
324  3,
325  ('XavierFill', {}),
326  ('ConstantFill', {}),
327  pad=1,
328  )
329  relu3 = brew.relu(model, conv3, "conv3")
330  conv4 = brew.conv(
331  model,
332  relu3,
333  "conv4",
334  256,
335  256,
336  3,
337  ('XavierFill', {}),
338  ('ConstantFill', {}),
339  pad=1,
340  )
341  relu4 = brew.relu(model, conv4, "conv4")
342  pool4 = brew.max_pool(model, relu4, "pool4", kernel=2, stride=2)
343  conv5 = brew.conv(
344  model,
345  pool4,
346  "conv5",
347  256,
348  512,
349  3,
350  ('XavierFill', {}),
351  ('ConstantFill', {}),
352  pad=1,
353  )
354  relu5 = brew.relu(model, conv5, "conv5")
355  conv6 = brew.conv(
356  model,
357  relu5,
358  "conv6",
359  512,
360  512,
361  3,
362  ('XavierFill', {}),
363  ('ConstantFill', {}),
364  pad=1,
365  )
366  relu6 = brew.relu(model, conv6, "conv6")
367  pool6 = brew.max_pool(model, relu6, "pool6", kernel=2, stride=2)
368  conv7 = brew.conv(
369  model,
370  pool6,
371  "conv7",
372  512,
373  512,
374  3,
375  ('XavierFill', {}),
376  ('ConstantFill', {}),
377  pad=1,
378  )
379  relu7 = brew.relu(model, conv7, "conv7")
380  conv8 = brew.conv(
381  model,
382  relu7,
383  "conv8",
384  512,
385  512,
386  3,
387  ('XavierFill', {}),
388  ('ConstantFill', {}),
389  pad=1,
390  )
391  relu8 = brew.relu(model, conv8, "conv8")
392  pool8 = brew.max_pool(model, relu8, "pool8", kernel=2, stride=2)
393 
394  fcix = brew.fc(
395  model, pool8, "fcix", 512 * 7 * 7, 4096, ('XavierFill', {}),
396  ('ConstantFill', {})
397  )
398  reluix = brew.relu(model, fcix, "fcix")
399  fcx = brew.fc(
400  model, reluix, "fcx", 4096, 4096, ('XavierFill', {}),
401  ('ConstantFill', {})
402  )
403  relux = brew.relu(model, fcx, "fcx")
404  fcxi = brew.fc(
405  model, relux, "fcxi", 4096, 1000, ('XavierFill', {}),
406  ('ConstantFill', {})
407  )
408  pred = brew.softmax(model, fcxi, "pred")
409  xent = model.net.LabelCrossEntropy([pred, "label"], "xent")
410  model.net.AveragedLoss(xent, "loss")
411  return model, 231
412 
413 
414 def _InceptionModule(
415  model, input_blob, input_depth, output_name, conv1_depth, conv3_depths,
416  conv5_depths, pool_depth
417 ):
418  # path 1: 1x1 conv
419  conv1 = brew.conv(
420  model, input_blob, output_name + ":conv1", input_depth, conv1_depth, 1,
421  ('XavierFill', {}), ('ConstantFill', {})
422  )
423  conv1 = brew.relu(model, conv1, conv1)
424  # path 2: 1x1 conv + 3x3 conv
425  conv3_reduce = brew.conv(
426  model, input_blob, output_name + ":conv3_reduce", input_depth,
427  conv3_depths[0], 1, ('XavierFill', {}), ('ConstantFill', {})
428  )
429  conv3_reduce = brew.relu(model, conv3_reduce, conv3_reduce)
430  conv3 = brew.conv(
431  model,
432  conv3_reduce,
433  output_name + ":conv3",
434  conv3_depths[0],
435  conv3_depths[1],
436  3,
437  ('XavierFill', {}),
438  ('ConstantFill', {}),
439  pad=1,
440  )
441  conv3 = brew.relu(model, conv3, conv3)
442  # path 3: 1x1 conv + 5x5 conv
443  conv5_reduce = brew.conv(
444  model, input_blob, output_name + ":conv5_reduce", input_depth,
445  conv5_depths[0], 1, ('XavierFill', {}), ('ConstantFill', {})
446  )
447  conv5_reduce = brew.relu(model, conv5_reduce, conv5_reduce)
448  conv5 = brew.conv(
449  model,
450  conv5_reduce,
451  output_name + ":conv5",
452  conv5_depths[0],
453  conv5_depths[1],
454  5,
455  ('XavierFill', {}),
456  ('ConstantFill', {}),
457  pad=2,
458  )
459  conv5 = brew.relu(model, conv5, conv5)
460  # path 4: pool + 1x1 conv
461  pool = brew.max_pool(
462  model,
463  input_blob,
464  output_name + ":pool",
465  kernel=3,
466  stride=1,
467  pad=1,
468  )
469  pool_proj = brew.conv(
470  model, pool, output_name + ":pool_proj", input_depth, pool_depth, 1,
471  ('XavierFill', {}), ('ConstantFill', {})
472  )
473  pool_proj = brew.relu(model, pool_proj, pool_proj)
474  output = brew.concat(model, [conv1, conv3, conv5, pool_proj], output_name)
475  return output
476 
477 
478 def Inception(order, cudnn_ws):
479  my_arg_scope = {
480  'order': order,
481  'use_cudnn': True,
482  'cudnn_exhaustive_search': True,
483  }
484  if cudnn_ws:
485  my_arg_scope['ws_nbytes_limit'] = cudnn_ws
486  model = model_helper.ModelHelper(
487  name="inception",
488  arg_scope=my_arg_scope,
489  )
490  conv1 = brew.conv(
491  model,
492  "data",
493  "conv1",
494  3,
495  64,
496  7,
497  ('XavierFill', {}),
498  ('ConstantFill', {}),
499  stride=2,
500  pad=3,
501  )
502  relu1 = brew.relu(model, conv1, "conv1")
503  pool1 = brew.max_pool(model, relu1, "pool1", kernel=3, stride=2, pad=1)
504  conv2a = brew.conv(
505  model, pool1, "conv2a", 64, 64, 1, ('XavierFill', {}),
506  ('ConstantFill', {})
507  )
508  conv2a = brew.relu(model, conv2a, conv2a)
509  conv2 = brew.conv(
510  model,
511  conv2a,
512  "conv2",
513  64,
514  192,
515  3,
516  ('XavierFill', {}),
517  ('ConstantFill', {}),
518  pad=1,
519  )
520  relu2 = brew.relu(model, conv2, "conv2")
521  pool2 = brew.max_pool(model, relu2, "pool2", kernel=3, stride=2, pad=1)
522  # Inception modules
523  inc3 = _InceptionModule(
524  model, pool2, 192, "inc3", 64, [96, 128], [16, 32], 32
525  )
526  inc4 = _InceptionModule(
527  model, inc3, 256, "inc4", 128, [128, 192], [32, 96], 64
528  )
529  pool5 = brew.max_pool(model, inc4, "pool5", kernel=3, stride=2, pad=1)
530  inc5 = _InceptionModule(
531  model, pool5, 480, "inc5", 192, [96, 208], [16, 48], 64
532  )
533  inc6 = _InceptionModule(
534  model, inc5, 512, "inc6", 160, [112, 224], [24, 64], 64
535  )
536  inc7 = _InceptionModule(
537  model, inc6, 512, "inc7", 128, [128, 256], [24, 64], 64
538  )
539  inc8 = _InceptionModule(
540  model, inc7, 512, "inc8", 112, [144, 288], [32, 64], 64
541  )
542  inc9 = _InceptionModule(
543  model, inc8, 528, "inc9", 256, [160, 320], [32, 128], 128
544  )
545  pool9 = brew.max_pool(model, inc9, "pool9", kernel=3, stride=2, pad=1)
546  inc10 = _InceptionModule(
547  model, pool9, 832, "inc10", 256, [160, 320], [32, 128], 128
548  )
549  inc11 = _InceptionModule(
550  model, inc10, 832, "inc11", 384, [192, 384], [48, 128], 128
551  )
552  pool11 = brew.average_pool(model, inc11, "pool11", kernel=7, stride=1)
553  fc = brew.fc(
554  model, pool11, "fc", 1024, 1000, ('XavierFill', {}),
555  ('ConstantFill', {})
556  )
557  # It seems that Soumith's benchmark does not have softmax on top
558  # for Inception. We will add it anyway so we can have a proper
559  # backward pass.
560  pred = brew.softmax(model, fc, "pred")
561  xent = model.net.LabelCrossEntropy([pred, "label"], "xent")
562  model.net.AveragedLoss(xent, "loss")
563  return model, 224
564 
565 
566 def AddParameterUpdate(model):
567  """ Simple plain SGD update -- not tuned to actually train the models """
568  ITER = brew.iter(model, "iter")
569  LR = model.net.LearningRate(
570  ITER, "LR", base_lr=-1e-8, policy="step", stepsize=10000, gamma=0.999)
571  ONE = model.param_init_net.ConstantFill([], "ONE", shape=[1], value=1.0)
572  for param in model.params:
573  param_grad = model.param_to_grad[param]
574  model.net.WeightedSum([param, ONE, param_grad, LR], param)
575 
576 
577 def Benchmark(model_gen, arg):
578  model, input_size = model_gen(arg.order, arg.cudnn_ws)
579  model.Proto().type = arg.net_type
580  model.Proto().num_workers = arg.num_workers
581 
582  # In order to be able to run everything without feeding more stuff, let's
583  # add the data and label blobs to the parameter initialization net as well.
584  if arg.order == "NCHW":
585  input_shape = [arg.batch_size, 3, input_size, input_size]
586  else:
587  input_shape = [arg.batch_size, input_size, input_size, 3]
588  if arg.model == "MLP":
589  input_shape = [arg.batch_size, input_size]
590 
591  model.param_init_net.GaussianFill(
592  [],
593  "data",
594  shape=input_shape,
595  mean=0.0,
596  std=1.0
597  )
598  model.param_init_net.UniformIntFill(
599  [],
600  "label",
601  shape=[arg.batch_size, ],
602  min=0,
603  max=999
604  )
605 
606  if arg.forward_only:
607  print('{}: running forward only.'.format(arg.model))
608  else:
609  print('{}: running forward-backward.'.format(arg.model))
610  model.AddGradientOperators(["loss"])
611  AddParameterUpdate(model)
612  if arg.order == 'NHWC':
613  print(
614  '==WARNING==\n'
615  'NHWC order with CuDNN may not be supported yet, so I might\n'
616  'exit suddenly.'
617  )
618 
619  if not arg.cpu:
620  model.param_init_net.RunAllOnGPU()
621  model.net.RunAllOnGPU()
622 
623  if arg.engine:
624  for op in model.net.Proto().op:
625  op.engine = arg.engine
626 
627  if arg.dump_model:
628  # Writes out the pbtxt for benchmarks on e.g. Android
629  with open(
630  "{0}_init_batch_{1}.pbtxt".format(arg.model, arg.batch_size), "w"
631  ) as fid:
632  fid.write(str(model.param_init_net.Proto()))
633  with open("{0}.pbtxt".format(arg.model, arg.batch_size), "w") as fid:
634  fid.write(str(model.net.Proto()))
635 
636  workspace.RunNetOnce(model.param_init_net)
637  workspace.CreateNet(model.net)
638  workspace.BenchmarkNet(
639  model.net.Proto().name, arg.warmup_iterations, arg.iterations,
640  arg.layer_wise_benchmark)
641 
642 
643 def GetArgumentParser():
644  parser = argparse.ArgumentParser(description="Caffe2 benchmark.")
645  parser.add_argument(
646  "--batch_size",
647  type=int,
648  default=128,
649  help="The batch size."
650  )
651  parser.add_argument("--model", type=str, help="The model to benchmark.")
652  parser.add_argument(
653  "--order",
654  type=str,
655  default="NCHW",
656  help="The order to evaluate."
657  )
658  parser.add_argument(
659  "--cudnn_ws",
660  type=int,
661  help="The cudnn workspace size."
662  )
663  parser.add_argument(
664  "--iterations",
665  type=int,
666  default=10,
667  help="Number of iterations to run the network."
668  )
669  parser.add_argument(
670  "--warmup_iterations",
671  type=int,
672  default=10,
673  help="Number of warm-up iterations before benchmarking."
674  )
675  parser.add_argument(
676  "--forward_only",
677  action='store_true',
678  help="If set, only run the forward pass."
679  )
680  parser.add_argument(
681  "--layer_wise_benchmark",
682  action='store_true',
683  help="If True, run the layer-wise benchmark as well."
684  )
685  parser.add_argument(
686  "--cpu",
687  action='store_true',
688  help="If True, run testing on CPU instead of GPU."
689  )
690  parser.add_argument(
691  "--engine",
692  type=str,
693  default="",
694  help="If set, blindly prefer the given engine(s) for every op.")
695  parser.add_argument(
696  "--dump_model",
697  action='store_true',
698  help="If True, dump the model prototxts to disk."
699  )
700  parser.add_argument("--net_type", type=str, default="dag")
701  parser.add_argument("--num_workers", type=int, default=2)
702  parser.add_argument("--use-nvtx", default=False, action='store_true')
703  parser.add_argument("--htrace_span_log_path", type=str)
704  return parser
705 
706 
707 if __name__ == '__main__':
708  args, extra_args = GetArgumentParser().parse_known_args()
709  if (
710  not args.batch_size or not args.model or not args.order
711  ):
712  GetArgumentParser().print_help()
713  else:
714  workspace.GlobalInit(
715  ['caffe2', '--caffe2_log_level=0'] + extra_args +
716  (['--caffe2_use_nvtx'] if args.use_nvtx else []) +
717  (['--caffe2_htrace_span_log_path=' + args.htrace_span_log_path]
718  if args.htrace_span_log_path else []))
719 
720  model_map = {
721  'AlexNet': AlexNet,
722  'OverFeat': OverFeat,
723  'VGGA': VGGA,
724  'Inception': Inception,
725  'MLP': MLP,
726  }
727  Benchmark(model_map[args.model], args)