4 from __future__
import absolute_import
5 from __future__
import division
6 from __future__
import print_function
7 from __future__
import unicode_literals
19 np.random.seed(seed=0)
21 def assertSameOutputs(self, outputs1, outputs2, decimal=7):
22 self.assertEqual(len(outputs1), len(outputs2))
23 for o1, o2
in zip(outputs1, outputs2):
24 self.assertEqual(o1.dtype, o2.dtype)
25 np.testing.assert_almost_equal(o1, o2, decimal=decimal)
27 def add_test_case(self, name, test_func):
28 if not name.startswith(
'test_'):
29 raise ValueError(
'Test name must start with test_: {}'.format(name))
30 if hasattr(self, name):
31 raise ValueError(
'Duplicated test name: {}'.format(name))
32 setattr(self, name, test_func)
37 def _download(self, model):
38 model_dir = self._model_dir(model)
39 assert not os.path.exists(model_dir)
40 os.makedirs(model_dir)
41 for f
in [
'predict_net.pb',
'init_net.pb',
'value_info.json']:
42 url = getURLFromName(model, f)
43 dest = os.path.join(model_dir, f)
46 downloadFromURLToFile(url, dest,
52 downloadFromURLToFile(url, dest)
53 except Exception
as e:
54 print(
"Abort: {reason}".format(reason=e))
55 print(
"Cleaning up...")
56 deleteDirectory(model_dir)
57 raise AssertionError(
"Test model downloading failed")