Caffe2 - Python API
A deep learning, cross platform ML framework
benchmark.py
1 import argparse
2 import os
3 from timeit import default_timer as timer
4 import torch
5 import torch.distributed as dist
6 
7 
8 def print_header(title):
9  print(title)
10  print("{:>8}\t{:>5}\t{:<{num_tensors_width}}\t{:>11}\t{:>11}".
11  format("MB/s", "MB", "#", "s", "ms/op",
12  num_tensors_width=MAX_NUM_TENSORS))
13 
14 
15 def print_stats(bytes, num_tensors, time):
16  print("{:>8.3f}\t{:>5.1f}\t{:<{num_tensors_width}}\t{:>11.3f}\t{:>11.3f}".
17  format(bytes * num_tensors / (2**20 * time),
18  bytes / 2**20,
19  num_tensors,
20  time,
21  1000 * time / num_tensors,
22  num_tensors_width=MAX_NUM_TENSORS))
23 
24 
25 parser = argparse.ArgumentParser(description='Benchmark torch.distributed.')
26 parser.add_argument('--max-bytes', dest='max_bytes', action='store', default=28,
27  type=int,
28  help='set the inclusive upper limit for tensor size; ' +
29  'default: 22 (2**22 = 4 MB)')
30 parser.add_argument('--max-num-tensors', dest='max_num_tensors', action='store',
31  default=3, type=int,
32  help='set the inclusive upper limit for the number of ' +
33  'tensors to be sent during one test run; ' +
34  'default: 3 (10**3 = 1000)')
35 parser.add_argument('--min-bytes', dest='min_bytes', action='store', default=19,
36  type=int,
37  help='set the inclusive lower limit for tensor size; ' +
38  'default: 19 (2**19 = 512 KB)')
39 parser.add_argument('--min-num-tensors', dest='min_num_tensors', action='store',
40  default=2, type=int,
41  help='set the inclusive lower limit for the number of ' +
42  'tensors to be sent during one test run; ' +
43  'default: 2 (10**2 = 100)')
44 
45 args = parser.parse_args()
46 
47 MIN_NUM_TENSORS = args.min_num_tensors
48 MIN_BYTES = args.min_bytes
49 MAX_NUM_TENSORS = args.max_num_tensors + 1
50 MAX_BYTES = args.max_bytes + 1
51 
52 dist.init_process_group(backend=os.environ['BACKEND'])
53 
54 rank = dist.get_rank()
55 dist.barrier()
56 
57 if rank == 0:
58  print_header("broadcast")
59  for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
60  tensor = torch.ByteTensor(bytes).fill_(42)
61  for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
62  start = timer()
63  for i in range(0, num_tensors):
64  dist.broadcast(tensor, 0)
65  end = timer()
66  print_stats(bytes, num_tensors, end - start)
67  print()
68 else:
69  for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
70  tensor = torch.ByteTensor(bytes)
71  for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
72  for i in range(0, num_tensors):
73  dist.broadcast(tensor, 0)
74 dist.barrier()
75 
76 if rank == 0:
77  print_header("send from 0 to 1")
78  for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
79  tensor = torch.ByteTensor(bytes).fill_(42)
80  for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
81  start = timer()
82  for i in range(0, num_tensors):
83  dist.send(tensor, 1)
84  end = timer()
85  print_stats(bytes, num_tensors, end - start)
86  print()
87 elif rank == 1:
88  for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
89  tensor = torch.ByteTensor(bytes).fill_(42)
90  for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
91  for i in range(0, num_tensors):
92  dist.recv(tensor, 0)
93 dist.barrier()
94 
95 if rank == 0:
96  print_header("reduce")
97  for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
98  tensor = torch.ByteTensor(bytes).fill_(42)
99  for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
100  start = timer()
101  for i in range(0, num_tensors):
102  dist.reduce(tensor, 0)
103  end = timer()
104  print_stats(bytes, num_tensors, end - start)
105  print()
106 else:
107  for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
108  tensor = torch.ByteTensor(bytes).fill_(42)
109  for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
110  for i in range(0, num_tensors):
111  dist.reduce(tensor, 0)
112 dist.barrier()
113 
114 if rank == 0:
115  print_header("all reduce")
116  for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
117  tensor = torch.ByteTensor(bytes).fill_(42)
118  for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
119  start = timer()
120  for i in range(0, num_tensors):
121  dist.all_reduce(tensor)
122  end = timer()
123  print_stats(bytes, num_tensors, end - start)
124  print()
125 else:
126  for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
127  tensor = torch.ByteTensor(bytes).fill_(42)
128  for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
129  for i in range(0, num_tensors):
130  dist.all_reduce(tensor)
131 dist.barrier()
132 
133 if rank == 0:
134  print_header("scatter")
135  for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
136  tensor = torch.ByteTensor(bytes).fill_(42)
137  tensors = [tensor for n in range(0, dist.get_world_size())]
138  for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
139  start = timer()
140  for i in range(0, num_tensors):
141  dist.scatter(tensor, scatter_list=tensors)
142  end = timer()
143  print_stats(bytes, num_tensors, end - start)
144  print()
145 else:
146  for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
147  tensor = torch.ByteTensor(bytes).fill_(42)
148  for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
149  for i in range(0, num_tensors):
150  dist.scatter(tensor, src=0)
151 dist.barrier()
152 
153 if rank == 0:
154  print_header("gather")
155  for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
156  tensor = torch.ByteTensor(bytes).fill_(42)
157  tensors = [tensor for n in range(0, dist.get_world_size())]
158  for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
159  start = timer()
160  for i in range(0, num_tensors):
161  dist.gather(tensor, gather_list=tensors)
162  end = timer()
163  print_stats(bytes, num_tensors, end - start)
164  print()
165 else:
166  for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
167  tensor = torch.ByteTensor(bytes).fill_(42)
168  for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
169  for i in range(0, num_tensors):
170  dist.gather(tensor, dst=0)
171 dist.barrier()
172 
173 if rank == 0:
174  print_header("all gather")
175  for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
176  tensor = torch.ByteTensor(bytes).fill_(42)
177  tensors = [tensor for n in range(0, dist.get_world_size())]
178  for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
179  start = timer()
180  for i in range(0, num_tensors):
181  dist.all_gather(tensors, tensor)
182  end = timer()
183  print_stats(bytes, num_tensors, end - start)
184  print()
185 else:
186  for bytes in [2**n for n in range(MIN_BYTES, MAX_BYTES)]:
187  tensor = torch.ByteTensor(bytes).fill_(42)
188  tensors = [tensor for n in range(0, dist.get_world_size())]
189  for num_tensors in [10**n for n in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
190  for i in range(0, num_tensors):
191  dist.all_gather(tensors, tensor)
192 dist.barrier()