1 from torchvision.models.alexnet
import alexnet
2 from torchvision.models.inception
import inception_v3
3 from torchvision.models.densenet
import densenet121
4 from torchvision.models.resnet
import resnet50
5 from torchvision.models.vgg
import vgg16, vgg16_bn, vgg19, vgg19_bn
13 from model_defs.op_test
import DummyNet, ConcatNet, PermuteNet, PReluNet
15 from test_pytorch_common
import TestCase, run_tests, skipIfNoLapack
28 import google.protobuf.text_format
35 from verify
import verify
48 def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7):
50 torch._C._jit_pass_lint(trace.graph())
51 verify(model, inputs, backend, rtol=rtol, atol=atol)
55 torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)
61 torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)
65 def test_concat(self):
66 input_a = Variable(torch.randn(BATCH_SIZE, 3))
67 input_b = Variable(torch.randn(BATCH_SIZE, 3))
68 inputs = ((toC(input_a), toC(input_b)), )
71 def test_permute(self):
72 x = Variable(torch.randn(BATCH_SIZE, 3, 10, 12))
75 @unittest.skip(
"This model takes too much memory")
76 def test_srresnet(self):
77 x = Variable(torch.randn(1, 3, 224, 224).fill_(1.0))
81 def test_super_resolution(self):
83 torch.randn(BATCH_SIZE, 1, 224, 224).fill_(1.0)
87 def test_alexnet(self):
89 torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)
93 @unittest.skip(
"Waiting for https://github.com/pytorch/pytorch/pull/3100")
95 x = Variable(torch.randn(BATCH_SIZE, 1, 28, 28).fill_(1.0))
100 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
103 def test_vgg16_bn(self):
105 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
108 def test_vgg19(self):
110 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
113 def test_vgg19_bn(self):
115 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
118 def test_resnet(self):
120 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
121 self.
exportTest(toC(resnet50()), toC(x), atol=1e-6)
123 def test_inception(self):
125 torch.randn(BATCH_SIZE, 3, 299, 299) + 1.)
128 def test_squeezenet(self):
131 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
137 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
141 @unittest.skip(
"Temporary - waiting for https://github.com/onnx/onnx/pull/1773.")
142 def test_densenet(self):
144 x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
147 def test_dcgan_netD(self):
149 netD.apply(weights_init)
150 input = Variable(torch.Tensor(bsz, 3, imgsz, imgsz).normal_(0, 1))
153 def test_dcgan_netG(self):
155 netG.apply(weights_init)
156 input = Variable(torch.Tensor(bsz, nz, 1, 1).normal_(0, 1))
159 if __name__ ==
'__main__':
def _trace(func, args, operator_export_type, return_outs=False)
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7)