Caffe2 - Python API
A deep learning, cross platform ML framework
test_type_info.py
1 from common_utils import TestCase, run_tests, TEST_NUMPY, load_tests
2 
3 # load_tests from common_utils is used to automatically filter tests for
4 # sharding on sandcastle. This line silences flake warnings
5 load_tests = load_tests
6 
7 import torch
8 import unittest
9 
10 if TEST_NUMPY:
11  import numpy as np
12 
13 
15 
16  def test_invalid_input(self):
17  for dtype in [torch.float32, torch.float64]:
18  with self.assertRaises(TypeError):
19  _ = torch.iinfo(dtype)
20 
21  for dtype in [torch.int64, torch.int32, torch.int16, torch.uint8]:
22  with self.assertRaises(TypeError):
23  _ = torch.finfo(dtype)
24 
25  @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
26  def test_iinfo(self):
27  for dtype in [torch.int64, torch.int32, torch.int16, torch.uint8]:
28  x = torch.zeros((2, 2), dtype=dtype)
29  xinfo = torch.iinfo(x.dtype)
30  xn = x.cpu().numpy()
31  xninfo = np.iinfo(xn.dtype)
32  self.assertEqual(xinfo.bits, xninfo.bits)
33  self.assertEqual(xinfo.max, xninfo.max)
34  self.assertEqual(xinfo.min, xninfo.min)
35 
36  @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
37  def test_finfo(self):
38  initial_default_type = torch.get_default_dtype()
39  for dtype in [torch.float32, torch.float64]:
40  x = torch.zeros((2, 2), dtype=dtype)
41  xinfo = torch.finfo(x.dtype)
42  xn = x.cpu().numpy()
43  xninfo = np.finfo(xn.dtype)
44  self.assertEqual(xinfo.bits, xninfo.bits)
45  self.assertEqual(xinfo.max, xninfo.max)
46  self.assertEqual(xinfo.min, xninfo.min)
47  self.assertEqual(xinfo.eps, xninfo.eps)
48  self.assertEqual(xinfo.tiny, xninfo.tiny)
50  self.assertEqual(torch.finfo(dtype), torch.finfo())
51  # Restore the default type to ensure that the test has no side effect
52  torch.set_default_dtype(initial_default_type)
53 
54 if __name__ == '__main__':
55  run_tests()
def assertEqual(self, x, y, prec=None, message='', allow_inf=False)
def set_default_dtype(d)
Definition: __init__.py:156