3 from timeit
import default_timer
as timer
8 def print_header(title):
10 print(
"{:>8}\t{:>5}\t{:<{num_tensors_width}}\t{:>11}\t{:>11}".
11 format(
"MB/s",
"MB",
"#",
"s",
"ms/op",
12 num_tensors_width=MAX_NUM_TENSORS))
15 def print_stats(bytes, num_tensors, time):
16 print(
"{:>8.3f}\t{:>5.1f}\t{:<{num_tensors_width}}\t{:>11.3f}\t{:>11.3f}".
17 format(bytes * num_tensors / (2**20 * time),
21 1000 * time / num_tensors,
22 num_tensors_width=MAX_NUM_TENSORS))
25 parser = argparse.ArgumentParser(description=
'Benchmark torch.distributed.')
26 parser.add_argument(
'--max-bytes', dest=
'max_bytes', action=
'store', default=28,
28 help=
'set the inclusive upper limit for tensor size; ' +
29 'default: 22 (2**22 = 4 MB)')
30 parser.add_argument(
'--max-num-tensors', dest=
'max_num_tensors', action=
'store',
32 help=
'set the inclusive upper limit for the number of ' +
33 'tensors to be sent during one test run; ' +
34 'default: 3 (10**3 = 1000)')
35 parser.add_argument(
'--min-bytes', dest=
'min_bytes', action=
'store', default=19,
37 help=
'set the inclusive lower limit for tensor size; ' +
38 'default: 19 (2**19 = 512 KB)')
39 parser.add_argument(
'--min-num-tensors', dest=
'min_num_tensors', action=
'store',
41 help=
'set the inclusive lower limit for the number of ' +
42 'tensors to be sent during one test run; ' +
43 'default: 2 (10**2 = 100)')
45 args = parser.parse_args()
47 MIN_NUM_TENSORS = args.min_num_tensors
48 MIN_BYTES = args.min_bytes
49 MAX_NUM_TENSORS = args.max_num_tensors + 1
50 MAX_BYTES = args.max_bytes + 1
52 dist.init_process_group(backend=os.environ[
'BACKEND'])
54 rank = dist.get_rank()
58 print_header(
"broadcast")
59 for bytes
in [2**n
for n
in range(MIN_BYTES, MAX_BYTES)]:
60 tensor = torch.ByteTensor(bytes).fill_(42)
61 for num_tensors
in [10**n
for n
in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
63 for i
in range(0, num_tensors):
64 dist.broadcast(tensor, 0)
66 print_stats(bytes, num_tensors, end - start)
69 for bytes
in [2**n
for n
in range(MIN_BYTES, MAX_BYTES)]:
70 tensor = torch.ByteTensor(bytes)
71 for num_tensors
in [10**n
for n
in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
72 for i
in range(0, num_tensors):
73 dist.broadcast(tensor, 0)
77 print_header(
"send from 0 to 1")
78 for bytes
in [2**n
for n
in range(MIN_BYTES, MAX_BYTES)]:
79 tensor = torch.ByteTensor(bytes).fill_(42)
80 for num_tensors
in [10**n
for n
in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
82 for i
in range(0, num_tensors):
85 print_stats(bytes, num_tensors, end - start)
88 for bytes
in [2**n
for n
in range(MIN_BYTES, MAX_BYTES)]:
89 tensor = torch.ByteTensor(bytes).fill_(42)
90 for num_tensors
in [10**n
for n
in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
91 for i
in range(0, num_tensors):
96 print_header(
"reduce")
97 for bytes
in [2**n
for n
in range(MIN_BYTES, MAX_BYTES)]:
98 tensor = torch.ByteTensor(bytes).fill_(42)
99 for num_tensors
in [10**n
for n
in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
101 for i
in range(0, num_tensors):
102 dist.reduce(tensor, 0)
104 print_stats(bytes, num_tensors, end - start)
107 for bytes
in [2**n
for n
in range(MIN_BYTES, MAX_BYTES)]:
108 tensor = torch.ByteTensor(bytes).fill_(42)
109 for num_tensors
in [10**n
for n
in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
110 for i
in range(0, num_tensors):
111 dist.reduce(tensor, 0)
115 print_header(
"all reduce")
116 for bytes
in [2**n
for n
in range(MIN_BYTES, MAX_BYTES)]:
117 tensor = torch.ByteTensor(bytes).fill_(42)
118 for num_tensors
in [10**n
for n
in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
120 for i
in range(0, num_tensors):
121 dist.all_reduce(tensor)
123 print_stats(bytes, num_tensors, end - start)
126 for bytes
in [2**n
for n
in range(MIN_BYTES, MAX_BYTES)]:
127 tensor = torch.ByteTensor(bytes).fill_(42)
128 for num_tensors
in [10**n
for n
in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
129 for i
in range(0, num_tensors):
130 dist.all_reduce(tensor)
134 print_header(
"scatter")
135 for bytes
in [2**n
for n
in range(MIN_BYTES, MAX_BYTES)]:
136 tensor = torch.ByteTensor(bytes).fill_(42)
137 tensors = [tensor
for n
in range(0, dist.get_world_size())]
138 for num_tensors
in [10**n
for n
in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
140 for i
in range(0, num_tensors):
141 dist.scatter(tensor, scatter_list=tensors)
143 print_stats(bytes, num_tensors, end - start)
146 for bytes
in [2**n
for n
in range(MIN_BYTES, MAX_BYTES)]:
147 tensor = torch.ByteTensor(bytes).fill_(42)
148 for num_tensors
in [10**n
for n
in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
149 for i
in range(0, num_tensors):
150 dist.scatter(tensor, src=0)
154 print_header(
"gather")
155 for bytes
in [2**n
for n
in range(MIN_BYTES, MAX_BYTES)]:
156 tensor = torch.ByteTensor(bytes).fill_(42)
157 tensors = [tensor
for n
in range(0, dist.get_world_size())]
158 for num_tensors
in [10**n
for n
in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
160 for i
in range(0, num_tensors):
161 dist.gather(tensor, gather_list=tensors)
163 print_stats(bytes, num_tensors, end - start)
166 for bytes
in [2**n
for n
in range(MIN_BYTES, MAX_BYTES)]:
167 tensor = torch.ByteTensor(bytes).fill_(42)
168 for num_tensors
in [10**n
for n
in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
169 for i
in range(0, num_tensors):
170 dist.gather(tensor, dst=0)
174 print_header(
"all gather")
175 for bytes
in [2**n
for n
in range(MIN_BYTES, MAX_BYTES)]:
176 tensor = torch.ByteTensor(bytes).fill_(42)
177 tensors = [tensor
for n
in range(0, dist.get_world_size())]
178 for num_tensors
in [10**n
for n
in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
180 for i
in range(0, num_tensors):
181 dist.all_gather(tensors, tensor)
183 print_stats(bytes, num_tensors, end - start)
186 for bytes
in [2**n
for n
in range(MIN_BYTES, MAX_BYTES)]:
187 tensor = torch.ByteTensor(bytes).fill_(42)
188 tensors = [tensor
for n
in range(0, dist.get_world_size())]
189 for num_tensors
in [10**n
for n
in range(MIN_NUM_TENSORS, MAX_NUM_TENSORS)]:
190 for i
in range(0, num_tensors):
191 dist.all_gather(tensors, tensor)