4 def elu_double_backwards(ctx, ggI):
6 input, grad_output = t[0], t[1]
7 alpha = ctx.additional_args[0]
9 negative_mask = (input < 0).type_as(ggI)
10 exp_alpha = input.exp() * alpha * negative_mask
11 gI = ggI * grad_output * exp_alpha
13 non_negative_mask = (input >= 0).type_as(ggI)
14 ggO = ggI * (exp_alpha + non_negative_mask)
15 return gI, ggO,
None,
None,
None,
None 18 def gatedlinear_double_backwards(ctx, ggI):
19 input, gO = ctx.saved_tensors
20 dim = ctx.additional_args[0]
22 input_size = input.size(dim) // 2
24 first_half = input.narrow(dim, 0, input_size)
25 second_half = input.narrow(dim, input_size, input_size)
26 sig_second_half = second_half.sigmoid()
27 one_sub_sig_second_half = 1 - sig_second_half
28 sig_one_sub_sig = sig_second_half * one_sub_sig_second_half
30 ggI_first_half = ggI.narrow(dim, 0, input_size)
31 ggI_second_half = ggI.narrow(dim, input_size, input_size)
32 ggI_second_half_times_first_half = ggI_second_half * first_half
34 gI_first_half = ggI_second_half * gO * sig_one_sub_sig
35 second_order_sh = sig_one_sub_sig * one_sub_sig_second_half - sig_second_half * sig_one_sub_sig
36 gI_second_half = ggI_second_half_times_first_half * gO * second_order_sh + ggI_first_half * gO * sig_one_sub_sig
37 gI = torch.cat((gI_first_half, gI_second_half), dim)
39 ggO = ggI_first_half * sig_second_half + ggI_second_half_times_first_half * sig_one_sub_sig
41 return gI, ggO,
None,
None,
None 44 def hardshrink_double_backwards(ctx, ggI):
47 lambd = ctx.additional_args[0]
50 mask = torch.zeros_like(input).masked_fill_(input > lambd, 1).masked_fill_(input < -lambd, 1)
53 return gI, ggO,
None,
None,
None 56 def hardtanh_double_backwards(ctx, ggI):
58 input, grad_output = t[0], t[1]
59 min_val, max_val = ctx.additional_args[0:2]
61 max_mask = input <= max_val
62 min_mask = input <= min_val
63 gI = torch.zeros_like(ggI)
64 ggO = ggI * (max_mask - min_mask).type_as(grad_output)
65 return gI, ggO,
None,
None,
None 68 def leakyrelu_double_backwards(ctx, ggI):
71 negative_slope = ctx.additional_args[0]
73 gI = torch.zeros_like(ggI)
74 input_lt_0 = (input < 0).type_as(ggI)
75 input_ge_0 = (input >= 0).type_as(ggI)
76 ggO = ggI * (input_lt_0 * negative_slope + input_ge_0)
77 return gI, ggO,
None,
None,
None 80 def logsigmoid_double_backwards(ctx, ggI):
83 input, gO = t[0], t[1]
85 exp_input = input.exp()
86 exp_input_plus_1 = exp_input + 1
87 gI = ggI * gO * -1 * exp_input / (exp_input_plus_1.pow(2))
88 ggO = ggI / exp_input_plus_1
90 return gI, ggO,
None,
None,
None,
None 93 def softplus_double_backwards(ctx, ggI):
95 input, gO, output = t[0], t[1], t[2]
96 beta, threshold = ctx.additional_args[0], ctx.additional_args[1]
98 input_beta = input * beta
99 above_threshold = torch.zeros_like(ggI).masked_fill_(input_beta > threshold, 1)
100 below_threshold = torch.zeros_like(ggI).masked_fill_(input_beta <= threshold, 1)
102 exp_output_beta = (output * beta).exp()
103 first_deriv = (exp_output_beta - 1) / exp_output_beta
104 first_deriv_below_threshold = first_deriv * below_threshold
106 gI = ggI * gO * first_deriv_below_threshold * beta / exp_output_beta
107 ggO = ggI * (above_threshold + first_deriv_below_threshold)
109 return gI, ggO,
None,
None,
None,
None 112 def softshrink_double_backwards(ctx, ggI):
113 return hardshrink_double_backwards(ctx, ggI)
116 def threshold_double_backwards(ctx, ggI):
117 t = ctx.saved_tensors
119 threshold, value = ctx.additional_args[0:2]
121 gI = torch.zeros_like(ggI)
122 input_gt_threshold = (input > threshold).type_as(ggI)
123 ggO = ggI * input_gt_threshold
124 return gI, ggO,
None,
None,
None 127 def klddivloss_double_backwards(ctx, ggI):
128 size_average = ctx.additional_args[0]
129 input, target, gO = ctx.saved_tensors
130 div_factor = input.nelement()
if size_average
else 1
133 ggO = (ggI * target).sum() / -div_factor
135 return gI,
None, ggO,
None,
None 138 def l1loss_double_backwards(ctx, ggI):
139 size_average = ctx.additional_args[0]
140 input, target, grad_output = ctx.saved_tensors
141 gI = torch.zeros_like(ggI)
143 positive_mask = (input > target).type_as(ggI)
144 negative_mask = (input < target).type_as(ggI)
145 ggO = (ggI * (positive_mask - negative_mask)).sum()
147 ggO = ggO / input.nelement()
148 return gI,
None, ggO,
None,
None 151 def mseloss_double_backwards(ctx, ggI):
152 size_average = ctx.additional_args[0]
153 reduce = ctx.additional_args[1]
154 input, target, gO = ctx.saved_tensors
155 div_factor = input.nelement()
if size_average
and reduce
else 1
157 gI = ggI * (gO * 2. / div_factor).expand_as(input)
159 ggO = (ggI * (input - target)).sum() * (2. / div_factor)
161 ggO = (ggI * (input - target)) * 2.
163 return gI,
None, ggO,
None,
None 166 def nllloss_double_backwards(ctx, ggI):
167 t = ctx.saved_tensors
169 weights = ctx.additional_args[1]
170 size_average = ctx.additional_args[0]
171 ignore_index = ctx.additional_args[3]
172 reduce = ctx.additional_args[4]
178 target_mask = target == ignore_index
179 safe_target = target.clone()
180 safe_target.masked_fill_(target_mask, 0)
182 if weights.dim() == 0:
183 weights_to_scatter = torch.ones_like(safe_target)
185 weights_maybe_resized = weights
186 while weights_maybe_resized.dim() < target.dim():
187 weights_maybe_resized = weights_maybe_resized.unsqueeze(1)
189 weights_maybe_resized = weights_maybe_resized.expand(weights.size()[0:1] + target.size()[1:])
190 weights_to_scatter = weights_maybe_resized.gather(0, safe_target)
192 weights_to_scatter.masked_fill_(target_mask, 0)
193 divisor = weights_to_scatter.sum()
if size_average
and reduce
else 1
194 weights_to_scatter = -1 * weights_to_scatter / divisor
195 zeros = torch.zeros_like(ggI)
196 mask = zeros.scatter_(1, safe_target.unsqueeze(1), weights_to_scatter.unsqueeze(1))
199 ggO = (ggI * mask).sum()
201 ggO = (ggI * mask).sum(dim=1)
203 return gI,
None, ggO,
None,
None,
None 206 def smoothl1loss_double_backwards(ctx, ggI):
207 size_average = ctx.additional_args[0]
208 input, target, gO = ctx.saved_tensors
209 div_factor = input.nelement()
if size_average
else 1
211 input_sub_target = input - target
212 small_error_mask = (input_sub_target.abs() < 1)
213 large_error_mask = (small_error_mask == 0)
214 large_error_pos_mask = (((input_sub_target > 0) + large_error_mask) == 2).type_as(ggI)
215 large_error_neg_mask = (((input_sub_target <= 0) + large_error_mask) == 2).type_as(ggI)
216 small_error_mask = small_error_mask.type_as(ggI)
218 gI = small_error_mask * ggI * gO / div_factor
219 ggO = (ggI * (input_sub_target * small_error_mask + large_error_pos_mask - large_error_neg_mask)).sum() / div_factor
221 return gI,
None, ggO,
None,
None,
None 224 def softmarginloss_double_backwards(ctx, ggI):
225 size_average = ctx.additional_args[0]
226 input, target, gO = ctx.saved_tensors
227 div_factor = input.nelement()
if size_average
else 1
229 t0 = (1 + (-target * input).exp()).pow(-1)
230 t1 = (-target * (-target * input).exp())
231 first_deriv = t0 * t1
233 gI = -1 * gO * ggI / div_factor * (first_deriv.pow(2) + first_deriv * target)
234 ggO = (ggI * first_deriv).sum() / div_factor
236 return gI,
None, ggO,
None,
None,
None 239 double_backwards_fns = {
240 'ELU': elu_double_backwards,
241 'GatedLinear': gatedlinear_double_backwards,
242 'Hardshrink': hardshrink_double_backwards,
243 'Hardtanh': hardtanh_double_backwards,
244 'LeakyReLU': leakyrelu_double_backwards,
245 'LogSigmoid': logsigmoid_double_backwards,
246 'Softplus': softplus_double_backwards,
247 'Softshrink': softshrink_double_backwards,
248 'Threshold': threshold_double_backwards,
249 'KLDivLoss': klddivloss_double_backwards,
250 'L1Loss': l1loss_double_backwards,
251 'MSELoss': mseloss_double_backwards,
252 'NLLLoss': nllloss_double_backwards,
253 'NLLLoss2d': nllloss_double_backwards,
254 'SmoothL1Loss': smoothl1loss_double_backwards,
255 'SoftMarginLoss': softmarginloss_double_backwards,