3 from functools
import reduce, wraps
4 from operator
import mul, itemgetter
7 from common_utils
import (skipIfNoLapack,
8 prod_single_zero, random_square_matrix_of_rank,
9 random_symmetric_matrix, random_symmetric_psd_matrix,
10 random_symmetric_pd_matrix, make_nonzero_det,
11 random_fullrank_matrix_distinct_singular_value, set_rng_seed)
14 def index_variable(shape, max_indices):
15 if not isinstance(shape, tuple):
17 index = torch.rand(*shape).mul_(max_indices).floor_().long()
21 def index_perm_variable(shape, max_indices):
22 if not isinstance(shape, tuple):
25 index = torch.randperm(max_indices).narrow(0, 0, reduce(mul, shape)).view(shape)
29 def gather_variable(shape, index_dim, max_indices, duplicate=False):
30 assert len(shape) == 2
32 batch_dim = 1 - index_dim
33 index = torch.LongTensor(*shape)
34 for i
in range(shape[index_dim]):
35 index.select(index_dim, i).copy_(
36 torch.randperm(max_indices)[:shape[batch_dim]])
38 index.select(batch_dim, 0).copy_(index.select(batch_dim, 1))
42 def bernoulli_scalar():
46 def mask_not_all_zeros(shape):
49 result = torch.randn(shape).gt(0)
54 def uniform_scalar(offset=0, requires_grad=False):
55 v = torch.rand(()) + offset
56 v.requires_grad = requires_grad
60 def normal_scalar_clamp(amin, amax, requires_grad=False):
61 v = torch.randn(()).clamp(amin, amax)
62 v.requires_grad = requires_grad
66 def prod_zeros(dim_size, dim_select):
67 assert len(dim_select) == 2
68 result = torch.randn(dim_size, dim_size, dim_size)
69 result.narrow(dim_select[0], 0, 1).narrow(dim_select[1], 1, 1).zero_()
70 result.narrow(dim_select[0], 2, 1).narrow(dim_select[1], 3, 1).zero_()
71 result.narrow(dim_select[0], 4, 1).narrow(dim_select[1], 3, 1).zero_()
76 def __init__(self, tensor):
84 class NoArgsClass(object):
95 NO_ARGS = NoArgsClass()
112 (
'add', (S, S, S), ((S, S, S),)),
113 (
'add', (S, S, S), ((S, S),),
'broadcast_rhs'),
114 (
'add', (S, S), ((S, S, S),),
'broadcast_lhs'),
115 (
'add', (S, 1, S), ((M, S),),
'broadcast_all'),
116 (
'add', (), ((),),
'scalar'),
117 (
'add', (S, S, S), ((),),
'scalar_broadcast_rhs'),
118 (
'add', (), ((S, S, S),),
'scalar_broadcast_lhs'),
119 (
'add', (S, S, S), (3.14,),
'constant'),
120 (
'add', (), (3.14,),
'scalar_constant'),
121 (
'__radd__', (S, S, S), (3.14,),
'constant'),
122 (
'__radd__', (), (3.14,),
'scalar_constant'),
123 (
'sub', (S, S, S), ((S, S, S),)),
124 (
'sub', (S, S, S), ((S, S),),
'broadcast_rhs'),
125 (
'sub', (S, S), ((S, S, S),),
'broadcast_lhs'),
126 (
'sub', (S, 1, S), ((M, S),),
'broadcast_all'),
127 (
'sub', (S, S, S), ((),),
'scalar_broadcast_rhs'),
128 (
'sub', (), ((S, S, S),),
'scalar_broadcast_lhs'),
129 (
'sub', (S, S, S), (3.14,),
'constant'),
130 (
'sub', (), (3.14,),
'scalar_constant'),
131 (
'__rsub__', (S, S, S), (3.14,),
'constant'),
132 (
'__rsub__', (), (3.14,),
'scalar_constant'),
133 (
'mul', (S, S, S), ((S, S, S),)),
134 (
'mul', (), ((),),
'scalar'),
135 (
'mul', (S, S, S), ((S, S),),
'broadcast_rhs'),
136 (
'mul', (S, S), ((S, S, S),),
'broadcast_lhs'),
137 (
'mul', (S, 1, S), ((M, S),),
'broadcast_all'),
138 (
'mul', (S, S, S), ((),),
'scalar_broadcast_rhs'),
139 (
'mul', (), ((S, S, S),),
'scalar_broadcast_lhs'),
140 (
'mul', (S, S, S), (3.14,),
'constant'),
141 (
'mul', (), (3.14,),
'scalar_constant'),
142 (
'__rmul__', (S, S, S), (3.14,),
'constant'),
143 (
'__rmul__', (), (3.14,),
'scalar_constant'),
144 (
'div', (S, S, S), (torch.rand(S, S, S) + 0.1,)),
145 (
'div', (S, S, S), (torch.rand(S, S) + 0.1,),
'broadcast_rhs'),
146 (
'div', (S, S), (torch.rand(S, S, S) + 0.1,),
'broadcast_lhs'),
147 (
'div', (S, 1, S), (torch.rand(M, S) + 0.1,),
'broadcast_all'),
148 (
'div', (), (uniform_scalar(0.1),),
'scalar'),
149 (
'div', (S, S, S), (uniform_scalar(0.1),),
'scalar_broadcast_rhs'),
150 (
'div', (), (uniform_scalar(0.1),),
'scalar_broadcast_lhs'),
151 (
'div', torch.rand(S, S, S) + 1e-1, (3.14,),
'constant'),
152 (
'__rdiv__', torch.rand(S, S, S) + 1e-1, (3.14,),
'constant'),
153 (
'div', uniform_scalar(1e-1, requires_grad=
True), (3.14,),
'scalar_constant'),
154 (
'__rdiv__', uniform_scalar(1e-1, requires_grad=
True), (3.14,),
'scalar_constant'),
155 (
'pow', torch.rand(S, S, S) + 1e-3, (torch.rand(S, S, S) + 0.1,)),
156 (
'pow', torch.rand(S, S, S) + 1e-3, (torch.rand(1,) + 0.1,),
'broadcast_rhs'),
157 (
'pow', torch.rand(1,) + 1e-3, (torch.rand(S, S, S) + 0.1,),
'broadcast_lhs'),
158 (
'pow', torch.rand(S, 1, S) + 1e-3, (torch.rand(1, S, 1) + 0.1,),
'broadcast_all'),
159 (
'pow', uniform_scalar(1e-3, requires_grad=
True), (uniform_scalar(0.1),),
'scalar'),
160 (
'pow', torch.rand(S, S, S) + 1e-3, (uniform_scalar(0.1),),
'scalar_broadcast_rhs'),
161 (
'pow', uniform_scalar(1e-3, requires_grad=
True), (torch.rand(S, S, S) + 0.1,),
'scalar_broadcast_lhs'),
162 (
'pow', torch.rand(S, S, S) + 1e-3, (3.14,),
'constant'),
163 (
'__rpow__', torch.rand(S, S, S) + 1e-3, (3.14,),
'constant'),
164 (
'pow', uniform_scalar(1e-3, requires_grad=
True), (3.14,),
'scalar_constant'),
165 (
'__rpow__', uniform_scalar(1e-3, requires_grad=
True), (3.14,),
'scalar_constant'),
166 (
'transpose', (1, 2, 3), (1, 2),
'dim', [0, 1]),
167 (
'transpose', (), (0, 0),
'scalar'),
168 (
'transpose', (1,), (0, 0),
'1d'),
169 (
'transpose', torch.rand(L, L), (0, 1),
'2d'),
170 (
'transpose', torch.rand(S, S, S), (2, 0),
'3d'),
171 (
't', (1, 2), NO_ARGS),
172 (
'view', (S, S, S), (S * S, S),),
173 (
'view', (S, S, S), (torch.Size([S * S, S]),),
'size'),
174 (
'view', (S,), (S,),
'1d'),
175 (
'view', (), (dont_convert(()),),
'scalar_to_scalar'),
176 (
'view', (), (1,),
'scalar_to_1d'),
177 (
'reshape', (S, S, S), (S * S, S),),
178 (
'reshape', (S, S, S), (torch.Size([S * S, S]),),
'size'),
179 (
'reshape', (S,), (S,),
'1d'),
180 (
'reshape', (), (dont_convert(()),),
'scalar_to_scalar'),
181 (
'reshape', (), (1,),
'scalar_to_1d'),
182 (
'reshape_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)),
183 (
'reshape_as', (), (non_differentiable(
torch.tensor(42.)),),
'scalar'),
184 (
'reshape_as', (), (non_differentiable(torch.rand(1, 1)),),
'scalar_to_dims'),
185 (
'flip', (S, S, S), ([0],),
'd0'),
186 (
'flip', (S, S, S), ([0, 1, 2],),
'd012'),
187 (
'flip', (S, S, S), ([0, 2],),
'd02'),
188 (
'flip', (S, S, S), ([2, 0],),
'd20'),
189 (
'flip', (S, S, S), ([-1],),
'neg_d'),
190 (
'roll', (S, S, S), (0, 0),
'd0'),
191 (
'roll', (S, S, S), (1, 2),
'd12'),
192 (
'roll', (S, S, S), (0, 2,),
'd02'),
193 (
'roll', (S, S, S), (2, 0,),
'd20'),
194 (
'roll', (S, S, S), (-1, 0),
'neg_shift'),
195 (
'roll', (S, S, S), (10000, 1),
'loop_shift'),
196 (
'roll', (S, S, S), (2,),
'flattened'),
197 (
'roll', (S, S, S), ([1, 2, -1], [0, 1, 2]),
'three_dims'),
198 (
'rot90', (S, S, S), (1, [0, 1],),
'k1_d01'),
199 (
'rot90', (S, S, S), (1, [1, 2],),
'k1_d12'),
200 (
'rot90', (S, S, S), (1, [1, -1],),
'k1_neg_d'),
201 (
'rot90', (S, S, S), (),
'default'),
202 (
'view_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)),
203 (
'view_as', (), (non_differentiable(
torch.tensor(5.5)),),
'scalar'),
204 (
'view_as', (), (non_differentiable(torch.rand(1, 1)),),
'scalar_to_dims'),
205 (
'expand', (S, 1, 1), (S, S, S)),
206 (
'expand', (torch.Size([S, 1, S]),), (S, S, S),
'size'),
207 (
'expand', (S, 1), (S, S, S),
'new_dim'),
208 (
'expand', (1,), (S, S, S),
'1_element'),
209 (
'expand', (1, S), (1, 1, S),
'new_dim_front_old_front_1'),
210 (
'expand', (), (dont_convert(()),),
'scalar_to_scalar'),
211 (
'expand', (), (1, 3, 2),
'scalar_to_dims'),
212 (
'expand_as', (S, 1, 1), (torch.rand(S, S, S),)),
213 (
'exp', (S, S, S), NO_ARGS),
214 (
'exp', (), NO_ARGS,
'scalar'),
215 (
'expm1', (S, S, S), NO_ARGS),
216 (
'expm1', (), NO_ARGS,
'scalar'),
217 (
'erf', torch.rand(S, S, S), NO_ARGS),
218 (
'erf', uniform_scalar(requires_grad=
True), NO_ARGS,
'scalar'),
219 (
'erfc', torch.rand(S, S, S), NO_ARGS),
220 (
'erfc', uniform_scalar(requires_grad=
True), NO_ARGS,
'scalar'),
221 (
'erfinv', torch.rand(S, S, S).clamp(-0.9, 0.9), NO_ARGS),
222 (
'erfinv', normal_scalar_clamp(-0.9, 0.9, requires_grad=
True), NO_ARGS,
'scalar'),
223 (
'log', torch.rand(S, S, S) + 1e-2, NO_ARGS),
224 (
'log', uniform_scalar(1e-2, requires_grad=
True), NO_ARGS,
'scalar'),
225 (
'log10', torch.rand(S, S, S) + 1e-2, NO_ARGS),
226 (
'log10', uniform_scalar(1e-2, requires_grad=
True), NO_ARGS,
'scalar'),
227 (
'log1p', torch.rand(S, S, S), NO_ARGS),
228 (
'log1p', uniform_scalar(requires_grad=
True), NO_ARGS,
'scalar'),
229 (
'log2', torch.rand(S, S, S) + 1e-2, NO_ARGS),
230 (
'log2', uniform_scalar(1e-2, requires_grad=
True), NO_ARGS,
'scalar'),
231 (
'tanh', (S, S, S), NO_ARGS),
232 (
'tanh', (), NO_ARGS,
'scalar'),
233 (
'sigmoid', (S, S, S), NO_ARGS),
234 (
'sigmoid', (), NO_ARGS,
'scalar'),
235 (
'sinh', (S, S, S), NO_ARGS),
236 (
'sinh', (), NO_ARGS,
'scalar'),
237 (
'cosh', (S, S, S), NO_ARGS),
238 (
'cosh', (), NO_ARGS,
'scalar'),
239 (
'abs', (S, S, S), NO_ARGS),
240 (
'abs', (), NO_ARGS,
'scalar'),
241 (
'clamp', (S, S, S), (0, 1)),
242 (
'clamp', (S, S, S), (
None, 0.5),
'min'),
243 (
'clamp', (S, S, S), (0.5,
None),
'max'),
244 (
'clamp', (), (0, 1),
'scalar'),
245 (
'clamp', (), (
None, 0.5),
'min_scalar'),
246 (
'clamp', (), (0.5,
None),
'max_scalar'),
247 (
'sqrt', torch.rand(S, S, S) + 5e-4, NO_ARGS),
248 (
'sqrt', uniform_scalar(5e-4, requires_grad=
True), NO_ARGS,
'scalar'),
249 (
'sin', (S, S, S), NO_ARGS),
250 (
'sin', (), NO_ARGS,
'scalar'),
251 (
'cos', (S, S, S), NO_ARGS),
252 (
'cos', (), NO_ARGS,
'scalar'),
253 (
'tan', torch.randn(S, S, S).clamp(-1, 1), NO_ARGS),
254 (
'asin', torch.randn(S, S, S).clamp(-0.9, 0.9), NO_ARGS),
255 (
'acos', torch.randn(S, S, S).clamp(-0.9, 0.9), NO_ARGS),
256 (
'atan', (S, S, S), NO_ARGS),
257 (
'atan', (), NO_ARGS,
'scalar'),
258 (
'atan2', (S, S, S), ((S, S, S),)),
259 (
'atan2', (), ((),),
'scalar'),
260 (
'atan2', (S, S, S), ((S,),),
'broadcast_rhs'),
261 (
'atan2', (S,), ((S, S, S),),
'broadcast_lhs'),
262 (
'atan2', (S, 1, S), ((S, S),),
'broadcast_all'),
263 (
'reciprocal', torch.rand(S, S, S) + 0.1, NO_ARGS),
264 (
'reciprocal', uniform_scalar(0.1, requires_grad=
True), NO_ARGS,
'scalar'),
265 (
'round', (S, S, S), NO_ARGS),
266 (
'round', (), NO_ARGS,
'scalar'),
267 (
'sign', (S, S, S), NO_ARGS),
268 (
'sign', (), NO_ARGS,
'scalar'),
269 (
'trunc', (S, S, S), NO_ARGS),
270 (
'trunc', (), NO_ARGS,
'scalar'),
271 (
'floor', (S, S, S), NO_ARGS),
272 (
'floor', (), NO_ARGS,
'scalar'),
273 (
'ceil', (S, S, S), NO_ARGS),
274 (
'ceil', (), NO_ARGS,
'scalar'),
275 (
'rsqrt', torch.rand(S, S, S) + 1e-2, NO_ARGS),
276 (
'rsqrt', uniform_scalar(1e-2, requires_grad=
True), NO_ARGS,
'scalar'),
277 (
'frac', (S, S, S), NO_ARGS),
278 (
'frac', (), NO_ARGS,
'scalar'),
279 (
'fmod', (S, S, S), (1.5,)),
280 (
'fmod', (), (1.5,),
'scalar'),
281 (
'fmod', (S, S, S), (non_differentiable(torch.rand(S, S, S) + 1.5),),
'tensor'),
282 (
'fmod', (S,), (non_differentiable(torch.rand(S, S, S) + 1.5),),
'tensor_broadcast_lhs'),
283 (
'fmod', (S, S, S), (non_differentiable(torch.rand(S) + 1.5),),
'tensor_broadcast_rhs'),
284 (
'fmod', (S, 1, S), (non_differentiable(torch.rand(S, S) + 1.5),),
'tensor_broadcast_all'),
285 (
'fmod', (), (non_differentiable(uniform_scalar(1.5)),),
'scalar_tensor'),
286 (
'fmod', (), (non_differentiable(torch.rand(S, S, S) + 1.5),),
'scalar_tensor_broadcast_lhs'),
287 (
'fmod', (S, S, S), (non_differentiable(uniform_scalar(1.5)),),
'scalar_tensor_broadcast_rhs'),
288 (
'remainder', (S, S, S), (1.5,)),
289 (
'remainder', (), (1.5,),
'scalar'),
290 (
'remainder', (S, S, S), (non_differentiable(torch.rand(S, S, S) + 1.5),),
'tensor'),
291 (
'remainder', (S,), (non_differentiable(torch.rand(S, S, S) + 1.5),),
'tensor_broadcast_lhs'),
292 (
'remainder', (S, 1, S), (non_differentiable(torch.rand(S, S) + 1.5),),
'tensor_broadcast_all'),
293 (
'remainder', (), (non_differentiable(uniform_scalar(1.5)),),
'scalar_tensor'),
294 (
'remainder', (), (non_differentiable(torch.rand(S, S, S) + 1.5),),
'scalar_tensor_broadcast_lhs'),
295 (
'lerp', (S, S, S), ((S, S, S), 0.4),
'scalar_no_broadcast'),
296 (
'lerp', (S, S, S), ((S,), 0.4),
'broadcast_rhs'),
297 (
'lerp', (S,), ((S, S, S), 0.4),
'broadcast_lhs'),
298 (
'lerp', (S, 1, S), ((S, S), 0.4),
'broadcast_all'),
299 (
'lerp', (), ((), 0.4),
'scalar'),
300 (
'lerp', (S, S, S), ((), 0.4),
'scalar_broadcast_rhs'),
301 (
'lerp', (), ((S, S, S), 0.4),
'scalar_broadcast_lhs'),
302 (
'max', (S, S, S), NO_ARGS),
303 (
'max', (S, S, S), (1,),
'dim', [0]),
304 (
'max', (S, S, S), (1,
True,),
'keepdim_dim', [0]),
305 (
'max', (), NO_ARGS,
'scalar'),
306 (
'max', (), (0,),
'scalar_dim', [0]),
307 (
'max', (), (0,
True,),
'scalar_keepdim_dim', [0]),
308 (
'max', (S, S, S), ((S, S, S),),
'elementwise'),
309 (
'max', (S, S, S), ((S,),),
'elementwise_broadcast_rhs'),
310 (
'max', (S,), ((S, S, S),),
'elementwise_broadcast_lhs'),
311 (
'max', (S, 1, S), ((S, S),),
'elementwise_broadcast_all'),
312 (
'max', (), ((),),
'scalar_elementwise'),
313 (
'max', (S, S, S), ((),),
'scalar_elementwise_broadcast_rhs'),
314 (
'max', (), ((S, S, S),),
'scalar_elementwise_broadcast_lhs'),
315 (
'min', (S, S, S), NO_ARGS),
316 (
'min', (S, S, S), (1,),
'dim', [0]),
317 (
'min', (S, S, S), (1,
True,),
'keepdim_dim', [0]),
318 (
'min', (), NO_ARGS,
'scalar'),
319 (
'min', (), (0,),
'scalar_dim', [0]),
320 (
'min', (), (0,
True,),
'scalar_keepdim_dim', [0]),
321 (
'min', (S, S, S), ((S, S, S),),
'elementwise'),
322 (
'min', (S, S, S), ((S,),),
'elementwise_broadcast_rhs'),
323 (
'min', (S,), ((S, S, S),),
'elementwise_broadcast_lhs'),
324 (
'min', (S, 1, S), ((S, S),),
'elementwise_broadcast_all'),
325 (
'min', (), ((),),
'scalar_elementwise'),
326 (
'min', (S, S, S), ((),),
'scalar_elementwise_broadcast_rhs'),
327 (
'min', (), ((S, S, S),),
'scalar_elementwise_broadcast_lhs'),
328 (
'mean', (S, S, S), NO_ARGS),
329 (
'mean', (S, S, S), (1,),
'dim', [0]),
330 (
'mean', (S, S, S), (1,
True,),
'keepdim_dim', [0]),
331 (
'mean', (), NO_ARGS,
'scalar'),
332 (
'mean', (), (0,),
'scalar_dim', [0]),
333 (
'mean', (), (0,
True,),
'scalar_keepdim_dim', [0]),
334 (
'kthvalue', (S, S, S), (2,)),
335 (
'kthvalue', (), (1,),
'scalar'),
336 (
'kthvalue', (S, S, S), (2, 1,),
'dim', [1]),
337 (
'kthvalue', (), (1, 0,),
'scalar_dim', [1]),
338 (
'kthvalue', (S, S, S), (2, 1,
True,),
'keepdim_dim', [1]),
339 (
'kthvalue', (), (1, 0,
True),
'scalar_keepdim_dim', [1]),
340 (
'kthvalue', (S,), (2, 0,),
'dim_1d', [1]),
341 (
'kthvalue', (S,), (2, 0,
True,),
'keepdim_dim_1d', [1]),
342 (
'median', (S, S, S), NO_ARGS),
343 (
'median', (S, S, S), (1,),
'dim', [0]),
344 (
'median', (S, S, S), (1,
True,),
'keepdim_dim', [0]),
345 (
'median', (), NO_ARGS,
'scalar'),
346 (
'median', (), (0,),
'scalar_dim', [0]),
347 (
'median', (), (0,
True,),
'scalar_keepdim_dim', [0]),
348 (
'mode', (S, S, S), NO_ARGS),
349 (
'mode', (S, S, S), (1,),
'dim', [0]),
350 (
'mode', (S, S, S), (1,
True,),
'keepdim_dim', [0]),
351 (
'mode', (), NO_ARGS,
'scalar'),
352 (
'mode', (), (0,),
'scalar_dim', [0]),
353 (
'mode', (), (0,
True,),
'scalar_keepdim_dim', [0]),
354 (
'sum', (S, S, S), NO_ARGS),
355 (
'sum', (S, S, S), (1,),
'dim', [0]),
356 (
'sum', (S, S, S), (1,
True,),
'keepdim_dim', [0]),
357 (
'sum', (), NO_ARGS,
'scalar'),
358 (
'sum', (), (0,),
'scalar_dim', [0]),
359 (
'sum', (), (0,
True,),
'scalar_keepdim_dim', [0]),
360 (
'sum', (S, S, S), ([1, 2],),
'multi_dim'),
361 (
'sum', (S, S, S), ([1, 2],
True,),
'multi_dim_keepdim'),
362 (
'prod', (S, S, S), NO_ARGS),
363 (
'prod', (S, S, S), (1,),
'dim', [0]),
364 (
'prod', (S, S, S), (1,
True,),
'keepdim_dim', [0]),
365 (
'prod', (), NO_ARGS,
'scalar'),
366 (
'prod', (), (0,),
'scalar_dim', [0]),
367 (
'prod', (), (0,
True,),
'scalar_keepdim_dim', [0]),
368 (
'prod', prod_zeros(S, [0, 1]), NO_ARGS,
'zerodims2'),
369 (
'prod', prod_zeros(S, [0, 2]), NO_ARGS,
'zerodims1'),
370 (
'prod', prod_zeros(S, [1, 2]), NO_ARGS,
'zerodims0'),
371 (
'prod', prod_zeros(S, [0, 1]), (1,),
'zeros_dims2', [0]),
372 (
'prod', prod_zeros(S, [0, 2]), (1,),
'zeros_dims1', [0]),
373 (
'prod', prod_zeros(S, [1, 2]), (1,),
'zeros_dims0', [0]),
374 (
'prod', prod_zeros(S, [0, 1]), (1,
True),
'keepdim_zeros_dims2', [0]),
375 (
'prod', prod_zeros(S, [0, 2]), (1,
True),
'keepdim_zeros_dims1', [0]),
376 (
'prod', prod_zeros(S, [1, 2]), (1,
True),
'keepdim_zeros_dims0', [0]),
377 (
'prod', prod_single_zero(S), NO_ARGS,
'single_zero'),
378 (
'prod', (
torch.tensor(0., requires_grad=
True)), NO_ARGS,
'scalar_zero'),
379 (
'prod', (
torch.tensor(0., requires_grad=
True)), (0,),
'scalar_dim_zero', [0]),
380 (
'prod', (
torch.tensor(0., requires_grad=
True)), (0,
True,),
'scalar_keepdim_dim_zero', [0]),
381 (
'var', (S, S, S), NO_ARGS),
382 (
'var', (S, S, S), (1,),
'dim', [0]),
383 (
'var', (S, S, S), (1,
True,
True),
'keepdim_dim', [0]),
384 (
'var', (S,), (0,),
'dim_1d', [0]),
385 (
'var', (S,), (0,
True,
True),
'keepdim_dim_1d', [0]),
386 (
'std', (S, S, S), NO_ARGS),
387 (
'std', (S, S, S), (1,),
'dim', [0]),
388 (
'std', (S, S, S), (1,
True,
True),
'keepdim_dim', [0]),
389 (
'std', (S,), (0,),
'dim_1d', [0]),
390 (
'std', (S,), (0,
True,
True),
'keepdim_dim_1d', [0]),
391 (
'renorm', (S, S, S), (2, 1, 0.5),
'dim', [1]),
392 (
'renorm', (S, S, S), (1, 2, 3),
'norm_1'),
393 (
'renorm', (S, S, S), (inf, 2, 0.5),
'norm_inf'),
394 (
'repeat', (S,), (2,),
'single_number'),
395 (
'repeat', (), (2, 3),
'scalar'),
396 (
'repeat', (2, 2), (3, 2)),
397 (
'repeat', (2, 2), (1, 3, 1, 2),
'unsqueeze'),
398 (
'cumsum', (S, S, S), (0,),
'dim0', [0]),
399 (
'cumsum', (S, S, S), (1,),
'dim1', [0]),
400 (
'cumsum', (S, S, S), (1,),
'dim1_cast', [0], (),
lambda x: x, {
'dtype': torch.float64}),
401 (
'cumsum', (), (0,),
'dim0_scalar', [0]),
402 (
'cumprod', (S, S, S), (0,)),
403 (
'cumprod', (S, S, S), (1,),
'dim1', [0]),
404 (
'cumprod', (), (0,),
'scalar'),
405 (
'cumprod', (
torch.tensor(0., requires_grad=
True)), (0,),
'scalar_zeros'),
406 (
'cumprod', prod_zeros(S, [0, 1]), (1,),
'zeros_dim2', [0]),
407 (
'cumprod', prod_zeros(S, [0, 2]), (1,),
'zeros_dim1', [0]),
408 (
'cumprod', prod_zeros(S, [1, 2]), (1,),
'zeros_dim0', [0]),
409 (
'cumprod', prod_zeros(S, [1, 2]), (1,),
'zeros_dim0_cast', [0], (),
lambda x: x, {
'dtype': torch.float64}),
410 (
'unfold', (), (0, 1, 1),
'scalar', [0]),
411 (
'unfold', (S, S, S, S), (1, 3, 1),
'', [0]),
412 (
'unfold', (S, S, S), (2, 3, 2),
'lastdim', [0]),
413 (
'addmm', (S, M), ((S, S), (S, M)),),
414 (
'addmm', (1,), ((S, S), (S, M)),
'broadcast_lhs'),
415 (
'addmm', (S, M), ((S, S), (S, M)),
'coef', (), (),
lambda x: x, {
'beta': 0.2,
'alpha': 0.6}),
416 (
'addmm', (1,), ((S, S), (S, M)),
'broadcast_lhs_coef', (), (),
lambda x: x, {
'beta': 0.2,
'alpha': 0.6}),
417 (
'addmm', (), ((S, S), (S, M)),
'scalar_broadcast_lhs'),
418 (
'addmm', (), ((S, S), (S, M)),
'scalar_broadcast_lhs_coef', (), (),
lambda x: x, {
'beta': 0.2,
'alpha': 0.6}),
419 (
'addbmm', (S, M), ((S, S, S), (S, S, M)),),
420 (
'addbmm', (1,), ((S, S, S), (S, S, M)),
'broadcast_lhs'),
421 (
'addbmm', (S, M), ((S, S, S), (S, S, M)),
'coef', (), (),
lambda x: x, {
'beta': 0.2,
'alpha': 0.6}),
422 (
'addbmm', (1,), ((S, S, S), (S, S, M)),
'broadcast_lhs_coef',
423 (), (),
lambda x: x, {
'beta': 0.2,
'alpha': 0.6}),
424 (
'addbmm', (), ((S, S, S), (S, S, M)),
'scalar_broadcast_lhs'),
425 (
'addbmm', (), ((S, S, S), (S, S, M)),
'scalar_broadcast_lhs_coef', (), (),
lambda x: x,
426 {
'beta': 0.2,
'alpha': 0.6}),
427 (
'baddbmm', (S, S, M), ((S, S, S), (S, S, M)),),
428 (
'baddbmm', (1,), ((S, S, S), (S, S, M)),
'broadcast_lhs'),
429 (
'baddbmm', (S, S, M), ((S, S, S), (S, S, M)),
'coef', (), (),
lambda x: x, {
'beta': 0.2,
'alpha': 0.6}),
430 (
'baddbmm', (1,), ((S, S, S), (S, S, M)),
'broadcast_lhs_coef',
431 (), (),
lambda x: x, {
'beta': 0.2,
'alpha': 0.6}),
432 (
'baddbmm', (), ((S, S, S), (S, S, M)),
'scalar_broadcast_lhs'),
433 (
'baddbmm', (), ((S, S, S), (S, S, M)),
'scalar_broadcast_lhs_coef', (), (),
lambda x: x,
434 {
'beta': 0.2,
'alpha': 0.6}),
435 (
'addmv', (S,), ((S, M), (M,)),),
436 (
'addmv', (1,), ((S, M), (M,)),
'broadcast_lhs'),
437 (
'addmv', (S,), ((S, M), (M,)),
'coef', (), (),
lambda x: x, {
'beta': 0.2,
'alpha': 0.6}),
438 (
'addmv', (1,), ((S, M), (M,)),
'broadcast_lhs_coef', (), (),
lambda x: x, {
'beta': 0.2,
'alpha': 0.6}),
439 (
'addmv', (), ((S, M), (M,)),
'scalar_broadcast_lhs'),
440 (
'addmv', (), ((S, M), (M,)),
'scalar_broadcast_lhs_coef', (), (),
lambda x: x, {
'beta': 0.2,
'alpha': 0.6}),
441 (
'addr', (S, M), ((S,), (M,)),),
442 (
'addr', (), ((S,), (M,)),
'broadcast_lhs'),
443 (
'addr', (S, M), ((S,), (M,)),
'coef', (), (),
lambda x: x, {
'beta': 0.2,
'alpha': 0.6}),
444 (
'addr', (), ((S,), (M,)),
'broadcast_lhs_coef', (), (),
lambda x: x, {
'beta': 0.2,
'alpha': 0.6}),
445 (
'dot', (L,), ((L,),),),
446 (
'mm', (S, M), ((M, S),)),
447 (
'bmm', (M, S, M), ((M, M, S),)),
448 (
'mv', (S, M), ((M,),)),
449 (
'ger', (S,), ((M,),)),
450 (
'matmul', (L,), ((L,),),),
451 (
'matmul', (S, M), ((M,),),
"2d_1d"),
452 (
'matmul', (M, ), ((M, S),),
"1d_2d"),
453 (
'matmul', (S, M), ((M, S),),
"2d_2d"),
454 (
'matmul', (S, S, M, M), ((S, S, M, S),),
"4d_4d"),
455 (
'matmul', (S, S, M, M), ((M,),),
"4d_1d"),
456 (
'matmul', (M,), ((S, S, M, S),),
"1d_4d"),
457 (
'matrix_power', (S, S), [2],
"n=2"),
458 (
'matrix_power', (S, S, S), [3],
"n=3"),
459 (
'matrix_power', (S, S, S), [1],
"n=1"),
460 (
'matrix_power', (S, S, S), [0],
"n=0"),
461 (
'matrix_power',
lambda: random_fullrank_matrix_distinct_singular_value(S), [-1],
"n=-1",
462 NO_ARGS, [skipIfNoLapack]),
463 (
'matrix_power',
lambda: random_fullrank_matrix_distinct_singular_value(S), [-3],
"n=-3",
464 NO_ARGS, [skipIfNoLapack]),
465 (
'matrix_power',
lambda: random_fullrank_matrix_distinct_singular_value(S, S), [-2],
"n=-2",
466 NO_ARGS, [skipIfNoLapack]),
467 (
'mvlgamma', torch.empty(S,).uniform_(0.5, 1), [1],
"p=1"),
468 (
'mvlgamma', torch.empty(S,).uniform_(1, 2), [2],
"p=2"),
469 (
'mvlgamma', torch.empty(S, S).uniform_(1.5, 3), [3],
"p=3"),
470 (
'mvlgamma', torch.empty(S, S).uniform_(2.5, 5), [5],
"p=5"),
471 (
'addcmul', (S, S), ((S, S), (S, S))),
472 (
'addcmul', (S, S), ((S, 1), (1, S)),
'broadcast_rhs'),
473 (
'addcmul', (1,), ((S, S, 1), (1, S)),
'broadcast_all'),
474 (
'addcmul', (S, S), ((S, S), (S, S)),
'scale', (), (),
lambda x: x, {
'value': 0.5}),
475 (
'addcmul', (S, S), ((S, 1), (1, S)),
'scale_broadcast_rhs', (), (),
lambda x: x, {
'value': 0.5}),
476 (
'addcmul', (1,), ((S, S, 1), (1, S)),
'scale_broadcast_all', (), (),
lambda x: x, {
'value': 0.5}),
477 (
'addcmul', (), ((), ()),
'scalar'),
478 (
'addcmul', (S, S), ((), ()),
'scalar_broadcast_rhs'),
479 (
'addcmul', (), ((S, S, 1), (1, S)),
'scalar_broadcast_lhs'),
480 (
'addcmul', (), ((), ()),
'scalar_scale', (), (),
lambda x: x, {
'value': 0.5}),
481 (
'addcmul', (S, S), ((), ()),
'scalar_scale_broadcast_rhs', (), (),
lambda x: x, {
'value': 0.5}),
482 (
'addcmul', (), ((S, S, 1), (1, S)),
'scalar_scale_broadcast_lhs', (), (),
lambda x: x, {
'value': 0.5}),
483 (
'addcdiv', (S, S), ((S, S), (S, S))),
484 (
'addcdiv', (S, S), ((S, 1), (1, S)),
'broadcast_rhs'),
485 (
'addcdiv', (1,), ((S, S, 1), (1, S)),
'broadcast_all'),
486 (
'addcdiv', (S, S), ((S, S), (S, S)),
'scale', (), (),
lambda x: x, {
'value': 0.5}),
487 (
'addcdiv', (S, S), ((S, 1), (1, S)),
'scale_broadcast_rhs', (), (),
lambda x: x, {
'value': 0.5}),
488 (
'addcdiv', (1,), ((S, S, 1), (1, S)),
'scale_broadcast_all', (), (),
lambda x: x, {
'value': 0.5}),
489 (
'addcdiv', (), ((), ()),
'scalar'),
490 (
'addcdiv', (S, S), ((), ()),
'scalar_broadcast_rhs'),
491 (
'addcdiv', (), ((S, S, 1), (1, S)),
'scalar_broadcast_lhs'),
492 (
'addcdiv', (), ((), ()),
'scalar_scale', (), (),
lambda x: x, {
'value': 0.5}),
493 (
'addcdiv', (S, S), ((), ()),
'scalar_scale_broadcast_rhs', (), (),
lambda x: x, {
'value': 0.5}),
494 (
'addcdiv', (), ((S, S, 1), (1, S)),
'scalar_scale_broadcast_lhs', (), (),
lambda x: x, {
'value': 0.5}),
495 (
'zero_', (S, S, S), NO_ARGS),
496 (
'zero_', (), NO_ARGS,
'scalar'),
497 (
'logsumexp', (S, S), (1,)),
498 (
'logsumexp', (), (0,),
'scalar'),
499 (
'norm', (S, S), (),
'default'),
500 (
'norm', (S, S), (2,),
'2'),
501 (
'norm', (S, S), (0,),
'0'),
502 (
'norm', (S, S), (0.5,),
'0_5'),
503 (
'norm', (S, S), (1,),
'1'),
504 (
'norm', (S, S), (3,),
'3'),
505 (
'norm', (S, S), (inf,),
'inf'),
506 (
'norm', (S, S), (-inf,),
'-inf'),
507 (
'norm', (S, S), (
'fro',),
'fro_default'),
508 (
'norm', (S, S), (
'fro', [0, 1],),
'fro'),
509 (
'norm', (S, S), (
'nuc',),
'nuc', NO_ARGS, [skipIfNoLapack]),
510 (
'norm', (S, S), (-1,),
'neg_1'),
511 (
'norm', (S, S), (-2,),
'neg_2'),
512 (
'norm', (S, S), (-0.5,),
'neg_0_5'),
513 (
'norm', (S, S), (-1.5,),
'neg_1_5'),
514 (
'norm', (S, S), (-2, 1,),
'neg_2_2_dim', [1]),
515 (
'norm', (S, S), (-1, 1,),
'neg_1_2_dim', [1]),
516 (
'norm', (S, S), (0, 1,),
'0_2_dim', [1]),
517 (
'norm', (S, S), (1, 1,),
'1_2_dim', [1]),
518 (
'norm', (S, S), (2, 1,),
'2_2_dim', [1]),
519 (
'norm', (S, S), (3, 1,),
'3_2_dim', [1]),
520 (
'norm', (S, S), (inf, 1,),
'inf_2_dim'),
521 (
'norm', torch.rand(S, S, S) + 5e-2, (1.5,),
'1_5_default'),
522 (
'norm', (S, S, S), (2, 1),
'2_dim', [1]),
523 (
'norm', (S, S, S), (3, 1),
'3_dim', [1]),
524 (
'norm', torch.rand(S, S, S) + 5e-2, (1.5, 1),
'1_5_dim', [1]),
525 (
'norm', (S, S, S), (2, 1,
True),
'keepdim_2_dim', [1]),
526 (
'norm', (S, S, S), (3, 1,
True),
'keepdim_3_dim', [1]),
527 (
'norm', torch.rand(S, S, S) + 5e-2, (1.5, 1,
True),
'keepdim_1_5_dim', [1]),
528 (
'norm', (), (2, 0),
'2_dim_scalar', [1]),
529 (
'norm', (), (3, 0),
'3_dim_scalar', [1]),
530 (
'norm', (), (2, 0,
True),
'keepdim_2_dim_scalar', [1]),
531 (
'norm', (), (3, 0,
True),
'keepdim_3_dim_scalar', [1]),
532 (
'clone', (S, M, S), NO_ARGS),
533 (
'clone', (), NO_ARGS,
'scalar'),
534 (
'dist', (S, S, S), ((S, S, S),)),
535 (
'dist', (S, S, S), ((S,),),
'broadcast_rhs'),
536 (
'dist', (S,), ((S, S, S),),
'broadcast_lhs'),
537 (
'dist', (S, 1, S), ((S, S),),
'broadcast_all'),
538 (
'dist', (), ((),),
'scalar'),
539 (
'dist', (S, S, S), ((),),
'scalar_broadcast_rhs'),
540 (
'dist', (), ((S, S, S),),
'scalar_broadcast_lhs'),
541 (
'dist', (S, S, S), ((S, S, S), 4),
'4'),
542 (
'dist', (S, S, S), ((S,), 4),
'4_broadcast_rhs'),
543 (
'dist', (S,), ((S, S, S), 4),
'4_broadcast_lhs'),
544 (
'dist', (S, 1, S), ((S, S), 4),
'4_broadcast_all'),
545 (
'dist', (), ((), 4),
'scalar_4'),
546 (
'dist', (S, S, S), ((), 4),
'scalar_4_broadcast_rhs'),
547 (
'dist', (), ((S, S, S), 4),
'scalar_4_broadcast_lhs'),
548 (
'diag', (M, M), NO_ARGS,
'2d'),
549 (
'diag', (3, 5), NO_ARGS,
'2d_wide'),
550 (
'diag', (3, 5), (2,),
'2d_wide_pos'),
551 (
'diag', (3, 5), (-2,),
'2d_wide_neg'),
552 (
'diag', (5, 3), NO_ARGS,
'2d_tall'),
553 (
'diag', (5, 3), (2,),
'2d_tall_pos'),
554 (
'diag', (5, 3), (-2,),
'2d_tall_neg'),
555 (
'diag', (M,), NO_ARGS,
'1d'),
556 (
'diag', (M, M), (1,),
'2d_1'),
557 (
'diag', (M, M), (2,),
'2d_2'),
558 (
'diag_embed', (S, S), NO_ARGS),
559 (
'diagonal', (M, M), NO_ARGS,
'2d'),
560 (
'diagonal', (3, 5), NO_ARGS,
'2d_wide'),
561 (
'diagonal', (3, 5), (2,),
'2d_wide_pos'),
562 (
'diagonal', (3, 5), (-2,),
'2d_wide_neg'),
563 (
'diagonal', (5, 3), NO_ARGS,
'2d_tall'),
564 (
'diagonal', (5, 3), (2,),
'2d_tall_pos'),
565 (
'diagonal', (5, 3), (-2,),
'2d_tall_neg'),
566 (
'diagonal', (M, M), (1,),
'2d_1'),
567 (
'diagonal', (M, M), (2,),
'2d_2'),
568 (
'diagonal', (M, M, M), (1, 1, 2),
'3d_1'),
569 (
'diagonal', (M, M, M), (2, 0, 1),
'3d_2'),
570 (
'diagonal', (M, M, M), (-2, 0, 1),
'3d_3'),
571 (
'tril', (M, M), NO_ARGS),
572 (
'tril', (M, M), (2,),
'idx'),
573 (
'tril', (S, M, M), NO_ARGS,
'batched'),
574 (
'tril', (S, M, M), (2,),
'batched_idx'),
575 (
'tril', (3, 3, S, S), NO_ARGS,
'more_batched'),
576 (
'triu', (M, M), NO_ARGS),
577 (
'triu', (M, M), (2,),
'idx'),
578 (
'triu', (S, M, M), NO_ARGS,
'batched'),
579 (
'triu', (S, M, M), (2,),
'batched_idx'),
580 (
'triu', (3, 3, S, S), NO_ARGS,
'more_batched'),
581 (
'trace', (M, M), NO_ARGS),
582 (
'cross', (S, 3), ((S, 3),)),
583 (
'cross', (S, 3, S), ((S, 3, S), 1),
'dim'),
584 (
'index_select', (S, S, S), (0, index_variable(2, S)),
'dim', [0]),
585 (
'index_select', (), (0,
torch.tensor([0], dtype=torch.int64)),
'scalar_mixed_dim', [0]),
586 (
'index_select', (), (0,
torch.tensor(0, dtype=torch.int64)),
'scalar_dim', [0]),
587 (
'index_add', (S, S), (0, index_variable(2, S), (2, S)),
'dim', [0]),
588 (
'index_add', (), (0,
torch.tensor([0], dtype=torch.int64), (1,)),
'scalar_input_dim', [0]),
589 (
'index_add', (), (0,
torch.tensor(0, dtype=torch.int64), ()),
'scalar_all_dim', [0]),
590 (
'index_copy', (S, S), (0, index_perm_variable(2, S), (2, S)),
'dim', [0]),
591 (
'index_copy', (), (0,
torch.tensor([0], dtype=torch.int64), (1,)),
'scalar_input_dim', [0]),
592 (
'index_copy', (), (0,
torch.tensor(0, dtype=torch.int64), ()),
'scalar_all_dim', [0]),
593 (
'index_fill', (S, S), (0, index_variable(2, S), 2),
'dim', [0]),
594 (
'index_fill', (S, S), (0, index_variable(2, S), ()),
'variable_dim', [0]),
595 (
'index_fill', (S, S), (0,
torch.tensor(0, dtype=torch.int64), 2),
'scalar_index_dim', [0]),
596 (
'index_fill', (), (0,
torch.tensor([0], dtype=torch.int64), 2),
'scalar_input_dim', [0]),
597 (
'index_fill', (), (0,
torch.tensor(0, dtype=torch.int64), 2),
'scalar_both_dim', [0]),
598 (
'inverse',
lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS,
'', NO_ARGS, [skipIfNoLapack]),
599 (
'inverse',
lambda: random_fullrank_matrix_distinct_singular_value(S, 2, 3),
600 NO_ARGS,
'batched', NO_ARGS, [skipIfNoLapack]),
601 (
'det', (S, S), NO_ARGS,
'', NO_ARGS, [skipIfNoLapack]),
602 (
'det', (1, 1), NO_ARGS,
'1x1', NO_ARGS, [skipIfNoLapack]),
603 (
'det',
lambda: random_symmetric_matrix(S), NO_ARGS,
'symmetric', NO_ARGS, [skipIfNoLapack]),
604 (
'det',
lambda: random_symmetric_psd_matrix(S), NO_ARGS,
'symmetric_psd', NO_ARGS, [skipIfNoLapack]),
605 (
'det',
lambda: random_symmetric_pd_matrix(S), NO_ARGS,
'symmetric_pd', NO_ARGS, [skipIfNoLapack]),
606 (
'det',
lambda: random_square_matrix_of_rank(S, S - 2), NO_ARGS,
'dim2_null', NO_ARGS, [skipIfNoLapack]),
607 (
'det',
lambda: random_square_matrix_of_rank(S, 1), NO_ARGS,
'rank1', NO_ARGS, [skipIfNoLapack]),
608 (
'det',
lambda: random_square_matrix_of_rank(S, 2), NO_ARGS,
'rank2', NO_ARGS, [skipIfNoLapack]),
609 (
'det',
lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS,
610 'distinct_singular_values', NO_ARGS, [skipIfNoLapack]),
616 (
'logdet',
lambda: make_nonzero_det(torch.randn(S, S), 1), NO_ARGS,
'', NO_ARGS, [skipIfNoLapack]),
617 (
'logdet',
lambda: make_nonzero_det(torch.randn(1, 1), 1), NO_ARGS,
'1x1', NO_ARGS, [skipIfNoLapack]),
618 (
'logdet',
lambda: make_nonzero_det(random_symmetric_matrix(S), 1), NO_ARGS,
619 'symmetric', NO_ARGS, [skipIfNoLapack]),
620 (
'logdet',
lambda: make_nonzero_det(random_symmetric_pd_matrix(S), 1), NO_ARGS,
621 'symmetric_pd', NO_ARGS, [skipIfNoLapack]),
622 (
'logdet',
lambda: make_nonzero_det(random_fullrank_matrix_distinct_singular_value(S), 1, 0), NO_ARGS,
623 'distinct_singular_values', NO_ARGS, [skipIfNoLapack]),
624 (
'slogdet',
lambda: make_nonzero_det(torch.randn(1, 1), 1), NO_ARGS,
625 '1x1_pos_det', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
626 (
'slogdet',
lambda: make_nonzero_det(torch.randn(1, 1), -1), NO_ARGS,
627 '1x1_neg_det', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
628 (
'slogdet',
lambda: make_nonzero_det(torch.randn(S, S), 1), NO_ARGS,
629 'pos_det', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
630 (
'slogdet',
lambda: make_nonzero_det(torch.randn(S, S), -1), NO_ARGS,
631 'neg_det', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
632 (
'slogdet',
lambda: make_nonzero_det(random_symmetric_matrix(S)), NO_ARGS,
633 'symmetric', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
634 (
'slogdet',
lambda: random_symmetric_pd_matrix(S), NO_ARGS,
635 'symmetric_pd', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
636 (
'slogdet',
lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS,
637 'distinct_singular_values', NO_ARGS, [skipIfNoLapack], itemgetter(1)),
638 (
'symeig',
lambda: random_symmetric_matrix(S), (
True,
False),
'lower', NO_ARGS, [skipIfNoLapack]),
639 (
'symeig',
lambda: random_symmetric_matrix(S), (
True,
True),
'upper', NO_ARGS, [skipIfNoLapack]),
640 (
'symeig',
lambda: random_symmetric_matrix(M), (
True,
True),
'large', NO_ARGS, [skipIfNoLapack]),
641 (
'svd',
lambda: random_fullrank_matrix_distinct_singular_value(S), NO_ARGS,
'', NO_ARGS, [skipIfNoLapack]),
642 (
'svd',
lambda: random_fullrank_matrix_distinct_singular_value(S)[:(S - 2)], NO_ARGS,
643 'wide', NO_ARGS, [skipIfNoLapack]),
644 (
'svd',
lambda: random_fullrank_matrix_distinct_singular_value(S)[:, :(S - 2)], NO_ARGS,
645 'tall', NO_ARGS, [skipIfNoLapack]),
646 (
'svd',
lambda: random_fullrank_matrix_distinct_singular_value(S)[:(S - 2)], (
False,),
647 'wide_all', NO_ARGS, [skipIfNoLapack],
lambda usv: (usv[0], usv[1], usv[2][:, :(S - 2)])),
648 (
'svd',
lambda: random_fullrank_matrix_distinct_singular_value(S)[:, :(S - 2)], (
False,),
649 'tall_all', NO_ARGS, [skipIfNoLapack],
lambda usv: (usv[0][:, :(S - 2)], usv[1], usv[2])),
650 (
'svd',
lambda: random_fullrank_matrix_distinct_singular_value(M), NO_ARGS,
651 'large', NO_ARGS, [skipIfNoLapack]),
652 (
'solve', (S, S), (random_fullrank_matrix_distinct_singular_value(
653 S, silent=
True),),
'', NO_ARGS, [skipIfNoLapack]),
654 (
'solve', (S, S, S), (random_fullrank_matrix_distinct_singular_value(S, S, silent=
True),),
655 'batched', NO_ARGS, [skipIfNoLapack]),
656 (
'solve', (2, 3, S, S), (random_fullrank_matrix_distinct_singular_value(S, 2, 3, silent=
True),),
657 'batched_dims', NO_ARGS, [skipIfNoLapack]),
658 (
'solve', (2, 2, S, S), (random_fullrank_matrix_distinct_singular_value(S, 1, silent=
True),),
659 'batched_broadcast_A', NO_ARGS, [skipIfNoLapack]),
660 (
'solve', (1, S, S), (random_fullrank_matrix_distinct_singular_value(S, 2, 2, silent=
True),),
661 'batched_broadcast_b', NO_ARGS, [skipIfNoLapack]),
662 (
'fill_', (S, S, S), (1,),
'number'),
663 (
'fill_', (), (1,),
'number_scalar'),
664 (
'fill_', (S, S, S), ((),),
'variable'),
665 (
'eq_', (S, S, S), ((S, S, S),)),
666 (
'eq_', (S, S, S), ((1,),),
'broadcast_rhs'),
667 (
'eq_', (), ((),),
'scalar'),
668 (
'eq_', (S, S, S), ((),),
'scalar_broadcast_rhs'),
669 (
'ne_', (S, S, S), ((S, S, S),)),
670 (
'ne_', (S, S, S), ((1,),),
'broadcast_rhs'),
671 (
'ne_', (), ((),),
'scalar'),
672 (
'ne_', (S, S, S), ((),),
'scalar_broadcast_rhs'),
673 (
'gt_', (S, S, S), ((S, S, S),)),
674 (
'gt_', (S, S, S), ((1,),),
'broadcast_rhs'),
675 (
'gt_', (), ((),),
'scalar'),
676 (
'gt_', (S, S, S), ((),),
'scalar_broadcast_rhs'),
677 (
'ge_', (S, S, S), ((S, S, S),)),
678 (
'ge_', (S, S, S), ((1,),),
'broadcast_rhs'),
679 (
'ge_', (), ((),),
'scalar'),
680 (
'ge_', (S, S, S), ((),),
'scalar_broadcast_rhs'),
681 (
'lt_', (S, S, S), ((S, S, S),)),
682 (
'lt_', (S, S, S), ((1,),),
'broadcast_rhs'),
683 (
'lt_', (), ((),),
'scalar'),
684 (
'lt_', (S, S, S), ((),),
'scalar_broadcast_rhs'),
685 (
'le_', (S, S, S), ((S, S, S),)),
686 (
'le_', (S, S, S), ((1,),),
'broadcast_rhs'),
687 (
'le_', (), ((),),
'scalar'),
688 (
'le_', (S, S, S), ((),),
'scalar_broadcast_rhs'),
689 (
'eq_', (S, S, S), (0,),
'pyscalar'),
690 (
'ne_', (S, S, S), (0,),
'pyscalar'),
691 (
'gt_', (S, S, S), (0,),
'pyscalar'),
692 (
'ge_', (S, S, S), (0,),
'pyscalar'),
693 (
'le_', (S, S, S), (0,),
'pyscalar'),
694 (
'lt_', (), (0,),
'pyscalar'),
695 (
'eq_', (), (0,),
'pyscalar_scalar'),
696 (
'ne_', (), (0,),
'pyscalar_scalar'),
697 (
'gt_', (), (0,),
'pyscalar_scalar'),
698 (
'ge_', (), (0,),
'pyscalar_scalar'),
699 (
'lt_', (), (0,),
'pyscalar_scalar'),
700 (
'le_', (), (0,),
'pyscalar_scalar'),
701 (
'permute', (1, 2, 3, 4), (0, 2, 3, 1)),
702 (
'permute', (1, 2, 3, 4), (0, -2, -1, 1),
'neg_dim'),
703 (
'permute', (), (dont_convert(()),),
'scalar'),
704 (
'select', (S, S, S), (1, 2),
'dim', [0]),
705 (
'select', (S, S, S), (1, -1),
'wrap_dim', [0]),
706 (
'select', (S,), (0, 2),
'1d'),
707 (
'narrow', (S, S, S), (1, 2, 2),
'dim', [0]),
708 (
'narrow', (S, S, S), (1, 0, 0),
'empty_dim', [0]),
709 (
'squeeze', (S, 1, S, 1), NO_ARGS),
710 (
'squeeze', (1, 1, 1, 1), NO_ARGS,
'input_sizes_are_ones'),
711 (
'squeeze', (S, 1, S, 1), (1,),
'1_dim', [0]),
712 (
'squeeze', (S, 1, S, 1), (2,),
'not_1_dim', [0]),
713 (
'squeeze', (), (0,),
'scalar', [0]),
714 (
'unsqueeze', (S, S, S), (0,),
'first', [0]),
715 (
'unsqueeze', (S, S, S), (1,),
'middle', [0]),
716 (
'unsqueeze', (S, S, S), (3,),
'last', [0]),
717 (
'unsqueeze', (), (0,),
'scalar', [0]),
718 (
'chunk', (S, S, S), (2,)),
719 (
'chunk', (S, S, S), (S, 1),
'dim', [1]),
720 (
'split', (S, S, S), (2,)),
721 (
'split', (S, S, S), (S, 1),
'dim', [1]),
722 (
'split', (S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],),
'size_list'),
723 (
'split', (S, S, S), ([int(S / 2), S - int(S / 2) * 2, int(S / 2)], 2),
'size_list_dim', [1]),
724 (
'gather', (M, S), (0, gather_variable((S, S), 1, M,
True)),
'dim0', [0]),
725 (
'gather', (M, S), (1, gather_variable((M, S // 2), 0, S,
True)),
'dim1', [0]),
726 (
'gather', (), (0,
torch.tensor([0], dtype=torch.int64)),
'scalar_input', [0]),
727 (
'gather', (S,), (0,
torch.tensor(0, dtype=torch.int64)),
'scalar_index', [0]),
728 (
'gather', (), (0,
torch.tensor(0, dtype=torch.int64)),
'scalar_both', [0]),
729 (
'scatter', (M, S), (0, gather_variable((S, S), 1, M), (S, S)),
'dim0', [0]),
730 (
'scatter', (M, S), (1, gather_variable((M, S // 2), 0, S), (M, S // 2)),
'dim1', [0]),
731 (
'scatter', (), (0,
torch.tensor(0, dtype=torch.int64), ()),
'scalartensor_all_dim0', [0]),
732 (
'scatter', (), (0,
torch.tensor(0, dtype=torch.int64), 2.5),
'scalar_all_dim0', [0]),
733 (
'scatter_add', (M, S), (0, gather_variable((S, S), 1, M), (S, S)),
'dim0', [0]),
734 (
'scatter_add', (M, S), (1, gather_variable((M, S // 2), 0, S), (M, S // 2)),
'dim1', [0]),
735 (
'scatter_add', (), (0,
torch.tensor(0, dtype=torch.int64), ()),
'scalar_all_dim0', [0]),
736 (
'masked_select', (M, M), (mask_not_all_zeros((M, M)),)),
737 (
'masked_select', (M, M), (mask_not_all_zeros((M,)),),
'broadcast_rhs'),
738 (
'masked_select', (M,), (mask_not_all_zeros((M, M)),),
'broadcast_lhs'),
739 (
'masked_select', (M, 1, M), (mask_not_all_zeros((M, M)),),
741 (
'masked_select', (), (
torch.tensor(1, dtype=torch.uint8),),
'scalar'),
742 (
'masked_select', (M, M), (
torch.tensor(1, dtype=torch.uint8),),
'scalar_broadcast_rhs'),
743 (
'masked_select', (), (mask_not_all_zeros((M, M)),),
'scalar_broadcast_lhs'),
744 (
'masked_fill', (M, M), (torch.ByteTensor(M, M).bernoulli_(), 10)),
745 (
'masked_fill', (M, M), (torch.ByteTensor(M, M).bernoulli_(), ()),
'tensor'),
746 (
'masked_fill', (M,), (torch.ByteTensor(M, M).bernoulli_(), 10),
'broadcast_lhs'),
747 (
'masked_fill', (M, M), (torch.ByteTensor(M,).bernoulli_(), 10),
'broadcast_rhs'),
748 (
'masked_fill', (), (
torch.tensor(0, dtype=torch.uint8).bernoulli_(), 10),
'scalar'),
749 (
'masked_fill', (), (
torch.tensor(0, dtype=torch.uint8).bernoulli_(), ()),
751 (
'masked_fill', (M, M), (
torch.tensor(0, dtype=torch.uint8).bernoulli_(), 10),
752 'scalar_broadcast_rhs'),
753 (
'masked_scatter', (M, M), (torch.ByteTensor(M, M).bernoulli_(), (M, M))),
754 (
'masked_scatter', (M,), (torch.ByteTensor(M, M).bernoulli_(), (M, M)),
756 (
'masked_scatter', (M, M), (torch.ByteTensor(M,).bernoulli_(), (M, M)),
758 (
'masked_scatter', (M, M), (bernoulli_scalar(), (M, M)),
'scalar'),
759 (
'masked_scatter', (M, M), (bernoulli_scalar(), (M, M)),
760 'scalar_broadcast_rhs'),
761 (
'resize_', (S, S, S), (torch.Size([S * S, S])),
'fewer_dims'),
762 (
'resize_', (), (dont_convert(()),),
'scalar'),
763 (
'resize_', (), (torch.Size([1, 1, 1])),
'scalar_to_dims'),
764 (
'resize_as_', (), (non_differentiable(
torch.tensor(5.)),),
'scalar'),
765 (
'resize_as_', (), (non_differentiable(torch.randn((1, 1, 1))),),
'scalar_to_dims'),
766 (
'resize_as_', (S, S, S), (non_differentiable(torch.randn(S * S, S)),)),
767 (
'sort', (S, M, S), NO_ARGS),
768 (
'sort', (S, M, S), (1,),
'dim'),
769 (
'sort', (S, M, S), (1,
True),
'dim_desc'),
770 (
'sort', (), NO_ARGS,
'scalar'),
771 (
'sort', (), (0,),
'dim_scalar'),
772 (
'sort', (), (0,
True),
'dim_desc_scalar'),
773 (
'topk', (S, M, S), (3,)),
774 (
'topk', (S, M, S), (3, 1),
'dim', [1]),
775 (
'topk', (S, M, S), (3, 1,
True),
'dim_desc', [1]),
776 (
'topk', (S, M, S), (3, 1,
True,
True),
'dim_desc_sort', [1]),
777 (
'topk', (), (1,),
'scalar'),
778 (
'topk', (), (1, 0),
'dim_scalar', [1]),
779 (
'topk', (), (1, 0,
True),
'dim_desc_scalar', [1]),
780 (
'topk', (), (1, 0,
True,
True),
'dim_desc_sort_scalar', [1]),
781 (
'take', (S, S, S), (torch.LongTensor([[-3, 2], [20, 2]]),)),
782 (
'take', (S, S, S), (
torch.tensor(0, dtype=torch.int64),),
'scalar_index'),
783 (
'take', (), (torch.LongTensor([0]),),
'scalar_data'),
784 (
'take', (), (
torch.tensor(0, dtype=torch.int64),),
'scalar_both'),
785 (
'where', (M, M), (mask_not_all_zeros((M, M)), (M, M))),
786 (
'where', (M, 1, M), (mask_not_all_zeros((M, M)), (M, M, 1)),
'broadcast_all'),
787 (
'where', (), (bernoulli_scalar(), ()),
'scalar'),
788 (
'where', (M, 1, M), (bernoulli_scalar(), (M, M, 1)),
'scalar_broadcast_mask'),
789 (
'where', (), (mask_not_all_zeros((M, M)), ()),
'scalar_broadcast_non_mask'),
790 (
'__getitem__', torch.randn(S, S, S), (dont_convert([1, 2]),)),
791 (
'__getitem__', torch.randn(S, S, S), (slice(0, 3),),
'slice'),
792 (
'__getitem__', torch.randn(S, S, S), (dont_convert([slice(0, 3), 1]),),
'slice_index'),
793 (
'__getitem__', torch.randn(S, S, S), (dont_convert([[0, 2, 3], [1, 3, 3], [0, 0, 2]]),),
'adv_index'),
794 (
'__getitem__', torch.randn(S, S, S), (dont_convert([[0, 0, 3], [1, 1, 3], [0, 0, 2]]),),
'adv_index_dup'),
795 (
'__getitem__', torch.randn(S, S, S), (dont_convert([slice(
None), slice(
None), [0, 3]]),),
'adv_index_end'),
796 (
'__getitem__', torch.randn(S, S, S), (dont_convert([slice(
None), [0, 3], slice(
None)]),),
'adv_index_mid'),
797 (
'__getitem__', torch.randn(S, S, S), (dont_convert([[0, 3], slice(
None), slice(
None)]),),
'adv_index_beg'),
798 (
'__getitem__', torch.randn(S, S, S), (dont_convert([[0, 3], [1, 2], slice(
None)]),),
'adv_index_comb'),
799 (
'__getitem__', torch.randn(S, S, S), (dont_convert([[0, 3], ]),),
'adv_index_sub'),
800 (
'__getitem__', torch.randn(S, S, S), (dont_convert([[0, 3], slice(
None)]),),
'adv_index_sub_2'),
801 (
'__getitem__', torch.randn(S, S, S), (dont_convert([[0, 3], Ellipsis]),),
'adv_index_sub_3'),
802 (
'__getitem__', torch.randn(S, S, S), (dont_convert([[0, 2, 3], [1, 3, 3],
803 torch.LongTensor([0, 0, 2])]),),
'adv_index_var'),
808 def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwargs=None):
809 if not isinstance(call_args, tuple):
810 call_args = (call_args,)
813 def maybe_non_contig(tensor):
814 return tensor
if not non_contiguous
else make_non_contiguous(tensor)
816 if isinstance(arg, torch.Size)
or isinstance(arg, dont_convert):
818 elif isinstance(arg, tuple)
and len(arg) == 0:
819 var = torch.randn((), dtype=torch.double)
820 var.requires_grad = requires_grad
822 elif isinstance(arg, tuple)
and not isinstance(arg[0], torch.Tensor):
823 return Variable(maybe_non_contig(torch.randn(*arg, dtype=torch.double)), requires_grad=requires_grad)
824 elif isinstance(arg, non_differentiable):
825 if isinstance(arg.tensor, torch.Tensor):
826 return maybe_non_contig(arg.tensor)
827 return maybe_non_contig(arg.tensor)
828 elif isinstance(arg, torch.Tensor):
829 if arg.dtype == torch.float:
832 v = maybe_non_contig(arg).detach().clone()
833 v.requires_grad = requires_grad
and v.is_floating_point()
836 return map_arg(arg())
839 args_out = tuple(map_arg(arg)
for arg
in call_args)
840 kwargs_out = {k: map_arg(v)
for k, v
in call_kwargs.items()}
if call_kwargs
else {}
841 return args_out, kwargs_out
844 def _compare_trilu_indices(
845 self, row, col, offset=0, dtype=torch.long, device=
'cpu'):
846 if row == 0
or col == 0:
850 torch.empty(0, 2, dtype=dtype, device=device).transpose(0, 1),
851 torch.tril_indices(row, col, offset, dtype=dtype, device=device))
854 torch.empty(0, 2, dtype=dtype, device=device).transpose(0, 1),
855 torch.triu_indices(row, col, offset, dtype=dtype, device=device))
859 torch.ones(row, col, dtype=dtype, device=
'cpu')
860 .tril(offset).nonzero().transpose(0, 1).to(device),
861 torch.tril_indices(row, col, offset, dtype=dtype, device=device))
864 torch.ones(row, col, dtype=dtype, device=
'cpu')
865 .tril(offset).nonzero().transpose(0, 1).to(device),
866 torch.tril_indices(row, col, offset, dtype=dtype, device=device))
869 def _compare_large_trilu_indices(
870 self, row, col, offset=0, dtype=torch.long, device=
'cpu'):
871 l = torch.ones(row, col, dtype=dtype, device=
'cpu').tril(offset) \
872 .nonzero()[-100:-1, :].transpose(0, 1).to(device)
875 r = torch.tril_indices(
876 row, col, offset, dtype=dtype, device=device)[:, -100:-1]
877 self.assertEqual(l, r)
880 l = torch.ones(row, col, dtype=dtype, device=
'cpu').triu(offset) \
881 .nonzero()[-100:-1, :].transpose(0, 1).to(device)
884 r = torch.triu_indices(
885 row, col, offset, dtype=dtype, device=device)[:, -100:-1]
886 self.assertEqual(l, r)
927 (258, 253, 1, torch.float32),
928 (257, 258, 1, torch.float64),
929 (258, 258, 1, torch.short),
930 (3, 513, 1, torch.long),
931 (513, 3, 1, torch.int),
932 (513, 0, 1, torch.double),
934 (1024, 1024, 500, torch.float32),
948 tri_large_tests_args = [
967 def run_additional_tri_tests(self, device):
969 3, 3, dtype=torch.long, device=device, layout=torch.strided)
970 l = x.tril(0).nonzero().transpose(0, 1)
971 u = x.triu(0).nonzero().transpose(0, 1)
972 self.assertEqual(l, torch.tril_indices(3, 3, device=device))
974 l, torch.tril_indices(3, 3, device=device, layout=torch.strided))
976 self.assertEqual(u, torch.triu_indices(3, 3, device=device))
978 u, torch.triu_indices(3, 3, device=device, layout=torch.strided))
982 lambda: torch.triu_indices(
983 1, 1, device=device, layout=torch.sparse_coo))
987 lambda: torch.tril_indices(
988 1, 1, device=device, layout=torch.sparse_coo))
991 def unpack_variables(args):
993 return tuple(unpack_variables(elem)
for elem
in args)
998 EXCLUDE_FUNCTIONAL = {
1010 EXCLUDE_GRADCHECK = {
1012 EXCLUDE_GRADGRADCHECK = {
1014 EXCLUDE_GRADGRADCHECK_BY_TEST_NAME = {
1022 'test_det_symmetric',
1023 'test_det_symmetric_psd',
1024 'test_det_dim2_null',
1031 'test_logdet_symmetric',
1032 'test_slogdet_1x1_neg_det',
1033 'test_slogdet_neg_det',
1034 'test_slogdet_symmetric',
1039 def exclude_tensor_method(name, test_name):
1041 exclude_all_tensor_method_by_test_name = {
1044 'test_clamp_min_scalar',
1045 'test_clamp_max_scalar',
1048 'test_where_broadcast_all',
1049 'test_where_scalar',
1050 'test_where_scalar_broadcast_mask',
1051 'test_where_scalar_broadcast_non_mask',
1054 exclude_outplace_tensor_method = {
1064 if test_name
in exclude_all_tensor_method_by_test_name:
1066 is_magic_method = name[:2] ==
'__' and name[-2:] ==
'__' 1067 is_inplace = name[-1] ==
"_" and not is_magic_method
1068 if not is_inplace
and name
in exclude_outplace_tensor_method: