Caffe2 - Python API
A deep learning, cross platform ML framework
lstm_comparison.py
1 from __future__ import absolute_import
2 from __future__ import division
3 from __future__ import print_function
4 from __future__ import unicode_literals
5 from caffe2.proto import caffe2_pb2
6 from caffe2.python import workspace, core, lstm_benchmark, utils
7 from copy import copy
8 
9 @utils.debug
10 def Compare(args):
11  results = []
12  num_iters = 1000
13  args.gpu = True
14  with core.DeviceScope(core.DeviceOption(workspace.GpuDeviceType, 0)):
15  for batch_size in [64, 128, 256]:
16  for seq_length in [20, 100]:
17  for hidden_dim in [40, 100, 400, 800]:
18  args.batch_size = batch_size
19  args.seq_length = seq_length
20  args.hidden_dim = hidden_dim
21  args.data_size = batch_size * seq_length * num_iters
22  args.iters_to_report = num_iters // 3
23 
24  args.implementation = 'own'
25  t_own = lstm_benchmark.Benchmark(args)
26  workspace.ResetWorkspace()
27  args.implementation = 'cudnn'
28  t_cudnn = lstm_benchmark.Benchmark(args)
29  workspace.ResetWorkspace()
30  results.append((copy(args), float(t_own), float(t_cudnn)))
31  print(args)
32  print("t_cudnn / t_own: {}".format(t_cudnn / t_own))
33 
34  for args, t_own, t_cudnn in results:
35  print("{}: cudnn time: {}, own time: {}, ratio: {}".format(
36  str(args), t_cudnn, t_own, t_cudnn / t_own))
37 
38  ratio_sum = 0
39  for args, t_own, t_cudnn in results:
40  ratio = float(t_cudnn) / t_own
41  ratio_sum += ratio
42  print("hidden_dim: {}, seq_lengths: {}, batch_size: {}, num_layers: {}:"
43  " cudnn time: {}, own time: {}, ratio: {}".format(
44  args.hidden_dim, args.seq_length, args.batch_size,
45  args.num_layers, t_cudnn, t_own, ratio))
46 
47  print("Ratio average: {}".format(ratio_sum / len(results)))
48 
49 
50 if __name__ == '__main__':
51  args = lstm_benchmark.GetArgumentParser().parse_args()
52 
53  workspace.GlobalInit([
54  'caffe2',
55  '--caffe2_log_level=0',
56  '--caffe2_print_blob_sizes_at_exit=0',
57  '--caffe2_gpu_memory_tracking=1'])
58 
59  Compare(args)