1 from common_utils
import TestCase, run_tests, TEST_NUMPY, load_tests
5 load_tests = load_tests
16 def test_invalid_input(self):
17 for dtype
in [torch.float32, torch.float64]:
18 with self.assertRaises(TypeError):
19 _ = torch.iinfo(dtype)
21 for dtype
in [torch.int64, torch.int32, torch.int16, torch.uint8]:
22 with self.assertRaises(TypeError):
23 _ = torch.finfo(dtype)
25 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
27 for dtype
in [torch.int64, torch.int32, torch.int16, torch.uint8]:
28 x = torch.zeros((2, 2), dtype=dtype)
29 xinfo = torch.iinfo(x.dtype)
31 xninfo = np.iinfo(xn.dtype)
36 @unittest.skipIf(
not TEST_NUMPY,
"Numpy not found")
38 initial_default_type = torch.get_default_dtype()
39 for dtype
in [torch.float32, torch.float64]:
40 x = torch.zeros((2, 2), dtype=dtype)
41 xinfo = torch.finfo(x.dtype)
43 xninfo = np.finfo(xn.dtype)
54 if __name__ ==
'__main__':
def assertEqual(self, x, y, prec=None, message='', allow_inf=False)