7 def _initialize_orthogonal(conv):
8 prelu_gain = math.sqrt(2)
9 init.orthogonal(conv.weight, gain=prelu_gain)
10 if conv.bias
is not None:
11 conv.bias.data.zero_()
15 def __init__(self, n_filters):
16 super(ResidualBlock, self).__init__()
17 self.
conv1 = nn.Conv2d(n_filters, n_filters, kernel_size=3, padding=1, bias=
False)
18 self.
bn1 = nn.BatchNorm2d(n_filters)
19 self.
prelu = nn.PReLU(n_filters)
20 self.
conv2 = nn.Conv2d(n_filters, n_filters, kernel_size=3, padding=1, bias=
False)
21 self.
bn2 = nn.BatchNorm2d(n_filters)
24 _initialize_orthogonal(self.
conv1)
25 _initialize_orthogonal(self.
conv2)
29 residual = self.
bn2(self.
conv2(residual))
34 def __init__(self, n_filters):
35 super(UpscaleBlock, self).__init__()
36 self.
upscaling_conv = nn.Conv2d(n_filters, 4 * n_filters, kernel_size=3, padding=1)
46 def __init__(self, rescale_factor, n_filters, n_blocks):
47 super(SRResNet, self).__init__()
52 self.
conv1 = nn.Conv2d(3, n_filters, kernel_size=9, padding=4)
53 self.
prelu1 = nn.PReLU(n_filters)
55 for residual_block_num
in range(1, n_blocks + 1):
57 self.add_module(
'residual_block' + str(residual_block_num), nn.Sequential(residual_block))
59 self.
skip_conv = nn.Conv2d(n_filters, n_filters, kernel_size=3, padding=1, bias=
False)
60 self.
skip_bn = nn.BatchNorm2d(n_filters)
64 self.add_module(
'upscale_block' + str(upscale_block_num), nn.Sequential(upscale_block))
66 self.
output_conv = nn.Conv2d(n_filters, 3, kernel_size=9, padding=4)
69 _initialize_orthogonal(self.
conv1)
75 x = self.residual_block1(x_init)
76 for residual_block_num
in range(2, self.
n_blocks + 1):
77 x = getattr(self,
'residual_block' + str(residual_block_num))(x)
80 x = getattr(self,
'upscale_block' + str(upscale_block_num))(x)