Caffe2 - Python API
A deep learning, cross platform ML framework
test_custom_ops.py
1 import argparse
2 import os.path
3 import tempfile
4 import unittest
5 
6 import torch
7 from torch import ops
8 
9 from model import Model, get_custom_op_library_path
10 
11 
12 class TestCustomOperators(unittest.TestCase):
13  def setUp(self):
14  self.library_path = get_custom_op_library_path()
15  ops.load_library(self.library_path)
16 
17  def test_custom_library_is_loaded(self):
18  self.assertIn(self.library_path, ops.loaded_libraries)
19 
20  def test_calling_custom_op_string(self):
21  output = ops.custom.op2("abc", "def")
22  self.assertLess(output, 0)
23  output = ops.custom.op2("abc", "abc")
24  self.assertEqual(output, 0)
25 
26  def test_calling_custom_op(self):
27  output = ops.custom.op(torch.ones(5), 2.0, 3)
28  self.assertEqual(type(output), list)
29  self.assertEqual(len(output), 3)
30  for tensor in output:
31  self.assertTrue(tensor.allclose(torch.ones(5) * 2))
32 
33  output = ops.custom.op_with_defaults(torch.ones(5))
34  self.assertEqual(type(output), list)
35  self.assertEqual(len(output), 1)
36  self.assertTrue(output[0].allclose(torch.ones(5)))
37 
38  def test_calling_custom_op_inside_script_module(self):
39  model = Model()
40  output = model.forward(torch.ones(5))
41  self.assertTrue(output.allclose(torch.ones(5) + 1))
42 
43  def test_saving_and_loading_script_module_with_custom_op(self):
44  model = Model()
45  # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
46  # opens the file, and it cannot be opened multiple times in Windows. To support Windows,
47  # close the file after creation and try to remove it manually.
48  file = tempfile.NamedTemporaryFile(delete=False)
49  try:
50  file.close()
51  model.save(file.name)
52  loaded = torch.jit.load(file.name)
53  finally:
54  os.unlink(file.name)
55 
56  output = loaded.forward(torch.ones(5))
57  self.assertTrue(output.allclose(torch.ones(5) + 1))
58 
59 
60 if __name__ == "__main__":
61  unittest.main()
def load(f, map_location=None, _extra_files=DEFAULT_EXTRA_FILES_MAP)
Definition: __init__.py:82