1 from test_pytorch_common 
import TestCase, run_tests, skipIfNoLapack, flatten
     6 from torch.nn import Module, functional
    18 import common_utils 
as common
    21 '''Usage: python test/onnx/test_operators.py [--no-onnx] [--produce-onnx-test-data]    22           --no-onnx: no onnx python dependence    23           --produce-onnx-test-data: generate onnx test data    30 def export_to_pbtxt(model, inputs, *args, **kwargs):
    32         model, inputs, 
None, verbose=
False, google_printer=
True,
    36 def export_to_pb(model, inputs, *args, **kwargs):
    37     kwargs[
'operator_export_type'] = torch.onnx.OperatorExportTypes.ONNX
    45     def __init__(self, f, params=None):
    48         super(FuncModule, self).__init__()
    50         self.
params = nn.ParameterList(list(params))
    52     def forward(self, *args):
    53         return self.
f(*itertools.chain(args, self.
params))
    58     def assertONNX(self, f, args, params=None, **kwargs):
    61         if isinstance(f, nn.Module):
    66         onnx_model_pbtxt = export_to_pbtxt(m, args, **kwargs)
    67         subname = kwargs.pop(
'subname', 
None)
    68         self.assertExpected(onnx_model_pbtxt, subname)
    70             onnx_model_pb = export_to_pb(m, args, **kwargs)
    73             import onnx.numpy_helper
    74             import test_onnx_common
    75             model_def = onnx.ModelProto.FromString(onnx_model_pb)
    76             onnx.checker.check_model(model_def)
    78                 test_function = inspect.stack()[1][0].f_code.co_name
    79                 test_name = test_function[0:4] + 
"_operator" + test_function[4:]
    80                 output_dir = os.path.join(test_onnx_common.pytorch_operator_dir, test_name)
    84                 assert not os.path.exists(output_dir), 
"{} should not exist!".format(output_dir)
    85                 os.makedirs(output_dir)
    86                 with open(os.path.join(output_dir, 
"model.onnx"), 
'wb') 
as file:
    87                     file.write(model_def.SerializeToString())
    88                 data_dir = os.path.join(output_dir, 
"test_data_set_0")
    90                 if isinstance(args, Variable):
    92                 for index, var 
in enumerate(flatten(args)):
    93                     tensor = onnx.numpy_helper.from_array(var.data.numpy())
    94                     with open(os.path.join(data_dir, 
"input_{}.pb".format(index)), 
'wb') 
as file:
    95                         file.write(tensor.SerializeToString())
    97                 if isinstance(outputs, Variable):
    99                 for index, var 
in enumerate(flatten(outputs)):
   100                     tensor = onnx.numpy_helper.from_array(var.data.numpy())
   101                     with open(os.path.join(data_dir, 
"output_{}.pb".format(index)), 
'wb') 
as file:
   102                         file.write(tensor.SerializeToString())
   104     def assertONNXRaises(self, err, f, args, params=None, **kwargs):
   107         if isinstance(f, nn.Module):
   111         self.assertExpectedRaises(err, 
lambda: export_to_pbtxt(m, args, **kwargs))
   113     def assertONNXRaisesRegex(self, err, reg, f, args, params=None, **kwargs):
   116         if isinstance(f, nn.Module):
   120         with self.assertRaisesRegex(err, reg):
   121             export_to_pbtxt(m, args, **kwargs)
   123     def test_basic(self):
   126         self.
assertONNX(
lambda x, y: -torch.sigmoid(torch.tanh(x * (x + y))), (x, y))
   132     def test_index(self):
   136     def test_type_as(self):
   140     def test_addconstant(self):
   141         x = torch.randn(2, 3, requires_grad=
True).double()
   144     def test_add_broadcast(self):
   145         x = torch.randn(2, 3, requires_grad=
True).double()
   146         y = torch.randn(3, requires_grad=
True).double()
   149     def test_add_left_broadcast(self):
   150         x = torch.randn(3, requires_grad=
True).double()
   151         y = torch.randn(2, 3, requires_grad=
True).double()
   154     def test_add_size1_broadcast(self):
   155         x = torch.randn(2, 3, requires_grad=
True).double()
   156         y = torch.randn(2, 1, requires_grad=
True).double()
   159     def test_add_size1_right_broadcast(self):
   160         x = torch.randn(2, 3, requires_grad=
True).double()
   161         y = torch.randn(3, requires_grad=
True).double()
   164     def test_add_size1_singleton_broadcast(self):
   165         x = torch.randn(2, 3, requires_grad=
True).double()
   166         y = torch.randn(1, 3, requires_grad=
True).double()
   170         x = torch.randn(2, 3, requires_grad=
True).double()
   173     def test_transpose(self):
   174         x = 
torch.tensor([[0.0, 1.0], [2.0, 3.0]], requires_grad=
True)
   175         self.
assertONNX(
lambda x: x.transpose(0, 1).transpose(1, 0), x)
   177     def test_chunk(self):
   181     def test_split(self):
   182         x = 
torch.tensor([[0.0, 1.0, 1.0, 0.0, 2.0, 2.0], [2.0, 3.0, 3.0, 2.0, 1.0, 1.0]])
   183         self.
assertONNX(
lambda x: torch.split(x, 2, 1), x)
   185     def test_split_with_sizes(self):
   186         x = 
torch.tensor([[0.0, 1.0, 1.0, 0.0, 2.0, 2.0], [2.0, 3.0, 3.0, 2.0, 1.0, 1.0]])
   187         self.
assertONNX(
lambda x: torch.split(x, [2, 1, 3], 1), x)
   189     def test_concat2(self):
   190         x = torch.randn(2, 3)
   191         y = torch.randn(2, 3)
   192         self.
assertONNX(
lambda inputs: torch.cat(inputs, 1), ((x, y),))
   195         m1 = torch.randn(2, 3, requires_grad=
True)
   196         m2 = torch.randn(3, 4, requires_grad=
True)
   199     def test_addmm(self):
   200         m1 = torch.randn(2, 3, requires_grad=
True)
   201         m2 = torch.randn(3, 4, requires_grad=
True)
   202         m3 = torch.randn(4, requires_grad=
True)
   203         self.
assertONNX(
lambda x, y, z: torch.addmm(torch.addmm(z, x, y), x, y), (m1, m2, m3))
   205     def test_permute2(self):
   207         self.
assertONNX(
lambda x: x.permute(0, 1, 4, 2, 5, 3), x)
   210         x = 
torch.tensor([[[[0.0, 1.0, 1.0, 1.0], [2.0, 3.0, 7.0, 7.0]]]], requires_grad=
True)
   211         self.
assertONNX(nn.ReflectionPad2d((2, 3, 0, 1)), x)
   213     def test_params(self):
   214         x = 
torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=
True)
   215         y = nn.Parameter(
torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=
True))
   216         self.
assertONNX(
lambda x, y: -torch.sigmoid(torch.tanh(x * (x + y))), x, params=(y, ))
   218     def test_symbolic_mismatch(self):
   219         class MyFun(Function):
   227             def forward(ctx, x, y):
   234         with self.assertRaisesRegex(TypeError, 
"occurred when translating MyFun"):
   235             export_to_pbtxt(
FuncModule(MyFun().apply), (x, y))
   238     def test_batchnorm(self):
   239         x = torch.ones(2, 2, 2, 2, requires_grad=
True)
   242     def test_batchnorm_1d(self):
   243         x = torch.ones(2, 2, requires_grad=
True)
   246     def test_batchnorm_training(self):
   247         x = torch.ones(2, 2, 2, 2, requires_grad=
True)
   248         self.
assertONNX(nn.BatchNorm2d(2), x, training=
True)
   251         x = torch.ones(20, 16, 50, 40, requires_grad=
True)
   252         self.
assertONNX(nn.Conv2d(16, 13, 3, bias=
False), x)
   254     def test_convtranspose(self):
   255         x = torch.ones(2, 3, 4, 5, requires_grad=
True)
   256         self.
assertONNX(nn.ConvTranspose2d(3, 3, 3, stride=3, bias=
False,
   257                                            padding=1, output_padding=2), x)
   259     def test_maxpool(self):
   260         x = torch.randn(20, 16, 50)
   263     def test_avg_pool2d(self):
   264         x = torch.randn(20, 16, 50, 32)
   267     def test_maxpool_indices(self):
   268         x = torch.randn(20, 16, 50)
   269         self.
assertONNX(nn.MaxPool1d(3, stride=2, return_indices=
True), x)
   271     def test_at_op(self):
   272         x = torch.randn(3, 4)
   274         class MyFun(Function):
   278                 return g.at(
"add", x, x)
   284         class MyModule(Module):
   285             def forward(self, x):
   286                 return MyFun.apply(x)
   291         x = torch.randn(3, 4, requires_grad=
True)
   292         self.
assertONNX(
lambda x: torch.clamp(x, min=-0.5, max=0.5), x)
   294     def test_clip_min(self):
   295         x = torch.randn(1, 2, 3, 4, requires_grad=
True)
   296         self.
assertONNX(
lambda x: x.clamp(min=-0.1), x)
   298     def test_clip_max(self):
   299         x = torch.randn(1, 2, 3, 4, requires_grad=
True)
   300         self.
assertONNX(
lambda x: x.clamp(max=0.1), x)
   302     def test_hardtanh(self):
   303         x = torch.randn(3, 4, requires_grad=
True)
   304         self.
assertONNX(
lambda x: torch.nn.Hardtanh(-0.5, 0.5)(x), x)
   307         x = torch.randn(3, 4, requires_grad=
True)
   308         self.
assertONNX(
lambda x: torch.full(x.shape, 2), x)
   310     def test_full_like(self):
   311         x = torch.randn(3, 4, requires_grad=
True)
   312         self.
assertONNX(
lambda x: torch.full_like(x, 2), x)
   315         x = torch.randn(3, 4, requires_grad=
True)
   316         y = torch.randn(3, 4, requires_grad=
True)
   317         self.
assertONNX(
lambda x, y: torch.max(x, y), (x, y))
   320         x = torch.randn(3, 4, requires_grad=
True)
   321         y = torch.randn(3, 4, requires_grad=
True)
   322         self.
assertONNX(
lambda x, y: torch.min(x, y), (x, y))
   325         x = torch.randn(1, 2, 3, 4, requires_grad=
True)
   328     def test_reduced_mean(self):
   329         x = torch.randn(1, 2, 3, 4, requires_grad=
True)
   330         self.
assertONNX(
lambda x: torch.mean(x, dim=2), x)
   332     def test_reduced_mean_keepdim(self):
   333         x = torch.randn(1, 2, 3, 4, requires_grad=
True)
   334         self.
assertONNX(
lambda x: torch.mean(x, dim=2, keepdim=
True), x)
   337         x = torch.randn(1, 2, 3, 4, requires_grad=
True)
   340     def test_reduced_sum(self):
   341         x = torch.randn(1, 2, 3, 4, requires_grad=
True)
   342         self.
assertONNX(
lambda x: torch.sum(x, dim=2), x)
   344     def test_reduced_sum_keepdim(self):
   345         x = torch.randn(1, 2, 3, 4, requires_grad=
True)
   346         self.
assertONNX(
lambda x: torch.sum(x, dim=2, keepdim=
True), x)
   349         x = torch.randn(1, 2, 3, 4, requires_grad=
True)
   352     def test_reduced_prod(self):
   353         x = torch.randn(1, 2, 3, 4, requires_grad=
True)
   354         self.
assertONNX(
lambda x: torch.prod(x, dim=2), x)
   356     def test_reduced_prod_keepdim(self):
   357         x = torch.randn(1, 2, 3, 4, requires_grad=
True)
   358         self.
assertONNX(
lambda x: torch.prod(x, dim=2, keepdim=
True), x)
   361         x = torch.randn(3, 4, requires_grad=
True)
   364     def test_equal(self):
   365         x = torch.randn(1, 2, 3, 1, requires_grad=
False).int()
   366         y = torch.randn(1, 4, requires_grad=
False).int()
   370         x = torch.randn(1, 2, 3, 1, requires_grad=
False).int()
   371         y = torch.randn(1, 4, requires_grad=
False).int()
   375         x = torch.randn(1, 2, 3, 1, requires_grad=
False).int()
   376         y = torch.randn(1, 4, requires_grad=
False).int()
   380         x = torch.randn(3, 4, requires_grad=
False).int()
   381         y = torch.randn(3, 4, requires_grad=
False).int()
   385         x = torch.randn(3, 4, requires_grad=
False).int()
   386         y = torch.randn(3, 4, requires_grad=
False).int()
   390         x = torch.randn(3, 4, requires_grad=
True)
   394         x = torch.randn(3, 4, requires_grad=
True)
   398         x = torch.randn(3, 4, requires_grad=
True)
   402         x = torch.randn(3, 4, requires_grad=
True)
   406         x = torch.rand(3, 4, requires_grad=
True)
   410         x = torch.rand(3, 4, requires_grad=
True)
   413     def test_slice(self):
   414         x = torch.rand(3, 4, requires_grad=
True)
   417     def test_narrow(self):
   418         x = torch.randn(3, 3, requires_grad=
True)
   419         self.
assertONNX(
lambda x: torch.narrow(x, 0, 0, 2), x)
   422         x = torch.randn(3, 4, requires_grad=
True)
   425     def test_view_flatten(self):
   426         x = torch.randn(1, 2, 3, 4, requires_grad=
True)
   427         self.
assertONNX(
lambda x: x.view(x.size()[0], x.numel() // x.size()[0]), x)
   429     def test_flatten(self):
   430         x = torch.randn(1, 2, 3, 4, requires_grad=
True)
   431         self.
assertONNX(
lambda x: torch.flatten(x), x)
   433     def test_flatten2D(self):
   434         x = torch.randn(1, 2, 3, 4, requires_grad=
True)
   435         self.
assertONNX(
lambda x: torch.flatten(x, 1), x)
   437     def test_isnan(self):
   441     def test_argmax(self):
   442         x = torch.randn(4, 4, requires_grad=
True)
   443         self.
assertONNX(
lambda x: torch.argmax(x, dim=1), x)
   445     def test_logsoftmax(self):
   446         x = torch.randn(1, 2, 3, 4, requires_grad=
True)
   450         x = torch.randn(1, 2, 3, 4, requires_grad=
True)
   451         y = torch.randn(1, 2, 3, 4, requires_grad=
True)
   452         self.
assertONNX(
lambda x, y: x.pow(y), (x, y))
   455         x = torch.randn(1, 2, 3, 4, requires_grad=
True)
   459         x = torch.randn(1, 2, 3, 4, requires_grad=
True)
   462     def test_repeat(self):
   463         x = torch.randn(1, 2, 3, 4, requires_grad=
True)
   464         self.
assertONNX(
lambda x: x.repeat(1, 2, 3, 4), x)
   466     def test_repeat_dim_overflow(self):
   467         x = torch.randn(1, 2, requires_grad=
True)
   468         self.
assertONNX(
lambda x: x.repeat(1, 2, 3, 4), x)
   471         x = torch.randn(1, 2, 3, 4, requires_grad=
True)
   472         self.
assertONNX(
lambda x: x.norm(p=2, dim=2), (x))
   474     @unittest.skip(
"Temporary - waiting for https://github.com/onnx/onnx/pull/1773.")
   475     def test_upsample(self):
   476         x = torch.randn(1, 2, 3, 4, requires_grad=
True)
   477         self.
assertONNX(
lambda x: nn.functional.interpolate(x, scale_factor=2., mode=
'bilinear'), x)
   479     def test_unsqueeze(self):
   480         x = torch.randn(3, 4, requires_grad=
True)
   481         self.
assertONNX(
lambda x: x.unsqueeze(len(x.shape)), x)
   483     def test_batchnorm_noaffine(self):
   484         x = torch.randn(128, 128, 1, 1, requires_grad=
True)
   485         self.
assertONNX(nn.BatchNorm2d(128, affine=
False), x)
   487     def test_embedding_bags(self):
   488         emb_bag = nn.EmbeddingBag(10, 8)
   493     def test_implicit_expand(self):
   494         x = torch.randn(3, 4, requires_grad=
True)
   497     def test_reduce_sum_negative_indices(self):
   498         x = torch.randn(3, 4, requires_grad=
True)
   501     def test_randn(self):
   502         x = torch.randn(1, 2, 3, 4)
   503         self.
assertONNX(
lambda x: torch.randn(1, 2, 3, 4) + x, x)
   505     def test_rrelu(self):
   506         x = torch.randn(1, 2, 3, 4)
   509     def test_log_sigmoid(self):
   510         x = torch.randn(1, 2, 3, 4)
   513     def test_linear(self):
   514         x = torch.randn(3, 4)
   515         self.
assertONNX(torch.nn.Linear(4, 5, bias=
True), x)
   517     def test_zeros_like(self):
   518         x = torch.randn(5, 8, requires_grad=
True)
   519         self.
assertONNX(
lambda x: torch.zeros_like(x), x)
   521     def test_ones_like(self):
   522         x = torch.randn(6, 10, requires_grad=
True)
   523         self.
assertONNX(
lambda x: torch.ones_like(x), x)
   525     def test_expand(self):
   526         x = torch.randn(6, 1, requires_grad=
True)
   527         self.
assertONNX(
lambda x: x.expand(4, 6, 2), x)
   530         x = torch.randn(1, 2, 3, 1, requires_grad=
False).int()
   531         y = torch.randn(1, 4, requires_grad=
False).int()
   532         self.
assertONNX(
lambda x, y: torch.ne(x, y), (x, y))
   534     def test_reducemax(self):
   535         x = torch.randn(1, 2, 3, 4)
   538     def test_reducemin(self):
   539         x = torch.randn(1, 2, 3, 4)
   543         x = torch.randn(1, 2, 3, 4)
   546     def test_dropout(self):
   547         x = torch.randn(3, 4, requires_grad=
True)
   548         self.
assertONNX(
lambda x: torch.max(functional.dropout(x, training=
False)), x)
   550     def test_nonzero(self):
   551         x = 
torch.tensor([[[2., 2.], [1., 0.]], [[0., 0.], [1., 1.]]], requires_grad=
True)
   552         self.
assertONNX(
lambda x: torch.nonzero(x), x)
   554     def test_master_opset(self):
   555         x = torch.randn(2, 3).float()
   556         y = torch.randn(2, 3).float()
   557         self.
assertONNX(
lambda x, y: x + y, (x, y), opset_version=10)
   559     def test_retain_param_name_disabled(self):
   560         class MyModule(Module):
   562                 super(MyModule, self).__init__()
   563                 self.
fc1 = nn.Linear(4, 5, bias=
False)
   564                 self.fc1.weight.data.fill_(2.)
   565                 self.
fc2 = nn.Linear(5, 6, bias=
False)
   566                 self.fc2.weight.data.fill_(3.)
   568             def forward(self, x):
   569                 return self.
fc2(self.
fc1(x))
   571         x = torch.randn(3, 4).float()
   572         self.
assertONNX(MyModule(), (x,), _retain_param_name=
False)
   575 if __name__ == 
'__main__':
   576     no_onnx_dep_flag = 
'--no-onnx'   577     _onnx_dep = no_onnx_dep_flag 
not in common.UNITTEST_ARGS
   578     if no_onnx_dep_flag 
in common.UNITTEST_ARGS:
   579         common.UNITTEST_ARGS.remove(no_onnx_dep_flag)
   580     onnx_test_flag = 
'--produce-onnx-test-data'   581     _onnx_test = onnx_test_flag 
in common.UNITTEST_ARGS
   582     if onnx_test_flag 
in common.UNITTEST_ARGS:
   583         common.UNITTEST_ARGS.remove(onnx_test_flag)
   586         import test_onnx_common
   587         for d 
in glob.glob(os.path.join(test_onnx_common.pytorch_operator_dir, 
"test_operator_*")):
 def export_to_pretty_string(args, kwargs)
 
def assertONNX(self, f, args, params=None, kwargs)