8 from pytorch_helper 
import PyTorchModule
    12 from test_pytorch_common 
import skipIfNoLapack
    18     def test_helper(self):
    20         class SuperResolutionNet(nn.Module):
    21             def __init__(self, upscale_factor, inplace=False):
    22                 super(SuperResolutionNet, self).__init__()
    24                 self.
relu = nn.ReLU(inplace=inplace)
    25                 self.
conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
    26                 self.
conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
    27                 self.
conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
    28                 self.
conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
    31                 self._initialize_weights()
    40             def _initialize_weights(self):
    41                 init.orthogonal(self.conv1.weight, init.calculate_gain(
'relu'))
    42                 init.orthogonal(self.conv2.weight, init.calculate_gain(
'relu'))
    43                 init.orthogonal(self.conv3.weight, init.calculate_gain(
'relu'))
    44                 init.orthogonal(self.conv4.weight)
    46         torch_model = SuperResolutionNet(upscale_factor=3)
    48         fake_input = torch.randn(1, 1, 224, 224, requires_grad=
True)
    52         start = helper.Sigmoid([
'the_input'])
    54         toutput, = PyTorchModule(helper, torch_model, (fake_input,), [start])
    55         output = helper.Sigmoid(toutput)
    57         workspace.RunNetOnce(helper.InitProto())
    58         workspace.FeedBlob(
'the_input', fake_input.data.numpy())
    60         workspace.RunNetOnce(helper.Proto())
    61         c2_out = workspace.FetchBlob(str(output))
    63         torch_out = torch.sigmoid(torch_model(torch.sigmoid(fake_input)))
    65         np.testing.assert_almost_equal(torch_out.data.cpu().numpy(), c2_out, decimal=3)
    68 if __name__ == 
'__main__':