8 def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
9 input = input.contiguous()
12 mean, invstd = torch.batch_norm_stats(input, eps)
14 mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=mean.device)
15 invstd_all = torch.empty(world_size, invstd.size(0), dtype=invstd.dtype, device=invstd.device)
16 mean_l = list(mean_all.unbind(0))
17 invstd_l = list(invstd_all.unbind(0))
19 mean_all_reduce = torch.distributed.all_gather(mean_l, mean, process_group, async_op=
True)
20 invstd_all_reduce = torch.distributed.all_gather(invstd_l, invstd, process_group, async_op=
True)
23 mean_all_reduce.wait()
24 invstd_all_reduce.wait()
27 mean, invstd = torch.batch_norm_gather_stats(
35 int(input.numel() / input.size(1))
43 out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
47 def backward(self, grad_output):
48 grad_output = grad_output.contiguous()
49 saved_input, weight, mean, invstd = self.saved_tensors
50 grad_input = grad_weight = grad_bias =
None 55 mean_dy, mean_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
60 self.needs_input_grad[0],
61 self.needs_input_grad[1],
62 self.needs_input_grad[2]
65 if self.needs_input_grad[0]:
68 mean_dy_all_reduce = torch.distributed.all_reduce(
69 mean_dy, torch.distributed.ReduceOp.SUM, process_group, async_op=
True)
70 mean_dy_xmu_all_reduce = torch.distributed.all_reduce(
71 mean_dy_xmu, torch.distributed.ReduceOp.SUM, process_group, async_op=
True)
74 mean_dy_all_reduce.wait()
75 mean_dy_xmu_all_reduce.wait()
77 mean_dy.div_(world_size)
78 mean_dy_xmu.div_(world_size)
80 grad_input = torch.batch_norm_backward_elemt(
92 if weight
is None or not self.needs_input_grad[1]:
95 if weight
is None or not self.needs_input_grad[2]:
98 return grad_input, grad_weight, grad_bias,
None,
None,
None,
None,
None,
None
def save_for_backward(self, tensors)