Caffe2 - Python API
A deep learning, cross platform ML framework
common_nn.py
1 import sys
2 import tempfile
3 import unittest
4 from copy import deepcopy
5 from itertools import product
6 from functools import reduce
7 from operator import mul
8 
9 
10 import torch
11 import torch.cuda
12 import torch.nn as nn
13 import torch.nn.functional as F
14 from torch.nn.functional import _Reduction
15 from common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \
16  TEST_WITH_ROCM, skipIfRocm
17 from common_cuda import TEST_CUDA
18 from torch.autograd.gradcheck import get_numerical_jacobian, iter_tensors
19 from torch.autograd import Variable
21 
22 
23 # tarfile module tries to obtain a file object name in python 3.3
24 if sys.version_info[:2] == (3, 3):
25  TemporaryFile = tempfile.NamedTemporaryFile
26 else:
27  TemporaryFile = tempfile.TemporaryFile
28 PRECISION = 1e-5
29 
30 
31 def get_reduction(m):
32  result = getattr(m, 'reduction', None)
33  if result is None:
34  result = _Reduction.legacy_get_string(getattr(m, 'sizeAverage', None), True, emit_warning=False)
35  assert result is not None
36  return result
37 
38 
39 def get_weight(m):
40  result = getattr(m, 'weight', None)
41  if result is not None:
42  return result
43  return getattr(m, 'weights', None)
44 
45 module_tests = [
46  dict(
47  module_name='Linear',
48  constructor_args=(10, 8),
49  input_size=(4, 10),
50  reference_fn=lambda i, p: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8),
51  ),
52  dict(
53  module_name='Linear',
54  constructor_args=(10, 8, False),
55  input_size=(4, 10),
56  desc='no_bias',
57  reference_fn=lambda i, p: torch.mm(i, p[0].t())
58  ),
59  dict(
60  module_name='Threshold',
61  constructor_args=(2., 1.),
62  input_size=(2, 3, 4, 5),
63  check_inplace=True,
64  desc='threshold_value'
65  ),
66  dict(
67  module_name='Threshold',
68  constructor_args=(2., 10.),
69  input_size=(2, 3, 4, 5),
70  desc='large_value'
71  ),
72  dict(
73  module_name='ReLU',
74  input_size=(2, 3, 4, 5),
75  check_inplace=True,
76  ),
77  dict(
78  module_name='ReLU6',
79  input_size=(2, 3, 4, 5),
80  check_inplace=True,
81  ),
82  dict(
83  module_name='RReLU',
84  input_size=(1, 2, 2),
85  test_cuda=False,
86  ),
87  dict(
88  module_name='RReLU',
89  constructor_args=(0.1, 0.9),
90  input_size=(4, 4, 5),
91  desc='with_up_down',
92  test_cuda=False,
93  ),
94  dict(
95  module_name='Hardtanh',
96  input_size=(3, 2, 5),
97  reference_fn=lambda i, _: i.clamp(-1, 1),
98  ),
99  dict(
100  module_name='Sigmoid',
101  input_size=(2, 3, 4, 5)
102  ),
103  dict(
104  module_name='Tanh',
105  input_size=(2, 3, 4, 5)
106  ),
107  dict(
108  module_name='Softmax',
109  constructor_args=(1,),
110  input_size=(10, 20),
111  reference_fn=lambda i, _: torch.exp(i).div(torch.exp(i).sum(1, True).expand(10, 20)),
112  ),
113  dict(
114  module_name='Softmax2d',
115  input_size=(1, 3, 10, 20),
116  reference_fn=lambda i, _: torch.exp(i).div(torch.exp(i).sum(1, False)),
117  ),
118  dict(
119  module_name='LogSoftmax',
120  constructor_args=(1,),
121  input_size=(10, 20),
122  reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1, True).expand(10, 20)).log_(),
123  ),
124  dict(
125  module_name='LogSoftmax',
126  constructor_args=(1,),
127  input_size=(1, 3, 10, 20),
128  reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1, False)).log_(),
129  desc='multiparam',
130  ),
131  dict(
132  module_name='ELU',
133  constructor_args=(2.,),
134  input_size=(3, 2, 5),
135  reference_fn=lambda x, _: torch.where(x >= 0, x, 2 * (x.exp() - 1)),
136  ),
137  # TODO: reference function
138  dict(
139  module_name='Hardshrink',
140  constructor_args=(2.,),
141  input_size=(4, 3, 2, 4),
142  ),
143  dict(
144  module_name='LeakyReLU',
145  input_size=(3, 2, 5),
146  check_inplace=True
147  ),
148  dict(
149  module_name='LeakyReLU',
150  constructor_args=(0.5,),
151  input_size=(3, 2, 5),
152  check_inplace=True,
153  desc='with_negval'
154  ),
155  dict(
156  module_name='LogSigmoid',
157  input_size=(2, 3, 4),
158  reference_fn=lambda i, _: i.sigmoid().log(),
159  ),
160  dict(
161  module_name='Softplus',
162  input_size=(10, 20),
163  reference_fn=lambda i, _: torch.log(1 + torch.exp(i)),
164  ),
165  dict(
166  module_name='Softplus',
167  constructor_args=(2,),
168  input_size=(10, 20),
169  reference_fn=lambda i, _: 1. / 2. * torch.log(1 + torch.exp(2 * i)),
170  desc='beta',
171  ),
172  dict(
173  module_name='Softplus',
174  constructor_args=(2, -100),
175  input_size=(10, 20),
176  reference_fn=(lambda i, _: ((i * 2) > -100).type_as(i) * i +
177  ((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log(1 + torch.exp(2 * i))),
178  desc='beta_threshold',
179  ),
180  dict(
181  module_name='Softshrink',
182  input_size=(3, 2, 5),
183  ),
184  dict(
185  module_name='Softshrink',
186  constructor_args=(1,),
187  input_size=(3, 2, 5),
188  desc='lambda',
189  ),
190  dict(
191  module_name='CrossMapLRN2d',
192  constructor_args=(5, 5e-3, 1e-3, 2),
193  input_size=(2, 3, 6, 6),
194  check_gradgrad=False,
195  ),
196  dict(
197  module_name='PReLU',
198  input_size=(2, 3, 4),
199  reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
200  desc='1d',
201  ),
202  dict(
203  module_name='PReLU',
204  constructor_args=(3,),
205  input_size=(2, 3, 4),
206  desc='1d_multiparam',
207  reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
208  ),
209  dict(
210  module_name='PReLU',
211  input_size=(2, 3, 4, 5),
212  desc='2d',
213  reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
214  ),
215  dict(
216  module_name='PReLU',
217  constructor_args=(3,),
218  input_size=(2, 3, 4, 5),
219  desc='2d_multiparam',
220  reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
221  ),
222  dict(
223  module_name='PReLU',
224  input_size=(2, 3, 4, 5, 6),
225  reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
226  desc='3d',
227  ),
228  dict(
229  module_name='PReLU',
230  constructor_args=(3,),
231  input_size=(2, 3, 4, 5, 6),
232  desc='3d_multiparam',
233  reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
234  ),
235  dict(
236  module_name='Softsign',
237  input_size=(3, 2, 5),
238  reference_fn=lambda i, _: i.div(1 + torch.abs(i)),
239  ),
240  dict(
241  module_name='Softmin',
242  constructor_args=(1,),
243  input_size=(10, 20),
244  ),
245  dict(
246  module_name='Softmin',
247  constructor_args=(1,),
248  input_size=(2, 3, 5, 10),
249  desc='multidim',
250  ),
251  dict(
252  module_name='Tanhshrink',
253  input_size=(2, 3, 4, 5),
254  ),
255 ]
256 
257 
258 # Generates rand tensor with non-equal values. This ensures that duplicate
259 # values won't be causing test failure for modules like MaxPooling.
260 # size should be small, otherwise randperm fails / long overflows.
261 def _rand_tensor_non_equal(*size):
262  total = reduce(mul, size, 1)
263  return torch.randperm(total).view(*size).double()
264 
265 
266 def wrap_functional(fn, **kwargs):
267  class FunctionalModule(nn.Module):
268  def forward(self, *args):
269  return fn(*args, **kwargs)
270  return FunctionalModule
271 
272 
273 def poissonnllloss_no_reduce_test():
274  t = torch.randn(10, 10)
275  return dict(
276  fullname='PoissonNLLLLoss_no_reduce',
277  constructor=wrap_functional(
278  lambda i: F.poisson_nll_loss(i, t.type_as(i), reduction='none')),
279  input_fn=lambda: torch.rand(10, 10),
280  pickle=False)
281 
282 
283 def bceloss_no_reduce_test():
284  t = Variable(torch.randn(15, 10).gt(0).double())
285  return dict(
286  fullname='BCELoss_no_reduce',
287  constructor=wrap_functional(
288  lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')),
289  input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
290  reference_fn=lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()),
291  check_gradgrad=False,
292  pickle=False)
293 
294 
295 def bceloss_no_reduce_scalar_test():
296  t = torch.randn(()).gt(0).double()
297  return dict(
298  fullname='BCELoss_no_reduce_scalar',
299  constructor=wrap_functional(
300  lambda i: F.binary_cross_entropy(i, t.type_as(i), reduction='none')),
301  input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
302  reference_fn=lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()),
303  check_gradgrad=False,
304  pickle=False)
305 
306 
307 def bceloss_weights_no_reduce_test():
308  t = Variable(torch.randn(15, 10).gt(0).double())
309  weights = torch.rand(10)
310  return dict(
311  fullname='BCELoss_weights_no_reduce',
312  constructor=wrap_functional(
313  lambda i: F.binary_cross_entropy(i, t.type_as(i),
314  weight=weights.type_as(i), reduction='none')),
315  input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
316  reference_fn=lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
317  check_gradgrad=False,
318  pickle=False
319  )
320 
321 
322 def bceloss_weights_no_reduce_scalar_test():
323  t = torch.randn(()).double()
324  weights = torch.rand(())
325  return dict(
326  fullname='BCELoss_weights_no_reduce_scalar',
327  constructor=wrap_functional(
328  lambda i: F.binary_cross_entropy(i, t.type_as(i),
329  weight=weights.type_as(i), reduction='none')),
330  input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
331  reference_fn=lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
332  check_gradgrad=False,
333  pickle=False
334  )
335 
336 
337 def bce_with_logistic_legacy_enum_test():
338  t = Variable(torch.randn(15, 10).gt(0).double())
339  sigmoid = nn.Sigmoid()
340  return dict(
341  fullname='BCEWithLogitsLoss_legacy_enum',
342  constructor=wrap_functional(
343  lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduce=False)),
344  input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
345  reference_fn=lambda i, m: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
346  check_gradgrad=False,
347  pickle=False,
348  )
349 
350 
351 def bce_with_logistic_no_reduce_test():
352  t = Variable(torch.randn(15, 10).gt(0).double())
353  sigmoid = nn.Sigmoid()
354  return dict(
355  fullname='BCEWithLogitsLoss_no_reduce',
356  constructor=wrap_functional(
357  lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')),
358  input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2),
359  reference_fn=lambda i, m: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
360  check_gradgrad=False,
361  pickle=False,
362  )
363 
364 
365 def bce_with_logistic_no_reduce_scalar_test():
366  t = torch.randn(()).gt(0).double()
367  sigmoid = nn.Sigmoid()
368  return dict(
369  fullname='BCEWithLogitsLoss_no_reduce_scalar',
370  constructor=wrap_functional(
371  lambda i: F.binary_cross_entropy_with_logits(i, t.type_as(i), reduction='none')),
372  input_fn=lambda: torch.rand(()).clamp_(2.8e-2, 1 - 2.8e-2),
373  reference_fn=lambda i, m: -(t * sigmoid(i).log() + (1 - t) * (1 - sigmoid(i)).log()),
374  check_gradgrad=False,
375  pickle=False
376  )
377 
378 
379 def kldivloss_with_target_no_reduce_test():
380  i = torch.rand(10, 10).log()
381  return dict(
382  fullname='KLDivLoss_with_target_no_reduce',
383  constructor=wrap_functional(
384  lambda t: F.kl_div(i.type_as(t), t, reduction='none')),
385  input_fn=lambda: torch.rand(10, 10),
386  reference_fn=lambda t, _:
387  loss_reference_fns['KLDivLoss'](i.type_as(t), t, reduction='none'),
388  pickle=False)
389 
390 
391 def kldivloss_no_reduce_test():
392  t = torch.randn(10, 10)
393  return dict(
394  fullname='KLDivLoss_no_reduce',
395  constructor=wrap_functional(
396  lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
397  input_fn=lambda: torch.rand(10, 10).log(),
398  reference_fn=lambda i, _:
399  loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
400  pickle=False,
401  )
402 
403 
404 def kldivloss_no_reduce_scalar_test():
405  t = torch.randn(())
406  return dict(
407  fullname='KLDivLoss_no_reduce_scalar',
408  constructor=wrap_functional(
409  lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
410  input_fn=lambda: torch.rand(()).log(),
411  reference_fn=lambda i, _:
412  loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
413  pickle=False)
414 
415 
416 def l1loss_no_reduce_test():
417  t = torch.randn(2, 3, 4)
418  return dict(
419  fullname='L1Loss_no_reduce',
420  constructor=wrap_functional(
421  lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
422  input_fn=lambda: torch.randn(2, 3, 4),
423  reference_fn=lambda i, m: (i - t.type_as(i)).abs(),
424  pickle=False)
425 
426 
427 def l1loss_no_reduce_scalar_test():
428  t = torch.randn(())
429  return dict(
430  fullname='L1Loss_no_reduce_scalar',
431  constructor=wrap_functional(
432  lambda i: F.l1_loss(i, t.type_as(i), reduction='none')),
433  input_fn=lambda: torch.randn(()),
434  reference_fn=lambda i, m: (i - t.type_as(i)).abs(),
435  pickle=False)
436 
437 
438 def mseloss_no_reduce_test():
439  input_size = (2, 3, 4, 5)
440  target = torch.randn(*input_size)
441  return dict(
442  fullname='MSELoss_no_reduce',
443  constructor=wrap_functional(
444  lambda i: F.mse_loss(i, target.type_as(i), reduction='none')),
445  input_size=input_size,
446  reference_fn=lambda i, m: (i - target).pow(2),
447  pickle=False)
448 
449 
450 def mseloss_no_reduce_scalar_test():
451  input_size = ()
452  target = torch.randn(input_size)
453  return dict(
454  fullname='MSELoss_no_reduce_scalar',
455  constructor=wrap_functional(
456  lambda i: F.mse_loss(i, target.type_as(i), reduction='none')),
457  input_size=input_size,
458  reference_fn=lambda i, m: (i - target).pow(2),
459  pickle=False)
460 
461 
462 def nllloss_no_reduce_test():
463  t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
464  kwargs = {'reduction': 'none'}
465  return dict(
466  fullname='NLLLoss_no_reduce',
467  constructor=wrap_functional(
468  lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
469  input_fn=lambda: torch.rand(15, 10).log(),
470  reference_fn=lambda i, _:
471  loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs),
472  pickle=False)
473 
474 
475 def nllloss_no_reduce_ignore_index_test():
476  t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
477  kwargs = {'ignore_index': 2, 'reduction': 'none'}
478  return dict(
479  fullname='NLLLoss_no_reduce_ignore_index',
480  constructor=wrap_functional(
481  lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
482  input_fn=lambda: torch.rand(15, 10).log(),
483  reference_fn=lambda i, _:
484  loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs),
485  pickle=False)
486 
487 
488 def nllloss_no_reduce_weights_test():
489  t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
490  weight = torch.rand(10)
491 
492  def kwargs(i):
493  return {'weight': weight.type_as(i), 'reduction': 'none'}
494 
495  return dict(
496  fullname='NLLLoss_no_reduce_weights',
497  constructor=wrap_functional(
498  lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
499  input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
500  reference_fn=lambda i, _:
501  loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
502  pickle=False)
503 
504 
505 def nllloss_no_reduce_weights_ignore_index_test():
506  t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
507  weight = torch.rand(10)
508 
509  def kwargs(i):
510  return {'weight': weight.type_as(i), 'reduction': 'none',
511  'ignore_index': 2}
512 
513  return dict(
514  fullname='NLLLoss_no_reduce_weights_ignore_index',
515  constructor=wrap_functional(
516  lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i.data))),
517  input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
518  reference_fn=lambda i, _:
519  loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
520  pickle=False)
521 
522 
523 def nllloss_no_reduce_weights_ignore_index_neg_test():
524  t = Variable(torch.Tensor(15).uniform_().mul(10).floor().long())
525  weight = torch.rand(10)
526 
527  def kwargs(i):
528  return {'weight': weight.type_as(i), 'reduction': 'none',
529  'ignore_index': -1}
530 
531  return dict(
532  fullname='NLLLoss_no_reduce_weights_ignore_index_neg',
533  constructor=wrap_functional(
534  lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
535  input=torch.rand(15, 10).add(1e-2).log(),
536  reference_fn=lambda i, _:
537  loss_reference_fns['NLLLoss'](i, t.type_as(i).long(), **kwargs(i)),
538  pickle=False)
539 
540 
541 def nllloss2d_no_reduce_test():
542  t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
543  kwargs = {'reduction': 'none'}
544  return dict(
545  fullname='NLLLoss2d_no_reduce',
546  constructor=wrap_functional(
547  lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
548  input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
549  reference_fn=lambda i, _:
550  loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
551  pickle=False)
552 
553 
554 def nllloss2d_no_reduce_ignore_index_test():
555  t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
556  kwargs = {'ignore_index': 1, 'reduction': 'none'}
557  return dict(
558  fullname='NLLLoss2d_no_reduce_ignore_index',
559  constructor=wrap_functional(
560  lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
561  input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
562  reference_fn=lambda i, _:
563  loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
564  pickle=False)
565 
566 
567 def nllloss2d_no_reduce_weights_test():
568  t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
569  weight = torch.rand(3)
570 
571  def kwargs(i):
572  return {'weight': weight.type_as(i), 'reduction': 'none'}
573 
574  return dict(
575  fullname='NLLLoss2d_no_reduce_weights',
576  constructor=wrap_functional(
577  lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
578  input_fn=lambda: torch.rand(2, 3, 5, 5).log(),
579  reference_fn=lambda i, _:
580  loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs(i)),
581  pickle=False)
582 
583 
584 def nlllossNd_no_reduce_test():
585  t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
586  kwargs = {'reduction': 'none'}
587  return dict(
588  fullname='NLLLossNd_no_reduce',
589  constructor=wrap_functional(
590  lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
591  input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
592  reference_fn=lambda i, _:
593  loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
594  pickle=False)
595 
596 
597 def nlllossNd_no_reduce_ignore_index_test():
598  t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
599  kwargs = {'ignore_index': 1, 'reduction': 'none'}
600  return dict(
601  fullname='NLLLossNd_no_reduce_ignore_index',
602  constructor=wrap_functional(
603  lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs)),
604  input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
605  reference_fn=lambda i, _:
606  loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs),
607  pickle=False)
608 
609 
610 def nlllossNd_no_reduce_weights_test():
611  t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
612  weight = torch.rand(3)
613 
614  def kwargs(i):
615  return {'weight': weight.type_as(i), 'reduction': 'none'}
616 
617  return dict(
618  fullname='NLLLossNd_no_reduce_weights',
619  constructor=wrap_functional(
620  lambda i: F.nll_loss(i, t.type_as(i).long(), **kwargs(i))),
621  input_fn=lambda: torch.rand(2, 3, 5, 5, 2, 2).log(),
622  reference_fn=lambda i, _:
623  loss_reference_fns['NLLLossNd'](i, t.type_as(i).long(), **kwargs(i)),
624  pickle=False)
625 
626 
627 def smoothl1loss_no_reduce_test():
628  t = torch.randn(2, 3, 4)
629  return dict(
630  fullname='SmoothL1Loss_no_reduce',
631  constructor=wrap_functional(
632  lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')),
633  input_fn=lambda: torch.randn(2, 3, 4),
634  reference_fn=lambda i, _:
635  loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'),
636  pickle=False)
637 
638 
639 def smoothl1loss_no_reduce_scalar_test():
640  t = torch.randn(())
641  return dict(
642  fullname='SmoothL1Loss_no_reduce_scalar',
643  constructor=wrap_functional(
644  lambda i: F.smooth_l1_loss(i, t.type_as(i), reduction='none')),
645  input_fn=lambda: torch.randn(()),
646  reference_fn=lambda i, _:
647  loss_reference_fns['SmoothL1Loss'](i, t.type_as(i), reduction='none'),
648  pickle=False)
649 
650 
651 def multilabelmarginloss_1d_no_reduce_test():
652  t = Variable(torch.rand(10).mul(10).floor().long())
653  return dict(
654  fullname='MultiLabelMarginLoss_1d_no_reduce',
655  constructor=wrap_functional(
656  lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
657  input_fn=lambda: torch.randn(10),
658  reference_fn=lambda i, _:
659  loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
660  check_sum_reduction=True,
661  check_gradgrad=False,
662  pickle=False)
663 
664 
665 def multilabelmarginloss_index_neg_test():
666  t = Variable(torch.clamp(torch.rand(5, 10).add(-.5).mul(20).floor().long(), min=-1))
667  return dict(
668  fullname='MultiLabelMarginLoss_index_neg',
669  constructor=wrap_functional(
670  lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
671  input_fn=lambda: torch.randn(5, 10),
672  reference_fn=lambda i, _:
673  loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
674  check_sum_reduction=True,
675  check_gradgrad=False,
676  pickle=False)
677 
678 
679 def multilabelmarginloss_no_reduce_test():
680  t = Variable(torch.rand(5, 10).mul(10).floor().long())
681  return dict(
682  fullname='MultiLabelMarginLoss_no_reduce',
683  constructor=wrap_functional(
684  lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduction='none')),
685  input_fn=lambda: torch.randn(5, 10),
686  reference_fn=lambda i, _:
687  loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
688  check_sum_reduction=True,
689  check_gradgrad=False,
690  pickle=False)
691 
692 
693 def hingeembeddingloss_no_reduce_test():
694  t = Variable(torch.randn(10).gt(0).double().mul_(2).sub(1))
695  return dict(
696  fullname='HingeEmbeddingLoss_no_reduce',
697  constructor=wrap_functional(
698  lambda i: F.hinge_embedding_loss(i, t.type_as(i), reduction='none')),
699  input_fn=lambda: torch.randn(10),
700  reference_fn=lambda i, _:
701  loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), reduction='none'),
702  check_sum_reduction=True,
703  pickle=False)
704 
705 
706 def hingeembeddingloss_margin_no_reduce_test():
707  t = Variable(torch.randn(10).gt(0).double().mul_(2).sub(1))
708  return dict(
709  fullname='HingeEmbeddingLoss_margin_no_reduce',
710  constructor=wrap_functional(
711  lambda i: F.hinge_embedding_loss(i, t.type_as(i), margin=0.5, reduction='none')),
712  input_fn=lambda: torch.randn(10),
713  reference_fn=lambda i, _:
714  loss_reference_fns['HingeEmbeddingLoss'](i, t.type_as(i), margin=0.5, reduction='none'),
715  check_sum_reduction=True,
716  pickle=False)
717 
718 
719 def softmarginloss_no_reduce_test():
720  t = torch.randn(5, 5)
721  return dict(
722  fullname='SoftMarginLoss_no_reduce',
723  constructor=wrap_functional(
724  lambda i: F.soft_margin_loss(i, t.type_as(i), reduction='none')),
725  input_fn=lambda: torch.randn(5, 5),
726  reference_fn=lambda i, _:
727  loss_reference_fns['SoftMarginLoss'](i, t.type_as(i), reduction='none'),
728  pickle=False)
729 
730 
731 def multilabelsoftmarginloss_no_reduce_test():
732  t = torch.rand(5, 10).mul(2).floor()
733  return dict(
734  fullname='MultiLabelSoftMarginLoss_no_reduce',
735  constructor=wrap_functional(
736  lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i), reduction='none')),
737  input_fn=lambda: torch.randn(5, 10),
738  reference_fn=lambda i, m:
739  (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log())).sum(dim=1) / i.size(1),
740  check_gradgrad=False,
741  pickle=False)
742 
743 
744 def multilabelsoftmarginloss_weights_no_reduce_test():
745  t = torch.rand(5, 10).mul(2).floor()
746  weights = torch.rand(10)
747  return dict(
748  fullname='MultiLabelSoftMarginLoss_weights_no_reduce',
749  constructor=wrap_functional(
750  lambda i: F.multilabel_soft_margin_loss(i, t.type_as(i),
751  weight=weights.type_as(i), reduction='none')),
752  input_fn=lambda: torch.randn(5, 10),
753  reference_fn=lambda i, m:
754  (-(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()) * weights).sum(dim=1) / i.size(1),
755  check_sum_reduction=True,
756  check_gradgrad=False,
757  pickle=False)
758 
759 
760 def multimarginloss_no_reduce_test():
761  t = torch.rand(5).mul(8).floor().long()
762  return dict(
763  fullname='MultiMarginLoss_no_reduce',
764  constructor=wrap_functional(
765  lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
766  input_fn=lambda: torch.randn(5, 10),
767  reference_fn=lambda i, _:
768  loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
769  check_sum_reduction=True,
770  check_gradgrad=False,
771  pickle=False)
772 
773 
774 def multimarginloss_1d_no_reduce_test():
775  t = torch.rand(1).mul(8).floor().long()
776  return dict(
777  fullname='MultiMarginLoss_1d_no_reduce',
778  constructor=wrap_functional(
779  lambda i: F.multi_margin_loss(i, t.type_as(i).long(), reduction='none')),
780  input_fn=lambda: torch.randn(10),
781  reference_fn=lambda i, _:
782  loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), reduction='none'),
783  check_sum_reduction=True,
784  check_gradgrad=False,
785  pickle=False)
786 
787 
788 def multimarginloss_p_no_reduce_test():
789  t = torch.rand(5).mul(8).floor().long()
790  return dict(
791  fullname='MultiMarginLoss_p_no_reduce',
792  constructor=wrap_functional(
793  lambda i: F.multi_margin_loss(i, t.type_as(i).long(), p=2, reduction='none')),
794  input_fn=lambda: torch.randn(5, 10).clamp_(1e-2, 1 - 1e-2),
795  reference_fn=lambda i, _:
796  loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(), p=2, reduction='none'),
797  check_sum_reduction=True,
798  check_gradgrad=False,
799  pickle=False)
800 
801 
802 def multimarginloss_margin_no_reduce_test():
803  t = torch.rand(5).mul(8).floor().long()
804  return dict(
805  fullname='MultiMarginLoss_margin_no_reduce',
806  constructor=wrap_functional(
807  lambda i: F.multi_margin_loss(i, t.type_as(i).long(), margin=0.5, reduction='none')),
808  input_fn=lambda: torch.randn(5, 10),
809  reference_fn=lambda i, _:
810  loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(),
811  margin=0.5, reduction='none'),
812  check_sum_reduction=True,
813  check_gradgrad=False,
814  pickle=False)
815 
816 
817 def multimarginloss_weights_no_reduce_test():
818  t = torch.rand(5).mul(8).floor().long()
819  weights = torch.rand(10)
820  return dict(
821  fullname='MultiMarginLoss_weights_no_reduce',
822  constructor=wrap_functional(
823  lambda i: F.multi_margin_loss(i, t.type_as(i).long(), weight=weights.type_as(i),
824  reduction='none')),
825  input_fn=lambda: torch.randn(5, 10),
826  reference_fn=lambda i, _:
827  loss_reference_fns['MultiMarginLoss'](i, t.data.type_as(i).long(),
828  weight=weights, reduction='none'),
829  check_sum_reduction=True,
830  check_gradgrad=False,
831  pickle=False)
832 
833 
834 def fractional_max_pool2d_test(test_case):
835  random_samples = torch.DoubleTensor(1, 3, 2).uniform_()
836  if test_case == 'ratio':
837  return dict(
838  constructor=lambda: nn.FractionalMaxPool2d(
839  2, output_ratio=0.5, _random_samples=random_samples),
840  input_size=(1, 3, 5, 7),
841  fullname='FractionalMaxPool2d_ratio')
842  elif test_case == 'size':
843  return dict(
844  constructor=lambda: nn.FractionalMaxPool2d((2, 3), output_size=(
845  4, 3), _random_samples=random_samples),
846  input_size=(1, 3, 7, 6),
847  fullname='FractionalMaxPool2d_size')
848 
849 
850 def fractional_max_pool3d_test(test_case):
851  random_samples = torch.DoubleTensor(2, 4, 3).uniform_()
852  if test_case == 'ratio':
853  return dict(
854  constructor=lambda: nn.FractionalMaxPool3d(
855  2, output_ratio=0.5, _random_samples=random_samples),
856  input_size=(2, 4, 5, 5, 5),
857  fullname='FractionalMaxPool3d_ratio')
858  elif test_case == 'size':
859  return dict(
860  constructor=lambda: nn.FractionalMaxPool3d((2, 2, 2), output_size=(
861  4, 4, 4), _random_samples=random_samples),
862  input_size=(2, 4, 7, 7, 7),
863  fullname='FractionalMaxPool3d_size')
864  elif test_case == 'asymsize':
865  return dict(
866  constructor=lambda: nn.FractionalMaxPool3d((4, 2, 3), output_size=(
867  10, 3, 2), _random_samples=random_samples),
868  input_size=(2, 4, 16, 7, 5),
869  fullname='FractionalMaxPool3d_asymsize')
870 
871 
872 new_module_tests = [
873  poissonnllloss_no_reduce_test(),
874  bceloss_no_reduce_test(),
875  bceloss_weights_no_reduce_test(),
876  bce_with_logistic_legacy_enum_test(),
877  bce_with_logistic_no_reduce_test(),
878  bceloss_no_reduce_scalar_test(),
879  bceloss_weights_no_reduce_scalar_test(),
880  bce_with_logistic_no_reduce_scalar_test(),
881  kldivloss_with_target_no_reduce_test(),
882  kldivloss_no_reduce_test(),
883  kldivloss_no_reduce_scalar_test(),
884  l1loss_no_reduce_test(),
885  l1loss_no_reduce_scalar_test(),
886  mseloss_no_reduce_test(),
887  mseloss_no_reduce_scalar_test(),
888  nllloss_no_reduce_test(),
889  nllloss_no_reduce_ignore_index_test(),
890  nllloss_no_reduce_weights_test(),
891  nllloss_no_reduce_weights_ignore_index_test(),
892  nllloss_no_reduce_weights_ignore_index_neg_test(),
893  nllloss2d_no_reduce_test(),
894  nllloss2d_no_reduce_weights_test(),
895  nllloss2d_no_reduce_ignore_index_test(),
896  nlllossNd_no_reduce_test(),
897  nlllossNd_no_reduce_weights_test(),
898  nlllossNd_no_reduce_ignore_index_test(),
899  smoothl1loss_no_reduce_test(),
900  smoothl1loss_no_reduce_scalar_test(),
901  multilabelmarginloss_1d_no_reduce_test(),
902  multilabelmarginloss_index_neg_test(),
903  multilabelmarginloss_no_reduce_test(),
904  hingeembeddingloss_no_reduce_test(),
905  hingeembeddingloss_margin_no_reduce_test(),
906  softmarginloss_no_reduce_test(),
907  multilabelsoftmarginloss_no_reduce_test(),
908  multilabelsoftmarginloss_weights_no_reduce_test(),
909  multimarginloss_no_reduce_test(),
910  multimarginloss_1d_no_reduce_test(),
911  multimarginloss_p_no_reduce_test(),
912  multimarginloss_margin_no_reduce_test(),
913  multimarginloss_weights_no_reduce_test(),
914  fractional_max_pool2d_test('ratio'),
915  fractional_max_pool2d_test('size'),
916  fractional_max_pool3d_test('ratio'),
917  fractional_max_pool3d_test('size'),
918  fractional_max_pool3d_test('asymsize'),
919  dict(
920  module_name='BatchNorm1d',
921  constructor_args=(10,),
922  input_size=(4, 10),
923  cudnn=True,
924  check_eval=True,
925  desc='affine',
926  skip_double=TEST_WITH_ROCM,
927  test_cuda=(not TEST_WITH_ROCM),
928  ),
929  dict(
930  module_name='BatchNorm1d',
931  constructor_args=(5,),
932  input_size=(4, 5, 3),
933  cudnn=True,
934  check_eval=True,
935  desc='3d_input',
936  skip_double=TEST_WITH_ROCM,
937  ),
938  dict(
939  module_name='BatchNorm1d',
940  constructor_args=(10, 1e-3, None),
941  input_size=(4, 10),
942  cudnn=True,
943  check_eval=True,
944  desc='affine_simple_average',
945  skip_double=TEST_WITH_ROCM,
946  test_cuda=(not TEST_WITH_ROCM),
947  ),
948  dict(
949  module_name='BatchNorm1d',
950  constructor_args=(10, 1e-3, 0.3, False),
951  input_size=(4, 10),
952  cudnn=True,
953  check_eval=True,
954  desc='not_affine',
955  skip_double=TEST_WITH_ROCM,
956  ),
957  dict(
958  module_name='BatchNorm1d',
959  constructor_args=(10, 1e-3, 0.3, True, False),
960  input_size=(4, 10),
961  cudnn=True,
962  check_eval=True,
963  desc='not_tracking_stats',
964  skip_double=TEST_WITH_ROCM,
965  test_cuda=(not TEST_WITH_ROCM),
966  ),
967  dict(
968  module_name='BatchNorm1d',
969  constructor_args=(5, 1e-3, 0.3, False),
970  input_size=(4, 5, 3),
971  cudnn=True,
972  check_eval=True,
973  desc='3d_input_not_affine',
974  skip_double=TEST_WITH_ROCM,
975  ),
976  dict(
977  module_name='BatchNorm2d',
978  constructor_args=(3,),
979  input_size=(2, 3, 6, 6),
980  cudnn=True,
981  check_eval=True,
982  skip_double=TEST_WITH_ROCM,
983  ),
984  dict(
985  module_name='BatchNorm2d',
986  constructor_args=(3, 1e-3, None),
987  input_size=(2, 3, 6, 6),
988  cudnn=True,
989  check_eval=True,
990  desc='2d_simple_average',
991  skip_double=TEST_WITH_ROCM,
992  ),
993  dict(
994  module_name='BatchNorm2d',
995  constructor_args=(3, 1e-3, 0.8),
996  input_size=(2, 3, 6, 6),
997  cudnn=True,
998  check_eval=True,
999  desc='momentum',
1000  skip_double=TEST_WITH_ROCM,
1001  ),
1002  dict(
1003  module_name='BatchNorm2d',
1004  constructor_args=(3, 1e-3, 0.8, False),
1005  input_size=(2, 3, 6, 6),
1006  cudnn=True,
1007  check_eval=True,
1008  desc='not_affine',
1009  skip_double=TEST_WITH_ROCM,
1010  ),
1011  dict(
1012  module_name='BatchNorm2d',
1013  constructor_args=(3, 1e-3, 0.8, True, False),
1014  input_size=(2, 3, 6, 6),
1015  cudnn=True,
1016  check_eval=True,
1017  desc='not_tracking_stats',
1018  skip_double=TEST_WITH_ROCM,
1019  ),
1020  dict(
1021  module_name='BatchNorm3d',
1022  constructor_args=(3,),
1023  input_size=(2, 3, 4, 4, 4),
1024  cudnn=True,
1025  check_eval=True,
1026  ),
1027  dict(
1028  module_name='BatchNorm3d',
1029  constructor_args=(3, 1e-3, None),
1030  input_size=(2, 3, 4, 4, 4),
1031  cudnn=True,
1032  check_eval=True,
1033  desc='3d_simple_average',
1034  ),
1035  dict(
1036  module_name='BatchNorm3d',
1037  constructor_args=(3, 1e-3, 0.7),
1038  input_size=(2, 3, 4, 4, 4),
1039  cudnn=True,
1040  check_eval=True,
1041  desc='momentum',
1042  ),
1043  dict(
1044  module_name='BatchNorm3d',
1045  constructor_args=(3, 1e-3, 0.7, False),
1046  input_size=(2, 3, 4, 4, 4),
1047  cudnn=True,
1048  check_eval=True,
1049  desc='not_affine',
1050  ),
1051  dict(
1052  module_name='BatchNorm3d',
1053  constructor_args=(3, 1e-3, 0.7, True, False),
1054  input_size=(2, 3, 4, 4, 4),
1055  cudnn=True,
1056  check_eval=True,
1057  desc='not_tracking_stats',
1058  ),
1059  dict(
1060  module_name='InstanceNorm1d',
1061  constructor_args=(3, 1e-3, 0.3),
1062  input_size=(4, 3, 15),
1063  cudnn=True,
1064  check_eval=True,
1065  ),
1066  dict(
1067  module_name='InstanceNorm1d',
1068  constructor_args=(3, 1e-3, 0.3, False, True),
1069  input_size=(4, 3, 15),
1070  cudnn=True,
1071  check_eval=True,
1072  desc='tracking_stats',
1073  ),
1074  dict(
1075  module_name='InstanceNorm2d',
1076  constructor_args=(3, 1e-3, 0.3),
1077  input_size=(2, 3, 6, 6),
1078  cudnn=True,
1079  check_eval=True,
1080  ),
1081  dict(
1082  module_name='InstanceNorm2d',
1083  constructor_args=(3, 1e-3, 0.3, False, True),
1084  input_size=(2, 3, 6, 6),
1085  cudnn=True,
1086  check_eval=True,
1087  desc='tracking_stats',
1088  ),
1089  dict(
1090  module_name='InstanceNorm3d',
1091  constructor_args=(3, 1e-3, 0.3),
1092  input_size=(2, 3, 4, 4, 4),
1093  cudnn=True,
1094  check_eval=True,
1095  ),
1096  dict(
1097  module_name='InstanceNorm3d',
1098  constructor_args=(3, 1e-3, 0.3, False, True),
1099  input_size=(2, 3, 4, 4, 4),
1100  cudnn=True,
1101  check_eval=True,
1102  desc='tracking_stats',
1103  ),
1104  dict(
1105  module_name='LayerNorm',
1106  constructor_args=([5], 1e-3),
1107  input_size=(4, 5, 5),
1108  cudnn=True,
1109  check_eval=True,
1110  desc='1d_elementwise_affine',
1111  ),
1112  dict(
1113  module_name='LayerNorm',
1114  constructor_args=([5], 1e-3, False),
1115  input_size=(4, 5, 5),
1116  cudnn=True,
1117  check_eval=True,
1118  desc='1d_no_elementwise_affine',
1119  ),
1120  dict(
1121  module_name='LayerNorm',
1122  constructor_args=([2, 2, 5], 1e-3),
1123  input_size=(4, 2, 2, 5),
1124  cudnn=True,
1125  check_eval=True,
1126  desc='3d_elementwise_affine',
1127  ),
1128  dict(
1129  module_name='LayerNorm',
1130  constructor_args=([2, 2, 5], 1e-3, False),
1131  input_size=(4, 2, 2, 5),
1132  cudnn=True,
1133  check_eval=True,
1134  desc='3d_no_elementwise_affine',
1135  ),
1136  dict(
1137  module_name='GroupNorm',
1138  constructor_args=(3, 6, 1e-3),
1139  input_size=(4, 6, 5),
1140  cudnn=True,
1141  check_eval=True,
1142  desc='1d_affine',
1143  ),
1144  dict(
1145  module_name='GroupNorm',
1146  constructor_args=(5, 5, 1e-3, False),
1147  input_size=(4, 5, 5),
1148  cudnn=True,
1149  check_eval=True,
1150  desc='1d_no_affine_IN', # this setting is equivalent with InstanceNormi
1151  ),
1152  dict(
1153  module_name='GroupNorm',
1154  constructor_args=(1, 5, 1e-3, False),
1155  input_size=(4, 5, 5),
1156  cudnn=True,
1157  check_eval=True,
1158  desc='1d_no_affine_LN', # this setting is equivalent with LayerNorm
1159  ),
1160  dict(
1161  module_name='GroupNorm',
1162  constructor_args=(3, 6, 1e-3),
1163  input_size=(4, 6, 2, 3),
1164  cudnn=True,
1165  check_eval=True,
1166  desc='2d_affine',
1167  ),
1168  dict(
1169  module_name='GroupNorm',
1170  constructor_args=(3, 3, 1e-3, False),
1171  input_size=(4, 3, 2, 3),
1172  cudnn=True,
1173  check_eval=True,
1174  desc='2d_no_affine_IN', # this setting is equivalent with InstanceNorm
1175  ),
1176  dict(
1177  module_name='GroupNorm',
1178  constructor_args=(1, 3, 1e-3, False),
1179  input_size=(4, 3, 2, 3),
1180  cudnn=True,
1181  check_eval=True,
1182  desc='2d_no_affine_LN', # this setting is equivalent with LayerNorm
1183  ),
1184  dict(
1185  module_name='Conv1d',
1186  constructor_args=(4, 5, 3),
1187  input_size=(2, 4, 10),
1188  cudnn=True,
1189  skip_double=TEST_WITH_ROCM,
1190  ),
1191  dict(
1192  module_name='Conv1d',
1193  constructor_args=(4, 5, 3, 2),
1194  input_size=(2, 4, 10),
1195  cudnn=True,
1196  desc='stride',
1197  skip_double=TEST_WITH_ROCM,
1198  ),
1199  dict(
1200  module_name='Conv1d',
1201  constructor_args=(4, 5, 3, 1, 1),
1202  input_size=(2, 4, 10),
1203  cudnn=True,
1204  desc='pad1',
1205  skip_double=TEST_WITH_ROCM,
1206  ),
1207  dict(
1208  module_name='Conv1d',
1209  constructor_args=(4, 5, 5, 1, 2),
1210  input_size=(2, 4, 10),
1211  cudnn=True,
1212  desc='pad2',
1213  skip_double=TEST_WITH_ROCM,
1214  ),
1215  dict(
1216  module_name='Conv1d',
1217  constructor_args=(4, 4, 3, 1, 1),
1218  input_size=(1, 4, 1),
1219  cudnn=True,
1220  desc='pad1size1',
1221  skip_double=TEST_WITH_ROCM,
1222  ),
1223  dict(
1224  module_name='Conv1d',
1225  constructor_args=(4, 4, 5, 1, 2),
1226  input_size=(1, 4, 1),
1227  cudnn=True,
1228  desc='pad2size1',
1229  skip_double=TEST_WITH_ROCM,
1230  ),
1231  dict(
1232  fullname='Conv1d_dilated',
1233  constructor=lambda: nn.Conv1d(4, 5, kernel_size=3, dilation=2),
1234  input_size=(2, 4, 10),
1235  skip_double=TEST_WITH_ROCM,
1236  ),
1237  dict(
1238  fullname='Conv1d_groups',
1239  constructor=lambda: nn.Conv1d(4, 6, kernel_size=3, groups=2),
1240  input_size=(2, 4, 6),
1241  cudnn=True,
1242  ),
1243  dict(
1244  fullname='ConvTranspose1d',
1245  constructor=lambda: nn.ConvTranspose1d(3, 4, kernel_size=3, stride=(3,), padding=1, output_padding=(1,)),
1246  cudnn=True,
1247  input_size=(1, 3, 7),
1248  ),
1249  dict(
1250  module_name='ConvTranspose1d',
1251  constructor_args=(3, 4, 3, 2, 1, 1, 1, False),
1252  input_size=(1, 3, 6),
1253  cudnn=True,
1254  desc='no_bias',
1255  ),
1256  dict(
1257  module_name='ConvTranspose1d',
1258  constructor_args=(3, 4, 3, 2, 1, 1, 1, True, 2),
1259  input_size=(1, 3, 6),
1260  cudnn=True,
1261  desc='dilated',
1262  ),
1263  dict(
1264  fullname='ConvTranspose1d_groups',
1265  constructor=lambda: nn.ConvTranspose1d(4, 6, 3, stride=(3,), padding=1, output_padding=(1,), groups=2),
1266  cudnn=True,
1267  input_size=(2, 4, 7),
1268  ),
1269  dict(
1270  module_name='MaxPool1d',
1271  constructor_args=(4,),
1272  input_size=(2, 10, 4),
1273  ),
1274  dict(
1275  module_name='MaxPool1d',
1276  constructor_args=(4, 4),
1277  input_size=(2, 10, 4),
1278  desc='stride',
1279  ),
1280  dict(
1281  module_name='Conv2d',
1282  constructor_args=(3, 4, (3, 2)),
1283  input_size=(2, 3, 7, 5),
1284  cudnn=True,
1285  ),
1286  dict(
1287  module_name='Conv2d',
1288  constructor_args=(3, 4, (3, 3), (2, 2)),
1289  input_size=(2, 3, 6, 6),
1290  cudnn=True,
1291  desc='strided',
1292  ),
1293  dict(
1294  module_name='Conv2d',
1295  constructor_args=(3, 4, (3, 3), (2, 2), (1, 1)),
1296  input_size=(2, 3, 6, 6),
1297  cudnn=True,
1298  desc='padding',
1299  ),
1300  dict(
1301  module_name='Conv2d',
1302  constructor_args=(3, 2, (3, 3), (2, 2), (1, 1), (2, 2)),
1303  input_size=(2, 3, 8, 8),
1304  cudnn=True,
1305  desc='dilated',
1306  ),
1307  dict(
1308  module_name='Conv2d',
1309  constructor_args=(3, 4, (3, 2), 1, 0, 1, 1, False),
1310  input_size=(2, 3, 6, 5),
1311  cudnn=True,
1312  desc='no_bias',
1313  ),
1314  dict(
1315  fullname='Conv2d_groups',
1316  constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2),
1317  input_size=(2, 4, 6, 5),
1318  cudnn=True,
1319  ),
1320  dict(
1321  fullname='Conv2d_groups_thnn',
1322  constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2),
1323  input_size=(2, 4, 6, 5),
1324  ),
1325  dict(
1326  module_name='ConvTranspose2d',
1327  constructor_args=(3, 4, 3, (3, 2), 1, (1, 1)),
1328  cudnn=True,
1329  input_size=(1, 3, 7, 6),
1330  ),
1331  dict(
1332  module_name='ConvTranspose2d',
1333  constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False, (2, 2)),
1334  input_size=(1, 3, 6, 7),
1335  cudnn=True,
1336  desc='dilated',
1337  ),
1338  dict(
1339  module_name='ConvTranspose2d',
1340  constructor_args=(3, 4, 3, (2, 3), 1, (1, 1), 1, False),
1341  input_size=(1, 3, 6, 7),
1342  cudnn=True,
1343  desc='no_bias',
1344  ),
1345  dict(
1346  fullname='ConvTranspose2d_groups',
1347  constructor=lambda: nn.ConvTranspose2d(2, 4, (2, 3), groups=2),
1348  input_size=(1, 2, 4, 5),
1349  cudnn=True,
1350  ),
1351  dict(
1352  fullname='Conv2d_depthwise',
1353  constructor=lambda: nn.Conv2d(4, 4, (3, 3), groups=4),
1354  input_size=(2, 4, 6, 6),
1355  ),
1356  dict(
1357  fullname='Conv2d_depthwise_with_multiplier',
1358  constructor=lambda: nn.Conv2d(4, 8, (3, 3), groups=4),
1359  input_size=(2, 4, 6, 6),
1360  ),
1361  dict(
1362  fullname='Conv2d_depthwise_strided',
1363  constructor=lambda: nn.Conv2d(4, 4, (3, 3), stride=(2, 2), groups=4),
1364  input_size=(2, 4, 6, 6),
1365  ),
1366  dict(
1367  fullname='Conv2d_depthwise_padded',
1368  constructor=lambda: nn.Conv2d(4, 4, (3, 3), padding=(1, 1), groups=4),
1369  input_size=(2, 4, 6, 6),
1370  ),
1371  dict(
1372  fullname='Conv2d_depthwise_dilated',
1373  constructor=lambda: nn.Conv2d(4, 4, (2, 2), dilation=(2, 2), groups=4),
1374  input_size=(2, 4, 5, 5),
1375  ),
1376  dict(
1377  module_name='MaxPool2d',
1378  constructor_args=((3, 3), (2, 2), (1, 1)),
1379  input_size=(1, 3, 7, 7),
1380  ),
1381  dict(
1382  module_name='AvgPool1d',
1383  constructor_args=(2,),
1384  input_size=(2, 3, 6),
1385  ),
1386  dict(
1387  module_name='AvgPool1d',
1388  constructor_args=((2,), (2,)),
1389  input_size=(2, 3, 6),
1390  desc='stride',
1391  ),
1392  dict(
1393  module_name='AvgPool1d',
1394  constructor_args=(2, 2, 1),
1395  input_size=(2, 3, 6),
1396  desc='stride_pad',
1397  ),
1398  dict(
1399  module_name='AvgPool2d',
1400  constructor_args=((2, 2),),
1401  input_size=(2, 3, 6, 6),
1402  ),
1403  dict(
1404  module_name='AvgPool2d',
1405  constructor_args=((2, 2), (2, 2)),
1406  input_size=(2, 3, 6, 6),
1407  desc='stride',
1408  ),
1409  dict(
1410  module_name='AvgPool2d',
1411  constructor_args=((2, 2), (2, 2), (1, 1)),
1412  input_size=(2, 3, 6, 6),
1413  desc='stride_pad',
1414  ),
1415  dict(
1416  module_name='LPPool2d',
1417  constructor_args=(2, 2, 2),
1418  input_size=(1, 3, 7, 7),
1419  ),
1420  dict(
1421  module_name='LPPool2d',
1422  constructor_args=(1.5, 2),
1423  input_fn=lambda: torch.rand(1, 3, 7, 7),
1424  desc='norm',
1425  ),
1426  dict(
1427  module_name='LPPool1d',
1428  constructor_args=(1.5, 2),
1429  input_fn=lambda: torch.rand(1, 3, 7),
1430  desc='norm',
1431  ),
1432  dict(
1433  module_name='LPPool1d',
1434  constructor_args=(2, 2, 3),
1435  input_size=(1, 3, 7),
1436  ),
1437  dict(
1438  module_name='LocalResponseNorm',
1439  constructor_args=(3, ),
1440  input_size=(1, 5, 7),
1441  desc='1d',
1442  ),
1443  dict(
1444  module_name='LocalResponseNorm',
1445  constructor_args=(2, ),
1446  input_size=(1, 5, 7, 7),
1447  desc='2d_uneven_pad',
1448  ),
1449  dict(
1450  module_name='LocalResponseNorm',
1451  constructor_args=(1, 1., 0.5, 2.),
1452  input_size=(1, 5, 7, 7, 7),
1453  desc='3d_custom_params',
1454  ),
1455  dict(
1456  module_name='ReflectionPad1d',
1457  constructor_args=((1, 2),),
1458  input_size=(2, 3, 8),
1459  ),
1460  dict(
1461  module_name='ReflectionPad2d',
1462  constructor_args=((1, 2, 3, 4),),
1463  input_size=(2, 3, 8, 8),
1464  ),
1465  dict(
1466  module_name='ReplicationPad1d',
1467  constructor_args=((1, 2),),
1468  input_size=(2, 3, 4),
1469  ),
1470  dict(
1471  module_name='ReplicationPad2d',
1472  constructor_args=((1, 2, 3, 4),),
1473  input_size=(2, 3, 4, 4),
1474  ),
1475  dict(
1476  module_name='ZeroPad2d',
1477  constructor_args=((1, 2, 3, 4),),
1478  input_size=(2, 3, 4, 4)
1479  ),
1480  dict(
1481  module_name='ZeroPad2d',
1482  constructor_args=((-1, -1, -1, -2),),
1483  input_size=(2, 3, 4, 4),
1484  desc='negative_dims'
1485  ),
1486  dict(
1487  module_name='ConstantPad1d',
1488  constructor_args=((1, 2), 2.),
1489  input_size=(2, 3, 4)
1490  ),
1491  dict(
1492  module_name='ConstantPad2d',
1493  constructor_args=((1, 2, 3, 4), 2.),
1494  input_size=(2, 3, 4, 4)
1495  ),
1496  dict(
1497  module_name='ConstantPad3d',
1498  constructor_args=((1, 2, 3, 4, 1, 0), 2.),
1499  input_size=(2, 3, 4, 4, 5)
1500  ),
1501  dict(
1502  module_name='Conv3d',
1503  constructor_args=(3, 4, (2, 3, 4)),
1504  input_size=(2, 3, 3, 4, 5),
1505  cudnn=True,
1506  ),
1507  dict(
1508  module_name='Conv3d',
1509  constructor_args=(3, 4, (2, 3, 4), 1, 0, 1, 1, False),
1510  input_size=(2, 3, 3, 4, 5),
1511  cudnn=True,
1512  desc='no_bias',
1513  ),
1514  dict(
1515  module_name='Conv3d',
1516  constructor_args=(3, 4, 2, 2),
1517  input_size=(2, 3, 5, 5, 5),
1518  cudnn=True,
1519  desc='stride',
1520  ),
1521  dict(
1522  module_name='Conv3d',
1523  constructor_args=(3, 4, 2, 2, 1),
1524  input_size=(2, 3, 5, 5, 5),
1525  cudnn=True,
1526  desc='stride_padding',
1527  ),
1528  dict(
1529  fullname='Conv3d_groups',
1530  constructor=lambda: nn.Conv3d(4, 6, kernel_size=3, groups=2),
1531  input_size=(2, 4, 4, 5, 4),
1532  cudnn=True,
1533  ),
1534  dict(
1535  fullname='Conv3d_dilated',
1536  constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2),
1537  input_size=(2, 3, 5, 5, 5),
1538  ),
1539  dict(
1540  fullname='Conv3d_dilated_strided',
1541  constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2, stride=2),
1542  input_size=(2, 3, 5, 5, 5),
1543  ),
1544  dict(
1545  module_name='ConvTranspose3d',
1546  constructor_args=(2, 3, (2, 3, 2)),
1547  cudnn=True,
1548  input_size=(1, 2, 4, 5, 4),
1549  ),
1550  dict(
1551  module_name='ConvTranspose3d',
1552  constructor_args=(2, 3, (2, 3, 2), 1, 0, 0, 1, True, (2, 2, 2)),
1553  cudnn=True,
1554  input_size=(1, 2, 4, 5, 4),
1555  desc='dilated',
1556  ),
1557  dict(
1558  module_name='MaxPool3d',
1559  constructor_args=((2, 2, 2),),
1560  input_size=(2, 3, 5, 5, 5),
1561  ),
1562  dict(
1563  module_name='MaxPool3d',
1564  constructor_args=(2, (2, 2, 2)),
1565  input_size=(2, 3, 5, 5, 5),
1566  desc='stride',
1567  ),
1568  dict(
1569  module_name='MaxPool3d',
1570  constructor_args=(2, 2, (1, 1, 1)),
1571  input_size=(2, 3, 5, 5, 5),
1572  desc='stride_padding',
1573  ),
1574  dict(
1575  module_name='AvgPool3d',
1576  constructor_args=((2, 2, 2),),
1577  input_size=(2, 3, 4, 4, 4),
1578  ),
1579  dict(
1580  module_name='AvgPool3d',
1581  constructor_args=(2, (2, 2, 2)),
1582  input_size=(2, 3, 5, 5, 5),
1583  desc='stride',
1584  ),
1585  dict(
1586  module_name='AvgPool3d',
1587  constructor_args=(2, 2, (1, 1, 1)),
1588  input_size=(2, 3, 5, 5, 5),
1589  desc='stride_pad',
1590  ),
1591  dict(
1592  module_name='AvgPool3d',
1593  constructor_args=(4, 2, (1, 2, 1)),
1594  input_size=(2, 3, 5, 5, 5),
1595  desc='stride_pad_gpu_fixedkw_output',
1596  ),
1597  dict(
1598  module_name='AvgPool3d',
1599  constructor_args=((2, 4, 8), 1, (1, 1, 2)),
1600  input_size=(2, 3, 2, 4, 8),
1601  desc='stride_pad_gpu_general_output',
1602  ),
1603  dict(
1604  module_name='AvgPool3d',
1605  constructor_args=(3, 1, 0),
1606  input_size=(2, 3, 4, 4, 4),
1607  desc='stride1_pad0_gpu_input',
1608  ),
1609  dict(
1610  module_name='AvgPool3d',
1611  constructor_args=(2, 2, (1, 1, 1)),
1612  input_size=(2, 3, 4, 4, 4),
1613  desc='stride_pad_gpu_input_nooverlap',
1614  ),
1615  dict(
1616  module_name='ReplicationPad3d',
1617  constructor_args=((1, 2, 3, 4, 5, 6),),
1618  input_size=(2, 3, 5, 5, 5),
1619  ),
1620  dict(
1621  module_name='Embedding',
1622  constructor_args=(4, 3),
1623  input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
1624  jacobian_input=False,
1625  check_gradgrad=False,
1626  ),
1627  dict(
1628  module_name='EmbeddingBag',
1629  constructor_args=(4, 3),
1630  input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
1631  jacobian_input=False,
1632  check_gradgrad=False,
1633  desc='mean',
1634  ),
1635  dict(
1636  module_name='EmbeddingBag',
1637  constructor_args=(4, 3, None, 2., False, 'sum'),
1638  input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
1639  jacobian_input=False,
1640  check_gradgrad=False,
1641  desc='sum',
1642  ),
1643  dict(
1644  module_name='EmbeddingBag',
1645  constructor_args=(4, 3, None, 2., False, 'max'),
1646  input_fn=lambda: torch.empty(2, 3, dtype=torch.long).random_(4),
1647  jacobian_input=False,
1648  check_gradgrad=False,
1649  desc='max',
1650  ),
1651  dict(
1652  fullname='EmbeddingBag_sparse',
1653  constructor=lambda: nn.EmbeddingBag(4, 3, sparse=True),
1654  input_fn=lambda: torch.randperm(2).repeat(1, 2),
1655  jacobian_input=False,
1656  check_gradgrad=False,
1657  ),
1658  dict(
1659  constructor=lambda: nn.Embedding(4, 3, sparse=True),
1660  input_fn=lambda: torch.randperm(2).repeat(1, 2),
1661  jacobian_input=False,
1662  fullname='Embedding_sparse',
1663  check_gradgrad=False,
1664  ),
1665  dict(
1666  module_name='PixelShuffle',
1667  constructor_args=(3,),
1668  input_size=(1, 9, 4, 4),
1669  ),
1670  dict(
1671  constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
1672  input_size=(1, 2, 4),
1673  fullname='interpolate_nearest_1d',
1674  pickle=False,
1675  ),
1676  dict(
1677  constructor=wrap_functional(F.interpolate, size=(12, ), scale_factor=None, mode='nearest'),
1678  input_size=(1, 2, 3),
1679  fullname='interpolate_nearest_tuple_1d',
1680  pickle=False,
1681  ),
1682  dict(
1683  constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'),
1684  input_size=(1, 2, 4),
1685  fullname='interpolate_nearest_scale_1d',
1686  pickle=False,
1687  ),
1688  dict(
1689  constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=False),
1690  input_size=(1, 2, 4),
1691  fullname='interpolate_linear_1d',
1692  pickle=False,
1693  ),
1694  dict(
1695  constructor=wrap_functional(F.interpolate, size=(4, ), scale_factor=None, mode='linear', align_corners=False),
1696  input_size=(1, 2, 3),
1697  fullname='interpolate_linear_tuple_1d',
1698  pickle=False,
1699  ),
1700  dict(
1701  constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='linear', align_corners=False),
1702  input_size=(1, 2, 4),
1703  fullname='interpolate_linear_scale_1d',
1704  pickle=False,
1705  ),
1706  dict(
1707  constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='linear', align_corners=True),
1708  input_size=(1, 2, 4),
1709  fullname='interpolate_linear_1d_align_corners',
1710  pickle=False,
1711  ),
1712  dict(
1713  constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='linear', align_corners=True),
1714  input_size=(1, 2, 4),
1715  fullname='interpolate_linear_scale_1d_align_corners',
1716  pickle=False,
1717  ),
1718  dict(
1719  constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
1720  input_size=(1, 2, 4, 4),
1721  fullname='interpolate_nearest_2d',
1722  pickle=False,
1723  ),
1724  dict(
1725  constructor=wrap_functional(F.interpolate, size=(12, 16), scale_factor=None, mode='nearest'),
1726  input_size=(1, 2, 3, 4),
1727  fullname='interpolate_nearest_tuple_2d',
1728  pickle=False,
1729  ),
1730  dict(
1731  constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'),
1732  input_size=(1, 2, 4, 4),
1733  fullname='interpolate_nearest_scale_2d',
1734  pickle=False,
1735  ),
1736  dict(
1737  constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bilinear', align_corners=False),
1738  input_size=(1, 2, 4, 4),
1739  fullname='interpolate_bilinear_2d',
1740  pickle=False,
1741  ),
1742  dict(
1743  constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None,
1744  mode='bilinear', align_corners=False),
1745  input_size=(1, 2, 2, 3),
1746  fullname='interpolate_bilinear_tuple_2d',
1747  pickle=False,
1748  ),
1749  dict(
1750  constructor=wrap_functional(F.interpolate, size=None, scale_factor=4.,
1751  mode='bilinear', align_corners=False),
1752  input_size=(1, 2, 4, 4),
1753  fullname='interpolate_bilinear_scale_2d',
1754  pickle=False,
1755  ),
1756  dict(
1757  constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 2.),
1758  mode='bilinear', align_corners=False),
1759  input_size=(1, 2, 4, 4),
1760  fullname='interpolate_bilinear_scale_tuple_shared_2d',
1761  pickle=False,
1762  ),
1763  dict(
1764  constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
1765  mode='bilinear', align_corners=False),
1766  input_size=(1, 2, 4, 4),
1767  fullname='interpolate_bilinear_scale_tuple_skewed_2d',
1768  pickle=False,
1769  ),
1770  dict(
1771  constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bilinear', align_corners=True),
1772  input_size=(1, 2, 4, 4),
1773  fullname='interpolate_bilinear_tuple_2d_align_corners',
1774  pickle=False,
1775  ),
1776  dict(
1777  constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
1778  mode='bilinear', align_corners=True),
1779  input_size=(1, 2, 4, 4),
1780  fullname='interpolate_bilinear_scale_tuple_skewed_2d_align_corners',
1781  pickle=False,
1782  ),
1783  dict(
1784  constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='bicubic', align_corners=False),
1785  input_size=(1, 2, 4, 4),
1786  fullname='interpolate_bicubic_2d',
1787  pickle=False,
1788  ),
1789  dict(
1790  constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None,
1791  mode='bicubic', align_corners=False),
1792  input_size=(1, 2, 2, 3),
1793  fullname='interpolate_bicubic_tuple_2d',
1794  pickle=False,
1795  ),
1796  dict(
1797  constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='bicubic', align_corners=False),
1798  input_size=(1, 2, 4, 4),
1799  fullname='interpolate_bicubic_scale_2d',
1800  pickle=False,
1801  ),
1802  dict(
1803  constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 2.),
1804  mode='bicubic', align_corners=False),
1805  input_size=(1, 2, 4, 4),
1806  fullname='interpolate_bicubic_scale_tuple_shared_2d',
1807  pickle=False,
1808  ),
1809  dict(
1810  constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
1811  mode='bicubic', align_corners=False),
1812  input_size=(1, 2, 4, 4),
1813  fullname='interpolate_bicubic_scale_tuple_skewed_2d',
1814  pickle=False,
1815  ),
1816  dict(
1817  constructor=wrap_functional(F.interpolate, size=(4, 6), scale_factor=None, mode='bicubic', align_corners=True),
1818  input_size=(1, 2, 4, 4),
1819  fullname='interpolate_bicubic_tuple_2d_align_corners',
1820  pickle=False,
1821  ),
1822  dict(
1823  constructor=wrap_functional(F.interpolate, size=None, scale_factor=(2., 1.),
1824  mode='bicubic', align_corners=True),
1825  input_size=(1, 2, 4, 4),
1826  fullname='interpolate_bicubic_scale_tuple_skewed_2d_align_corners',
1827  pickle=False,
1828  ),
1829  dict(
1830  constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='nearest'),
1831  input_size=(1, 2, 4, 4, 4),
1832  fullname='interpolate_nearest_3d',
1833  pickle=False,
1834  ),
1835  dict(
1836  constructor=wrap_functional(F.interpolate, size=(12, 16, 16), scale_factor=None, mode='nearest'),
1837  input_size=(1, 2, 3, 4, 4),
1838  fullname='interpolate_nearest_tuple_3d',
1839  pickle=False,
1840  ),
1841  dict(
1842  constructor=wrap_functional(F.interpolate, size=None, scale_factor=4., mode='nearest'),
1843  input_size=(1, 2, 4, 4, 4),
1844  fullname='interpolate_nearest_scale_3d',
1845  pickle=False,
1846  ),
1847  dict(
1848  constructor=wrap_functional(F.interpolate, size=12, scale_factor=None, mode='trilinear', align_corners=False),
1849  input_size=(1, 2, 4, 4, 4),
1850  fullname='interpolate_trilinear_3d',
1851  pickle=False,
1852  ),
1853  dict(
1854  constructor=wrap_functional(F.interpolate, size=(4, 6, 6),
1855  scale_factor=None, mode='trilinear', align_corners=False),
1856  input_size=(1, 2, 2, 3, 3),
1857  fullname='interpolate_trilinear_tuple_3d',
1858  pickle=False,
1859  ),
1860  dict(
1861  constructor=wrap_functional(F.interpolate, size=None, scale_factor=3., mode='trilinear', align_corners=False),
1862  input_size=(1, 2, 3, 4, 4),
1863  fullname='interpolate_trilinear_scale_3d',
1864  # See https://github.com/pytorch/pytorch/issues/5006
1865  precision=3e-4,
1866  pickle=False,
1867  ),
1868  dict(
1869  constructor=wrap_functional(F.interpolate, size=(4, 6, 6), scale_factor=None,
1870  mode='trilinear', align_corners=True),
1871  input_size=(1, 2, 2, 3, 3),
1872  fullname='interpolate_trilinear_tuple_3d_align_corners',
1873  pickle=False,
1874  ),
1875  dict(
1876  constructor=wrap_functional(F.interpolate, size=None, scale_factor=3., mode='trilinear', align_corners=True),
1877  input_size=(1, 2, 3, 4, 4),
1878  fullname='interpolate_trilinear_scale_3d_align_corners',
1879  # See https://github.com/pytorch/pytorch/issues/5006
1880  precision=3e-4,
1881  pickle=False,
1882  ),
1883  dict(
1884  module_name='AdaptiveMaxPool1d',
1885  constructor_args=(3,),
1886  input_fn=lambda: _rand_tensor_non_equal(1, 3, 5),
1887  ),
1888  dict(
1889  module_name='AdaptiveMaxPool2d',
1890  constructor_args=(3,),
1891  input_fn=lambda: _rand_tensor_non_equal(1, 3, 5, 6),
1892  desc='single',
1893  ),
1894  dict(
1895  module_name='AdaptiveMaxPool2d',
1896  constructor_args=((3, 4),),
1897  input_fn=lambda: _rand_tensor_non_equal(1, 3, 5, 6),
1898  desc='tuple',
1899  ),
1900  dict(
1901  module_name='AdaptiveMaxPool2d',
1902  constructor_args=((3, None),),
1903  input_fn=lambda: _rand_tensor_non_equal(1, 3, 5, 6),
1904  desc='tuple_none',
1905  ),
1906  dict(
1907  module_name='AdaptiveMaxPool3d',
1908  constructor_args=(3,),
1909  input_fn=lambda: _rand_tensor_non_equal(2, 3, 5, 6, 7),
1910  desc='single',
1911  ),
1912  dict(
1913  module_name='AdaptiveMaxPool3d',
1914  constructor_args=((3, 4, 5),),
1915  input_fn=lambda: _rand_tensor_non_equal(2, 3, 5, 6, 7),
1916  desc='tuple',
1917  ),
1918  dict(
1919  module_name='AdaptiveMaxPool3d',
1920  constructor_args=((3, None, 5),),
1921  input_fn=lambda: _rand_tensor_non_equal(2, 3, 5, 6, 7),
1922  desc='tuple_none',
1923  ),
1924  dict(
1925  module_name='AdaptiveMaxPool3d',
1926  constructor_args=(3,),
1927  input_fn=lambda: _rand_tensor_non_equal(2, 3, 12, 9, 3),
1928  desc='single_nonatomic',
1929  ),
1930  dict(
1931  module_name='AdaptiveMaxPool3d',
1932  constructor_args=((3, 4, 5),),
1933  input_fn=lambda: _rand_tensor_non_equal(2, 3, 6, 4, 10),
1934  desc='tuple_nonatomic',
1935  ),
1936  dict(
1937  module_name='AdaptiveAvgPool1d',
1938  constructor_args=(3,),
1939  input_fn=lambda: torch.rand(1, 3, 5),
1940  ),
1941  dict(
1942  module_name='AdaptiveAvgPool1d',
1943  constructor_args=(1,),
1944  input_fn=lambda: torch.rand(1, 3, 5),
1945  desc='one_output',
1946  ),
1947  dict(
1948  module_name='AdaptiveAvgPool2d',
1949  constructor_args=(3,),
1950  input_fn=lambda: torch.rand(1, 3, 5, 6),
1951  desc='single',
1952  ),
1953  dict(
1954  module_name='AdaptiveAvgPool2d',
1955  constructor_args=(1,),
1956  input_fn=lambda: torch.rand(1, 3, 5, 6),
1957  desc='single_1x1output',
1958  ),
1959  dict(
1960  module_name='AdaptiveAvgPool2d',
1961  constructor_args=((3, 4),),
1962  input_fn=lambda: torch.rand(1, 3, 5, 6),
1963  desc='tuple',
1964  ),
1965  dict(
1966  module_name='AdaptiveAvgPool2d',
1967  constructor_args=((3, None),),
1968  input_fn=lambda: torch.rand(1, 3, 5, 6),
1969  desc='tuple_none',
1970  ),
1971  dict(
1972  module_name='AdaptiveAvgPool3d',
1973  constructor_args=(3,),
1974  input_fn=lambda: torch.rand(2, 3, 5, 2, 7),
1975  desc='single',
1976  ),
1977  dict(
1978  module_name='AdaptiveAvgPool3d',
1979  constructor_args=((3, 4, 5),),
1980  input_fn=lambda: torch.rand(2, 3, 5, 3, 7),
1981  desc='tuple',
1982  ),
1983  dict(
1984  module_name='AdaptiveAvgPool3d',
1985  constructor_args=((None, 4, 5),),
1986  input_fn=lambda: torch.rand(2, 3, 5, 3, 7),
1987  desc='tuple_none',
1988  ),
1989  dict(
1990  module_name='SELU',
1991  input_size=(3, 2, 5),
1992  check_inplace=True
1993  ),
1994  dict(
1995  module_name='SELU',
1996  input_size=(),
1997  check_inplace=True,
1998  desc='scalar'
1999  ),
2000  dict(
2001  module_name='CELU',
2002  input_size=(3, 2, 5),
2003  constructor_args=(2.,),
2004  check_inplace=True,
2005  reference_fn=lambda x, _: torch.where(x >= 0, x, 2. * ((.5 * x).exp() - 1)),
2006  ),
2007  dict(
2008  module_name='CELU',
2009  input_size=(),
2010  constructor_args=(2.,),
2011  check_inplace=True,
2012  reference_fn=lambda x, _: torch.where(x >= 0, x, 2. * ((.5 * x).exp() - 1)),
2013  desc='scalar'
2014  ),
2015  dict(
2016  module_name='GLU',
2017  input_size=(5, 6),
2018  ),
2019  dict(
2020  module_name='GLU',
2021  constructor_args=(1,),
2022  input_size=(5, 6, 7),
2023  desc='dim',
2024  ),
2025  dict(
2026  constructor=wrap_functional(F.softmax, dim=-1),
2027  input_size=(2, 128), # trigger the last-dim algo in CUDA
2028  fullname='softmax_lastdim',
2029  pickle=False,
2030  ),
2031  dict(
2032  constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64),
2033  input_size=(2, 128),
2034  fullname='softmax_lastdim_dtype',
2035  pickle=False,
2036  test_cuda=False
2037  ),
2038  dict(
2039  constructor=wrap_functional(F.softmax, dim=1),
2040  input_size=(2, 128, 2, 2), # trigger special case of spatial CUDA algo
2041  fullname='softmax_spatial_special',
2042  pickle=False,
2043  test_cuda=(not TEST_WITH_ROCM)
2044  ),
2045  dict(
2046  constructor=wrap_functional(F.softmax, dim=1),
2047  input_size=(2, 2, 4, 4), # regular spatial algorithm
2048  fullname='softmax_spatial',
2049  pickle=False,
2050  ),
2051  dict(
2052  constructor=wrap_functional(F.softmax, dim=1, dtype=torch.float64),
2053  input_size=(2, 2, 4, 4), # regular spatial algorithm
2054  fullname='softmax_spatial_dtype',
2055  pickle=False,
2056  test_cuda=False
2057  ),
2058  dict(
2059  constructor=wrap_functional(F.softmax, dim=0),
2060  input_size=(2, 3, 4, 5),
2061  fullname='softmax_functional_dim0',
2062  test_cuda=False,
2063  pickle=False,
2064  ),
2065  dict(
2066  constructor=wrap_functional(F.softmax, dim=3),
2067  input_size=(2, 3, 4, 5),
2068  fullname='softmax_functional_dim3',
2069  test_cuda=False,
2070  pickle=False,
2071  ),
2072  dict(
2073  constructor=wrap_functional(F.softmax, dim=-1),
2074  input_size=(),
2075  fullname='softmax_functional_scalar',
2076  test_cuda=False,
2077  pickle=False,
2078  ),
2079  dict(
2080  constructor=wrap_functional(F.log_softmax, dim=-1),
2081  input_size=(2, 128), # trigger the last-dim algo in CUDA
2082  fullname='log_softmax_lastdim',
2083  pickle=False,
2084  ),
2085  dict(
2086  constructor=wrap_functional(F.log_softmax, dim=1),
2087  input_size=(2, 128, 2, 2), # trigger special case of spatial CUDA algo
2088  fullname='log_softmax_spatial_special',
2089  pickle=False,
2090  test_cuda=(not TEST_WITH_ROCM)
2091  ),
2092  dict(
2093  constructor=wrap_functional(F.log_softmax, dim=1),
2094  input_size=(2, 2, 4, 4), # regular spatial algorithm
2095  fullname='log_softmax_spatial',
2096  pickle=False,
2097  ),
2098  dict(
2099  constructor=wrap_functional(F.log_softmax, dim=0),
2100  input_size=(2, 3, 4, 5),
2101  fullname='log_softmax_dim0',
2102  pickle=False,
2103  ),
2104  dict(
2105  constructor=wrap_functional(F.log_softmax, dim=3),
2106  input_size=(2, 3, 4, 5),
2107  fullname='log_softmax_dim3',
2108  pickle=False,
2109  ),
2110  dict(
2111  constructor=wrap_functional(F.log_softmax, dim=0),
2112  input_size=(),
2113  fullname='log_softmax_scalar',
2114  pickle=False,
2115  ),
2116  dict(
2117  fullname='Unfold',
2118  constructor=lambda: nn.Unfold((2, 2), (1, 1), (0, 0), (1, 1)),
2119  input_size=(2, 4, 3, 3),
2120  check_gradgrad=False,
2121  test_cuda=True,
2122  ),
2123  dict(
2124  fullname='Fold',
2125  constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)),
2126  input_size=(2, 16, 4),
2127  check_gradgrad=False,
2128  test_cuda=True,
2129  ),
2130  dict(
2131  fullname='Unfold_int_input',
2132  constructor=lambda: nn.Unfold(2, 1, 0, 1),
2133  input_size=(2, 4, 3, 3),
2134  check_gradgrad=False,
2135  test_cuda=True,
2136  ),
2137  dict(
2138  fullname='Fold_int_input',
2139  constructor=lambda: nn.Fold(3, 2, 1, 0, 1),
2140  input_size=(2, 16, 4),
2141  check_gradgrad=False,
2142  test_cuda=True,
2143  ),
2144  dict(
2145  module_name='Threshold',
2146  constructor_args=(2., 1.),
2147  input_size=(),
2148  check_inplace=True,
2149  desc='threshold_value_scalar'
2150  ),
2151 
2152  dict(
2153  module_name='ReLU',
2154  input_size=(),
2155  check_inplace=True,
2156  desc='scalar'
2157  ),
2158  dict(
2159  module_name='ReLU6',
2160  input_size=(),
2161  check_inplace=True,
2162  desc='scalar'
2163  ),
2164  dict(
2165  module_name='RReLU',
2166  constructor_args=(0.1, 0.9),
2167  input_size=(),
2168  desc='with_up_down_scalar',
2169  test_cuda=False,
2170  ),
2171  dict(
2172  module_name='Hardtanh',
2173  input_size=(),
2174  reference_fn=lambda i, _: i.clamp(-1, 1),
2175  desc='scalar'
2176  ),
2177  dict(
2178  module_name='Sigmoid',
2179  input_size=(),
2180  desc='scalar',
2181  ),
2182  dict(
2183  module_name='Tanh',
2184  input_size=(),
2185  desc='scalar',
2186  ),
2187  dict(
2188  module_name='Softmax',
2189  constructor_args=(0,),
2190  input_size=(),
2191  reference_fn=lambda i, _: torch.exp(i).div(torch.exp(i).sum(0, True)),
2192  desc='scalar',
2193  ),
2194  dict(
2195  module_name='LogSoftmax',
2196  constructor_args=(0,),
2197  input_size=(),
2198  reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(0, False)).log_(),
2199  desc='multiparam_scalar',
2200  ),
2201  dict(
2202  module_name='ELU',
2203  constructor_args=(2.,),
2204  input_size=(),
2205  desc='scalar',
2206  ),
2207  dict(
2208  module_name='Hardshrink',
2209  constructor_args=(2.,),
2210  input_size=(),
2211  desc='scalar',
2212  ),
2213  dict(
2214  module_name='LeakyReLU',
2215  constructor_args=(0.5,),
2216  input_size=(),
2217  check_inplace=True,
2218  desc='with_negval_scalar'
2219  ),
2220  dict(
2221  module_name='LogSigmoid',
2222  input_size=(),
2223  reference_fn=lambda i, _: i.sigmoid().log(),
2224  desc='scalar'
2225  ),
2226  dict(
2227  module_name='Softplus',
2228  constructor_args=(2, -100),
2229  input_size=(),
2230  reference_fn=(lambda i, _: ((i * 2) > -100).type_as(i) * i +
2231  ((i * 2) <= -100).type_as(i) * 1. / 2. * torch.log(1 + torch.exp(2 * i))),
2232  desc='beta_threshold_scalar',
2233  ),
2234  dict(
2235  module_name='Softshrink',
2236  constructor_args=(1,),
2237  input_size=(),
2238  desc='lambda_scalar',
2239  ),
2240  dict(
2241  module_name='PReLU',
2242  input_size=(),
2243  reference_fn=lambda i, p: torch.clamp(i, min=0) + torch.clamp(i, max=0) * p[0][0],
2244  desc='scalar',
2245  ),
2246  dict(
2247  module_name='Softsign',
2248  input_size=(),
2249  reference_fn=lambda i, _: i.div(1 + torch.abs(i)),
2250  desc='scalar',
2251  ),
2252  dict(
2253  module_name='Softmin',
2254  constructor_args=(0,),
2255  input_size=(),
2256  desc='scalar',
2257  ),
2258  dict(
2259  module_name='Tanhshrink',
2260  input_size=(),
2261  desc='scalar',
2262  ),
2263  dict(
2264  fullname='Padding12_1dcircular',
2265  constructor=wrap_functional(F.pad, pad=(1, 2), mode='circular'),
2266  input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]),
2267  reference_fn=lambda i, _: padding1d_circular(i, (1, 2)),
2268  skip_double=TEST_WITH_ROCM,
2269  pickle=False,
2270  ),
2271  dict(
2272  fullname='Padding31_1dcircular',
2273  constructor=wrap_functional(F.pad, pad=(3, 1), mode='circular'),
2274  input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]),
2275  reference_fn=lambda i, _: padding1d_circular(i, (3, 1)),
2276  skip_double=TEST_WITH_ROCM,
2277  pickle=False,
2278  ),
2279  dict(
2280  fullname='Padding33_1dcircular',
2281  constructor=wrap_functional(F.pad, pad=(3, 3), mode='circular'),
2282  input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 2, 3]),
2283  reference_fn=lambda i, _: padding1d_circular(i, (3, 3)),
2284  skip_double=TEST_WITH_ROCM,
2285  pickle=False,
2286  ),
2287  dict(
2288  fullname='Padding1221_2dcircular',
2289  constructor=wrap_functional(F.pad, pad=(1, 2, 2, 1), mode='circular'),
2290  input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 1, 2, 3]),
2291  reference_fn=lambda i, _: padding2d_circular(i, (1, 2, 2, 1)),
2292  skip_double=TEST_WITH_ROCM,
2293  pickle=False,
2294  ),
2295  dict(
2296  fullname='Padding2322_2dcircular',
2297  constructor=wrap_functional(F.pad, pad=(2, 3, 2, 2), mode='circular'),
2298  input_fn=lambda: torch.arange(6, out=torch.DoubleTensor()).reshape([1, 1, 2, 3]),
2299  reference_fn=lambda i, _: padding2d_circular(i, (2, 3, 2, 2)),
2300  skip_double=TEST_WITH_ROCM,
2301  pickle=False,
2302  ),
2303  dict(
2304  fullname='Padding3331_2dcircular',
2305  constructor=wrap_functional(F.pad, pad=(3, 3, 3, 1), mode='circular'),
2306  input_fn=lambda: torch.arange(9, out=torch.DoubleTensor()).reshape([1, 1, 3, 3]),
2307  reference_fn=lambda i, _: padding2d_circular(i, (3, 3, 3, 1)),
2308  skip_double=TEST_WITH_ROCM,
2309  pickle=False,
2310  ),
2311  dict(
2312  fullname='Padding122112_3dcircular',
2313  constructor=wrap_functional(F.pad, pad=(1, 2, 2, 1, 1, 2), mode='circular'),
2314  input_fn=lambda: torch.arange(12, out=torch.DoubleTensor()).reshape([1, 1, 2, 2, 3]),
2315  reference_fn=lambda i, _: padding3d_circular(i, (1, 2, 2, 1, 1, 2)),
2316  skip_double=TEST_WITH_ROCM,
2317  pickle=False,
2318  ),
2319  dict(
2320  fullname='Padding322112_3dcircular',
2321  constructor=wrap_functional(F.pad, pad=(3, 2, 2, 1, 1, 2), mode='circular'),
2322  input_fn=lambda: torch.arange(12, out=torch.DoubleTensor()).reshape([1, 1, 2, 2, 3]),
2323  reference_fn=lambda i, _: padding3d_circular(i, (3, 2, 2, 1, 1, 2)),
2324  skip_double=TEST_WITH_ROCM,
2325  pickle=False,
2326  ),
2327  dict(
2328  fullname='Padding332122_3dcircular',
2329  constructor=wrap_functional(F.pad, pad=(3, 3, 2, 1, 2, 2), mode='circular'),
2330  input_fn=lambda: torch.arange(12, out=torch.DoubleTensor()).reshape([1, 1, 2, 2, 3]),
2331  reference_fn=lambda i, _: padding3d_circular(i, (3, 3, 2, 1, 2, 2)),
2332  skip_double=TEST_WITH_ROCM,
2333  pickle=False,
2334  ),
2335 
2336  dict(
2337  module_name='Conv1d',
2338  constructor_args=(3, 4, 2, 2, (1,), 1, 1, True, 'circular'),
2339  input_size=(2, 3, 5,),
2340  cudnn=True,
2341  desc='stride1_pad1circular',
2342  ),
2343  dict(
2344  module_name='Conv1d',
2345  constructor_args=(3, 4, 2, 2, (2,), 1, 1, True, 'circular'),
2346  input_size=(2, 3, 5,),
2347  cudnn=True,
2348  desc='stride1_pad2circular',
2349  ),
2350  dict(
2351  module_name='Conv2d',
2352  constructor_args=(3, 4, (3, 3), (2, 2), (1, 2), 1, 1, True, 'circular'),
2353  input_size=(2, 3, 3, 3),
2354  cudnn=True,
2355  desc='pad2circular'
2356  ),
2357  dict(
2358  module_name='Conv3d',
2359  constructor_args=(3, 4, 2, 2, (1, 2, 3), 1, 1, True, 'circular'),
2360  input_size=(2, 3, 3, 3, 3),
2361  cudnn=True,
2362  desc='stride_pad1circular',
2363  ),
2364 ]
2365 
2366 
2367 def kldivloss_reference(input, target, reduction='mean'):
2368  safe_target = target * (target > 0).type_as(target)
2369  safe_target_log = (safe_target + (target <= 0).type_as(target)).log()
2370  result = safe_target * (safe_target_log - input)
2371  if reduction == 'mean':
2372  return result.mean()
2373  elif reduction == 'sum':
2374  return result.sum()
2375  elif reduction == 'batchmean' and results.dim() != 0:
2376  return result.sum() / result.size(0)
2377  return result
2378 
2379 
2380 def nlllossNd_reference(input, target, weight=None, ignore_index=-100,
2381  reduction='mean'):
2382  assert input.dim() >= 3
2383  N = input.size(0)
2384  C = input.size(1)
2385  out_size = (N,) + input.size()[2:]
2386  output = torch.zeros(out_size).type_as(input)
2387 
2388  if weight is None:
2389  weight = torch.ones(C).type_as(input)
2390  total_weight = 0
2391  for tup in product(*[range(size) for size in out_size]):
2392  t_nx = target[tup]
2393  norm = 0. if ignore_index == t_nx else weight[t_nx].item()
2394  input_index = list(tup)
2395  input_index.insert(1, t_nx)
2396  output[tup] = -input[tuple(input_index)] * norm
2397  total_weight += norm
2398 
2399  if reduction == 'mean':
2400  return output.sum() / total_weight
2401  elif reduction == 'sum':
2402  return output.sum()
2403  return output
2404 
2405 
2406 def nllloss_reference(input, target, weight=None, ignore_index=-100,
2407  reduction='mean'):
2408 
2409  def nll_loss_helper(input, target, weight, ignore_index):
2410  if target == ignore_index:
2411  return (0, 0)
2412  norm = 1 if weight is None else weight[target]
2413  result = -input[target] * norm
2414  return (result, norm)
2415 
2416  losses_and_weights = [nll_loss_helper(i, t, weight, ignore_index)
2417  for i, t in zip(input, target)]
2418  losses, weights = zip(*losses_and_weights)
2419  losses_tensor = input.new_tensor(losses)
2420  if reduction == 'mean':
2421  return sum(losses_tensor) / sum(weights)
2422  elif reduction == 'sum':
2423  return sum(losses_tensor)
2424  else:
2425  return losses_tensor
2426 
2427 
2428 def smoothl1loss_reference(input, target, reduction='mean'):
2429  abs_diff = (input - target).abs()
2430  ge_one_mask = (abs_diff >= 1).type_as(abs_diff)
2431  lt_one_mask = (abs_diff < 1).type_as(abs_diff)
2432  output = ge_one_mask * (abs_diff - 0.5) + lt_one_mask * 0.5 * (abs_diff ** 2)
2433  if reduction == 'mean':
2434  return output.mean()
2435  elif reduction == 'sum':
2436  return output.sum()
2437  return output
2438 
2439 
2440 def _multilabelmarginloss_reference(input, target):
2441  targets = []
2442  for target_index in target:
2443  if target_index < 0:
2444  break
2445  targets.append(target_index)
2446 
2447  sum = 0
2448  for target_index in targets:
2449  for i in range(0, len(input)):
2450  if i not in targets:
2451  sum += max(0, 1 - input[target_index] + input[i])
2452 
2453  return sum
2454 
2455 
2456 def multilabelmarginloss_reference(input, target, reduction='mean'):
2457  if input.dim() == 1:
2458  n = 1
2459  dim = input.size(0)
2460  output = input.new(n).zero_()
2461  output[0] = _multilabelmarginloss_reference(input, target)
2462  else:
2463  n = input.size(0)
2464  dim = input.size(1)
2465  output = input.new(n).zero_()
2466  for i in range(0, n):
2467  output[i] = _multilabelmarginloss_reference(input[i], target[i])
2468 
2469  if reduction == 'mean':
2470  return output.mean() / dim
2471  elif reduction == 'sum':
2472  return output.sum() / dim
2473  return output / dim
2474 
2475 
2476 def hingeembeddingloss_reference(input, target, margin=1.0, reduction='mean'):
2477  margin_clamp = (margin - input).clamp(min=0).type_as(input)
2478  output = torch.where(target == 1, input, margin_clamp)
2479 
2480  if reduction == 'mean':
2481  return output.mean()
2482  elif reduction == 'sum':
2483  return output.sum()
2484  return output
2485 
2486 
2487 def softmarginloss_reference(input, target, reduction='mean'):
2488  output = (1 + (-input * target).exp()).log()
2489 
2490  if reduction == 'mean':
2491  return output.mean()
2492  elif reduction == 'sum':
2493  return output.sum()
2494  return output
2495 
2496 
2497 def _multimarginloss_reference(input, target_idx, p, margin, weight):
2498  if weight is None:
2499  weight = input.new(len(input)).fill_(1)
2500 
2501  output = 0
2502  for i in range(0, len(input)):
2503  if i != target_idx:
2504  output += max(0, weight[target_idx] * (margin - input[target_idx] + input[i]) ** p)
2505  return output
2506 
2507 
2508 def multimarginloss_reference(input, target, p=1, margin=1, weight=None, reduction='mean'):
2509  if input.dim() == 1:
2510  n = 1
2511  dim = input.size(0)
2512  return _multimarginloss_reference(input, target[0], p, margin, weight) / dim
2513  else:
2514  n = input.size(0)
2515  dim = input.size(1)
2516  output = input.new(n)
2517  for x in range(0, n):
2518  output[x] = _multimarginloss_reference(input[x], target[x], p, margin, weight)
2519 
2520  if reduction == 'mean':
2521  return output.mean() / dim
2522  elif reduction == 'sum':
2523  return output.sum() / dim
2524  return output / dim
2525 
2526 
2527 def cosineembeddingloss_reference(input1, input2, target, margin=0, reduction='mean'):
2528  def _cos(a, b):
2529  cos = a.new(a.size(0))
2530  for i in range(0, a.size(0)):
2531  cos[i] = (a[i] * b[i]).sum() / ((((a[i] * a[i]).sum() + 1e-12) * ((b[i] * b[i]).sum() + 1e-12)) ** 0.5)
2532  return cos
2533 
2534  output = torch.where(target == 1, 1 - _cos(input1, input2), (_cos(input1, input2) - margin).clamp(min=0))
2535 
2536  if reduction == 'mean':
2537  return output.mean()
2538  elif reduction == 'sum':
2539  return output.sum()
2540  return output
2541 
2542 
2543 def tripletmarginloss_reference(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False,
2544  reduction='mean'):
2545  d_p = torch.pairwise_distance(anchor, positive, p, eps)
2546  d_n = torch.pairwise_distance(anchor, negative, p, eps)
2547  if swap:
2548  d_s = torch.pairwise_distance(positive, negative, p, eps)
2549  d_n = torch.min(d_n, d_s)
2550 
2551  output = torch.clamp(margin + d_p - d_n, min=0.0)
2552  if reduction == 'mean':
2553  return output.mean()
2554  elif reduction == 'sum':
2555  return output.sum()
2556  return output
2557 
2558 
2559 def marginrankingloss_reference(input1, input2, target, margin=0, reduction='mean'):
2560  output = (-target * (input1 - input2) + margin).clamp(min=0)
2561  if reduction == 'mean':
2562  return output.mean()
2563  elif reduction == 'sum':
2564  return output.sum()
2565  return output
2566 
2567 
2568 # this directly follows Graves et al's paper, in contrast to the production implementation, it does not use log-space
2569 def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean'):
2570  input_lengths = torch.as_tensor(input_lengths, dtype=torch.long)
2571  target_lengths = torch.as_tensor(target_lengths, dtype=torch.long)
2572  dt = log_probs.dtype
2573  log_probs = log_probs.double() # we need the accuracy as we are not in logspace
2574  targets = targets.long()
2575  cum_target_lengths = target_lengths.cumsum(0)
2576  losses = []
2577  for i in range(log_probs.size(1)):
2578  input_length = input_lengths[i].item()
2579  target_length = target_lengths[i].item()
2580  cum_target_length = cum_target_lengths[i].item()
2581  targets_prime = targets.new_full((2 * target_length + 1,), blank)
2582  if targets.dim() == 2:
2583  targets_prime[1::2] = targets[i, :target_length]
2584  else:
2585  targets_prime[1::2] = targets[cum_target_length - target_length:cum_target_length]
2586  probs = log_probs[:input_length, i].exp()
2587  alpha = log_probs.new_zeros((target_length * 2 + 1,))
2588  alpha[0] = probs[0, blank]
2589  alpha[1] = probs[0, targets_prime[1]]
2590  mask_third = (targets_prime[:-2] != targets_prime[2:])
2591  for t in range(1, input_length):
2592  alpha_next = alpha.clone()
2593  alpha_next[1:] += alpha[:-1]
2594  alpha_next[2:] += torch.where(mask_third, alpha[:-2], alpha.new_zeros(1))
2595  alpha = probs[t, targets_prime] * alpha_next
2596  losses.append(-alpha[-2:].sum().log()[None])
2597  output = torch.cat(losses, 0)
2598  if reduction == 'mean':
2599  return (output / target_lengths.to(dtype=output.dtype, device=output.device)).mean()
2600  elif reduction == 'sum':
2601  return output.sum()
2602  output = output.to(dt)
2603  return output
2604 
2605 
2606 def padding1d_circular(input, pad):
2607  r""" input:
2608  [[[0., 1., 2.],
2609  [3., 4., 5.]]]
2610  pad: (1, 2)
2611  output:
2612  [[[2., 0., 1., 2., 0., 1.],
2613  [5., 3., 4., 5., 3., 4.]]]
2614  """
2615  return torch.cat([input[:, :, -pad[0]:], input,
2616  input[:, :, 0:pad[1]]], dim=2)
2617 
2618 
2619 def padding2d_circular(input, pad):
2620  r"""input:
2621  [[[[0., 1., 2],
2622  [3., 4., 5.]]]]
2623  pad: (1, 2, 2, 1)
2624  output:
2625  [[[[2., 0., 1., 2., 0., 1.],
2626  [5., 3., 4., 5., 3., 4.],
2627  [2., 0., 1., 2., 0., 1.],
2628  [5., 3., 4., 5., 3., 4.],
2629  [2., 0., 1., 2., 0., 1.]]]]
2630  """
2631  input = torch.cat([input[:, :, -pad[2]:], input, input[:, :, 0:pad[3]]], dim=2)
2632  return torch.cat([input[:, :, :, -pad[0]:], input, input[:, :, :, 0:pad[1]]], dim=3)
2633 
2634 
2635 def padding3d_circular(input, pad):
2636  r"""input:
2637  [[[[[ 0., 1., 2.],
2638  [ 3., 4., 5.]],
2639  [[ 6., 7., 8.],
2640  [ 9., 10., 11.]]]]]
2641  pad: (1, 2, 2, 1, 1, 2)
2642  output: [[[[[ 8., 6., 7., 8., 6., 7.],
2643  [11., 9., 10., 11., 9., 10.],
2644  [ 8., 6., 7., 8., 6., 7.],
2645  [11., 9., 10., 11., 9., 10.],
2646  [ 8., 6., 7., 8., 6., 7.]],
2647 
2648  [[ 2., 0., 1., 2., 0., 1.],
2649  [ 5., 3., 4., 5., 3., 4.],
2650  [ 2., 0., 1., 2., 0., 1.],
2651  [ 5., 3., 4., 5., 3., 4.],
2652  [ 2., 0., 1., 2., 0., 1.]],
2653 
2654  [[ 8., 6., 7., 8., 6., 7.],
2655  [11., 9., 10., 11., 9., 10.],
2656  [ 8., 6., 7., 8., 6., 7.],
2657  [11., 9., 10., 11., 9., 10.],
2658  [ 8., 6., 7., 8., 6., 7.]],
2659 
2660  [[ 2., 0., 1., 2., 0., 1.],
2661  [ 5., 3., 4., 5., 3., 4.],
2662  [ 2., 0., 1., 2., 0., 1.],
2663  [ 5., 3., 4., 5., 3., 4.],
2664  [ 2., 0., 1., 2., 0., 1.]],
2665 
2666  [[ 8., 6., 7., 8., 6., 7.],
2667  [11., 9., 10., 11., 9., 10.],
2668  [ 8., 6., 7., 8., 6., 7.],
2669  [11., 9., 10., 11., 9., 10.],
2670  [ 8., 6., 7., 8., 6., 7.]]]]]
2671  """
2672  input = torch.cat([input[:, :, -pad[4]:], input, input[:, :, 0:pad[5]]], dim=2)
2673  input = torch.cat([input[:, :, :, -pad[2]:], input, input[:, :, :, 0:pad[3]]], dim=3)
2674  return torch.cat([input[:, :, :, :, -pad[0]:], input, input[:, :, :, :, 0:pad[1]]], dim=4)
2675 
2676 
2677 loss_reference_fns = {
2678  'KLDivLoss': kldivloss_reference,
2679  'NLLLoss': nllloss_reference,
2680  'NLLLossNd': nlllossNd_reference,
2681  'SmoothL1Loss': smoothl1loss_reference,
2682  'MultiLabelMarginLoss': multilabelmarginloss_reference,
2683  'HingeEmbeddingLoss': hingeembeddingloss_reference,
2684  'SoftMarginLoss': softmarginloss_reference,
2685  'MultiMarginLoss': multimarginloss_reference,
2686  'CosineEmbeddingLoss': cosineembeddingloss_reference,
2687  'TripletMarginLoss': tripletmarginloss_reference,
2688  'MarginRankingLoss': marginrankingloss_reference,
2689  'CTCLoss': ctcloss_reference,
2690 }
2691 
2692 
2693 criterion_tests = [
2694  dict(
2695  module_name='L1Loss',
2696  input_size=(2, 3, 4),
2697  target_size=(2, 3, 4),
2698  reference_fn=lambda i, t, _: 1. / i.numel() *
2699  sum((a - b).abs().sum() for a, b in zip(i, t)),
2700  ),
2701  dict(
2702  module_name='NLLLoss',
2703  input_fn=lambda: torch.rand(15, 10).log(),
2704  target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
2705  reference_fn=lambda i, t, m:
2706  nllloss_reference(i, t, reduction=get_reduction(m)),
2707  check_sum_reduction=True
2708  ),
2709  dict(
2710  module_name='NLLLoss',
2711  constructor_args=(None, None, 2),
2712  input_fn=lambda: torch.rand(15, 10).log(),
2713  target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
2714  reference_fn=lambda i, t, _: nllloss_reference(i, t, ignore_index=2),
2715  desc='ignore_index'
2716  ),
2717  dict(
2718  module_name='NLLLoss',
2719  constructor_args_fn=lambda: (torch.rand(10),),
2720  input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
2721  target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
2722  reference_fn=lambda i, t, m:
2723  nllloss_reference(i, t, weight=get_weight(m)),
2724  desc='weights',
2725  ),
2726  dict(
2727  module_name='NLLLoss',
2728  constructor_args_fn=lambda: (torch.rand(10), None, 2),
2729  input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
2730  target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
2731  reference_fn=lambda i, t, m:
2732  nllloss_reference(i, t, weight=get_weight(m), ignore_index=2),
2733  desc='weights_ignore_index'
2734  ),
2735  dict(
2736  module_name='NLLLoss',
2737  constructor_args_fn=lambda: (torch.rand(10), None, -1),
2738  input_fn=lambda: torch.rand(15, 10).add(1e-2).log(),
2739  target_fn=lambda: torch.Tensor(15).uniform_().mul(10 + 1).floor().long() - 1,
2740  reference_fn=lambda i, t, m:
2741  nllloss_reference(i, t, weight=get_weight(m), ignore_index=-1),
2742  desc='weights_ignore_index_neg'
2743  ),
2744  dict(
2745  module_name='KLDivLoss',
2746  input_fn=lambda: torch.rand(10, 10).log(),
2747  target_fn=lambda: torch.rand(10, 10),
2748  reference_fn=lambda i, t, m:
2749  kldivloss_reference(i, t, get_reduction(m)),
2750  check_sum_reduction=True,
2751  ),
2752  dict(
2753  module_name='MSELoss',
2754  input_size=(2, 3, 4, 5),
2755  target_size=(2, 3, 4, 5),
2756  reference_fn=lambda i, t, m: ((i - t).abs().pow(2).sum() / (i.numel()
2757  if get_reduction(m) == 'mean' else 1)),
2758  check_sum_reduction=True,
2759  ),
2760  dict(
2761  module_name='BCELoss',
2762  input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
2763  target_fn=lambda: torch.randn(15, 10).gt(0).double(),
2764  reference_fn=lambda i, t, m: -(t * i.log() + (1 - t) * (1 - i).log()).sum() /
2765  (i.numel() if get_reduction(m) else 1),
2766  check_gradgrad=False,
2767  ),
2768  dict(
2769  module_name='BCELoss',
2770  constructor_args_fn=lambda: (torch.rand(10),),
2771  input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
2772  target_fn=lambda: torch.randn(15, 10).gt(0).double(),
2773  reference_fn=lambda i, t, m: -((t * i.log() + (1 - t) * (1 - i).log()) * get_weight(m)).sum() /
2774  (i.numel() if get_reduction(m) else 1),
2775  desc='weights',
2776  check_gradgrad=False,
2777  ),
2778  dict(
2779  module_name='CrossEntropyLoss',
2780  input_size=(15, 10),
2781  target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
2782  ),
2783  dict(
2784  module_name='CrossEntropyLoss',
2785  constructor_args_fn=lambda: (torch.rand(10),),
2786  input_size=(15, 10),
2787  target_fn=lambda: torch.Tensor(15).uniform_().mul(10).floor().long(),
2788  desc='weights',
2789  ),
2790  dict(
2791  module_name='HingeEmbeddingLoss',
2792  input_size=(10,),
2793  target_fn=lambda: torch.randn(10).gt(0).double().mul_(2).sub(1),
2794  reference_fn=lambda i, t, m:
2795  hingeembeddingloss_reference(i, t, reduction=get_reduction(m)),
2796  check_sum_reduction=True,
2797  ),
2798  dict(
2799  module_name='HingeEmbeddingLoss',
2800  constructor_args=(0.5,),
2801  input_size=(10,),
2802  target_fn=lambda: torch.randn(10).gt(0).double().mul_(2).sub(1),
2803  reference_fn=lambda i, t, m:
2804  hingeembeddingloss_reference(i, t, margin=0.5, reduction=get_reduction(m)),
2805  desc='margin',
2806  check_sum_reduction=True,
2807  ),
2808  dict(
2809  module_name='MultiLabelMarginLoss',
2810  input_size=(10,),
2811  target_fn=lambda: torch.rand(10).mul(10).floor().long(),
2812  reference_fn=lambda i, t, m:
2813  multilabelmarginloss_reference(i, t, reduction=get_reduction(m)),
2814  desc="1d",
2815  check_sum_reduction=True,
2816  check_gradgrad=False,
2817  ),
2818  dict(
2819  module_name='MultiLabelMarginLoss',
2820  input_size=(5, 10),
2821  target_fn=lambda: torch.rand(5, 10).mul(10).floor().long(),
2822  reference_fn=lambda i, t, m:
2823  multilabelmarginloss_reference(i, t, reduction=get_reduction(m)),
2824  check_sum_reduction=True,
2825  check_gradgrad=False,
2826  ),
2827  dict(
2828  module_name='MultiLabelSoftMarginLoss',
2829  input_size=(5, 10),
2830  target_fn=lambda: torch.rand(5, 10).mul(2).floor(),
2831  reference_fn=lambda i, t, m: -(t * i.sigmoid().log() + (1 - t) * (-i).sigmoid().log()).sum() / i.numel(),
2832  check_gradgrad=False,
2833  ),
2834  dict(
2835  module_name='MultiMarginLoss',
2836  input_size=(5, 10),
2837  target_fn=lambda: torch.rand(5).mul(8).floor().long(),
2838  reference_fn=lambda i, t, m:
2839  multimarginloss_reference(i, t, reduction=get_reduction(m)),
2840  check_sum_reduction=True,
2841  check_gradgrad=False,
2842  ),
2843  dict(
2844  module_name='MultiMarginLoss',
2845  input_size=(10,),
2846  target_fn=lambda: torch.rand(1).mul(8).floor().long(),
2847  reference_fn=lambda i, t, m:
2848  multimarginloss_reference(i, t, reduction=get_reduction(m)),
2849  desc='1d',
2850  check_sum_reduction=True,
2851  check_gradgrad=False,
2852  ),
2853  dict(
2854  module_name='MultiMarginLoss',
2855  constructor_args=(2,),
2856  input_fn=lambda: torch.rand(5, 10).clamp_(1e-2, 1 - 1e-2),
2857  target_fn=lambda: torch.rand(5).mul(8).floor().long(),
2858  reference_fn=lambda i, t, m:
2859  multimarginloss_reference(i, t, p=2, reduction=get_reduction(m)),
2860  desc='p',
2861  check_sum_reduction=True,
2862  check_gradgrad=False,
2863  ),
2864  dict(
2865  module_name='MultiMarginLoss',
2866  constructor_args=(1, 0.5),
2867  legacy_constructor_args=(1, None, 0.5),
2868  input_size=(5, 10),
2869  target_fn=lambda: torch.rand(5).mul(8).floor().long(),
2870  reference_fn=lambda i, t, m:
2871  multimarginloss_reference(i, t, margin=0.5, reduction=get_reduction(m)),
2872  desc='margin',
2873  check_sum_reduction=True,
2874  check_gradgrad=False,
2875  ),
2876  dict(
2877  module_name='MultiMarginLoss',
2878  constructor_args=(1, 1., torch.rand(10)),
2879  legacy_constructor_args=(1, torch.rand(10)),
2880  input_size=(5, 10),
2881  target_fn=lambda: torch.rand(5).mul(8).floor().long(),
2882  reference_fn=lambda i, t, m:
2883  multimarginloss_reference(i, t, weight=get_weight(m), reduction=get_reduction(m)),
2884  desc='weights',
2885  check_sum_reduction=True,
2886  check_gradgrad=False,
2887  ),
2888  dict(
2889  module_name='SmoothL1Loss',
2890  input_size=(5, 10),
2891  target_size=(5, 10),
2892  check_sum_reduction=True,
2893  reference_fn=lambda i, t, m:
2894  smoothl1loss_reference(i, t, reduction=get_reduction(m)),
2895  ),
2896  dict(
2897  module_name='SoftMarginLoss',
2898  input_size=(5, 5),
2899  target_fn=lambda: torch.randn(5, 5).sign(),
2900  reference_fn=lambda i, t, m:
2901  softmarginloss_reference(i, t, reduction=get_reduction(m)),
2902  check_sum_reduction=True,
2903  ),
2904  dict(
2905  module_name='CosineEmbeddingLoss',
2906  input_fn=lambda: (torch.rand(15, 10), torch.rand(15, 10)),
2907  target_fn=lambda: torch.randn(15).sign(),
2908  reference_fn=lambda i, t, m:
2909  cosineembeddingloss_reference(i[0], i[1], t, reduction=get_reduction(m)),
2910  check_sum_reduction=True,
2911  ),
2912  dict(
2913  module_name='CosineEmbeddingLoss',
2914  constructor_args=(0.7,),
2915  input_fn=lambda: (torch.rand(15, 10), torch.rand(15, 10)),
2916  target_fn=lambda: torch.randn(15).sign(),
2917  reference_fn=lambda i, t, m:
2918  cosineembeddingloss_reference(i[0], i[1], t, margin=0.7, reduction=get_reduction(m)),
2919  desc='margin',
2920  check_sum_reduction=True,
2921  ),
2922  dict(
2923  module_name='MarginRankingLoss',
2924  input_fn=lambda: (torch.randn(50).mul(10), torch.randn(50).mul(10)),
2925  target_fn=lambda: torch.randn(50).sign(),
2926  reference_fn=lambda i, t, m:
2927  marginrankingloss_reference(i[0], i[1], t, reduction=get_reduction(m)),
2928  check_sum_reduction=True,
2929  ),
2930  dict(
2931  module_name='MarginRankingLoss',
2932  constructor_args=(0.5,),
2933  input_fn=lambda: (torch.randn(50).mul(10), torch.randn(50).mul(10)),
2934  target_fn=lambda: torch.randn(50).sign(),
2935  reference_fn=lambda i, t, m:
2936  marginrankingloss_reference(i[0], i[1], t, margin=0.5, reduction=get_reduction(m)),
2937  desc='margin',
2938  check_sum_reduction=True,
2939  ),
2940 ]
2941 
2942 
2944 
2945  def _jacobian(self, input, num_out):
2946  if isinstance(input, tuple):
2947  return tuple(self._jacobian(elem, num_out) for elem in input)
2948  elif isinstance(input, list):
2949  return [self._jacobian(elem, num_out) for elem in input]
2950  else:
2951  return torch.zeros(input.nelement(), num_out)
2952 
2953  def _flatten_tensors(self, x):
2954  if isinstance(x, torch.Tensor):
2955  if x.is_sparse:
2956  return x.to_dense().view(-1)
2957  else:
2958  return x.view(-1)
2959  else:
2960  return tuple(self._flatten_tensors(a) for a in x)
2961 
2962  def _zero_grad_input(self, input):
2963  if isinstance(input, torch.Tensor):
2964  if input.requires_grad and input.grad is not None:
2965  input.grad.zero_()
2966  input.grad.detach_()
2967  else:
2968  for i in input:
2969  self._zero_grad_input(i)
2970 
2971  def _analytical_jacobian(self, module, input, jacobian_input=True, jacobian_parameters=True):
2972  output = self._forward(module, input)
2973  output_size = output.nelement()
2974 
2975  if jacobian_input:
2976  jacobian_inp = self._jacobian(input, output_size)
2977  flat_jacobian_input = list(iter_tensors(jacobian_inp))
2978 
2979  if jacobian_parameters:
2980  num_param = sum(p.numel() for p in self._get_parameters(module)[0])
2981  jacobian_param = torch.zeros(num_param, output_size)
2982 
2983  for i in range(output_size):
2984  param, d_param = self._get_parameters(module)
2985  # make non grad zeros
2986  d_param = [torch.zeros_like(p) if d is None else d for (p, d) in zip(param, d_param)]
2987 
2988  d_out = torch.zeros_like(output)
2989  flat_d_out = d_out.view(-1)
2990  flat_d_out[i] = 1
2991 
2992  if jacobian_parameters:
2993  self._zero_grad_parameters(module)
2994  # Tensors will accumulate gradient from multiple steps
2995  if jacobian_input:
2996  self._zero_grad_input(input)
2997  d_input = self._backward(module, input, output, d_out)
2998 
2999  if jacobian_input:
3000  for jacobian_x, d_x in zip(flat_jacobian_input, iter_tensors(d_input)):
3001  jacobian_x[:, i] = d_x.contiguous().view(-1)
3002  if jacobian_parameters:
3003  jacobian_param[:, i] = torch.cat(self._flatten_tensors(d_param), 0)
3004 
3005  res = tuple()
3006  if jacobian_input:
3007  res += jacobian_inp,
3008  if jacobian_parameters:
3009  res += jacobian_param,
3010 
3011  return res
3012 
3013  def _numerical_jacobian(self, module, input, jacobian_input=True, jacobian_parameters=True):
3014  def fw(input):
3015  return self._forward(module, input).detach()
3016 
3017  res = tuple()
3018  if jacobian_input:
3019  res += get_numerical_jacobian(fw, input, eps=1e-6),
3020  if jacobian_parameters:
3021  param, _ = self._get_parameters(module)
3022  res += torch.cat([get_numerical_jacobian(fw, input, p, eps=1e-6) for p in param], 0),
3023  return res
3024 
3025  def check_jacobian(self, module, input, jacobian_input=True):
3026  jacobian_parameters = bool(self._get_parameters(module)[0])
3027  analytical = self._analytical_jacobian(module, input, jacobian_input, jacobian_parameters)
3028  numerical = self._numerical_jacobian(module, input, jacobian_input, jacobian_parameters)
3029  analytical_t = list(iter_tensors(analytical))
3030  numerical_t = list(iter_tensors(numerical))
3031 
3032  # TODO: compare structure
3033  self.assertLessEqual(
3034  max(a.add(-1, n).abs().max() for a, n in zip(analytical_t, numerical_t)),
3035  PRECISION
3036  )
3037 
3038  def check_criterion_jacobian(self, criterion, input, target):
3039  eps = 1e-6
3040  self._forward_criterion(criterion, input, target)
3041  analytical_d_x = self._backward_criterion(criterion, input, target)
3042  numerical_d_x = deepcopy(analytical_d_x)
3043 
3044  input_t = iter_tensors(input)
3045  numerical_t = iter_tensors(numerical_d_x)
3046  for x, d_x in zip(input_t, numerical_t):
3047  x = x.view(-1).data
3048  d_x = d_x.view(-1).data
3049  for i in range(x.nelement()):
3050  original = x[i].item()
3051  x[i] = original + eps
3052  fx1 = self._forward_criterion(criterion, input, target)
3053  x[i] = original - eps
3054  fx2 = self._forward_criterion(criterion, input, target)
3055  deriv = (fx1 - fx2) / (2. * eps)
3056  d_x[i] = float(deriv)
3057  x[i] = original
3058 
3059  # TODO: check structure
3060  analytical_t = list(iter_tensors(analytical_d_x))
3061  numerical_t = list(iter_tensors(numerical_d_x))
3062 
3063  self.assertLessEqual(
3064  max(a.add(-1, n).abs().max() for a, n in zip(analytical_t, numerical_t)),
3065  PRECISION
3066  )
3067 
3068 
3069 class TestBase(object):
3070 
3071  _required_arg_names = {'constructor_args', 'input', 'extra_args'}
3072 
3073  def __init__(self, constructor, desc='', reference_fn=None, fullname=None, **kwargs):
3074  self.desc = desc
3075  self.fullname = fullname
3076  self.constructor = constructor
3077  self.reference_fn = reference_fn
3078  for name in self._required_arg_names:
3079  if name not in kwargs and name + '_fn' not in kwargs and name + '_size' not in kwargs:
3080  if name in {'constructor_args', 'extra_args'}:
3081  kwargs[name] = tuple()
3082  else:
3083  raise ValueError("{}: Specify {} by a value, a function to generate it, or it's size!"
3084  .format(self.get_name(), name))
3085  self._extra_kwargs = kwargs
3086  self._arg_cache = {}
3087 
3088  def get_name(self):
3089  if self.fullname is not None:
3090  return 'test_' + self.fullname
3091 
3092  test_name = 'test_' + self.constructor.__name__
3093  if self.desc:
3094  test_name += '_' + self.desc
3095  return test_name
3096 
3097  def _unpack(self, value):
3098  if isinstance(value, torch.Tensor):
3099  return value
3100  elif is_iterable(value):
3101  return type(value)(self._unpack(v) for v in value)
3102  else:
3103  return value
3104 
3105  @property
3106  def constructor_args(self):
3107  return self._get_arg('constructor_args', True)
3108 
3109  @property
3110  def extra_args(self):
3111  return self._get_arg('extra_args', True)
3112 
3113  def _get_arg(self, name, unpack):
3114  assert name in self._required_arg_names
3115 
3116  if name not in self._arg_cache:
3117  fn_name = name + '_fn'
3118  size_name = name + '_size'
3119 
3120  if name in self._extra_kwargs:
3121  self._arg_cache[name] = self._extra_kwargs[name]
3122  elif fn_name in self._extra_kwargs:
3123  self._arg_cache[name] = self._extra_kwargs[fn_name]()
3124  else:
3125  assert size_name in self._extra_kwargs
3126 
3127  def map_tensor_sizes(sizes):
3128  if isinstance(sizes, list):
3129  return [map_tensor_sizes(s) for s in sizes]
3130  elif isinstance(sizes, torch.Tensor):
3131  return sizes.double()
3132  else:
3133  return torch.randn(sizes)
3134 
3135  self._arg_cache[name] = map_tensor_sizes(self._extra_kwargs[size_name])
3136 
3137  return self._unpack(self._arg_cache[name]) if unpack else self._arg_cache[name]
3138 
3139  def _get_input(self, unpack=True):
3140  return self._get_arg('input', unpack)
3141 
3142  def __call__(self, test_case):
3143  raise NotImplementedError
3144 
3145 
3147 
3148  def __init__(self, *args, **kwargs):
3149  super(ModuleTest, self).__init__(*args, **kwargs)
3150  self.jacobian_input = kwargs.get('jacobian_input', True)
3151  self.should_test_cuda = kwargs.get('test_cuda', True)
3152  self.should_test_pickle = kwargs.get('pickle', True)
3153  self.check_gradgrad = kwargs.get('check_gradgrad', True)
3155  kwargs.get('FIXME_no_cuda_gradgrad_comparison', False)
3156  self.precision = kwargs.get('precision', 2e-4)
3157 
3158  def __call__(self, test_case):
3159  module = self.constructor(*self.constructor_args)
3160  input = self._get_input()
3161 
3162  if self.reference_fn is not None:
3163  out = test_case._forward(module, input)
3164  ref_input = deepcopy(input)
3165  expected_out = self.reference_fn(ref_input, test_case._get_parameters(module)[0])
3166  test_case.assertEqual(out, expected_out)
3167  self.test_noncontig(test_case, module, input)
3168 
3169  if self.should_test_pickle:
3170  # TODO: do this with in-memory files as soon as torch.save will support it
3171  with TemporaryFile() as f:
3172  test_case._forward(module, input)
3173  torch.save(module, f)
3174  f.seek(0)
3175  module_copy = torch.load(f)
3176  test_case.assertEqual(test_case._forward(module, input), test_case._forward(module_copy, input))
3177 
3178  self._do_test(test_case, module, input)
3179 
3180  def noncontiguize(self, obj):
3181  if isinstance(obj, list):
3182  return [self.noncontiguize(o) for o in obj]
3183  tensor = obj
3184  ndim = tensor.dim()
3185  # Always making only the last dimension noncontiguous is easy to hide
3186  # bugs because .view(-1) will still work. So try to find a dim with size
3187  # > 1 and make that non-contiguous, i.e., stack + select on the
3188  # dimension directly after that.
3189  dim = ndim
3190  for d in range(ndim):
3191  if tensor.size(d) > 1:
3192  dim = d + 1
3193  break
3194  noncontig = torch.stack([torch.empty_like(tensor), tensor], dim).select(dim, 1).detach()
3195  assert noncontig.numel() == 1 or not noncontig.is_contiguous()
3196  noncontig.requires_grad = tensor.requires_grad
3197  return noncontig
3198 
3199  def test_noncontig(self, test_case, module, input):
3200  # check no scalars, can't make non-contig
3201  if isinstance(input, torch.Tensor) and input.dim() == 0:
3202  return
3203  if any(i.dim() == 0 for i in input if isinstance(i, torch.Tensor)):
3204  return
3205 
3206  test_case._zero_grad_parameters(module)
3207  test_case._zero_grad_input(input)
3208  with freeze_rng_state():
3209  output = test_case._forward(module, input)
3210  grad_output = output.new(output.shape).normal_()
3211  output = output.clone()
3212  d_input = deepcopy(test_case._backward(module, input, output, grad_output))
3213  d_param = deepcopy(test_case._get_parameters(module)[1])
3214 
3215  nc_input = self.noncontiguize(input)
3216  nc_grad_output = self.noncontiguize(grad_output)
3217  for contig_i, contig_g in product((True, False), repeat=2):
3218  i = input if contig_i else nc_input
3219  go = grad_output if contig_g else nc_grad_output
3220  test_case._zero_grad_parameters(module)
3221  test_case._zero_grad_input(i)
3222  with freeze_rng_state():
3223  out = test_case._forward(module, i)
3224  grad = test_case._backward(module, i, out, go)
3225 
3226  test_case.assertEqual(out, output)
3227  test_case.assertEqual(grad, d_input, 1e-4)
3228  test_case.assertEqual(test_case._get_parameters(module)[1], d_param)
3229 
3230  def test_cuda(self, test_case):
3231  if not TEST_CUDA or not self.should_test_cuda:
3232  raise unittest.SkipTest('Excluded from CUDA tests')
3233  try:
3234  cpu_input = self._get_input()
3235  type_map = {'torch.DoubleTensor': torch.cuda.FloatTensor}
3236  gpu_input = to_gpu(cpu_input, type_map=type_map)
3237 
3238  cpu_module = self.constructor(*self.constructor_args)
3239  gpu_module = self.constructor(*self.constructor_args).float().cuda()
3240  cpu_param = test_case._get_parameters(cpu_module)
3241  gpu_param = test_case._get_parameters(gpu_module)
3242  for cpu_p, gpu_p in zip(cpu_param[0], gpu_param[0]):
3243  gpu_p.data.copy_(cpu_p)
3244 
3245  test_case._zero_grad_input(cpu_input)
3246  test_case._zero_grad_input(gpu_input)
3247  test_case._zero_grad_parameters(cpu_module)
3248  test_case._zero_grad_parameters(gpu_module)
3249  cpu_output = test_case._forward(cpu_module, cpu_input)
3250  gpu_output = test_case._forward(gpu_module, gpu_input)
3251  test_case.assertEqual(cpu_output, gpu_output, self.precision)
3252 
3253  # Run backwards on CPU and GPU and compare results
3254  for _ in range(5):
3255  cpu_gradOutput = cpu_output.clone().normal_()
3256  gpu_gradOutput = cpu_gradOutput.type('torch.cuda.FloatTensor')
3257  cpu_gradInput = test_case._backward(cpu_module, cpu_input, cpu_output, cpu_gradOutput)
3258  gpu_gradInput = test_case._backward(gpu_module, gpu_input, gpu_output, gpu_gradOutput)
3259  test_case.assertEqual(cpu_gradInput, gpu_gradInput, self.precision)
3260  for cpu_d_p, gpu_d_p in zip(cpu_param[1], gpu_param[1]):
3261  test_case.assertEqual(cpu_d_p, gpu_d_p, self.precision)
3262 
3263  # Run double-backwards on CPU and GPU and compare results
3265  cpu_output = cpu_module(cpu_input)
3266  gpu_output = gpu_module(gpu_input)
3267 
3268  cpu_gradOutput = torch.randn_like(cpu_output, requires_grad=True)
3269  gpu_gradOutput = cpu_gradOutput.type_as(gpu_output).detach()
3270  gpu_gradOutput.requires_grad = True
3271 
3272  cpu_gradInputs = torch.autograd.grad(
3273  cpu_output,
3274  (cpu_input,) + tuple(cpu_module.parameters()),
3275  cpu_gradOutput,
3276  create_graph=True)
3277  gpu_gradInputs = torch.autograd.grad(
3278  gpu_output,
3279  (gpu_input,) + tuple(gpu_module.parameters()),
3280  gpu_gradOutput,
3281  create_graph=True)
3282 
3283  for cpu_d_i, gpu_d_i in zip(cpu_gradInputs, gpu_gradInputs):
3284  test_case.assertEqual(cpu_d_i, gpu_d_i, self.precision)
3285 
3286  # We mix output into the second backwards computation so that
3287  # torch.autograd.grad doesn't complain that some inputs
3288  # are unreachable (which can happen if you differentiate
3289  # only on the gradient.
3290  cpu_gg = torch.autograd.grad(
3291  cpu_output.sum() + sum(map(lambda x: x.sum(), cpu_gradInputs)),
3292  (cpu_input, cpu_gradOutput) + tuple(cpu_module.parameters()),
3293  retain_graph=True)
3294  gpu_gg = torch.autograd.grad(
3295  gpu_output.sum() + sum(map(lambda x: x.sum(), gpu_gradInputs)),
3296  (gpu_input, gpu_gradOutput) + tuple(gpu_module.parameters()),
3297  retain_graph=True)
3298 
3299  test_case.assertEqual(cpu_gradInput, gpu_gradInput, self.precision)
3300  for cpu_d_p, gpu_d_p in zip(cpu_gg, gpu_gg):
3301  test_case.assertEqual(cpu_d_p, gpu_d_p, self.precision)
3302 
3303  self.test_noncontig(test_case, gpu_module, gpu_input)
3304  except NotImplementedError:
3305  pass
3306  # TODO: remove this after CUDA scatter_ is implemented
3307  except AttributeError as e:
3308  if len(e.args) == 1 and "'FloatTensor' object has no attribute 'scatter_'" in e.args[0]:
3309  pass
3310  else:
3311  raise
3312 
3313 
3315 
3316  _required_arg_names = TestBase._required_arg_names.union({'target'})
3317 
3318  def __init__(self, *args, **kwargs):
3319  super(CriterionTest, self).__init__(*args, **kwargs)
3320  self.should_test_cuda = kwargs.get('test_cuda', True)
3321  self.check_forward_only = kwargs.get('check_forward_only', True)
3322 
3323  def _get_target(self):
3324  return self._get_arg('target', True)
3325 
3326  def __call__(self, test_case):
3327  module = self.constructor(*self.constructor_args)
3328  input = self._get_input()
3329 
3330  # Check that these methods don't raise errors
3331  module.__repr__()
3332  str(module)
3333 
3334  target = self._get_target()
3335 
3336  if self.reference_fn is not None:
3337  out = test_case._forward_criterion(module, input, target, extra_args=self.extra_args)
3338  ref_args = (deepcopy(input), deepcopy(target)) + self.extra_args + (module,)
3339  expected_out = self.reference_fn(*ref_args)
3340  test_case.assertEqual(out, expected_out)
3341 
3342  if self.check_forward_only:
3343  return
3344 
3345  test_case.check_criterion_jacobian(module, input, target)
3346  self._do_extra_tests(test_case, module, input, target)
3347 
3348  def test_cuda(self, test_case):
3349  if not TEST_CUDA or not self.should_test_cuda:
3350  raise unittest.SkipTest('Excluded from CUDA tests')
3351  try:
3352  cpu_input = self._get_input()
3353  type_map = {
3354  'torch.DoubleTensor': torch.cuda.FloatTensor,
3355  }
3356  gpu_input = to_gpu(cpu_input, type_map=type_map)
3357 
3358  cpu_target = self._get_target()
3359  gpu_target = to_gpu(cpu_target, type_map=type_map)
3360 
3361  cpu_module = self.constructor(*self.constructor_args)
3362  gpu_module = self.constructor(*self.constructor_args).float().cuda()
3363 
3364  cpu_output = test_case._forward_criterion(cpu_module, cpu_input, cpu_target)
3365  gpu_output = test_case._forward_criterion(gpu_module, gpu_input, gpu_target)
3366  test_case.assertEqual(cpu_output, gpu_output, 4e-4)
3367 
3368  gradOutput = torch.randn(())
3369  cpu_gradInput = test_case._backward_criterion(cpu_module, cpu_input, cpu_target, gradOutput)
3370  gpu_gradInput = test_case._backward_criterion(gpu_module, gpu_input, gpu_target, gradOutput)
3371  test_case.assertEqual(cpu_gradInput, gpu_gradInput, 4e-4)
3372  except NotImplementedError:
3373  pass
3374 
3375  def _do_extra_tests(self, test_case, module, input, target):
3376  pass
def get_name(self)
Definition: common_nn.py:3088
def _zero_grad_input(self, input)
Definition: common_nn.py:2962
def noncontiguize(self, obj)
Definition: common_nn.py:3180
def test_noncontig(self, test_case, module, input)
Definition: common_nn.py:3199
def _jacobian(self, input, num_out)
Definition: common_nn.py:2945
def _do_extra_tests(self, test_case, module, input, target)
Definition: common_nn.py:3375
def _get_input(self, unpack=True)
Definition: common_nn.py:3139
def _unpack(self, value)
Definition: common_nn.py:3097
def _analytical_jacobian(self, module, input, jacobian_input=True, jacobian_parameters=True)
Definition: common_nn.py:2971
dictionary _required_arg_names
Definition: common_nn.py:3071
def _numerical_jacobian(self, module, input, jacobian_input=True, jacobian_parameters=True)
Definition: common_nn.py:3013
def _get_arg(self, name, unpack)
Definition: common_nn.py:3113
def extra_args(self)
Definition: common_nn.py:3110
def _flatten_tensors(self, x)
Definition: common_nn.py:2953
def constructor_args(self)
Definition: common_nn.py:3106
def grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False, only_inputs=True, allow_unused=False)
Definition: __init__.py:97