Caffe2 - Python API
A deep learning, cross platform ML framework
super_resolution.py
1 import torch
2 import torch.nn as nn
3 import torch.nn.init as init
4 
5 
6 class SuperResolutionNet(nn.Module):
7  def __init__(self, upscale_factor):
8  super(SuperResolutionNet, self).__init__()
9 
10  self.relu = nn.ReLU()
11  self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
12  self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
13  self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
14  self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
15  self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
16 
17  self._initialize_weights()
18 
19  def forward(self, x):
20  x = self.relu(self.conv1(x))
21  x = self.relu(self.conv2(x))
22  x = self.relu(self.conv3(x))
23  x = self.pixel_shuffle(self.conv4(x))
24  return x
25 
26  def _initialize_weights(self):
27  init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
28  init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
29  init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
30  init.orthogonal_(self.conv4.weight)