Caffe2 - Python API
A deep learning, cross platform ML framework
auto_double_backwards.py
1 import torch
2 
3 
4 def elu_double_backwards(ctx, ggI):
5  t = ctx.saved_tensors
6  input, grad_output = t[0], t[1]
7  alpha = ctx.additional_args[0]
8 
9  negative_mask = (input < 0).type_as(ggI)
10  exp_alpha = input.exp() * alpha * negative_mask
11  gI = ggI * grad_output * exp_alpha
12 
13  non_negative_mask = (input >= 0).type_as(ggI)
14  ggO = ggI * (exp_alpha + non_negative_mask)
15  return gI, ggO, None, None, None, None
16 
17 
18 def gatedlinear_double_backwards(ctx, ggI):
19  input, gO = ctx.saved_tensors
20  dim = ctx.additional_args[0]
21 
22  input_size = input.size(dim) // 2
23 
24  first_half = input.narrow(dim, 0, input_size)
25  second_half = input.narrow(dim, input_size, input_size)
26  sig_second_half = second_half.sigmoid()
27  one_sub_sig_second_half = 1 - sig_second_half
28  sig_one_sub_sig = sig_second_half * one_sub_sig_second_half
29 
30  ggI_first_half = ggI.narrow(dim, 0, input_size)
31  ggI_second_half = ggI.narrow(dim, input_size, input_size)
32  ggI_second_half_times_first_half = ggI_second_half * first_half
33 
34  gI_first_half = ggI_second_half * gO * sig_one_sub_sig
35  second_order_sh = sig_one_sub_sig * one_sub_sig_second_half - sig_second_half * sig_one_sub_sig
36  gI_second_half = ggI_second_half_times_first_half * gO * second_order_sh + ggI_first_half * gO * sig_one_sub_sig
37  gI = torch.cat((gI_first_half, gI_second_half), dim)
38 
39  ggO = ggI_first_half * sig_second_half + ggI_second_half_times_first_half * sig_one_sub_sig
40 
41  return gI, ggO, None, None, None
42 
43 
44 def hardshrink_double_backwards(ctx, ggI):
45  t = ctx.saved_tensors
46  input = t[0]
47  lambd = ctx.additional_args[0]
48  gI = None
49 
50  mask = torch.zeros_like(input).masked_fill_(input > lambd, 1).masked_fill_(input < -lambd, 1)
51  ggO = ggI * mask
52 
53  return gI, ggO, None, None, None
54 
55 
56 def hardtanh_double_backwards(ctx, ggI):
57  t = ctx.saved_tensors
58  input, grad_output = t[0], t[1]
59  min_val, max_val = ctx.additional_args[0:2]
60 
61  max_mask = input <= max_val
62  min_mask = input <= min_val
63  gI = torch.zeros_like(ggI)
64  ggO = ggI * (max_mask - min_mask).type_as(grad_output)
65  return gI, ggO, None, None, None
66 
67 
68 def leakyrelu_double_backwards(ctx, ggI):
69  t = ctx.saved_tensors
70  input = t[0]
71  negative_slope = ctx.additional_args[0]
72 
73  gI = torch.zeros_like(ggI)
74  input_lt_0 = (input < 0).type_as(ggI)
75  input_ge_0 = (input >= 0).type_as(ggI)
76  ggO = ggI * (input_lt_0 * negative_slope + input_ge_0)
77  return gI, ggO, None, None, None
78 
79 
80 def logsigmoid_double_backwards(ctx, ggI):
81  t = ctx.saved_tensors
82  # maybe more efficient in terms of output, but save_output is False
83  input, gO = t[0], t[1]
84 
85  exp_input = input.exp()
86  exp_input_plus_1 = exp_input + 1
87  gI = ggI * gO * -1 * exp_input / (exp_input_plus_1.pow(2))
88  ggO = ggI / exp_input_plus_1
89 
90  return gI, ggO, None, None, None, None
91 
92 
93 def softplus_double_backwards(ctx, ggI):
94  t = ctx.saved_tensors
95  input, gO, output = t[0], t[1], t[2]
96  beta, threshold = ctx.additional_args[0], ctx.additional_args[1]
97 
98  input_beta = input * beta
99  above_threshold = torch.zeros_like(ggI).masked_fill_(input_beta > threshold, 1)
100  below_threshold = torch.zeros_like(ggI).masked_fill_(input_beta <= threshold, 1)
101 
102  exp_output_beta = (output * beta).exp()
103  first_deriv = (exp_output_beta - 1) / exp_output_beta
104  first_deriv_below_threshold = first_deriv * below_threshold
105 
106  gI = ggI * gO * first_deriv_below_threshold * beta / exp_output_beta
107  ggO = ggI * (above_threshold + first_deriv_below_threshold)
108 
109  return gI, ggO, None, None, None, None
110 
111 
112 def softshrink_double_backwards(ctx, ggI):
113  return hardshrink_double_backwards(ctx, ggI)
114 
115 
116 def threshold_double_backwards(ctx, ggI):
117  t = ctx.saved_tensors
118  input = t[0]
119  threshold, value = ctx.additional_args[0:2]
120 
121  gI = torch.zeros_like(ggI)
122  input_gt_threshold = (input > threshold).type_as(ggI)
123  ggO = ggI * input_gt_threshold
124  return gI, ggO, None, None, None
125 
126 
127 def klddivloss_double_backwards(ctx, ggI):
128  size_average = ctx.additional_args[0]
129  input, target, gO = ctx.saved_tensors
130  div_factor = input.nelement() if size_average else 1
131 
132  gI = None
133  ggO = (ggI * target).sum() / -div_factor
134 
135  return gI, None, ggO, None, None
136 
137 
138 def l1loss_double_backwards(ctx, ggI):
139  size_average = ctx.additional_args[0]
140  input, target, grad_output = ctx.saved_tensors
141  gI = torch.zeros_like(ggI)
142 
143  positive_mask = (input > target).type_as(ggI)
144  negative_mask = (input < target).type_as(ggI)
145  ggO = (ggI * (positive_mask - negative_mask)).sum()
146  if size_average:
147  ggO = ggO / input.nelement()
148  return gI, None, ggO, None, None
149 
150 
151 def mseloss_double_backwards(ctx, ggI):
152  size_average = ctx.additional_args[0]
153  reduce = ctx.additional_args[1]
154  input, target, gO = ctx.saved_tensors
155  div_factor = input.nelement() if size_average and reduce else 1
156 
157  gI = ggI * (gO * 2. / div_factor).expand_as(input)
158  if reduce:
159  ggO = (ggI * (input - target)).sum() * (2. / div_factor)
160  else:
161  ggO = (ggI * (input - target)) * 2.
162 
163  return gI, None, ggO, None, None
164 
165 
166 def nllloss_double_backwards(ctx, ggI):
167  t = ctx.saved_tensors
168  target = t[1]
169  weights = ctx.additional_args[1]
170  size_average = ctx.additional_args[0]
171  ignore_index = ctx.additional_args[3]
172  reduce = ctx.additional_args[4]
173 
174  gI = None
175 
176  # can't scatter/gather on indices outside of range, let's just put them in range
177  # and 0 out the weights later (so it doesn't matter where in range we put them)
178  target_mask = target == ignore_index
179  safe_target = target.clone()
180  safe_target.masked_fill_(target_mask, 0)
181 
182  if weights.dim() == 0:
183  weights_to_scatter = torch.ones_like(safe_target)
184  else:
185  weights_maybe_resized = weights
186  while weights_maybe_resized.dim() < target.dim():
187  weights_maybe_resized = weights_maybe_resized.unsqueeze(1)
188 
189  weights_maybe_resized = weights_maybe_resized.expand(weights.size()[0:1] + target.size()[1:])
190  weights_to_scatter = weights_maybe_resized.gather(0, safe_target)
191 
192  weights_to_scatter.masked_fill_(target_mask, 0)
193  divisor = weights_to_scatter.sum() if size_average and reduce else 1
194  weights_to_scatter = -1 * weights_to_scatter / divisor
195  zeros = torch.zeros_like(ggI)
196  mask = zeros.scatter_(1, safe_target.unsqueeze(1), weights_to_scatter.unsqueeze(1))
197 
198  if reduce:
199  ggO = (ggI * mask).sum()
200  else:
201  ggO = (ggI * mask).sum(dim=1)
202 
203  return gI, None, ggO, None, None, None
204 
205 
206 def smoothl1loss_double_backwards(ctx, ggI):
207  size_average = ctx.additional_args[0]
208  input, target, gO = ctx.saved_tensors
209  div_factor = input.nelement() if size_average else 1
210 
211  input_sub_target = input - target
212  small_error_mask = (input_sub_target.abs() < 1)
213  large_error_mask = (small_error_mask == 0)
214  large_error_pos_mask = (((input_sub_target > 0) + large_error_mask) == 2).type_as(ggI)
215  large_error_neg_mask = (((input_sub_target <= 0) + large_error_mask) == 2).type_as(ggI)
216  small_error_mask = small_error_mask.type_as(ggI)
217 
218  gI = small_error_mask * ggI * gO / div_factor
219  ggO = (ggI * (input_sub_target * small_error_mask + large_error_pos_mask - large_error_neg_mask)).sum() / div_factor
220 
221  return gI, None, ggO, None, None, None
222 
223 
224 def softmarginloss_double_backwards(ctx, ggI):
225  size_average = ctx.additional_args[0]
226  input, target, gO = ctx.saved_tensors
227  div_factor = input.nelement() if size_average else 1
228 
229  t0 = (1 + (-target * input).exp()).pow(-1)
230  t1 = (-target * (-target * input).exp())
231  first_deriv = t0 * t1
232 
233  gI = -1 * gO * ggI / div_factor * (first_deriv.pow(2) + first_deriv * target)
234  ggO = (ggI * first_deriv).sum() / div_factor
235 
236  return gI, None, ggO, None, None, None
237 
238 
239 double_backwards_fns = {
240  'ELU': elu_double_backwards,
241  'GatedLinear': gatedlinear_double_backwards,
242  'Hardshrink': hardshrink_double_backwards,
243  'Hardtanh': hardtanh_double_backwards,
244  'LeakyReLU': leakyrelu_double_backwards,
245  'LogSigmoid': logsigmoid_double_backwards,
246  'Softplus': softplus_double_backwards,
247  'Softshrink': softshrink_double_backwards,
248  'Threshold': threshold_double_backwards,
249  'KLDivLoss': klddivloss_double_backwards,
250  'L1Loss': l1loss_double_backwards,
251  'MSELoss': mseloss_double_backwards,
252  'NLLLoss': nllloss_double_backwards,
253  'NLLLoss2d': nllloss_double_backwards,
254  'SmoothL1Loss': smoothl1loss_double_backwards,
255  'SoftMarginLoss': softmarginloss_double_backwards,
256 }