Caffe2 - Python API
A deep learning, cross platform ML framework
common_methods_invocations.py
1 import torch
2 from torch._six import inf, nan, istuple
3 from functools import reduce, wraps
4 from operator import mul, itemgetter
5 from torch.autograd import Variable, Function, detect_anomaly
6 from torch.testing import make_non_contiguous
7 from common_utils import (skipIfNoLapack,
8  prod_single_zero, random_square_matrix_of_rank,
9  random_symmetric_matrix, random_symmetric_psd_matrix,
10  random_symmetric_pd_matrix, make_nonzero_det,
11  random_fullrank_matrix_distinct_singular_value, set_rng_seed)
12 
13 
14 def index_variable(shape, max_indices):
15  if not isinstance(shape, tuple):
16  shape = (shape,)
17  index = torch.rand(*shape).mul_(max_indices).floor_().long()
18  return index
19 
20 
21 def index_perm_variable(shape, max_indices):
22  if not isinstance(shape, tuple):
23  shape = (shape,)
24 
25  index = torch.randperm(max_indices).narrow(0, 0, reduce(mul, shape)).view(shape)
26  return index
27 
28 
29 def gather_variable(shape, index_dim, max_indices, duplicate=False):
30  assert len(shape) == 2
31  assert index_dim < 2
32  batch_dim = 1 - index_dim
33  index = torch.LongTensor(*shape)
34  for i in range(shape[index_dim]):
35  index.select(index_dim, i).copy_(
36  torch.randperm(max_indices)[:shape[batch_dim]])
37  if duplicate:
38  index.select(batch_dim, 0).copy_(index.select(batch_dim, 1))
39  return index
40 
41 
42 def bernoulli_scalar():
43  return torch.tensor(0, dtype=torch.uint8).bernoulli_()
44 
45 
46 def mask_not_all_zeros(shape):
47  assert len(shape) > 0
48  while True:
49  result = torch.randn(shape).gt(0)
50  if result.sum() > 0:
51  return result
52 
53 
54 def uniform_scalar(offset=0, requires_grad=False):
55  v = torch.rand(()) + offset
56  v.requires_grad = requires_grad
57  return v
58 
59 
60 def normal_scalar_clamp(amin, amax, requires_grad=False):
61  v = torch.randn(()).clamp(amin, amax)
62  v.requires_grad = requires_grad
63  return v
64 
65 
66 def prod_zeros(dim_size, dim_select):
67  assert len(dim_select) == 2
68  result = torch.randn(dim_size, dim_size, dim_size)
69  result.narrow(dim_select[0], 0, 1).narrow(dim_select[1], 1, 1).zero_()
70  result.narrow(dim_select[0], 2, 1).narrow(dim_select[1], 3, 1).zero_()
71  result.narrow(dim_select[0], 4, 1).narrow(dim_select[1], 3, 1).zero_()
72  return result
73 
74 
75 class non_differentiable(object):
76  def __init__(self, tensor):
77  self.tensor = tensor
78 
79 
80 class dont_convert(tuple):
81  pass
82 
83 
84 class NoArgsClass(object):
85  def __iter__(self):
86  return self
87 
88  def __next__(self):
89  raise StopIteration()
90  next = __next__ # Python 2 compatibility
91 
92  def __len__(self):
93  return 0
94 
95 NO_ARGS = NoArgsClass()
96 L = 20
97 M = 10
98 S = 5
99 
100 
101 # (
102 # method name,
103 # input size/constructing fn,
104 # args (tuple represents shape of a tensor arg),
105 # test variant name (will be used at test name suffix), // optional
106 # indices for possible dim arg, // optional
107 # fn mapping output to part that should be gradcheck'ed, // optional
108 # )
109 def method_tests():
110  set_rng_seed(0)
111  return [
112  ('add', (S, S, S), ((S, S, S),)),
113  ('add', (S, S, S), ((S, S),), 'broadcast_rhs'),
114  ('add', (S, S), ((S, S, S),), 'broadcast_lhs'),
115  ('add', (S, 1, S), ((M, S),), 'broadcast_all'),
116  ('add', (), ((),), 'scalar'),
117  ('add', (S, S, S), ((),), 'scalar_broadcast_rhs'),
118  ('add', (), ((S, S, S),), 'scalar_broadcast_lhs'),
119  ('add', (S, S, S), (3.14,), 'constant'),
120  ('add', (), (3.14,), 'scalar_constant'),
121  ('__radd__', (S, S, S), (3.14,), 'constant'),
122  ('__radd__', (), (3.14,), 'scalar_constant'),
123  ('sub', (S, S, S), ((S, S, S),)),
124  ('sub', (S, S, S), ((S, S),), 'broadcast_rhs'),
125  ('sub', (S, S), ((S, S, S),), 'broadcast_lhs'),
126  ('sub', (S, 1, S), ((M, S),), 'broadcast_all'),
127  ('sub', (S, S, S), ((),), 'scalar_broadcast_rhs'),
128  ('sub', (), ((S, S, S),), 'scalar_broadcast_lhs'),
129  ('sub', (S, S, S), (3.14,), 'constant'),
130  ('sub', (), (3.14,), 'scalar_constant'),
131  ('__rsub__', (S, S, S), (3.14,), 'constant'),
132  ('__rsub__', (), (3.14,), 'scalar_constant'),
133  ('mul', (S, S, S), ((S, S, S),)),
134  ('mul', (), ((),), 'scalar'),
135  ('mul', (S, S, S), ((S, S),), 'broadcast_rhs'),
136  ('mul', (S, S), ((S, S, S),), 'broadcast_lhs'),
137  ('mul', (S, 1, S), ((M, S),), 'broadcast_all'),
138  ('mul', (S, S, S), ((),), 'scalar_broadcast_rhs'),
139  ('mul', (), ((S, S, S),), 'scalar_broadcast_lhs'),
140  ('mul', (S, S, S), (3.14,), 'constant'),
141  ('mul', (), (3.14,), 'scalar_constant'),
142  ('__rmul__', (S, S, S), (3.14,), 'constant'),
143  ('__rmul__', (), (3.14,), 'scalar_constant'),
144  ('div', (S, S, S), (torch.rand(S, S, S) + 0.1,)),
145  ('div', (S, S, S), (torch.rand(S, S) + 0.1,), 'broadcast_rhs'),
146  ('div', (S, S), (torch.rand(S, S, S) + 0.1,), 'broadcast_lhs'),
147  ('div', (S, 1, S), (torch.rand(M, S) + 0.1,), 'broadcast_all'),
148  ('div', (), (uniform_scalar(0.1),), 'scalar'),
149  ('div', (S, S, S), (uniform_scalar(0.1),), 'scalar_broadcast_rhs'),
150  ('div', (), (uniform_scalar(0.1),), 'scalar_broadcast_lhs'),
151  ('div', torch.rand(S, S, S) + 1e-1, (3.14,), 'constant'),
152  ('__rdiv__', torch.rand(S, S, S) + 1e-1, (3.14,), 'constant'),
153  ('div', uniform_scalar(1e-1, requires_grad=True), (3.14,), 'scalar_constant'),
154  ('__rdiv__', uniform_scalar(1e-1, requires_grad=True), (3.14,), 'scalar_constant'),
155  ('pow', torch.rand(S, S, S) + 1e-3, (torch.rand(S, S, S) + 0.1,)),
156  ('pow', torch.rand(S, S, S) + 1e-3, (torch.rand(1,) + 0.1,), 'broadcast_rhs'),
157  ('pow', torch.rand(1,) + 1e-3, (torch.rand(S, S, S) + 0.1,), 'broadcast_lhs'),
158  ('pow', torch.rand(S, 1, S) + 1e-3, (torch.rand(1, S, 1) + 0.1,), 'broadcast_all'),
159  ('pow', uniform_scalar(1e-3, requires_grad=True), (uniform_scalar(0.1),), 'scalar'),
160  ('pow', torch.rand(S, S, S) + 1e-3, (uniform_scalar(0.1),), 'scalar_broadcast_rhs'),
161  ('pow', uniform_scalar(1e-3, requires_grad=True), (torch.rand(S, S, S) + 0.1,), 'scalar_broadcast_lhs'),
162  ('pow', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant'),
163  ('__rpow__', torch.rand(S, S, S) + 1e-3, (3.14,), 'constant'),
164  ('pow', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant'),
165  ('__rpow__', uniform_scalar(1e-3, requires_grad=True), (3.14,), 'scalar_constant'),
166  ('transpose', (1, 2, 3), (1, 2), 'dim', [0, 1]),
167  ('transpose', (), (0, 0), 'scalar'),
168  ('transpose', (1,), (0, 0), '1d'),
169  ('transpose', torch.rand(L, L), (0, 1), '2d'),
170  ('transpose', torch.rand(S, S, S), (2, 0), '3d'),
171  ('t', (1, 2), NO_ARGS),
172  ('view', (S, S, S), (S * S, S),),
173  ('view', (S, S, S), (torch.Size([S * S, S]),), 'size'),
174  ('view', (S,), (S,), '1d'),
175  ('view', (), (dont_convert(()),), 'scalar_to_scalar'),
176  ('view', (), (1,), 'scalar_to_1d'),
177  ('reshape', (S, S, S), (S * S, S),),
178  ('reshape', (S, S, S), (torch.Size([S * S, S]),), 'size'),
179  ('reshape', (S,), (S,), '1d'),
180  ('reshape', (), (dont_convert(()),), 'scalar_to_scalar'),
181  ('reshape', (), (1,), 'scalar_to_1d'),
182  ('reshape_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)),
183  ('reshape_as', (), (non_differentiable(torch.tensor(42.)),), 'scalar'),
184  ('reshape_as', (), (non_differentiable(torch.rand(1, 1)),), 'scalar_to_dims'),
185  ('flip', (S, S, S), ([0],), 'd0'),
186  ('flip', (S, S, S), ([0, 1, 2],), 'd012'),
187  ('flip', (S, S, S), ([0, 2],), 'd02'),
188  ('flip', (S, S, S), ([2, 0],), 'd20'),
189  ('flip', (S, S, S), ([-1],), 'neg_d'),
190  ('roll', (S, S, S), (0, 0), 'd0'),
191  ('roll', (S, S, S), (1, 2), 'd12'),
192  ('roll', (S, S, S), (0, 2,), 'd02'),
193  ('roll', (S, S, S), (2, 0,), 'd20'),
194  ('roll', (S, S, S), (-1, 0), 'neg_shift'),
195  ('roll', (S, S, S), (10000, 1), 'loop_shift'),
196  ('roll', (S, S, S), (2,), 'flattened'),
197  ('roll', (S, S, S), ([1, 2, -1], [0, 1, 2]), 'three_dims'),
198  ('rot90', (S, S, S), (1, [0, 1],), 'k1_d01'),
199  ('rot90', (S, S, S), (1, [1, 2],), 'k1_d12'),
200  ('rot90', (S, S, S), (1, [1, -1],), 'k1_neg_d'),
201  ('rot90', (S, S, S), (), 'default'),
202  ('view_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)),
203  ('view_as', (), (non_differentiable(torch.tensor(5.5)),), 'scalar'),
204  ('view_as', (), (non_differentiable(torch.rand(1, 1)),), 'scalar_to_dims'),
205  ('expand', (S, 1, 1), (S, S, S)),
206  ('expand', (torch.Size([S, 1, S]),), (S, S, S), 'size'),
207  ('expand', (S, 1), (S, S, S), 'new_dim'),
208  ('expand', (1,), (S, S, S), '1_element'),
209  ('expand', (1, S), (1, 1, S), 'new_dim_front_old_front_1'),
210  ('expand', (), (dont_convert(()),), 'scalar_to_scalar'),
211  ('expand', (), (1, 3, 2), 'scalar_to_dims'),
212  ('expand_as', (S, 1, 1), (torch.rand(S, S, S),)),
213  ('exp', (S, S, S), NO_ARGS),
214  ('exp', (), NO_ARGS, 'scalar'),
215  ('expm1', (S, S, S), NO_ARGS),
216  ('expm1', (), NO_ARGS, 'scalar'),
217  ('erf', torch.rand(S, S, S), NO_ARGS),
218  ('erf', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar'),
219  ('erfc', torch.rand(S, S, S), NO_ARGS),
220  ('erfc', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar'),
221  ('erfinv', torch.rand(S, S, S).clamp(-0.9, 0.9), NO_ARGS),
222  ('erfinv', normal_scalar_clamp(-0.9, 0.9, requires_grad=True), NO_ARGS, 'scalar'),
223  ('log', torch.rand(S, S, S) + 1e-2, NO_ARGS),
224  ('log', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar'),
225  ('log10', torch.rand(S, S, S) + 1e-2, NO_ARGS),
226  ('log10', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar'),
227  ('log1p', torch.rand(S, S, S), NO_ARGS),
228  ('log1p', uniform_scalar(requires_grad=True), NO_ARGS, 'scalar'),
229  ('log2', torch.rand(S, S, S) + 1e-2, NO_ARGS),
230  ('log2', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar'),
231  ('tanh', (S, S, S), NO_ARGS),
232  ('tanh', (), NO_ARGS, 'scalar'),
233  ('sigmoid', (S, S, S), NO_ARGS),
234  ('sigmoid', (), NO_ARGS, 'scalar'),
235  ('sinh', (S, S, S), NO_ARGS),
236  ('sinh', (), NO_ARGS, 'scalar'),
237  ('cosh', (S, S, S), NO_ARGS),
238  ('cosh', (), NO_ARGS, 'scalar'),
239  ('abs', (S, S, S), NO_ARGS),
240  ('abs', (), NO_ARGS, 'scalar'),
241  ('clamp', (S, S, S), (0, 1)),
242  ('clamp', (S, S, S), (None, 0.5), 'min'),
243  ('clamp', (S, S, S), (0.5, None), 'max'),
244  ('clamp', (), (0, 1), 'scalar'),
245  ('clamp', (), (None, 0.5), 'min_scalar'),
246  ('clamp', (), (0.5, None), 'max_scalar'),
247  ('sqrt', torch.rand(S, S, S) + 5e-4, NO_ARGS),
248  ('sqrt', uniform_scalar(5e-4, requires_grad=True), NO_ARGS, 'scalar'),
249  ('sin', (S, S, S), NO_ARGS),
250  ('sin', (), NO_ARGS, 'scalar'),
251  ('cos', (S, S, S), NO_ARGS),
252  ('cos', (), NO_ARGS, 'scalar'),
253  ('tan', torch.randn(S, S, S).clamp(-1, 1), NO_ARGS),
254  ('asin', torch.randn(S, S, S).clamp(-0.9, 0.9), NO_ARGS),
255  ('acos', torch.randn(S, S, S).clamp(-0.9, 0.9), NO_ARGS),
256  ('atan', (S, S, S), NO_ARGS),
257  ('atan', (), NO_ARGS, 'scalar'),
258  ('atan2', (S, S, S), ((S, S, S),)),
259  ('atan2', (), ((),), 'scalar'),
260  ('atan2', (S, S, S), ((S,),), 'broadcast_rhs'),
261  ('atan2', (S,), ((S, S, S),), 'broadcast_lhs'),
262  ('atan2', (S, 1, S), ((S, S),), 'broadcast_all'),
263  ('reciprocal', torch.rand(S, S, S) + 0.1, NO_ARGS),
264  ('reciprocal', uniform_scalar(0.1, requires_grad=True), NO_ARGS, 'scalar'),
265  ('round', (S, S, S), NO_ARGS),
266  ('round', (), NO_ARGS, 'scalar'),
267  ('sign', (S, S, S), NO_ARGS),
268  ('sign', (), NO_ARGS, 'scalar'),
269  ('trunc', (S, S, S), NO_ARGS),
270  ('trunc', (), NO_ARGS, 'scalar'),
271  ('floor', (S, S, S), NO_ARGS),
272  ('floor', (), NO_ARGS, 'scalar'),
273  ('ceil', (S, S, S), NO_ARGS),
274  ('ceil', (), NO_ARGS, 'scalar'),
275  ('rsqrt', torch.rand(S, S, S) + 1e-2, NO_ARGS),
276  ('rsqrt', uniform_scalar(1e-2, requires_grad=True), NO_ARGS, 'scalar'),
277  ('frac', (S, S, S), NO_ARGS),
278  ('frac', (), NO_ARGS, 'scalar'),
279  ('fmod', (S, S, S), (1.5,)),
280  ('fmod', (), (1.5,), 'scalar'),
281  ('fmod', (S, S, S), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'tensor'),
282  ('fmod', (S,), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'tensor_broadcast_lhs'),
283  ('fmod', (S, S, S), (non_differentiable(torch.rand(S) + 1.5),), 'tensor_broadcast_rhs'),
284  ('fmod', (S, 1, S), (non_differentiable(torch.rand(S, S) + 1.5),), 'tensor_broadcast_all'),
285  ('fmod', (), (non_differentiable(uniform_scalar(1.5)),), 'scalar_tensor'),
286  ('fmod', (), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'scalar_tensor_broadcast_lhs'),
287  ('fmod', (S, S, S), (non_differentiable(uniform_scalar(1.5)),), 'scalar_tensor_broadcast_rhs'),
288  ('remainder', (S, S, S), (1.5,)),
289  ('remainder', (), (1.5,), 'scalar'),
290  ('remainder', (S, S, S), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'tensor'),
291  ('remainder', (S,), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'tensor_broadcast_lhs'),
292  ('remainder', (S, 1, S), (non_differentiable(torch.rand(S, S) + 1.5),), 'tensor_broadcast_all'),
293  ('remainder', (), (non_differentiable(uniform_scalar(1.5)),), 'scalar_tensor'),
294  ('remainder', (), (non_differentiable(torch.rand(S, S, S) + 1.5),), 'scalar_tensor_broadcast_lhs'),
295  ('lerp', (S, S, S), ((S, S, S), 0.4), 'scalar_no_broadcast'),
296  ('lerp', (S, S, S), ((S,), 0.4), 'broadcast_rhs'),
297  ('lerp', (S,), ((S, S, S), 0.4), 'broadcast_lhs'),
298  ('lerp', (S, 1, S), ((S, S), 0.4), 'broadcast_all'),
299  ('lerp', (), ((), 0.4), 'scalar'),
300  ('lerp', (S, S, S), ((), 0.4), 'scalar_broadcast_rhs'),
301  ('lerp', (), ((S, S, S), 0.4), 'scalar_broadcast_lhs'),
302  ('max', (S, S, S), NO_ARGS),
303  ('max', (S, S, S), (1,), 'dim', [0]),
304  ('max', (S, S, S), (1, True,), 'keepdim_dim', [0]),
305  ('max', (), NO_ARGS, 'scalar'),
306  ('max', (), (0,), 'scalar_dim', [0]),
307  ('max', (), (0, True,), 'scalar_keepdim_dim', [0]),
308  ('max', (S, S, S), ((S, S, S),), 'elementwise'),
309  ('max', (S, S, S), ((S,),), 'elementwise_broadcast_rhs'),
310  ('max', (S,), ((S, S, S),), 'elementwise_broadcast_lhs'),
311  ('max', (S, 1, S), ((S, S),), 'elementwise_broadcast_all'),
312  ('max', (), ((),), 'scalar_elementwise'),
313  ('max', (S, S, S), ((),), 'scalar_elementwise_broadcast_rhs'),
314  ('max', (), ((S, S, S),), 'scalar_elementwise_broadcast_lhs'),
315  ('min', (S, S, S), NO_ARGS),
316  ('min', (S, S, S), (1,), 'dim', [0]),
317  ('min', (S, S, S), (1, True,), 'keepdim_dim', [0]),
318  ('min', (), NO_ARGS, 'scalar'),
319  ('min', (), (0,), 'scalar_dim', [0]),
320  ('min', (), (0, True,), 'scalar_keepdim_dim', [0]),
321  ('min', (S, S, S), ((S, S, S),), 'elementwise'),
322  ('min', (S, S, S), ((S,),), 'elementwise_broadcast_rhs'),
323  ('min', (S,), ((S, S, S),), 'elementwise_broadcast_lhs'),
324  ('min', (S, 1, S), ((S, S),), 'elementwise_broadcast_all'),
325  ('min', (), ((),), 'scalar_elementwise'),
326  ('min', (S, S, S), ((),), 'scalar_elementwise_broadcast_rhs'),
327  ('min', (), ((S, S, S),), 'scalar_elementwise_broadcast_lhs'),
328  ('mean', (S, S, S), NO_ARGS),
329  ('mean', (S, S, S), (1,), 'dim', [0]),
330  ('mean', (S, S, S), (1, True,), 'keepdim_dim', [0]),
331  ('mean', (), NO_ARGS, 'scalar'),
332  ('mean', (), (0,), 'scalar_dim', [0]),
333  ('mean', (), (0, True,), 'scalar_keepdim_dim', [0]),
334  ('kthvalue', (S, S, S), (2,)),
335  ('kthvalue', (), (1,), 'scalar'),
336  ('kthvalue', (S, S, S), (2, 1,), 'dim', [1]),
337  ('kthvalue', (), (1, 0,), 'scalar_dim', [1]),
338  ('kthvalue', (S, S, S), (2, 1, True,), 'keepdim_dim', [1]),
339  ('kthvalue', (), (1, 0, True), 'scalar_keepdim_dim', [1]),
340  ('kthvalue', (S,), (2, 0,), 'dim_1d', [1]),
341  ('kthvalue', (S,), (2, 0, True,), 'keepdim_dim_1d', [1]),
342  ('median', (S, S, S), NO_ARGS),
343  ('median', (S, S, S), (1,), 'dim', [0]),
344  ('median', (S, S, S), (1, True,), 'keepdim_dim', [0]),
345  ('median', (), NO_ARGS, 'scalar'),
346  ('median', (), (0,), 'scalar_dim', [0]),
347  ('median', (), (0, True,), 'scalar_keepdim_dim', [0]),
348  ('mode', (S, S, S), NO_ARGS),
349  ('mode', (S, S, S), (1,), 'dim', [0]),
350  ('mode', (S, S, S), (1, True,), 'keepdim_dim', [0]),
351  ('mode', (), NO_ARGS, 'scalar'),
352  ('mode', (), (0,), 'scalar_dim', [0]),
353  ('mode', (), (0, True,), 'scalar_keepdim_dim', [0]),
354  ('sum', (S, S, S), NO_ARGS),
355  ('sum', (S, S, S), (1,), 'dim', [0]),
356  ('sum', (S, S, S), (1, True,), 'keepdim_dim', [0]),
357  ('sum', (), NO_ARGS, 'scalar'),
358  ('sum', (), (0,), 'scalar_dim', [0]),
359  ('sum', (), (0, True,), 'scalar_keepdim_dim', [0]),
360  ('sum', (S, S, S), ([1, 2],), 'multi_dim'),
361  ('sum', (S, S, S), ([1, 2], True,), 'multi_dim_keepdim'),
362  ('prod', (S, S, S), NO_ARGS),
363  ('prod', (S, S, S), (1,), 'dim', [0]),
364  ('prod', (S, S, S), (1, True,), 'keepdim_dim', [0]),
365  ('prod', (), NO_ARGS, 'scalar'),
366  ('prod', (), (0,), 'scalar_dim', [0]),
367  ('prod', (), (0, True,), 'scalar_keepdim_dim', [0]),
368  ('prod', prod_zeros(S, [0, 1]), NO_ARGS, 'zerodims2'),
369  ('prod', prod_zeros(S, [0, 2]), NO_ARGS, 'zerodims1'),
370  ('prod', prod_zeros(S, [1, 2]), NO_ARGS, 'zerodims0'),
371  ('prod', prod_zeros(S, [0, 1]), (1,), 'zeros_dims2', [0]),
372  ('prod', prod_zeros(S, [0, 2]), (1,), 'zeros_dims1', [0]),
373  ('prod', prod_zeros(S, [1, 2]), (1,), 'zeros_dims0', [0]),
374  ('prod', prod_zeros(S, [0, 1]), (1, True), 'keepdim_zeros_dims2', [0]),
375  ('prod', prod_zeros(S, [0, 2]), (1, True), 'keepdim_zeros_dims1', [0]),
376  ('prod', prod_zeros(S, [1, 2]), (1, True), 'keepdim_zeros_dims0', [0]),
377  ('prod', prod_single_zero(S), NO_ARGS, 'single_zero'),
378  ('prod', (torch.tensor(0., requires_grad=True)), NO_ARGS, 'scalar_zero'),
379  ('prod', (torch.tensor(0., requires_grad=True)), (0,), 'scalar_dim_zero', [0]),
380  ('prod', (torch.tensor(0., requires_grad=True)), (0, True,), 'scalar_keepdim_dim_zero', [0]),
381  ('var', (S, S, S), NO_ARGS),
382  ('var', (S, S, S), (1,), 'dim', [0]),
383  ('var', (S, S, S), (1, True, True), 'keepdim_dim', [0]),
384  ('var', (S,), (0,), 'dim_1d', [0]),
385  ('var', (S,), (0, True, True), 'keepdim_dim_1d', [0]),
386  ('std', (S, S, S), NO_ARGS),
387  ('std', (S, S, S), (1,), 'dim', [0]),
388  ('std', (S, S, S), (1, True, True), 'keepdim_dim', [0]),
389  ('std', (S,), (0,), 'dim_1d', [0]),
390  ('std', (S,), (0, True, True), 'keepdim_dim_1d', [0]),
391  ('renorm', (S, S, S), (2, 1, 0.5), 'dim', [1]),
392  ('renorm', (S, S, S), (1, 2, 3), 'norm_1'),
393  ('renorm', (S, S, S), (inf, 2, 0.5), 'norm_inf'),
394  ('repeat', (S,), (2,), 'single_number'),
395  ('repeat', (), (2, 3), 'scalar'),
396  ('repeat', (2, 2), (3, 2)),
397  ('repeat', (2, 2), (1, 3, 1, 2), 'unsqueeze'),
398  ('cumsum', (S, S, S), (0,), 'dim0', [0]),
399  ('cumsum', (S, S, S), (1,), 'dim1', [0]),
400  ('cumsum', (S, S, S), (1,), 'dim1_cast', [0], (), lambda x: x, {'dtype': torch.float64}),
401  ('cumsum', (), (0,), 'dim0_scalar', [0]),
402  ('cumprod', (S, S, S), (0,)),
403  ('cumprod', (S, S, S), (1,), 'dim1', [0]),
404  ('cumprod', (), (0,), 'scalar'),
405  ('cumprod', (torch.tensor(0., requires_grad=True)), (0,), 'scalar_zeros'),
406  ('cumprod', prod_zeros(S, [0, 1]), (1,), 'zeros_dim2', [0]),
407  ('cumprod', prod_zeros(S, [0, 2]), (1,), 'zeros_dim1', [0]),
408  ('cumprod', prod_zeros(S, [1, 2]), (1,), 'zeros_dim0', [0]),
409  ('cumprod', prod_zeros(S, [1, 2]), (1,), 'zeros_dim0_cast', [0], (), lambda x: x, {'dtype': torch.float64}),
410  ('unfold', (), (0, 1, 1), 'scalar', [0]),
411  ('unfold', (S, S, S, S), (1, 3, 1), '', [0]),
412  ('unfold', (S, S, S), (2, 3, 2), 'lastdim', [0]),
413  ('addmm', (S, M), ((S, S), (S, M)),),
414  ('addmm', (1,), ((S, S), (S, M)), 'broadcast_lhs'),
415  ('addmm', (S, M), ((S, S), (S, M)), 'coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
416  ('addmm', (1,), ((S, S), (S, M)), 'broadcast_lhs_coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
417  ('addmm', (), ((S, S), (S, M)), 'scalar_broadcast_lhs'),
418  ('addmm', (), ((S, S), (S, M)), 'scalar_broadcast_lhs_coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
419  ('addbmm', (S, M), ((S, S, S), (S, S, M)),),
420  ('addbmm', (1,), ((S, S, S), (S, S, M)), 'broadcast_lhs'),
421  ('addbmm', (S, M), ((S, S, S), (S, S, M)), 'coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
422  ('addbmm', (1,), ((S, S, S), (S, S, M)), 'broadcast_lhs_coef',
423  (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
424  ('addbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs'),
425  ('addbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs_coef', (), (), lambda x: x,
426  {'beta': 0.2, 'alpha': 0.6}),
427  ('baddbmm', (S, S, M), ((S, S, S), (S, S, M)),),
428  ('baddbmm', (1,), ((S, S, S), (S, S, M)), 'broadcast_lhs'),
429  ('baddbmm', (S, S, M), ((S, S, S), (S, S, M)), 'coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
430  ('baddbmm', (1,), ((S, S, S), (S, S, M)), 'broadcast_lhs_coef',
431  (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
432  ('baddbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs'),
433  ('baddbmm', (), ((S, S, S), (S, S, M)), 'scalar_broadcast_lhs_coef', (), (), lambda x: x,
434  {'beta': 0.2, 'alpha': 0.6}),
435  ('addmv', (S,), ((S, M), (M,)),),
436  ('addmv', (1,), ((S, M), (M,)), 'broadcast_lhs'),
437  ('addmv', (S,), ((S, M), (M,)), 'coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
438  ('addmv', (1,), ((S, M), (M,)), 'broadcast_lhs_coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
439  ('addmv', (), ((S, M), (M,)), 'scalar_broadcast_lhs'),
440  ('addmv', (), ((S, M), (M,)), 'scalar_broadcast_lhs_coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
441  ('addr', (S, M), ((S,), (M,)),),
442  ('addr', (), ((S,), (M,)), 'broadcast_lhs'),
443  ('addr', (S, M), ((S,), (M,)), 'coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
444  ('addr', (), ((S,), (M,)), 'broadcast_lhs_coef', (), (), lambda x: x, {'beta': 0.2, 'alpha': 0.6}),
445  ('dot', (L,), ((L,),),),
446  ('mm', (S, M), ((M, S),)),
447  ('bmm', (M, S, M), ((M, M, S),)),
448  ('mv', (S, M), ((M,),)),
449  ('ger', (S,), ((M,),)),
450  ('matmul', (L,), ((L,),),),
451  ('matmul', (S, M), ((M,),), "2d_1d"),
452  ('matmul', (M, ), ((M, S),), "1d_2d"),
453  ('matmul', (S, M), ((M, S),), "2d_2d"),
454  ('matmul', (S, S, M, M), ((S, S, M, S),), "4d_4d"),
455  ('matmul', (S, S, M, M), ((M,),), "4d_1d"),
456  ('matmul', (M,), ((S, S, M, S),), "1d_4d"),
457  ('matrix_power', (S, S), [2], "n=2"),
458  ('matrix_power', (S, S, S), [3], "n=3"),
459  ('matrix_power', (S, S, S), [1], "n=1"),
460  ('matrix_power', (S, S, S), [0], "n=0"),
461  ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S), [-1], "n=-1",
462  NO_ARGS, [skipIfNoLapack]),
463  ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S), [-3], "n=-3",
464  NO_ARGS, [skipIfNoLapack]),
465  ('matrix_power', lambda: random_fullrank_matrix_distinct_singular_value(S, S), [-2], "n=-2",
466  NO_ARGS, [skipIfNoLapack]),
467  ('mvlgamma', torch.empty(S,).uniform_(0.5, 1), [1], "p=1"),
468  ('mvlgamma', torch.empty(S,).uniform_(1, 2), [2], "p=2"),
469  ('mvlgamma', torch.empty(S, S).uniform_(1.5, 3), [3], "p=3"),
470  ('mvlgamma', torch.empty(S, S).uniform_(2.5, 5), [5], "p=5"),
471  ('addcmul', (S, S), ((S, S), (S, S))),
472  ('addcmul', (S, S), ((S, 1), (1, S)), 'broadcast_rhs'),
473  ('addcmul', (1,), ((S, S, 1), (1, S)), 'broadcast_all'),
474  ('addcmul', (S, S), ((S, S), (S, S)), 'scale', (), (), lambda x: x, {'value': 0.5}),
475  ('addcmul', (S, S), ((S, 1), (1, S)), 'scale_broadcast_rhs', (), (), lambda x: x, {'value': 0.5}),
476  ('addcmul', (1,), ((S, S, 1), (1, S)), 'scale_broadcast_all', (), (), lambda x: x, {'value': 0.5}),
477  ('addcmul', (), ((), ()), 'scalar'),
478  ('addcmul', (S, S), ((), ()), 'scalar_broadcast_rhs'),
479  ('addcmul', (), ((S, S, 1), (1, S)), 'scalar_broadcast_lhs'),
480  ('addcmul', (), ((), ()), 'scalar_scale', (), (), lambda x: x, {'value': 0.5}),
481  ('addcmul', (S, S), ((), ()), 'scalar_scale_broadcast_rhs', (), (), lambda x: x, {'value': 0.5}),
482  ('addcmul', (), ((S, S, 1), (1, S)), 'scalar_scale_broadcast_lhs', (), (), lambda x: x, {'value': 0.5}),
483  ('addcdiv', (S, S), ((S, S), (S, S))),
484  ('addcdiv', (S, S), ((S, 1), (1, S)), 'broadcast_rhs'),
485  ('addcdiv', (1,), ((S, S, 1), (1, S)), 'broadcast_all'),
486  ('addcdiv', (S, S), ((S, S), (S, S)), 'scale', (), (), lambda x: x, {'value': 0.5}),
487  ('addcdiv', (S, S), ((S, 1), (1, S)), 'scale_broadcast_rhs', (), (), lambda x: x, {'value': 0.5}),
488  ('addcdiv', (1,), ((S, S, 1), (1, S)), 'scale_broadcast_all', (), (), lambda x: x, {'value': 0.5}),
489  ('addcdiv', (), ((), ()), 'scalar'),
490  ('addcdiv', (S, S), ((), ()), 'scalar_broadcast_rhs'),
491  ('addcdiv', (), ((S, S, 1), (1, S)), 'scalar_broadcast_lhs'),
492  ('addcdiv', (), ((), ()), 'scalar_scale', (), (), lambda x: x, {'value': 0.5}),
493  ('addcdiv', (S, S), ((), ()), 'scalar_scale_broadcast_rhs', (), (), lambda x: x, {'value': 0.5}),
494  ('addcdiv', (), ((S, S, 1), (1, S)), 'scalar_scale_broadcast_lhs', (), (), lambda x: x, {'value': 0.5}),
495  ('zero_', (S, S, S), NO_ARGS),
496  ('zero_', (), NO_ARGS, 'scalar'),
497  ('logsumexp', (S, S), (1,)),
498  ('logsumexp', (), (0,), 'scalar'),
499  ('norm', (S, S), (), 'default'),
500  ('norm', (S, S), (2,), '2'),
501  ('norm', (S, S), (0,), '0'),
502  ('norm', (S, S), (0.5,), '0_5'),
503  ('norm', (S, S), (1,), '1'),
504  ('norm', (S, S), (3,), '3'),
505  ('norm', (S, S), (inf,), 'inf'),
506  ('norm', (S, S), (-inf,), '-inf'),
507  ('norm', (S, S), ('fro',), 'fro_default'),
508  ('norm', (S, S), ('fro', [0, 1],), 'fro'),
509  ('norm', (S, S), ('nuc',), 'nuc', NO_ARGS, [skipIfNoLapack]),
510  ('norm', (S, S), (-1,), 'neg_1'),
511  ('norm', (S, S), (-2,), 'neg_2'),
512  ('norm', (S, S), (-0.5,), 'neg_0_5'),
513  ('norm', (S, S), (-1.5,), 'neg_1_5'),
514  ('norm', (S, S), (-2, 1,), 'neg_2_2_dim', [1]),
515  ('norm', (S, S), (-1, 1,), 'neg_1_2_dim', [1]),
516  ('norm', (S, S), (0, 1,), '0_2_dim', [1]),
517  ('norm', (S, S), (1, 1,), '1_2_dim', [1]),
518  ('norm', (S, S), (2, 1,), '2_2_dim', [1]),
519  ('norm', (S, S), (3, 1,), '3_2_dim', [1]),
520  ('norm', (S, S), (inf, 1,), 'inf_2_dim'),
521  ('norm', torch.rand(S, S, S) + 5e-2, (1.5,), '1_5_default'),
522  ('norm', (S, S, S), (2, 1), '2_dim', [1]),
523  ('norm', (S, S, S), (3, 1), '3_dim', [1]),
524  ('norm', torch.rand(S, S, S) + 5e-2, (1.5, 1), '1_5_dim', [1]),
525  ('norm', (S, S, S), (2, 1, True), 'keepdim_2_dim', [1]),
526  ('norm', (S, S, S), (3, 1, True), 'keepdim_3_dim', [1]),
527  ('norm', torch.rand(S, S, S) + 5e-2, (1.5, 1, True), 'keepdim_1_5_dim', [1]),
528  ('norm', (), (2, 0), '2_dim_scalar', [1]),
529  ('norm', (), (3, 0), '3_dim_scalar', [1]),
530  ('norm', (), (2, 0, True), 'keepdim_2_dim_scalar', [1]),
531  ('norm', (), (3, 0, True), 'keepdim_3_dim_scalar', [1]),
532  ('clone', (S, M, S), NO_ARGS),
533  ('clone', (), NO_ARGS, 'scalar'),
534  ('dist', (S, S, S), ((S, S, S),)),
535  ('dist', (S, S, S), ((S,),), 'broadcast_rhs'),
536  ('dist', (S,), ((S, S, S),), 'broadcast_lhs'),
537  ('dist', (S, 1, S), ((S, S),), 'broadcast_all'),
538  ('dist', (), ((),), 'scalar'),
539  ('dist', (S, S, S), ((),), 'scalar_broadcast_rhs'),
540  ('dist', (), ((S, S, S),), 'scalar_broadcast_lhs'),
541  ('dist', (S, S, S), ((S, S, S), 4), '4'),
542  ('dist', (S, S, S), ((S,), 4), '4_broadcast_rhs'),
543  ('dist', (S,), ((S, S, S), 4), '4_broadcast_lhs'),
544  ('dist', (S, 1, S), ((S, S), 4), '4_broadcast_all'),
545  ('dist', (), ((), 4), 'scalar_4'),
546  ('dist', (S, S, S), ((), 4), 'scalar_4_broadcast_rhs'),
547  ('dist', (), ((S, S, S), 4), 'scalar_4_broadcast_lhs'),
548  ('diag', (M, M), NO_ARGS, '2d'),
549  ('diag', (3, 5), NO_ARGS, '2d_wide'),
550  ('diag', (3, 5), (2,), '2d_wide_pos'),
551  ('diag', (3, 5), (-2,), '2d_wide_neg'),
552  ('diag', (5, 3), NO_ARGS, '2d_tall'),
553  ('diag', (5, 3), (2,), '2d_tall_pos'),
554  ('diag', (5, 3), (-2,), '2d_tall_neg'),
555  ('diag', (M,), NO_ARGS, '1d'),
556  ('diag', (M, M), (1,), '2d_1'),
557  ('diag', (M, M), (2,), '2d_2'),
558  ('diag_embed', (S, S), NO_ARGS),
559  ('diagonal', (M, M), NO_ARGS, '2d'),
560  ('diagonal', (3, 5), NO_ARGS, '2d_wide'),
561  ('diagonal', (3, 5), (2,), '2d_wide_pos'),
562  ('diagonal', (3, 5), (-2,), '2d_wide_neg'),
563  ('diagonal', (5, 3), NO_ARGS, '2d_tall'),
564  ('diagonal', (5, 3), (2,), '2d_tall_pos'),
565  ('diagonal', (5, 3), (-2,), '2d_tall_neg'),
566  ('diagonal', (M, M), (1,), '2d_1'),
567  ('diagonal', (M, M), (2,), '2d_2'),
568  ('diagonal', (M, M, M), (1, 1, 2), '3d_1'),
569  ('diagonal', (M, M, M), (2, 0, 1), '3d_2'),
570  ('diagonal', (M, M, M), (-2, 0, 1), '3d_3'),
571  ('tril', (M, M), NO_ARGS),
572  ('tril', (M, M), (2,), 'idx'),
573  ('tril', (S, M, M), NO_ARGS, 'batched'),
574  ('tril', (S, M, M), (2,), 'batched_idx'),
575  ('tril', (3, 3, S, S), NO_ARGS, 'more_batched'),
576  ('triu', (M, M), NO_ARGS),
577  ('triu', (M, M), (2,), 'idx'),
578  ('triu', (S, M, M), NO_ARGS, 'batched'),
579  ('triu', (S, M, M), (2,), 'batched_idx'),
580  ('triu', (3, 3, S, S), NO_ARGS, 'more_batched'),
581  ('trace', (M, M), NO_ARGS),
582  ('cross', (S, 3), ((S, 3),)),
583  ('cross', (S, 3, S), ((S, 3, S), 1), 'dim'),
584  ('index_select', (S, S, S), (0, index_variable(2, S)), 'dim', [0]),
585  ('index_select', (), (0, torch.tensor([0], dtype=torch.int64)), 'scalar_mixed_dim', [0]),
586  ('index_select', (), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_dim', [0]),
587  ('index_add', (S, S), (0, index_variable(2, S), (2, S)), 'dim', [0]),
588  ('index_add', (), (0, torch.tensor([0], dtype=torch.int64), (1,)), 'scalar_input_dim', [0]),
589  ('index_add', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim', [0]),
590  ('index_copy', (S, S), (0, index_perm_variable(2, S), (2, S)), 'dim', [0]),
591  ('index_copy', (), (0, torch.tensor([0], dtype=torch.int64), (1,)), 'scalar_input_dim', [0]),
592  ('index_copy', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim', [0]),
593  ('index_fill', (S, S), (0, index_variable(2, S), 2), 'dim', [0]),
594  ('index_fill', (S, S), (0, index_variable(2, S), ()), 'variable_dim', [0]),
595  ('index_fill', (S, S), (0, torch.tensor(0, dtype=torch.int64), 2), 'scalar_index_dim', [0]),
596  ('index_fill', (), (0, torch.tensor([0], dtype=torch.int64), 2), 'scalar_input_dim', [0]),
597  ('index_fill', (), (0, torch.tensor(0, dtype=torch.int64), 2), 'scalar_both_dim', [0]),
598  ('inverse', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS, '', NO_ARGS, [skipIfNoLapack]),
599  ('inverse', lambda: random_fullrank_matrix_distinct_singular_value(S, 2, 3),
600  NO_ARGS, 'batched', NO_ARGS, [skipIfNoLapack]),
601  ('det', (S, S), NO_ARGS, '', NO_ARGS, [skipIfNoLapack]),
602  ('det', (1, 1), NO_ARGS, '1x1', NO_ARGS, [skipIfNoLapack]),
603  ('det', lambda: random_symmetric_matrix(S), NO_ARGS, 'symmetric', NO_ARGS, [skipIfNoLapack]),
604  ('det', lambda: random_symmetric_psd_matrix(S), NO_ARGS, 'symmetric_psd', NO_ARGS, [skipIfNoLapack]),
605  ('det', lambda: random_symmetric_pd_matrix(S), NO_ARGS, 'symmetric_pd', NO_ARGS, [skipIfNoLapack]),
606  ('det', lambda: random_square_matrix_of_rank(S, S - 2), NO_ARGS, 'dim2_null', NO_ARGS, [skipIfNoLapack]),
607  ('det', lambda: random_square_matrix_of_rank(S, 1), NO_ARGS, 'rank1', NO_ARGS, [skipIfNoLapack]),
608  ('det', lambda: random_square_matrix_of_rank(S, 2), NO_ARGS, 'rank2', NO_ARGS, [skipIfNoLapack]),
609  ('det', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS,
610  'distinct_singular_values', NO_ARGS, [skipIfNoLapack]),
611  # For `logdet` and `slogdet`, the function at det=0 is not smooth.
612  # We need to exclude tests with det=0 (e.g. dim2_null, rank1, rank2) and use
613  # `make_nonzero_det` to make the random matrices have nonzero det. For
614  # `logdet`, we also set `make_nonzero_det(matrix, sign=1)` to make the
615  # matrix have positive det.
616  ('logdet', lambda: make_nonzero_det(torch.randn(S, S), 1), NO_ARGS, '', NO_ARGS, [skipIfNoLapack]),
617  ('logdet', lambda: make_nonzero_det(torch.randn(1, 1), 1), NO_ARGS, '1x1', NO_ARGS, [skipIfNoLapack]),
618  ('logdet', lambda: make_nonzero_det(random_symmetric_matrix(S), 1), NO_ARGS,
619  'symmetric', NO_ARGS, [skipIfNoLapack]),
620  ('logdet', lambda: make_nonzero_det(random_symmetric_pd_matrix(S), 1), NO_ARGS,
621  'symmetric_pd', NO_ARGS, [skipIfNoLapack]),
622  ('logdet', lambda: make_nonzero_det(random_fullrank_matrix_distinct_singular_value(S), 1, 0), NO_ARGS,
623  'distinct_singular_values', NO_ARGS, [skipIfNoLapack]),
624  ('slogdet', lambda: make_nonzero_det(torch.randn(1, 1), 1), NO_ARGS,
625  '1x1_pos_det', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
626  ('slogdet', lambda: make_nonzero_det(torch.randn(1, 1), -1), NO_ARGS,
627  '1x1_neg_det', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
628  ('slogdet', lambda: make_nonzero_det(torch.randn(S, S), 1), NO_ARGS,
629  'pos_det', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
630  ('slogdet', lambda: make_nonzero_det(torch.randn(S, S), -1), NO_ARGS,
631  'neg_det', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
632  ('slogdet', lambda: make_nonzero_det(random_symmetric_matrix(S)), NO_ARGS,
633  'symmetric', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
634  ('slogdet', lambda: random_symmetric_pd_matrix(S), NO_ARGS,
635  'symmetric_pd', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
636  ('slogdet', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS,
637  'distinct_singular_values', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
638  ('symeig', lambda: random_symmetric_matrix(S), (True, False), 'lower', NO_ARGS, [skipIfNoLapack]),
639  ('symeig', lambda: random_symmetric_matrix(S), (True, True), 'upper', NO_ARGS, [skipIfNoLapack]),
640  ('symeig', lambda: random_symmetric_matrix(M), (True, True), 'large', NO_ARGS, [skipIfNoLapack]),
641  ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS, '', NO_ARGS, [skipIfNoLapack]),
642  ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S)[:(S - 2)], NO_ARGS,
643  'wide', NO_ARGS, [skipIfNoLapack]),
644  ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S)[:, :(S - 2)], NO_ARGS,
645  'tall', NO_ARGS, [skipIfNoLapack]),
646  ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S)[:(S - 2)], (False,),
647  'wide_all', NO_ARGS, [skipIfNoLapack], lambda usv: (usv[0], usv[1], usv[2][:, :(S - 2)])),
648  ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S)[:, :(S - 2)], (False,),
649  'tall_all', NO_ARGS, [skipIfNoLapack], lambda usv: (usv[0][:, :(S - 2)], usv[1], usv[2])),
650  ('svd', lambda: random_fullrank_matrix_distinct_singular_value(M), NO_ARGS,
651  'large', NO_ARGS, [skipIfNoLapack]),
652  ('solve', (S, S), (random_fullrank_matrix_distinct_singular_value(
653  S, silent=True),), '', NO_ARGS, [skipIfNoLapack]),
654  ('solve', (S, S, S), (random_fullrank_matrix_distinct_singular_value(S, S, silent=True),),
655  'batched', NO_ARGS, [skipIfNoLapack]),
656  ('solve', (2, 3, S, S), (random_fullrank_matrix_distinct_singular_value(S, 2, 3, silent=True),),
657  'batched_dims', NO_ARGS, [skipIfNoLapack]),
658  ('solve', (2, 2, S, S), (random_fullrank_matrix_distinct_singular_value(S, 1, silent=True),),
659  'batched_broadcast_A', NO_ARGS, [skipIfNoLapack]),
660  ('solve', (1, S, S), (random_fullrank_matrix_distinct_singular_value(S, 2, 2, silent=True),),
661  'batched_broadcast_b', NO_ARGS, [skipIfNoLapack]),
662  ('fill_', (S, S, S), (1,), 'number'),
663  ('fill_', (), (1,), 'number_scalar'),
664  ('fill_', (S, S, S), ((),), 'variable'),
665  ('eq_', (S, S, S), ((S, S, S),)),
666  ('eq_', (S, S, S), ((1,),), 'broadcast_rhs'),
667  ('eq_', (), ((),), 'scalar'),
668  ('eq_', (S, S, S), ((),), 'scalar_broadcast_rhs'),
669  ('ne_', (S, S, S), ((S, S, S),)),
670  ('ne_', (S, S, S), ((1,),), 'broadcast_rhs'),
671  ('ne_', (), ((),), 'scalar'),
672  ('ne_', (S, S, S), ((),), 'scalar_broadcast_rhs'),
673  ('gt_', (S, S, S), ((S, S, S),)),
674  ('gt_', (S, S, S), ((1,),), 'broadcast_rhs'),
675  ('gt_', (), ((),), 'scalar'),
676  ('gt_', (S, S, S), ((),), 'scalar_broadcast_rhs'),
677  ('ge_', (S, S, S), ((S, S, S),)),
678  ('ge_', (S, S, S), ((1,),), 'broadcast_rhs'),
679  ('ge_', (), ((),), 'scalar'),
680  ('ge_', (S, S, S), ((),), 'scalar_broadcast_rhs'),
681  ('lt_', (S, S, S), ((S, S, S),)),
682  ('lt_', (S, S, S), ((1,),), 'broadcast_rhs'),
683  ('lt_', (), ((),), 'scalar'),
684  ('lt_', (S, S, S), ((),), 'scalar_broadcast_rhs'),
685  ('le_', (S, S, S), ((S, S, S),)),
686  ('le_', (S, S, S), ((1,),), 'broadcast_rhs'),
687  ('le_', (), ((),), 'scalar'),
688  ('le_', (S, S, S), ((),), 'scalar_broadcast_rhs'),
689  ('eq_', (S, S, S), (0,), 'pyscalar'),
690  ('ne_', (S, S, S), (0,), 'pyscalar'),
691  ('gt_', (S, S, S), (0,), 'pyscalar'),
692  ('ge_', (S, S, S), (0,), 'pyscalar'),
693  ('le_', (S, S, S), (0,), 'pyscalar'),
694  ('lt_', (), (0,), 'pyscalar'),
695  ('eq_', (), (0,), 'pyscalar_scalar'),
696  ('ne_', (), (0,), 'pyscalar_scalar'),
697  ('gt_', (), (0,), 'pyscalar_scalar'),
698  ('ge_', (), (0,), 'pyscalar_scalar'),
699  ('lt_', (), (0,), 'pyscalar_scalar'),
700  ('le_', (), (0,), 'pyscalar_scalar'),
701  ('permute', (1, 2, 3, 4), (0, 2, 3, 1)),
702  ('permute', (1, 2, 3, 4), (0, -2, -1, 1), 'neg_dim'),
703  ('permute', (), (dont_convert(()),), 'scalar'),
704  ('select', (S, S, S), (1, 2), 'dim', [0]),
705  ('select', (S, S, S), (1, -1), 'wrap_dim', [0]),
706  ('select', (S,), (0, 2), '1d'),
707  ('narrow', (S, S, S), (1, 2, 2), 'dim', [0]),
708  ('narrow', (S, S, S), (1, 0, 0), 'empty_dim', [0]),
709  ('squeeze', (S, 1, S, 1), NO_ARGS),
710  ('squeeze', (1, 1, 1, 1), NO_ARGS, 'input_sizes_are_ones'),
711  ('squeeze', (S, 1, S, 1), (1,), '1_dim', [0]),
712  ('squeeze', (S, 1, S, 1), (2,), 'not_1_dim', [0]),
713  ('squeeze', (), (0,), 'scalar', [0]),
714  ('unsqueeze', (S, S, S), (0,), 'first', [0]),
715  ('unsqueeze', (S, S, S), (1,), 'middle', [0]),
716  ('unsqueeze', (S, S, S), (3,), 'last', [0]),
717  ('unsqueeze', (), (0,), 'scalar', [0]),
718  ('chunk', (S, S, S), (2,)),
719  ('chunk', (S, S, S), (S, 1), 'dim', [1]),
720  ('split', (S, S, S), (2,)),
721  ('split', (S, S, S), (S, 1), 'dim', [1]),
722  ('split', (S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],), 'size_list'),
723  ('split', (S, S, S), ([int(S / 2), S - int(S / 2) * 2, int(S / 2)], 2), 'size_list_dim', [1]),
724  ('gather', (M, S), (0, gather_variable((S, S), 1, M, True)), 'dim0', [0]),
725  ('gather', (M, S), (1, gather_variable((M, S // 2), 0, S, True)), 'dim1', [0]),
726  ('gather', (), (0, torch.tensor([0], dtype=torch.int64)), 'scalar_input', [0]),
727  ('gather', (S,), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_index', [0]),
728  ('gather', (), (0, torch.tensor(0, dtype=torch.int64)), 'scalar_both', [0]),
729  ('scatter', (M, S), (0, gather_variable((S, S), 1, M), (S, S)), 'dim0', [0]),
730  ('scatter', (M, S), (1, gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1', [0]),
731  ('scatter', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalartensor_all_dim0', [0]),
732  ('scatter', (), (0, torch.tensor(0, dtype=torch.int64), 2.5), 'scalar_all_dim0', [0]),
733  ('scatter_add', (M, S), (0, gather_variable((S, S), 1, M), (S, S)), 'dim0', [0]),
734  ('scatter_add', (M, S), (1, gather_variable((M, S // 2), 0, S), (M, S // 2)), 'dim1', [0]),
735  ('scatter_add', (), (0, torch.tensor(0, dtype=torch.int64), ()), 'scalar_all_dim0', [0]),
736  ('masked_select', (M, M), (mask_not_all_zeros((M, M)),)),
737  ('masked_select', (M, M), (mask_not_all_zeros((M,)),), 'broadcast_rhs'),
738  ('masked_select', (M,), (mask_not_all_zeros((M, M)),), 'broadcast_lhs'),
739  ('masked_select', (M, 1, M), (mask_not_all_zeros((M, M)),),
740  'broadcast_all'),
741  ('masked_select', (), (torch.tensor(1, dtype=torch.uint8),), 'scalar'),
742  ('masked_select', (M, M), (torch.tensor(1, dtype=torch.uint8),), 'scalar_broadcast_rhs'),
743  ('masked_select', (), (mask_not_all_zeros((M, M)),), 'scalar_broadcast_lhs'),
744  ('masked_fill', (M, M), (torch.ByteTensor(M, M).bernoulli_(), 10)),
745  ('masked_fill', (M, M), (torch.ByteTensor(M, M).bernoulli_(), ()), 'tensor'),
746  ('masked_fill', (M,), (torch.ByteTensor(M, M).bernoulli_(), 10), 'broadcast_lhs'),
747  ('masked_fill', (M, M), (torch.ByteTensor(M,).bernoulli_(), 10), 'broadcast_rhs'),
748  ('masked_fill', (), (torch.tensor(0, dtype=torch.uint8).bernoulli_(), 10), 'scalar'),
749  ('masked_fill', (), (torch.tensor(0, dtype=torch.uint8).bernoulli_(), ()),
750  'scalar_variable'),
751  ('masked_fill', (M, M), (torch.tensor(0, dtype=torch.uint8).bernoulli_(), 10),
752  'scalar_broadcast_rhs'),
753  ('masked_scatter', (M, M), (torch.ByteTensor(M, M).bernoulli_(), (M, M))),
754  ('masked_scatter', (M,), (torch.ByteTensor(M, M).bernoulli_(), (M, M)),
755  'broadcast_lhs'),
756  ('masked_scatter', (M, M), (torch.ByteTensor(M,).bernoulli_(), (M, M)),
757  'broadcast_rhs'),
758  ('masked_scatter', (M, M), (bernoulli_scalar(), (M, M)), 'scalar'),
759  ('masked_scatter', (M, M), (bernoulli_scalar(), (M, M)),
760  'scalar_broadcast_rhs'),
761  ('resize_', (S, S, S), (torch.Size([S * S, S])), 'fewer_dims'),
762  ('resize_', (), (dont_convert(()),), 'scalar'),
763  ('resize_', (), (torch.Size([1, 1, 1])), 'scalar_to_dims'),
764  ('resize_as_', (), (non_differentiable(torch.tensor(5.)),), 'scalar'),
765  ('resize_as_', (), (non_differentiable(torch.randn((1, 1, 1))),), 'scalar_to_dims'),
766  ('resize_as_', (S, S, S), (non_differentiable(torch.randn(S * S, S)),)),
767  ('sort', (S, M, S), NO_ARGS),
768  ('sort', (S, M, S), (1,), 'dim'),
769  ('sort', (S, M, S), (1, True), 'dim_desc'),
770  ('sort', (), NO_ARGS, 'scalar'),
771  ('sort', (), (0,), 'dim_scalar'),
772  ('sort', (), (0, True), 'dim_desc_scalar'),
773  ('topk', (S, M, S), (3,)),
774  ('topk', (S, M, S), (3, 1), 'dim', [1]),
775  ('topk', (S, M, S), (3, 1, True), 'dim_desc', [1]),
776  ('topk', (S, M, S), (3, 1, True, True), 'dim_desc_sort', [1]),
777  ('topk', (), (1,), 'scalar'),
778  ('topk', (), (1, 0), 'dim_scalar', [1]),
779  ('topk', (), (1, 0, True), 'dim_desc_scalar', [1]),
780  ('topk', (), (1, 0, True, True), 'dim_desc_sort_scalar', [1]),
781  ('take', (S, S, S), (torch.LongTensor([[-3, 2], [20, 2]]),)),
782  ('take', (S, S, S), (torch.tensor(0, dtype=torch.int64),), 'scalar_index'),
783  ('take', (), (torch.LongTensor([0]),), 'scalar_data'),
784  ('take', (), (torch.tensor(0, dtype=torch.int64),), 'scalar_both'),
785  ('where', (M, M), (mask_not_all_zeros((M, M)), (M, M))),
786  ('where', (M, 1, M), (mask_not_all_zeros((M, M)), (M, M, 1)), 'broadcast_all'),
787  ('where', (), (bernoulli_scalar(), ()), 'scalar'),
788  ('where', (M, 1, M), (bernoulli_scalar(), (M, M, 1)), 'scalar_broadcast_mask'),
789  ('where', (), (mask_not_all_zeros((M, M)), ()), 'scalar_broadcast_non_mask'),
790  ('__getitem__', torch.randn(S, S, S), (dont_convert([1, 2]),)),
791  ('__getitem__', torch.randn(S, S, S), (slice(0, 3),), 'slice'),
792  ('__getitem__', torch.randn(S, S, S), (dont_convert([slice(0, 3), 1]),), 'slice_index'),
793  ('__getitem__', torch.randn(S, S, S), (dont_convert([[0, 2, 3], [1, 3, 3], [0, 0, 2]]),), 'adv_index'),
794  ('__getitem__', torch.randn(S, S, S), (dont_convert([[0, 0, 3], [1, 1, 3], [0, 0, 2]]),), 'adv_index_dup'),
795  ('__getitem__', torch.randn(S, S, S), (dont_convert([slice(None), slice(None), [0, 3]]),), 'adv_index_end'),
796  ('__getitem__', torch.randn(S, S, S), (dont_convert([slice(None), [0, 3], slice(None)]),), 'adv_index_mid'),
797  ('__getitem__', torch.randn(S, S, S), (dont_convert([[0, 3], slice(None), slice(None)]),), 'adv_index_beg'),
798  ('__getitem__', torch.randn(S, S, S), (dont_convert([[0, 3], [1, 2], slice(None)]),), 'adv_index_comb'),
799  ('__getitem__', torch.randn(S, S, S), (dont_convert([[0, 3], ]),), 'adv_index_sub'),
800  ('__getitem__', torch.randn(S, S, S), (dont_convert([[0, 3], slice(None)]),), 'adv_index_sub_2'),
801  ('__getitem__', torch.randn(S, S, S), (dont_convert([[0, 3], Ellipsis]),), 'adv_index_sub_3'),
802  ('__getitem__', torch.randn(S, S, S), (dont_convert([[0, 2, 3], [1, 3, 3],
803  torch.LongTensor([0, 0, 2])]),), 'adv_index_var'),
804  ]
805 # TODO: clamp with min/max
806 
807 
808 def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwargs=None):
809  if not isinstance(call_args, tuple):
810  call_args = (call_args,)
811 
812  def map_arg(arg):
813  def maybe_non_contig(tensor):
814  return tensor if not non_contiguous else make_non_contiguous(tensor)
815 
816  if isinstance(arg, torch.Size) or isinstance(arg, dont_convert):
817  return arg
818  elif isinstance(arg, tuple) and len(arg) == 0:
819  var = torch.randn((), dtype=torch.double)
820  var.requires_grad = requires_grad
821  return var
822  elif isinstance(arg, tuple) and not isinstance(arg[0], torch.Tensor):
823  return Variable(maybe_non_contig(torch.randn(*arg, dtype=torch.double)), requires_grad=requires_grad)
824  elif isinstance(arg, non_differentiable):
825  if isinstance(arg.tensor, torch.Tensor):
826  return maybe_non_contig(arg.tensor)
827  return maybe_non_contig(arg.tensor)
828  elif isinstance(arg, torch.Tensor):
829  if arg.dtype == torch.float:
830  arg = arg.double()
831  # NOTE: We do clone() after detach() here because we need to be able to change size/storage of v afterwards
832  v = maybe_non_contig(arg).detach().clone()
833  v.requires_grad = requires_grad and v.is_floating_point()
834  return v
835  elif callable(arg):
836  return map_arg(arg())
837  else:
838  return arg
839  args_out = tuple(map_arg(arg) for arg in call_args)
840  kwargs_out = {k: map_arg(v) for k, v in call_kwargs.items()} if call_kwargs else {}
841  return args_out, kwargs_out
842 
843 
844 def _compare_trilu_indices(
845  self, row, col, offset=0, dtype=torch.long, device='cpu'):
846  if row == 0 or col == 0:
847  # have to handle this separately as tril and triu does not take
848  # empty matrix as input
849  self.assertEqual(
850  torch.empty(0, 2, dtype=dtype, device=device).transpose(0, 1),
851  torch.tril_indices(row, col, offset, dtype=dtype, device=device))
852 
853  self.assertEqual(
854  torch.empty(0, 2, dtype=dtype, device=device).transpose(0, 1),
855  torch.triu_indices(row, col, offset, dtype=dtype, device=device))
856 
857  else:
858  self.assertEqual(
859  torch.ones(row, col, dtype=dtype, device='cpu')
860  .tril(offset).nonzero().transpose(0, 1).to(device),
861  torch.tril_indices(row, col, offset, dtype=dtype, device=device))
862 
863  self.assertEqual(
864  torch.ones(row, col, dtype=dtype, device='cpu')
865  .tril(offset).nonzero().transpose(0, 1).to(device),
866  torch.tril_indices(row, col, offset, dtype=dtype, device=device))
867 
868 
869 def _compare_large_trilu_indices(
870  self, row, col, offset=0, dtype=torch.long, device='cpu'):
871  l = torch.ones(row, col, dtype=dtype, device='cpu').tril(offset) \
872  .nonzero()[-100:-1, :].transpose(0, 1).to(device)
874 
875  r = torch.tril_indices(
876  row, col, offset, dtype=dtype, device=device)[:, -100:-1]
877  self.assertEqual(l, r)
879 
880  l = torch.ones(row, col, dtype=dtype, device='cpu').triu(offset) \
881  .nonzero()[-100:-1, :].transpose(0, 1).to(device)
883 
884  r = torch.triu_indices(
885  row, col, offset, dtype=dtype, device=device)[:, -100:-1]
886  self.assertEqual(l, r)
888 
889 # (
890 # row
891 # col
892 # offset (optional)
893 # dtype (optional)
894 # )
895 tri_tests_args = [
896  (1, 1),
897  (3, 3),
898  (3, 3, 1),
899  (3, 3, 2),
900  (3, 3, 200),
901  (3, 3, -1),
902  (3, 3, -2),
903  (3, 3, -200),
904  (0, 3, 0),
905  (0, 3, 1),
906  (0, 3, -1),
907  (3, 0, 0),
908  (3, 0, 1),
909  (3, 0, -1),
910  (0, 0, 0),
911  (0, 0, 1),
912  (0, 0, -1),
913  (3, 6, 0),
914  (3, 6, 1),
915  (3, 6, 3),
916  (3, 6, 9),
917  (3, 6, -1),
918  (3, 6, -3),
919  (3, 6, -9),
920  (6, 3, 0),
921  (6, 3, 1),
922  (6, 3, 3),
923  (6, 3, 9),
924  (6, 3, -1),
925  (6, 3, -3),
926  (6, 3, -9),
927  (258, 253, 1, torch.float32),
928  (257, 258, 1, torch.float64),
929  (258, 258, 1, torch.short),
930  (3, 513, 1, torch.long),
931  (513, 3, 1, torch.int),
932  (513, 0, 1, torch.double),
933  (1024, 1024),
934  (1024, 1024, 500, torch.float32),
935  (1024, 1024, 1023),
936  (1024, 1024, -500),
937  (1023, 1025),
938  (1025, 1023, 1022),
939  (1024, 1024, -500),
940  (3, 2028),
941  (3, 2028, 1),
942  (3, 2028, -1),
943  (2028, 3),
944  (2028, 1),
945  (2028, 1, -1)
946 ]
947 
948 tri_large_tests_args = [
949  # Large test cases below are deliberately commented out to speed up CI
950  # tests and to avoid OOM error. When modifying implementations of
951  # tril_indices and triu_indices, please enable these tests and make sure
952  # they pass.
953  #
954  # (1, 268435455),
955  # (5000, 5000),
956  # (10000, 10000),
957  # (268435455, 1),
958  # (134217727, 2, 1),
959  # (2, 134217727, 1),
960  # (536870901, 1),
961  # (1, 536870901),
962  # (268435455, 2, 1),
963  # (2, 268435455, 1)
964 ]
965 
966 
967 def run_additional_tri_tests(self, device):
968  x = torch.ones(
969  3, 3, dtype=torch.long, device=device, layout=torch.strided)
970  l = x.tril(0).nonzero().transpose(0, 1)
971  u = x.triu(0).nonzero().transpose(0, 1)
972  self.assertEqual(l, torch.tril_indices(3, 3, device=device))
973  self.assertEqual(
974  l, torch.tril_indices(3, 3, device=device, layout=torch.strided))
975 
976  self.assertEqual(u, torch.triu_indices(3, 3, device=device))
977  self.assertEqual(
978  u, torch.triu_indices(3, 3, device=device, layout=torch.strided))
979 
980  self.assertRaises(
981  RuntimeError,
982  lambda: torch.triu_indices(
983  1, 1, device=device, layout=torch.sparse_coo))
984 
985  self.assertRaises(
986  RuntimeError,
987  lambda: torch.tril_indices(
988  1, 1, device=device, layout=torch.sparse_coo))
989 
990 
991 def unpack_variables(args):
992  if istuple(args):
993  return tuple(unpack_variables(elem) for elem in args)
994  else:
995  return args
996 
997 
998 EXCLUDE_FUNCTIONAL = {
999  'addmm',
1000  'addmm_',
1001  'addbmm',
1002  'baddbmm',
1003  'addmv',
1004  'addmv_',
1005  'addr',
1006  'addr_',
1007  'reshape',
1008  'where' # argument order
1009 }
1010 EXCLUDE_GRADCHECK = {
1011 }
1012 EXCLUDE_GRADGRADCHECK = {
1013 }
1014 EXCLUDE_GRADGRADCHECK_BY_TEST_NAME = {
1015  # *det methods uses svd in backward when matrix is not invertible. However,
1016  # svd backward is unstable unless the matrix has positive distinct singular
1017  # values. Generated random matrices satisfy this with high probability, but
1018  # we can't rely on it. So only test gradgrad on invertible test cases and
1019  # _distinct_singular_values.
1020  'test_det',
1021  'test_det_1x1',
1022  'test_det_symmetric',
1023  'test_det_symmetric_psd',
1024  'test_det_dim2_null',
1025  'test_det_rank1',
1026  'test_det_rank2',
1027  # `other` expand_as(self, other) is not used in autograd.
1028  'test_expand_as',
1029  'test_logdet',
1030  'test_logdet_1x1',
1031  'test_logdet_symmetric',
1032  'test_slogdet_1x1_neg_det',
1033  'test_slogdet_neg_det',
1034  'test_slogdet_symmetric',
1035  'test_cdist',
1036 }
1037 
1038 
1039 def exclude_tensor_method(name, test_name):
1040  # there are no tensor equivalents for these (inplace or out)
1041  exclude_all_tensor_method_by_test_name = {
1042  'test_clamp_min',
1043  'test_clamp_max',
1044  'test_clamp_min_scalar',
1045  'test_clamp_max_scalar',
1046  'test_slice',
1047  'test_where',
1048  'test_where_broadcast_all',
1049  'test_where_scalar',
1050  'test_where_scalar_broadcast_mask',
1051  'test_where_scalar_broadcast_non_mask',
1052  }
1053  # there are no out-of-place tensor equivalents for these
1054  exclude_outplace_tensor_method = {
1055  'index_add',
1056  'index_copy',
1057  'index_fill',
1058  'masked_fill',
1059  'masked_scatter',
1060  'scatter',
1061  'scatter_add',
1062  'det',
1063  }
1064  if test_name in exclude_all_tensor_method_by_test_name:
1065  return True
1066  is_magic_method = name[:2] == '__' and name[-2:] == '__'
1067  is_inplace = name[-1] == "_" and not is_magic_method
1068  if not is_inplace and name in exclude_outplace_tensor_method:
1069  return True
1070  return False
def empty_cache()
Definition: __init__.py:395