9 from model
import Model, get_custom_op_library_path
17 def test_custom_library_is_loaded(self):
20 def test_calling_custom_op_string(self):
21 output = ops.custom.op2(
"abc",
"def")
22 self.assertLess(output, 0)
23 output = ops.custom.op2(
"abc",
"abc")
24 self.assertEqual(output, 0)
26 def test_calling_custom_op(self):
27 output = ops.custom.op(torch.ones(5), 2.0, 3)
28 self.assertEqual(type(output), list)
29 self.assertEqual(len(output), 3)
31 self.assertTrue(tensor.allclose(torch.ones(5) * 2))
33 output = ops.custom.op_with_defaults(torch.ones(5))
34 self.assertEqual(type(output), list)
35 self.assertEqual(len(output), 1)
36 self.assertTrue(output[0].allclose(torch.ones(5)))
38 def test_calling_custom_op_inside_script_module(self):
40 output = model.forward(torch.ones(5))
41 self.assertTrue(output.allclose(torch.ones(5) + 1))
43 def test_saving_and_loading_script_module_with_custom_op(self):
48 file = tempfile.NamedTemporaryFile(delete=
False)
56 output = loaded.forward(torch.ones(5))
57 self.assertTrue(output.allclose(torch.ones(5) + 1))
60 if __name__ ==
"__main__":
def load(f, map_location=None, _extra_files=DEFAULT_EXTRA_FILES_MAP)