Caffe2 - Python API
A deep learning, cross platform ML framework
test_models.py
1 from torchvision.models.alexnet import alexnet
2 from torchvision.models.inception import inception_v3
3 from torchvision.models.densenet import densenet121
4 from torchvision.models.resnet import resnet50
5 from torchvision.models.vgg import vgg16, vgg16_bn, vgg19, vgg19_bn
6 
7 from model_defs.mnist import MNIST
8 from model_defs.word_language_model import RNNModel
9 from model_defs.squeezenet import SqueezeNet
10 from model_defs.super_resolution import SuperResolutionNet
11 from model_defs.srresnet import SRResNet
12 from model_defs.dcgan import _netD, _netG, weights_init, bsz, imgsz, nz
13 from model_defs.op_test import DummyNet, ConcatNet, PermuteNet, PReluNet
14 
15 from test_pytorch_common import TestCase, run_tests, skipIfNoLapack
16 
17 import torch
18 import torch.onnx
19 import torch.onnx.utils
20 from torch.autograd import Variable, Function
21 from torch.nn import Module
22 from torch.onnx import OperatorExportTypes
23 
24 import onnx
25 import onnx.checker
26 import onnx.helper
27 
28 import google.protobuf.text_format
29 
30 import io
31 import unittest
32 
33 import caffe2.python.onnx.backend as backend
34 
35 from verify import verify
36 
38  def toC(x):
39  return x.cuda()
40 else:
41  def toC(x):
42  return x
43 
44 BATCH_SIZE = 2
45 
46 
47 class TestModels(TestCase):
48  def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7):
49  trace = torch.onnx.utils._trace(model, inputs, OperatorExportTypes.ONNX)
50  torch._C._jit_pass_lint(trace.graph())
51  verify(model, inputs, backend, rtol=rtol, atol=atol)
52 
53  def test_ops(self):
54  x = Variable(
55  torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)
56  )
57  self.exportTest(toC(DummyNet()), toC(x))
58 
59  def test_prelu(self):
60  x = Variable(
61  torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)
62  )
63  self.exportTest(PReluNet(), x)
64 
65  def test_concat(self):
66  input_a = Variable(torch.randn(BATCH_SIZE, 3))
67  input_b = Variable(torch.randn(BATCH_SIZE, 3))
68  inputs = ((toC(input_a), toC(input_b)), )
69  self.exportTest(toC(ConcatNet()), inputs)
70 
71  def test_permute(self):
72  x = Variable(torch.randn(BATCH_SIZE, 3, 10, 12))
73  self.exportTest(PermuteNet(), x)
74 
75  @unittest.skip("This model takes too much memory")
76  def test_srresnet(self):
77  x = Variable(torch.randn(1, 3, 224, 224).fill_(1.0))
78  self.exportTest(toC(SRResNet(rescale_factor=4, n_filters=64, n_blocks=8)), toC(x))
79 
80  @skipIfNoLapack
81  def test_super_resolution(self):
82  x = Variable(
83  torch.randn(BATCH_SIZE, 1, 224, 224).fill_(1.0)
84  )
85  self.exportTest(toC(SuperResolutionNet(upscale_factor=3)), toC(x), atol=1e-6)
86 
87  def test_alexnet(self):
88  x = Variable(
89  torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)
90  )
91  self.exportTest(toC(alexnet()), toC(x))
92 
93  @unittest.skip("Waiting for https://github.com/pytorch/pytorch/pull/3100")
94  def test_mnist(self):
95  x = Variable(torch.randn(BATCH_SIZE, 1, 28, 28).fill_(1.0))
96  self.exportTest(toC(MNIST()), toC(x))
97 
98  def test_vgg16(self):
99  # VGG 16-layer model (configuration "D")
100  x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
101  self.exportTest(toC(vgg16()), toC(x))
102 
103  def test_vgg16_bn(self):
104  # VGG 16-layer model (configuration "D") with batch normalization
105  x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
106  self.exportTest(toC(vgg16_bn()), toC(x))
107 
108  def test_vgg19(self):
109  # VGG 19-layer model (configuration "E")
110  x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
111  self.exportTest(toC(vgg19()), toC(x))
112 
113  def test_vgg19_bn(self):
114  # VGG 19-layer model (configuration 'E') with batch normalization
115  x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
116  self.exportTest(toC(vgg19_bn()), toC(x))
117 
118  def test_resnet(self):
119  # ResNet50 model
120  x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
121  self.exportTest(toC(resnet50()), toC(x), atol=1e-6)
122 
123  def test_inception(self):
124  x = Variable(
125  torch.randn(BATCH_SIZE, 3, 299, 299) + 1.)
126  self.exportTest(toC(inception_v3()), toC(x))
127 
128  def test_squeezenet(self):
129  # SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and
130  # <0.5MB model size
131  x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
132  sqnet_v1_0 = SqueezeNet(version=1.1)
133  self.exportTest(toC(sqnet_v1_0), toC(x))
134 
135  # SqueezeNet 1.1 has 2.4x less computation and slightly fewer params
136  # than SqueezeNet 1.0, without sacrificing accuracy.
137  x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
138  sqnet_v1_1 = SqueezeNet(version=1.1)
139  self.exportTest(toC(sqnet_v1_1), toC(x))
140 
141  @unittest.skip("Temporary - waiting for https://github.com/onnx/onnx/pull/1773.")
142  def test_densenet(self):
143  # Densenet-121 model
144  x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
145  self.exportTest(toC(densenet121()), toC(x))
146 
147  def test_dcgan_netD(self):
148  netD = _netD(1)
149  netD.apply(weights_init)
150  input = Variable(torch.Tensor(bsz, 3, imgsz, imgsz).normal_(0, 1))
151  self.exportTest(toC(netD), toC(input))
152 
153  def test_dcgan_netG(self):
154  netG = _netG(1)
155  netG.apply(weights_init)
156  input = Variable(torch.Tensor(bsz, nz, 1, 1).normal_(0, 1))
157  self.exportTest(toC(netG), toC(input))
158 
159 if __name__ == '__main__':
160  run_tests()
def is_available()
Definition: __init__.py:45
def _trace(func, args, operator_export_type, return_outs=False)
Definition: utils.py:180
Definition: verify.py:1
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7)
Definition: test_models.py:48