7 def __init__(self, upscale_factor):
8 super(SuperResolutionNet, self).__init__()
11 self.
conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
12 self.
conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
13 self.
conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
14 self.
conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
26 def _initialize_weights(self):
27 init.orthogonal_(self.conv1.weight, init.calculate_gain(
'relu'))
28 init.orthogonal_(self.conv2.weight, init.calculate_gain(
'relu'))
29 init.orthogonal_(self.conv3.weight, init.calculate_gain(
'relu'))
30 init.orthogonal_(self.conv4.weight)
def _initialize_weights(self)