Caffe2 - Python API
A deep learning, cross platform ML framework
_functions.py
1 import torch
2 from torch.autograd.function import Function
3 
4 
6 
7  @staticmethod
8  def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
9  input = input.contiguous()
10 
11  # calcualte mean/invstd for input.
12  mean, invstd = torch.batch_norm_stats(input, eps)
13 
14  mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=mean.device)
15  invstd_all = torch.empty(world_size, invstd.size(0), dtype=invstd.dtype, device=invstd.device)
16  mean_l = list(mean_all.unbind(0))
17  invstd_l = list(invstd_all.unbind(0))
18  # using all_gather instead of all reduce so we can calculate mean/var in one go
19  mean_all_reduce = torch.distributed.all_gather(mean_l, mean, process_group, async_op=True)
20  invstd_all_reduce = torch.distributed.all_gather(invstd_l, invstd, process_group, async_op=True)
21 
22  # wait on the async communication to finish
23  mean_all_reduce.wait()
24  invstd_all_reduce.wait()
25 
26  # calcualte global mean & invstd
27  mean, invstd = torch.batch_norm_gather_stats(
28  input,
29  mean_all,
30  invstd_all,
31  running_mean,
32  running_var,
33  momentum,
34  eps,
35  int(input.numel() / input.size(1))
36  )
37 
38  self.save_for_backward(input, weight, mean, invstd)
39  self.process_group = process_group
40  self.world_size = world_size
41 
42  # apply element-wise normalization
43  out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
44  return out
45 
46  @staticmethod
47  def backward(self, grad_output):
48  grad_output = grad_output.contiguous()
49  saved_input, weight, mean, invstd = self.saved_tensors
50  grad_input = grad_weight = grad_bias = None
51  process_group = self.process_group
52  world_size = self.world_size
53 
54  # calculate local stats as well as grad_weight / grad_bias
55  mean_dy, mean_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
56  grad_output,
57  saved_input,
58  mean,
59  invstd,
60  self.needs_input_grad[0],
61  self.needs_input_grad[1],
62  self.needs_input_grad[2]
63  )
64 
65  if self.needs_input_grad[0]:
66  # synchronizing stats used to calculate input gradient.
67  # TODO: move div_ into batch_norm_backward_elemt kernel
68  mean_dy_all_reduce = torch.distributed.all_reduce(
69  mean_dy, torch.distributed.ReduceOp.SUM, process_group, async_op=True)
70  mean_dy_xmu_all_reduce = torch.distributed.all_reduce(
71  mean_dy_xmu, torch.distributed.ReduceOp.SUM, process_group, async_op=True)
72 
73  # wait on the async communication to finish
74  mean_dy_all_reduce.wait()
75  mean_dy_xmu_all_reduce.wait()
76 
77  mean_dy.div_(world_size)
78  mean_dy_xmu.div_(world_size)
79  # backward pass for gradient calculation
80  grad_input = torch.batch_norm_backward_elemt(
81  grad_output,
82  saved_input,
83  mean,
84  invstd,
85  weight,
86  mean_dy,
87  mean_dy_xmu
88  )
89 
90  # synchronizing of grad_weight / grad_bias is not needed as distributed
91  # training would handle all reduce.
92  if weight is None or not self.needs_input_grad[1]:
93  grad_weight = None
94 
95  if weight is None or not self.needs_input_grad[2]:
96  grad_bias = None
97 
98  return grad_input, grad_weight, grad_bias, None, None, None, None, None, None