Caffe2 - Python API
A deep learning, cross platform ML framework
test_pytorch_helper.py
1 # Some standard imports
2 import numpy as np
3 from torch import nn
4 from torch.autograd import Variable
5 import torch.onnx
6 import torch.nn.init as init
7 from caffe2.python.model_helper import ModelHelper
8 from pytorch_helper import PyTorchModule
9 import unittest
10 from caffe2.python.core import workspace
11 
12 from test_pytorch_common import skipIfNoLapack
13 
14 
15 class TestCaffe2Backend(unittest.TestCase):
16 
17  @skipIfNoLapack
18  def test_helper(self):
19 
20  class SuperResolutionNet(nn.Module):
21  def __init__(self, upscale_factor, inplace=False):
22  super(SuperResolutionNet, self).__init__()
23 
24  self.relu = nn.ReLU(inplace=inplace)
25  self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
26  self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
27  self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
28  self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
29  self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
30 
31  self._initialize_weights()
32 
33  def forward(self, x):
34  x = self.relu(self.conv1(x))
35  x = self.relu(self.conv2(x))
36  x = self.relu(self.conv3(x))
37  x = self.pixel_shuffle(self.conv4(x))
38  return x
39 
40  def _initialize_weights(self):
41  init.orthogonal(self.conv1.weight, init.calculate_gain('relu'))
42  init.orthogonal(self.conv2.weight, init.calculate_gain('relu'))
43  init.orthogonal(self.conv3.weight, init.calculate_gain('relu'))
44  init.orthogonal(self.conv4.weight)
45 
46  torch_model = SuperResolutionNet(upscale_factor=3)
47 
48  fake_input = torch.randn(1, 1, 224, 224, requires_grad=True)
49 
50  # use ModelHelper to create a C2 net
51  helper = ModelHelper(name="test_model")
52  start = helper.Sigmoid(['the_input'])
53  # Embed the ONNX-converted pytorch net inside it
54  toutput, = PyTorchModule(helper, torch_model, (fake_input,), [start])
55  output = helper.Sigmoid(toutput)
56 
57  workspace.RunNetOnce(helper.InitProto())
58  workspace.FeedBlob('the_input', fake_input.data.numpy())
59  # print([ k for k in workspace.blobs ])
60  workspace.RunNetOnce(helper.Proto())
61  c2_out = workspace.FetchBlob(str(output))
62 
63  torch_out = torch.sigmoid(torch_model(torch.sigmoid(fake_input)))
64 
65  np.testing.assert_almost_equal(torch_out.data.cpu().numpy(), c2_out, decimal=3)
66 
67 
68 if __name__ == '__main__':
69  unittest.main()