8 from pytorch_helper
import PyTorchModule
12 from test_pytorch_common
import skipIfNoLapack
18 def test_helper(self):
20 class SuperResolutionNet(nn.Module):
21 def __init__(self, upscale_factor, inplace=False):
22 super(SuperResolutionNet, self).__init__()
24 self.
relu = nn.ReLU(inplace=inplace)
25 self.
conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
26 self.
conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
27 self.
conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
28 self.
conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
31 self._initialize_weights()
40 def _initialize_weights(self):
41 init.orthogonal(self.conv1.weight, init.calculate_gain(
'relu'))
42 init.orthogonal(self.conv2.weight, init.calculate_gain(
'relu'))
43 init.orthogonal(self.conv3.weight, init.calculate_gain(
'relu'))
44 init.orthogonal(self.conv4.weight)
46 torch_model = SuperResolutionNet(upscale_factor=3)
48 fake_input = torch.randn(1, 1, 224, 224, requires_grad=
True)
52 start = helper.Sigmoid([
'the_input'])
54 toutput, = PyTorchModule(helper, torch_model, (fake_input,), [start])
55 output = helper.Sigmoid(toutput)
57 workspace.RunNetOnce(helper.InitProto())
58 workspace.FeedBlob(
'the_input', fake_input.data.numpy())
60 workspace.RunNetOnce(helper.Proto())
61 c2_out = workspace.FetchBlob(str(output))
63 torch_out = torch.sigmoid(torch_model(torch.sigmoid(fake_input)))
65 np.testing.assert_almost_equal(torch_out.data.cpu().numpy(), c2_out, decimal=3)
68 if __name__ ==
'__main__':