3 from torch.nn import Module, Parameter
5 from verify
import verify
7 from test_pytorch_common
import TestCase, run_tests
15 def assertVerifyExpectFail(self, *args, **kwargs):
18 except AssertionError
as e:
30 self.assertTrue(
False, msg=
"verify() did not fail when expected to")
32 def test_result_different(self):
33 class BrokenAdd(Function):
35 def symbolic(g, a, b):
36 return g.op(
"Add", a, b)
39 def forward(ctx, a, b):
42 class MyModel(Module):
43 def forward(self, x, y):
44 return BrokenAdd().apply(x, y)
50 def test_jumbled_params(self):
51 class MyModel(Module):
53 super(MyModel, self).__init__()
61 with self.assertRaisesRegex(RuntimeError,
"state_dict changed"):
62 verify(MyModel(), x, backend)
64 def test_modifying_params(self):
65 class MyModel(Module):
67 super(MyModel, self).__init__()
72 self.param.data.add_(1.0)
78 def test_dynamic_model_structure(self):
79 class MyModel(Module):
81 super(MyModel, self).__init__()
85 if self.
iters % 2 == 0:
95 @unittest.skip(
"Indexing is broken by #3725")
96 def test_embedded_constant_difference(self):
97 class MyModel(Module):
99 super(MyModel, self).__init__()
102 def forward(self, x):
103 r = x[self.
iters % 2]
110 def test_explicit_test_args(self):
111 class MyModel(Module):
112 def forward(self, x):
113 if x.data.sum() == 1.0:
123 if __name__ ==
'__main__':
def assertVerifyExpectFail(self, args, kwargs)