Caffe2 - Python API
A deep learning, cross platform ML framework
test_verify.py
1 import torch
2 from torch.autograd import Function
3 from torch.nn import Module, Parameter
4 import caffe2.python.onnx.backend as backend
5 from verify import verify
6 
7 from test_pytorch_common import TestCase, run_tests
8 
9 import unittest
10 
11 
12 class TestVerify(TestCase):
13  maxDiff = None
14 
15  def assertVerifyExpectFail(self, *args, **kwargs):
16  try:
17  verify(*args, **kwargs)
18  except AssertionError as e:
19  if str(e):
20  # substring a small piece of string because the exact message
21  # depends on system's formatting settings
22  # self.assertExpected(str(e)[:60])
23  # NB: why we comment out the above check? because numpy keeps
24  # changing the error format, and we have to keep updating the
25  # expect files let's relax this constraint
26  return
27  else:
28  raise
29  # Don't put this in the try block; the AssertionError will catch it
30  self.assertTrue(False, msg="verify() did not fail when expected to")
31 
32  def test_result_different(self):
33  class BrokenAdd(Function):
34  @staticmethod
35  def symbolic(g, a, b):
36  return g.op("Add", a, b)
37 
38  @staticmethod
39  def forward(ctx, a, b):
40  return a.sub(b) # yahaha! you found me!
41 
42  class MyModel(Module):
43  def forward(self, x, y):
44  return BrokenAdd().apply(x, y)
45 
46  x = torch.tensor([1, 2])
47  y = torch.tensor([3, 4])
48  self.assertVerifyExpectFail(MyModel(), (x, y), backend)
49 
50  def test_jumbled_params(self):
51  class MyModel(Module):
52  def __init__(self):
53  super(MyModel, self).__init__()
54 
55  def forward(self, x):
56  y = x * x
57  self.param = Parameter(torch.tensor([2.0]))
58  return y
59 
60  x = torch.tensor([1, 2])
61  with self.assertRaisesRegex(RuntimeError, "state_dict changed"):
62  verify(MyModel(), x, backend)
63 
64  def test_modifying_params(self):
65  class MyModel(Module):
66  def __init__(self):
67  super(MyModel, self).__init__()
68  self.param = Parameter(torch.tensor([2.0]))
69 
70  def forward(self, x):
71  y = x * x
72  self.param.data.add_(1.0)
73  return y
74 
75  x = torch.tensor([1, 2])
76  self.assertVerifyExpectFail(MyModel(), x, backend)
77 
78  def test_dynamic_model_structure(self):
79  class MyModel(Module):
80  def __init__(self):
81  super(MyModel, self).__init__()
82  self.iters = 0
83 
84  def forward(self, x):
85  if self.iters % 2 == 0:
86  r = x * x
87  else:
88  r = x + x
89  self.iters += 1
90  return r
91 
92  x = torch.tensor([1, 2])
93  self.assertVerifyExpectFail(MyModel(), x, backend)
94 
95  @unittest.skip("Indexing is broken by #3725")
96  def test_embedded_constant_difference(self):
97  class MyModel(Module):
98  def __init__(self):
99  super(MyModel, self).__init__()
100  self.iters = 0
101 
102  def forward(self, x):
103  r = x[self.iters % 2]
104  self.iters += 1
105  return r
106 
107  x = torch.tensor([[1, 2], [3, 4]])
108  self.assertVerifyExpectFail(MyModel(), x, backend)
109 
110  def test_explicit_test_args(self):
111  class MyModel(Module):
112  def forward(self, x):
113  if x.data.sum() == 1.0:
114  return x + x
115  else:
116  return x * x
117 
118  x = torch.tensor([[6, 2]])
119  y = torch.tensor([[2, -1]])
120  self.assertVerifyExpectFail(MyModel(), x, backend, test_args=[(y,)])
121 
122 
123 if __name__ == '__main__':
124  run_tests()
def assertVerifyExpectFail(self, args, kwargs)
Definition: test_verify.py:15
Definition: verify.py:1