Caffe2 - Python API
A deep learning, cross platform ML framework
srresnet.py
1 import math
2 
3 from torch import nn
4 from torch.nn import init
5 
6 
7 def _initialize_orthogonal(conv):
8  prelu_gain = math.sqrt(2)
9  init.orthogonal(conv.weight, gain=prelu_gain)
10  if conv.bias is not None:
11  conv.bias.data.zero_()
12 
13 
14 class ResidualBlock(nn.Module):
15  def __init__(self, n_filters):
16  super(ResidualBlock, self).__init__()
17  self.conv1 = nn.Conv2d(n_filters, n_filters, kernel_size=3, padding=1, bias=False)
18  self.bn1 = nn.BatchNorm2d(n_filters)
19  self.prelu = nn.PReLU(n_filters)
20  self.conv2 = nn.Conv2d(n_filters, n_filters, kernel_size=3, padding=1, bias=False)
21  self.bn2 = nn.BatchNorm2d(n_filters)
22 
23  # Orthogonal initialisation
24  _initialize_orthogonal(self.conv1)
25  _initialize_orthogonal(self.conv2)
26 
27  def forward(self, x):
28  residual = self.prelu(self.bn1(self.conv1(x)))
29  residual = self.bn2(self.conv2(residual))
30  return x + residual
31 
32 
33 class UpscaleBlock(nn.Module):
34  def __init__(self, n_filters):
35  super(UpscaleBlock, self).__init__()
36  self.upscaling_conv = nn.Conv2d(n_filters, 4 * n_filters, kernel_size=3, padding=1)
37  self.upscaling_shuffler = nn.PixelShuffle(2)
38  self.upscaling = nn.PReLU(n_filters)
39  _initialize_orthogonal(self.upscaling_conv)
40 
41  def forward(self, x):
42  return self.upscaling(self.upscaling_shuffler(self.upscaling_conv(x)))
43 
44 
45 class SRResNet(nn.Module):
46  def __init__(self, rescale_factor, n_filters, n_blocks):
47  super(SRResNet, self).__init__()
48  self.rescale_levels = int(math.log(rescale_factor, 2))
49  self.n_filters = n_filters
50  self.n_blocks = n_blocks
51 
52  self.conv1 = nn.Conv2d(3, n_filters, kernel_size=9, padding=4)
53  self.prelu1 = nn.PReLU(n_filters)
54 
55  for residual_block_num in range(1, n_blocks + 1):
56  residual_block = ResidualBlock(self.n_filters)
57  self.add_module('residual_block' + str(residual_block_num), nn.Sequential(residual_block))
58 
59  self.skip_conv = nn.Conv2d(n_filters, n_filters, kernel_size=3, padding=1, bias=False)
60  self.skip_bn = nn.BatchNorm2d(n_filters)
61 
62  for upscale_block_num in range(1, self.rescale_levels + 1):
63  upscale_block = UpscaleBlock(self.n_filters)
64  self.add_module('upscale_block' + str(upscale_block_num), nn.Sequential(upscale_block))
65 
66  self.output_conv = nn.Conv2d(n_filters, 3, kernel_size=9, padding=4)
67 
68  # Orthogonal initialisation
69  _initialize_orthogonal(self.conv1)
70  _initialize_orthogonal(self.skip_conv)
71  _initialize_orthogonal(self.output_conv)
72 
73  def forward(self, x):
74  x_init = self.prelu1(self.conv1(x))
75  x = self.residual_block1(x_init)
76  for residual_block_num in range(2, self.n_blocks + 1):
77  x = getattr(self, 'residual_block' + str(residual_block_num))(x)
78  x = self.skip_bn(self.skip_conv(x)) + x_init
79  for upscale_block_num in range(1, self.rescale_levels + 1):
80  x = getattr(self, 'upscale_block' + str(upscale_block_num))(x)
81  return self.output_conv(x)