5 from .
import _all_functions
10 def __init__(self, size, alpha=1e-4, beta=0.75, k=1):
11 super(CrossMapLRN2d, self).__init__()
19 def forward(self, input):
20 assert input.dim() == 4
25 backend = type2backend[input.type()]
26 if backend
is not None:
28 backend.SpatialCrossMapLRN_updateOutput
30 except NotImplementedError:
34 self._backend.SpatialCrossMapLRN_updateOutput(
35 self._backend.library_state,
45 batch_size = input.size(0)
46 channels = input.size(1)
47 input_height = input.size(2)
48 input_width = input.size(3)
50 output.resize_as_(input)
51 self.scale.resize_as_(input)
55 torch.pow(input, 2, out=input_square)
57 pre_pad = int((self.
size - 1) / 2 + 1)
58 pre_pad_crop = channels
if pre_pad > channels
else pre_pad
60 scale_first = self.scale.select(1, 0)
63 for c
in range(pre_pad_crop):
64 scale_first.add_(input_square.select(1, c))
68 for c
in range(1, channels):
69 scale_previous = self.scale.select(1, c - 1)
70 scale_current = self.scale.select(1, c)
71 scale_current.copy_(scale_previous)
72 if c < channels - pre_pad + 1:
73 square_next = input_square.select(1, c + pre_pad - 1)
74 scale_current.add_(1, square_next)
77 square_previous = input_square.select(1, c - pre_pad)
78 scale_current.add_(-1, square_previous)
80 self.scale.mul_(self.
alpha / self.
size).add_(self.
k)
82 torch.pow(self.
scale, -self.
beta, out=output)
88 def backward(self, grad_output):
89 input, output = self.saved_tensors
90 grad_input = grad_output.new()
93 self._backend.SpatialCrossMapLRN_updateGradInput(
94 self._backend.library_state,
106 batch_size = input.size(0)
107 channels = input.size(1)
108 input_height = input.size(2)
109 input_width = input.size(3)
111 paddded_ratio = input.new(channels + self.
size - 1, input_height,
113 accum_ratio = input.new(input_height, input_width)
116 inversePrePad = int(self.
size - (self.
size - 1) / 2)
118 grad_input.resize_as_(input)
119 torch.pow(self.
scale, -self.
beta, out=grad_input).mul_(grad_output)
121 paddded_ratio.zero_()
122 padded_ratio_center = paddded_ratio.narrow(0, inversePrePad,
124 for n
in range(batch_size):
125 torch.mul(grad_output[n], output[n], out=padded_ratio_center)
126 padded_ratio_center.div_(self.
scale[n])
128 paddded_ratio.narrow(0, 0, self.
size - 1), 0, keepdim=
False, out=accum_ratio)
129 for c
in range(channels):
130 accum_ratio.add_(paddded_ratio[c + self.
size - 1])
131 grad_input[n][c].addcmul_(-cache_ratio_value, input[n][c],
133 accum_ratio.add_(-1, paddded_ratio[c])
138 _all_functions.append(CrossMapLRN2d)
def save_for_backward(self, tensors)