4 from copy
import deepcopy
5 from itertools
import product
6 from functools
import reduce
7 from operator
import mul
15 from common_utils
import TestCase, to_gpu, freeze_rng_state, is_iterable, \
16 TEST_WITH_ROCM, skipIfRocm
17 from common_cuda
import TEST_CUDA
24 if sys.version_info[:2] == (3, 3):
25 TemporaryFile = tempfile.NamedTemporaryFile
27 TemporaryFile = tempfile.TemporaryFile
32 result = getattr(m,
'reduction',
None)
34 result = _Reduction.legacy_get_string(getattr(m,
'sizeAverage',
None),
True, emit_warning=
False)
35 assert result
is not None 40 result = getattr(m,
'weight',
None)
41 if result
is not None:
43 return getattr(m,
'weights',
None)
48 constructor_args=(10, 8),
50 reference_fn=
lambda i, p: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8),
54 constructor_args=(10, 8,
False),
57 reference_fn=
lambda i, p: torch.mm(i, p[0].t())
60 module_name=
'Threshold',
61 constructor_args=(2., 1.),
62 input_size=(2, 3, 4, 5),
64 desc=
'threshold_value' 67 module_name=
'Threshold',
68 constructor_args=(2., 10.),
69 input_size=(2, 3, 4, 5),
74 input_size=(2, 3, 4, 5),
79 input_size=(2, 3, 4, 5),
89 constructor_args=(0.1, 0.9),
95 module_name=
'Hardtanh',
97 reference_fn=
lambda i, _: i.clamp(-1, 1),
100 module_name=
'Sigmoid',
101 input_size=(2, 3, 4, 5)
105 input_size=(2, 3, 4, 5)
108 module_name=
'Softmax',
109 constructor_args=(1,),
111 reference_fn=
lambda i, _: torch.exp(i).div(torch.exp(i).sum(1,
True).expand(10, 20)),
114 module_name=
'Softmax2d',
115 input_size=(1, 3, 10, 20),
116 reference_fn=
lambda i, _: torch.exp(i).div(torch.exp(i).sum(1,
False)),
119 module_name=
'LogSoftmax',
120 constructor_args=(1,),
122 reference_fn=
lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1,
True).expand(10, 20)).log_(),
125 module_name=
'LogSoftmax',
126 constructor_args=(1,),
127 input_size=(1, 3, 10, 20),
128 reference_fn=
lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1,
False)).log_(),
133 constructor_args=(2.,),
134 input_size=(3, 2, 5),
135 reference_fn=
lambda x, _: torch.where(x >= 0, x, 2 * (x.exp() - 1)),
139 module_name=
'Hardshrink',
140 constructor_args=(2.,),
141 input_size=(4, 3, 2, 4),
144 module_name=
'LeakyReLU',
145 input_size=(3, 2, 5),
149 module_name=
'LeakyReLU',
150 constructor_args=(0.5,),
151 input_size=(3, 2, 5),
156 module_name=
'LogSigmoid',
157 input_size=(2, 3, 4),
158 reference_fn=
lambda i, _: i.sigmoid().log(),
161 module_name=
'Softplus',
163 reference_fn=
lambda i, _: torch.log(1 + torch.exp(i)),
166 module_name=
'Softplus',
167 constructor_args=(2,),
169 reference_fn=
lambda i, _: 1. / 2. * torch.log(1 + torch.exp(2 * i)),
173 module_name=
'Softplus',
174 constructor_args=(2, -100),
176 reference_fn=(
lambda i, _: ((i * 2) > -100).type_as(i) * i +
177 ((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log(1 + torch.exp(2 * i))),
178 desc=
'beta_threshold',
181 module_name=
'Softshrink',
182 input_size=(3, 2, 5),
185 module_name=
'Softshrink',
186 constructor_args=(1,),
187 input_size=(3, 2, 5),
191 module_name=
'CrossMapLRN2d',
192 constructor_args=(5, 5e-3, 1e-3, 2),
193 input_size=(2, 3, 6, 6),
194 check_gradgrad=
False,
198 input_size=(2, 3, 4),
199 reference_fn=
lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
204 constructor_args=(3,),
205 input_size=(2, 3, 4),
206 desc=
'1d_multiparam',
207 reference_fn=
lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
211 input_size=(2, 3, 4, 5),
213 reference_fn=
lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
217 constructor_args=(3,),
218 input_size=(2, 3, 4, 5),
219 desc=
'2d_multiparam',
220 reference_fn=
lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
224 input_size=(2, 3, 4, 5, 6),
225 reference_fn=
lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
230 constructor_args=(3,),
231 input_size=(2, 3, 4, 5, 6),
232 desc=
'3d_multiparam',
233 reference_fn=
lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
236 module_name=
'Softsign',
237 input_size=(3, 2, 5),
238 reference_fn=
lambda i, _: i.div(1 + torch.abs(i)),
241 module_name=
'Softmin',
242 constructor_args=(1,),
246 module_name=
'Softmin',
247 constructor_args=(1,),
248 input_size=(2, 3, 5, 10),
252 module_name=
'Tanhshrink',
253 input_size=(2, 3, 4, 5),
261 def _rand_tensor_non_equal(*size):
262 total = reduce(mul, size, 1)
263 return torch.randperm(total).view(*size).double()
266 def wrap_functional(fn, **kwargs):
267 class FunctionalModule(nn.Module):
268 def forward(self, *args):
269 return fn(*args, **kwargs)
270 return FunctionalModule
273 def poissonnllloss_no_reduce_test():
274 t = torch.randn(10, 10)
276 fullname=
'PoissonNLLLLoss_no_reduce',
277 constructor=wrap_functional(
278 lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction=
'none')),
279 input_fn=
lambda: torch.rand(10, 10),
283 def bceloss_no_reduce_test():
284 t = Variable(torch.randn(15, 10).gt(0).double())
286 fullname=
'BCELoss_no_reduce',
287 constructor=wrap_functional(
288 lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction=
'none')),
289 input_fn=
lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
290 reference_fn=
lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()),
291 check_gradgrad=
False,
295 def bceloss_no_reduce_scalar_test():
296 t = torch.randn(()).gt(0).double()
298 fullname=
'BCELoss_no_reduce_scalar',
299 constructor=wrap_functional(
300 lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction=
'none')),
301 input_fn=
lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
302 reference_fn=
lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()),
303 check_gradgrad=
False,
307 def bceloss_weights_no_reduce_test():
308 t = Variable(torch.randn(15, 10).gt(0).double())
309 weights = torch.rand(10)
311 fullname=
'BCELoss_weights_no_reduce',
312 constructor=wrap_functional(
313 lambda i: F.binary_cross_entropy(i, t.type_as(i),
314 weight=weights.type_as(i), reduction=
'none')),
315 input_fn=
lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
316 reference_fn=
lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
317 check_gradgrad=
False,
322 def bceloss_weights_no_reduce_scalar_test():
323 t = torch.randn(()).double()
324 weights = torch.rand(())
326 fullname=
'BCELoss_weights_no_reduce_scalar',
327 constructor=wrap_functional(
328 lambda i: F.binary_cross_entropy(i, t.type_as(i),
329 weight=weights.type_as(i), reduction=
'none')),
330 input_fn=
lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
331 reference_fn=
lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
332 check_gradgrad=
False,
337 def bce_with_logistic_legacy_enum_test():
338 t = Variable(torch.randn(15, 10).gt(0).double())
339 sigmoid = nn.Sigmoid()
341 fullname=
'BCEWithLogitsLoss_legacy_enum',
342 constructor=wrap_functional(
343 lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduce=
False)),
344 input_fn=
lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
345 reference_fn=
lambda i, m: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
346 check_gradgrad=
False,
351 def bce_with_logistic_no_reduce_test():
352 t = Variable(torch.randn(15, 10).gt(0).double())
353 sigmoid = nn.Sigmoid()
355 fullname=
'BCEWithLogitsLoss_no_reduce',
356 constructor=wrap_functional(
357 lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction=
'none')),
358 input_fn=
lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
359 reference_fn=
lambda i, m: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
360 check_gradgrad=
False,
365 def bce_with_logistic_no_reduce_scalar_test():
366 t = torch.randn(()).gt(0).double()
367 sigmoid = nn.Sigmoid()
369 fullname=
'BCEWithLogitsLoss_no_reduce_scalar',
370 constructor=wrap_functional(
371 lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction=
'none')),
372 input_fn=
lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
373 reference_fn=
lambda i, m: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
374 check_gradgrad=
False,
379 def kldivloss_with_target_no_reduce_test():
380 i = torch.rand(10, 10).log()
382 fullname=
'KLDivLoss_with_target_no_reduce',
383 constructor=wrap_functional(
384 lambda t: F.kl_div(i.type_as(t), t, reduction=
'none')),
385 input_fn=
lambda: torch.rand(10, 10),
386 reference_fn=
lambda t, _:
387 loss_reference_fns[
'KLDivLoss'](i.type_as(t), t, reduction=
'none'),
391 def kldivloss_no_reduce_test():
392 t = torch.randn(10, 10)
394 fullname=
'KLDivLoss_no_reduce',
395 constructor=wrap_functional(
396 lambda i: F.kl_div(i, t.type_as(i), reduction=
'none')),
397 input_fn=
lambda: torch.rand(10, 10).log(),
398 reference_fn=
lambda i, _:
399 loss_reference_fns[
'KLDivLoss'](i, t.type_as(i), reduction=
'none'),
404 def kldivloss_no_reduce_scalar_test():
407 fullname=
'KLDivLoss_no_reduce_scalar',
408 constructor=wrap_functional(
409 lambda i: F.kl_div(i, t.type_as(i), reduction=
'none')),
410 input_fn=
lambda: torch.rand(()).log(),
411 reference_fn=
lambda i, _:
412 loss_reference_fns[
'KLDivLoss'](i, t.type_as(i), reduction=
'none'),
416 def l1loss_no_reduce_test():
417 t = torch.randn(2, 3, 4)
419 fullname=
'L1Loss_no_reduce',
420 constructor=wrap_functional(
421 lambda i: F.l1_loss(i, t.type_as(i), reduction=
'none')),
422 input_fn=
lambda: torch.randn(2, 3, 4),
423 reference_fn=
lambda i, m: (i - t.type_as(i)).abs(),
427 def l1loss_no_reduce_scalar_test():
430 fullname=
'L1Loss_no_reduce_scalar',
431 constructor=wrap_functional(
432 lambda i: F.l1_loss(i, t.type_as(i), reduction=
'none')),
433 input_fn=
lambda: torch.randn(()),
434 reference_fn=
lambda i, m: (i - t.type_as(i)).abs(),
438 def mseloss_no_reduce_test():
439 input_size = (2, 3, 4, 5)
440 target = torch.randn(*input_size)
442 fullname=
'MSELoss_no_reduce',
443 constructor=wrap_functional(
444 lambda i: F.mse_loss(i, target.type_as(i), reduction=
'none')),
445 input_size=input_size,
446 reference_fn=
lambda i, m: (i - target).pow(2),
450 def mseloss_no_reduce_scalar_test():
452 target = torch.randn(input_size)
454 fullname=
'MSELoss_no_reduce_scalar',
455 constructor=wrap_functional(
456 lambda i: F.mse_loss(i, target.type_as(i), reduction=
'none')),
457 input_size=input_size,
458 reference_fn=
lambda i, m: (i - target).pow(2),
462 def nllloss_no_reduce_test():
463 t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
464 kwargs = {
'reduction':
'none'}
466 fullname=
'NLLLoss_no_reduce',
467 constructor=wrap_functional(
468 lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
469 input_fn=
lambda: torch.rand(15, 10).log(),
470 reference_fn=
lambda i, _:
471 loss_reference_fns[
'NLLLoss'](i, t.type_as(i).long(), **kwargs),
475 def nllloss_no_reduce_ignore_index_test():
476 t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
477 kwargs = {
'ignore_index': 2,
'reduction':
'none'}
479 fullname=
'NLLLoss_no_reduce_ignore_index',
480 constructor=wrap_functional(
481 lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
482 input_fn=
lambda: torch.rand(15, 10).log(),
483 reference_fn=
lambda i, _:
484 loss_reference_fns[
'NLLLoss'](i, t.type_as(i).long(), **kwargs),
488 def nllloss_no_reduce_weights_test():
489 t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
490 weight = torch.rand(10)
493 return {
'weight': weight.type_as(i),
'reduction':
'none'}
496 fullname=
'NLLLoss_no_reduce_weights',
497 constructor=wrap_functional(
498 lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
499 input_fn=
lambda: torch.rand(15, 10).add(1e-2).log(),
500 reference_fn=
lambda i, _:
501 loss_reference_fns[
'NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
505 def nllloss_no_reduce_weights_ignore_index_test():
506 t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
507 weight = torch.rand(10)
510 return {
'weight': weight.type_as(i),
'reduction':
'none',
514 fullname=
'NLLLoss_no_reduce_weights_ignore_index',
515 constructor=wrap_functional(
516 lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i.data))),
517 input_fn=
lambda: torch.rand(15, 10).add(1e-2).log(),
518 reference_fn=
lambda i, _:
519 loss_reference_fns[
'NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
523 def nllloss_no_reduce_weights_ignore_index_neg_test():
524 t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
525 weight = torch.rand(10)
528 return {
'weight': weight.type_as(i),
'reduction':
'none',
532 fullname=
'NLLLoss_no_reduce_weights_ignore_index_neg',
533 constructor=wrap_functional(
534 lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
535 input=torch.rand(15, 10).add(1e-2).log(),
536 reference_fn=
lambda i, _:
537 loss_reference_fns[
'NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
541 def nllloss2d_no_reduce_test():
542 t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
543 kwargs = {
'reduction':
'none'}
545 fullname=
'NLLLoss2d_no_reduce',
546 constructor=wrap_functional(
547 lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
548 input_fn=
lambda: torch.rand(2, 3, 5, 5).log(),
549 reference_fn=
lambda i, _:
550 loss_reference_fns[
'NLLLossNd'](i, t.type_as(i).long(), **kwargs),
554 def nllloss2d_no_reduce_ignore_index_test():
555 t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
556 kwargs = {
'ignore_index': 1,
'reduction':
'none'}
558 fullname=
'NLLLoss2d_no_reduce_ignore_index',
559 constructor=wrap_functional(
560 lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
561 input_fn=
lambda: torch.rand(2, 3, 5, 5).log(),
562 reference_fn=
lambda i, _:
563 loss_reference_fns[
'NLLLossNd'](i, t.type_as(i).long(), **kwargs),
567 def nllloss2d_no_reduce_weights_test():
568 t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
569 weight = torch.rand(3)
572 return {
'weight': weight.type_as(i),
'reduction':
'none'}
575 fullname=
'NLLLoss2d_no_reduce_weights',
576 constructor=wrap_functional(
577 lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
578 input_fn=
lambda: torch.rand(2, 3, 5, 5).log(),
579 reference_fn=
lambda i, _:
580 loss_reference_fns[
'NLLLossNd'](i, t.type_as(i).long(), **kwargs(i)),
584 def nlllossNd_no_reduce_test():
585 t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
586 kwargs = {
'reduction':
'none'}
588 fullname=
'NLLLossNd_no_reduce',
589 constructor=wrap_functional(
590 lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
591 input_fn=
lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
592 reference_fn=
lambda i, _:
593 loss_reference_fns[
'NLLLossNd'](i, t.type_as(i).long(), **kwargs),
597 def nlllossNd_no_reduce_ignore_index_test():
598 t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
599 kwargs = {
'ignore_index': 1,
'reduction':
'none'}
601 fullname=
'NLLLossNd_no_reduce_ignore_index',
602 constructor=wrap_functional(
603 lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
604 input_fn=
lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
605 reference_fn=
lambda i, _:
606 loss_reference_fns[
'NLLLossNd'](i, t.type_as(i).long(), **kwargs),
610 def nlllossNd_no_reduce_weights_test():
611 t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
612 weight = torch.rand(3)
615 return {
'weight': weight.type_as(i),
'reduction':
'none'}
618 fullname=
'NLLLossNd_no_reduce_weights',
619 constructor=wrap_functional(
620 lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
621 input_fn=
lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
622 reference_fn=
lambda i, _:
623 loss_reference_fns[
'NLLLossNd'](i, t.type_as(i).long(), **kwargs(i)),
627 def smoothl1loss_no_reduce_test():
628 t = torch.randn(2, 3, 4)
630 fullname=
'SmoothL1Loss_no_reduce',
631 constructor=wrap_functional(
632 lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction=
'none')),
633 input_fn=
lambda: torch.randn(2, 3, 4),
634 reference_fn=
lambda i, _:
635 loss_reference_fns[
'SmoothL1Loss'](i, t.type_as(i), reduction=
'none'),
639 def smoothl1loss_no_reduce_scalar_test():
642 fullname=
'SmoothL1Loss_no_reduce_scalar',
643 constructor=wrap_functional(
644 lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction=
'none')),
645 input_fn=
lambda: torch.randn(()),
646 reference_fn=
lambda i, _:
647 loss_reference_fns[
'SmoothL1Loss'](i, t.type_as(i), reduction=
'none'),
651 def multilabelmarginloss_1d_no_reduce_test():
652 t = Variable(torch.rand(10).mul(10).floor().long())
654 fullname=
'MultiLabelMarginLoss_1d_no_reduce',
655 constructor=wrap_functional(
656 lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction=
'none')),
657 input_fn=
lambda: torch.randn(10),
658 reference_fn=
lambda i, _:
659 loss_reference_fns[
'MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction=
'none'),
660 check_sum_reduction=
True,
661 check_gradgrad=
False,
665 def multilabelmarginloss_index_neg_test():
666 t = Variable(torch.clamp(torch.rand(5, 10).add(-.5).mul(20).floor().long(), min=-1))
668 fullname=
'MultiLabelMarginLoss_index_neg',
669 constructor=wrap_functional(
670 lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction=
'none')),
671 input_fn=
lambda: torch.randn(5, 10),
672 reference_fn=
lambda i, _:
673 loss_reference_fns[
'MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction=
'none'),
674 check_sum_reduction=
True,
675 check_gradgrad=
False,
679 def multilabelmarginloss_no_reduce_test():
680 t = Variable(torch.rand(5, 10).mul(10).floor().long())
682 fullname=
'MultiLabelMarginLoss_no_reduce',
683 constructor=wrap_functional(
684 lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction=
'none')),
685 input_fn=
lambda: torch.randn(5, 10),
686 reference_fn=
lambda i, _:
687 loss_reference_fns[
'MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction=
'none'),
688 check_sum_reduction=
True,
689 check_gradgrad=
False,
693 def hingeembeddingloss_no_reduce_test():
694 t = Variable(torch.randn(10).gt(0).double().mul_(2).sub(1))
696 fullname=
'HingeEmbeddingLoss_no_reduce',
697 constructor=wrap_functional(
698 lambda i: F.hinge_embedding_loss(i, t.type_as(i), reduction=
'none')),
699 input_fn=
lambda: torch.randn(10),
700 reference_fn=
lambda i, _:
701 loss_reference_fns[
'HingeEmbeddingLoss'](i, t.type_as(i), reduction=
'none'),
702 check_sum_reduction=
True,
706 def hingeembeddingloss_margin_no_reduce_test():
707 t = Variable(torch.randn(10).gt(0).double().mul_(2).sub(1))
709 fullname=
'HingeEmbeddingLoss_margin_no_reduce',
710 constructor=wrap_functional(
711 lambda i: F.hinge_embedding_loss(i, t.type_as(i), margin=0.5, reduction=
'none')),
712 input_fn=
lambda: torch.randn(10),
713 reference_fn=
lambda i, _:
714 loss_reference_fns[
'HingeEmbeddingLoss'](i, t.type_as(i), margin=0.5, reduction=
'none'),
715 check_sum_reduction=
True,
719 def softmarginloss_no_reduce_test():
720 t = torch.randn(5, 5)
722 fullname=
'SoftMarginLoss_no_reduce',
723 constructor=wrap_functional(
724 lambda i: F.soft_margin_loss(i, t.type_as(i), reduction=
'none')),
725 input_fn=
lambda: torch.randn(5, 5),
726 reference_fn=
lambda i, _:
727 loss_reference_fns[
'SoftMarginLoss'](i, t.type_as(i), reduction=
'none'),
731 def multilabelsoftmarginloss_no_reduce_test():
732 t = torch.rand(5, 10).mul(2).floor()
734 fullname=
'MultiLabelSoftMarginLoss_no_reduce',
735 constructor=wrap_functional(
736 lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), reduction=
'none')),
737 input_fn=
lambda: torch.randn(5, 10),
738 reference_fn=
lambda i, m:
739 (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log())).sum(dim=1) / i.size(1),
740 check_gradgrad=
False,
744 def multilabelsoftmarginloss_weights_no_reduce_test():
745 t = torch.rand(5, 10).mul(2).floor()
746 weights = torch.rand(10)
748 fullname=
'MultiLabelSoftMarginLoss_weights_no_reduce',
749 constructor=wrap_functional(
750 lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i),
751 weight=weights.type_as(i), reduction=
'none')),
752 input_fn=
lambda: torch.randn(5, 10),
753 reference_fn=
lambda i, m:
754 (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * weights).sum(dim=1) / i.size(1),
755 check_sum_reduction=
True,
756 check_gradgrad=
False,
760 def multimarginloss_no_reduce_test():
761 t = torch.rand(5).mul(8).floor().long()
763 fullname=
'MultiMarginLoss_no_reduce',
764 constructor=wrap_functional(
765 lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction=
'none')),
766 input_fn=
lambda: torch.randn(5, 10),
767 reference_fn=
lambda i, _:
768 loss_reference_fns[
'MultiMarginLoss'](i, t.data.type_as(i).long(), reduction=
'none'),
769 check_sum_reduction=
True,
770 check_gradgrad=
False,
774 def multimarginloss_1d_no_reduce_test():
775 t = torch.rand(1).mul(8).floor().long()
777 fullname=
'MultiMarginLoss_1d_no_reduce',
778 constructor=wrap_functional(
779 lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction=
'none')),
780 input_fn=
lambda: torch.randn(10),
781 reference_fn=
lambda i, _:
782 loss_reference_fns[
'MultiMarginLoss'](i, t.data.type_as(i).long(), reduction=
'none'),
783 check_sum_reduction=
True,
784 check_gradgrad=
False,
788 def multimarginloss_p_no_reduce_test():
789 t = torch.rand(5).mul(8).floor().long()
791 fullname=
'MultiMarginLoss_p_no_reduce',
792 constructor=wrap_functional(
793 lambda i: F.multi_margin_loss(i, t.type_as(i).long(), p=2, reduction=
'none')),
794 input_fn=
lambda: torch.randn(5, 10).clamp_(1e-2, 1 - 1e-2),
795 reference_fn=
lambda i, _:
796 loss_reference_fns[
'MultiMarginLoss'](i, t.data.type_as(i).long(), p=2, reduction=
'none'),
797 check_sum_reduction=
True,
798 check_gradgrad=
False,
802 def multimarginloss_margin_no_reduce_test():
803 t = torch.rand(5).mul(8).floor().long()
805 fullname=
'MultiMarginLoss_margin_no_reduce',
806 constructor=wrap_functional(
807 lambda i: F.multi_margin_loss(i, t.type_as(i).long(), margin=0.5, reduction=
'none')),
808 input_fn=
lambda: torch.randn(5, 10),
809 reference_fn=
lambda i, _:
810 loss_reference_fns[
'MultiMarginLoss'](i, t.data.type_as(i).long(),
811 margin=0.5, reduction=
'none'),
812 check_sum_reduction=
True,
813 check_gradgrad=
False,
817 def multimarginloss_weights_no_reduce_test():
818 t = torch.rand(5).mul(8).floor().long()
819 weights = torch.rand(10)
821 fullname=
'MultiMarginLoss_weights_no_reduce',
822 constructor=wrap_functional(
823 lambda i: F.multi_margin_loss(i, t.type_as(i).long(), weight=weights.type_as(i),
825 input_fn=
lambda: torch.randn(5, 10),
826 reference_fn=
lambda i, _:
827 loss_reference_fns[
'MultiMarginLoss'](i, t.data.type_as(i).long(),
828 weight=weights, reduction=
'none'),
829 check_sum_reduction=
True,
830 check_gradgrad=
False,
834 def fractional_max_pool2d_test(test_case):
835 random_samples = torch.DoubleTensor(1, 3, 2).uniform_()
836 if test_case ==
'ratio':
838 constructor=
lambda: nn.FractionalMaxPool2d(
839 2, output_ratio=0.5, _random_samples=random_samples),
840 input_size=(1, 3, 5, 7),
841 fullname=
'FractionalMaxPool2d_ratio')
842 elif test_case ==
'size':
844 constructor=
lambda: nn.FractionalMaxPool2d((2, 3), output_size=(
845 4, 3), _random_samples=random_samples),
846 input_size=(1, 3, 7, 6),
847 fullname=
'FractionalMaxPool2d_size')
850 def fractional_max_pool3d_test(test_case):
851 random_samples = torch.DoubleTensor(2, 4, 3).uniform_()
852 if test_case ==
'ratio':
854 constructor=
lambda: nn.FractionalMaxPool3d(
855 2, output_ratio=0.5, _random_samples=random_samples),
856 input_size=(2, 4, 5, 5, 5),
857 fullname=
'FractionalMaxPool3d_ratio')
858 elif test_case ==
'size':
860 constructor=
lambda: nn.FractionalMaxPool3d((2, 2, 2), output_size=(
861 4, 4, 4), _random_samples=random_samples),
862 input_size=(2, 4, 7, 7, 7),
863 fullname=
'FractionalMaxPool3d_size')
864 elif test_case ==
'asymsize':
866 constructor=
lambda: nn.FractionalMaxPool3d((4, 2, 3), output_size=(
867 10, 3, 2), _random_samples=random_samples),
868 input_size=(2, 4, 16, 7, 5),
869 fullname=
'FractionalMaxPool3d_asymsize')
873 poissonnllloss_no_reduce_test(),
874 bceloss_no_reduce_test(),
875 bceloss_weights_no_reduce_test(),
876 bce_with_logistic_legacy_enum_test(),
877 bce_with_logistic_no_reduce_test(),
878 bceloss_no_reduce_scalar_test(),
879 bceloss_weights_no_reduce_scalar_test(),
880 bce_with_logistic_no_reduce_scalar_test(),
881 kldivloss_with_target_no_reduce_test(),
882 kldivloss_no_reduce_test(),
883 kldivloss_no_reduce_scalar_test(),
884 l1loss_no_reduce_test(),
885 l1loss_no_reduce_scalar_test(),
886 mseloss_no_reduce_test(),
887 mseloss_no_reduce_scalar_test(),
888 nllloss_no_reduce_test(),
889 nllloss_no_reduce_ignore_index_test(),
890 nllloss_no_reduce_weights_test(),
891 nllloss_no_reduce_weights_ignore_index_test(),
892 nllloss_no_reduce_weights_ignore_index_neg_test(),
893 nllloss2d_no_reduce_test(),
894 nllloss2d_no_reduce_weights_test(),
895 nllloss2d_no_reduce_ignore_index_test(),
896 nlllossNd_no_reduce_test(),
897 nlllossNd_no_reduce_weights_test(),
898 nlllossNd_no_reduce_ignore_index_test(),
899 smoothl1loss_no_reduce_test(),
900 smoothl1loss_no_reduce_scalar_test(),
901 multilabelmarginloss_1d_no_reduce_test(),
902 multilabelmarginloss_index_neg_test(),
903 multilabelmarginloss_no_reduce_test(),
904 hingeembeddingloss_no_reduce_test(),
905 hingeembeddingloss_margin_no_reduce_test(),
906 softmarginloss_no_reduce_test(),
907 multilabelsoftmarginloss_no_reduce_test(),
908 multilabelsoftmarginloss_weights_no_reduce_test(),
909 multimarginloss_no_reduce_test(),
910 multimarginloss_1d_no_reduce_test(),
911 multimarginloss_p_no_reduce_test(),
912 multimarginloss_margin_no_reduce_test(),
913 multimarginloss_weights_no_reduce_test(),
914 fractional_max_pool2d_test(
'ratio'),
915 fractional_max_pool2d_test(
'size'),
916 fractional_max_pool3d_test(
'ratio'),
917 fractional_max_pool3d_test(
'size'),
918 fractional_max_pool3d_test(
'asymsize'),
920 module_name=
'BatchNorm1d',
921 constructor_args=(10,),
926 skip_double=TEST_WITH_ROCM,
927 test_cuda=(
not TEST_WITH_ROCM),
930 module_name=
'BatchNorm1d',
931 constructor_args=(5,),
932 input_size=(4, 5, 3),
936 skip_double=TEST_WITH_ROCM,
939 module_name=
'BatchNorm1d',
940 constructor_args=(10, 1e-3,
None),
944 desc=
'affine_simple_average',
945 skip_double=TEST_WITH_ROCM,
946 test_cuda=(
not TEST_WITH_ROCM),
949 module_name=
'BatchNorm1d',
950 constructor_args=(10, 1e-3, 0.3,
False),
955 skip_double=TEST_WITH_ROCM,
958 module_name=
'BatchNorm1d',
959 constructor_args=(10, 1e-3, 0.3,
True,
False),
963 desc=
'not_tracking_stats',
964 skip_double=TEST_WITH_ROCM,
965 test_cuda=(
not TEST_WITH_ROCM),
968 module_name=
'BatchNorm1d',
969 constructor_args=(5, 1e-3, 0.3,
False),
970 input_size=(4, 5, 3),
973 desc=
'3d_input_not_affine',
974 skip_double=TEST_WITH_ROCM,
977 module_name=
'BatchNorm2d',
978 constructor_args=(3,),
979 input_size=(2, 3, 6, 6),
982 skip_double=TEST_WITH_ROCM,
985 module_name=
'BatchNorm2d',
986 constructor_args=(3, 1e-3,
None),
987 input_size=(2, 3, 6, 6),
990 desc=
'2d_simple_average',
991 skip_double=TEST_WITH_ROCM,
994 module_name=
'BatchNorm2d',
995 constructor_args=(3, 1e-3, 0.8),
996 input_size=(2, 3, 6, 6),
1000 skip_double=TEST_WITH_ROCM,
1003 module_name=
'BatchNorm2d',
1004 constructor_args=(3, 1e-3, 0.8,
False),
1005 input_size=(2, 3, 6, 6),
1009 skip_double=TEST_WITH_ROCM,
1012 module_name=
'BatchNorm2d',
1013 constructor_args=(3, 1e-3, 0.8,
True,
False),
1014 input_size=(2, 3, 6, 6),
1017 desc=
'not_tracking_stats',
1018 skip_double=TEST_WITH_ROCM,
1021 module_name=
'BatchNorm3d',
1022 constructor_args=(3,),
1023 input_size=(2, 3, 4, 4, 4),
1028 module_name=
'BatchNorm3d',
1029 constructor_args=(3, 1e-3,
None),
1030 input_size=(2, 3, 4, 4, 4),
1033 desc=
'3d_simple_average',
1036 module_name=
'BatchNorm3d',
1037 constructor_args=(3, 1e-3, 0.7),
1038 input_size=(2, 3, 4, 4, 4),
1044 module_name=
'BatchNorm3d',
1045 constructor_args=(3, 1e-3, 0.7,
False),
1046 input_size=(2, 3, 4, 4, 4),
1052 module_name=
'BatchNorm3d',
1053 constructor_args=(3, 1e-3, 0.7,
True,
False),
1054 input_size=(2, 3, 4, 4, 4),
1057 desc=
'not_tracking_stats',
1060 module_name=
'InstanceNorm1d',
1061 constructor_args=(3, 1e-3, 0.3),
1062 input_size=(4, 3, 15),
1067 module_name=
'InstanceNorm1d',
1068 constructor_args=(3, 1e-3, 0.3,
False,
True),
1069 input_size=(4, 3, 15),
1072 desc=
'tracking_stats',
1075 module_name=
'InstanceNorm2d',
1076 constructor_args=(3, 1e-3, 0.3),
1077 input_size=(2, 3, 6, 6),
1082 module_name=
'InstanceNorm2d',
1083 constructor_args=(3, 1e-3, 0.3,
False,
True),
1084 input_size=(2, 3, 6, 6),
1087 desc=
'tracking_stats',
1090 module_name=
'InstanceNorm3d',
1091 constructor_args=(3, 1e-3, 0.3),
1092 input_size=(2, 3, 4, 4, 4),
1097 module_name=
'InstanceNorm3d',
1098 constructor_args=(3, 1e-3, 0.3,
False,
True),
1099 input_size=(2, 3, 4, 4, 4),
1102 desc=
'tracking_stats',
1105 module_name=
'LayerNorm',
1106 constructor_args=([5], 1e-3),
1107 input_size=(4, 5, 5),
1110 desc=
'1d_elementwise_affine',
1113 module_name=
'LayerNorm',
1114 constructor_args=([5], 1e-3,
False),
1115 input_size=(4, 5, 5),
1118 desc=
'1d_no_elementwise_affine',
1121 module_name=
'LayerNorm',
1122 constructor_args=([2, 2, 5], 1e-3),
1123 input_size=(4, 2, 2, 5),
1126 desc=
'3d_elementwise_affine',
1129 module_name=
'LayerNorm',
1130 constructor_args=([2, 2, 5], 1e-3,
False),
1131 input_size=(4, 2, 2, 5),
1134 desc=
'3d_no_elementwise_affine',
1137 module_name=
'GroupNorm',
1138 constructor_args=(3, 6, 1e-3),
1139 input_size=(4, 6, 5),
1145 module_name=
'GroupNorm',
1146 constructor_args=(5, 5, 1e-3,
False),
1147 input_size=(4, 5, 5),
1150 desc=
'1d_no_affine_IN',
1153 module_name=
'GroupNorm',
1154 constructor_args=(1, 5, 1e-3,
False),
1155 input_size=(4, 5, 5),
1158 desc=
'1d_no_affine_LN',
1161 module_name=
'GroupNorm',
1162 constructor_args=(3, 6, 1e-3),
1163 input_size=(4, 6, 2, 3),
1169 module_name=
'GroupNorm',
1170 constructor_args=(3, 3, 1e-3,
False),
1171 input_size=(4, 3, 2, 3),
1174 desc=
'2d_no_affine_IN',
1177 module_name=
'GroupNorm',
1178 constructor_args=(1, 3, 1e-3,
False),
1179 input_size=(4, 3, 2, 3),
1182 desc=
'2d_no_affine_LN',
1185 module_name=
'Conv1d',
1186 constructor_args=(4, 5, 3),
1187 input_size=(2, 4, 10),
1189 skip_double=TEST_WITH_ROCM,
1192 module_name=
'Conv1d',
1193 constructor_args=(4, 5, 3, 2),
1194 input_size=(2, 4, 10),
1197 skip_double=TEST_WITH_ROCM,
1200 module_name=
'Conv1d',
1201 constructor_args=(4, 5, 3, 1, 1),
1202 input_size=(2, 4, 10),
1205 skip_double=TEST_WITH_ROCM,
1208 module_name=
'Conv1d',
1209 constructor_args=(4, 5, 5, 1, 2),
1210 input_size=(2, 4, 10),
1213 skip_double=TEST_WITH_ROCM,
1216 module_name=
'Conv1d',
1217 constructor_args=(4, 4, 3, 1, 1),
1218 input_size=(1, 4, 1),
1221 skip_double=TEST_WITH_ROCM,
1224 module_name=
'Conv1d',
1225 constructor_args=(4, 4, 5, 1, 2),
1226 input_size=(1, 4, 1),
1229 skip_double=TEST_WITH_ROCM,
1232 fullname=
'Conv1d_dilated',
1233 constructor=
lambda: nn.Conv1d(4, 5, kernel_size=3, dilation=2),
1234 input_size=(2, 4, 10),
1235 skip_double=TEST_WITH_ROCM,
1238 fullname=
'Conv1d_groups',
1239 constructor=
lambda: nn.Conv1d(4, 6, kernel_size=3, groups=2),
1240 input_size=(2, 4, 6),
1244 fullname=
'ConvTranspose1d',
1245 constructor=
lambda: nn.ConvTranspose1d(3, 4, kernel_size=3, stride=(3,), padding=1, output_padding=(1,)),
1247 input_size=(1, 3, 7),
1250 module_name=
'ConvTranspose1d',
1251 constructor_args=(3, 4, 3, 2, 1, 1, 1,
False),
1252 input_size=(1, 3, 6),
1257 module_name=
'ConvTranspose1d',
1258 constructor_args=(3, 4, 3, 2, 1, 1, 1,
True, 2),
1259 input_size=(1, 3, 6),
1264 fullname=
'ConvTranspose1d_groups',
1265 constructor=
lambda: nn.ConvTranspose1d(4, 6, 3, stride=(3,), padding=1, output_padding=(1,), groups=2),
1267 input_size=(2, 4, 7),
1270 module_name=
'MaxPool1d',
1271 constructor_args=(4,),
1272 input_size=(2, 10, 4),
1275 module_name=
'MaxPool1d',
1276 constructor_args=(4, 4),
1277 input_size=(2, 10, 4),
1281 module_name=
'Conv2d',
1282 constructor_args=(3, 4, (3, 2)),
1283 input_size=(2, 3, 7, 5),
1287 module_name=
'Conv2d',
1288 constructor_args=(3, 4, (3, 3), (2, 2)),
1289 input_size=(2, 3, 6, 6),
1294 module_name=
'Conv2d',
1295 constructor_args=(3, 4, (3, 3), (2, 2), (1, 1)),
1296 input_size=(2, 3, 6, 6),
1301 module_name=
'Conv2d',
1302 constructor_args=(3, 2, (3, 3), (2, 2), (1, 1), (2, 2)),
1303 input_size=(2, 3, 8, 8),
1308 module_name=
'Conv2d',
1309 constructor_args=(3, 4, (3, 2), 1, 0, 1, 1,
False),
1310 input_size=(2, 3, 6, 5),
1315 fullname=
'Conv2d_groups',
1316 constructor=
lambda: nn.Conv2d(4, 6, (3, 2), groups=2),
1317 input_size=(2, 4, 6, 5),
1321 fullname=
'Conv2d_groups_thnn',
1322 constructor=
lambda: nn.Conv2d(4, 6, (3, 2), groups=2),
1323 input_size=(2, 4, 6, 5),
1326 module_name=
'ConvTranspose2d',
1327 constructor_args=(3, 4, 3, (3, 2), 1, (1, 1)),
1329 input_size=(1, 3, 7, 6),
1332 module_name=
'ConvTranspose2d',
1333 constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1,
False, (2, 2)),
1334 input_size=(1, 3, 6, 7),
1339 module_name=
'ConvTranspose2d',
1340 constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1,
False),
1341 input_size=(1, 3, 6, 7),
1346 fullname=
'ConvTranspose2d_groups',
1347 constructor=
lambda: nn.ConvTranspose2d(2, 4, (2, 3), groups=2),
1348 input_size=(1, 2, 4, 5),
1352 fullname=
'Conv2d_depthwise',
1353 constructor=
lambda: nn.Conv2d(4, 4, (3, 3), groups=4),
1354 input_size=(2, 4, 6, 6),
1357 fullname=
'Conv2d_depthwise_with_multiplier',
1358 constructor=
lambda: nn.Conv2d(4, 8, (3, 3), groups=4),
1359 input_size=(2, 4, 6, 6),
1362 fullname=
'Conv2d_depthwise_strided',
1363 constructor=
lambda: nn.Conv2d(4, 4, (3, 3), stride=(2, 2), groups=4),
1364 input_size=(2, 4, 6, 6),
1367 fullname=
'Conv2d_depthwise_padded',
1368 constructor=
lambda: nn.Conv2d(4, 4, (3, 3), padding=(1, 1), groups=4),
1369 input_size=(2, 4, 6, 6),
1372 fullname=
'Conv2d_depthwise_dilated',
1373 constructor=
lambda: nn.Conv2d(4, 4, (2, 2), dilation=(2, 2), groups=4),
1374 input_size=(2, 4, 5, 5),
1377 module_name=
'MaxPool2d',
1378 constructor_args=((3, 3), (2, 2), (1, 1)),
1379 input_size=(1, 3, 7, 7),
1382 module_name=
'AvgPool1d',
1383 constructor_args=(2,),
1384 input_size=(2, 3, 6),
1387 module_name=
'AvgPool1d',
1388 constructor_args=((2,), (2,)),
1389 input_size=(2, 3, 6),
1393 module_name=
'AvgPool1d',
1394 constructor_args=(2, 2, 1),
1395 input_size=(2, 3, 6),
1399 module_name=
'AvgPool2d',
1400 constructor_args=((2, 2),),
1401 input_size=(2, 3, 6, 6),
1404 module_name=
'AvgPool2d',
1405 constructor_args=((2, 2), (2, 2)),
1406 input_size=(2, 3, 6, 6),
1410 module_name=
'AvgPool2d',
1411 constructor_args=((2, 2), (2, 2), (1, 1)),
1412 input_size=(2, 3, 6, 6),
1416 module_name=
'LPPool2d',
1417 constructor_args=(2, 2, 2),
1418 input_size=(1, 3, 7, 7),
1421 module_name=
'LPPool2d',
1422 constructor_args=(1.5, 2),
1423 input_fn=
lambda: torch.rand(1, 3, 7, 7),
1427 module_name=
'LPPool1d',
1428 constructor_args=(1.5, 2),
1429 input_fn=
lambda: torch.rand(1, 3, 7),
1433 module_name=
'LPPool1d',
1434 constructor_args=(2, 2, 3),
1435 input_size=(1, 3, 7),
1438 module_name=
'LocalResponseNorm',
1439 constructor_args=(3, ),
1440 input_size=(1, 5, 7),
1444 module_name=
'LocalResponseNorm',
1445 constructor_args=(2, ),
1446 input_size=(1, 5, 7, 7),
1447 desc=
'2d_uneven_pad',
1450 module_name=
'LocalResponseNorm',
1451 constructor_args=(1, 1., 0.5, 2.),
1452 input_size=(1, 5, 7, 7, 7),
1453 desc=
'3d_custom_params',
1456 module_name=
'ReflectionPad1d',
1457 constructor_args=((1, 2),),
1458 input_size=(2, 3, 8),
1461 module_name=
'ReflectionPad2d',
1462 constructor_args=((1, 2, 3, 4),),
1463 input_size=(2, 3, 8, 8),
1466 module_name=
'ReplicationPad1d',
1467 constructor_args=((1, 2),),
1468 input_size=(2, 3, 4),
1471 module_name=
'ReplicationPad2d',
1472 constructor_args=((1, 2, 3, 4),),
1473 input_size=(2, 3, 4, 4),
1476 module_name=
'ZeroPad2d',
1477 constructor_args=((1, 2, 3, 4),),
1478 input_size=(2, 3, 4, 4)
1481 module_name=
'ZeroPad2d',
1482 constructor_args=((-1, -1, -1, -2),),
1483 input_size=(2, 3, 4, 4),
1484 desc=
'negative_dims' 1487 module_name=
'ConstantPad1d',
1488 constructor_args=((1, 2), 2.),
1489 input_size=(2, 3, 4)
1492 module_name=
'ConstantPad2d',
1493 constructor_args=((1, 2, 3, 4), 2.),
1494 input_size=(2, 3, 4, 4)
1497 module_name=
'ConstantPad3d',
1498 constructor_args=((1, 2, 3, 4, 1, 0), 2.),
1499 input_size=(2, 3, 4, 4, 5)
1502 module_name=
'Conv3d',
1503 constructor_args=(3, 4, (2, 3, 4)),
1504 input_size=(2, 3, 3, 4, 5),
1508 module_name=
'Conv3d',
1509 constructor_args=(3, 4, (2, 3, 4), 1, 0, 1, 1,
False),
1510 input_size=(2, 3, 3, 4, 5),
1515 module_name=
'Conv3d',
1516 constructor_args=(3, 4, 2, 2),
1517 input_size=(2, 3, 5, 5, 5),
1522 module_name=
'Conv3d',
1523 constructor_args=(3, 4, 2, 2, 1),
1524 input_size=(2, 3, 5, 5, 5),
1526 desc=
'stride_padding',
1529 fullname=
'Conv3d_groups',
1530 constructor=
lambda: nn.Conv3d(4, 6, kernel_size=3, groups=2),
1531 input_size=(2, 4, 4, 5, 4),
1535 fullname=
'Conv3d_dilated',
1536 constructor=
lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2),
1537 input_size=(2, 3, 5, 5, 5),
1540 fullname=
'Conv3d_dilated_strided',
1541 constructor=
lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2, stride=2),
1542 input_size=(2, 3, 5, 5, 5),
1545 module_name=
'ConvTranspose3d',
1546 constructor_args=(2, 3, (2, 3, 2)),
1548 input_size=(1, 2, 4, 5, 4),
1551 module_name=
'ConvTranspose3d',
1552 constructor_args=(2, 3, (2, 3, 2), 1, 0, 0, 1,
True, (2, 2, 2)),
1554 input_size=(1, 2, 4, 5, 4),
1558 module_name=
'MaxPool3d',
1559 constructor_args=((2, 2, 2),),
1560 input_size=(2, 3, 5, 5, 5),
1563 module_name=
'MaxPool3d',
1564 constructor_args=(2, (2, 2, 2)),
1565 input_size=(2, 3, 5, 5, 5),
1569 module_name=
'MaxPool3d',
1570 constructor_args=(2, 2, (1, 1, 1)),
1571 input_size=(2, 3, 5, 5, 5),
1572 desc=
'stride_padding',
1575 module_name=
'AvgPool3d',
1576 constructor_args=((2, 2, 2),),
1577 input_size=(2, 3, 4, 4, 4),
1580 module_name=
'AvgPool3d',
1581 constructor_args=(2, (2, 2, 2)),
1582 input_size=(2, 3, 5, 5, 5),
1586 module_name=
'AvgPool3d',
1587 constructor_args=(2, 2, (1, 1, 1)),
1588 input_size=(2, 3, 5, 5, 5),
1592 module_name=
'AvgPool3d',
1593 constructor_args=(4, 2, (1, 2, 1)),
1594 input_size=(2, 3, 5, 5, 5),
1595 desc=
'stride_pad_gpu_fixedkw_output',
1598 module_name=
'AvgPool3d',
1599 constructor_args=((2, 4, 8), 1, (1, 1, 2)),
1600 input_size=(2, 3, 2, 4, 8),
1601 desc=
'stride_pad_gpu_general_output',
1604 module_name=
'AvgPool3d',
1605 constructor_args=(3, 1, 0),
1606 input_size=(2, 3, 4, 4, 4),
1607 desc=
'stride1_pad0_gpu_input',
1610 module_name=
'AvgPool3d',
1611 constructor_args=(2, 2, (1, 1, 1)),
1612 input_size=(2, 3, 4, 4, 4),
1613 desc=
'stride_pad_gpu_input_nooverlap',
1616 module_name=
'ReplicationPad3d',
1617 constructor_args=((1, 2, 3, 4, 5, 6),),
1618 input_size=(2, 3, 5, 5, 5),
1621 module_name=
'Embedding',
1622 constructor_args=(4, 3),
1623 input_fn=
lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
1624 jacobian_input=
False,
1625 check_gradgrad=
False,
1628 module_name=
'EmbeddingBag',
1629 constructor_args=(4, 3),
1630 input_fn=
lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
1631 jacobian_input=
False,
1632 check_gradgrad=
False,
1636 module_name=
'EmbeddingBag',
1637 constructor_args=(4, 3,
None, 2.,
False,
'sum'),
1638 input_fn=
lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
1639 jacobian_input=
False,
1640 check_gradgrad=
False,
1644 module_name=
'EmbeddingBag',
1645 constructor_args=(4, 3,
None, 2.,
False,
'max'),
1646 input_fn=
lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
1647 jacobian_input=
False,
1648 check_gradgrad=
False,
1652 fullname=
'EmbeddingBag_sparse',
1653 constructor=
lambda: nn.EmbeddingBag(4, 3, sparse=
True),
1654 input_fn=
lambda: torch.randperm(2).repeat(1, 2),
1655 jacobian_input=
False,
1656 check_gradgrad=
False,
1659 constructor=
lambda: nn.Embedding(4, 3, sparse=
True),
1660 input_fn=
lambda: torch.randperm(2).repeat(1, 2),
1661 jacobian_input=
False,
1662 fullname=
'Embedding_sparse',
1663 check_gradgrad=
False,
1666 module_name=
'PixelShuffle',
1667 constructor_args=(3,),
1668 input_size=(1, 9, 4, 4),
1671 constructor=wrap_functional(F.interpolate, size=12, scale_factor=
None, mode=
'nearest'),
1672 input_size=(1, 2, 4),
1673 fullname=
'interpolate_nearest_1d',
1677 constructor=wrap_functional(F.interpolate, size=(12, ), scale_factor=
None, mode=
'nearest'),
1678 input_size=(1, 2, 3),
1679 fullname=
'interpolate_nearest_tuple_1d',
1683 constructor=wrap_functional(F.interpolate, size=
None, scale_factor=4., mode=
'nearest'),
1684 input_size=(1, 2, 4),
1685 fullname=
'interpolate_nearest_scale_1d',
1689 constructor=wrap_functional(F.interpolate, size=12, scale_factor=
None, mode=
'linear', align_corners=
False),
1690 input_size=(1, 2, 4),
1691 fullname=
'interpolate_linear_1d',
1695 constructor=wrap_functional(F.interpolate, size=(4, ), scale_factor=
None, mode=
'linear', align_corners=
False),
1696 input_size=(1, 2, 3),
1697 fullname=
'interpolate_linear_tuple_1d',
1701 constructor=wrap_functional(F.interpolate, size=
None, scale_factor=4., mode=
'linear', align_corners=
False),
1702 input_size=(1, 2, 4),
1703 fullname=
'interpolate_linear_scale_1d',
1707 constructor=wrap_functional(F.interpolate, size=12, scale_factor=
None, mode=
'linear', align_corners=
True),
1708 input_size=(1, 2, 4),
1709 fullname=
'interpolate_linear_1d_align_corners',
1713 constructor=wrap_functional(F.interpolate, size=
None, scale_factor=4., mode=
'linear', align_corners=
True),
1714 input_size=(1, 2, 4),
1715 fullname=
'interpolate_linear_scale_1d_align_corners',
1719 constructor=wrap_functional(F.interpolate, size=12, scale_factor=
None, mode=
'nearest'),
1720 input_size=(1, 2, 4, 4),
1721 fullname=
'interpolate_nearest_2d',
1725 constructor=wrap_functional(F.interpolate, size=(12, 16), scale_factor=
None, mode=
'nearest'),
1726 input_size=(1, 2, 3, 4),
1727 fullname=
'interpolate_nearest_tuple_2d',
1731 constructor=wrap_functional(F.interpolate, size=
None, scale_factor=4., mode=
'nearest'),
1732 input_size=(1, 2, 4, 4),
1733 fullname=
'interpolate_nearest_scale_2d',
1737 constructor=wrap_functional(F.interpolate, size=12, scale_factor=
None, mode=
'bilinear', align_corners=
False),
1738 input_size=(1, 2, 4, 4),
1739 fullname=
'interpolate_bilinear_2d',
1743 constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=
None,
1744 mode=
'bilinear', align_corners=
False),
1745 input_size=(1, 2, 2, 3),
1746 fullname=
'interpolate_bilinear_tuple_2d',
1750 constructor=wrap_functional(F.interpolate, size=
None, scale_factor=4.,
1751 mode=
'bilinear', align_corners=
False),
1752 input_size=(1, 2, 4, 4),
1753 fullname=
'interpolate_bilinear_scale_2d',
1757 constructor=wrap_functional(F.interpolate, size=
None, scale_factor=(2., 2.),
1758 mode=
'bilinear', align_corners=
False),
1759 input_size=(1, 2, 4, 4),
1760 fullname=
'interpolate_bilinear_scale_tuple_shared_2d',
1764 constructor=wrap_functional(F.interpolate, size=
None, scale_factor=(2., 1.),
1765 mode=
'bilinear', align_corners=
False),
1766 input_size=(1, 2, 4, 4),
1767 fullname=
'interpolate_bilinear_scale_tuple_skewed_2d',
1771 constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=
None, mode=
'bilinear', align_corners=
True),
1772 input_size=(1, 2, 4, 4),
1773 fullname=
'interpolate_bilinear_tuple_2d_align_corners',
1777 constructor=wrap_functional(F.interpolate, size=
None, scale_factor=(2., 1.),
1778 mode=
'bilinear', align_corners=
True),
1779 input_size=(1, 2, 4, 4),
1780 fullname=
'interpolate_bilinear_scale_tuple_skewed_2d_align_corners',
1784 constructor=wrap_functional(F.interpolate, size=12, scale_factor=
None, mode=
'bicubic', align_corners=
False),
1785 input_size=(1, 2, 4, 4),
1786 fullname=
'interpolate_bicubic_2d',
1790 constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=
None,
1791 mode=
'bicubic', align_corners=
False),
1792 input_size=(1, 2, 2, 3),
1793 fullname=
'interpolate_bicubic_tuple_2d',
1797 constructor=wrap_functional(F.interpolate, size=
None, scale_factor=4., mode=
'bicubic', align_corners=
False),
1798 input_size=(1, 2, 4, 4),
1799 fullname=
'interpolate_bicubic_scale_2d',
1803 constructor=wrap_functional(F.interpolate, size=
None, scale_factor=(2., 2.),
1804 mode=
'bicubic', align_corners=
False),
1805 input_size=(1, 2, 4, 4),
1806 fullname=
'interpolate_bicubic_scale_tuple_shared_2d',
1810 constructor=wrap_functional(F.interpolate, size=
None, scale_factor=(2., 1.),
1811 mode=
'bicubic', align_corners=
False),
1812 input_size=(1, 2, 4, 4),
1813 fullname=
'interpolate_bicubic_scale_tuple_skewed_2d',
1817 constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=
None, mode=
'bicubic', align_corners=
True),
1818 input_size=(1, 2, 4, 4),
1819 fullname=
'interpolate_bicubic_tuple_2d_align_corners',
1823 constructor=wrap_functional(F.interpolate, size=
None, scale_factor=(2., 1.),
1824 mode=
'bicubic', align_corners=
True),
1825 input_size=(1, 2, 4, 4),
1826 fullname=
'interpolate_bicubic_scale_tuple_skewed_2d_align_corners',
1830 constructor=wrap_functional(F.interpolate, size=12, scale_factor=
None, mode=
'nearest'),
1831 input_size=(1, 2, 4, 4, 4),
1832 fullname=
'interpolate_nearest_3d',
1836 constructor=wrap_functional(F.interpolate, size=(12, 16, 16), scale_factor=
None, mode=
'nearest'),
1837 input_size=(1, 2, 3, 4, 4),
1838 fullname=
'interpolate_nearest_tuple_3d',
1842 constructor=wrap_functional(F.interpolate, size=
None, scale_factor=4., mode=
'nearest'),
1843 input_size=(1, 2, 4, 4, 4),
1844 fullname=
'interpolate_nearest_scale_3d',
1848 constructor=wrap_functional(F.interpolate, size=12, scale_factor=
None, mode=
'trilinear', align_corners=
False),
1849 input_size=(1, 2, 4, 4, 4),
1850 fullname=
'interpolate_trilinear_3d',
1854 constructor=wrap_functional(F.interpolate, size=(4, 6, 6),
1855 scale_factor=
None, mode=
'trilinear', align_corners=
False),
1856 input_size=(1, 2, 2, 3, 3),
1857 fullname=
'interpolate_trilinear_tuple_3d',
1861 constructor=wrap_functional(F.interpolate, size=
None, scale_factor=3., mode=
'trilinear', align_corners=
False),
1862 input_size=(1, 2, 3, 4, 4),
1863 fullname=
'interpolate_trilinear_scale_3d',
1869 constructor=wrap_functional(F.interpolate, size=(4, 6, 6), scale_factor=
None,
1870 mode=
'trilinear', align_corners=
True),
1871 input_size=(1, 2, 2, 3, 3),
1872 fullname=
'interpolate_trilinear_tuple_3d_align_corners',
1876 constructor=wrap_functional(F.interpolate, size=
None, scale_factor=3., mode=
'trilinear', align_corners=
True),
1877 input_size=(1, 2, 3, 4, 4),
1878 fullname=
'interpolate_trilinear_scale_3d_align_corners',
1884 module_name=
'AdaptiveMaxPool1d',
1885 constructor_args=(3,),
1886 input_fn=
lambda: _rand_tensor_non_equal(1, 3, 5),
1889 module_name=
'AdaptiveMaxPool2d',
1890 constructor_args=(3,),
1891 input_fn=
lambda: _rand_tensor_non_equal(1, 3, 5, 6),
1895 module_name=
'AdaptiveMaxPool2d',
1896 constructor_args=((3, 4),),
1897 input_fn=
lambda: _rand_tensor_non_equal(1, 3, 5, 6),
1901 module_name=
'AdaptiveMaxPool2d',
1902 constructor_args=((3,
None),),
1903 input_fn=
lambda: _rand_tensor_non_equal(1, 3, 5, 6),
1907 module_name=
'AdaptiveMaxPool3d',
1908 constructor_args=(3,),
1909 input_fn=
lambda: _rand_tensor_non_equal(2, 3, 5, 6, 7),
1913 module_name=
'AdaptiveMaxPool3d',
1914 constructor_args=((3, 4, 5),),
1915 input_fn=
lambda: _rand_tensor_non_equal(2, 3, 5, 6, 7),
1919 module_name=
'AdaptiveMaxPool3d',
1920 constructor_args=((3,
None, 5),),
1921 input_fn=
lambda: _rand_tensor_non_equal(2, 3, 5, 6, 7),
1925 module_name=
'AdaptiveMaxPool3d',
1926 constructor_args=(3,),
1927 input_fn=
lambda: _rand_tensor_non_equal(2, 3, 12, 9, 3),
1928 desc=
'single_nonatomic',
1931 module_name=
'AdaptiveMaxPool3d',
1932 constructor_args=((3, 4, 5),),
1933 input_fn=
lambda: _rand_tensor_non_equal(2, 3, 6, 4, 10),
1934 desc=
'tuple_nonatomic',
1937 module_name=
'AdaptiveAvgPool1d',
1938 constructor_args=(3,),
1939 input_fn=
lambda: torch.rand(1, 3, 5),
1942 module_name=
'AdaptiveAvgPool1d',
1943 constructor_args=(1,),
1944 input_fn=
lambda: torch.rand(1, 3, 5),
1948 module_name=
'AdaptiveAvgPool2d',
1949 constructor_args=(3,),
1950 input_fn=
lambda: torch.rand(1, 3, 5, 6),
1954 module_name=
'AdaptiveAvgPool2d',
1955 constructor_args=(1,),
1956 input_fn=
lambda: torch.rand(1, 3, 5, 6),
1957 desc=
'single_1x1output',
1960 module_name=
'AdaptiveAvgPool2d',
1961 constructor_args=((3, 4),),
1962 input_fn=
lambda: torch.rand(1, 3, 5, 6),
1966 module_name=
'AdaptiveAvgPool2d',
1967 constructor_args=((3,
None),),
1968 input_fn=
lambda: torch.rand(1, 3, 5, 6),
1972 module_name=
'AdaptiveAvgPool3d',
1973 constructor_args=(3,),
1974 input_fn=
lambda: torch.rand(2, 3, 5, 2, 7),
1978 module_name=
'AdaptiveAvgPool3d',
1979 constructor_args=((3, 4, 5),),
1980 input_fn=
lambda: torch.rand(2, 3, 5, 3, 7),
1984 module_name=
'AdaptiveAvgPool3d',
1985 constructor_args=((
None, 4, 5),),
1986 input_fn=
lambda: torch.rand(2, 3, 5, 3, 7),
1991 input_size=(3, 2, 5),
2002 input_size=(3, 2, 5),
2003 constructor_args=(2.,),
2005 reference_fn=
lambda x, _: torch.where(x >= 0, x, 2. * ((.5 * x).exp() - 1)),
2010 constructor_args=(2.,),
2012 reference_fn=
lambda x, _: torch.where(x >= 0, x, 2. * ((.5 * x).exp() - 1)),
2021 constructor_args=(1,),
2022 input_size=(5, 6, 7),
2026 constructor=wrap_functional(F.softmax, dim=-1),
2027 input_size=(2, 128),
2028 fullname=
'softmax_lastdim',
2032 constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64),
2033 input_size=(2, 128),
2034 fullname=
'softmax_lastdim_dtype',
2039 constructor=wrap_functional(F.softmax, dim=1),
2040 input_size=(2, 128, 2, 2),
2041 fullname=
'softmax_spatial_special',
2043 test_cuda=(
not TEST_WITH_ROCM)
2046 constructor=wrap_functional(F.softmax, dim=1),
2047 input_size=(2, 2, 4, 4),
2048 fullname=
'softmax_spatial',
2052 constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64),
2053 input_size=(2, 2, 4, 4),
2054 fullname=
'softmax_spatial_dtype',
2059 constructor=wrap_functional(F.softmax, dim=0),
2060 input_size=(2, 3, 4, 5),
2061 fullname=
'softmax_functional_dim0',
2066 constructor=wrap_functional(F.softmax, dim=3),
2067 input_size=(2, 3, 4, 5),
2068 fullname=
'softmax_functional_dim3',
2073 constructor=wrap_functional(F.softmax, dim=-1),
2075 fullname=
'softmax_functional_scalar',
2080 constructor=wrap_functional(F.log_softmax, dim=-1),
2081 input_size=(2, 128),
2082 fullname=
'log_softmax_lastdim',
2086 constructor=wrap_functional(F.log_softmax, dim=1),
2087 input_size=(2, 128, 2, 2),
2088 fullname=
'log_softmax_spatial_special',
2090 test_cuda=(
not TEST_WITH_ROCM)
2093 constructor=wrap_functional(F.log_softmax, dim=1),
2094 input_size=(2, 2, 4, 4),
2095 fullname=
'log_softmax_spatial',
2099 constructor=wrap_functional(F.log_softmax, dim=0),
2100 input_size=(2, 3, 4, 5),
2101 fullname=
'log_softmax_dim0',
2105 constructor=wrap_functional(F.log_softmax, dim=3),
2106 input_size=(2, 3, 4, 5),
2107 fullname=
'log_softmax_dim3',
2111 constructor=wrap_functional(F.log_softmax, dim=0),
2113 fullname=
'log_softmax_scalar',
2118 constructor=
lambda: nn.Unfold((2, 2), (1, 1), (0, 0), (1, 1)),
2119 input_size=(2, 4, 3, 3),
2120 check_gradgrad=
False,
2125 constructor=
lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)),
2126 input_size=(2, 16, 4),
2127 check_gradgrad=
False,
2131 fullname=
'Unfold_int_input',
2132 constructor=
lambda: nn.Unfold(2, 1, 0, 1),
2133 input_size=(2, 4, 3, 3),
2134 check_gradgrad=
False,
2138 fullname=
'Fold_int_input',
2139 constructor=
lambda: nn.Fold(3, 2, 1, 0, 1),
2140 input_size=(2, 16, 4),
2141 check_gradgrad=
False,
2145 module_name=
'Threshold',
2146 constructor_args=(2., 1.),
2149 desc=
'threshold_value_scalar' 2159 module_name=
'ReLU6',
2165 module_name=
'RReLU',
2166 constructor_args=(0.1, 0.9),
2168 desc=
'with_up_down_scalar',
2172 module_name=
'Hardtanh',
2174 reference_fn=
lambda i, _: i.clamp(-1, 1),
2178 module_name=
'Sigmoid',
2188 module_name=
'Softmax',
2189 constructor_args=(0,),
2191 reference_fn=
lambda i, _: torch.exp(i).div(torch.exp(i).sum(0,
True)),
2195 module_name=
'LogSoftmax',
2196 constructor_args=(0,),
2198 reference_fn=
lambda i, _: torch.exp(i).div_(torch.exp(i).sum(0,
False)).log_(),
2199 desc=
'multiparam_scalar',
2203 constructor_args=(2.,),
2208 module_name=
'Hardshrink',
2209 constructor_args=(2.,),
2214 module_name=
'LeakyReLU',
2215 constructor_args=(0.5,),
2218 desc=
'with_negval_scalar' 2221 module_name=
'LogSigmoid',
2223 reference_fn=
lambda i, _: i.sigmoid().log(),
2227 module_name=
'Softplus',
2228 constructor_args=(2, -100),
2230 reference_fn=(
lambda i, _: ((i * 2) > -100).type_as(i) * i +
2231 ((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log(1 + torch.exp(2 * i))),
2232 desc=
'beta_threshold_scalar',
2235 module_name=
'Softshrink',
2236 constructor_args=(1,),
2238 desc=
'lambda_scalar',
2241 module_name=
'PReLU',
2243 reference_fn=
lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
2247 module_name=
'Softsign',
2249 reference_fn=
lambda i, _: i.div(1 + torch.abs(i)),
2253 module_name=
'Softmin',
2254 constructor_args=(0,),
2259 module_name=
'Tanhshrink',
2264 fullname=
'Padding12_1dcircular',
2265 constructor=wrap_functional(F.pad, pad=(1, 2), mode=
'circular'),
2266 input_fn=
lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]),
2267 reference_fn=
lambda i, _: padding1d_circular(i, (1, 2)),
2268 skip_double=TEST_WITH_ROCM,
2272 fullname=
'Padding31_1dcircular',
2273 constructor=wrap_functional(F.pad, pad=(3, 1), mode=
'circular'),
2274 input_fn=
lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]),
2275 reference_fn=
lambda i, _: padding1d_circular(i, (3, 1)),
2276 skip_double=TEST_WITH_ROCM,
2280 fullname=
'Padding33_1dcircular',
2281 constructor=wrap_functional(F.pad, pad=(3, 3), mode=
'circular'),
2282 input_fn=
lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]),
2283 reference_fn=
lambda i, _: padding1d_circular(i, (3, 3)),
2284 skip_double=TEST_WITH_ROCM,
2288 fullname=
'Padding1221_2dcircular',
2289 constructor=wrap_functional(F.pad, pad=(1, 2, 2, 1), mode=
'circular'),
2290 input_fn=
lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 1, 2, 3]),
2291 reference_fn=
lambda i, _: padding2d_circular(i, (1, 2, 2, 1)),
2292 skip_double=TEST_WITH_ROCM,
2296 fullname=
'Padding2322_2dcircular',
2297 constructor=wrap_functional(F.pad, pad=(2, 3, 2, 2), mode=
'circular'),
2298 input_fn=
lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 1, 2, 3]),
2299 reference_fn=
lambda i, _: padding2d_circular(i, (2, 3, 2, 2)),
2300 skip_double=TEST_WITH_ROCM,
2304 fullname=
'Padding3331_2dcircular',
2305 constructor=wrap_functional(F.pad, pad=(3, 3, 3, 1), mode=
'circular'),
2306 input_fn=
lambda: torch.arange(9, out=torch.DoubleTensor()).reshape([1, 1, 3, 3]),
2307 reference_fn=
lambda i, _: padding2d_circular(i, (3, 3, 3, 1)),
2308 skip_double=TEST_WITH_ROCM,
2312 fullname=
'Padding122112_3dcircular',
2313 constructor=wrap_functional(F.pad, pad=(1, 2, 2, 1, 1, 2), mode=
'circular'),
2314 input_fn=
lambda: torch.arange(12, out=torch.DoubleTensor()).reshape([1, 1, 2, 2, 3]),
2315 reference_fn=
lambda i, _: padding3d_circular(i, (1, 2, 2, 1, 1, 2)),
2316 skip_double=TEST_WITH_ROCM,
2320 fullname=
'Padding322112_3dcircular',
2321 constructor=wrap_functional(F.pad, pad=(3, 2, 2, 1, 1, 2), mode=
'circular'),
2322 input_fn=
lambda: torch.arange(12, out=torch.DoubleTensor()).reshape([1, 1, 2, 2, 3]),
2323 reference_fn=
lambda i, _: padding3d_circular(i, (3, 2, 2, 1, 1, 2)),
2324 skip_double=TEST_WITH_ROCM,
2328 fullname=
'Padding332122_3dcircular',
2329 constructor=wrap_functional(F.pad, pad=(3, 3, 2, 1, 2, 2), mode=
'circular'),
2330 input_fn=
lambda: torch.arange(12, out=torch.DoubleTensor()).reshape([1, 1, 2, 2, 3]),
2331 reference_fn=
lambda i, _: padding3d_circular(i, (3, 3, 2, 1, 2, 2)),
2332 skip_double=TEST_WITH_ROCM,
2337 module_name=
'Conv1d',
2338 constructor_args=(3, 4, 2, 2, (1,), 1, 1,
True,
'circular'),
2339 input_size=(2, 3, 5,),
2341 desc=
'stride1_pad1circular',
2344 module_name=
'Conv1d',
2345 constructor_args=(3, 4, 2, 2, (2,), 1, 1,
True,
'circular'),
2346 input_size=(2, 3, 5,),
2348 desc=
'stride1_pad2circular',
2351 module_name=
'Conv2d',
2352 constructor_args=(3, 4, (3, 3), (2, 2), (1, 2), 1, 1,
True,
'circular'),
2353 input_size=(2, 3, 3, 3),
2358 module_name=
'Conv3d',
2359 constructor_args=(3, 4, 2, 2, (1, 2, 3), 1, 1,
True,
'circular'),
2360 input_size=(2, 3, 3, 3, 3),
2362 desc=
'stride_pad1circular',
2367 def kldivloss_reference(input, target, reduction='mean'):
2368 safe_target = target * (target > 0).type_as(target)
2369 safe_target_log = (safe_target + (target <= 0).type_as(target)).log()
2370 result = safe_target * (safe_target_log - input)
2371 if reduction ==
'mean':
2372 return result.mean()
2373 elif reduction ==
'sum':
2375 elif reduction ==
'batchmean' and results.dim() != 0:
2376 return result.sum() / result.size(0)
2380 def nlllossNd_reference(input, target, weight=None, ignore_index=-100,
2382 assert input.dim() >= 3
2385 out_size = (N,) + input.size()[2:]
2386 output = torch.zeros(out_size).type_as(input)
2389 weight = torch.ones(C).type_as(input)
2391 for tup
in product(*[range(size)
for size
in out_size]):
2393 norm = 0.
if ignore_index == t_nx
else weight[t_nx].item()
2394 input_index = list(tup)
2395 input_index.insert(1, t_nx)
2396 output[tup] = -input[tuple(input_index)] * norm
2397 total_weight += norm
2399 if reduction ==
'mean':
2400 return output.sum() / total_weight
2401 elif reduction ==
'sum':
2406 def nllloss_reference(input, target, weight=None, ignore_index=-100,
2409 def nll_loss_helper(input, target, weight, ignore_index):
2410 if target == ignore_index:
2412 norm = 1
if weight
is None else weight[target]
2413 result = -input[target] * norm
2414 return (result, norm)
2416 losses_and_weights = [nll_loss_helper(i, t, weight, ignore_index)
2417 for i, t
in zip(input, target)]
2418 losses, weights = zip(*losses_and_weights)
2419 losses_tensor = input.new_tensor(losses)
2420 if reduction ==
'mean':
2421 return sum(losses_tensor) / sum(weights)
2422 elif reduction ==
'sum':
2423 return sum(losses_tensor)
2425 return losses_tensor
2428 def smoothl1loss_reference(input, target, reduction='mean'):
2429 abs_diff = (input - target).abs()
2430 ge_one_mask = (abs_diff >= 1).type_as(abs_diff)
2431 lt_one_mask = (abs_diff < 1).type_as(abs_diff)
2432 output = ge_one_mask * (abs_diff - 0.5) + lt_one_mask * 0.5 * (abs_diff ** 2)
2433 if reduction ==
'mean':
2434 return output.mean()
2435 elif reduction ==
'sum':
2440 def _multilabelmarginloss_reference(input, target):
2442 for target_index
in target:
2443 if target_index < 0:
2445 targets.append(target_index)
2448 for target_index
in targets:
2449 for i
in range(0, len(input)):
2450 if i
not in targets:
2451 sum += max(0, 1 - input[target_index] + input[i])
2456 def multilabelmarginloss_reference(input, target, reduction='mean'):
2457 if input.dim() == 1:
2460 output = input.new(n).zero_()
2461 output[0] = _multilabelmarginloss_reference(input, target)
2465 output = input.new(n).zero_()
2466 for i
in range(0, n):
2467 output[i] = _multilabelmarginloss_reference(input[i], target[i])
2469 if reduction ==
'mean':
2470 return output.mean() / dim
2471 elif reduction ==
'sum':
2472 return output.sum() / dim
2476 def hingeembeddingloss_reference(input, target, margin=1.0, reduction='mean'):
2477 margin_clamp = (margin - input).clamp(min=0).type_as(input)
2478 output = torch.where(target == 1, input, margin_clamp)
2480 if reduction ==
'mean':
2481 return output.mean()
2482 elif reduction ==
'sum':
2487 def softmarginloss_reference(input, target, reduction='mean'):
2488 output = (1 + (-input * target).exp()).log()
2490 if reduction ==
'mean':
2491 return output.mean()
2492 elif reduction ==
'sum':
2497 def _multimarginloss_reference(input, target_idx, p, margin, weight):
2499 weight = input.new(len(input)).fill_(1)
2502 for i
in range(0, len(input)):
2504 output += max(0, weight[target_idx] * (margin - input[target_idx] + input[i]) ** p)
2508 def multimarginloss_reference(input, target, p=1, margin=1, weight=None, reduction='mean'):
2509 if input.dim() == 1:
2512 return _multimarginloss_reference(input, target[0], p, margin, weight) / dim
2516 output = input.new(n)
2517 for x
in range(0, n):
2518 output[x] = _multimarginloss_reference(input[x], target[x], p, margin, weight)
2520 if reduction ==
'mean':
2521 return output.mean() / dim
2522 elif reduction ==
'sum':
2523 return output.sum() / dim
2527 def cosineembeddingloss_reference(input1, input2, target, margin=0, reduction='mean'):
2529 cos = a.new(a.size(0))
2530 for i
in range(0, a.size(0)):
2531 cos[i] = (a[i] * b[i]).sum() / ((((a[i] * a[i]).sum() + 1e-12) * ((b[i] * b[i]).sum() + 1e-12)) ** 0.5)
2534 output = torch.where(target == 1, 1 - _cos(input1, input2), (_cos(input1, input2) - margin).clamp(min=0))
2536 if reduction ==
'mean':
2537 return output.mean()
2538 elif reduction ==
'sum':
2543 def tripletmarginloss_reference(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False,
2545 d_p = torch.pairwise_distance(anchor, positive, p, eps)
2546 d_n = torch.pairwise_distance(anchor, negative, p, eps)
2548 d_s = torch.pairwise_distance(positive, negative, p, eps)
2549 d_n = torch.min(d_n, d_s)
2551 output = torch.clamp(margin + d_p - d_n, min=0.0)
2552 if reduction ==
'mean':
2553 return output.mean()
2554 elif reduction ==
'sum':
2559 def marginrankingloss_reference(input1, input2, target, margin=0, reduction='mean'):
2560 output = (-target * (input1 - input2) + margin).clamp(min=0)
2561 if reduction ==
'mean':
2562 return output.mean()
2563 elif reduction ==
'sum':
2569 def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean'):
2570 input_lengths = torch.as_tensor(input_lengths, dtype=torch.long)
2571 target_lengths = torch.as_tensor(target_lengths, dtype=torch.long)
2572 dt = log_probs.dtype
2573 log_probs = log_probs.double()
2574 targets = targets.long()
2575 cum_target_lengths = target_lengths.cumsum(0)
2577 for i
in range(log_probs.size(1)):
2578 input_length = input_lengths[i].item()
2579 target_length = target_lengths[i].item()
2580 cum_target_length = cum_target_lengths[i].item()
2581 targets_prime = targets.new_full((2 * target_length + 1,), blank)
2582 if targets.dim() == 2:
2583 targets_prime[1::2] = targets[i, :target_length]
2585 targets_prime[1::2] = targets[cum_target_length - target_length:cum_target_length]
2586 probs = log_probs[:input_length, i].exp()
2587 alpha = log_probs.new_zeros((target_length * 2 + 1,))
2588 alpha[0] = probs[0, blank]
2589 alpha[1] = probs[0, targets_prime[1]]
2590 mask_third = (targets_prime[:-2] != targets_prime[2:])
2591 for t
in range(1, input_length):
2592 alpha_next = alpha.clone()
2593 alpha_next[1:] += alpha[:-1]
2594 alpha_next[2:] += torch.where(mask_third, alpha[:-2], alpha.new_zeros(1))
2595 alpha = probs[t, targets_prime] * alpha_next
2596 losses.append(-alpha[-2:].sum().log()[
None])
2597 output = torch.cat(losses, 0)
2598 if reduction ==
'mean':
2599 return (output / target_lengths.to(dtype=output.dtype, device=output.device)).mean()
2600 elif reduction ==
'sum':
2602 output = output.to(dt)
2606 def padding1d_circular(input, pad):
2612 [[[2., 0., 1., 2., 0., 1.], 2613 [5., 3., 4., 5., 3., 4.]]] 2615 return torch.cat([input[:, :, -pad[0]:], input,
2616 input[:, :, 0:pad[1]]], dim=2)
2619 def padding2d_circular(input, pad):
2625 [[[[2., 0., 1., 2., 0., 1.], 2626 [5., 3., 4., 5., 3., 4.], 2627 [2., 0., 1., 2., 0., 1.], 2628 [5., 3., 4., 5., 3., 4.], 2629 [2., 0., 1., 2., 0., 1.]]]] 2631 input = torch.cat([input[:, :, -pad[2]:], input, input[:, :, 0:pad[3]]], dim=2)
2632 return torch.cat([input[:, :, :, -pad[0]:], input, input[:, :, :, 0:pad[1]]], dim=3)
2635 def padding3d_circular(input, pad):
2641 pad: (1, 2, 2, 1, 1, 2) 2642 output: [[[[[ 8., 6., 7., 8., 6., 7.], 2643 [11., 9., 10., 11., 9., 10.], 2644 [ 8., 6., 7., 8., 6., 7.], 2645 [11., 9., 10., 11., 9., 10.], 2646 [ 8., 6., 7., 8., 6., 7.]], 2648 [[ 2., 0., 1., 2., 0., 1.], 2649 [ 5., 3., 4., 5., 3., 4.], 2650 [ 2., 0., 1., 2., 0., 1.], 2651 [ 5., 3., 4., 5., 3., 4.], 2652 [ 2., 0., 1., 2., 0., 1.]], 2654 [[ 8., 6., 7., 8., 6., 7.], 2655 [11., 9., 10., 11., 9., 10.], 2656 [ 8., 6., 7., 8., 6., 7.], 2657 [11., 9., 10., 11., 9., 10.], 2658 [ 8., 6., 7., 8., 6., 7.]], 2660 [[ 2., 0., 1., 2., 0., 1.], 2661 [ 5., 3., 4., 5., 3., 4.], 2662 [ 2., 0., 1., 2., 0., 1.], 2663 [ 5., 3., 4., 5., 3., 4.], 2664 [ 2., 0., 1., 2., 0., 1.]], 2666 [[ 8., 6., 7., 8., 6., 7.], 2667 [11., 9., 10., 11., 9., 10.], 2668 [ 8., 6., 7., 8., 6., 7.], 2669 [11., 9., 10., 11., 9., 10.], 2670 [ 8., 6., 7., 8., 6., 7.]]]]] 2672 input = torch.cat([input[:, :, -pad[4]:], input, input[:, :, 0:pad[5]]], dim=2)
2673 input = torch.cat([input[:, :, :, -pad[2]:], input, input[:, :, :, 0:pad[3]]], dim=3)
2674 return torch.cat([input[:, :, :, :, -pad[0]:], input, input[:, :, :, :, 0:pad[1]]], dim=4)
2677 loss_reference_fns = {
2678 'KLDivLoss': kldivloss_reference,
2679 'NLLLoss': nllloss_reference,
2680 'NLLLossNd': nlllossNd_reference,
2681 'SmoothL1Loss': smoothl1loss_reference,
2682 'MultiLabelMarginLoss': multilabelmarginloss_reference,
2683 'HingeEmbeddingLoss': hingeembeddingloss_reference,
2684 'SoftMarginLoss': softmarginloss_reference,
2685 'MultiMarginLoss': multimarginloss_reference,
2686 'CosineEmbeddingLoss': cosineembeddingloss_reference,
2687 'TripletMarginLoss': tripletmarginloss_reference,
2688 'MarginRankingLoss': marginrankingloss_reference,
2689 'CTCLoss': ctcloss_reference,
2695 module_name=
'L1Loss',
2696 input_size=(2, 3, 4),
2697 target_size=(2, 3, 4),
2698 reference_fn=
lambda i, t, _: 1. / i.numel() *
2699 sum((a - b).abs().sum()
for a, b
in zip(i, t)),
2702 module_name=
'NLLLoss',
2703 input_fn=
lambda: torch.rand(15, 10).log(),
2704 target_fn=
lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
2705 reference_fn=
lambda i, t, m:
2706 nllloss_reference(i, t, reduction=get_reduction(m)),
2707 check_sum_reduction=
True 2710 module_name=
'NLLLoss',
2711 constructor_args=(
None,
None, 2),
2712 input_fn=
lambda: torch.rand(15, 10).log(),
2713 target_fn=
lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
2714 reference_fn=
lambda i, t, _: nllloss_reference(i, t, ignore_index=2),
2718 module_name=
'NLLLoss',
2719 constructor_args_fn=
lambda: (torch.rand(10),),
2720 input_fn=
lambda: torch.rand(15, 10).add(1e-2).log(),
2721 target_fn=
lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
2722 reference_fn=
lambda i, t, m:
2723 nllloss_reference(i, t, weight=get_weight(m)),
2727 module_name=
'NLLLoss',
2728 constructor_args_fn=
lambda: (torch.rand(10),
None, 2),
2729 input_fn=
lambda: torch.rand(15, 10).add(1e-2).log(),
2730 target_fn=
lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
2731 reference_fn=
lambda i, t, m:
2732 nllloss_reference(i, t, weight=get_weight(m), ignore_index=2),
2733 desc=
'weights_ignore_index' 2736 module_name=
'NLLLoss',
2737 constructor_args_fn=
lambda: (torch.rand(10),
None, -1),
2738 input_fn=
lambda: torch.rand(15, 10).add(1e-2).log(),
2739 target_fn=
lambda: torch.Tensor(15).uniform_().mul(10 + 1).floor().long() - 1,
2740 reference_fn=
lambda i, t, m:
2741 nllloss_reference(i, t, weight=get_weight(m), ignore_index=-1),
2742 desc=
'weights_ignore_index_neg' 2745 module_name=
'KLDivLoss',
2746 input_fn=
lambda: torch.rand(10, 10).log(),
2747 target_fn=
lambda: torch.rand(10, 10),
2748 reference_fn=
lambda i, t, m:
2749 kldivloss_reference(i, t, get_reduction(m)),
2750 check_sum_reduction=
True,
2753 module_name=
'MSELoss',
2754 input_size=(2, 3, 4, 5),
2755 target_size=(2, 3, 4, 5),
2756 reference_fn=
lambda i, t, m: ((i - t).abs().pow(2).sum() / (i.numel()
2757 if get_reduction(m) ==
'mean' else 1)),
2758 check_sum_reduction=
True,
2761 module_name=
'BCELoss',
2762 input_fn=
lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
2763 target_fn=
lambda: torch.randn(15, 10).gt(0).double(),
2764 reference_fn=
lambda i, t, m: -(t * i.log() + (1 - t) * (1 - i).log()).sum() /
2765 (i.numel()
if get_reduction(m)
else 1),
2766 check_gradgrad=
False,
2769 module_name=
'BCELoss',
2770 constructor_args_fn=
lambda: (torch.rand(10),),
2771 input_fn=
lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
2772 target_fn=
lambda: torch.randn(15, 10).gt(0).double(),
2773 reference_fn=
lambda i, t, m: -((t * i.log() + (1 - t) * (1 - i).log()) * get_weight(m)).sum() /
2774 (i.numel()
if get_reduction(m)
else 1),
2776 check_gradgrad=
False,
2779 module_name=
'CrossEntropyLoss',
2780 input_size=(15, 10),
2781 target_fn=
lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
2784 module_name=
'CrossEntropyLoss',
2785 constructor_args_fn=
lambda: (torch.rand(10),),
2786 input_size=(15, 10),
2787 target_fn=
lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
2791 module_name=
'HingeEmbeddingLoss',
2793 target_fn=
lambda: torch.randn(10).gt(0).double().mul_(2).sub(1),
2794 reference_fn=
lambda i, t, m:
2795 hingeembeddingloss_reference(i, t, reduction=get_reduction(m)),
2796 check_sum_reduction=
True,
2799 module_name=
'HingeEmbeddingLoss',
2800 constructor_args=(0.5,),
2802 target_fn=
lambda: torch.randn(10).gt(0).double().mul_(2).sub(1),
2803 reference_fn=
lambda i, t, m:
2804 hingeembeddingloss_reference(i, t, margin=0.5, reduction=get_reduction(m)),
2806 check_sum_reduction=
True,
2809 module_name=
'MultiLabelMarginLoss',
2811 target_fn=
lambda: torch.rand(10).mul(10).floor().long(),
2812 reference_fn=
lambda i, t, m:
2813 multilabelmarginloss_reference(i, t, reduction=get_reduction(m)),
2815 check_sum_reduction=
True,
2816 check_gradgrad=
False,
2819 module_name=
'MultiLabelMarginLoss',
2821 target_fn=
lambda: torch.rand(5, 10).mul(10).floor().long(),
2822 reference_fn=
lambda i, t, m:
2823 multilabelmarginloss_reference(i, t, reduction=get_reduction(m)),
2824 check_sum_reduction=
True,
2825 check_gradgrad=
False,
2828 module_name=
'MultiLabelSoftMarginLoss',
2830 target_fn=
lambda: torch.rand(5, 10).mul(2).floor(),
2831 reference_fn=
lambda i, t, m: -(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()).sum() / i.numel(),
2832 check_gradgrad=
False,
2835 module_name=
'MultiMarginLoss',
2837 target_fn=
lambda: torch.rand(5).mul(8).floor().long(),
2838 reference_fn=
lambda i, t, m:
2839 multimarginloss_reference(i, t, reduction=get_reduction(m)),
2840 check_sum_reduction=
True,
2841 check_gradgrad=
False,
2844 module_name=
'MultiMarginLoss',
2846 target_fn=
lambda: torch.rand(1).mul(8).floor().long(),
2847 reference_fn=
lambda i, t, m:
2848 multimarginloss_reference(i, t, reduction=get_reduction(m)),
2850 check_sum_reduction=
True,
2851 check_gradgrad=
False,
2854 module_name=
'MultiMarginLoss',
2855 constructor_args=(2,),
2856 input_fn=
lambda: torch.rand(5, 10).clamp_(1e-2, 1 - 1e-2),
2857 target_fn=
lambda: torch.rand(5).mul(8).floor().long(),
2858 reference_fn=
lambda i, t, m:
2859 multimarginloss_reference(i, t, p=2, reduction=get_reduction(m)),
2861 check_sum_reduction=
True,
2862 check_gradgrad=
False,
2865 module_name=
'MultiMarginLoss',
2866 constructor_args=(1, 0.5),
2867 legacy_constructor_args=(1,
None, 0.5),
2869 target_fn=
lambda: torch.rand(5).mul(8).floor().long(),
2870 reference_fn=
lambda i, t, m:
2871 multimarginloss_reference(i, t, margin=0.5, reduction=get_reduction(m)),
2873 check_sum_reduction=
True,
2874 check_gradgrad=
False,
2877 module_name=
'MultiMarginLoss',
2878 constructor_args=(1, 1., torch.rand(10)),
2879 legacy_constructor_args=(1, torch.rand(10)),
2881 target_fn=
lambda: torch.rand(5).mul(8).floor().long(),
2882 reference_fn=
lambda i, t, m:
2883 multimarginloss_reference(i, t, weight=get_weight(m), reduction=get_reduction(m)),
2885 check_sum_reduction=
True,
2886 check_gradgrad=
False,
2889 module_name=
'SmoothL1Loss',
2891 target_size=(5, 10),
2892 check_sum_reduction=
True,
2893 reference_fn=
lambda i, t, m:
2894 smoothl1loss_reference(i, t, reduction=get_reduction(m)),
2897 module_name=
'SoftMarginLoss',
2899 target_fn=
lambda: torch.randn(5, 5).sign(),
2900 reference_fn=
lambda i, t, m:
2901 softmarginloss_reference(i, t, reduction=get_reduction(m)),
2902 check_sum_reduction=
True,
2905 module_name=
'CosineEmbeddingLoss',
2906 input_fn=
lambda: (torch.rand(15, 10), torch.rand(15, 10)),
2907 target_fn=
lambda: torch.randn(15).sign(),
2908 reference_fn=
lambda i, t, m:
2909 cosineembeddingloss_reference(i[0], i[1], t, reduction=get_reduction(m)),
2910 check_sum_reduction=
True,
2913 module_name=
'CosineEmbeddingLoss',
2914 constructor_args=(0.7,),
2915 input_fn=
lambda: (torch.rand(15, 10), torch.rand(15, 10)),
2916 target_fn=
lambda: torch.randn(15).sign(),
2917 reference_fn=
lambda i, t, m:
2918 cosineembeddingloss_reference(i[0], i[1], t, margin=0.7, reduction=get_reduction(m)),
2920 check_sum_reduction=
True,
2923 module_name=
'MarginRankingLoss',
2924 input_fn=
lambda: (torch.randn(50).mul(10), torch.randn(50).mul(10)),
2925 target_fn=
lambda: torch.randn(50).sign(),
2926 reference_fn=
lambda i, t, m:
2927 marginrankingloss_reference(i[0], i[1], t, reduction=get_reduction(m)),
2928 check_sum_reduction=
True,
2931 module_name=
'MarginRankingLoss',
2932 constructor_args=(0.5,),
2933 input_fn=
lambda: (torch.randn(50).mul(10), torch.randn(50).mul(10)),
2934 target_fn=
lambda: torch.randn(50).sign(),
2935 reference_fn=
lambda i, t, m:
2936 marginrankingloss_reference(i[0], i[1], t, margin=0.5, reduction=get_reduction(m)),
2938 check_sum_reduction=
True,
2945 def _jacobian(self, input, num_out):
2946 if isinstance(input, tuple):
2947 return tuple(self.
_jacobian(elem, num_out)
for elem
in input)
2948 elif isinstance(input, list):
2949 return [self.
_jacobian(elem, num_out)
for elem
in input]
2951 return torch.zeros(input.nelement(), num_out)
2953 def _flatten_tensors(self, x):
2954 if isinstance(x, torch.Tensor):
2956 return x.to_dense().view(-1)
2962 def _zero_grad_input(self, input):
2963 if isinstance(input, torch.Tensor):
2964 if input.requires_grad
and input.grad
is not None:
2966 input.grad.detach_()
2971 def _analytical_jacobian(self, module, input, jacobian_input=True, jacobian_parameters=True):
2972 output = self._forward(module, input)
2973 output_size = output.nelement()
2976 jacobian_inp = self.
_jacobian(input, output_size)
2977 flat_jacobian_input = list(iter_tensors(jacobian_inp))
2979 if jacobian_parameters:
2980 num_param = sum(p.numel()
for p
in self._get_parameters(module)[0])
2981 jacobian_param = torch.zeros(num_param, output_size)
2983 for i
in range(output_size):
2984 param, d_param = self._get_parameters(module)
2986 d_param = [torch.zeros_like(p)
if d
is None else d
for (p, d)
in zip(param, d_param)]
2988 d_out = torch.zeros_like(output)
2989 flat_d_out = d_out.view(-1)
2992 if jacobian_parameters:
2993 self._zero_grad_parameters(module)
2997 d_input = self._backward(module, input, output, d_out)
3000 for jacobian_x, d_x
in zip(flat_jacobian_input, iter_tensors(d_input)):
3001 jacobian_x[:, i] = d_x.contiguous().view(-1)
3002 if jacobian_parameters:
3007 res += jacobian_inp,
3008 if jacobian_parameters:
3009 res += jacobian_param,
3013 def _numerical_jacobian(self, module, input, jacobian_input=True, jacobian_parameters=True):
3015 return self._forward(module, input).detach()
3019 res += get_numerical_jacobian(fw, input, eps=1e-6),
3020 if jacobian_parameters:
3021 param, _ = self._get_parameters(module)
3022 res += torch.cat([get_numerical_jacobian(fw, input, p, eps=1e-6)
for p
in param], 0),
3025 def check_jacobian(self, module, input, jacobian_input=True):
3026 jacobian_parameters = bool(self._get_parameters(module)[0])
3029 analytical_t = list(iter_tensors(analytical))
3030 numerical_t = list(iter_tensors(numerical))
3033 self.assertLessEqual(
3034 max(a.add(-1, n).abs().max()
for a, n
in zip(analytical_t, numerical_t)),
3038 def check_criterion_jacobian(self, criterion, input, target):
3040 self._forward_criterion(criterion, input, target)
3041 analytical_d_x = self._backward_criterion(criterion, input, target)
3042 numerical_d_x = deepcopy(analytical_d_x)
3044 input_t = iter_tensors(input)
3045 numerical_t = iter_tensors(numerical_d_x)
3046 for x, d_x
in zip(input_t, numerical_t):
3048 d_x = d_x.view(-1).data
3049 for i
in range(x.nelement()):
3050 original = x[i].item()
3051 x[i] = original + eps
3052 fx1 = self._forward_criterion(criterion, input, target)
3053 x[i] = original - eps
3054 fx2 = self._forward_criterion(criterion, input, target)
3055 deriv = (fx1 - fx2) / (2. * eps)
3056 d_x[i] = float(deriv)
3060 analytical_t = list(iter_tensors(analytical_d_x))
3061 numerical_t = list(iter_tensors(numerical_d_x))
3063 self.assertLessEqual(
3064 max(a.add(-1, n).abs().max()
for a, n
in zip(analytical_t, numerical_t)),
3071 _required_arg_names = {
'constructor_args',
'input',
'extra_args'}
3073 def __init__(self, constructor, desc='', reference_fn=None, fullname=None, **kwargs):
3079 if name
not in kwargs
and name +
'_fn' not in kwargs
and name +
'_size' not in kwargs:
3080 if name
in {
'constructor_args',
'extra_args'}:
3081 kwargs[name] = tuple()
3083 raise ValueError(
"{}: Specify {} by a value, a function to generate it, or it's size!" 3092 test_name =
'test_' + self.constructor.__name__
3094 test_name +=
'_' + self.
desc 3097 def _unpack(self, value):
3098 if isinstance(value, torch.Tensor):
3100 elif is_iterable(value):
3101 return type(value)(self.
_unpack(v)
for v
in value)
3106 def constructor_args(self):
3107 return self.
_get_arg(
'constructor_args',
True)
3110 def extra_args(self):
3111 return self.
_get_arg(
'extra_args',
True)
3113 def _get_arg(self, name, unpack):
3117 fn_name = name +
'_fn' 3118 size_name = name +
'_size' 3127 def map_tensor_sizes(sizes):
3128 if isinstance(sizes, list):
3129 return [map_tensor_sizes(s)
for s
in sizes]
3130 elif isinstance(sizes, torch.Tensor):
3131 return sizes.double()
3133 return torch.randn(sizes)
3139 def _get_input(self, unpack=True):
3140 return self.
_get_arg(
'input', unpack)
3142 def __call__(self, test_case):
3143 raise NotImplementedError
3148 def __init__(self, *args, **kwargs):
3149 super(ModuleTest, self).__init__(*args, **kwargs)
3155 kwargs.get(
'FIXME_no_cuda_gradgrad_comparison',
False)
3156 self.
precision = kwargs.get(
'precision', 2e-4)
3158 def __call__(self, test_case):
3163 out = test_case._forward(module, input)
3164 ref_input = deepcopy(input)
3165 expected_out = self.
reference_fn(ref_input, test_case._get_parameters(module)[0])
3166 test_case.assertEqual(out, expected_out)
3171 with TemporaryFile()
as f:
3172 test_case._forward(module, input)
3173 torch.save(module, f)
3175 module_copy = torch.load(f)
3176 test_case.assertEqual(test_case._forward(module, input), test_case._forward(module_copy, input))
3178 self._do_test(test_case, module, input)
3180 def noncontiguize(self, obj):
3181 if isinstance(obj, list):
3190 for d
in range(ndim):
3191 if tensor.size(d) > 1:
3194 noncontig = torch.stack([torch.empty_like(tensor), tensor], dim).select(dim, 1).detach()
3195 assert noncontig.numel() == 1
or not noncontig.is_contiguous()
3196 noncontig.requires_grad = tensor.requires_grad
3199 def test_noncontig(self, test_case, module, input):
3201 if isinstance(input, torch.Tensor)
and input.dim() == 0:
3203 if any(i.dim() == 0
for i
in input
if isinstance(i, torch.Tensor)):
3206 test_case._zero_grad_parameters(module)
3207 test_case._zero_grad_input(input)
3208 with freeze_rng_state():
3209 output = test_case._forward(module, input)
3210 grad_output = output.new(output.shape).normal_()
3211 output = output.clone()
3212 d_input = deepcopy(test_case._backward(module, input, output, grad_output))
3213 d_param = deepcopy(test_case._get_parameters(module)[1])
3217 for contig_i, contig_g
in product((
True,
False), repeat=2):
3218 i = input
if contig_i
else nc_input
3219 go = grad_output
if contig_g
else nc_grad_output
3220 test_case._zero_grad_parameters(module)
3221 test_case._zero_grad_input(i)
3222 with freeze_rng_state():
3223 out = test_case._forward(module, i)
3224 grad = test_case._backward(module, i, out, go)
3226 test_case.assertEqual(out, output)
3227 test_case.assertEqual(grad, d_input, 1e-4)
3228 test_case.assertEqual(test_case._get_parameters(module)[1], d_param)
3232 raise unittest.SkipTest(
'Excluded from CUDA tests')
3235 type_map = {
'torch.DoubleTensor': torch.cuda.FloatTensor}
3236 gpu_input = to_gpu(cpu_input, type_map=type_map)
3240 cpu_param = test_case._get_parameters(cpu_module)
3241 gpu_param = test_case._get_parameters(gpu_module)
3242 for cpu_p, gpu_p
in zip(cpu_param[0], gpu_param[0]):
3243 gpu_p.data.copy_(cpu_p)
3245 test_case._zero_grad_input(cpu_input)
3246 test_case._zero_grad_input(gpu_input)
3247 test_case._zero_grad_parameters(cpu_module)
3248 test_case._zero_grad_parameters(gpu_module)
3249 cpu_output = test_case._forward(cpu_module, cpu_input)
3250 gpu_output = test_case._forward(gpu_module, gpu_input)
3251 test_case.assertEqual(cpu_output, gpu_output, self.
precision)
3255 cpu_gradOutput = cpu_output.clone().normal_()
3256 gpu_gradOutput = cpu_gradOutput.type(
'torch.cuda.FloatTensor')
3257 cpu_gradInput = test_case._backward(cpu_module, cpu_input, cpu_output, cpu_gradOutput)
3258 gpu_gradInput = test_case._backward(gpu_module, gpu_input, gpu_output, gpu_gradOutput)
3259 test_case.assertEqual(cpu_gradInput, gpu_gradInput, self.
precision)
3260 for cpu_d_p, gpu_d_p
in zip(cpu_param[1], gpu_param[1]):
3261 test_case.assertEqual(cpu_d_p, gpu_d_p, self.
precision)
3265 cpu_output = cpu_module(cpu_input)
3266 gpu_output = gpu_module(gpu_input)
3268 cpu_gradOutput = torch.randn_like(cpu_output, requires_grad=
True)
3269 gpu_gradOutput = cpu_gradOutput.type_as(gpu_output).detach()
3270 gpu_gradOutput.requires_grad =
True 3274 (cpu_input,) + tuple(cpu_module.parameters()),
3279 (gpu_input,) + tuple(gpu_module.parameters()),
3283 for cpu_d_i, gpu_d_i
in zip(cpu_gradInputs, gpu_gradInputs):
3284 test_case.assertEqual(cpu_d_i, gpu_d_i, self.
precision)
3291 cpu_output.sum() + sum(map(
lambda x: x.sum(), cpu_gradInputs)),
3292 (cpu_input, cpu_gradOutput) + tuple(cpu_module.parameters()),
3295 gpu_output.sum() + sum(map(
lambda x: x.sum(), gpu_gradInputs)),
3296 (gpu_input, gpu_gradOutput) + tuple(gpu_module.parameters()),
3299 test_case.assertEqual(cpu_gradInput, gpu_gradInput, self.
precision)
3300 for cpu_d_p, gpu_d_p
in zip(cpu_gg, gpu_gg):
3301 test_case.assertEqual(cpu_d_p, gpu_d_p, self.
precision)
3304 except NotImplementedError:
3307 except AttributeError
as e:
3308 if len(e.args) == 1
and "'FloatTensor' object has no attribute 'scatter_'" in e.args[0]:
3316 _required_arg_names = TestBase._required_arg_names.union({
'target'})
3318 def __init__(self, *args, **kwargs):
3319 super(CriterionTest, self).__init__(*args, **kwargs)
3323 def _get_target(self):
3324 return self.
_get_arg(
'target',
True)
3326 def __call__(self, test_case):
3337 out = test_case._forward_criterion(module, input, target, extra_args=self.
extra_args)
3338 ref_args = (deepcopy(input), deepcopy(target)) + self.
extra_args + (module,)
3340 test_case.assertEqual(out, expected_out)
3345 test_case.check_criterion_jacobian(module, input, target)
3350 raise unittest.SkipTest(
'Excluded from CUDA tests')
3354 'torch.DoubleTensor': torch.cuda.FloatTensor,
3356 gpu_input = to_gpu(cpu_input, type_map=type_map)
3359 gpu_target = to_gpu(cpu_target, type_map=type_map)
3364 cpu_output = test_case._forward_criterion(cpu_module, cpu_input, cpu_target)
3365 gpu_output = test_case._forward_criterion(gpu_module, gpu_input, gpu_target)
3366 test_case.assertEqual(cpu_output, gpu_output, 4e-4)
3368 gradOutput = torch.randn(())
3369 cpu_gradInput = test_case._backward_criterion(cpu_module, cpu_input, cpu_target, gradOutput)
3370 gpu_gradInput = test_case._backward_criterion(gpu_module, gpu_input, gpu_target, gradOutput)
3371 test_case.assertEqual(cpu_gradInput, gpu_gradInput, 4e-4)
3372 except NotImplementedError:
3375 def _do_extra_tests(self, test_case, module, input, target):
def _zero_grad_input(self, input)
FIXME_no_cuda_gradgrad_comparison
def noncontiguize(self, obj)
def test_noncontig(self, test_case, module, input)
def _jacobian(self, input, num_out)
def _do_extra_tests(self, test_case, module, input, target)
def _get_input(self, unpack=True)
def _analytical_jacobian(self, module, input, jacobian_input=True, jacobian_parameters=True)
dictionary _required_arg_names
def _numerical_jacobian(self, module, input, jacobian_input=True, jacobian_parameters=True)
def _get_arg(self, name, unpack)
def _flatten_tensors(self, x)
def constructor_args(self)
def grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False, only_inputs=True, allow_unused=False)