2 The testing package contains testing-specific utilities. 8 FileCheck = torch._C.FileCheck
11 'assert_allclose',
'make_non_contiguous',
'rand_like',
'randn_like' 14 rand_like = torch.rand_like
15 randn_like = torch.randn_like
18 def assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True):
19 if not isinstance(actual, torch.Tensor):
21 if not isinstance(expected, torch.Tensor):
23 if expected.shape != actual.shape:
24 expected = expected.expand_as(actual)
25 if rtol
is None or atol
is None:
26 if rtol
is not None or atol
is not None:
27 raise ValueError(
"rtol and atol must both be specified or both be unspecified")
28 rtol, atol = _get_default_tolerance(actual, expected)
30 close = torch.isclose(actual, expected, rtol, atol, equal_nan)
35 error = (expected - actual).abs()
36 expected_error = atol + rtol * expected.abs()
37 delta = error - expected_error
39 _, index = delta.reshape(-1).max(0)
42 def _unravel_index(index, shape):
44 for size
in shape[::-1]:
45 res.append(int(index % size))
46 index = int(index // size)
47 return tuple(res[::-1])
49 index = _unravel_index(index.item(), actual.shape)
52 count = (~close).long().sum()
54 msg = (
'Not within tolerance rtol={} atol={} at input{} ({} vs. {}) and {}' 55 ' other locations ({:2.2f}%)')
57 raise AssertionError(msg.format(
58 rtol, atol, list(index), actual[index].item(), expected[index].item(),
59 count - 1, 100 * count / actual.numel()))
62 def make_non_contiguous(tensor):
63 if tensor.numel() <= 1:
65 osize = list(tensor.size())
69 dim = random.randint(0, len(osize) - 1)
70 add = random.randint(4, 15)
71 osize[dim] = osize[dim] + add
77 input = tensor.new(torch.Size(osize + [random.randint(2, 3)]))
78 input = input.select(len(input.size()) - 1, random.randint(0, 1))
80 for i
in range(len(osize)):
81 if input.size(i) != tensor.size(i):
82 bounds = random.randint(1, input.size(i) - tensor.size(i))
83 input = input.narrow(i, bounds, tensor.size(i))
90 return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64,
91 torch.float16, torch.float32, torch.float64]
95 _default_tolerances = {
96 'float64': (1e-5, 1e-8),
97 'float32': (1e-4, 1e-5),
98 'float16': (1e-3, 1e-3),
102 def _get_default_tolerance(a, b=None):
104 dtype = str(a.dtype).
split(
'.')[-1]
105 return _default_tolerances.get(dtype, (0, 0))
106 a_tol = _get_default_tolerance(a)
107 b_tol = _get_default_tolerance(b)
108 return (max(a_tol[0], b_tol[0]), max(a_tol[1], b_tol[1]))
Module caffe2.python.layers.split.