Caffe2 - Python API
A deep learning, cross platform ML framework
test_jit.py
1 from __future__ import division
2 import torch
3 import torch.jit
4 import torch.nn as nn
5 import torch.nn.functional as F
6 import torch.nn.parallel as dp
7 import torch.optim as optim
8 import torch.cuda
10 from contextlib import contextmanager
11 from itertools import product, chain
12 import torch.jit.frontend
13 from torch.autograd import Variable, Function
14 from torch.nn import Module
15 from torch.autograd.function import traceable
16 from torch.testing import assert_allclose
17 from torch.onnx import OperatorExportTypes
18 from torch._six import inf, PY2, builtins
19 from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
20  skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
21  freeze_rng_state, set_rng_seed
22 from common_nn import module_tests, new_module_tests, criterion_tests
23 from textwrap import dedent
24 from functools import wraps
25 import os
26 import io
27 import itertools
28 import sys
29 import unittest
30 import inspect
31 import textwrap
32 import numpy as np
33 import tempfile
34 import shutil
35 import warnings
36 import math
37 import types
38 import pickle
39 import copy
40 
41 from common_methods_invocations import method_tests as autograd_method_tests
42 from common_methods_invocations import create_input, unpack_variables, \
43  exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL
44 from torch.testing import FileCheck
45 from torch._C import TensorType, TupleType, FloatType, IntType, \
46  ListType, StringType, DictType
47 from copy import deepcopy
48 import random
49 from typing import List, Dict, Optional, Tuple
50 from torch.jit.frontend import NotSupportedError
51 from torch.jit import BatchTensor
52 from torch import Tensor
53 from torch.jit.annotations import BroadcastingList2, BroadcastingList3
54 
55 # For testing truediv in python 2
56 from test_module.future_div import div_int_future, div_float_future
57 from test_module.no_future_div import div_int_nofuture, div_float_nofuture
58 
59 
60 # load_tests from common_utils is used to automatically filter tests for
61 # sharding on sandcastle. This line silences flake warnings
62 load_tests = load_tests
63 
64 try:
65  import torchvision
66  HAS_TORCHVISION = True
67 except ImportError:
68  HAS_TORCHVISION = False
69 
70 
71 skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
72 
73 RUN_CUDA = torch.cuda.is_available()
74 RUN_CUDA_HALF = RUN_CUDA
76  CUDA_VERSION = torch._C._cuda_getCompiledVersion()
77  for d in range(torch.cuda.device_count()):
79  if (CUDA_VERSION < 8000 and major >= 6) or (CUDA_VERSION < 9000 and major >= 7):
80  RUN_CUDA = False
81  if (CUDA_VERSION < 9000 or major < 6):
82  RUN_CUDA_HALF = False
83 
84 RUN_CUDA_MULTI_GPU = RUN_CUDA and torch.cuda.device_count() > 1
85 
86 PY35 = sys.version_info >= (3, 5)
87 WINDOWS = sys.platform == 'win32'
88 
89 
90 if WINDOWS:
91  @contextmanager
92  def TemporaryFileName():
93  # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
94  # opens the file, and it cannot be opened multiple times in Windows. To support Windows,
95  # close the file after creation and try to remove it manually
96  f = tempfile.NamedTemporaryFile(delete=False)
97  try:
98  f.close()
99  yield f.name
100  finally:
101  os.unlink(f.name)
102 else:
103  @contextmanager # noqa: T484
104  def TemporaryFileName():
105  with tempfile.NamedTemporaryFile() as f:
106  yield f.name
107 
108 
109 def LSTMCellF(input, hx, cx, *params):
110  return LSTMCell(input, (hx, cx), *params)
111 
112 
113 def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
114  hx, cx = hidden
115  gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
116 
117  ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
118  ingate = torch.sigmoid(ingate)
119  forgetgate = torch.sigmoid(forgetgate)
120  cellgate = torch.tanh(cellgate)
121  outgate = torch.sigmoid(outgate)
122 
123  cy = (forgetgate * cx) + (ingate * cellgate)
124  hy = outgate * torch.tanh(cy)
125  return hy, cy
126 
127 
128 def LSTMCellC(*args, **kwargs):
129  hy, cy = LSTMCellF(*args, **kwargs)
130  return torch.cat((hy, cy))
131 
132 
133 def LSTMCellS(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
134  gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh
135  ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
136  ingate = torch.sigmoid(ingate)
137  forgetgate = torch.sigmoid(forgetgate)
138  cellgate = torch.tanh(cellgate)
139  outgate = torch.sigmoid(outgate)
140  cy = (forgetgate * cx) + (ingate * cellgate)
141  hy = outgate * torch.tanh(cy)
142  return hy, cy
143 
144 
145 # Code reference: https://github.com/pytorch/translate/blob/master/pytorch_translate/rnn_cell.py#L27:44
146 def MiLSTMCell(x, hx, cx, w_ih, w_hh, alpha, beta_i, beta_h, bias):
147  Wx = x.mm(w_ih.t())
148  Uz = hx.mm(w_hh.t())
149  # Section 2.1 in https://arxiv.org/pdf/1606.06630.pdf
150  gates = alpha * Wx * Uz + beta_i * Wx + beta_h * Uz + bias
151  # Same as LSTMCell after this point
152  ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
153  ingate = ingate.sigmoid()
154  forgetgate = forgetgate.sigmoid()
155  cellgate = cellgate.tanh()
156  outgate = outgate.sigmoid()
157  cy = (forgetgate * cx) + (ingate * cellgate)
158  hy = outgate * cy.tanh()
159  return hy, cy
160 
161 
162 def canonical(graph):
163  return str(torch._C._jit_pass_canonicalize(graph))
164 
165 
166 def get_lstm_inputs(device, training=False, seq_length=None):
167  input_shape = (3, 10) if seq_length is None else (seq_length, 3, 10)
168  input = torch.randn(*input_shape, dtype=torch.float, device=device, requires_grad=training)
169  hx = torch.randn(3, 20, dtype=torch.float, device=device, requires_grad=training)
170  cx = torch.randn(3, 20, dtype=torch.float, device=device, requires_grad=training)
171  module = nn.LSTMCell(10, 20).to(device, torch.float) # Just to allocate weights with correct sizes
172  if training:
173  params = tuple(module.parameters())
174  else:
175  params = tuple(p.requires_grad_(False) for p in module.parameters())
176  return (input, hx, cx) + params
177 
178 
179 def get_milstm_inputs(device, training=False):
180  minibatch = 3
181  input_size = 10
182  hidden_size = 20
183  x = torch.randn(minibatch, input_size, device=device, dtype=torch.float)
184  hx = torch.randn(minibatch, hidden_size, device=device, dtype=torch.float)
185  cx = torch.randn(minibatch, hidden_size, device=device, dtype=torch.float)
186 
187  ih = torch.randn(4 * hidden_size, input_size, device=device, dtype=torch.float, requires_grad=training)
188  hh = torch.randn(4 * hidden_size, hidden_size, device=device, dtype=torch.float, requires_grad=training)
189  alpha = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
190  ibeta = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
191  hbeta = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
192  bias = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
193  return x, hx, cx, ih, hh, alpha, ibeta, hbeta, bias
194 
195 
196 def get_fn(file_name, script_path):
197  import importlib.util
198  spec = importlib.util.spec_from_file_location(file_name, script_path)
199  module = importlib.util.module_from_spec(spec)
200  spec.loader.exec_module(module)
201  fn = module.fn
202  return fn
203 
204 
205 def get_execution_plan(graph_executor_state):
206  execution_plans = list(graph_executor_state.execution_plans.values())
207  num_plans = len(execution_plans)
208  if num_plans != 1:
209  raise RuntimeError('This test assumes this GraphExecutor should '
210  'only have one execution plan, got: {}'.format(num_plans))
211  return execution_plans[0]
212 
213 
214 def get_grad_executor(plan_state, diff_graph_idx=None):
215  if diff_graph_idx is None:
216  nodes = list(plan_state.graph.nodes())
217  if len(nodes) == 1 or (len(nodes) == 2 and nodes[1].kind() == "prim::TupleConstruct"):
218  pass
219  else:
220  raise RuntimeError("Can't get a grad_executor for a non-differentiable graph")
221  grad_executors = list(plan_state.code.grad_executors())
222  return grad_executors[diff_graph_idx or 0]
223 
224 
225 def backward_graph(script_module, diff_graph_idx=None):
226  if not isinstance(script_module, torch.jit.ScriptModule):
227  raise RuntimeError('Expected ScriptModule')
228  ge_state = script_module.get_debug_state()
229  fwd_plan = get_execution_plan(ge_state)
230  grad_executor = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx)
231  bwd_plan = get_execution_plan(grad_executor.get_debug_state())
232  # Running JIT passes requires that we own the graph (with a shared_ptr).
233  # The debug state struct does not own its graph so we make a copy of it.
234  return bwd_plan.graph.copy()
235 
236 
237 # make it easy to quicky define/trace a function for these tests
238 def _trace(*args, **kwargs):
239  def wrapper(func):
240  return torch.jit.trace(func, args, **kwargs)
241  return wrapper
242 
243 
244 def enable_cpu_fuser(fn):
245  def wrapper(*args, **kwargs):
246  torch._C._jit_override_can_fuse_on_cpu(True)
247  try:
248  fn(*args, **kwargs)
249  finally:
250  torch._C._jit_override_can_fuse_on_cpu(False)
251  return wrapper
252 
253 
255  _do_cuda_memory_leak_check = True
256  _restored_warnings = False
257 
258  def setUp(self):
259  # unittest overrides all warning filters and forces all of them to show up
260  # after we install our own to silence those coming from inside PyTorch.
261  # This will ensure that our filter still takes precedence.
262  if not JitTestCase._restored_warnings:
264  JitTestCase._restored_warnings = True
265  torch._C._jit_set_emit_module_hook(self.emitModuleHook)
266 
267  def tearDown(self):
268  # needs to be cleared because python might be unloaded before
269  # the callback gets destucted
270  torch._C._jit_set_emit_module_hook(None)
271  torch._C._jit_clear_class_registry()
272 
273  @contextmanager
274  def disableModuleHook(self):
275  torch._C._jit_set_emit_module_hook(None)
276  yield None
277  torch._C._jit_set_emit_module_hook(self.emitModuleHook)
278 
279  def emitModuleHook(self, module):
280  def copy_structure_and_params(m):
282  for name, v in m._get_parameters():
283  c._register_parameter(name, v, False)
284  for name, the_type, v in m._get_attributes():
285  c._register_attribute(name, the_type, v)
286  for name, s in m._get_modules():
287  c._register_module(name, copy_structure_and_params(s))
288  return c
289 
290  # disable the hook while we parse code, otherwise we will re-enter the hook
291  with self.disableModuleHook():
292  try:
293  pp, constant_table = module._python_print()
294  except RuntimeError as e:
295  se = str(e)
296  if "could not export python function" not in se and \
297  "closures are not exportable" not in se:
298  raise
299  else:
300  return
301  ppv = "op_version_set = 0\n{}".format(pp)
302  sm = copy_structure_and_params(module)
303  torch._C._jit_import_methods(sm, ppv, constant_table)
304  pp2, _ = sm._python_print()
305  if pp != pp2:
306  self.assertMultiLineEqual(pp, pp2)
307 
308  def getExportImportCopy(self, m, also_test_file=True, map_location=None):
309  buffer = io.BytesIO()
310  torch.jit.save(m, buffer)
311  buffer.seek(0)
312  imported = torch.jit.load(buffer, map_location=map_location)
313 
314  if not also_test_file:
315  return imported
316 
317  with TemporaryFileName() as fname:
318  imported.save(fname)
319  return torch.jit.load(fname, map_location=map_location)
320 
321  def getExportImportCopyWithPacking(self, m, also_test_file=True, map_location=None):
322  buffer = io.BytesIO()
323  m.apply(lambda s: s._pack() if s._has_method('_pack') else None)
324  torch.jit.save(m, buffer)
325  m.apply(lambda s: s._unpack() if s._has_method('_unpack') else None)
326  buffer.seek(0)
327  imported = torch.jit.load(buffer, map_location=map_location)
328  imported.apply(lambda s: s._unpack() if s._has_method('_unpack') else None)
329 
330  if not also_test_file:
331  return imported
332 
333  # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
334  # opens the file, and it cannot be opened multiple times in Windows. To support Windows,
335  # close the file after creation and try to remove it manually
336  f = tempfile.NamedTemporaryFile(delete=False)
337  try:
338  f.close()
339  imported.save(f.name)
340  result = torch.jit.load(f.name, map_location=map_location)
341  finally:
342  os.unlink(f.name)
343 
344  result.apply(lambda s: s._unpack() if s._has_method('_unpack') else None)
345  return result
346 
347  def assertGraphContains(self, graph, kind):
348  self.assertTrue(any(n.kind() == kind for n in graph.nodes()))
349 
350  def assertGraphContainsExactly(self, graph, kind, num_kind_nodes, consider_subgraphs=False):
351  def perform_assert(graph, kind, actual, expected, consider_subgraphs):
352  if actual == expected:
353  return
354  subgraph = 'including' if consider_subgraphs else 'excluding'
355  raise AssertionError(
356  '{}\nError: graph contains {} {} nodes ({} subgraphs) but expected {}'.format(
357  graph, actual, kind, subgraph, expected))
358 
359  if consider_subgraphs:
360  strgraph = str(graph)
361  count = strgraph.count(kind) - strgraph.count('with {}'.format(kind))
362  perform_assert(graph, kind, count, num_kind_nodes,
363  consider_subgraphs)
364  return
365 
366  nodes = [node for node in graph.nodes()
367  if node.kind() == kind]
368  perform_assert(graph, kind, len(nodes), num_kind_nodes,
369  consider_subgraphs)
370 
371  def assertExpectedONNXGraph(self, trace, *args, **kwargs):
372  torch.onnx._optimize_trace(trace, operator_export_type=OperatorExportTypes.ONNX)
373  self.assertExpectedGraph(trace, *args, **kwargs)
374 
375  def assertExpectedGraph(self, trace, *args, **kwargs):
376  if isinstance(trace, torch._C.Graph):
377  graph = trace
378  else:
379  graph = trace.graph()
380 
381  torch._C._jit_pass_lint(graph)
382  torch._C._jit_pass_dce(graph)
383  torch._C._jit_pass_lint(graph)
384  graph = torch._C._jit_pass_canonicalize(graph)
385  torch._C._jit_pass_lint(graph)
386  self.assertExpected(str(graph), *args, **kwargs)
387 
388  def run_pass(self, name, trace):
389  if isinstance(trace, torch._C.Graph):
390  graph = trace
391  set_graph = False
392  else:
393  set_graph = True
394  graph = trace.graph()
395 
396  torch._C._jit_pass_lint(graph)
397  result = getattr(torch._C, '_jit_pass_' + name)(graph)
398  if result is not None:
399  graph = result
400  torch._C._jit_pass_lint(graph)
401 
402  if set_graph:
403  trace.set_graph(graph)
404  return graph
405 
406  def checkScript(self,
407  script,
408  inputs,
409  optimize=True,
410  outputs=None,
411  name='func',
412  capture_output=False,
413  frames_up=1,
414  check_expected=False):
415  if isinstance(script, str):
416  cu = torch.jit.CompilationUnit(script, optimize, _frames_up=frames_up)
417  ge = getattr(cu, name)
418  else:
419  if capture_output:
420  with self.capture_stdout() as captured:
421  outputs = script(*inputs)
422  else:
423  outputs = script(*inputs)
424  # Check the string frontend first
425  source = textwrap.dedent(inspect.getsource(script))
426  self.checkScript(
427  source,
428  inputs,
429  optimize,
430  outputs,
431  script.__name__,
432  capture_output,
433  frames_up=2,
434  check_expected=check_expected)
435  # Continue checking the Python frontend
436  ge = torch.jit.script(script, optimize, _frames_up=1)
437 
438  if capture_output:
439  with self.capture_stdout() as captured:
440  outputs_ge = ge(*inputs)
441  if not WINDOWS:
442  self.assertExpected(captured[0], subname='stdout')
443  else:
444  outputs_ge = ge(*inputs)
445  self.assertEqual(outputs, outputs_ge)
446 
447  if check_expected:
448  self.assertExpectedGraph(ge.graph)
449 
450  return ge
451 
452  def checkTrace(self, func, reference_tensors, input_tensors=None,
453  optimize=True, drop=None, allow_unused=False, verbose=False,
454  inputs_require_grads=True, check_tolerance=1e-5, export_import=True,
455  _force_outplace=False):
456  # TODO: check gradients for parameters, not just inputs
457  def allSum(vs):
458  # drop allows us to remove some values from ever being used
459  # to test unused outputs
460  if drop is not None:
461  vs = vs[:-drop]
462  # we don't want all the grad for all the outputs to be the same
463  # so we multiply each by a constant
464  return sum(math.log(i + 2) * v.sum() for i, v in enumerate(vs) if v is not None)
465  if input_tensors is None:
466  input_tensors = reference_tensors
467 
468  nograd_inputs = reference_tensors
469  if inputs_require_grads:
470  recording_inputs = [t.clone().requires_grad_() for t in reference_tensors]
471  else:
472  recording_inputs = reference_tensors
473 
474  if isinstance(func, torch._C.Graph):
475  ge = torch._C.GraphExecutor(func, optimize)
476  else:
477  ge = torch.jit.trace(func, input_tensors, optimize=optimize, check_tolerance=check_tolerance,
478  _force_outplace=_force_outplace)
479 
480  if export_import:
481  ge = self.getExportImportCopy(ge)
482 
483  if verbose:
484  print(ge.graph)
485 
486  # test no gradients case
487  outputs = func(*nograd_inputs)
488  outputs_ge = ge(*nograd_inputs)
489  self.assertEqual(outputs, outputs_ge)
490 
491  # test single grad case
492  outputs = func(*recording_inputs)
493  if inputs_require_grads:
494  grads = torch.autograd.grad(allSum(outputs), recording_inputs,
495  allow_unused=allow_unused)
496 
497  outputs_ge = ge(*recording_inputs)
498  if inputs_require_grads:
499  grads_ge = torch.autograd.grad(allSum(outputs_ge), recording_inputs,
500  allow_unused=allow_unused)
501  self.assertEqual(outputs, outputs_ge)
502  if inputs_require_grads:
503  self.assertEqual(grads, grads_ge)
504 
505  # test the grad grad case
506 
507  outputs = func(*recording_inputs)
508  l1 = allSum(outputs)
509  if inputs_require_grads:
510  grads = torch.autograd.grad(l1, recording_inputs, create_graph=True,
511  allow_unused=allow_unused)
512  if inputs_require_grads:
513  l2 = (allSum(grads) * l1)
514  grads2 = torch.autograd.grad(l2, recording_inputs, allow_unused=allow_unused)
515 
516  if inputs_require_grads:
517  recording_inputs = [Variable(t, requires_grad=True)
518  for t in reference_tensors]
519 
520  outputs_ge = ge(*recording_inputs)
521  l1_ge = allSum(outputs_ge)
522  if inputs_require_grads:
523  grads_ge = torch.autograd.grad(
524  l1_ge, recording_inputs, create_graph=True, allow_unused=allow_unused)
525 
526  if inputs_require_grads:
527  l2_ge = (allSum(grads_ge) * l1_ge)
528  grads2_ge = torch.autograd.grad(l2_ge, recording_inputs, allow_unused=allow_unused)
529 
530  self.assertEqual(outputs, outputs_ge)
531  if inputs_require_grads:
532  self.assertEqual(grads, grads_ge)
533  for g2, g2_ge in zip(grads2, grads2_ge):
534  if g2 is None and g2_ge is None:
535  continue
536  self.assertTrue(torch.allclose(g2, g2_ge, atol=8e-4, rtol=8e-4))
537 
538  return ge
539 
540  def createScriptModuleFromGraph(self, trace):
541  graph = trace if isinstance(trace, torch._C.Graph) else trace.graph()
543  m._create_method_from_graph("forward", graph)
544  return m
545 
546  def assertExportImport(self, trace, inputs):
547  m = self.createScriptModuleFromGraph(trace)
548  self.assertExportImportModule(m, inputs)
549 
550  def assertExportImportModule(self, m, inputs):
551  m_import = self.getExportImportCopy(m)
552  self.assertEqual(self.runAndSaveRNG(m.forward, inputs),
553  self.runAndSaveRNG(m_import.forward, inputs))
554 
555  def runAndSaveRNG(self, func, inputs, kwargs=None):
556  kwargs = kwargs if kwargs else {}
557  with freeze_rng_state():
558  results = func(*inputs, **kwargs)
559  return results
560 
561 
562 # has to be at top level or Pickle complains
563 class FooToPickle(torch.nn.Module):
564  def __init__(self):
565  super(FooToPickle, self).__init__()
566  self.bar = torch.jit.ScriptModule()
567 
568 
570 
571  @unittest.skip("Requires a lot of RAM")
572  def test_big(self):
574  gig = int(1024 * 1024 * 1024 / 4)
575  # a small tensor in the first 4GB
576  m.v0 = nn.Parameter(torch.full((2,), 1, dtype=torch.float))
577  # a large tensor in the first 4GB that ends outside of it
578  m.v1 = nn.Parameter(torch.full((5, gig), 2, dtype=torch.float))
579  # a small tensor in >4GB space
580  m.v2 = nn.Parameter(torch.full((2,), 3, dtype=torch.float))
581  # s large tensor in the > 4GB space
582  m.v3 = nn.Parameter(torch.full((5, gig), 4, dtype=torch.float))
583 
584  m2 = self.getExportImportCopy(m)
585 
586  self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
587 
588  def test_simple(self):
589  x = torch.tensor([0.4], requires_grad=True)
590  y = torch.tensor([0.7], requires_grad=True)
591 
592  def f(x, y):
593  return torch.sigmoid(torch.tanh(x * (x + y)))
594 
595  self.checkTrace(f, (x, y))
596 
597  def test_restore_device(self):
598  # main purpose is checking map_location works
600  cpu_device_str = 'cpu'
601  m.p0 = nn.Parameter(torch.tensor([0.3], dtype=torch.float,
602  device=cpu_device_str))
603  m.register_buffer('b0', torch.tensor([0.9], dtype=torch.float,
604  device=cpu_device_str))
605  m2 = self.getExportImportCopy(m)
606  self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
607  self.assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
608  self.assertFalse(m2.p0.is_cuda)
609  self.assertFalse(m2.b0.is_cuda)
610 
611  def test_model_save_error(self):
612  with TemporaryFileName() as fname:
613  with self.assertRaisesRegex(pickle.PickleError, "not supported"):
614  torch.save(FooToPickle(), fname)
615 
616  def test_single_tuple_trace(self):
617  x = torch.tensor(2.)
618 
619  def f2(x):
620  return (x,)
621  jit_f2 = torch.jit.trace(f2, x)
622  assert f2(x) == jit_f2(x) # fails
623 
624  @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA")
625  def test_restore_device_cuda(self):
626  class MyModule(torch.jit.ScriptModule):
627  def __init__(self):
628  super(MyModule, self).__init__(False)
629  self.register_buffer('b0', torch.randn(1, 3))
630  self.p0 = nn.Parameter(torch.randn(2, 3))
631 
632  @torch.jit.script_method
633  def forward(self, x):
634  return x + self.b0 + self.p0
635 
636  m = MyModule()
637  m.cuda(torch.cuda.device_count() - 1)
638  cuda_device_str = 'cuda:' + str(torch.cuda.device_count() - 1)
639 
640  self.assertTrue(m.p0.is_cuda)
641  self.assertTrue(m.b0.is_cuda)
642 
643  # restore to the saved devices
644  m2 = self.getExportImportCopy(m)
645  self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
646  self.assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
647  self.assertEqual(str(m2.p0.device), cuda_device_str)
648  self.assertEqual(str(m2.b0.device), cuda_device_str)
649 
650  # restore all to cpu using string
651  cpu_device_str = 'cpu'
652  m3 = self.getExportImportCopy(m, map_location=cpu_device_str)
653  self.assertEqual(str(m3.p0.device), cpu_device_str)
654  self.assertEqual(str(m3.b0.device), cpu_device_str)
655 
656  # restore all to first gpu using device
657  m4 = self.getExportImportCopy(
658  m3, map_location=torch.device('cuda:0'))
659  self.assertEqual(str(m4.p0.device), 'cuda:0')
660  self.assertEqual(str(m4.b0.device), 'cuda:0')
661 
662  # compute and compare the results
663  input = torch.rand(2, 3).cuda(torch.cuda.device_count() - 1)
664  origin_result = m(input)
665  self.assertEqual(origin_result, m2(input))
666  self.assertEqual(origin_result, m3(input.cpu()))
667  self.assertEqual(origin_result, m4(input.cuda(0)))
668 
669  @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA")
670  def test_restore_shared_storage_on_cuda(self):
671  whole_tensor = torch.randn(4, 5, dtype=torch.float, device='cpu')
673  m.p0 = nn.Parameter(whole_tensor.narrow(0, 0, 1))
674  m.register_buffer('b0', whole_tensor.narrow(0, 3, 1))
675  m2 = self.getExportImportCopy(m, map_location=torch.device('cuda:0'))
676  self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
677  self.assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
678  self.assertTrue(m2.p0.is_cuda)
679  self.assertTrue(m2.b0.is_cuda)
680  self.assertTrue(m2.p0.is_shared())
681  self.assertTrue(m2.b0.is_shared())
682  self.assertEqual(m2.b0.storage().data_ptr(), m2.p0.storage().data_ptr())
683 
684  def test_typeas_trace_check(self):
685  a = torch.tensor([0.4], requires_grad=True)
686  b = torch.tensor([0.7], requires_grad=True)
687 
688  def f(x, y):
689  return x.type_as(y)
690 
691  trace = torch.jit.trace(f, (a, b))
692 
693  def test_peephole(self):
694  a = torch.tensor([0.4])
695  b = torch.tensor([0.7])
696  c = torch.tensor([0], dtype=torch.int32)
697 
698  def f(x, y):
699  return x.type_as(y)
700 
701  tf = torch.jit.trace(f, (a, b))
702  FileCheck().check("type_as").run(str(tf.graph))
703  self.run_pass('peephole', tf.graph)
704  FileCheck().check_not("type_as").run(str(tf.graph))
705  tf2 = torch.jit.trace(f, (a, c))
706  s = str(tf2.graph)
707  self.run_pass('peephole', tf2.graph)
708  self.assertEqual(s, str(s))
709 
710  def test_peephole_dynamic(self):
711  def f(x, y):
712  return x.type_as(y)
713 
714  fn = torch.jit.script(f)
715  s = str(fn.graph)
716  torch._C._jit_pass_peephole(fn.graph)
717  self.assertEqual(s, str(fn.graph))
718 
719  @unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA")
720  def test_peephole_cuda(self):
721  a = torch.tensor([0.4], device='cpu')
722  b = torch.tensor([0.7], device='cuda')
723  c = torch.tensor([0.7], device='cuda')
724 
725  def f(x, y):
726  return x.type_as(y)
727 
728  trace = torch.jit.trace(f, (a, c))
729  s = str(trace.graph)
730  self.run_pass('peephole', trace.graph)
731  self.assertEqual(s, str(trace.graph))
732  trace = torch.jit.trace(f, (b, c))
733  self.run_pass('peephole', trace.graph)
734  self.assertTrue(len(list(trace.graph.nodes())) == 0)
735 
736  def test_index(self):
737  x = torch.tensor([0.4], requires_grad=True)
738  y = torch.tensor([0], dtype=torch.int64)
739 
740  def fn(x, y):
741  return x[y]
742 
743  fn_traced = torch.jit.trace(fn, (x, y,))
744 
745  self.assertEqual(fn(x, y), fn_traced(x, y))
746 
747  def test_disabled(self):
748  torch.jit._enabled = False
749  try:
750  def f(x, y):
751  return x + y
752 
753  self.assertIs(torch.jit.trace(f, (torch.randn(2, 2), torch.randn(2, 2))), f)
754  self.assertIs(torch.jit.script(f), f)
755 
756  class MyModule(torch.jit.ScriptModule):
757  @torch.jit.script_method
758  def method(self, x):
759  return x
760 
761  # XXX: Unfortunately ScriptModule won't simply become Module now,
762  # because that requires disabling the JIT at startup time, which
763  # we can't do in here.
764  # We need to or those two conditions to make it work with all versions of Python
765  self.assertTrue(inspect.ismethod(MyModule.method) or inspect.isfunction(MyModule.method))
766  finally:
767  torch.jit._enabled = True
768 
769  def test_train_eval(self):
770  class Sub(nn.Module):
771  def forward(self, input):
772  if self.training:
773  return input
774  else:
775  return -input
776 
777  class MyModule(torch.jit.ScriptModule):
778  def __init__(self, module):
779  super(MyModule, self).__init__()
780  self.module = module
781 
782  @torch.jit.script_method
783  def forward(self, input):
784  return self.module(input) + 1
785 
786  m = MyModule(Sub())
787  input = torch.rand(3, 4)
788  self.assertEqual(input + 1, m(input))
789  m.eval()
790  self.assertEqual(-input + 1, m(input))
791 
792  # test batchnorm and dropout train/eval
793  input = torch.randn(6, 10)
794  batchnorm = nn.BatchNorm1d(10)
795  dropout = nn.Dropout(p=0.2)
796 
797  m_batchnorm = MyModule(batchnorm)
798  self.assertEqual(batchnorm(input) + 1, m_batchnorm(input))
799  batchnorm.eval()
800  m_batchnorm.eval()
801  self.assertEqual(batchnorm(input) + 1, m_batchnorm(input))
802 
803  m_dropout = MyModule(dropout)
804  dropout.eval()
805  m_dropout.eval()
806  self.assertEqual(dropout(input) + 1, m_dropout(input))
807 
808  def test_diff_subgraph_clones_constants(self):
809  @torch.jit.script
810  def f(x, y):
811  return x + x + y + x + y + x + y + x + y + x
812 
813  def count_constants(graph):
814  return sum(node.kind() == 'prim::Constant' for node in graph.nodes())
815 
816  graph = f.graph.copy()
817  self.run_pass('cse', graph)
818  self.run_pass('create_autodiff_subgraphs', graph)
819  nodes = list(graph.nodes())
820  self.assertEqual(count_constants(graph), 1)
821  self.assertEqual(count_constants(nodes[1].g('Subgraph')), 1)
822 
823  # Backwards tracing was broken for indexing by a constant,
824  # because it's internally implemented using as_strided,
825  # and we attempted to trace its derivative (which is not
826  # currently supported.) It currently works because
827  # slice() is now not marked as traceable.
828  def test_index_constant(self):
829  x = torch.tensor([0.4], requires_grad=True)
830 
831  def fn(x):
832  return x[0]
833 
834  def run(f):
835  y = f(x)
836  grad = torch.autograd.grad(y, x)[0].clone()
837  return y, grad
838 
839  traced_fn = torch.jit.trace(fn, torch.ones(1))
840  self.assertEqual(run(fn), run(traced_fn))
841 
842  def test_scopes(self):
843  x = torch.tensor([0.4], requires_grad=True)
844  y = torch.tensor([0.7], requires_grad=True)
845 
846  def f(x, y):
847  out = x + y
848  with torch.jit.scope('Foo'):
849  out = x * out
850  with torch.jit.scope('Bar'):
851  out = torch.tanh(out)
852  out = torch.sigmoid(out)
853  return out
854 
855  self.checkTrace(f, (x, y))
856 
857  def test_scopes_intermediate_node(self):
858  class Net(nn.Module):
859  def forward(self, x):
860  return F.log_softmax(x, dim=0)
861 
862  net = Net()
863  t = torch.ones(2, requires_grad=True)
864  trace, outputs, inputs = torch.jit.get_trace_graph(net, (t,), return_inputs=True)
865  self.assertEqual(outputs, self.createScriptModuleFromGraph(trace)(*inputs))
866  self.assertExportImport(trace, (t,))
867  torch.onnx._optimize_trace(trace, operator_export_type=OperatorExportTypes.ONNX)
868  FileCheck().check("onnx::LogSoftmax").check("scope: Net").run(str(trace))
869 
870  def test_scopes_identity_node(self):
871 
872  class Net(nn.Module):
873 
874  def __init__(self):
875  super(Net, self).__init__()
876  self.features = nn.Sequential(
877  nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
878  nn.ReLU(inplace=True),
879  nn.MaxPool2d(kernel_size=3, stride=2),
880  )
881 
882  def forward(self, x):
883  x = self.features(x)
884  return x
885 
886  model = Net()
887 
888  t = torch.ones(1, 3, 227, 227, requires_grad=True)
889 
890  with torch.onnx.set_training(model, False):
891  trace, _ = torch.jit.get_trace_graph(model, (t,))
892 
893  self.assertExportImport(trace, (t,) + tuple(model.parameters()))
894  torch.onnx._optimize_trace(trace, operator_export_type=OperatorExportTypes.ONNX)
895  FileCheck().check("Net/Sequential[features]/Conv2d[0]").check("ReLU").check("MaxPool").run(str(trace))
896 
897  def test_canonicalize_tensor_iterator(self):
898  x = torch.randn(4, 4)
899 
900  def f(x):
901  x = x + 2
902  x = x - 4
903  x = x * 6
904  x = x / 8
905  return x
906 
907  traced = torch.jit.trace(f, (x,))
908  f(x)
909  graph = traced.graph_for(x)
910  # There should be 4 int constants for the right sides of operators, plus one
911  # for the alpha argument for add and sub
912  self.assertTrue(str(traced.graph_for(x)).count(': int = prim::Constant') == 5)
913 
914  # TODO: adapt this test to check that GraphExecutor treats them differently
915  @unittest.skip("Need to be adjusted to Graph Executor")
917  """Different arg configurations should trigger different traces"""
918  x = Variable(torch.FloatTensor(4, 4).uniform_())
919  x_double = Variable(x.data.double())
920  x_grad = Variable(x.data.clone(), requires_grad=True)
921  y = Variable(torch.randn(4))
922 
923  configurations = [
924  (x,),
925  (x_double,),
926  (x_grad,),
927  (y,),
928  ([x, x],),
929  ([x, y],),
930  ]
932  x_cuda = Variable(x.data.cuda())
933  configurations += [
934  (x_cuda,),
935  ([x, x_cuda],),
936  ([x_cuda, x],),
937  ([[x_cuda, x]],),
938  ]
939  if torch.cuda.device_count() > 1:
940  x_cuda_1 = Variable(x.data.cuda(1))
941  configurations += [
942  (x_cuda_1,),
943  ([x_cuda, x_cuda_1],),
944  ]
945 
946  @torch.jit.compile(nderivs=0)
947  def fn(*args):
948  in_vars, _ = torch._C._jit_flatten(args)
949  return in_vars[0] + 1
950 
951  for i, config in enumerate(configurations):
952  self.assertFalse(fn.has_trace_for(*config))
953  fn(*config)
954  self.assertTrue(fn.has_trace_for(*config))
955  for unk_config in configurations[i + 1:]:
956  self.assertFalse(fn.has_trace_for(*unk_config))
957  self.assertEqual(fn.hits, 0)
958 
959  def test_cse(self):
960  x = torch.tensor([0.4, 0.3], requires_grad=True)
961  y = torch.tensor([0.7, 0.5], requires_grad=True)
962 
963  def fn(x, y):
964  w = (x + y) * (x + y) * (x + y)
965  t = torch.tanh(w) + torch.tanh(w)
966  z = (x + y) * (x + y) * (x + y) + t
967  return z
968 
969  trace, _ = torch.jit.get_trace_graph(fn, (x, y))
970  self.run_pass('cse', trace)
971  do_exactly = True
972  FileCheck().check_count("add", 1).check_count("mul", 2, do_exactly) \
973  .check_count("tanh", 1, do_exactly).check_count("add", 2, do_exactly).check_next("return") \
974  .run(str(trace))
975 
976  self.assertExportImport(trace, (x, y))
977 
978  def test_recursive_cse(self):
979  x = torch.tensor([0.1])
980  y = torch.tensor([0.2])
981 
982  def fn(x, y):
983  z = x
984  if bool(x + y > x):
985  z = x + y
986  return z
987 
988  graph = torch.jit.script(fn).graph
989  self.run_pass('cse', graph)
990  FileCheck().check("block").check_not("aten::add").check_not("aten::gt").run(str(graph))
991 
992  def test_shape_analysis_broadcast(self):
993  def broadcast(a, b):
994  return a + b
995 
996  x = torch.randn(3, 1, 5, requires_grad=True)
997  y = torch.randn(4, 1, 8, 5, requires_grad=True)
998 
999  graph = torch.jit.script(broadcast).graph
1000  torch._C._jit_pass_complete_shape_analysis(graph, (x, y), False)
1001  FileCheck().check("Double(4, 3, 8, 5)").run(str(graph))
1002 
1003  # TODO: update verify to work with GraphExecutors
1004  @unittest.skip("verify needs to be updated to work with GraphExecutors")
1005  def test_verify(self):
1006  x = torch.tensor([0.4], requires_grad=True)
1007  y = torch.tensor([0.7], requires_grad=True)
1008 
1009  @torch.jit.compile
1010  def f(x, y):
1011  z = torch.sigmoid(x * (x + y))
1012  w = torch.abs(x * x * x + y) + Variable(torch.ones(1))
1013  return z, w
1014 
1015  torch.jit.verify(f, (x, y), loss_fn=lambda z, w: z * w, devices=[])
1016 
1017  @suppress_warnings
1018  def test_constant(self):
1019  x = torch.randn(2, 2, requires_grad=True)
1020 
1021  def f(x):
1022  return x.matmul(torch.diag(torch.tensor([2., 2.])))
1023 
1024  self.checkTrace(f, (x,), (torch.ones(2, 2, requires_grad=True),))
1025 
1026  def test_legacy_fail(self):
1027  class MyLegacyFn(Function):
1028  def forward(self, x):
1029  return x
1030 
1031  def backward(self, grad_output):
1032  return grad_output
1033 
1034  x = torch.tensor([0.], requires_grad=True)
1035  with self.assertRaisesRegex(RuntimeError, "MyLegacyFn"):
1036  torch.jit.get_trace_graph(lambda x: MyLegacyFn()(x), (x,))
1037 
1038  def test_inplace_transplant(self):
1039  x = torch.tensor([0.], requires_grad=True)
1040 
1041  def fn(x):
1042  y = x.clone()
1043  y.add_(2)
1044  y.add_(3)
1045  return y
1046 
1047  trace, _ = torch.jit.get_trace_graph(fn, (x,))
1048  self.run_pass('dce', trace)
1049  FileCheck().check_count("aten::clone", 1, exactly=True) \
1050  .check_count("aten::add_", 2, exactly=True) \
1051  .check_next("return").run(str(trace))
1052  self.assertExportImport(trace, (x,))
1053 
1054  def test_inplace_flags(self):
1055  class InplaceFn(Function):
1056  @staticmethod
1057  def forward(ctx, x):
1058  ctx.mark_dirty(x)
1059  return x.add_(1)
1060 
1061  @staticmethod
1062  def backward(ctx, go):
1063  return go
1064 
1065  class RegularFn(Function):
1066  @staticmethod
1067  def forward(ctx, x):
1068  return x.add(1)
1069 
1070  @staticmethod
1071  def backward(ctx, go):
1072  return go
1073 
1074  x = torch.tensor([0.], requires_grad=True)
1075 
1076  def fn(x):
1077  y = RegularFn.apply(x)
1078  y = InplaceFn.apply(y)
1079  y = InplaceFn.apply(y)
1080  y = RegularFn.apply(y)
1081  return y
1082 
1083  trace, _ = torch.jit.get_trace_graph(fn, (x,), _force_outplace=True)
1084  self.run_pass('dce', trace)
1085  ops = [n for n in trace.graph().nodes()]
1086  for op in ops:
1087  self.assertTrue(op.hasAttribute('inplace'))
1088  inplace_flags = [False, True, True, False]
1089  for op, is_inplace in zip(ops, inplace_flags):
1090  self.assertEqual(op.i('inplace'), is_inplace)
1091 
1092  def test_inplace_check(self):
1093  class MyInplaceFn(Function):
1094  @staticmethod
1095  def forward(self, x):
1096  x.add_(1)
1097  self.mark_dirty(x)
1098  return x
1099 
1100  @staticmethod
1101  def backward(self, grad):
1102  return grad
1103 
1104  def fn(x):
1105  return MyInplaceFn.apply(x)
1106 
1107  x = torch.randn(5, 5)
1108  ge = torch._C.GraphExecutor(fn, (x,), lambda var: '', _force_outplace=True)
1109  with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'):
1110  ge(x)
1111 
1112  def do_trace_size(self, requires_grad):
1113  def fn(x):
1114  return x.view(x.shape[1] * 2, x.size(0), 2)
1115 
1116  x = torch.randn(5, 2, 4, requires_grad=requires_grad)
1117  y = torch.randn(4, 8, 4, requires_grad=requires_grad)
1118 
1119  # Check that it behaves as expected
1120  traced_fn = torch.jit.trace(fn, x)
1121  self.assertEqual(traced_fn(y), fn(y))
1122  self.assertEqual(traced_fn(x), fn(x))
1123 
1124  def test_trace_size(self):
1125  self.do_trace_size(False)
1126 
1127  # test the different graph_executor path that happens when
1128  # gradients are required and sizes are involved
1129  def test_trace_size_with_grad(self):
1130  self.do_trace_size(True)
1131 
1132  def test_trace_casts(self):
1133  casts = [
1134  lambda x: x.byte(),
1135  lambda x: x.float(),
1136  lambda x: x.cpu(),
1137  lambda x: x.to(device='cpu'),
1138  lambda x: x.to(dtype=torch.int64),
1139  lambda x: x.to(device='cpu', dtype=torch.float),
1140  lambda x: x.to(x)
1141  ]
1142 
1143  def assertContainsCast(trace):
1144  self.assertEqual(sum(n.kind() == 'aten::to' for n in trace.graph.nodes()), 1)
1145 
1146  for cast in casts:
1147  trace = torch.jit.trace(cast, torch.randn(2, 2))
1148  assertContainsCast(trace)
1149  x = torch.randn(2, 2)
1150  self.assertEqual(trace(x), cast(x))
1151 
1152  def to_tensor(x, y):
1153  return x.to(y)
1154 
1155  to_tensor_trace = torch.jit.trace(to_tensor, (torch.randn(2, 2), torch.randn(1, 8)))
1156  assertContainsCast(to_tensor_trace)
1157  x, y = torch.randn(2, 2), torch.randn(1, 10)
1158  self.assertEqual(to_tensor_trace(x, y), to_tensor(x, y))
1159 
1160  def test_trace_warn(self):
1161  def fn(x):
1162  int(x) # Warning 1.
1163  y = x * 1
1164  if y: # Warning 2.
1165  pass
1166  q = [x, x * 4]
1167  z = q[y] # Warning 3.
1168  float(z) # Warning 4.
1169  z.tolist() # Warning 5.
1170  z.numpy() # Warning 6.
1171  for _ in torch.ones(4, 4): # Warning 7.
1172  pass
1173  return z + 4
1174 
1175  with warnings.catch_warnings(record=True) as warns:
1176  traced_fn = torch.jit.trace(fn, torch.tensor([1]))
1177  warns = [str(w.message) for w in warns]
1178  self.assertEqual(len(warns), 7)
1179  self.assertIn('a Python integer', warns[0])
1180  self.assertIn('a Python boolean', warns[1])
1181  self.assertIn('a Python index', warns[2])
1182  self.assertIn('a Python float', warns[3])
1183  self.assertIn('a Python list', warns[4])
1184  self.assertIn('a NumPy array', warns[5])
1185  self.assertIn('Iterating over', warns[6])
1186 
1187  def test_trace_tuple(self):
1188  def fn(x, y):
1189  return x, (x * y[1], x * y[0])
1190 
1191  x, y = torch.randn(2, 2), (torch.ones(2, 2), torch.randn(2, 2))
1192  traced_fn = torch.jit.trace(fn, (x, y))
1193  self.assertEqual(traced_fn(x, y), fn(x, y))
1194  # should be a tuple nested within another tuple
1195  FileCheck().check_count("prim::TupleConstruct", 2, exactly=True).check_next("return") \
1196  .run(str(traced_fn.graph))
1197  self.assertExportImport(traced_fn.graph, (x, y))
1198 
1199  def test_trace_random(self):
1200  def f(mean, std):
1201  return torch.normal(mean, std)
1202 
1203  traced = torch.jit.trace(f, (torch.zeros(2, 3), torch.ones(2, 3)), check_trace=False)
1204  mean, std = torch.zeros(5, 5), torch.ones(5, 5)
1205  with torch.random.fork_rng(devices=[]):
1206  output = f(mean, std)
1207  traced_output = traced(mean, std)
1208  self.assertEqual(output, traced_output)
1209 
1210  def test_trace_tensor_factory(self):
1211  def run(**kwargs):
1212  inputs_require_grads = kwargs.pop('inputs_require_grads', True)
1213 
1214  def fn(x):
1215  return x + torch.ones(2, 3, **kwargs)
1216 
1217  input_kwargs = kwargs.copy()
1218  if 'out' in input_kwargs:
1219  del input_kwargs['out']
1220  input = torch.ones(2, 3, **input_kwargs)
1221  self.checkTrace(fn, (input,), inputs_require_grads=inputs_require_grads)
1222  # check we recorded 'ones' and did not just record a constant
1223  tfn = torch.jit.trace(fn, input)
1224  self.assertTrue("ones" in str(tfn.graph))
1225  run()
1226  run(dtype=torch.int, inputs_require_grads=False)
1227  run(out=torch.tensor([]))
1228  if RUN_CUDA:
1229  run(device="cuda:0")
1230  if RUN_CUDA_MULTI_GPU:
1231  run(device="cuda:1")
1232 
1233  def test_trace_indexed_assignment(self):
1234  def stuff(x, y):
1235  x = x.clone()
1236  x[0] = y
1237  return x
1238  example = torch.rand(3, 4)
1239  self.checkTrace(stuff, (example, example[0] + 1))
1240 
1241  # TODO: implement
1242  @unittest.expectedFailure
1244  """Check that outputs of traced functions retain the original structure and nesting"""
1245  def fn(x):
1246  return (x * 2, (x ** 2, x + 4, (x + 2,), ), x * 4)
1247 
1248  self.checkTrace(fn, (torch.randn(2, 2),))
1249 
1250  # TODO: implement
1251  @unittest.expectedFailure
1253  """Check that inputs to traced functions are flattened"""
1254 
1255  def fn(x, t):
1256  y, z = t
1257  return x * y * z
1258 
1259  inputs = (torch.randn(1), (torch.randn(1), torch.randn(1)))
1260  self.checkTrace(fn, inputs)
1261 
1262  # TODO: adapt to a GraphExecutor test
1263  @unittest.skip("Need to instrument GraphExecutors a bit more")
1264  def test_flags(self):
1265  x, y = torch.randn(2, 2)
1266  y = Variable(torch.randn(2, 2))
1267 
1268  @torch.jit.compile
1269  def fn(x, y):
1270  return (x * x + y * y + x * y).sum()
1271 
1272  grads = {}
1273  for rx, ry in product((True, False), repeat=2):
1274  x.requires_grad = rx
1275  y.requires_grad = ry
1276 
1277  self.assertFalse(fn.has_trace_for(x, y))
1278  out = fn(x, y)
1279 
1280  self.assertFalse(fn.has_trace_for(x, y))
1281  for v, name, compute in [(x, 'x', rx), (y, 'y', ry)]:
1282  if not compute:
1283  continue
1284  grad_v, = torch.autograd.grad(out, v, retain_graph=True)
1285  expected_grad = grads.setdefault(name, grad_v)
1286  self.assertEqual(grad_v, expected_grad)
1287  self.assertEqual(fn.has_trace_for(x, y), rx or ry)
1288 
1289  def test_python_ir(self):
1290  x = torch.tensor([0.4], requires_grad=True)
1291  y = torch.tensor([0.7], requires_grad=True)
1292 
1293  def doit(x, y):
1294  return torch.sigmoid(torch.tanh(x * (x + y)))
1295 
1296  trace, _ = torch.jit.get_trace_graph(doit, (x, y))
1297  self.run_pass('dce', trace)
1298  self.run_pass('canonicalize', trace)
1299  g = trace.graph()
1300  g2 = torch._C.Graph()
1301  g_to_g2 = {}
1302  for node in g.inputs():
1303  g_to_g2[node] = g2.addInput()
1304  for node in g.nodes():
1305  n_ = g2.createClone(node, lambda x: g_to_g2[x])
1306  g2.appendNode(n_)
1307  for o, no in zip(node.outputs(), n_.outputs()):
1308  g_to_g2[o] = no
1309 
1310  for node in g.outputs():
1311  g2.registerOutput(g_to_g2[node])
1312 
1313  t_node = g2.create("prim::TensorTest").t_("a", torch.ones([2, 2]))
1314  self.assertEqual(t_node.attributeNames(), ["a"])
1315  g2.appendNode(t_node)
1316  self.assertTrue(torch.equal(torch.ones(2, 2), t_node.t("a")))
1317  for node in g.nodes():
1318  self.assertTrue(g2.findNode(node.kind()) is not None)
1319 
1320  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
1321  @unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA")
1322  @skipIfRocm
1323  def test_cpp_cuda(self):
1324  from cpp.jit import tests_setup
1325  tests_setup.setup()
1326  torch._C._jit_run_cpp_tests()
1327  tests_setup.shutdown()
1328 
1329  def test_batchnorm(self):
1330  x = torch.ones(2, 2, 2, 2)
1331  trace, outputs, inputs = torch.jit.get_trace_graph(nn.BatchNorm2d(2), x,
1332  _force_outplace=True, return_inputs=True)
1333  m = self.createScriptModuleFromGraph(trace)
1334  self.assertEqual(outputs, m(*inputs))
1335 
1336  def test_dropout(self):
1337  x = torch.ones(2, 2)
1338  with torch.random.fork_rng(devices=[]):
1339  trace, outputs, inputs = torch.jit.get_trace_graph(nn.Dropout(0.6), x, return_inputs=True)
1340  with torch.random.fork_rng(devices=[]):
1341  m = self.createScriptModuleFromGraph(trace)
1342  self.assertEqual(outputs, m(*inputs))
1343 
1344  @unittest.skipIf(not RUN_CUDA, "test_dropout_cuda require CUDA")
1345  def test_dropout_cuda(self):
1346  # Dropout AD is dispatched to _fused_dropout in CUDA case,
1347  # which is not included in TestJitGeneratedFunctional
1348  x = torch.ones(4, 4).cuda().requires_grad_()
1349 
1350  @torch.jit.script
1351  def func(x):
1352  return torch.nn.functional.dropout(x)
1353 
1354  with freeze_rng_state():
1355  out_ref = torch.nn.functional.dropout(x)
1356  grad_ref = torch.autograd.grad(out_ref.sum(), x)
1357 
1358  with freeze_rng_state():
1359  out = func(x)
1360  grad = torch.autograd.grad(out.sum(), x)
1361 
1362  self.assertEqual(out, out_ref)
1363  self.assertEqual(grad, grad_ref)
1364 
1365  def test_conv(self):
1366  x = torch.ones(20, 16, 50, 40)
1367  trace, outputs, inputs = torch.jit.get_trace_graph(nn.Conv2d(16, 13, 3, bias=False), x, return_inputs=True)
1368  m = self.createScriptModuleFromGraph(trace)
1369  self.assertEqual(outputs, m(*inputs))
1370 
1371  def test_repeated_input(self):
1372  def fn(a, b):
1373  return a + b
1374 
1375  ge = self.checkTrace(fn, [torch.randn(2, 2)] * 2)
1376  inputs = set(ge.graph.inputs())
1377  self.assertTrue(len(inputs) == 2)
1378 
1379  def test_repeated_output(self):
1380  def fn(a, b):
1381  z = a + b
1382  return z, z
1383 
1384  ge = self.checkTrace(fn, [torch.randn(2, 2) for _ in range(2)])
1385  tuple_output = list(ge.graph.outputs())[0]
1386  tuple_inputs = list(tuple_output.node().inputs())
1387  self.assertTrue(tuple_inputs[0] == tuple_inputs[1])
1388 
1389  @skipIfNoTorchVision
1390  def test_alexnet(self):
1391  x = torch.ones(1, 3, 224, 224)
1392  model = torchvision.models.AlexNet()
1393  with torch.random.fork_rng(devices=[]):
1394  trace, outputs, inputs = torch.jit.get_trace_graph(model, x, return_inputs=True)
1395  self.run_pass('cse', trace)
1396  m = self.createScriptModuleFromGraph(trace)
1397  with torch.random.fork_rng(devices=[]):
1398  self.assertEqual(outputs, m(*inputs))
1399 
1400  def test_inplace_copy(self):
1401  x = torch.randn(4, 4, requires_grad=True)
1402 
1403  def f(x):
1404  out = Variable(torch.zeros(x.size()))
1405  out.copy_(x)
1406  return out
1407 
1408  trace, outputs, inputs = torch.jit.get_trace_graph(f, (x, ), return_inputs=True)
1409  self.run_pass('dce', trace)
1410  m = self.createScriptModuleFromGraph(trace)
1411  self.assertEqual(outputs, m(*inputs))
1412  self.assertExportImport(trace, (x,))
1413 
1414  def test_shared_param(self):
1415  class MyModule(torch.nn.Module):
1416  def __init__(self):
1417  super(MyModule, self).__init__()
1418  self.b = self.a = nn.Parameter(torch.randn(2, 2))
1419 
1420  def forward(self, x):
1421  return x * self.a + self.b
1422 
1423  m = MyModule()
1424  trace, _ = torch.jit.get_trace_graph(m, (torch.randn(2, 2),))
1425  self.run_pass('dce', trace)
1426  self.assertEqual(len(list(trace.graph().inputs())), 2)
1427  FileCheck().check("mul").check("add").run(str(trace))
1428 
1429  def test_trace_c10_ops(self):
1430  class MyModel(torch.nn.Module):
1431  def __init__(self):
1432  super(MyModel, self).__init__()
1433 
1434  def forward(self, scores, bbox_deltas, im_info, anchors):
1435  a, b = torch.ops._caffe2.GenerateProposals(
1436  (scores), (bbox_deltas), (im_info), (anchors),
1437  2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0,
1438  )
1439  return a, b
1440  model = MyModel()
1441  A = 4
1442  H = 10
1443  W = 8
1444  img_count = 3
1445  scores = torch.ones(img_count, A, H, W, dtype=torch.float32)
1446  bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W,
1447  dtype=torch.float32)
1448  bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W)
1449  im_info = torch.ones(img_count, 3, dtype=torch.float32)
1450  anchors = torch.ones(A, 4, dtype=torch.float32)
1451  inputs = (scores, bbox_deltas, im_info, anchors)
1452  traced_model = torch.jit.trace(model, inputs)
1453  self.assertEqual(traced_model(*inputs), model(*inputs))
1454  self.assertExportImport(traced_model.graph, (scores, bbox_deltas, im_info, anchors))
1455 
1456  def test_nested_inplace(self):
1457  x = torch.randn(2, 2)
1458  trace, outputs, inputs = torch.jit.get_trace_graph(
1459  lambda x: F.threshold(x, 0, 0, inplace=True), (x, ), return_inputs=True)
1460  m = self.createScriptModuleFromGraph(trace)
1461  self.assertEqual(outputs, m(*inputs))
1462  FileCheck().check("threshold_").run(str(trace))
1463  self.assertExportImport(trace, (x,))
1464 
1465  def run_ge_tests(self, optimize, use_cuda):
1466  def rand(*args):
1467  t = torch.rand(*args).float()
1468  if use_cuda:
1469  t = t.cuda()
1470  return t
1471  self.checkTrace(lambda a, b: a * b + b,
1472  [rand(1), rand(1)], [rand(2, 3), rand(2, 3)],
1473  optimize=optimize)
1474  # trivial identity
1475  self.checkTrace(lambda a, b: (
1476  b, a), [rand(1), rand(1)], optimize=optimize)
1477 
1478  def foo(a):
1479  t = a * a
1480  return t * t, 4 * t
1481  self.checkTrace(foo, [rand(1)], optimize=optimize)
1482  # unused input
1483  self.checkTrace(
1484  lambda a, b: a * a, [rand(1), rand(1)], optimize=optimize,
1485  allow_unused=True)
1486  # test outputs that do not get used in grad
1487  self.checkTrace(foo, [rand(1)], drop=1, optimize=optimize)
1488  # test autograd fallback
1489  self.checkTrace(lambda a, b: a * b /
1490  (a - 2 * b) + b, [rand(1), rand(1)],
1491  optimize=optimize)
1492 
1493  def test_ge_unoptimized(self):
1494  self.run_ge_tests(False, False)
1495 
1496  @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
1497  @enable_cpu_fuser
1498  def test_ge_optimized(self):
1499  self.run_ge_tests(True, False)
1500 
1501  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
1502  @unittest.skipIf(not RUN_CUDA, "requires CUDA")
1503  def test_ge_cuda(self):
1504  self.run_ge_tests(True, True)
1505 
1506  # more manual test of graph executor that can be used as a scratchpad
1507  def test_ge(self):
1508  def foo(a, b):
1509  return a * b / (a - b) + b
1510  V = Variable
1511  a, b = V(torch.rand(1)), V(torch.rand(1))
1512  ge = torch._C.GraphExecutor(foo, (a, b), lambda var: '')
1513  a, b = V(torch.rand(1), requires_grad=True), V(
1514  torch.rand(1), requires_grad=True)
1515  r, = ge(a, b)
1516  da, db = torch.autograd.grad(r + 3, [a, b], create_graph=True)
1517 
1518  l2 = (da * db + db * db)
1519  g2result = torch.autograd.grad(l2, [da, db])
1520 
1521  r = foo(a, b)
1522  da2, db2 = torch.autograd.grad(r + 3, [a, b], create_graph=True)
1523  self.assertEqual(da, da2)
1524  self.assertEqual(db, db2)
1525  l3 = (da2 * db2 + db2 * db2)
1526  g2result2 = torch.autograd.grad(l3, [da2, db2])
1527  self.assertEqual(g2result, g2result2)
1528 
1529  def test_trace_annotation(self):
1530  @_trace(torch.rand(1))
1531  def foo(a):
1532  return a + a + a
1533 
1534  x = torch.randn(5, 5)
1535  self.assertEqual(foo(x), x + x + x)
1536 
1537  def test_trace_script(self):
1538  @torch.jit.script
1539  def func1(x):
1540  # type: (Tuple[Tensor, Tensor]) -> Tensor
1541  return x[0] + x[1]
1542 
1543  @torch.jit.script
1544  def func2(x):
1545  # type: (List[Tensor]) -> Tensor
1546  return x[0] + x[1]
1547 
1548  a = torch.randn(5)
1549  b = torch.randn(5)
1550 
1551  expected = func1((a, b))
1552  traced = torch.jit.trace(func1, ((a, b),))
1553  result = traced((a, b))
1554  self.assertEqual(expected, result)
1555 
1556  expected = func2((a, b))
1557  traced = torch.jit.trace(func2, ((a, b),))
1558  result = traced((a, b))
1559  self.assertEqual(expected, result)
1560 
1561  def test_einsum(self):
1562  def outer(x, y):
1563  return torch.einsum('i,j->ij', (x, y))
1564 
1565  traced = torch.jit.trace(outer, (torch.randn(4), torch.randn(5)))
1566  script = torch.jit.script(outer)
1567  fns = [traced, script]
1568  x, y = torch.randn(10), torch.randn(2)
1569  for fn in [traced, script]:
1570  self.assertGraphContains(fn.graph, kind='aten::einsum')
1571  self.assertEqual(fn(x, y), outer(x, y))
1572 
1573  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
1574  @unittest.skipIf(not RUN_CUDA, "calls .cuda()")
1575  def test_traced_module_cuda(self):
1576  class Model(nn.Module):
1577  def __init__(self, num_features, num_layers):
1578  super(Model, self).__init__()
1579  self.num_layers = num_layers
1580  layers = [[nn.Linear(num_features, num_features), nn.Sigmoid()]
1581  for _ in range(num_layers)]
1582  self.submodule = nn.Sequential(*chain(*layers))
1583 
1584  def forward(self, x):
1585  for i in range(self.num_layers):
1586  x = self.submodule[i](x) + x
1587  return x
1588 
1589  model = Model(5, 3)
1590  x = torch.randn(2, 5)
1591  traced_model = torch.jit.trace(model, x)
1592 
1593  # We're missing some attributes these modules had initially. Make sure we can
1594  # still get the __repr__()
1595  model.__repr__()
1596 
1597  # XXX: indexing sequentials is broken
1598  linear_submodule = next(iter(traced_model.submodule._modules.values()))
1599 
1600  # All attributes that aren't parameters should raise
1601  with self.assertRaises(AttributeError):
1602  linear_submodule.in_features
1603  linear_submodule.weight
1604  with self.assertRaises(RuntimeError):
1605  traced_model.asdf = 4
1606  linear_submodule.weight = nn.Parameter(torch.randn(linear_submodule.weight.shape))
1607  with self.assertRaises(RuntimeError):
1608  del linear_submodule.weight
1609 
1610  # Submodules can't be called
1611  with self.assertRaises(RuntimeError):
1612  linear_submodule(x)
1613 
1614  # Type casts
1615  linear_submodule.cuda()
1616  traced_model.float().cuda()
1617  cuda_out = traced_model(x.float().cuda())
1618  traced_model.cpu()
1619  cpu_out = traced_model(x.float())
1620  self.assertEqual(cpu_out, cuda_out)
1621  traced_model.to('cuda')
1622  cuda_out = traced_model(x.float().cuda())
1623  traced_model.to('cpu')
1624  cpu_out = traced_model(x.float())
1625  self.assertEqual(cpu_out, cuda_out)
1626  traced_model.double()
1627 
1628  # state_dict + load_state_dict
1629  state = {k: v.clone() for k, v in traced_model.state_dict().items()}
1630  new_state = {k: v.clone().fill_(1) for k, v in state.items()}
1631  out = traced_model(x)
1632  traced_model.load_state_dict(new_state)
1633  out_ones = traced_model(x)
1634  traced_model.load_state_dict(state)
1635  out_state = traced_model(x)
1636  self.assertEqual(out, out_state)
1637  self.assertNotEqual(out, out_ones)
1638 
1639  def test_export_no_reorder(self):
1640  def func(a, b):
1641  return a * b / (a - 2 * b) + b
1642 
1643  recording_inputs = [torch.tensor([0.55619788169860839844], dtype=torch.float32, requires_grad=True),
1644  torch.tensor([0.25947844982147216797], dtype=torch.float32, requires_grad=True)]
1645 
1646  ge1 = torch.jit.trace(func, recording_inputs, optimize=True)
1647  ge2 = self.getExportImportCopy(ge1)
1648 
1649  outputs_ge1 = ge1(*recording_inputs)
1650  outputs_ge2 = ge2(*recording_inputs)
1651 
1652  grad_ge1 = torch.autograd.grad(outputs_ge1, recording_inputs)
1653  grad_ge2 = torch.autograd.grad(outputs_ge2, recording_inputs)
1654  self.assertTrue(outputs_ge1 == outputs_ge2)
1655  self.assertTrue(grad_ge1 == grad_ge2)
1656 
1657  def test_python_function(self):
1658  class MyFn(Function):
1659  @staticmethod
1660  def forward(ctx, x):
1661  return x + 1
1662 
1663  @staticmethod
1664  def backward(ctx, grad_output):
1665  return grad_output
1666 
1667  @_trace(torch.zeros(2))
1668  def fn(x):
1669  return MyFn.apply(x + 2) + 3
1670 
1671  x = torch.tensor([1., 2., 3.])
1672  y = torch.randn(2, 2, requires_grad=True)
1673  fn(x)
1674  fn(y)
1675 
1676  def test_python_function_tup(self):
1677  class MyFn(Function):
1678  @staticmethod
1679  def forward(ctx, x):
1680  return x + 1, x - 1
1681 
1682  @staticmethod
1683  def backward(ctx, grad_output):
1684  return grad_output, grad_output
1685 
1686  @_trace(torch.zeros(2))
1687  def fn(x):
1688  a, b = MyFn.apply(x + 2)
1689  return a + b + 3
1690  x = torch.tensor([1., 2., 3.])
1691  y = torch.randn(2, 2, requires_grad=True)
1692  fn(x)
1693  fn(y)
1694 
1695  def test_decompose_addmm(self):
1696  def does_decompose():
1697  @torch.jit.script
1698  def addmm(mat, mat1, mat2, alpha, beta):
1699  a = mat.addmm(mat1, mat2, alpha=4.20, beta=2.0)
1700  b = mat.addmm(mat1, mat2, alpha=int(alpha), beta=int(beta))
1701 
1702  return a + b
1703 
1704  mat = torch.randn(2, 2)
1705  mat1 = torch.randn(2, 4)
1706  mat2 = torch.randn(4, 2)
1707  alpha = torch.FloatTensor([123.0])
1708  beta = torch.FloatTensor([321.0])
1709 
1710  out_ref = addmm(mat, mat1, mat2, alpha, beta)
1711  self.run_pass('canonicalize_ops', addmm.graph)
1712  out_test = addmm(mat, mat1, mat2, alpha, beta)
1713  self.assertEqual(out_ref, out_test)
1714  FileCheck().check_not("addmm").run(str(addmm.graph))
1715 
1716  def doesnt_decompose():
1717  @torch.jit.script
1718  def addmm(mat, mat1, mat2, alpha, beta):
1719  a = mat.addmm(mat1, mat2)
1720  b = mat.addmm(mat1, mat2, alpha=1.0, beta=1.0)
1721 
1722  orig = str(addm.graph)
1723  self.run_pass('canonicalize_ops', addmm.graph)
1724  self.assertTrue(orig == str(addmm.graph))
1725 
1726  def test_index_put(self):
1727  ten = torch.zeros(3, 3)
1728  mask = torch.Tensor([[True, True, True],
1729  [True, False, False],
1730  [True, True, False]]).byte()
1731 
1732  def test_fn(ten, mask):
1733  ten[mask] = torch.ones(6)
1734  return ten
1735 
1736  traced_test_fn = torch.jit.trace(test_fn, (ten, mask))
1737 
1738  ten = torch.rand(3, 3)
1739  self.assertEqual(test_fn(ten, mask), traced_test_fn(ten, mask))
1740 
1741  def test_sparse_tensors_error(self):
1742  def get_sparse():
1743  return torch.sparse.FloatTensor(2, 3)
1744 
1745  @torch.jit.script
1746  def sparse(input):
1747  output = get_sparse()
1748  return output, input
1749 
1750  with self.assertRaisesRegex(RuntimeError, "sparse tensors not supported"):
1751  sparse(get_sparse())
1752 
1753  with self.assertRaisesRegex(RuntimeError, "sparse tensors not supported"):
1754  sparse(torch.tensor([1]))
1755 
1756  def test_tuple_specialization(self):
1757  @torch.jit.script
1758  def f(t):
1759  # type: (Tuple[Tensor, Tensor]) -> Tensor
1760  x, y = t
1761  return x + y
1762 
1763  t = torch.randn(2, 2), torch.randn(2, 2)
1764  f(t)
1765  graph = f.graph_for(t)
1766  input_types = list(next(graph.inputs()).type().elements())
1767  for t in input_types:
1768  self.assertEqual(t.kind(), 'DimensionedTensorType')
1769 
1770  def test_constant_prop_simple(self):
1771  @torch.jit.script
1772  def constant_prop(input_int):
1773  # type: (int) -> int
1774  a = 2 * 3
1775  b = a + 2
1776  return b - input_int
1777 
1778  out_ref = constant_prop(2)
1779  self.run_pass('constant_propagation', constant_prop.graph)
1780  out_test = constant_prop(2)
1781  self.assertEqual(out_ref, out_test)
1782  graph_str = str(constant_prop.graph)
1783  self.assertTrue("aten::add" not in graph_str and "aten::mul" not in graph_str)
1784  const = constant_prop.graph.findNode("prim::Constant").output().toIValue()
1785  self.assertEqual(const, 8)
1786 
1787  def test_constant_prop_nested(self):
1788  @torch.jit.script
1789  def constant_prop(a):
1790  b = 2 + 1
1791  if bool(a < 2):
1792  c = b + 2
1793  else:
1794  c = b - 2
1795  return c
1796  out_ref = constant_prop(torch.tensor(2))
1797  self.run_pass('constant_propagation', constant_prop.graph)
1798  out_test = constant_prop(torch.tensor(2))
1799  self.assertEqual(out_ref, out_test)
1800  if_node = constant_prop.graph.findNode("prim::If")
1801  for block in if_node.blocks():
1802  for node in block.nodes():
1803  self.assertTrue(node.kind() == "prim::Constant")
1804 
1805  def test_constant_prop_print(self):
1806  @torch.jit.script
1807  def constant_prop(input_tensor):
1808  a = 2 * 3
1809  print(a)
1810  b = a + 2
1811  return b + input_tensor
1812 
1813  self.run_pass('constant_propagation', constant_prop.graph)
1814  graph = constant_prop.graph
1815  print_node = graph.findNode("prim::Print")
1816  self.assertTrue(print_node.input().toIValue() == 6)
1817 
1818  def test_constant_prop_rand(self):
1819  @torch.jit.script
1820  def constant_prop():
1821  a = torch.randn([3])
1822  b = a + 2
1823  return b
1824 
1825  self.run_pass('constant_propagation', constant_prop.graph)
1826  self.assertTrue("aten::randn" in str(constant_prop.graph))
1827 
1828  def test_constant_prop_none(self):
1829  @torch.jit.script
1830  def typed_none():
1831  # type: () -> Optional[int]
1832  return None
1833 
1834  @torch.jit.script
1835  def constant_prop():
1836  a = typed_none()
1837  b = typed_none()
1838  if (a is None and b is None):
1839  a = 2
1840  else:
1841  a = 1
1842  return a
1843 
1844  self.run_pass('constant_propagation', constant_prop.graph)
1845  graph_str = str(constant_prop.graph)
1846  self.assertTrue(graph_str.count("prim::Constant") == 1)
1847 
1848  def test_constant_prop_if_inline(self):
1849  @torch.jit.script
1850  def constant_prop():
1851  cond = True
1852  a = 1
1853  if cond:
1854  a = 1 * 2
1855  else:
1856  a = 1 // 0
1857  return a
1858 
1859  # testing that 1 // 0 error is not thrownn
1860  self.run_pass('constant_propagation', constant_prop.graph)
1861 
1862  def test_trace_records_names(self):
1863  def foo(bar, baz):
1864  baz = bar + 3
1865  quick_brown_fox = torch.neg(baz)
1866  for _ in range(20):
1867  yeet = quick_brown_fox - 3.14
1868  return yeet
1869 
1870  traced = torch.jit.trace(foo, (torch.rand(3, 3), torch.rand(3, 3)))
1871  graph_str = str(traced.graph)
1872  assert 'bar' in graph_str
1873  assert 'baz' in graph_str
1874  assert 'quick_brown_fox' in graph_str
1875 
1876  def test_constant_prop_if_constant(self):
1877  @torch.jit.script
1878  def constant_prop(a, b):
1879  c0 = 1
1880  c1 = 1
1881  c2 = 1
1882  if bool(a): # -> c0, c1
1883  if bool(b): # -> c0
1884  if True: # -> c0
1885  c0 = c0 + 1
1886  if False:
1887  c1 = c1 + 1
1888  c2 = c2 + 1
1889  else: # -> c0, c1
1890  c1 = c1 + 1
1891 
1892  if True: # inlined
1893  c0 = c0 + 1 # dynamic
1894  c2 = c2 + 4 # set to 5
1895  return a + c0 + c1 + c2
1896 
1897  graph = constant_prop.graph
1898  self.run_pass('constant_propagation', graph)
1899  ifs = graph.findAllNodes("prim::If", recurse=False)
1900  snd_if_inlined = len(ifs) == 1
1901  self.assertTrue(snd_if_inlined)
1902  first_if = ifs[0]
1903  self.assertTrue(first_if.outputsSize() == 2)
1904  second_if = first_if.findNode("prim::If", recurse=False)
1905  self.assertTrue(second_if.outputsSize() == 1)
1906  self.assertTrue(second_if.findNode("prim::If") is None)
1907 
1908  def test_constant_prop_loop_constant(self):
1909  @torch.jit.script
1910  def constant_prop(cond, iter):
1911  # type: (bool, int) -> int
1912  b = 0
1913  while True:
1914  print("stays")
1915  for _ in range(2):
1916  print("stays")
1917  for _ in range(iter):
1918  print("stays")
1919  while cond:
1920  print("stays")
1921  while False:
1922  print("removed")
1923  for _i in range(0):
1924  print("removed")
1925  for _i in range(-4):
1926  print("removed")
1927  return b
1928 
1929  self.run_pass('constant_propagation', constant_prop.graph)
1930  graph = canonical(constant_prop.graph)
1931  self.assertTrue(graph.count("removed") == 0)
1932  self.assertTrue(graph.count("stays") == 1) # constant gets pooled
1933  self.assertTrue(graph.count("prim::Print") == 4)
1934 
1935  def test_constant_prop_remove_output(self):
1936  @torch.jit.script
1937  def constant_prop(iter):
1938  # type: (int) -> None
1939  a = 1
1940  b = 1
1941  c = 1
1942  for i in range(iter):
1943  if False:
1944  a = 10
1945  if i == 5:
1946  b = 2
1947  c = 3
1948  print(a, b, c)
1949 
1950  graph = constant_prop.graph
1951  self.run_pass('constant_propagation', graph)
1952  self.assertTrue(graph.findNode("prim::Loop").outputsSize() == 2)
1953 
1954  def test_trace_detach(self):
1955  def foo(x, w):
1956  return torch.matmul(x, w).detach()
1957 
1958  traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5)))
1959 
1960  FileCheck().check("matmul").check("detach").run(str(traced.graph))
1961  x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
1962  traced_result = traced(x, w)
1963  self.assertEqual(foo(x, w), traced_result)
1964  self.assertFalse(traced_result.requires_grad)
1965  self.assertIsNone(traced_result.grad_fn)
1966 
1967  def test_trace_detach_inplace(self):
1968  def foo(x, w):
1969  y = torch.matmul(x, w)
1970  y.detach_()
1971  return y
1972 
1973  traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5)))
1974 
1975  FileCheck().check("matmul").check("detach(").run(str(traced.graph))
1976  x, w = torch.rand(3, 4), torch.rand(4, 5)
1977  traced_result = traced(x, w)
1978  self.assertEqual(foo(x, w), traced_result)
1979  self.assertFalse(traced_result.requires_grad)
1980  self.assertIsNone(traced_result.grad_fn)
1981 
1982  def test_trace_detach_onnx_erase(self):
1983  class Mod(torch.nn.Module):
1984  def forward(self, x, w):
1985  return torch.matmul(x, w).detach()
1986 
1987  f = io.BytesIO()
1989  Mod(), (torch.rand(3, 4), torch.rand(4, 5)), f))
1990 
1991  def test_trace_slice_full_dim(self):
1992  def foo(x):
1993  return x[0:5, 0] + 1.0
1994 
1995  traced = torch.jit.trace(foo, (torch.rand(5, 4),))
1996  test_x = torch.rand(6, 3)
1997  self.assertEqual(foo(test_x), traced(test_x))
1998 
1999  def test_export_dropout(self):
2000  test = torch.nn.Dropout()
2001  test.eval()
2002 
2003  traced = torch.jit.trace(test, (torch.rand(3, 4),), check_trace=False)
2004  imported = self.getExportImportCopy(traced)
2005  x = torch.randn(3, 4)
2006  self.assertEqual(traced(x), imported(x))
2007 
2008  def test_onnx_transpose_incomplete_tensor_type(self):
2009  # Smoke test to get us into the state where we are attempting to export
2010  # a transpose op, where the input is a TensorType rather than a
2011  # CompleteTensorType. This would previously not work, since we would
2012  # take the size of the input and use the length of its sizes as the
2013  # number of dimensions in the permutation.
2014  class Foo(torch.jit.ScriptModule):
2015  @torch.jit.script_method
2016  def forward(self, x):
2017  return x.contiguous().transpose(0, 1).sum()
2018 
2019  class TraceMe(torch.nn.Module):
2020  def __init__(self):
2021  super(TraceMe, self).__init__()
2022  self.foo = Foo()
2023 
2024  def forward(self, x):
2025  return self.foo(x)
2026 
2027  tm = TraceMe()
2028  tm = torch.jit.trace(tm, torch.rand(3, 4))
2029  example_outputs = (tm(torch.rand(3, 4)),)
2030  f = io.BytesIO()
2031  torch.onnx._export(tm, (torch.rand(3, 4),), f, example_outputs=example_outputs)
2032 
2033  @unittest.skipIf(not RUN_CUDA, "requires CUDA")
2034  def test_cuda_export_restore(self):
2035  class Sub(torch.jit.ScriptModule):
2036  def __init__(self):
2037  super(Sub, self).__init__()
2038  self.weight = nn.Parameter(torch.randn(3, 4))
2039 
2040  @torch.jit.script_method
2041  def forward(self, thing):
2042  return self.weight + thing
2043 
2044  class M(torch.jit.ScriptModule):
2045  def __init__(self):
2046  super(M, self).__init__()
2047  self.mod = Sub()
2048 
2049  @torch.jit.script_method
2050  def forward(self, v):
2051  return self.mod(v)
2052  m = M()
2053  m.cuda()
2054  m2 = self.getExportImportCopy(m)
2055  m2.cuda()
2056  input = torch.rand(3, 4).cuda()
2057  self.assertEqual(m(input), m2(input))
2058 
2059  def test_export_batchnorm(self):
2060  for mode in ['eval', 'train']:
2061  for clazz in [
2062  torch.nn.BatchNorm1d(100),
2063  torch.nn.BatchNorm1d(100, affine=False),
2064  torch.nn.BatchNorm2d(100),
2065  torch.nn.BatchNorm2d(100, affine=False)]:
2066  getattr(clazz, mode)()
2067  input = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \
2068  torch.randn(20, 100, 35, 45)
2069  traced = torch.jit.trace(clazz, (input,))
2070  imported = self.getExportImportCopy(traced)
2071  x = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \
2072  torch.randn(20, 100, 35, 45)
2073  self.assertEqual(traced(x), imported(x))
2074 
2075  def test_export_rnn(self):
2076  for clazz in [nn.RNN(10, 20, 2), nn.GRU(10, 20, 2)]:
2077  class RNNTest(torch.nn.Module):
2078  def __init__(self):
2079  super(RNNTest, self).__init__()
2080  self.rnn = clazz
2081 
2082  def forward(self, x, lengths, h0):
2083  packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths)
2084  out, h = self.rnn(packed, h0)
2085  padded_outs, _ = torch.nn.utils.rnn.pad_packed_sequence(out)
2086  return padded_outs
2087 
2088  test = RNNTest()
2089 
2090  traced = torch.jit.trace(test, (torch.randn(5, 3, 10), torch.LongTensor([3, 2, 1]), torch.randn(2, 3, 20)))
2091  imported = self.getExportImportCopy(traced)
2092  # NB: We make sure to pass in a batch with a different max sequence
2093  # length to ensure that the argument stashing for pad_packed works
2094  # properly.
2095  x, lengths, h0 = torch.randn(7, 4, 10), torch.LongTensor([7, 3, 2, 1]), torch.randn(2, 4, 20)
2096  self.assertEqual(traced(x, lengths, h0), imported(x, lengths, h0))
2097 
2098  def test_export_lstm(self):
2099  class LSTMTest(torch.nn.Module):
2100  def __init__(self):
2101  super(LSTMTest, self).__init__()
2102  self.rnn = nn.LSTM(10, 20, 2)
2103 
2104  def forward(self, x, lengths, hiddens):
2105  h0, c0 = hiddens
2106  packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths)
2107  out, (h, c) = self.rnn(packed, (h0, c0))
2108  padded_outs, _ = torch.nn.utils.rnn.pad_packed_sequence(out)
2109  return padded_outs
2110 
2111  test = LSTMTest()
2112 
2113  traced = torch.jit.trace(test, (torch.randn(5, 3, 10),
2114  torch.LongTensor([3, 2, 1]),
2115  (torch.randn(2, 3, 20), torch.randn(2, 3, 20))))
2116  imported = self.getExportImportCopy(traced)
2117  x, lengths, h0, c0 = \
2118  torch.randn(7, 3, 10), torch.LongTensor([7, 5, 2]), torch.randn(2, 3, 20), torch.randn(2, 3, 20)
2119  self.assertEqual(traced(x, lengths, (h0, c0)), imported(x, lengths, (h0, c0)))
2120 
2121  def test_trace_dict_input(self):
2122  class Bar(torch.nn.Module):
2123  def __init__(self):
2124  super(Bar, self).__init__()
2125  self.foo = Foo()
2126 
2127  def forward(self, a, b):
2128  return self.foo({'a': a, 'b': b})['a']
2129 
2130  class Foo(torch.nn.Module):
2131  def forward(self, x):
2132  return {'a': x['a'] * x['b']}
2133 
2134  x = (torch.rand(3), torch.rand(3))
2135  model = Bar()
2136  self.checkTrace(model, x)
2137 
2138  def test_trace_variable_instantiation(self):
2139  def random_foo(x):
2140  return Variable(Variable(x) + 1.0)
2141 
2142  random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),))
2143 
2144  x = torch.rand(5, 6)
2145  self.assertEqual(random_foo(x), random_foo_traced(x))
2146 
2147  def test_trace_slice_expr_complete_type(self):
2148  def random_foo(x):
2149  return x + 1.0
2150 
2151  random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),))
2152 
2153  @torch.jit.script
2154  def random_bar(x):
2155  return random_foo_traced(x)[0:1]
2156 
2157  x = torch.rand(3, 4)
2158  self.assertEqual(random_bar(x), (x + 1)[0:1])
2159 
2160  def test_export_tensoroption_to(self):
2161  def foo(x):
2162  return x.new_tensor(x[0]).cpu() + x
2163 
2164  traced = torch.jit.trace(foo, (torch.rand([2])))
2165  example_outputs = traced(torch.rand([2]))
2166 
2167  f = io.BytesIO()
2168  self.assertExpected(torch.onnx._export_to_pretty_string(traced, (torch.rand([2]),), f,
2169  example_outputs=example_outputs))
2170 
2171  def test_pretty_printer(self):
2172  @torch.jit.script
2173  def if_test(a, b):
2174  # FIXME: use 0 instead of a.
2175  # c = 0
2176  c = a
2177  if bool(a < b):
2178  c = b
2179  else:
2180  c = a
2181  return c
2182 
2183  @torch.jit.script
2184  def if_one(a, b):
2185  c = b
2186  if bool(a < b):
2187  c = a
2188  return c
2189 
2190  @torch.jit.script
2191  def while_test(a, i):
2192  while bool(i < 3):
2193  a *= a
2194  i += 1
2195  return a
2196 
2197  @torch.jit.script
2198  def while_if_test(a, b):
2199  c = 0
2200  while bool(a < 10):
2201  a = a + 1
2202  b = b + 1
2203  if bool(a > b):
2204  c = 2
2205  else:
2206  c = 3
2207  return a + 1 + c
2208 
2209  @torch.jit.script
2210  def loop_use_test(y):
2211  x = y + 1
2212  z = x + 5
2213  while bool(y < 8):
2214  y += 1
2215  z = x
2216  return x, z
2217 
2218  def python_fn(x):
2219  return x + 10
2220 
2221  @torch.jit.script
2222  def python_op_name_test(y):
2223  return python_fn(y)
2224 
2225  @torch.jit.script
2226  def empty_int_list_test(y):
2227  x = torch.jit.annotate(List[int], [])
2228  return x[0]
2229 
2230  @torch.jit.script
2231  def empty_float_list_test(y):
2232  return [1.0, 2.0, 3.0]
2233 
2234  @torch.jit.script
2235  def print_weird_test(y):
2236  print("hi\016")
2237 
2238  self.assertExpected(if_test.graph.pretty_print(), "if_test")
2239  self.assertExpected(if_one.graph.pretty_print(), "if_one")
2240  self.assertExpected(while_test.graph.pretty_print(), "while_test")
2241  self.assertExpected(while_if_test.graph.pretty_print(), "while_if_test")
2242  self.assertExpected(loop_use_test.graph.pretty_print(), "loop_use_test")
2243  self.assertExpected(python_op_name_test.graph.pretty_print(), "python_op_name_test")
2244  self.assertExpected(empty_int_list_test.graph.pretty_print(), "empty_int_list_test")
2245  self.assertExpected(empty_float_list_test.graph.pretty_print(), "empty_float_list_test")
2246  self.assertExpected(print_weird_test.graph.pretty_print(), "print_weird_test")
2247 
2248  def test_cu_escaped_number(self):
2249  cu = torch.jit.CompilationUnit('''
2250  def foo(a):
2251  print("hi\016")
2252  ''')
2253  self.assertExpected(cu.foo.graph.pretty_print())
2254 
2255  def test_import_method(self):
2256  @torch.jit.script
2257  def foo(x, y):
2258  return 2 * x + y
2259 
2260  r, _ = foo._python_print()
2261  mod = torch.jit.ScriptModule()
2262  torch._C._jit_import_methods(mod, "op_version_set = 0\n{}".format(r), [])
2263  self.assertExpected(mod.graph.pretty_print())
2264 
2265  def test_function_default_values(self):
2266  outer_var = torch.tensor(20)
2267  outer_var2 = torch.tensor(30)
2268  a = torch.tensor(0.5)
2269  b = torch.tensor(10)
2270 
2271  @torch.jit.script
2272  def simple_fn(x, a=a, b=b, c=outer_var + outer_var2):
2273  return x + a + b + c
2274 
2275  self.assertEqual(
2276  simple_fn(torch.ones(1)),
2277  torch.ones(1) + 0.5 + 10 + (20 + 30))
2278  self.assertEqual(
2279  simple_fn(torch.ones(1), torch.tensor(1), torch.tensor(3), torch.tensor(4)),
2280  torch.ones(1) + 1 + 3 + 4)
2281 
2282  outer_c = torch.tensor(9)
2283  outer_flag = torch.tensor(False)
2284 
2285  @torch.jit.script
2286  def bool_fn(x, a=outer_c, flag=outer_flag):
2287  if bool(flag):
2288  result = x
2289  else:
2290  result = x + a
2291  return result
2292 
2293  self.assertEqual(bool_fn(torch.ones(1)), torch.ones(1) + 9)
2294  self.assertEqual(
2295  bool_fn(torch.ones(1), torch.tensor(1), torch.tensor(True)),
2296  torch.ones(1))
2297 
2298  @torch.jit.script
2299  def none_fn(x=None):
2300  # type: (Optional[int]) -> Optional[int]
2301  return x
2302 
2303  self.assertEqual(none_fn(), None)
2304  self.assertEqual(none_fn(1), 1)
2305 
2306  @torch.jit.script
2307  def hints(x, a=0.5, b=10):
2308  # type: (Tensor, float, int) -> Tensor
2309  return x + a + b
2310 
2311  self.assertEqual(hints(torch.ones(1)), torch.ones(1) + 0.5 + 10)
2312 
2313  with self.assertRaisesRegex(RuntimeError, "Expected a default value"):
2314 
2315  @torch.jit.script
2316  def hints_bad_types(x, a=10, b=0.5): # noqa: T484
2317  # type: (Tensor, float, int) -> Tensor
2318  return x + a + b
2319 
2320  def test_module_default_values(self):
2321  four = torch.tensor(4)
2322 
2323  class Test(torch.jit.ScriptModule):
2324  def __init__(self):
2325  super(Test, self).__init__()
2326 
2327  @torch.jit.script_method
2328  def forward(self, input, other=four):
2329  return input + other
2330 
2331  t = Test()
2332  self.assertEqual(t(torch.ones(1)), torch.ones(1) + 4)
2333 
2334  def test_warnings(self):
2335  import warnings
2336 
2337  @torch.jit.script
2338  def fn(x):
2339  if bool(x < 2):
2340  warnings.warn("x is less than 2")
2341  return x
2342 
2343  FileCheck().check("aten::warn").run(str(fn.graph))
2344 
2345  def test_no_erroneous_warnings(self):
2346  import warnings
2347 
2348  def fn(x):
2349  if bool(x > 0):
2350  warnings.warn('This should NOT be printed')
2351  x += 1
2352  return x
2353 
2354  with warnings.catch_warnings(record=True) as warns:
2355  fn_script = torch.jit.script(fn)
2356  fn_script(torch.tensor(0))
2357  warns = [str(w.message) for w in warns]
2358  self.assertEqual(len(warns), 0)
2359 
2360  @unittest.skipIf(sys.platform == "win32", "TODO: need to fix this test case for Windows")
2361  def test_torch_load_error(self):
2362  class J(torch.jit.ScriptModule):
2363  def __init__(self):
2364  super(J, self).__init__()
2365 
2366  @torch.jit.script_method
2367  def forward(self, input):
2368  return input + 100
2369 
2370  j = J()
2371  with tempfile.NamedTemporaryFile() as f:
2372  j.save(f.name)
2373  with self.assertRaisesRegex(RuntimeError, "is a zip"):
2374  torch.load(f.name)
2375 
2376  def test_legacy_constructors(self):
2377  def fn(x):
2378  return x.new_zeros(5, 5, requires_grad=False)
2379 
2380  with warnings.catch_warnings(record=True) as warns:
2381  torch.jit.trace(fn, (torch.ones(2, 2)))
2382  warns = [str(w.message) for w in warns]
2383  self.assertEqual(len(warns), 1)
2384  self.assertEqual(warns[0], "new_zeros is a legacy constructor and is not supported in the JIT.")
2385 
2386  def test_python_bindings(self):
2387  lstm_cell = torch.jit.script(LSTMCellS)
2388 
2389  def lstm(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
2390  for i in range(x.size(0)):
2391  hx, cx = lstm_cell(x[i], hx, cx, w_ih, w_hh, b_ih, b_hh)
2392  return hx
2393 
2394  slstm = torch.jit.script(lstm)
2395 
2396  inputs = get_lstm_inputs('cpu', training=True, seq_length=10)
2397  slstm(*inputs).sum().backward()
2398  global fw_graph
2399  fw_graph = slstm.graph_for(*inputs)
2400  nodes = [n for n in fw_graph.nodes()]
2401  tested_blocks = False
2402  for node in nodes:
2403  for output in [o for o in node.outputs()]:
2404  self.assertTrue(hasattr(output, 'type'))
2405  self.assertTrue(output.type() is not None)
2406  for input in [i for i in node.inputs()]:
2407  self.assertTrue(hasattr(input, 'type'))
2408  self.assertTrue(input.type() is not None)
2409  for block in [b for b in node.blocks()]:
2410  tested_blocks = True
2411  self.assertTrue(hasattr(block, 'inputs'))
2412  self.assertTrue(hasattr(block, 'outputs'))
2413  for output in [o for o in block.outputs()]:
2414  self.assertTrue(hasattr(output, 'type'))
2415  self.assertTrue(output.type() is not None)
2416  for input in [i for i in block.inputs()]:
2417  self.assertTrue(hasattr(input, 'type'))
2418  self.assertTrue(input.type() is not None)
2419  self.assertTrue(hasattr(block, 'returnNode'))
2420  self.assertTrue(type(block.returnNode()) == torch._C.Node)
2421  self.assertTrue(hasattr(block, 'paramNode'))
2422  self.assertTrue(type(block.paramNode()) == torch._C.Node)
2423  self.assertTrue(tested_blocks)
2424 
2425 
2427  # generate random examples and create an batchtensor with them
2428  def rand_batch(self, *dims):
2429  dims = [dim for dim in dims if dim != ()]
2430  xs = [torch.rand(1, *(random.randint(1, size) if b else size for b, size in dims[1:]),
2431  requires_grad=True) for i in range(dims[0])]
2432  xb = BatchTensor(xs, torch.tensor([b for b, d in dims[1:]]).byte())
2433  return xs, xb
2434 
2435  def test_create_batchtensor(self):
2436  # create from tensorlist
2437  xs, batch = self.rand_batch(4, (True, 3), (False, 2), (True, 5))
2438  self.assertEqual(xs, batch.examples())
2439  # create from data, mask, dims
2440  batch2 = BatchTensor(batch.get_data(), batch.get_mask(), batch.get_dims())
2441  self.assertEqual(xs, batch2.examples())
2442  # expand a tensor to a batchtensor given batch_size
2443  xs = torch.rand(3, 4, 5)
2444  batch3 = BatchTensor(xs, 2)
2445  xs = xs.unsqueeze(0)
2446  self.assertEqual([xs, xs], batch3.examples())
2447 
2448  def test_batch_elementwise_unary(self):
2449  @torch.jit.batch(batch_size=4)
2450  def tanh(a):
2451  return torch.tanh(a)
2452 
2453  xs, batch = self.rand_batch(4, (True, 3), (False, 2))
2454  res_batch = tanh(batch)
2455  res = [torch.tanh(xs[j]) for j in range(4)]
2456  self.assertEqual(res, res_batch.examples())
2457 
2458  def test_batch_elementwise_binary(self):
2459  @torch.jit.batch(batch_size=4)
2460  def add(a, b):
2461  return a + b
2462 
2463  xs, batch = self.rand_batch(4, (True, 3), (False, 2))
2464  xs2, batch2 = xs, batch
2465  res_batch = add(batch, batch2)
2466  res = [torch.add(xs[j], xs2[j]) for j in range(4)]
2467  self.assertEqual(res, res_batch.examples())
2468 
2469  # test broadcast
2470  xs, batch = self.rand_batch(4, (False, 3), (False, 2))
2471  b = torch.rand(3, 2)
2472  res_batch = add(batch, b)
2473  res = [torch.add(xs[j], b) for j in range(4)]
2474  self.assertEqual(res, res_batch.examples())
2475 
2476  def test_batch_mm(self):
2477  @torch.jit.batch(batch_size=4)
2478  def mm(a, b):
2479  return torch.mm(a, b)
2480 
2481  xs, batch = self.rand_batch(4, (True, 3), (False, 2))
2482  xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
2483  res_batch = mm(batch, batch2)
2484  res = [torch.mm(xs[j].squeeze(0), xs2[j].squeeze(0)).unsqueeze(0) for j in range(4)]
2485  self.assertEqual(res, res_batch.examples())
2486 
2487  # test broadcast
2488  b = torch.rand(2, 4)
2489  res_batch = mm(batch, b)
2490  res = [torch.mm(xs[j].squeeze(0), b).unsqueeze(0) for j in range(4)]
2491  self.assertEqual(res, res_batch.examples())
2492 
2493  def test_batch_matmul(self):
2494  @torch.jit.batch(batch_size=4)
2495  def matmul(a, b):
2496  return torch.matmul(a, b)
2497 
2498  def matmul_test(xs, batch, xs2, batch2):
2499  ys = [torch.matmul(xs[j].squeeze(0), xs2[j].squeeze(0)).unsqueeze(0) for j in range(4)]
2500  ybs = matmul(batch, batch2)
2501  self.assertEqual(ys, ybs.examples())
2502 
2503  # 1 dimension * 1 dimension
2504  xs, batch = self.rand_batch(4, (False, 2))
2505  xs2, batch2 = self.rand_batch(4, (False, 2))
2506  matmul_test(xs, batch, xs2, batch2)
2507  # 1 dimension * 2 dimension
2508  xs, batch = self.rand_batch(4, (False, 2))
2509  xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
2510  matmul_test(xs, batch, xs2, batch2)
2511  # 2 dimension * 1 dimensions
2512  xs, batch = self.rand_batch(4, (True, 3), (False, 2))
2513  xs2, batch2 = self.rand_batch(4, (False, 2))
2514  matmul_test(xs, batch, xs2, batch2)
2515  # 2 dimension * 2 dimension
2516  xs, batch = self.rand_batch(4, (True, 3), (False, 2))
2517  xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
2518  matmul_test(xs, batch, xs2, batch2)
2519 
2520  def test_batch_select(self):
2521  @torch.jit.batch(batch_size=4)
2522  def select(x):
2523  return torch.select(x, 1, 0)
2524 
2525  xs, batch = self.rand_batch(4, (True, 3), (True, 2))
2526  res_batch = select(batch)
2527  res = [torch.select(xs[j], 1, 0) for j in range(4)]
2528  self.assertEqual(res, res_batch.examples())
2529 
2530  xs, batch = self.rand_batch(4, (False, 3), (True, 2))
2531  res_batch = select(batch)
2532  res = [torch.select(xs[j], 1, 0) for j in range(4)]
2533  self.assertEqual(res, res_batch.examples())
2534 
2535  def test_batch_index_select(self):
2536  @torch.jit.batch(batch_size=4)
2537  def index_select(x, ind):
2538  return x.index_select(1, ind)
2539 
2540  xs, batch = self.rand_batch(4, (False, 5), (True, 2))
2541  ind = [torch.randint(0, 4, (1,), dtype=torch.long) for i in range(4)]
2542  ind_batch = BatchTensor(ind, torch.tensor([]).byte())
2543  res_batch = index_select(batch, ind_batch)
2544  res = [torch.index_select(xs[j], 1, ind[j]) for j in range(4)]
2545  self.assertEqual(res, res_batch.examples())
2546 
2547  def test_batch_where(self):
2548  @torch.jit.batch(batch_size=4)
2549  def where(c, a, b):
2550  return torch.where(c, a, b)
2551 
2552  xs, batch = self.rand_batch(4, (False, 3), (False, 2))
2553  xs2, batch2 = self.rand_batch(4, (False, 3), (False, 2))
2554 
2555  dims = [4, (False, 3), (False, 2)]
2556  xs_cond = [torch.rand(1, 3, 2).byte() for i in range(dims[0])]
2557  batch_cond = BatchTensor(xs_cond, torch.tensor([b for b, d in dims[1:]]))
2558 
2559  res_batch = where(batch_cond, batch, batch2)
2560  res = [torch.where(xs_cond[j], xs[j], xs2[j]) for j in range(4)]
2561  self.assertEqual(res, res_batch.examples())
2562 
2563  def test_batch_argmax(self):
2564  @torch.jit.batch(batch_size=4)
2565  def argmax(a):
2566  return torch.argmax(a, 1)
2567 
2568  xs, batch = self.rand_batch(4, (True, 5), (True, 6))
2569  res_batch = argmax(batch)
2570  res = [torch.argmax(xs[j], 1) for j in range(4)]
2571  self.assertEqual(res, res_batch.examples())
2572 
2573  @torch.jit.batch(batch_size=4)
2574  def argmax(a):
2575  return torch.argmax(a, 1, False)
2576 
2577  res_batch = argmax(batch)
2578  res = [torch.argmax(xs[j], 1, False) for j in range(4)]
2579  self.assertEqual(res, res_batch.examples())
2580 
2581  def test_batch_topk(self):
2582  @torch.jit.batch(batch_size=4)
2583  def topk(a):
2584  return torch.topk(a, 3, 1)
2585 
2586  xs, batch = self.rand_batch(4, (False, 5), (True, 6))
2587 
2588  # along static dim
2589  res_batch = topk(batch)
2590  res = [torch.topk(xs[j], 3, 1)[0] for j in range(4)]
2591  res_idx = [torch.topk(xs[j], 3, 1)[1] for j in range(4)]
2592  self.assertEqual(res, res_batch[0].examples())
2593  self.assertEqual(res_idx, res_batch[1].examples())
2594 
2595  @torch.jit.batch(batch_size=4)
2596  def topk(a):
2597  return torch.topk(a, 1, 2)
2598 
2599  # along dynamic dim
2600  res_batch = topk(batch)
2601  res = [torch.topk(xs[j], 1, 2)[0] for j in range(4)]
2602  res_idx = [torch.topk(xs[j], 1, 2)[1] for j in range(4)]
2603  self.assertEqual(res, res_batch[0].examples())
2604  self.assertEqual(res_idx, res_batch[1].examples())
2605 
2606  def test_batch_softmax(self):
2607  @torch.jit.batch(batch_size=4)
2608  def softmax(a):
2609  return torch.softmax(a, 1)
2610 
2611  xs, batch = self.rand_batch(4, (False, 5), (True, 6))
2612 
2613  # along static dim
2614  res_batch = softmax(batch)
2615  res = [torch.softmax(xs[j], 1) for j in range(4)]
2616  self.assertEqual(res, res_batch.examples())
2617 
2618  @torch.jit.batch(batch_size=4)
2619  def softmax(a):
2620  return torch.softmax(a, 2)
2621 
2622  # along dynamic dim
2623  res_batch = softmax(batch)
2624  res = [torch.softmax(xs[j], 2) for j in range(4)]
2625  self.assertEqual(res, res_batch.examples())
2626 
2627  def test_batch_view(self):
2628  @torch.jit.batch(batch_size=4)
2629  def view(a):
2630  return a.view([4, -1, 3])
2631 
2632  xs, batch = self.rand_batch(4, (True, 5), (False, 3))
2633  res_batch = view(batch)
2634  res = [xs[j].view([1, -1, 3]) for j in range(4)]
2635  self.assertEqual(res, res_batch.examples())
2636 
2637  def test_batch_cat(self):
2638  @torch.jit.batch(batch_size=4)
2639  def cat2(a, b):
2640  return torch.cat([a, b], 2)
2641 
2642  xs, batch = self.rand_batch(4, (True, 5), (False, 3))
2643  xs2, batch2 = xs, batch
2644  res_batch = cat2(batch, batch2)
2645  res = [torch.cat([xs[j], xs2[j]], 2) for j in range(4)]
2646  self.assertEqual(res, res_batch.examples())
2647 
2648  def test_batch_sum(self):
2649  @torch.jit.batch(batch_size=4)
2650  def batch_sum(a):
2651  return a.sum()
2652 
2653  xs, batch = self.rand_batch(4, (True, 5), (False, 3))
2654  res_batch = batch_sum(batch)
2655  res = [xs[j].sum().unsqueeze(0) for j in range(4)]
2656  self.assertEqual(res, res_batch.examples())
2657 
2658  def test_if_else(self):
2659  def single_if(a, b):
2660  if bool(a > b):
2661  a = a + b
2662  else:
2663  a = a - b
2664  return a
2665 
2666  batch_if = torch.jit.batch(batch_size=4)(single_if)
2667 
2668  a, batch_a = self.rand_batch(4, ())
2669  b, batch_b = self.rand_batch(4, ())
2670  res_batch = batch_if(batch_a, batch_b)
2671  res = [single_if(a[j], b[j]) for j in range(4)]
2672  self.assertEqual(res, res_batch.examples())
2673 
2674  script_if = torch.jit.script(single_if)
2675  torch.to_batch_graph(script_if.graph)
2676 
2677  def test_if_else_with_scalar(self):
2678  def single_if(a, b):
2679  if bool(a > 0.1):
2680  a = a + b
2681  else:
2682  a = a - b
2683  return a
2684 
2685  batch_if = torch.jit.batch(batch_size=4)(single_if)
2686 
2687  a, batch_a = self.rand_batch(4, ())
2688  b, batch_b = self.rand_batch(4, ())
2689  res_batch = batch_if(batch_a, batch_b)
2690  res = [single_if(a[j], b[j]) for j in range(4)]
2691  self.assertEqual(res, res_batch.examples())
2692 
2693  script_if = torch.jit.script(single_if)
2694  torch.to_batch_graph(script_if.graph)
2695 
2696  def test_if_noelse(self):
2697  def single_if(a, b):
2698  if bool(a > b):
2699  a = a + b
2700  return a
2701 
2702  batch_if = torch.jit.batch(batch_size=4)(single_if)
2703 
2704  a, batch_a = self.rand_batch(4, ())
2705  b, batch_b = self.rand_batch(4, ())
2706  res_batch = batch_if(batch_a, batch_b)
2707  res = [single_if(a[j], b[j]) for j in range(4)]
2708  self.assertEqual(res, res_batch.examples())
2709 
2710  script_if = torch.jit.script(single_if)
2711  torch.to_batch_graph(script_if.graph)
2712 
2713  def test_if_noelse_with_scalar(self):
2714  def single_if(a, b):
2715  if bool(a > 0.1):
2716  a = a + b
2717  return a
2718 
2719  batch_if = torch.jit.batch(batch_size=4)(single_if)
2720 
2721  a, batch_a = self.rand_batch(4, ())
2722  b, batch_b = self.rand_batch(4, ())
2723  res_batch = batch_if(batch_a, batch_b)
2724  res = [single_if(a[j], b[j]) for j in range(4)]
2725  self.assertEqual(res, res_batch.examples())
2726 
2727  script_if = torch.jit.script(single_if)
2728  torch.to_batch_graph(script_if.graph)
2729 
2730  def test_while(self):
2731  def single_while(a, b):
2732  while bool(a > b):
2733  a = a - b
2734  return a
2735 
2736  batch_while = torch.jit.batch(batch_size=4)(single_while)
2737 
2738  a, batch_a = self.rand_batch(4, ())
2739  b = [torch.abs(torch.rand(1)) for i in range(4)]
2740  batch_b = BatchTensor(b, torch.tensor([]).byte())
2741  res_batch = batch_while(batch_a, batch_b)
2742  res = [single_while(a[j], b[j]) for j in range(4)]
2743  self.assertEqual(res, res_batch.examples())
2744 
2745  script_while = torch.jit.script(single_while)
2746  torch.to_batch_graph(script_while.graph)
2747 
2748  def test_for(self):
2749  def single_for(x, y):
2750  for _ in range(10):
2751  x = x + y
2752  return x
2753 
2754  batch_for = torch.jit.batch(batch_size=4)(single_for)
2755 
2756  a, batch_a = self.rand_batch(4, ())
2757  b, batch_b = self.rand_batch(4, ())
2758  res_batch = batch_for(batch_a, batch_b)
2759  res = [single_for(a[j], b[j]) for j in range(4)]
2760  self.assertEqual(res, res_batch.examples())
2761 
2762  script_for = torch.jit.script(single_for)
2763  torch.to_batch_graph(script_for.graph)
2764 
2765  def test_lstm(self):
2766  def LSTM(x_all, h, c, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c):
2767  for i in range(x_all.size(1)):
2768  x = x_all.select(1, i)
2769  i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
2770  f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
2771  o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
2772  # activations
2773  i_t = torch.sigmoid(i_t)
2774  f_t = torch.sigmoid(f_t)
2775  o_t = torch.sigmoid(o_t)
2776  # cell computations
2777  c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
2778  c_t = torch.tanh(c_t)
2779  c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
2780  h_t = torch.mul(o_t, torch.tanh(c_t))
2781  h = h_t
2782  c = c_t
2783  return h
2784 
2785  LSTM_batch = torch.jit.batch(batch_size=4)(LSTM)
2786 
2787  batch_size, input_size, hidden_size = 4, 3, 2
2788  xs, batch = self.rand_batch(batch_size, (True, 4), (False, input_size))
2789  hx, h_batch = self.rand_batch(batch_size, (False, hidden_size))
2790  cx, c_batch = self.rand_batch(batch_size, (False, hidden_size))
2791 
2792  # input to hidden weights
2793  w_xi = torch.rand(input_size, hidden_size)
2794  w_xf = torch.rand(input_size, hidden_size)
2795  w_xo = torch.rand(input_size, hidden_size)
2796  w_xc = torch.rand(input_size, hidden_size)
2797  # hidden to hidden weights
2798  w_hi = torch.rand(hidden_size, hidden_size)
2799  w_hf = torch.rand(hidden_size, hidden_size)
2800  w_ho = torch.rand(hidden_size, hidden_size)
2801  w_hc = torch.rand(hidden_size, hidden_size)
2802  # bias terms
2803  b_i = torch.rand(hidden_size)
2804  b_f = torch.rand(hidden_size)
2805  b_o = torch.rand(hidden_size)
2806  b_c = torch.rand(hidden_size)
2807 
2808  ys = [LSTM(xs[j], hx[j], cx[j], w_xi, w_xf, w_xo, w_xc,
2809  w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c) for j in range(batch_size)]
2810  ybs = LSTM_batch(batch, h_batch, c_batch, w_xi, w_xf, w_xo, w_xc,
2811  w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c)
2812  self.assertEqual(ys, ybs.examples())
2813 
2814  def test_greedy_search(self):
2815  def greedy(x, h, c, embed, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
2816  b_i, b_f, b_o, b_c, w_hs, b_s, iter_num):
2817  iter_count = torch.zeros_like(iter_num)
2818  while bool(iter_count < iter_num):
2819  iter_count = iter_count + 1
2820  # LSTM Cell
2821  i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
2822  f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
2823  o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
2824  # activations
2825  i_t = torch.sigmoid(i_t)
2826  f_t = torch.sigmoid(f_t)
2827  o_t = torch.sigmoid(o_t)
2828  # cell computations
2829  c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
2830  c_t = torch.tanh(c_t)
2831  c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
2832  h_t = torch.mul(o_t, torch.tanh(c_t))
2833  h = h_t
2834  c = c_t
2835  # calculate feature with max probability
2836  s_t = torch.matmul(h_t, w_hs) + b_s
2837  p_t = torch.softmax(s_t, 1)
2838  i_t = torch.argmax(p_t, 1)
2839  x = embed.index_select(1, i_t).squeeze(1)
2840  return h
2841 
2842  greedy_batch = torch.jit.batch(batch_size=4)(greedy)
2843 
2844  batch_size, input_size, hidden_size, vocab_size = 4, 6, 8, 7
2845  xs, batch = self.rand_batch(batch_size, (False, input_size))
2846  hx, h_batch = self.rand_batch(batch_size, (False, hidden_size))
2847  cx, c_batch = self.rand_batch(batch_size, (False, hidden_size))
2848  embed, embed_batch = self.rand_batch(batch_size, (False, vocab_size), (False, input_size))
2849  iter_num = [torch.randint(2, 5, (1,)) for i in range(batch_size)]
2850  iter_num_batch = BatchTensor(iter_num, torch.tensor([]).byte())
2851 
2852  # input to hidden weights
2853  w_xi = torch.rand(input_size, hidden_size)
2854  w_xf = torch.rand(input_size, hidden_size)
2855  w_xo = torch.rand(input_size, hidden_size)
2856  w_xc = torch.rand(input_size, hidden_size)
2857  # hidden to hidden weights
2858  w_hi = torch.rand(hidden_size, hidden_size)
2859  w_hf = torch.rand(hidden_size, hidden_size)
2860  w_ho = torch.rand(hidden_size, hidden_size)
2861  w_hc = torch.rand(hidden_size, hidden_size)
2862  # bias terms
2863  b_i = torch.rand(hidden_size)
2864  b_f = torch.rand(hidden_size)
2865  b_o = torch.rand(hidden_size)
2866  b_c = torch.rand(hidden_size)
2867  # hidden to vocab weights, bias
2868  w_hs = torch.rand(hidden_size, vocab_size)
2869  b_s = torch.rand(vocab_size)
2870 
2871  ys = [greedy(xs[j], hx[j], cx[j], embed[j], w_xi, w_xf, w_xo, w_xc,
2872  w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num[j]) for j in range(batch_size)]
2873  ybs = greedy_batch(batch, h_batch, c_batch, embed_batch, w_xi, w_xf, w_xo, w_xc,
2874  w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num_batch)
2875  self.assertEqual(ys, ybs.examples())
2876 
2877  def test_beam_search(self):
2878  def beam(x, h, c, embed, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
2879  b_i, b_f, b_o, b_c, w_hs, b_s, iter_num, idx):
2880  k = 5
2881  vocab_size = embed.size(1)
2882  iter_count = torch.zeros_like(iter_num)
2883  max_len = idx.size(2)
2884  while bool(iter_count < iter_num):
2885  iter_count = iter_count + 1
2886  # LSTM Cell
2887  i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
2888  f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
2889  o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
2890  # activations
2891  i_t = torch.sigmoid(i_t)
2892  f_t = torch.sigmoid(f_t)
2893  o_t = torch.sigmoid(o_t)
2894  # cell computations
2895  c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
2896  c_t = torch.tanh(c_t)
2897  c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
2898  h_t = torch.mul(o_t, torch.tanh(c_t))
2899  h = h_t
2900  c = c_t
2901  # calculate features with max probability
2902  s_t = torch.matmul(h_t, w_hs) + b_s
2903  s_t = s_t.view([1, s_t.size(1) * s_t.size(2)])
2904  p_t = torch.softmax(s_t, 1)
2905  prob_t, idx_t = torch.topk(p_t, k, 1)
2906  if(int(idx_t.dim()) > 1):
2907  idx_t_tmp = idx_t.squeeze(0)
2908  else:
2909  idx_t_tmp = idx_t
2910  new_y = torch.fmod(idx_t_tmp, vocab_size)
2911  pre_y = idx_t_tmp / vocab_size
2912  x = embed.index_select(1, new_y)
2913  h = h_t.index_select(1, pre_y)
2914  c = c_t.index_select(1, pre_y)
2915  iter = int(iter_count[0])
2916  idx = torch.cat([idx.narrow(2, 0, iter).index_select(1, pre_y),
2917  torch.fmod(idx_t, vocab_size).unsqueeze(-1),
2918  idx.narrow(2, iter, max_len - iter)], 2)
2919  idx = idx.narrow(2, 0, max_len)
2920  return idx
2921 
2922  beam_batch = torch.jit.batch(batch_size=4)(beam)
2923 
2924  k = 5
2925  batch_size, input_size, hidden_size, vocab_size = 4, 6, 8, 7
2926  max_len = 5
2927  xs, batch = self.rand_batch(batch_size, (False, 1), (False, input_size))
2928  hx, h_batch = self.rand_batch(batch_size, (False, 1), (False, hidden_size))
2929  cx, c_batch = self.rand_batch(batch_size, (False, 1), (False, hidden_size))
2930  embed, embed_batch = self.rand_batch(batch_size, (False, vocab_size), (False, input_size))
2931  iter_num = [torch.randint(2, max_len + 1, (1,)) for i in range(batch_size)]
2932  iter_num_batch = BatchTensor(iter_num, torch.tensor([]).byte())
2933 
2934  # input to hidden weights
2935  w_xi = torch.rand(input_size, hidden_size)
2936  w_xf = torch.rand(input_size, hidden_size)
2937  w_xo = torch.rand(input_size, hidden_size)
2938  w_xc = torch.rand(input_size, hidden_size)
2939  # hidden to hidden weights
2940  w_hi = torch.rand(hidden_size, hidden_size)
2941  w_hf = torch.rand(hidden_size, hidden_size)
2942  w_ho = torch.rand(hidden_size, hidden_size)
2943  w_hc = torch.rand(hidden_size, hidden_size)
2944  # bias terms
2945  b_i = torch.rand(1, hidden_size)
2946  b_f = torch.rand(1, hidden_size)
2947  b_o = torch.rand(1, hidden_size)
2948  b_c = torch.rand(1, hidden_size)
2949  # hidden to vocab weights, bias
2950  w_hs = torch.rand(hidden_size, vocab_size)
2951  b_s = torch.rand(1, vocab_size)
2952 
2953  idx_batch = torch.jit.BatchTensor(torch.zeros([batch_size, k, max_len], dtype=torch.long),
2954  torch.zeros([batch_size, 1, max_len]).byte(),
2955  torch.tensor([0, 1]).byte())
2956  idx = [torch.zeros([1, k, max_len], dtype=torch.long) for _ in range(batch_size)]
2957 
2958  ys = [beam(xs[j], hx[j], cx[j], embed[j], w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
2959  b_i, b_f, b_o, b_c, w_hs, b_s, iter_num[j], idx[j]).narrow(2, 0, int(iter_num[j]))
2960  for j in range(batch_size)]
2961  ybs = beam_batch(batch, h_batch, c_batch, embed_batch, w_xi, w_xf, w_xo, w_xc,
2962  w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num_batch, idx_batch)
2963  self.assertEqual(ys, ybs.examples())
2964 
2965 
2966 def execWrapper(code, glob, loc):
2967  if PY2:
2968  exec(code) in glob, loc
2969  else:
2970  exec(code, glob, loc)
2971 
2972 
2974  @contextmanager
2975  def capture_stdout(self):
2976  # No idea how to capture stdout from C++ on Windows
2977  if WINDOWS:
2978  yield ['']
2979  return
2980  import os
2981  import fcntl
2982  import errno
2983  sys.stdout.flush()
2984  stdout_fd = os.dup(1)
2985  r, w = os.pipe()
2986  try:
2987  # Override stdout with r - dup is guaranteed to return the lowest free fd
2988  os.close(1)
2989  os.dup(w)
2990 
2991  captured_stdout = ['']
2992  yield captured_stdout
2993  sys.stdout.flush() # Make sure that Python hasn't buffered anything
2994 
2995  # Do the ugly dance to read all the data that was written into the pipe
2996  fcntl.fcntl(r, fcntl.F_SETFL, os.O_NONBLOCK)
2997  total_stdout = ''
2998  while True:
2999  try:
3000  total_stdout += os.read(r, 1000).decode('ascii')
3001  except OSError as e:
3002  if e.errno != errno.EAGAIN:
3003  raise
3004  break
3005  captured_stdout[0] = total_stdout
3006  finally:
3007  # Revert the change, and clean up all fds
3008  os.close(1)
3009  os.dup(stdout_fd)
3010  os.close(stdout_fd)
3011  os.close(r)
3012  os.close(w)
3013 
3014  def checkScriptRaisesRegex(self, script, inputs, exception, regex,
3015  optimize=True, outputs=None, capture_output=False):
3016  """
3017  Checks that a given function will throw the correct exception,
3018  when executed with normal python, the string frontend, and the AST frontend
3019  """
3020  # normal python
3021  with self.assertRaisesRegex(exception, regex):
3022  script(*inputs)
3023  # string frontend
3024  with self.assertRaisesRegex(exception, regex):
3025  source = textwrap.dedent(inspect.getsource(script))
3026  cu = torch.jit.CompilationUnit(source, optimize)
3027  ge = getattr(cu, script.__name__)
3028  ge(*inputs)
3029  # python AST frontend
3030  with self.assertRaisesRegex(exception, regex):
3031  ge = torch.jit.script(script, optimize)
3032  ge(*inputs)
3033 
3034  def test_training_param(self):
3035  class What(torch.jit.ScriptModule):
3036  @torch.jit.script_method
3037  def forward(self, x):
3038  # type: (int) -> int
3039  if self.training:
3040  r = x
3041  else:
3042  r = x + 4
3043  # check double use of training
3044  if self.training:
3045  r = r + 1
3046  return r
3047 
3048  w = What()
3049  self.assertEqual(4, w(3))
3050  w.train(False)
3051  self.assertEqual(7, w(3))
3052 
3053  def test_jitter_bug(self):
3054  @torch.jit.script
3055  def fn2(input, kernel_size):
3056  # type: (Tensor, List[int]) -> Tensor
3057  if kernel_size[0] > 1:
3058  _stride = [2]
3059  else:
3060  _stride = kernel_size
3061  print(_stride, kernel_size)
3062  return input
3063 
3064  @torch.jit.script
3065  def fn(input):
3066  # type: (Tensor) -> Tensor
3067  return fn2(input, [1])
3068 
3069  def test_parser_kwargonly(self):
3070  cu = torch.jit.CompilationUnit('''
3071  def foo(x, *, y) -> Tuple[Tensor, Tensor]:
3072  return x, x
3073  def bar(x):
3074  return foo(x, y=x)
3075  ''')
3076  self.assertTrue('*' in cu.module._get_method('foo').pretty_print_schema())
3077  with self.assertRaisesRegex(RuntimeError, "not provided"):
3079  def foo(x, *, y) -> Tuple[Tensor, Tensor]:
3080  return x, x
3081  def bar(x):
3082  return foo(x, x)
3083  ''')
3084 
3085  def test_annoying_doubles(self):
3086  mod = types.ModuleType("temp")
3087  mod.inf = float("inf")
3088  mod.ninf = float("-inf")
3089  mod.nan = float("nan")
3090 
3091  with self.disableModuleHook():
3092  @torch.jit.script
3093  def foo():
3094  return math.pi, 0.1, mod.inf, mod.ninf, 2.225073858507201e-308, mod.nan
3095 
3096  pp, table = foo._get_method('forward').python_print()
3097  ppv = "op_version_set = 0\n{}".format(pp)
3098  sm = torch.jit.ScriptModule()
3099  torch._C._jit_import_methods(sm, ppv, table)
3100  r = foo()
3101  r2 = sm()
3102  # use precise assert, we are checking floating point details
3103  self.assertTrue(r[:-1] == r2[:-1])
3104  self.assertTrue(math.isnan(r[-1]) and math.isnan(r2[-1]))
3105 
3106  def test_type_annotate(self):
3107 
3108  def foo(a):
3109  return torch.jit.annotate(torch.Tensor, a)
3110 
3111  self.checkScript(foo, (torch.rand(3),))
3112 
3113  def bar():
3114  a = torch.jit.annotate(List[int], [])
3115  for _ in range(10):
3116  a.append(4)
3117  return a
3118 
3119  self.checkScript(bar, ())
3120 
3121  def baz(a):
3122  return torch.jit.annotate(float, a)
3123  self.checkScript(baz, (torch.rand(()),))
3124 
3125  # test annotate none types
3126  def annotate_none():
3127  return torch.jit.annotate(Optional[torch.Tensor], None)
3128 
3129  def annotate_none_no_optional():
3130  return torch.jit.annotate(torch.Tensor, None)
3131 
3132  self.checkScript(annotate_none, ())
3133  self.checkScript(annotate_none_no_optional, ())
3134 
3135  def test_robust_op_resolution(self):
3136  neg = torch.add # misleading name to make sure we resolve by function
3137 
3138  def stuff(x):
3139  return neg(x, x)
3140 
3141  a = (torch.rand(3),)
3142  self.checkScript(stuff, a)
3143 
3144  def test_tuple_io(self):
3145  def stuff(x):
3146  # type: (Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]
3147  a, b = x
3148  return b, a
3149 
3150  a = (torch.rand(3), torch.rand(3))
3151  self.checkScript(stuff, (a,))
3152 
3153  def test_tuple_create_return(self):
3154  def stuff2(x):
3155  # type: (int) -> Tuple[Tensor, Tensor]
3156  a = (torch.ones(x), torch.zeros(x))
3157  return a
3158  self.checkScript(stuff2, (3,))
3159 
3160  def test_list_io(self):
3161  def stuff3(x):
3162  # type: (List[int]) -> Tuple[Tensor, List[int]]
3163  return torch.ones(x), x
3164  self.checkScript(stuff3, ([3, 2],))
3165 
3166  # to avoid defining sum_list in multiple tests
3167  def get_sum_list_fn(self):
3168  def sum_list(a):
3169  # type: (List[int]) -> int
3170  sum = 0
3171  for i in a:
3172  sum += i
3173 
3174  return sum
3175 
3176  return sum_list
3177 
3178  def test_sum_list_diff_elms(self):
3179  self.checkScript(self.get_sum_list_fn(), ([1, 2, 3, 4, 5],))
3180 
3181  def test_sum_list_empty(self):
3182  self.checkScript(self.get_sum_list_fn(), ([],))
3183 
3184  def test_sum_list_one(self):
3185  self.checkScript(self.get_sum_list_fn(), ([1],))
3186 
3187  def test_sum_list_literal(self):
3188 
3189  def sum_list():
3190  # type: () -> int
3191  sum = 0
3192  for i in [1, 2, 3, 4, 5]:
3193  sum += i
3194 
3195  return sum
3196 
3197  self.checkScript(sum_list, ())
3198 
3199  def test_sum_list_wrong_type(self):
3200 
3201  with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
3202  @torch.jit.script
3203  def sum_list(a):
3204  # type: (int) -> int
3205  sum = 0
3206  for i in a: # noqa: T484
3207  sum += i
3208 
3209  return sum
3210 
3211  sum_list(1)
3212 
3213  def test_bool_list_io(self):
3214  @torch.jit.script
3215  def stuff4(x):
3216  # type: (List[bool]) -> Tuple[List[bool], List[bool], List[List[bool]]]
3217  return x, [True, False], [[True]]
3218 
3219  li_1, li_2, li_3 = stuff4([True])
3220  li_3 = li_3[0]
3221  for li in [li_1, li_2, li_3]:
3222  self.assertTrue(type(li[0]) == type(True))
3223 
3224  def test_nested_list(self):
3225  def foo(z):
3226  # type: (Tuple[int, List[List[int]]]) -> int
3227  x, y = z
3228  return y[0][1]
3229  self.checkScript(foo, ((1, [[1, 2], [3, 4]]),))
3230 
3231  def test_nested_list_construct(self):
3232  def foo():
3233  return [[4]] + [[4, 5]]
3234  self.checkScript(foo, ())
3235 
3236  def test_tensor_shape(self):
3237  x = torch.empty(34, 56, 78)
3238 
3239  def f(x):
3240  return x.shape
3241 
3242  self.checkScript(f, (x,))
3243 
3244  def test_tensor_grad(self):
3245  x = torch.tensor(1.0, requires_grad=True)
3246  y = torch.tensor(1.0, requires_grad=False)
3247 
3248  def f(x):
3249  return x.requires_grad
3250 
3251  self.checkScript(f, (x,))
3252  self.checkScript(f, (y,))
3253 
3254  def test_tensor_dtype(self):
3255  x_byte = torch.empty(34, 56, 78, dtype=torch.uint8)
3256  x_long = torch.empty(34, 56, 78, dtype=torch.long)
3257  x_float32 = torch.empty(34, 56, 78, dtype=torch.float32)
3258 
3259  @torch.jit.script
3260  def byte(x):
3261  return x.dtype == torch.uint8
3262 
3263  @torch.jit.script
3264  def long(x):
3265  return x.dtype == torch.long
3266 
3267  @torch.jit.script
3268  def float32(x):
3269  return x.dtype == torch.float32
3270 
3271  self.assertTrue(byte(x_byte))
3272  self.assertFalse(byte(x_long))
3273  self.assertFalse(byte(x_float32))
3274  self.assertFalse(long(x_byte))
3275  self.assertTrue(long(x_long))
3276  self.assertFalse(long(x_float32))
3277  self.assertFalse(float32(x_byte))
3278  self.assertFalse(float32(x_long))
3279  self.assertTrue(float32(x_float32))
3280 
3281  @unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
3282  def test_tensor_device(self):
3283  cpu = torch.empty(34, 56, 78, device='cpu')
3284  gpu = torch.empty(34, 56, 78, device='cuda')
3285 
3286  @torch.jit.script
3287  def same_device(x, y):
3288  return x.device == y.device
3289 
3290  self.assertTrue(same_device(cpu, cpu))
3291  self.assertTrue(same_device(gpu, gpu))
3292  self.assertFalse(same_device(cpu, gpu))
3293 
3294  @unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
3295  def test_tensor_to_device(self):
3296  def to_device(x):
3297  return x.to(device="cuda").to(device=torch.device("cpu"))
3298 
3299  self.checkScript(to_device, (torch.ones(3, 4),))
3300 
3301  def test_tensor_to_cpu(self):
3302  def to_cpu(x):
3303  return x.cpu()
3304 
3305  x = torch.ones(3, 4)
3306  script_fn = torch.jit.script(to_cpu)
3307  self.assertEqual(to_cpu(x).device, script_fn(x).device)
3308  self.checkScript(to_cpu, (x,))
3309 
3310  @unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
3311  def test_tensor_to_cuda(self):
3312  def to_cuda(x):
3313  return x.cuda()
3314 
3315  x = torch.ones(3, 4)
3316  script_fn = torch.jit.script(to_cuda)
3317  self.assertEqual(to_cuda(x).device, script_fn(x).device)
3318  self.checkScript(to_cuda, (x,))
3319 
3320  def test_generic_list_errors(self):
3321  with self.assertRaisesRegex(RuntimeError, "previously matched to type"):
3322  @torch.jit.script
3323  def foo(x):
3324  return [[x]] + [[1]]
3325 
3326  def test_script_cu(self):
3327  cu = torch.jit.CompilationUnit('''
3328  def foo(a):
3329  b = a
3330  return b
3331  ''')
3332  a = Variable(torch.rand(1))
3333  self.assertEqual(a, cu.foo(a))
3334 
3335  # because the compilation unit ingests python strings
3336  # to use an escape sequence escape the backslash (\\n = \n)
3337  def test_string_cu(self):
3338  cu = torch.jit.CompilationUnit('''
3339  def foo(a):
3340  print(a, """a\\n\tb\\n""", 2, "a\
3341 a")
3342  return a
3343  ''')
3344  FileCheck().check("aa").check("a\\n\\tb\\n").run(str(cu.foo.graph))
3345 
3346  def test_string_ops(self):
3347  def foo():
3348  a = "a" + "b"
3349  return a + a, "ab" == "b", "ab" != "b", "ab" == "ab", "ab" != "ab"
3350 
3351  self.checkScript(foo, ())
3352 
3353  def test_string_new_line(self):
3354  with self.assertRaisesRegex(RuntimeError, "expected a valid token*"):
3356  def test_while(a):
3357  print("
3358  a")
3359  return a
3360  ''')
3361 
3362  def test_string_single_escape(self):
3363  with self.assertRaisesRegex(RuntimeError, "expected a valid token*"):
3365  def test_while(a):
3366  print("\\")
3367  return a
3368  ''')
3369 
3370  def test_script_annotation(self):
3371  @torch.jit.script
3372  def foo(a):
3373  return a + a + a
3374  s = Variable(torch.rand(2))
3375  self.assertEqual(s + s + s, foo(s))
3376 
3377  def test_inf(self):
3378  @torch.jit.script
3379  def foo(a):
3380  return a < float('inf')
3381  s = torch.rand(1)
3382  self.assertTrue(foo(s))
3383 
3384  @torch.jit.script
3385  def bar(a):
3386  return a > float('-inf')
3387  s = torch.rand(1)
3388  self.assertTrue(foo(s))
3389 
3390  def test_add(self):
3391  def func(a, b):
3392  c = a + b
3393  c += a
3394  return c
3395 
3396  a = torch.rand(1, requires_grad=True)
3397  b = torch.rand(1, requires_grad=True)
3398  self.checkScript(func, (a, b), optimize=True)
3399 
3400  def test_mul(self):
3401  def func(a, b):
3402  return a * b
3403 
3404  a = torch.rand(1, requires_grad=True)
3405  b = torch.rand(1, requires_grad=True)
3406  self.checkScript(func, (a, b), optimize=True)
3407 
3408  @unittest.skipIf(not PY35, "Python 3.5 needed")
3409  def test_matmul_py3(self):
3410  code = dedent("""
3411  def fn(a, b):
3412  return a @ b
3413  """)
3414 
3415  with tempfile.TemporaryDirectory() as tmp_dir:
3416  script_path = os.path.join(tmp_dir, 'script.py')
3417  with open(script_path, 'w') as f:
3418  f.write(code)
3419  fn = get_fn('test_matmul_py3', script_path)
3420 
3421  a = torch.rand(4, 3, requires_grad=True)
3422  b = torch.rand(3, 2, requires_grad=True)
3423  self.checkScript(fn, (a, b), optimize=True)
3424 
3425  def test_pow(self):
3426  def func(a, b):
3427  return a ** b
3428 
3429  def func2(a, b, c, d):
3430  return c + a ** b ** d
3431 
3432  a = torch.rand(1, requires_grad=True)
3433  b = torch.rand(1, requires_grad=True)
3434  c = torch.rand(1, requires_grad=True)
3435  d = torch.rand(1, requires_grad=True)
3436  self.checkScript(func, (a, b), optimize=True)
3437  self.checkScript(func2, (a, b, c, d), optimize=True)
3438 
3439  def test_triple(self):
3440  def func(x):
3441  return 3. * x
3442 
3443  x = torch.rand(1, dtype=torch.float, requires_grad=True)
3444  self.checkScript(func, [x], optimize=True)
3445 
3446  def test_slice(self):
3447  def func(x):
3448  return x[:5]
3449 
3450  x = torch.rand(10, dtype=torch.float, requires_grad=True)
3451  self.checkScript(func, [x], optimize=True)
3452 
3453  def func2(x):
3454  return x[5:]
3455 
3456  self.checkScript(func2, [x], optimize=True)
3457 
3458  def test_gather(self):
3459  def func(x):
3460  return x[0]
3461 
3462  x = torch.rand(10, dtype=torch.float, requires_grad=True)
3463  self.checkScript(func, [x], optimize=True)
3464 
3465  def test_random(self):
3466  @torch.jit.script
3467  def f(mean, std):
3468  return torch.normal(mean, std)
3469 
3470  mean, std = torch.zeros(5, 5), torch.ones(5, 5)
3471  with torch.random.fork_rng(devices=[]):
3472  output = torch.normal(mean, std)
3473  with torch.random.fork_rng(devices=[]):
3474  script_output = f(mean, std)
3475  self.assertEqual(output, script_output)
3476 
3477  def _check_code(self, code_str, fn_name, inputs):
3478  scope = {}
3479  exec(code_str, globals(), scope)
3480  cu = torch.jit.CompilationUnit(code_str)
3481  self.assertEqual(cu.func(*inputs), scope[fn_name](*inputs))
3482 
3483  @unittest.skipIf(not RUN_CUDA, 'no CUDA')
3484  def test_scriptmodule_releases_tensors_cuda(self):
3485  @torch.jit.script
3486  def fn(x, y):
3487  return x.sigmoid() * y.tanh()
3488 
3489  def test(backward=False):
3490  x = torch.randn(3, 3, dtype=torch.double, device='cuda', requires_grad=True)
3491  y = torch.randn(3, 3, dtype=torch.double, device='cuda', requires_grad=True)
3492  out = fn(x, y)
3493  if backward:
3494  out.sum().backward()
3495 
3496  with self.assertLeaksNoCudaTensors():
3497  test()
3498  test()
3499  test()
3500 
3501  with self.assertLeaksNoCudaTensors():
3502  test(backward=True)
3503  test(backward=True)
3504  test(backward=True)
3505 
3506  def test_index(self):
3507  def consec(size, start=0):
3508  numel = torch.tensor(size).prod().item()
3509  return torch.arange(numel).view(size)
3510 
3511  def check_indexing(indexing, tensor):
3512  template = dedent("""
3513  def func(x):
3514  return x{}
3515  """)
3516 
3517  self._check_code(template.format(indexing), "func", [tensor])
3518 
3519  def check_dynamic_indexing(indexing, tensor, value1, value2):
3520  value1 = torch.tensor(value1)
3521  value2 = torch.tensor(value2)
3522 
3523  template = dedent("""
3524  def func(x, value1, value2):
3525  i = int(value1)
3526  j = int(value2)
3527  return x{}
3528  """)
3529 
3530  self._check_code(template.format(indexing), "func", [tensor, value1, value2])
3531 
3532  # basic slices
3533  check_indexing('[0]', consec((3, 3)))
3534  check_indexing('[1]', consec((3, 3), 10))
3535  check_indexing('[2]', consec((3, 3), 19))
3536  check_indexing('[2]', consec((3,)))
3537  check_indexing('[-1]', consec((3, 3), 19))
3538  check_indexing('[0:2]', consec((3, 3, 3)))
3539  check_indexing('[1:-1]', consec((3, 3, 3)))
3540  check_indexing('[-3:-1]', consec((6, 3)))
3541  check_indexing('[1:]', consec((3, 3)))
3542  check_indexing('[:1]', consec((3, 3)))
3543  check_indexing('[:]', consec((3, 2)))
3544 
3545  # multi-dim: indexes
3546  check_indexing('[0, 1]', consec((3, 3)))
3547  check_indexing('[0, 1]', consec((3, 3, 2)))
3548  check_indexing('[1, 0, 2]', consec((3, 3, 3)))
3549  check_indexing('[2, -1]', consec((3, 3)))
3550 
3551  # multi-dim: mixed slicing and indexing
3552  check_indexing('[0, 1:2]', consec((3, 3)))
3553  check_indexing('[0, :1]', consec((3, 3, 2)))
3554  check_indexing('[1, 2:]', consec((3, 3, 3)))
3555  check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3)))
3556  check_indexing('[1:, -1, 0]', consec((3, 3, 3, 3)))
3557  check_indexing('[-1, 2:, 1:2]', consec((3, 3, 3, 3)))
3558  check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3)))
3559  check_indexing('[-1, :, 0, 2]', consec((3, 3, 3, 3)))
3560 
3561  # zero-sized slices
3562  check_indexing('[0:0]', consec((2, 2)))
3563  check_indexing('[0:0, 1]', consec((3, 3)))
3564 
3565  # trivial expression usage
3566  check_indexing('[1+1]', consec((3, 3)))
3567  check_indexing('[1:(0 + 2)]', consec((3, 3, 3)))
3568 
3569  # dynamic expression usage
3570  check_dynamic_indexing("[i + j]", consec((3, 3)), 0, 1)
3571  check_dynamic_indexing("[i:j, i]", consec((3, 3, 2)), 0, 2)
3572 
3573  def test_tensor_item(self):
3574  def test_scalar_to_float_coercion(x):
3575  return x.item() == 1
3576 
3577  self.checkScript(test_scalar_to_float_coercion, (torch.tensor(1.0),))
3578  self.checkScript(test_scalar_to_float_coercion, (torch.tensor(1),))
3579 
3580  def test_scalar_cast(x):
3581  scalar = x.item()
3582  return int(scalar), float(scalar)
3583 
3584  self.checkScript(test_scalar_to_float_coercion, (torch.tensor(1.0),))
3585  self.checkScript(test_scalar_to_float_coercion, (torch.tensor(1),))
3586 
3587  expected_str = r"Use int\(tensor\) or float\(tensor\) to retrieve"
3588  with self.assertRaisesRegex(RuntimeError, expected_str):
3589  @torch.jit.script
3590  def int_fn(a):
3591  # type: (int) -> int
3592  return a
3593 
3594  @torch.jit.script
3595  def test_error_msg(x):
3596  return int_fn(x.item())
3597 
3598  def test_method_on_number(self):
3599  def func():
3600  c = 1
3601  return c.add(1)
3602  with self.assertRaisesRegex(RuntimeError, 'Cannot call methods on numbers'):
3603  torch.jit.script(func)
3604 
3605  # testing implicit conversion of tensors to scalars to match function arguments
3606  def test_scalar_to_num_conversions(self):
3607  @torch.jit.script
3608  def multiple_defs(x):
3609  c = 1
3610  x = x + c
3611  return x
3612 
3613  self.assertTrue("ImplicitTensorToNum" not in str(multiple_defs.graph))
3614 
3615  @torch.jit.script
3616  def tensor_to_int_script(x, tensor):
3617  return x.unsqueeze(tensor)
3618 
3619  def tensor_to_int(x, tensor):
3620  return x.unsqueeze(tensor)
3621 
3622  @torch.jit.script
3623  def tensor_to_float_script(x, tensor):
3624  return x.addcmul(tensor, tensor, value=tensor)
3625 
3626  def tensor_to_float(x, tensor):
3627  return x.addcmul(tensor, tensor, value=tensor)
3628 
3629  x = torch.zeros(10)
3630  # float tensor, float tensor with grad, int tensor (can't set grad on int tensor)
3631  tensors = [torch.tensor(1.1),
3632  torch.tensor(1.1, requires_grad=True),
3633  torch.tensor(0),
3634  torch.tensor([2])]
3635 
3636  script_funs = [tensor_to_int_script, tensor_to_float_script]
3637  funs = [tensor_to_int, tensor_to_float]
3638 
3639  # return the result, or whether exception was thrown
3640  def test_func(func, x, tensor):
3641  try:
3642  result = func(x, tensor)
3643  except RuntimeError as e:
3644  result = True
3645  except TypeError as e:
3646  result = True
3647  return result
3648 
3649  # assert result or exception equal for each (function, inputs)
3650  for tensor in tensors:
3651  for i in range(len(script_funs)):
3652  self.assertEqual(test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor))
3653 
3654  def test_tuple_to_opt_list(self):
3655  @torch.jit.script
3656  def foo(x):
3657  # type: (Optional[List[int]]) -> int
3658  return 1
3659 
3660  @torch.jit.script
3661  def tuple_call():
3662  return foo((1, 2))
3663 
3664  def test_advancedindex(self):
3665  def consec(size, start=0):
3666  numel = torch.tensor(size).prod().item()
3667  return torch.arange(numel).view(size)
3668 
3669  def check_indexing(indexing, tensor, **kwargs):
3670  indices_dict = kwargs
3671 
3672  template = dedent("""
3673  def func(x{formals}):
3674  return x{expr}
3675  """)
3676 
3677  formals = []
3678  values = []
3679  for formal, value in indices_dict.items():
3680  formals.append(formal)
3681  values.append(value)
3682 
3683  formals = ''.join(map(', {}'.format, formals))
3684  inputs = [tensor] + values
3685  self._check_code(template.format(formals=formals, expr=indexing),
3686  "func", inputs)
3687 
3688  # Indexing with tensor (basic)
3689  check_indexing('[i]', consec((3, 3)), i=torch.tensor([0]))
3690  check_indexing('[i]', consec((3, 3)), i=torch.tensor(1))
3691  check_indexing('[i]', consec((3, 3)), i=torch.tensor([-2]))
3692  check_indexing('[i]', consec((3, 3), 2), i=torch.tensor([0, 0]))
3693  check_indexing('[i]', consec((3, 3, 2, 2)), i=torch.tensor([0, -2, 1]))
3694 
3695  # NB: indexing with tensors and indexing with sequences can be implemented
3696  # in a very similar way (sequences are converted to tensors), so only one
3697  # case needs to be tested extensively.
3698  # XXX: When we can index with sequences, replace these cases with
3699  # sequence indexing expressions; those are much easier to read.
3700 
3701  # Misc sequence advanced indexing
3702  inp = consec((4, 8, 5))
3703  to_check = [
3704  # [[0, 2], [1, 3]]
3705  ['[i, j]', {'i': [0, 2], 'j': [1, 3]}],
3706  # [[0, 2], [1, 3], [1, 1]]
3707  ['[i, j, k]', {'i': [0, 2], 'j': [1, 3], 'k': [1, 1]}],
3708  # [[0, 2], 1, [1, 1]]
3709  ['[i, j, k]', {'i': [0, 2], 'j': 1, 'k': [1, 1]}],
3710  # [:, :, [0, 3, 4]]
3711  ['[:, :, i]', {'i': [0, 3, 4]}],
3712  # [:, [2, 4, 5, 7], 2:4]
3713  ['[:, i, 2:4]', {'i': [0, 2, 3]}],
3714  # [[2, 3], :, :]
3715  ['[i, :, :]', {'i': [2, 3]}],
3716  # [:, [0, 2, 3], [1, 3, 4]]
3717  ['[:, i, j]', {'i': [0, 2, 3], 'j': [1, 3, 4]}],
3718  # [:, [0], [1, 2, 4]]
3719  ['[:, i, j]', {'i': [0], 'j': [1, 2, 4]}],
3720  # [:, [0, 1, 3], [4]]
3721  ['[:, i, j]', {'i': [0, 1, 3], 'j': [4]}],
3722  # [:, [[0, 1], [1, 0]], [[2, 3]]]
3723  ['[:, i, j]', {'i': [[0, 1], [1, 0]], 'j': [[2, 3]]}],
3724  # [:, [[0, 1], [2, 3]], [[0]]]
3725  ['[:, i, j]', {'i': [[0, 1], [2, 3]], 'j': [[0]]}],
3726  # [:, [[5, 6]], [[0, 3], [4, 4]]]
3727  ['[:, i, j]', {'i': [[5, 6]], 'j': [[0, 3], [4, 4]]}],
3728  # [[0, 2, 3], [1, 3, 4], :]
3729  ['[i, j, :]', {'i': [0, 2, 3], 'j': [1, 3, 4]}],
3730  # [0, [1, 2, 4], :]
3731  ['[i, j, :]', {'i': 0, 'j': [1, 2, 4]}],
3732  # [[0, 1, 3], 4, :]
3733  ['[i, j, :]', {'i': [0, 1, 3], 'j': 4}],
3734  # [[[0, 1], [1, 0]], [[2, 1], [3, 5]], :]
3735  ['[i, j, :]', {'i': [[0, 1], [1, 0]], 'j': [[2, 1], [3, 5]]}],
3736  # [[[0, 1], [1, 0]], [[2, 3]], :]
3737  ['[i, j, :]', {'i': [[0, 1], [1, 0]], 'j': [[2, 3]]}],
3738  # [[[0, 1], [2, 3]], [[0]], :]
3739  ['[i, j, :]', {'i': [[0, 1], [2, 3]], 'j': [[0]]}],
3740  # [[[2, 1]], [[0, 3], [4, 4]], :]
3741  ['[i, j, :]', {'i': [[2, 1]], 'j': [[0, 3], [4, 4]]}],
3742  # [[[2]], [[0, 3], [4, 1]], 0:2]
3743  ['[i, j, 0:2]', {'i': [[2]], 'j': [[0, 3], [4, 1]]}],
3744  ]
3745 
3746  for expr, argdict in to_check:
3747  tensordict = {k: torch.tensor(v) for (k, v) in argdict.items()}
3748  check_indexing(expr, inp, **tensordict)
3749 
3750  def test_keyword(self):
3751  @torch.jit.script
3752  def func(x):
3753  return torch.sum(x, dim=0)
3754 
3755  x = torch.rand(10, dtype=torch.float, requires_grad=True)
3756  y = func(x)
3757  y2 = torch.sum(x, dim=0)
3758  self.assertEqual(y, y2)
3759 
3760  def test_constant_pooling_none(self):
3761  @torch.jit.script
3762  def typed_nones(a=None, b=None, c=None):
3763  # type: (Optional[int], Optional[bool], Optional[Tensor]) -> Tuple[Optional[int], Optional[bool], Optional[Tensor]] # noqa
3764  return a, b, c
3765 
3766  @torch.jit.script
3767  def test(a):
3768  # type: (bool) -> None
3769  if a:
3770  print(typed_nones())
3771  else:
3772  print(typed_nones())
3773 
3774  graph_str = str(test.graph)
3775  self.assertTrue(graph_str.count("bool? = prim::Constant") == 1)
3776  self.assertTrue(graph_str.count("int? = prim::Constant") == 1)
3777  self.assertTrue(graph_str.count("None = prim::Constant") == 1)
3778 
3779  def test_literal(self):
3780  def func1(a, b):
3781  c = a, b
3782  d, e = c
3783  return d + e
3784 
3785  def func2(a, b):
3786  c = a, (a, b)
3787  d, e = c
3788  f, g = e
3789  return d + f + g
3790 
3791  def func3(a, b):
3792  # type: (float, float) -> float
3793  c = 0., (0., 0.)
3794  x = True
3795  while x:
3796  x = False
3797  c = a, (a, b)
3798  d, e = c
3799  f, g = e
3800  return d + f + g
3801 
3802  a = torch.rand(1, requires_grad=True)
3803  b = torch.rand(1, requires_grad=True)
3804  self.checkScript(func1, (a, b), optimize=True)
3805  self.checkScript(func2, (a, b), optimize=True)
3806  self.checkScript(func3, (a.item(), b.item()), optimize=True)
3807 
3808  def test_expand(self):
3809  @torch.jit.script
3810  def func(x, y):
3811  return x + y
3812 
3813  x = torch.rand(2, 3, dtype=torch.float, requires_grad=True)
3814  y = torch.rand(3, dtype=torch.float, requires_grad=True)
3815  out = func(x, y)
3816  self.assertEqual(func(x, y), x + y)
3817 
3818  grad = torch.randn(2, 3, dtype=torch.float)
3819  out.backward(grad)
3820  self.assertEqual(x.grad, grad)
3821  self.assertEqual(y.grad, grad.sum(dim=0))
3822 
3823  def test_sum(self):
3824  @torch.jit.script
3825  def func(x):
3826  return x.sum(dim=[4])
3827 
3828  @torch.jit.script
3829  def func2(x):
3830  return x.sum(dim=4)
3831 
3832  # test that shape analysis is written correctly for sum with IntArrayRef[1] dim argument
3833  self.run_pass('constant_propagation', func.graph)
3834  self.run_pass('constant_propagation', func2.graph)
3835  torch._C._jit_pass_shape_analysis(
3836  func.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
3837  torch._C._jit_pass_shape_analysis(
3838  func2.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
3839  self.assertTrue(func.graph.findNode("aten::sum").output().type().kind()
3840  == "DimensionedTensorType")
3841  self.assertTrue(func2.graph.findNode("aten::sum").output().type().kind()
3842  == "DimensionedTensorType")
3843 
3844  def test_cat(self):
3845  @torch.jit.script
3846  def func(x):
3847  return torch.cat((x, x), dim=0)
3848 
3849  x = torch.rand(10, dtype=torch.float, requires_grad=True)
3850  self.assertEqual(func(x), torch.cat((x, x), dim=0))
3851 
3852  @torch.jit.script
3853  def func2(x, y):
3854  return torch.cat((x, x), y)
3855 
3856  x = torch.rand([2, 2])
3857  y = torch.tensor(1)
3858  self.assertEqual(func2(x, y), torch.cat((x, x), y))
3859 
3860  def test_cat_lifts(self):
3861  @torch.jit.script
3862  def foo(x):
3863  return torch.cat([x, x], dim=1)
3864 
3865  @torch.jit.script
3866  def foo2(x):
3867  return torch.cat([], dim=1)
3868 
3869  @torch.jit.script
3870  def foo3(x):
3871  return torch.cat([x], dim=1)
3872 
3873  for g in [foo.graph, foo2.graph, foo3.graph]:
3874  FileCheck().check("int =").check("ListConstruct").check("aten::cat").run(str(g))
3875 
3876  def test_list_literal(self):
3877  def reassign():
3878  x = [1]
3879  if True:
3880  x = [2, 3]
3881  return
3882  self.checkScript(reassign, (), optimize=False)
3883 
3884  def reassign_arity_change():
3885  x = [1]
3886  if True:
3887  x = [1, 2, 3]
3888  return
3889  self.checkScript(reassign_arity_change, (), optimize=False)
3890 
3891  def reassign_from_empty_literal():
3892  x = []
3893  if True:
3894  x = [1, 2, 3]
3895  return
3896  with self.assertRaisesRegex(RuntimeError, r"previously has type Tensor\[\]"):
3897  self.checkScript(reassign_from_empty_literal, (), optimize=False)
3898 
3899  def reassign_from_empty_builtin():
3900  x = torch.jit.annotate(List[int], [])
3901  if True:
3902  x = [1, 2, 3]
3903  y = torch.jit.annotate(List[float], [])
3904  if True:
3905  y = [1.0, 2.0, 3.0]
3906  z = []
3907  if True:
3908  z = [torch.randn([1])]
3909  return
3910  self.checkScript(reassign_from_empty_builtin, (), optimize=False)
3911 
3912  def reassign_bad_type():
3913  x = [1]
3914  if True:
3915  x = [1.0]
3916  return
3917  with self.assertRaisesRegex(RuntimeError, "previously has type"):
3918  self.checkScript(reassign_bad_type, (), optimize=False)
3919 
3920  def reassign_nested():
3921  x = torch.jit.annotate(List[int], [])
3922  if True:
3923  x = [1, 2, 3]
3924  if True:
3925  x = [1.0]
3926  return
3927  with self.assertRaisesRegex(RuntimeError, "previously has type"):
3928  self.checkScript(reassign_nested, (), optimize=False)
3929 
3930  def test_list_gather(self):
3931  def index():
3932  a = [1, 2, 3]
3933  return a[1]
3934 
3935  self.checkScript(index, ())
3936 
3937  def negative_index():
3938  a = [1, 2, 3]
3939  return a[-1]
3940 
3941  self.checkScript(negative_index, ())
3942 
3943  def bad_index():
3944  a = [1, 2, 3]
3945  return a[4]
3946 
3947  self.checkScriptRaisesRegex(bad_index, (), IndexError,
3948  "list index out of range")
3949 
3950  def bad_negative_index():
3951  a = [1, 2, 3]
3952  return a[-5]
3953 
3954  self.checkScriptRaisesRegex(bad_negative_index, (), IndexError,
3955  "list index out of range")
3956 
3957  def test_tensor_len(self):
3958  def func(x):
3959  return len(x)
3960 
3961  self.checkScript(func, [torch.ones(4, 5, 6)])
3962 
3963  def test_list_len(self):
3964  def func():
3965  a = [1, 2, 3]
3966  return len(a) == 3
3967 
3968  self.checkScript(func, ())
3969 
3970  def func2():
3971  a = []
3972  return len(a) == 0
3973 
3974  self.checkScript(func2, ())
3975 
3976  def test_list_ops(self):
3977  def test_equality():
3978  a = [1, 2, 3]
3979  b = [1, 2, 3]
3980  return a == b
3981 
3982  self.checkScript(test_equality, (), optimize=True)
3983 
3984  def test_inequality():
3985  a = [1, 2, 3]
3986  b = [1, 2, 3]
3987  return a != b
3988 
3989  self.checkScript(test_equality, (), optimize=True)
3990 
3991  def test_non_equality():
3992  a = [1, 2, 3]
3993  b = [3]
3994  return a == b
3995 
3996  self.checkScript(test_non_equality, (), optimize=True)
3997 
3998  def test_non_inequality():
3999  a = [1, 2, 3]
4000  b = [3]
4001  return a != b
4002 
4003  self.checkScript(test_non_equality, (), optimize=True)
4004 
4005  def test_list_equality_as_cond():
4006  a = [1, 2, 3]
4007  b = [3]
4008  if a == b:
4009  c = 1
4010  else:
4011  c = 2
4012  return c
4013 
4014  self.checkScript(test_list_equality_as_cond, (), optimize=True)
4015 
4016  def test_list_add():
4017  a = [1, 2, 3]
4018  b = [2]
4019  c = a + b
4020  return c == [1, 2, 3, 2]
4021 
4022  self.checkScript(test_list_add, (), optimize=True)
4023 
4024  def test_list_add_empty():
4025  a = [1, 2, 3]
4026  b = torch.jit.annotate(List[int], [])
4027  c = a + b
4028  return c == [1, 2, 3]
4029 
4030  self.checkScript(test_list_add_empty, (), optimize=True)
4031 
4032  def test_tensor_list_equality():
4033  t1 = torch.ones([1, 1])
4034  t2 = torch.ones([1, 1])
4035  x = [t1, t2]
4036  y = [t2, t1]
4037  return x == y
4038 
4039  self.checkScript(test_tensor_list_equality, (), optimize=True)
4040 
4041  def test_invalid_list_equality():
4042  t1 = torch.ones([2, 2])
4043  t2 = torch.ones([2, 2])
4044  x = [t1, t2]
4045  y = [t2, t1]
4046  # will throw since the tensors have more than one element
4047  return x == y
4048 
4050  test_invalid_list_equality,
4051  (),
4052  RuntimeError,
4053  "bool value of Tensor")
4054 
4055  def test_list_slice(self):
4056  def test_regular_slice():
4057  a = [0, 1, 2, 3, 4]
4058  return a[2:3] == [2]
4059  self.checkScript(test_regular_slice, ())
4060 
4061  def test_open_ended_slice():
4062  a = [0, 1, 2, 3, 4]
4063  return a[2:] == [2, 3, 4]
4064  self.checkScript(test_open_ended_slice, ())
4065 
4066  def test_open_ended_slice2():
4067  a = [0, 1, 2, 3, 4]
4068  return a[:2] == [0, 1]
4069  self.checkScript(test_open_ended_slice2, ())
4070 
4071  def test_negative_slice():
4072  a = [0, 1, 2, 3, 4]
4073  return a[:-1] == [0, 1, 2, 3]
4074  self.checkScript(test_negative_slice, ())
4075 
4076  def test_negative_slice2():
4077  a = [0, 1, 2, 3, 4]
4078  return a[-3:-1] == [2, 3]
4079  self.checkScript(test_negative_slice2, ())
4080 
4081  def test_backward_slice():
4082  a = [0, 1, 2, 3, 4]
4083  return a[3:2] == torch.jit.annotate(List[int], [])
4084  self.checkScript(test_backward_slice, ())
4085 
4086  def test_over_slice():
4087  a = [0, 1, 2, 3, 4]
4088  return a[3:10] == [3, 4]
4089  self.checkScript(test_backward_slice, ())
4090 
4091  def test_mutable_list_append(self):
4092  def test_append():
4093  a = [0, 1]
4094  a.append(2)
4095  a.append(3)
4096  return a == [0, 1, 2, 3]
4097  self.checkScript(test_append, ())
4098 
4099  def test_mutable_list_append_2(self):
4100  def test_append_2():
4101  a = [0, 1]
4102  a.append(2)
4103  a = [1]
4104  a.append(4)
4105  return a == [1, 4]
4106  self.checkScript(test_append_2, ())
4107 
4108  def test_mutable_list_append_if(self):
4109  def test_append_if():
4110  a = [1]
4111  if True:
4112  a.append(4)
4113  return a == [1, 4]
4114  self.checkScript(test_append_if, ())
4115 
4116  def test_mutable_list_append_if_else(self):
4117  def test_append_if_else():
4118  a = [1]
4119  if False:
4120  a.append(4)
4121  else:
4122  a.append(10)
4123  return a == [1, 10]
4124  self.checkScript(test_append_if_else, ())
4125 
4126  def test_mutable_list_append_loop(self):
4127  def test_append_loop():
4128  a = torch.jit.annotate(List[int], [])
4129  for i in range(5):
4130  a.append(i)
4131 
4132  return a == [0, 1, 2, 3, 4]
4133  self.checkScript(test_append_loop, ())
4134 
4135  def test_mutable_list_append_loop_if(self):
4136  def test_append_loop_if():
4137  a = torch.jit.annotate(List[int], [])
4138  for i in range(5):
4139  if i > 3:
4140  a.append(i)
4141  else:
4142  a.append(0)
4143 
4144  return a == [0, 0, 0, 0, 4]
4145  self.checkScript(test_append_loop_if, ())
4146 
4147  def test_mutable_list_nested_loop(self):
4148  def test_nested_loop():
4149  a = torch.jit.annotate(List[int], [])
4150  for i in range(2):
4151  for j in range(2):
4152  a.append(i + j)
4153 
4154  return a == [0, 1, 1, 2]
4155  self.checkScript(test_nested_loop, ())
4156 
4157  def test_mutable_list_function_inline(self):
4158  @torch.jit.script
4159  def bar(y):
4160  # type: (List[int]) -> None
4161  y.append(4)
4162 
4163  @torch.jit.script
4164  def foo():
4165  x = [1, 2, 3]
4166  bar(x)
4167  return x
4168 
4169  self.assertEqual(foo(), [1, 2, 3, 4])
4170 
4171  def test_mutable_list_reverse_empty(self):
4172  def test_reverse_empty():
4173  a = []
4174  a.reverse()
4175 
4176  return a == []
4177  self.checkScript(test_reverse_empty, ())
4178 
4179  def test_mutable_list_reverse(self):
4180  def test_reverse():
4181  a = [1, 2, 3, 4]
4182  a.reverse()
4183 
4184  return a == [4, 3, 2, 1]
4185  self.checkScript(test_reverse, ())
4186 
4187  def test_mutable_tensor_list_reverse(self):
4188  def test_tensor_reverse():
4189  a = [torch.tensor(1), torch.tensor(2)]
4190  a.reverse()
4191 
4192  return a == [torch.tensor(2), torch.tensor(1)]
4193  self.checkScript(test_tensor_reverse, ())
4194 
4195  def test_mutable_list_pop_empty(self):
4196  @torch.jit.script
4197  def test_pop_empty():
4198  a = torch.jit.annotate(List[int], [])
4199  return a.pop()
4200 
4201  with self.assertRaisesRegex(RuntimeError, "pop from empty list"):
4202  test_pop_empty()
4203 
4204  def test_mutable_list_pop(self):
4205  def test_pop():
4206  a = [1, 2, 3, 4]
4207  b = a.pop()
4208 
4209  return b == 4
4210 
4211  self.checkScript(test_pop, ())
4212 
4213  def test_mutable_list_pop2(self):
4214  def test_pop2():
4215  a = [1, 2, 3, 4]
4216  b = a.pop()
4217 
4218  return len(a) == 3
4219 
4220  self.checkScript(test_pop2, ())
4221 
4222  def test_mutable_list_pop_at(self):
4223  def test_pop_at():
4224  a = [1, 2, 3, 4]
4225  b = a.pop(1)
4226 
4227  return b == 2
4228 
4229  self.checkScript(test_pop_at, ())
4230 
4231  def test_mutable_list_pop_at2(self):
4232  def test_pop_at2():
4233  a = [1, 2, 3, 4]
4234  b = a.pop(1)
4235 
4236  return len(a) == 3
4237 
4238  self.checkScript(test_pop_at2, ())
4239 
4240  def test_mutable_list_pop_at_negative(self):
4241  def test_pop_at_negative():
4242  a = [1, 2, 3, 4]
4243  b = a.pop(-2)
4244 
4245  return b == 3
4246 
4247  self.checkScript(test_pop_at_negative, ())
4248 
4249  def test_mutable_list_pop_at_negative2(self):
4250  def test_pop_at_negative2():
4251  a = [1, 2, 3, 4]
4252  b = a.pop(-2)
4253 
4254  return len(a) == 3
4255 
4256  self.checkScript(test_pop_at_negative2, ())
4257 
4258  def test_mutable_list_pop_slice(self):
4259  def test_pop_slice():
4260  a = [1, 2, 3, 4]
4261  b = [1, 2, 3, 4]
4262 
4263  a.pop()
4264  b = b[:-1]
4265 
4266  return a == b
4267 
4268  self.checkScript(test_pop_slice, ())
4269 
4270  @unittest.skipIf(sys.version_info < (3, 3), "clear not supported in version < 3.3")
4271  def test_mutable_list_clear_empty(self):
4272  def test_clear_empty():
4273  a = torch.jit.annotate(List[int], [])
4274  a.clear()
4275 
4276  return len(a) == 0
4277  self.checkScript(test_clear_empty, ())
4278 
4279  @unittest.skipIf(sys.version_info < (3, 3), "clear not supported in version < 3.3")
4280  def test_mutable_list_clear(self):
4281  def test_clear():
4282  a = [1, 2, 3, 4]
4283  a.clear()
4284 
4285  return len(a) == 0
4286  self.checkScript(test_clear, ())
4287 
4288  def test_mutable_list_insert(self):
4289  def test_list_insert():
4290  a = [1, 2, 3, 4]
4291  a.insert(2, 5)
4292 
4293  return a == [1, 2, 5, 3, 4]
4294  self.checkScript(test_list_insert, ())
4295 
4296  def test_mutable_list_insert_negative(self):
4297  def test_list_insert_negative():
4298  a = [1, 2, 3, 4]
4299  a.insert(-1, 5)
4300 
4301  return a == [1, 2, 3, 5, 4]
4302  self.checkScript(test_list_insert_negative, ())
4303 
4304  def test_mutable_list_insert_neg_out_of_bounds(self):
4305  def test_list_insert_neg_out_of_bounds():
4306  a = [1, 2, 3, 4]
4307  a.insert(-10, 5)
4308 
4309  return a == [5, 1, 2, 3, 4]
4310  self.checkScript(test_list_insert_neg_out_of_bounds, ())
4311 
4312  def test_mutable_list_insert_out_of_bounds(self):
4313  def test_list_insert_out_of_bounds():
4314  a = [1, 2, 3, 4]
4315  a.insert(10, 5)
4316 
4317  return a == [1, 2, 3, 4, 5]
4318  self.checkScript(test_list_insert_out_of_bounds, ())
4319 
4320  def test_mutable_list_remove_not_existing(self):
4321  @torch.jit.script
4322  def test_list_remove_not_existing():
4323  a = [1, 2, 3, 4]
4324  a.remove(5)
4325 
4326  return a
4327 
4328  with self.assertRaisesRegex(RuntimeError, "x not in list"):
4329  test_list_remove_not_existing()
4330 
4331  def test_mutable_list_remove(self):
4332  def test_list_remove():
4333  a = [1, 2, 3, 4]
4334  a.remove(3)
4335 
4336  return a == [1, 2, 4]
4337  self.checkScript(test_list_remove, ())
4338 
4339  def test_list_index_not_existing(self):
4340  @torch.jit.script
4341  def list_index_not_existing():
4342  a = [4, 1, 3, 2]
4343  i = a.index(5)
4344 
4345  return i
4346 
4347  with self.assertRaisesRegex(RuntimeError, "'5' is not in list"):
4348  list_index_not_existing()
4349 
4350  def test_list_index(self):
4351  def list_index():
4352  a = [4, 1, 3, 2]
4353  i = a.index(3)
4354 
4355  return i == 2
4356  self.checkScript(list_index, ())
4357 
4358  def test_tensor_list_index(self):
4359  def tensor_list_index():
4360  a = [torch.tensor(4), torch.tensor(1), torch.tensor(3), torch.tensor(2)]
4361  i = a.index(torch.tensor(3))
4362 
4363  return i == 2
4364  self.checkScript(tensor_list_index, ())
4365 
4366  def test_tensor_list_index_not_existing(self):
4367  @torch.jit.script
4368  def tensor_list_index_not_existing():
4369  a = [torch.tensor(4), torch.tensor(1), torch.tensor(3), torch.tensor(2)]
4370  i = a.index(torch.tensor(5))
4371 
4372  return i
4373 
4374  with self.assertRaisesRegex(RuntimeError, "is not in list"):
4375  tensor_list_index_not_existing()
4376 
4377  def test_list_count(self):
4378  def list_count():
4379  a = [4, 1, 4, 2, 4]
4380  i = a.count(4)
4381 
4382  return i == 3
4383  self.checkScript(list_count, ())
4384 
4385  def test_list_count_not_existing(self):
4386  def list_count_not_existing():
4387  a = [4, 1, 4, 2, 4]
4388  i = a.count(5)
4389 
4390  return i == 0
4391  self.checkScript(list_count_not_existing, ())
4392 
4393  def test_tensor_list_count(self):
4394  def tensor_list_count():
4395  a = [torch.tensor(4), torch.tensor(1), torch.tensor(4), torch.tensor(4)]
4396  i = a.count(torch.tensor(4))
4397 
4398  return i == 3
4399  self.checkScript(tensor_list_count, ())
4400 
4401  def test_tensor_list_count_not_existing(self):
4402  def tensor_list_count_not_existing():
4403  a = [torch.tensor(4), torch.tensor(1), torch.tensor(4), torch.tensor(4)]
4404  i = a.count(torch.tensor(5))
4405 
4406  return i == 0
4407  self.checkScript(tensor_list_count_not_existing, ())
4408 
4409  def test_mutable_list_remove_tensor(self):
4410  def test_list_remove_tensor():
4411  a = [torch.ones(1), torch.zeros(1), torch.ones(2)]
4412  a.remove(torch.zeros(1))
4413 
4414  return len(a) == 2
4415  self.checkScript(test_list_remove_tensor, ())
4416 
4417  def test_mutable_list_remove2(self):
4418  def test_list_remove2():
4419  a = [1]
4420  a.remove(1)
4421 
4422  return len(a) == 0
4423  self.checkScript(test_list_remove2, ())
4424 
4425  def test_extend_list_mutable(self):
4426  @torch.jit.script
4427  def extend_list(a, b):
4428  # type: (List[Tensor], List[Tensor]) -> List[Tensor]
4429 
4430  a.extend(b)
4431  return a
4432 
4433  for l in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]:
4434  for r in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]:
4435  self.assertEqual(extend_list(l, r), l + r)
4436 
4437  def test_extend_list_immutable(self):
4438  @torch.jit.script
4439  def extend_list(a, b):
4440  # type: (List[int], List[int]) -> List[int]
4441 
4442  a.extend(b)
4443  return a
4444 
4445  for l in [[], [1], [1, 2, 3]]:
4446  for r in [[], [1], [1, 2, 3]]:
4447  self.assertEqual(extend_list(l, r), l + r)
4448 
4449  def test_copy_list_mutable(self):
4450  @torch.jit.script
4451  def copy_list(a):
4452  # type: (List[Tensor]) -> List[Tensor]
4453  return a.copy()
4454 
4455  for l in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]:
4456  self.assertEqual(copy_list(l), l)
4457 
4458  def test_copy_list_immutable(self):
4459  @torch.jit.script
4460  def copy_list(a):
4461  # type: (List[int]) -> List[int]
4462  return a.copy()
4463 
4464  for l in [[], [1], [1, 2, 3]]:
4465  self.assertEqual(copy_list(l), l)
4466 
4467  def test_func_call(self):
4468  script = '''
4469  def add(a, b):
4470  return a + b
4471 
4472  def mul(a, x):
4473  return a * x
4474 
4475  def func(alpha, beta, x, y):
4476  return add(mul(alpha, x), mul(beta, y))
4477  '''
4478  alpha = torch.rand(1, dtype=torch.float, requires_grad=True)
4479  beta = torch.rand(1, dtype=torch.float, requires_grad=True)
4480  x = torch.rand(3, dtype=torch.float, requires_grad=True)
4481  y = torch.rand(3, dtype=torch.float, requires_grad=True)
4482  outputs = alpha * x + beta * y
4483  # NOTE: cannot optimize yet because broadcasts are not inserted before the fuser runs
4484  self.checkScript(script, [alpha, beta, x, y], optimize=False, outputs=outputs)
4485 
4486  def test_resize_input_ops(self):
4487  # resize_ and resize_as resize the input tensor. because our shape analysis
4488  # is flow invariant, we set any Tensor that can alias a resized Tensor
4489  # to the base Tensor Type, without size information.
4490 
4491  # testing that value which is an input of a graph gets handled
4492  def out_op_graph_input():
4493  @torch.jit.script
4494  def test(x, y, z):
4495  torch.mul(x, y, out=z)
4496  return z
4497 
4498  torch._C._jit_pass_shape_analysis(
4499  test.graph, (torch.zeros(2, 1), torch.zeros(1, 2), torch.zeros(1, 1, 1)), False)
4500  self.assertTrue(next(test.graph.outputs()).type() == TensorType.get())
4501  out_op_graph_input()
4502 
4503  def test_resize():
4504  @torch.jit.script
4505  def test(x):
4506  after_resize_alias = torch.zeros([2])
4507  for _i in range(5):
4508  b = x + 1
4509  f = [1]
4510  before_resize_alias = b.sub_(1)
4511  # for i in range(10):
4512  f.append(1)
4513  b.resize_(f)
4514  after_resize_alias = b.add_(1)
4515  return after_resize_alias
4516 
4517  g = test.graph
4518  self.run_pass('constant_propagation', g)
4519  torch._C._jit_pass_shape_analysis(
4520  g, (torch.zeros(1, 1),), False)
4521  resize_node = g.findNode("aten::resize_")
4522  # first input and output of b.resize_ is b
4523  self.assertTrue(next(resize_node.inputs()).type() == TensorType.get())
4524  self.assertTrue(next(resize_node.outputs()).type() == TensorType.get())
4525 
4526  # correctly propagates to b alias set
4527  before_resize = g.findNode("aten::sub_")
4528  self.assertTrue(next(before_resize.outputs()).type() == TensorType.get())
4529 
4530  after_resize = g.findNode("aten::add_")
4531  self.assertTrue(next(after_resize.outputs()).type() == TensorType.get())
4532 
4533  test_resize()
4534 
4535  def test_resize_as():
4536  @torch.jit.script
4537  def test(x):
4538  b = torch.zeros([2, 2])
4539  b.resize_as_(x)
4540  return b
4541 
4542  g = test.graph
4543  self.run_pass('constant_propagation', g)
4544  torch._C._jit_pass_shape_analysis(
4545  g, (torch.zeros(1, 1),), False)
4546 
4547  # x doesn't alias a resized op so it shouldn't be set to base Tensor type
4548  self.assertTrue(next(g.inputs()).type() != TensorType.get())
4549  # return is resized
4550  self.assertTrue(next(g.outputs()).type() == TensorType.get())
4551 
4552  test_resize_as()
4553 
4554  def test_view_shape_prop(self):
4555  cu = torch.jit.CompilationUnit('''
4556  def test_view_shape_prop(a):
4557  return a.view(size=[-1])
4558  ''')
4559  inputs = [torch.zeros(10, 10)]
4560  outputs = torch.zeros(100)
4561 
4562  real_outs = cu.test_view_shape_prop(*inputs)
4563  self.assertEqual(real_outs, outputs)
4564 
4565  def test_view_listconstruct_shape_prop(self):
4566  def fn(x):
4567  B = x.size(0)
4568  C = x.size(1)
4569  T = x.size(2)
4570  return x.view(T, B, C)
4571 
4572  x = torch.randn(3, 1, 5, requires_grad=True)
4573  graph = torch.jit.script(fn).graph
4574  torch._C._jit_pass_shape_analysis(graph, (x,), False)
4575  a = next(graph.outputs()).type().kind()
4576  self.assertTrue(next(graph.outputs()).type().kind() != 'TensorType')
4577 
4578  def test_integral_shape_inference(self):
4579  cu = torch.jit.CompilationUnit('''
4580  def test_integral_shape_inference(a):
4581  return a / a
4582  ''')
4583  inputs = [torch.ones(10, 10).type(torch.LongTensor)]
4584  outputs = torch.ones(10, 10)
4585 
4586  self.assertEqual(cu.test_integral_shape_inference(*inputs), outputs)
4587 
4588  def test_fuser_multiple_blocks(self):
4589  cu = torch.jit.CompilationUnit('''
4590  def test_fuser_multiple_blocks(this, that, theother, meme):
4591  i = 0
4592  while i < 20:
4593  this = torch.cat([this, meme], dim=0)
4594  that = torch.cat([that, meme], dim=0)
4595  theother = torch.cat([theother, meme], dim=0)
4596  i = i + 1
4597  return this, that, theother
4598  ''')
4599 
4600  inputs = [torch.ones(0, 10, 10)] * 3
4601  inputs += [torch.ones(1, 10, 10)]
4602  outputs = [torch.ones(20, 10, 10)] * 3
4603 
4604  self.assertEqual(cu.test_fuser_multiple_blocks(*inputs), outputs)
4605 
4606  def test_dropout_script(self):
4607 
4608  eg = torch.zeros(1, 2, 3, requires_grad=True)
4609 
4610  @_trace(eg)
4611  def foo(x):
4612  x = torch.neg(x)
4613  return F.dropout(x)
4614 
4615  class MyDrop(nn.Module):
4616  def forward(self, x):
4617  return foo(x)
4618 
4619  f = io.BytesIO()
4620  torch.onnx.export(MyDrop(), (eg,), f, verbose=False)
4621 
4622  @unittest.skip("RuntimeError: VariableType::ID() not implemented")
4623  def test_cast(self):
4624  script = '''
4625  def to_int(x):
4626  return int(x)
4627  '''
4628  x = Variable(torch.FloatTensor([1.1, 2.3]), requires_grad=True)
4629  out = Variable(torch.IntTensor([1, 2]), requires_grad=True)
4630  self.checkScript(script, [x], optimize=True, outputs=[out], func='to_int')
4631 
4632  def test_python_frontend(self):
4633  def fn(x, y, z):
4634  q = None
4635  q = x + y - z.sigmoid()
4636  print(q)
4637  w = -z
4638  if not x and not y and z:
4639  m = x if not z else y
4640  while x < y > z:
4641  q = x
4642  assert 1 == 1, "hello"
4643  return x
4644 
4646  self.assertExpected(str(ast))
4647 
4648  @unittest.skipIf(not PY2, "Requires python 2")
4649  def test_python_frontend_py2(self):
4650  def fn():
4651  raise Exception("hello")
4653  self.assertExpected(str(ast))
4654 
4655  @unittest.skipIf(PY2, "Requires python 3")
4656  def test_python_frontend_py3(self):
4657  def fn():
4658  raise Exception("hello")
4660  self.assertExpected(str(ast))
4661 
4662  def _make_scalar_vars(self, arr, dtype):
4663  return [torch.tensor(val, dtype=dtype) for val in arr]
4664 
4665  def test_string_print(self):
4666  def func(a):
4667  print(a, "a" 'b' '''c''' """d""", 2, 1.5)
4668  return a
4669 
4670  inputs = self._make_scalar_vars([1], torch.int64)
4671  self.checkScript(func, inputs, capture_output=True)
4672 
4673  def test_while(self):
4674  def func(a, b, max):
4675  while bool(a < max):
4676  a = a + 1
4677  b = b + 1
4678  c = a + b
4679  return c
4680 
4681  inputs = self._make_scalar_vars([1, 1, 10], torch.int64)
4682  self.checkScript(func, inputs, optimize=True)
4683 
4684  def test_fibb(self):
4685  def func(lim):
4686  first = 1
4687  second = 1
4688  i = 1
4689  somenum = 5
4690  dontmutateme = 3
4691  third = 0
4692  while bool(i < lim):
4693  third = first + second
4694  first = second
4695  second = third
4696  j = 0
4697  while j < 10:
4698  somenum = somenum * 2
4699  j = j + 1
4700  i = i + j
4701  i = i + dontmutateme
4702 
4703  st = second + third
4704  fs = first + second
4705  return third, st, fs
4706 
4707  inputs = self._make_scalar_vars([10], torch.int64)
4708  self.checkScript(func, inputs, optimize=True)
4709 
4710  def test_if(self):
4711  def func(a, b):
4712  # type: (int, int) -> int
4713  d = 3
4714  if bool(a > 10):
4715  a = 3 + d
4716  else:
4717  b = 3 + d
4718  d = 4
4719  c = a + b
4720  return c
4721 
4722  inputs = self._make_scalar_vars([1, -1], torch.int64)
4723  self.checkScript(func, inputs, optimize=True)
4724 
4725  def test_if_for_in_range(self):
4726  def func(a, b):
4727  # type: (int, int) -> int
4728  d = 3
4729  for _ in range(20):
4730  if bool(a > 10):
4731  a = 3 + d
4732  else:
4733  b = 3 + d
4734  d = 4
4735  c = a + b
4736  return d
4737  inputs = self._make_scalar_vars([1, -1], torch.int64)
4738  self.checkScript(func, inputs, optimize=True)
4739 
4740  def test_if_noelse(self):
4741  def func(a, b):
4742  if bool(a > 10):
4743  a = 3 + b
4744  c = a + b
4745  return c
4746 
4747  inputs = self._make_scalar_vars([-1, 1], torch.int64)
4748  self.checkScript(func, inputs, optimize=True)
4749 
4750  def test_if_is_none_dispatch(self):
4751 
4752  @torch.jit.script
4753  def test_lhs_none_rhs_none():
4754  # LHS, RHS both alwaysNone, dispatch always_none_branch
4755  # only emit one prim::Constant
4756  if None is None:
4757  return 1
4758  elif None is not None:
4759  return 2
4760  else:
4761  return 3
4762 
4763  self.assertTrue(str(test_lhs_none_rhs_none.graph).count(': int = prim::Constant') == 1)
4764 
4765  @torch.jit.script
4766  def test_lhs_opt_rhs_none(lhs=None):
4767  # type: (Optional[Tensor]) -> int
4768  # LHS maybeNone: emit normal if stmt that contains 3 constants
4769  if lhs is not None:
4770  return 2
4771  elif lhs is None:
4772  return 1
4773  else:
4774  return 3
4775 
4776  self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(': int = prim::Constant') == 3)
4777 
4778  @torch.jit.script
4779  def test_lhs_none_rhs_opt(rhs=None):
4780  # type: (Optional[Tensor]) -> int
4781  # RHS maybeNone, emit normal if stmt that contains 3 constants
4782  if None is rhs:
4783  return 1
4784  elif None is not rhs:
4785  return 2
4786  else:
4787  return 3
4788 
4789  self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(': int = prim::Constant') == 3)
4790 
4791  @torch.jit.script
4792  def test_lhs_never_rhs_none(lhs):
4793  # LHS neverNone, RHS alwaysNone dispatch never_none_branch
4794  # only emit one prim::Constant
4795  if lhs is None:
4796  return 1
4797  elif lhs is not None:
4798  return 2
4799  else:
4800  return 3
4801 
4802  self.assertTrue(str(test_lhs_never_rhs_none.graph).count(': int = prim::Constant') == 1)
4803 
4804  @torch.jit.script
4805  def test_lhs_none_rhs_never(rhs):
4806  # LHS alwaysNone, RHS neverNone dispatch never_none_branch
4807  # only emit one prim::Constant
4808  if None is rhs:
4809  return 1
4810  elif None is not rhs:
4811  return 2
4812  else:
4813  return 3
4814 
4815  self.assertTrue(str(test_lhs_none_rhs_never.graph).count(': int = prim::Constant') == 1)
4816 
4817  def test_explicit_bool_cast(self):
4818  with self.assertRaisesRegex(RuntimeError, "expected a boolean"):
4819  @torch.jit.script
4820  def test_bool_cast(a):
4821  if a:
4822  return a + 2
4823  return a + 1
4824 
4825  def test_while_nonexistent_value(self):
4826  with self.assertRaisesRegex(RuntimeError, "undefined value x"):
4828  def test_while(a, b):
4829  while bool(a < 10):
4830  a = a + x
4831  b = b + 1
4832  return a + b
4833  ''')
4834 
4835  def test_while_nonexistent_cond_value(self):
4836  with self.assertRaisesRegex(RuntimeError, "undefined value x"):
4838  def test_while(a, b):
4839  while a < x:
4840  a = a + 1
4841  b = b + 1
4842  return a + b
4843  ''')
4844 
4845  def test_optional_refinement(self):
4846  @torch.jit.script
4847  def test_if_none_assignment(x):
4848  # type: (Optional[int]) -> int
4849  if x is None:
4850  x = 1
4851  return x + 1
4852 
4853  self.assertEqual(test_if_none_assignment(1), 2)
4854 
4855  @torch.jit.script
4856  def test_ternary(x):
4857  # type: (Optional[int]) -> int
4858  x = x if x is not None else 2
4859  return x
4860 
4861  @torch.jit.script
4862  def test_not_none(x):
4863  # type: (Optional[int]) -> None
4864  if x is not None:
4865  print(x + 1)
4866 
4867  @torch.jit.script
4868  def test_and(x, y):
4869  # type: (Optional[int], Optional[int]) -> None
4870  if x is not None and y is not None:
4871  print(x + y)
4872 
4873  @torch.jit.script
4874  def test_not(x, y):
4875  # type: (Optional[int], Optional[int]) -> None
4876  if not (x is not None and y is not None):
4877  pass
4878  else:
4879  print(x + y)
4880 
4881  @torch.jit.script
4882  def test_bool_expression(x):
4883  # type: (Optional[int]) -> None
4884  if x is not None and x < 2:
4885  print(x + 1)
4886 
4887  @torch.jit.script
4888  def test_nested_bool_expression(x, y):
4889  # type: (Optional[int], Optional[int]) -> int
4890  if x is not None and x < 2 and y is not None:
4891  x = x + y
4892  else:
4893  x = 5
4894  return x + 2
4895 
4896  @torch.jit.script
4897  def test_or(x, y):
4898  # type: (Optional[int], Optional[int]) -> None
4899  if y is None or x is None:
4900  pass
4901  else:
4902  print(x + y)
4903 
4904  # backwards compatibility
4905  @torch.jit.script
4906  def test_manual_unwrap_opt(x):
4907  # type: (Optional[int]) -> int
4908  if x is None:
4909  x = 1
4910  else:
4912  return x # noqa: T484
4913 
4914  with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
4915  @torch.jit.script
4916  def or_error(x, y):
4917  # type: (Optional[int], Optional[int]) -> None
4918  if x is None or y is None:
4919  print(x + y) # noqa: T484
4920 
4921  with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
4922  @torch.jit.script
4923  def and_error(x, y):
4924  # type: (Optional[int], Optional[int]) -> None
4925  if x is None and y is None:
4926  pass
4927  else:
4928  print(x + y) # noqa: T484
4929 
4930  with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
4931  @torch.jit.script
4932  def named_var(x):
4933  # type: (Optional[int]) -> None
4934  x_none = x is not None
4935  if x_none:
4936  print(x + 1) # noqa: T484
4937 
4938  with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
4939  @torch.jit.script
4940  def named_var_and(x, y):
4941  # type: (Optional[int], Optional[int]) -> None
4942  x_none = x is not None
4943  if y is not None and x_none:
4944  print(x + y) # noqa: T484
4945 
4946  def test_while_write_outer_then_read(self):
4947  def func(a, b):
4948  while bool(a < 10):
4949  a = a + 1
4950  b = a + 1
4951  return a + b
4952 
4953  inputs = self._make_scalar_vars([42, 1337], torch.int64)
4954  self.checkScript(func, inputs, optimize=True)
4955 
4956  def test_while_nest_if(self):
4957  def func(a, b):
4958  # type: (int, int) -> int
4959  c = 0
4960  while a < 10:
4961  a = a + 1
4962  b = b + 1
4963  if a > b:
4964  c = -a
4965  else:
4966  c = -b
4967  return c + 1
4968 
4969  inputs = self._make_scalar_vars([-1234, 4321], torch.int64)
4970  self.checkScript(func, inputs, optimize=True)
4971 
4972  def test_math_ops(self):
4973 
4974  def test_floor():
4975  return math.floor(1.5)
4976 
4977  self.checkScript(test_floor, ())
4978 
4979  def test_if_nest_while(self):
4980  def func(a, b):
4981  # type: (int, int) -> int
4982  c = 0
4983  if a > b:
4984  while a > b:
4985  b = b + 1
4986  c = -b
4987  return c
4988 
4989  inputs = self._make_scalar_vars([4321, 1234], torch.int64)
4990  self.checkScript(func, inputs, optimize=True)
4991 
4992  def test_script_for_in_range(self):
4993  def fn():
4994  c = 0
4995  for i in range(100):
4996  c += i
4997  return c
4998  self.checkScript(fn, (), outputs=4950, optimize=True)
4999 
5000  def test_script_for_in_range_dynamic(self):
5001  def fn():
5002  c = 0
5003  for i in range(100):
5004  acc = 0
5005  for j in range(i):
5006  acc += j
5007  c += acc
5008  return c
5009  self.checkScript(fn, (), optimize=False)
5010 
5011  def test_script_for_in_range_ast(self):
5012  @torch.jit.script
5013  def test_script_for_in_range_ast():
5014  c = 0
5015  for i in range(100):
5016  acc = 0
5017  for j in range(i):
5018  acc += j
5019  c += acc
5020  return c
5021 
5022  self.assertEqual(test_script_for_in_range_ast(), 161700)
5023 
5024  def test_script_for_in_range_if_ast(self):
5025  @torch.jit.script
5026  def test_script_for_in_range_if_ast(x):
5027  output = x
5028  for i in range(20):
5029  if i == 0:
5030  output = x.unsqueeze(0)
5031  else:
5032  output = torch.cat((output, x.unsqueeze(0)), dim=0)
5033  return output
5034  inputs = self._make_scalar_vars([0], torch.int64)
5035 
5036  self.assertEqual(test_script_for_in_range_if_ast(*inputs).shape[0], 20)
5037 
5038  def test_script_optional_none(self):
5039  def none_stmt(x):
5040  output = None
5041  output = x
5042  return output
5043 
5044  def none_args(x):
5045  # type: (Optional[Tensor]) -> Optional[Tensor]
5046  return None
5047 
5048  self.checkScript(none_stmt, [torch.arange(0, 2)], optimize=True)
5049  self.checkScript(none_args, [None], optimize=True)
5050 
5051  # test undefined tensor None as default param
5052  def test_script_optional_tensor_none(x=None):
5053  # type: (Optional[Tensor]) -> Tensor
5054  res = torch.zeros(1, dtype=torch.int8)
5055  if x is None:
5056  res = res + 1
5057  else:
5058  res = x
5059  return res
5060 
5061  fn = test_script_optional_tensor_none
5062  scripted_fn = torch.jit.script(fn)
5063  self.assertEqual(fn(), scripted_fn())
5064  self.assertEqual(fn(torch.zeros(1)), scripted_fn(torch.zeros(1)))
5065 
5066  # test typical None as default param
5067  def test_script_optional_other_none(x=None):
5068  # type: (Optional[float]) -> float
5069  res = 2.0
5070  if x is None:
5071  res = res + 1.0
5072  else:
5073  res = x
5074  return res
5075 
5076  fn = test_script_optional_other_none
5077  scripted_fn = torch.jit.script(fn)
5078  self.assertEqual(fn(), scripted_fn())
5079  self.assertEqual(fn(1.0), scripted_fn(1.0))
5080 
5081  def test_script_clamp_none(self):
5082  def test_script_clamp_max_none(x):
5083  return torch.clamp(x, min=2, max=None)
5084 
5085  def test_script_clamp_max(x):
5086  return torch.clamp(x, max=2)
5087 
5088  def test_script_clamp_min_none(x):
5089  return torch.clamp(x, min=None, max=2)
5090 
5091  def test_script_clamp_min(x):
5092  return torch.clamp(x, min=2)
5093 
5094  input = [torch.arange(0, 3)]
5095  self.checkScript(test_script_clamp_max_none, input, optimize=True)
5096  self.checkScript(test_script_clamp_max, input, optimize=True)
5097  self.checkScript(test_script_clamp_min_none, input, optimize=True)
5098  self.checkScript(test_script_clamp_min, input, optimize=True)
5099 
5100  def test_script_bool_constant(self):
5101  script = '''
5102  def test_script_bool_constant():
5103  a = True
5104  return a
5105  '''
5106  outputs = [1]
5107  self.checkScript(script, [], outputs[0], True, 'test_script_bool_constant')
5108 
5109  def test_ternary(self):
5110  def func(a, b):
5111  c = 3
5112  c = a + b if bool(a > 3) else b
5113  return c
5114 
5115  inputs_true = self._make_scalar_vars([5, 2], torch.int64)
5116  inputs_false = self._make_scalar_vars([1, 0], torch.int64)
5117  self.checkScript(func, inputs_true, optimize=True)
5118  self.checkScript(func, inputs_false, optimize=True)
5119 
5120  def test_print(self):
5121  def func(x, y):
5122  q = (x + y).sigmoid()
5123  print(q, 1, 2, [1, 2], [1.0, 2.0])
5124  w = -q
5125  return w * w
5126 
5127  x = torch.arange(4., requires_grad=True)
5128  y = torch.arange(0., 8, 2, requires_grad=True)
5129  self.checkScript(func, [x, y], optimize=True, capture_output=True)
5130 
5131  def test_format(self):
5132  def func(x):
5133  print("{}, I'm a {}".format("Hello", "test"))
5134  print("format blank".format())
5135  print("stuff before {}".format("hi"))
5136  print("{} stuff after".format("hi"))
5137  return x + 1
5138 
5139  x = torch.arange(4., requires_grad=True)
5140  self.checkScript(func, [x], optimize=True, capture_output=True)
5141 
5142  def test_logical_short_circuit(self):
5143  @torch.jit.script
5144  def testNoThrows(t):
5145  c1 = 1
5146  if (False and bool(t[1])) or (True or bool(t[1])):
5147  c1 = 0
5148  return c1
5149 
5150  self.assertEqual(0, testNoThrows(torch.randn(0)))
5151  ifs = testNoThrows.graph.findAllNodes("prim::If", recurse=False)
5152 
5153  # three ifs at the top level, and the second one has a nested if for
5154  # the or (True or bool(t[1])) expression
5155  self.assertTrue(len(ifs) == 3)
5156  self.assertTrue(ifs[0].findNode("prim::If") is None)
5157  self.assertTrue(ifs[1].findNode("prim::If").findNode("prim::If") is None)
5158  self.assertTrue(ifs[2].findNode("prim::If") is None)
5159 
5160  @torch.jit.script
5161  def throwsOr(t):
5162  c0 = False or bool(t[1])
5163  print(c0)
5164 
5165  @torch.jit.script
5166  def throwsAnd(t):
5167  c0 = True and bool(t[1])
5168  print(c0)
5169 
5170  t = torch.randn(0)
5171  with self.assertRaisesRegex(RuntimeError, "index 1 out of range for tensor of size"):
5172  throwsOr(t)
5173  with self.assertRaisesRegex(RuntimeError, "index 1 out of range for tensor of size"):
5174  throwsAnd(t)
5175 
5176  def test_type_cast(self):
5177  template = dedent('''
5178  def cast(v):
5179  # type: ({from_type}) -> {to_type}
5180  return {to_type}(v)
5181  ''')
5182 
5183  def check_cast(from_type, to_type, value, raises=False):
5184  code = template.format(from_type=from_type, to_type=to_type)
5185  expected = getattr(builtins, to_type)(value)
5186  if raises:
5187  with self.assertRaisesRegex(RuntimeError, "Cannot cast"):
5188  cu = torch.jit.CompilationUnit(code)
5189  else:
5190  self.checkScript(code, (value,), name='cast', outputs=expected)
5191 
5192  check_cast('int', 'float', 1)
5193  check_cast('int', 'bool', 1)
5194  check_cast('int', 'bool', 0)
5195 
5196  check_cast('float', 'int', 1.)
5197  check_cast('float', 'bool', 1.)
5198  check_cast('float', 'bool', 0.)
5199 
5200  check_cast('bool', 'int', True)
5201  check_cast('bool', 'float', True)
5202 
5203  def test_multiple_assignment(self):
5204  def outer_func(x):
5205  return x * 2, x + 2
5206 
5207  @torch.jit.script
5208  def func(x):
5209  y, z = outer_func(x)
5210  return y + z
5211 
5212  x = torch.arange(4)
5213  self.assertEqual(func(x), x * 2 + x + 2)
5214 
5215  def test_literals(self):
5216  def func(a):
5217  return a.view(size=[1, 2, 3])
5218 
5219  a = torch.randn(6)
5220  self.checkScript(func, [a], optimize=True)
5221 
5222  def test_return(self):
5223  def no_return(a):
5224  a + 1
5225 
5226  def void_return(a):
5227  return
5228 
5229  def one_return(a):
5230  return a + 1.
5231 
5232  def multiple_returns(a):
5233  return a * 1., a * 2., a * 3.
5234 
5235  a = torch.randn(1, dtype=torch.float)
5236  self.checkScript(no_return, [a], optimize=True)
5237  self.checkScript(void_return, [a], optimize=True)
5238  self.checkScript(one_return, [a], optimize=True)
5239  self.checkScript(multiple_returns, [a], optimize=True)
5240 
5241  with self.assertRaisesRegex(RuntimeError, "but is actually of type None"):
5243  def no_return_bad_annotation(a):
5244  # type: (Tensor) -> Tensor
5245  a + 1
5246  ''')
5247 
5248  def test_error(self):
5249  @torch.jit.script
5250  def foo(a):
5251  return a.t()
5252  s = Variable(torch.rand(5, 5, 5))
5253  # XXX: this should stay quiet in stay propagation and only fail in the interpreter
5254  with self.assertRaisesRegex(RuntimeError, "failed in interpreter"):
5255  foo(s)
5256 
5257  @torch.jit.script
5258  def bar(c, b):
5259  return c + b
5260 
5261  with self.assertRaisesRegex(RuntimeError, "failed in interpreter"):
5262  bar(Variable(torch.rand(10), requires_grad=True), Variable(torch.rand(9), requires_grad=True))
5263 
5264  def test_binop_unsupported_error(self):
5265  with self.assertRaisesRegex(NotSupportedError, "unsupported binary operator:"):
5266  @torch.jit.script
5267  def binop(x, y):
5268  # Replace this with another unsupported op when/if it gets supported
5269  return x << y
5270 
5271  def test_bitwise_ops(self):
5272 
5273  def int_test():
5274  return 2 & 3, 2 ^ 3, 2 | 3
5275 
5276  self.checkScript(int_test, ())
5277 
5278  def bool_test(x, y):
5279  # type: (bool, bool) -> Tuple[bool, bool, bool]
5280  return x & y, x ^ y, x | y
5281 
5282  self.checkScript(bool_test, (True, False))
5283  self.checkScript(bool_test, (True, True))
5284 
5285  def tensor_test(x, y):
5286  return x & y, x ^ y, x | y
5287 
5288  x = torch.tensor(2)
5289  y = torch.tensor(3)
5290 
5291  self.checkScript(tensor_test, (x, y))
5292 
5293  def test_number_math(self):
5294  ops_template = dedent('''
5295  def func():
5296  return {scalar1} {op} {scalar2}
5297  ''')
5298  ops = ['+', '-', '*', '%', '<', '<=', '>', '>=', '==', '!=', '//']
5299  funcs_template = dedent('''
5300  def func():
5301  return {func}({scalar1}, {scalar2})
5302  ''')
5303  funcs = ['min', 'max']
5304  scalars = ['7', '2', '3', '-3', '3.14', '0.125', '-0.5', '2.0', '-2.0']
5305  scalar_pairs = [(scalar1, scalar2) for scalar1 in scalars for scalar2 in scalars]
5306 
5307  def run_test(code):
5308  scope = {}
5309  execWrapper(code, globals(), scope)
5310  cu = torch.jit.CompilationUnit(code)
5311 
5312  self.assertEqual(cu.func(), scope['func']())
5313 
5314  for scalar1, scalar2 in scalar_pairs:
5315  for op in ops:
5316  code = ops_template.format(op=op, scalar1=scalar1, scalar2=scalar2)
5317  run_test(code)
5318  for func in funcs:
5319  code = funcs_template.format(func=func, scalar1=scalar1, scalar2=scalar2)
5320  run_test(code)
5321 
5322  def test_number_div(self):
5323  self.checkScript(div_int_future, (), optimize=True)
5324  self.checkScript(div_float_future, (), optimize=True)
5325 
5326  if PY2:
5327  with self.assertRaisesRegex(RuntimeError, 'from __future__ import division'):
5328  torch.jit.script(div_int_nofuture)
5329  with self.assertRaisesRegex(RuntimeError, 'from __future__ import division'):
5330  torch.jit.script(div_float_nofuture)
5331  else:
5332  self.checkScript(div_int_nofuture, (), optimize=True)
5333  self.checkScript(div_float_nofuture, (), optimize=True)
5334 
5335  def test_floor_div(self):
5336  @torch.jit.script
5337  def foo(a, b):
5338  # type: (int, int) -> int
5339  return a // b
5340  for i in range(-8, 8):
5341  for j in range(-8, 8):
5342  if j != 0:
5343  self.assertEqual(foo(i, j), i // j)
5344  else:
5345  with self.assertRaisesRegex(RuntimeError, 'division by 0'):
5346  foo(i, j)
5347 
5348  def test_number_augassign(self):
5349  def func():
5350  z = 1
5351  z += 2
5352  return z
5353 
5354  self.checkScript(func, (), optimize=True)
5355 
5356  def test_number_neg(self):
5357  # int -> int
5358  def func1():
5359  return -8
5360 
5361  # float -> float
5362  def func2():
5363  return -3.14
5364 
5365  self.checkScript(func1, (), optimize=True)
5366  self.checkScript(func2, (), optimize=True)
5367 
5368  def _test_tensor_number_math(self, device='cpu'):
5369  template = dedent('''
5370  def func(t):
5371  return {lhs} {op} {rhs}
5372  ''')
5373 
5374  def test(op, const, swap_args):
5375  args = ('t', const)
5376  if swap_args:
5377  args = (const, 't')
5378 
5379  code = template.format(lhs=args[0], rhs=args[1], op=op)
5380  scope = {}
5381  execWrapper(code, globals(), scope)
5382  cu = torch.jit.CompilationUnit(code)
5383  self.assertEqual(cu.func(tensor), scope['func'](tensor))
5384 
5385  var_int = [2, -2]
5386  var_float = [1.4321, -1.2]
5387 
5388  ops = ['+', '-', '*', '%', '<', '<=', '>', '>=', '==', '!=', '/']
5389 
5390  float_tensor = torch.randn(5, 5, device=device)
5391  double_tensor = torch.randn(5, 5, dtype=torch.double, device=device)
5392  long_tensor = torch.randint(-5, 5, (5, 5), dtype=torch.long, device=device)
5393  long_tensor[long_tensor == 0] = 2
5394 
5395  tensors = [float_tensor, double_tensor, long_tensor]
5396  consts = var_int + var_float
5397 
5398  for op, tensor, const, swap_args in product(ops, tensors, consts, [True, False]):
5399  # FIXME: things like 2 / long_tensor are not implemented correctly
5400  # Look in torch/tensor.py to see how pytorch implements it.
5401  if op == '/' and tensor.data_ptr() == long_tensor.data_ptr():
5402  continue
5403 
5404  # % operator does not take: const % tensor
5405  if op == '%' and swap_args is True:
5406  continue
5407 
5408  test(op, const, swap_args)
5409 
5410  def test_tensor_number_math(self):
5411  self._test_tensor_number_math()
5412 
5413  def test_torch_tensor_bad_input(self):
5414  with self.assertRaisesRegex(RuntimeError, "Input list to torch.tensor must be of ints, floats, "
5415  "or bools, got None"):
5416  @torch.jit.script
5417  def test():
5418  return torch.tensor([None])
5419 
5420  with self.assertRaisesRegex(RuntimeError, "Note: empty lists are constructed as Tensor"):
5421  @torch.jit.script
5422  def tmp():
5423  return torch.tensor([])
5424 
5425  @torch.jit.script
5426  def foo():
5427  return torch.tensor([[2, 2], [1]])
5428  with self.assertRaisesRegex(RuntimeError, "Expected sequence of length"):
5429  foo()
5430 
5431  @suppress_warnings
5432  def test_torch_tensor_empty_list(self):
5433  def func():
5434  return torch.tensor(torch.jit.annotate(List[int], []))
5435  cu = torch.jit.script(func)
5436  t1 = cu()
5437  t2 = func()
5438 
5439  # torchscript returns int tensor, python returns float tensor
5440  self.assertNotEqual(t1.dtype, t2.dtype)
5441 
5442  def func():
5443  li = torch.jit.annotate(List[int], [])
5444  return torch.tensor([li, li])
5445 
5446  self.checkScript(func, ())
5447 
5448  def func():
5449  li = torch.jit.annotate(List[int], [])
5450  return torch.tensor([[[li]]])
5451 
5452  self.checkScript(func, ())
5453 
5454  def test_torch_tensor(self):
5455  template = dedent('''
5456  def func():
5457  li = {list_create}
5458  return torch.tensor(li {options})
5459  ''')
5460 
5461  lists = ["2.5", "4", "True", "False", "[2]", "[-.5]", "[False, True, False]", "[2, 2]",
5462  "torch.jit.annotate(List[int], [])", "[2.5, 2.5]", "[[2], [2]]", "[[-.5], [2.2]]", "[[False], [True]]"]
5463 
5464  dtypes = ["", ", dtype=torch.float", ", dtype=torch.double", ", dtype=torch.half",
5465  ", dtype=torch.uint8", ", dtype=torch.int8", ", dtype=torch.short",
5466  ", dtype=torch.int", ", dtype=torch.long"]
5467 
5468  devices = ['', ", device='cpu'"]
5469  if RUN_CUDA:
5470  devices.append(", device='cuda'")
5471 
5472  option_pairs = [dtype + device for dtype in dtypes for device in devices]
5473  for li in lists:
5474  for option in option_pairs:
5475  # tensor from empty list is type float in python and annotated type in torchscript
5476  if "annotate" in li and "dtype" not in option:
5477  continue
5478  code = template.format(list_create=li, options=option)
5479  scope = {}
5480  exec(code, globals(), scope)
5481  cu = torch.jit.CompilationUnit(code)
5482  t1 = cu.func()
5483  t2 = scope['func']()
5484  if t1.dtype == torch.float16: # equality NYI for half tensor
5485  self.assertTrue(str(t1) == str(t2))
5486  else:
5487  self.assertEqual(t1, t2)
5488  self.assertEqual(t1.dtype, t2.dtype)
5489  self.assertEqual(t1.device, t2.device)
5490 
5491  # adapted from test in test_torch
5492  def test_tensor_to(self):
5493  template = dedent('''
5494  def func(t):
5495  cuda = "{cuda}"
5496  device = "{device}"
5497  non_blocking = {non_blocking}
5498  return {to_str}
5499  ''')
5500 
5501  def s(t, to_str, non_blocking=None, device=None, cuda=None):
5502  device = device if device is not None else str(t.device)
5503  non_blocking = non_blocking if non_blocking is not None else False
5504  cuda = "cuda" if cuda is None else cuda
5505  code = template.format(to_str=to_str, device=device, non_blocking=non_blocking, cuda=cuda)
5506  scope = {}
5507  cu = torch.jit.CompilationUnit(code)
5508  return cu.func(t)
5509 
5510  def test_copy_behavior(t, non_blocking=False):
5511  self.assertIs(t, s(t, 't.to(t, non_blocking=non_blocking)', non_blocking))
5512  self.assertIs(t, s(t, 't.to(t.dtype, non_blocking=non_blocking)', non_blocking))
5513  self.assertIs(t, s(t, 't.to(torch.empty_like(t), non_blocking=non_blocking)', non_blocking))
5514  self.assertIsNot(t, s(t, 't.to(t, non_blocking=non_blocking, copy=True)', non_blocking))
5515  self.assertIsNot(t, s(t, 't.to(t.dtype, non_blocking=non_blocking, copy=True)', non_blocking))
5516  self.assertIsNot(t, s(t, 't.to(torch.empty_like(t), non_blocking=non_blocking, copy=True)', non_blocking))
5517 
5518  devices = [t.device]
5519  if t.device.type == 'cuda':
5520  if t.device.index == -1:
5521  devices.append('cuda:{}'.format(torch.cuda.current_device()))
5522  elif t.device.index == torch.cuda.current_device():
5523  devices.append('cuda')
5524  for device in devices:
5525  self.assertIs(t, s(t, 't.to(device, non_blocking=non_blocking)', non_blocking, device))
5526  self.assertIs(t, s(t, 't.to(device, t.dtype, non_blocking=non_blocking)', non_blocking, device))
5527  self.assertIsNot(t, s(t, 't.to(device, non_blocking=non_blocking, copy=True)', non_blocking, device))
5528  self.assertIsNot(t, s(t, 't.to(device, t.dtype, non_blocking=non_blocking, copy=True)',
5529  non_blocking, device))
5530 
5531  t = torch.tensor(5)
5532  test_copy_behavior(t)
5533 
5534  self.assertEqual(t.device, s(t, "t.to('cpu')").device)
5535  self.assertEqual(t.device, s(t, "t.to('cpu', dtype=torch.float32)").device)
5536  self.assertIs(torch.float32, s(t, "t.to('cpu', dtype=torch.float32)").dtype)
5537  self.assertEqual(t.device, s(t, "t.to(torch.float32)").device)
5538  self.assertIs(torch.float32, s(t, "t.to(dtype=torch.float32)").dtype)
5539  self.assertEqual(t.data_ptr(), s(t, "t.to('cpu')").data_ptr())
5540  self.assertEqual(t.data_ptr(), s(t, "t.to(dtype=t.dtype, device=t.device, copy=False)").data_ptr())
5541  self.assertEqual(t.data_ptr(), s(t, "t.to('cpu', copy=False)").data_ptr())
5542  self.assertNotEqual(t.data_ptr(), s(t, "t.to('cpu', copy=True)").data_ptr())
5543 
5544  a = torch.tensor(5)
5546  for non_blocking in [True, False]:
5547  for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
5548  b = torch.tensor(5., device=cuda)
5549  test_copy_behavior(b, non_blocking)
5550  self.assertEqual(b.device, s(b, "t.to(cuda, non_blocking=non_blocking).device", cuda=cuda))
5551  self.assertEqual(a.device, s(b, "t.to('cpu', non_blocking=non_blocking).device"))
5552  self.assertEqual(b.device, s(b, "t.to(cuda, non_blocking=non_blocking).device", cuda=cuda))
5553  self.assertIs(torch.int32, s(b, "t.to('cpu', dtype=torch.int32, non_blocking=non_blocking)").dtype)
5554  self.assertEqual(a.device, s(b, "t.to('cpu', dtype=torch.int32, non_blocking=non_blocking)").device)
5555  self.assertIs(torch.int32, s(b, "t.to(dtype=torch.int32)").dtype)
5556  self.assertEqual(b.device, s(b, "t.to(dtype=torch.int32)").device)
5557 
5558  # Test AD: aten::to(Tensor self, int dtype, bool non_blocking, bool copy) -> Tensor
5559  t = torch.tensor(5).float().requires_grad_()
5560  out_ref = t.to(torch.float32)
5561  out = s(t, "t.to(torch.float32)")
5562  self.assertEqual(out_ref, out)
5563 
5564  grad_ref = torch.autograd.grad(out_ref.sum(), t)
5565  grad = torch.autograd.grad(out.sum(), t)
5566  self.assertEqual(grad_ref, grad)
5567 
5568  # Test AD: aten::to(Tensor self, Device? device, int? dtype, bool non_blocking, bool copy) -> Tensor
5569  out_ref = t.to('cpu')
5570  out = s(t, "t.to('cpu')")
5571  self.assertEqual(out_ref, out)
5572 
5573  grad_ref = torch.autograd.grad(out_ref.sum(), t)
5574  grad = torch.autograd.grad(out.sum(), t)
5575  self.assertEqual(grad_ref, grad)
5576 
5577  # Test AD: aten::to(Tensor self, Tensor other, bool non_blocking, bool copy) -> Tensor
5578  @torch.jit.script
5579  def func2(t, t_ref):
5580  return t.to(t_ref)
5581 
5582  func2.debug_disable_autodiff_subgraph_inlining()
5583 
5584  t_ref = torch.tensor(4).double()
5585  out_ref = t.to(t_ref)
5586  out = func2(t, t_ref)
5587  grad_ref = torch.autograd.grad(out_ref.sum(), t)
5588  grad = torch.autograd.grad(out.sum(), t)
5589  self.assertEqual(grad_ref, grad)
5590 
5591  @unittest.skipIf(not RUN_CUDA, "No CUDA")
5592  def test_tensor_number_math_cuda(self):
5593  self._test_tensor_number_math(device='cuda')
5594 
5595  def test_not(self):
5596  # test not operator in python
5597  # TODO: add more tests when bool conversions ready
5598  def test_not_op(a):
5599  return not bool(a > 1)
5600 
5601  self.checkScript(test_not_op, (torch.tensor(2), ), optimize=True)
5602 
5603  def test_is_isnot(self):
5604  # test is and is not operator in python
5605  template = dedent('''
5606  def func():
5607  # type: () -> bool
5608  return {lhs} {op} {rhs}
5609  ''')
5610 
5611  def test(op, args):
5612  code = template.format(lhs=args[0], rhs=args[1], op=op)
5613  scope = {}
5614  execWrapper(code, globals(), scope)
5615  cu = torch.jit.CompilationUnit(code)
5616  self.assertEqual(
5617  cu.func(),
5618  scope['func'](),
5619  "Failed with op: {}, lhs: {}, rhs: {}"
5620  .format(op, args[0], args[1])
5621  )
5622 
5623  ops = ['is', 'is not']
5624  type_literals = [True, False, None, [1, 1]]
5625 
5626  # do literals product to try any types combinations
5627  for op, lhs, rhs in product(ops, type_literals, type_literals):
5628  test(op, [lhs, rhs])
5629 
5630  def test_isinstance(self):
5631  # test isinstance operator for static type checking
5632  template = dedent('''
5633  def func(x):
5634  # type: ({type_hint}) -> bool
5635  return isinstance(x, {typ})
5636  ''')
5637 
5638  def test(inp, typ, type_hint):
5639  code = template.format(typ=typ, type_hint=type_hint)
5640  scope = {}
5641  execWrapper(code, globals(), scope)
5642  cu = torch.jit.CompilationUnit(code)
5643  self.assertEqual(
5644  cu.func(inp),
5645  scope['func'](inp),
5646  "Failed with typ: {}"
5647  .format(typ)
5648  )
5649 
5650  inputs = [True, 1, 1.0, torch.tensor(1), [1, 2], (1.0,), [1, 2], 1]
5651  type_literals = ['bool', 'int', 'float', 'torch.Tensor', 'list', 'tuple',
5652  '(list, tuple)', '(int, float, bool)']
5653  type_annotations = ['bool', 'int', 'float', 'Tensor', 'List[int]', 'Tuple[float]',
5654  'List[int]', 'int']
5655 
5656  # do zipping to try different types
5657  for inp, typ, type_hint in zip(inputs, type_literals, type_annotations):
5658  test(inp, typ, type_hint)
5659 
5660  # test optional isintance check
5661  with self.assertRaisesRegex(RuntimeError, "Optional isinstance check is not supported"):
5662  @torch.jit.script
5663  def opt_func(x):
5664  # type: (Optional[int]) -> bool
5665  return isinstance(x, int)
5666 
5667  def test_python_call(self):
5668  def pyfunc(a):
5669  return a * 3.0
5670 
5671  cu = torch.jit.CompilationUnit('''
5672  def other_func(a):
5673  return a + a
5674 
5675  def test_call_python(a):
5676  b = pyfunc(a)
5677  b = other_func(b)
5678  i = 0
5679  step = 1
5680  while i < 10:
5681  b = pyfunc(b)
5682  if bool(b > 3.0):
5683  b = pyfunc(b)
5684  i = 11
5685  return b
5686  ''')
5687  inputs = self._make_scalar_vars([1], torch.float)
5688  outputs = self._make_scalar_vars([54], torch.float)
5689 
5690  self.assertEqual(cu.test_call_python(*inputs), outputs[0])
5691 
5692  def test_python_call_failure(self):
5693  with self.assertRaisesRegex(RuntimeError, "undefined value pyfunc2"):
5694  def pyfunc(a):
5695  return a * 3.0
5696 
5697  cu = torch.jit.CompilationUnit('''
5698  def other_func(a):
5699  return a + a
5700 
5701  def test_call_python(a):
5702  b = pyfunc(a)
5703  b = other_func(b)
5704  i = 0
5705  step = 1
5706  while i < 10:
5707  b = pyfunc2(b)
5708  if b > 3.0:
5709  b = pyfunc(b)
5710  i = 11
5711  return b
5712  ''')
5713  inputs = self._make_scalar_vars([1], torch.float)
5714  outputs = self._make_scalar_vars([54], torch.float)
5715 
5716  self.assertEqual(cu.test_call_python(*inputs), outputs)
5717 
5718  def test_python_call_annotation(self):
5719  def pyfunc(a):
5720  return a * 3.0
5721 
5722  @torch.jit.script
5723  def foo(a):
5724  return pyfunc(a) + pyfunc(a)
5725 
5726  inputs = self._make_scalar_vars([1], torch.float)
5727  outputs = self._make_scalar_vars([6], torch.float)
5728  self.assertEqual(foo(*inputs), outputs[0])
5729 
5730  def test_python_call_annoytation_failure(self):
5731  with self.assertRaisesRegex(RuntimeError, "undefined value pyfunc2"):
5732  def pyfunc(a):
5733  return a * 3.0
5734 
5735  @torch.jit.script
5736  def foo(a):
5737  return pyfunc2(a) + pyfunc(a)
5738 
5739  inputs = self._make_scalar_vars([1], torch.float)
5740  outputs = self._make_scalar_vars([6], torch.float)
5741 
5742  self.assertEqual(foo(*inputs), outputs[0])
5743 
5744  def test_desugar_module(self):
5745  import torch.nn.functional as F
5746 
5747  def fn(x, slope):
5748  a = torch.abs(x)
5749  b = torch.nn.functional.prelu(x, slope)
5750  c = F.prelu(x, slope)
5751  return a, b, c
5752 
5753  x = torch.arange(-3., 4)
5754  slope = torch.tensor([0.5])
5755  self.checkScript(fn, [x, slope], optimize=True)
5756 
5757  def test_script_docstring(self):
5758  @torch.jit.script
5759  def with_docstring(x):
5760  """test str"""
5761  y = x
5762  """y is the same as x"""
5763  return y
5764  self.assertEqual(with_docstring.__doc__, 'test str')
5765 
5766  def test_script_method_docstring(self):
5767  class A(torch.jit.ScriptModule):
5768  @torch.jit.script_method
5769  def with_docstring(self, x):
5770  """test str"""
5771  y = x
5772  """y is the same as x"""
5773  return y
5774  a = A()
5775  self.assertEqual(a.with_docstring.__doc__, 'test str')
5776 
5777  @unittest.skipIf(TEST_WITH_UBSAN or not torch.fbgemm_is_cpu_supported(),
5778  'Quantized RNN requires FBGEMM. FBGEMM does not play'
5779  ' well with UBSAN at the moment, so we skip the test if'
5780  ' we are in a UBSAN environment.')
5781  def test_rnn_cell_quantized(self):
5782  d_in, d_hid = 2, 2
5783 
5784  for cell in [
5785  torch.nn.LSTMCell(d_in, d_hid).float(),
5786  torch.nn.GRUCell(d_in, d_hid).float(),
5787  torch.nn.RNNCell(d_in, d_hid).float(),
5788  ]:
5789  if isinstance(cell, torch.nn.LSTMCell):
5790  num_chunks = 4
5791  elif isinstance(cell, torch.nn.GRUCell):
5792  num_chunks = 3
5793  elif isinstance(cell, torch.nn.RNNCell):
5794  num_chunks = 1
5795 
5796  # Replace parameter values s.t. the range of values is exactly
5797  # 255, thus we will have 0 quantization error in the quantized
5798  # GEMM call. This i s for testing purposes.
5799  #
5800  # Note that the current implementation does not support
5801  # accumulation values outside of the range representable by a
5802  # 16 bit integer, instead resulting in a saturated value. We
5803  # must take care that in our test we do not end up with a dot
5804  # product that overflows the int16 range, e.g.
5805  # (255*127+255*127) = 64770. So, we hardcode the test values
5806  # here and ensure a mix of signedness.
5807  vals = [[100, -155],
5808  [100, -155],
5809  [-155, 100],
5810  [-155, 100],
5811  [100, -155],
5812  [-155, 100],
5813  [-155, 100],
5814  [100, -155]]
5815  vals = vals[:d_hid * num_chunks]
5816  cell.weight_ih = torch.nn.Parameter(
5817  torch.tensor(vals, dtype=torch.float),
5818  requires_grad=False)
5819  cell.weight_hh = torch.nn.Parameter(
5820  torch.tensor(vals, dtype=torch.float),
5821  requires_grad=False)
5822 
5823  ref = copy.deepcopy(cell)
5824 
5826  x = torch.tensor([[100, -155],
5827  [-155, 100],
5828  [100, -155]], dtype=torch.float)
5829  h0_vals = [[-155, 100],
5830  [-155, 155],
5831  [100, -155]]
5832  hx = torch.tensor(h0_vals, dtype=torch.float)
5833  if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
5834  cx = torch.tensor(h0_vals, dtype=torch.float)
5835  hiddens = (hx, cx)
5836  else:
5837  hiddens = hx
5838 
5839  if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
5840  class ScriptWrapper(torch.jit.ScriptModule):
5841  def __init__(self, cell):
5842  super(ScriptWrapper, self).__init__()
5843  self.cell = cell
5844 
5845  @torch.jit.script_method
5846  def forward(self, x, hiddens):
5847  # type: (torch.Tensor, Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
5848  return self.cell(x, hiddens)
5849  else:
5850 
5851  class ScriptWrapper(torch.jit.ScriptModule):
5852  def __init__(self, cell):
5853  super(ScriptWrapper, self).__init__()
5854  self.cell = cell
5855 
5856  @torch.jit.script_method
5857  def forward(self, x, hiddens):
5858  # type: (torch.Tensor, torch.Tensor) -> torch.Tensor
5859  return self.cell(x, hiddens)
5860 
5861  cell = ScriptWrapper(cell)
5862  outs = cell(x, hiddens)
5863  cell = self.getExportImportCopyWithPacking(cell)
5864 
5865  outs = cell(x, hiddens)
5866  ref_outs = ref(x, hiddens)
5867 
5868  self.assertEqual(len(outs), len(ref_outs))
5869  for out, ref_out in zip(outs, ref_outs):
5870  torch.testing.assert_allclose(out, ref_out)
5871 
5872  def test_script_module(self):
5873  class M1(torch.jit.ScriptModule):
5874  def __init__(self):
5875  super(M1, self).__init__(False)
5876  self.weight = nn.Parameter(torch.randn(2))
5877 
5878  @torch.jit.script_method
5879  def forward(self, thing):
5880  return self.weight + thing
5881 
5882  class PModule(nn.Module):
5883  def __init__(self):
5884  super(PModule, self).__init__()
5885  self.a = nn.Parameter(torch.randn(2, 3))
5886 
5887  def forward(self, a):
5888  return self.a.mm(a)
5889 
5890  class M2(torch.jit.ScriptModule):
5891  def __init__(self):
5892  super(M2, self).__init__(False)
5893  # test submodule
5894  self.sub = M1()
5895  self.sub2 = PModule()
5896  # test parameters
5897  self.weight = nn.Parameter(torch.randn(2, 3))
5898  self.bias = nn.Parameter(torch.randn(2))
5899  # test defining a method from a string
5900  self.define("""
5901  def hi(self, a):
5902  return self.weight.mm(a)
5903  """)
5904  # test script methods
5905 
5906  @torch.jit.script_method
5907  def doit(self, input):
5908  # test use of parameter
5909  return self.weight.mm(input)
5910 
5911  @torch.jit.script_method
5912  def doit2(self, input):
5913  return self.weight.mm(input)
5914 
5915  @torch.jit.script_method
5916  def forward(self, input):
5917  a = self.doit(input)
5918  b = self.doit2(input)
5919  c = self.hi(input)
5920  d = self.sub2(input)
5921  return a + b + self.bias + self.sub(a) + c + d
5922  m2 = M2()
5923  input = torch.randn(3, 2)
5924  a = m2.weight.mm(input)
5925  b = m2.weight.mm(input)
5926  c = m2.weight.mm(input)
5927  d = m2.sub2.a.mm(input)
5928  ref = a + b + m2.bias + m2.sub.weight + a + c + d
5929  self.assertEqual(ref, m2.forward(input))
5930  m2.weight = nn.Parameter(torch.zeros_like(m2.weight))
5931  m2.bias = nn.Parameter(torch.zeros_like(m2.bias))
5932  m2.sub.weight = nn.Parameter(torch.zeros_like(m2.sub.weight))
5933  m2.sub2.a.data.zero_()
5934  self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2)))
5935 
5936  def test_filecheck(self):
5937  def test_check():
5938  file = "232"
5939  FileCheck().check("2").check("3").check("2").run(file)
5940  FileCheck().check("232").run(file)
5941 
5942  with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
5943  FileCheck().check("22").run(file)
5944  with self.assertRaisesRegex(RuntimeError, "CHECK: 3"):
5945  FileCheck().check("3").check("3").run(file)
5946 
5947  test_check()
5948 
5949  def test_check_count():
5950  file = "22222"
5951  FileCheck().check_count("2", 5).run(file)
5952  FileCheck().check_count("22", 2).run(file)
5953  FileCheck().check_count("222", 1).run(file)
5954 
5955  with self.assertRaisesRegex(RuntimeError, 'Expected to not find'):
5956  FileCheck().check_count("2", 4, exactly=True).run(file)
5957 
5958  with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
5959  FileCheck().check_count("22", 3).run(file)
5960 
5961  with self.assertRaisesRegex(RuntimeError, "CHECK-COUNT-6: 2"):
5962  FileCheck().check_count("2", 6).run(file)
5963 
5964  test_check_count()
5965 
5966  def test_check_same():
5967  file = "22\n33"
5968  FileCheck().check_same("22").run(file)
5969 
5970  with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
5971  FileCheck().check_same("33").run(file)
5972 
5973  file = "22 1 3"
5974 
5975  FileCheck().check("2").check_same("3").run(file)
5976  FileCheck().check_count("2", 2).check_same("3").run(file)
5977 
5978  test_check_same()
5979 
5980  def test_check_next():
5981  file = "\n1\n2\n3"
5982  FileCheck().check("1").check_next("2").check_next("3").run(file)
5983  FileCheck().check_next("1").check_next("2").check_next("3").run(file)
5984 
5985  with self.assertRaisesRegex(RuntimeError, "Expected to find"):
5986  FileCheck().check("1").check_next("2").run("12")
5987 
5988  with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
5989  FileCheck().check("1").check_next("2").run("1\n\n2")
5990 
5991  test_check_next()
5992 
5993  def test_check_dag():
5994  fc = FileCheck().check_dag("1").check_dag("2").check_not("2")
5995  fc.run("12")
5996  fc.run("21")
5997 
5998  fc = FileCheck()
5999  fc.check_not("3").check_dag("1").check_dag("2").check_not("3")
6000  fc.run("1 3 2")
6001  fc.run("2 3 1")
6002 
6003  fc = FileCheck().check_dag("1").check_dag("2").check("3")
6004  with self.assertRaisesRegex(RuntimeError, 'Expected to find "3" but did not find it'):
6005  fc.run("1 3 2")
6006 
6007  test_check_dag()
6008 
6009  def test_check_not():
6010  FileCheck().check_not("2").check("1").run("12")
6011  FileCheck().check("2").check_not("2").run("12")
6012 
6013  with self.assertRaisesRegex(RuntimeError, 'Expected to not find "2"'):
6014  FileCheck().check_not("2").check("1").run("21")
6015 
6016  with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'):
6017  FileCheck().check("2").check_not("1").run("21")
6018 
6019  # checks with distinct range matchings
6020  fb = FileCheck().check_count("2", 2).check_count("2", 2).check_not("2")
6021  with self.assertRaisesRegex(RuntimeError, 'Expected to not find "2"'):
6022  fb.run("22 2 22")
6023 
6024  fb = FileCheck().check_count("2", 2).check_not("1").check_count("2", 2)
6025  with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'):
6026  fb.run("22 1 22")
6027 
6028  def test_script_module_call_noscript(self):
6029  class M(torch.jit.ScriptModule):
6030  def __init__(self):
6031  super(M, self).__init__(False)
6032  self.value = 1
6033 
6034  def foo(self):
6035  return torch.ones(2, 2) + self.value
6036 
6037  @torch.jit.script_method
6038  def forward(self, input):
6039  return input + self.foo()
6040 
6041  m = M()
6042  input = torch.randn(2, 2)
6043  o = m(input)
6044  self.assertEqual(o, input + torch.ones(2, 2) + 1)
6045  # check that we can change python attributes
6046  # and that those changes are picked up in script methods
6047  m.value = 2
6048  o = m(input)
6049  self.assertEqual(o, input + torch.ones(2, 2) + 2)
6050 
6051  def test_script_module_nochange_submodule(self):
6052  class M(torch.jit.ScriptModule):
6053  def __init__(self):
6054  super(M, self).__init__(False)
6055  self.sub = nn.Linear(5, 5)
6056 
6057  @torch.jit.script_method
6058  def forward(self, input):
6059  return self.sub(input)
6060 
6061  m = M()
6062  input = torch.randn(1, 5, 5)
6063  o = m(input)
6064  self.assertEqual(o, m.sub(input))
6065  with self.assertRaisesRegex(RuntimeError, "cannot re-assign"):
6066  m.sub = nn.Linear(5, 5)
6067 
6068  def test_script_inline_trace_multiple_args(self):
6069  class M(torch.jit.ScriptModule):
6070  def __init__(self):
6071  super(M, self).__init__(False)
6072 
6073  def forward(self, input, input2):
6074  return input + input2
6075 
6076  class M2(torch.jit.ScriptModule):
6077  def __init__(self):
6078  super(M2, self).__init__(False)
6079  self.m = torch.jit.trace(M(), (torch.zeros(4, 3), torch.zeros(4, 3)))
6080 
6081  @torch.jit.script_method
6082  def forward(self, inp):
6083  return self.m(inp, inp)
6084 
6085  m2 = M2()
6086  m2(torch.zeros(4, 3))
6087 
6088  def test_script_module_const(self):
6089  class M(torch.jit.ScriptModule):
6090 
6091  __constants__ = ['b', 'i', 'c']
6092 
6093  def __init__(self):
6094  super(M, self).__init__(False)
6095  self.b = False
6096  self.i = 1
6097  self.c = 3.5
6098 
6099  @torch.jit.script_method
6100  def forward(self):
6101  return self.b, self.i, self.c
6102 
6103  m = M()
6104  o0, o1, o2 = m()
6105  self.assertEqual(o0, 0)
6106  self.assertEqual(o1, 1)
6107  self.assertEqual(o2, 3.5)
6108 
6109  def test_script_module_fail_const(self):
6110  class M(torch.jit.ScriptModule):
6111  def __init__(self):
6112  super(M, self).__init__(False)
6113  self.b = False
6114 
6115  @torch.jit.script_method
6116  def forward(self):
6117  return self.b
6118  with self.assertRaisesRegex(RuntimeError, "is not usable in a script method"):
6119  M()
6120 
6121  def test_script_module_valid_consts(self):
6122  tester = self
6123 
6124  class Foo(torch.jit.ScriptModule):
6125  __constants__ = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i']
6126 
6127  def __init__(self):
6128  super(Foo, self).__init__(False)
6129  self.a = 1
6130  self.b = 1.2
6131  self.c = False
6132  with tester.assertRaisesRegex(
6133  TypeError,
6134  "'Linear' object for attribute 'd' is not a valid constant"):
6135  self.d = [nn.Linear(3, 4)]
6136  self.e = lambda x: x
6137  self.f = [3, 4, 5]
6138  tester.assertTrue(type(self.f) is tuple)
6139  self.g = [3, (3, 4), 5]
6140  with tester.assertRaisesRegex(TypeError, "not a valid constant"):
6141  self.h = type(1)
6142  with tester.assertRaisesRegex(TypeError, "not a valid constant"):
6143  self.i = (3, 4, {})
6144 
6145  f = Foo()
6146 
6147  def test_script_module_param_buffer_mutation(self):
6148  # TODO: add param mutation test case after JIT support it
6149  class ModuleBufferMutate(torch.jit.ScriptModule):
6150  def __init__(self):
6151  super(ModuleBufferMutate, self).__init__(False)
6152  self.register_buffer('running_var', torch.tensor(0, dtype=torch.long))
6153 
6154  @torch.jit.script_method
6155  def forward(self):
6156  if self.training:
6157  self.running_var += 1
6158  return self.running_var
6159 
6160  m = ModuleBufferMutate()
6161  self.assertEqual(m(), 1)
6162  m.eval()
6163  self.assertEqual(m(), 1)
6164 
6165  def test_script_module_for(self):
6166  class M(torch.jit.ScriptModule):
6167  __constants__ = ['b']
6168 
6169  def __init__(self):
6170  super(M, self).__init__(False)
6171  self.b = [1, 2, 3, 4]
6172 
6173  @torch.jit.script_method
6174  def forward(self):
6175  sum = 0
6176  for i in self.b:
6177  sum += i
6178  return sum
6179 
6180  m = M()
6181  self.assertEqual(m(), 10)
6182 
6183  def test_script_module_for2(self):
6184  class Sub(torch.jit.ScriptModule):
6185  def __init__(self):
6186  super(Sub, self).__init__(False)
6187  self.weight = nn.Parameter(torch.randn(2))
6188 
6189  @torch.jit.script_method
6190  def forward(self, thing):
6191  return self.weight + thing
6192 
6193  class M(torch.jit.ScriptModule):
6194  __constants__ = ['mods']
6195 
6196  def __init__(self):
6197  super(M, self).__init__(False)
6198  self.mods = nn.ModuleList([Sub() for i in range(10)])
6199 
6200  @torch.jit.script_method
6201  def forward(self, v):
6202  for m in self.mods:
6203  v = m(v)
6204  return v
6205 
6206  i = torch.Tensor(2)
6207  m = M()
6208  o = m(i)
6209  v = i
6210  for sub in m.mods:
6211  v = sub(v)
6212  self.assertEqual(o, v)
6213 
6214  def test_script_module_const_submodule_fail(self):
6215  class Sub(torch.jit.ScriptModule):
6216  def __init__(self):
6217  super(Sub, self).__init__(False)
6218  self.weight = nn.Parameter(torch.randn(2))
6219 
6220  @torch.jit.script_method
6221  def forward(self, thing):
6222  return self.weight + thing
6223 
6224  class M(torch.jit.ScriptModule):
6225  def __init__(self):
6226  super(M, self).__init__(False)
6227  self.mods = [Sub() for _ in range(10)]
6228 
6229  @torch.jit.script_method
6230  def forward(self):
6231  for _ in self.mods:
6232  print(1)
6233  return 4
6234 
6235  with self.assertRaisesRegex(RuntimeError, "did you forget to add it __constants__"):
6236  M()
6237 
6238  # Specialized error for Tensors
6239  class S(torch.jit.ScriptModule):
6240  def __init__(self):
6241  self.tensor_constant = torch.ones(2)
6242 
6243  @torch.jit.script_method
6244  def forward(self):
6245  return self.tensor_constant + 2
6246 
6247  with self.assertRaisesRegex(RuntimeError, "Tensors must be added to a module as a buffer or parameter"):
6248  S()
6249 
6251  def __init__(self):
6252  super(TestScript.DerivedStateModule, self).__init__()
6253  self.param = torch.nn.Parameter(torch.ones(3, 4, dtype=torch.float))
6254  self.register_buffer('derived', torch.neg(self.param).detach().clone())
6255 
6256  # This is a flag so we can test that the pack method was called
6257  self.register_buffer('pack_called', torch.zeros(1, dtype=torch.long))
6258  # This is a flag so we can test that the unpack method was called
6259  self.register_buffer('unpack_called', torch.zeros(1, dtype=torch.long))
6260 
6261  @torch.jit.script_method
6262  def _pack(self):
6263  self.pack_called.set_(torch.ones(1, dtype=torch.long))
6264  self.derived.set_(torch.rand(1, dtype=torch.float).detach())
6265 
6266  @torch.jit.script_method
6267  def _unpack(self):
6268  self.unpack_called.set_(torch.ones(1, dtype=torch.long))
6269  self.derived.set_(torch.neg(self.param).detach())
6270 
6271  @torch.jit.script_method
6272  def forward(self, x):
6273  return x + self.derived
6274 
6275  def test_pack_unpack_state(self):
6277  x = torch.rand(3, 4, dtype=torch.float)
6278  torch.testing.assert_allclose(sm(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
6279 
6280  # Test save path
6281  self.assertFalse(sm.pack_called.item())
6282  self.assertFalse(sm.unpack_called.item())
6283  imported = self.getExportImportCopyWithPacking(sm)
6284  # ensure pack was called before serialization
6285  self.assertTrue(sm.pack_called.item())
6286  # ensure unpack was called after serialization so as to leave the module in an initialized state
6287  self.assertTrue(sm.unpack_called.item())
6288 
6289  torch.testing.assert_allclose(sm.derived, torch.neg(sm.param))
6290 
6291  # Test load paths
6292  self.assertTrue(imported.unpack_called.item())
6293  torch.testing.assert_allclose(imported(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
6294 
6295  def test_pack_unpack_nested(self):
6296  class SubSubMod(torch.jit.ScriptModule):
6297  def __init__(self):
6298  super(SubSubMod, self).__init__()
6299  self.register_buffer('buf', torch.ones(3, 4) * 3)
6300 
6301  @torch.jit.script_method
6302  def _pack(self):
6303  self.buf.set_(torch.zeros(1, dtype=torch.double))
6304 
6305  @torch.jit.script_method
6306  def _unpack(self):
6307  self.buf.set_(torch.ones(3, 4, dtype=torch.double) * 3)
6308 
6309  @torch.jit.script_method
6310  def forward(self, x):
6311  return x + self.buf
6312 
6313  class SubMod(torch.jit.ScriptModule):
6314  def __init__(self):
6315  super(SubMod, self).__init__()
6316  self.register_buffer('buf', torch.ones(3, 4) * 2)
6317  self.ssm = SubSubMod()
6318 
6319  @torch.jit.script_method
6320  def _pack(self):
6321  self.buf.set_(torch.zeros(1, dtype=torch.double))
6322 
6323  @torch.jit.script_method
6324  def _unpack(self):
6325  self.buf.set_(torch.ones(3, 4, dtype=torch.double) * 2)
6326 
6327  @torch.jit.script_method
6328  def forward(self, x):
6329  return self.ssm(x + self.buf)
6330 
6331  class Mod(torch.jit.ScriptModule):
6332  def __init__(self):
6333  super(Mod, self).__init__()
6334  self.submod = SubMod()
6335  self.register_buffer('buf', torch.ones(3, 4) * 1)
6336 
6337  @torch.jit.script_method
6338  def _pack(self):
6339  self.buf.set_(torch.zeros(1, dtype=torch.double))
6340 
6341  @torch.jit.script_method
6342  def _unpack(self):
6343  self.buf.set_(torch.ones(3, 4, dtype=torch.double))
6344 
6345  @torch.jit.script_method
6346  def forward(self, x):
6347  return self.submod(x + self.buf)
6348 
6349  m = Mod()
6350  torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6)
6351  m.apply(lambda s: s._pack())
6352  torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.zeros(3, 4))
6353  m.apply(lambda s: s._unpack())
6354  torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6)
6355 
6356  def test_script_module_not_tuple(self):
6357  class M(torch.jit.ScriptModule):
6358  __constants__ = ['mods']
6359 
6360  def __init__(self):
6361  super(M, self).__init__(False)
6362  self.mods = 1
6363 
6364  @torch.jit.script_method
6365  def forward(self, v):
6366  for m in self.mods:
6367  print(m)
6368  return v
6369  with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
6370  M()
6371 
6372  def test_script_module_list_sequential_error(self):
6373  class M(torch.jit.ScriptModule):
6374  def __init__(self, mod_list):
6375  super(M, self).__init__(False)
6376  self.mods = mod_list
6377 
6378  @torch.jit.script_method
6379  def forward(self, v):
6380  for m in self.mods:
6381  v = m(v)
6382  return v
6383 
6384  with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
6385  a = M(nn.Sequential(nn.ReLU()))
6386  with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
6387  a = M(nn.ModuleList([nn.ReLU()]))
6388 
6389  def test_script_sequential_for(self):
6390  class Sub(torch.jit.ScriptModule):
6391  def __init__(self):
6392  super(Sub, self).__init__(False)
6393  self.weight = nn.Parameter(torch.randn(2))
6394 
6395  @torch.jit.script_method
6396  def forward(self, thing):
6397  return self.weight + thing
6398 
6399  class M(torch.jit.ScriptModule):
6400  __constants__ = ['mods']
6401 
6402  def __init__(self):
6403  super(M, self).__init__(False)
6404  self.mods = nn.Sequential(Sub(), Sub(), Sub())
6405 
6406  @torch.jit.script_method
6407  def forward(self, v):
6408  for m in self.mods:
6409  v = m(v)
6410  return v
6411 
6412  @torch.jit.script_method
6413  def forward2(self, v):
6414  return self.mods(v)
6415 
6416  i = torch.Tensor(2)
6417  m = M()
6418  o = m(i)
6419  v = i
6420  for sub in m.mods:
6421  v = sub(v)
6422  self.assertEqual(o, v)
6423 
6424  o2 = m.forward2(i)
6425  self.assertEqual(o2, v)
6426 
6427  def test_script_sequential_multi_output_fail(self):
6428  class Sub(torch.jit.ScriptModule):
6429  def __init__(self):
6430  super(Sub, self).__init__(False)
6431  self.weight = nn.Parameter(torch.randn(2))
6432 
6433  @torch.jit.script_method
6434  def forward(self, thing):
6435  return self.weight + thing
6436 
6437  class ReturnMulti(torch.jit.ScriptModule):
6438  def __init__(self):
6439  super(ReturnMulti, self).__init__(False)
6440 
6441  @torch.jit.script_method
6442  def forward(self, x):
6443  return x, x, x
6444 
6445  class HaveSequential(torch.jit.ScriptModule):
6446  __constants__ = ['someseq']
6447 
6448  def __init__(self):
6449  super(HaveSequential, self).__init__(False)
6450  self.someseq = nn.Sequential(
6451  Sub(),
6452  ReturnMulti(),
6453  Sub()
6454  )
6455 
6456  @torch.jit.script_method
6457  def forward(self, x):
6458  return self.someseq(x)
6459 
6460  with self.assertRaisesRegex(RuntimeError, "(Tensor, Tensor, Tensor)"):
6461  hs = HaveSequential()
6462  i = torch.Tensor(2)
6463  hs(i)
6464 
6465  def test_constant_insert_fail_lint(self):
6466  @torch.jit.script
6467  def foo(x):
6468  y = x + 1
6469  z = torch.tensor([[1.0, 2.5]])
6470  print(x, z)
6471 
6472  # check that it doesnt error
6473  self.run_pass('constant_propagation', foo.graph)
6474  self.assertTrue("aten::tensor" in str(foo.graph)) # not constant propped
6475 
6476  def test_script_sequential_in_mod_list(self):
6477  class Sub(torch.jit.ScriptModule):
6478  def __init__(self):
6479  super(Sub, self).__init__(False)
6480  self.weight = nn.Parameter(torch.randn(2))
6481 
6482  @torch.jit.script_method
6483  def forward(self, thing):
6484  return self.weight + thing
6485 
6486  class M(torch.jit.ScriptModule):
6487  __constants__ = ['mods']
6488 
6489  def __init__(self):
6490  super(M, self).__init__(False)
6491  self.mods = nn.ModuleList([Sub(), nn.Sequential(Sub(), nn.Sequential(Sub(), Sub()), Sub())])
6492 
6493  @torch.jit.script_method
6494  def forward(self, v):
6495  for mod in self.mods:
6496  v = mod(v)
6497  return v
6498 
6499  m = M()
6500  graph = str(m.graph)
6501  self.assertTrue(graph.count("aten::add") == 5)
6502  self.assertTrue("python" not in graph)
6503 
6504  def test_script_nested_mod_list(self):
6505  class Sub(torch.jit.ScriptModule):
6506  def __init__(self):
6507  super(Sub, self).__init__(False)
6508  self.weight = nn.Parameter(torch.randn(2))
6509 
6510  @torch.jit.script_method
6511  def forward(self, thing):
6512  return self.weight + thing
6513 
6514  class M(torch.jit.ScriptModule):
6515  __constants__ = ['mods']
6516 
6517  def __init__(self):
6518  super(M, self).__init__(False)
6519  self.mods = nn.ModuleList([nn.ModuleList([Sub()]), nn.Sequential(Sub()), nn.ModuleList([Sub(), Sub()])])
6520 
6521  @torch.jit.script_method
6522  def forward(self, v):
6523  for mod in self.mods:
6524  for m in mod:
6525  v = m(v)
6526  return v
6527 
6528  m = M()
6529  graph = str(m.graph)
6530  self.assertTrue(graph.count("aten::add") == 4)
6531  self.assertTrue("python" not in graph)
6532 
6533  def test_constant_as_attr(self):
6534  class M(torch.jit.ScriptModule):
6535  __constants__ = ['dim']
6536 
6537  def __init__(self):
6538  super(M, self).__init__(False)
6539  self.dim = 1
6540 
6541  @torch.jit.script_method
6542  def forward(self, v):
6543  return torch.cat([v, v, v], dim=self.dim)
6544  v = torch.zeros(1, 1)
6545  self.assertEqual(torch.cat([v, v, v], dim=1), M()(v))
6546 
6547  class StarTestSumStarred(torch.nn.Module):
6548  def __init__(self):
6549  super(TestScript.StarTestSumStarred, self).__init__()
6550 
6551  def forward(self, *inputs):
6552  output = inputs[0]
6553  for i in range(1, len(inputs)):
6554  output += inputs[i]
6555  return output
6556 
6557  class StarTestReturnThree(torch.nn.Module):
6558  def __init__(self):
6559  super(TestScript.StarTestReturnThree, self).__init__()
6560 
6561  def forward(self, rep):
6562  return rep, rep, rep
6563 
6564  def test_script_star_expr(self):
6565 
6566  class M2(torch.jit.ScriptModule):
6567  def __init__(self):
6568  super(M2, self).__init__(True)
6570  (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)))
6571  self.g = torch.jit.trace(TestScript.StarTestReturnThree(), torch.ones(4, 3))
6572 
6573  @torch.jit.script_method
6574  def forward(self, rep):
6575  tup = self.g(rep)
6576  return self.m(*tup)
6577 
6578  m = M2()
6579  self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
6580 
6581  def test_script_star_expr_string(self):
6582  class M2(torch.jit.ScriptModule):
6583  def __init__(self):
6584  super(M2, self).__init__(True)
6586  (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)))
6587  self.g = torch.jit.trace(TestScript.StarTestReturnThree(), torch.ones(4, 3))
6588 
6589  self.define('''
6590  def forward(self, rep):
6591  tup = self.g(rep)
6592  return self.m(*tup)
6593  ''')
6594 
6595  m = M2()
6596  self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
6597 
6598  class StarTestSumAndReturnThree(torch.nn.Module):
6599  def __init__(self):
6600  super(TestScript.StarTestSumAndReturnThree, self).__init__()
6601 
6602  def forward(self, *inputs):
6603  output = inputs[0]
6604  for i in range(1, len(inputs)):
6605  output += inputs[i]
6606  return output, output, output
6607 
6608  def test_script_star_assign(self):
6609  class M2(torch.jit.ScriptModule):
6610  def __init__(self):
6611  super(M2, self).__init__(True)
6612  self.g = torch.jit.trace(TestScript.StarTestSumAndReturnThree(), torch.ones(4, 3))
6613  self.define('''
6614  def forward(self, rep):
6615  head, *tail = self.g(rep)
6616  return head
6617  ''')
6618 
6619  m = M2()
6620  self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
6621 
6622  def test_script_module_star_assign2(self):
6623  class M2(torch.jit.ScriptModule):
6624  def __init__(self):
6625  super(M2, self).__init__(True)
6626  self.g = torch.jit.trace(
6628  (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)),
6629  _force_outplace=True)
6630  self.define('''
6631  def forward(self, rep):
6632  *head, tail = self.g(rep, rep, rep)
6633  return tail
6634  ''')
6635 
6636  m = M2()
6637  self.assertEqual(m(torch.ones(4, 3)), 3 * torch.ones(4, 3))
6638 
6639  def test_script_module_star_assign2_inplace(self):
6640  class M2(torch.jit.ScriptModule):
6641  def __init__(self):
6642  super(M2, self).__init__(True)
6643  self.g = torch.jit.trace(
6645  (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)),
6646  _force_outplace=False)
6647  self.define('''
6648  def forward(self, rep):
6649  *head, tail = self.g(rep, rep, rep)
6650  return tail
6651  ''')
6652 
6653  m = M2()
6654  # since forward() makes three aliases to the input `rep` before passing
6655  # it to StarTestSumAndReturnThree(), in-place behavior will be different
6656  # than the above out of place.
6657  self.assertEqual(m(torch.ones(4, 3)), 4 * torch.ones(4, 3))
6658 
6659  def test_script_module_star_assign_fail_pythonop(self):
6660 
6661  with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
6662  class M2(torch.jit.ScriptModule):
6663  def __init__(self):
6664  super(M2, self).__init__(True)
6665 
6666  def myfunc():
6667  return torch.zeros(1, 2, 3), torch.zeros(1, 2, 3)
6668 
6669  self.define('''
6670  def forward(self, rep):
6671  a, *b = myfunc()
6672  return a
6673  ''')
6674 
6675  m = M2()
6676  m(torch.zeros(4, 3))
6677 
6678  def test_script_module_star_assign_fail_builtin(self):
6679  with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
6680  class M2(torch.jit.ScriptModule):
6681  def __init__(self):
6682  super(M2, self).__init__(True)
6683 
6684  self.define('''
6685  def forward(self, rep):
6686  a, *b = torch.neg(rep)
6687  return a
6688  ''')
6689 
6690  m = M2()
6691  m(torch.zeros(4, 3))
6692 
6693  def test_pack_padded_pad_packed_trace(self):
6694  from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
6695  T, B, C = 3, 5, 7
6696 
6697  class PadPackedWrapper(torch.nn.Module):
6698  def __init__(self):
6699  super(PadPackedWrapper, self).__init__()
6700 
6701  def forward(self, x, seq_lens):
6702  x = pack_padded_sequence(x, seq_lens)
6703  x, _ = pad_packed_sequence(x)
6704  return x
6705 
6706  x = np.ones((T, B, C))
6707  seq_lens = np.array([3, 3, 2, 2, 1], dtype=np.int32)
6708  # set padding value so we can test equivalence
6709  for b in range(B):
6710  if seq_lens[b] < T:
6711  x[seq_lens[b]:, b, :] = 0
6712  seq_lens = torch.from_numpy(seq_lens)
6713  x = torch.autograd.Variable(torch.from_numpy(x), requires_grad=True)
6714 
6715  m = PadPackedWrapper()
6716  m_traced = torch.jit.trace(m, (x, seq_lens,))
6717 
6718  y = m(x, seq_lens)
6719  loss = torch.sum(y)
6720  loss.backward()
6721  grad = x.grad.clone()
6722  x.grad.zero_()
6723 
6724  y_traced = m_traced(x, seq_lens)
6725  loss_traced = torch.sum(y_traced)
6726  loss_traced.backward()
6727  grad_traced = x.grad.clone()
6728 
6729  self.assertEqual(y_traced, x)
6730  self.assertEqual(y_traced, y)
6731  self.assertEqual(grad, grad_traced)
6732 
6733  f = io.BytesIO()
6734  torch.onnx._export(m, (x, seq_lens), f, verbose=False)
6735 
6736  def test_script_outputs(self):
6737  with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
6738  @torch.jit.script
6739  def foo(a):
6740  c, d = a + a
6741  return c + d
6742 
6743  @torch.jit.script
6744  def return3():
6745  return 1, 2, 3
6746 
6747  with self.assertRaisesRegex(RuntimeError, "too many values to unpack"):
6748  @torch.jit.script
6749  def bind2():
6750  a, b = return3()
6751  print(a)
6752  print(b)
6753 
6754  @unittest.skipIf(not RUN_CUDA, "requires CUDA")
6755  def test_script_get_device_cuda(self):
6756  @torch.jit.script
6757  def foo(a):
6758  return a.get_device()
6759 
6760  v = torch.randn(1, device='cuda')
6761  self.assertEqual(foo(v), 0)
6762 
6763  def test_script_chunk(self):
6764  @torch.jit.script
6765  def foo(a):
6766  b, c = torch.chunk(a, dim=0, chunks=2)
6767  return b
6768  v = torch.rand(10, 3)
6769  self.assertEqual(torch.chunk(v, dim=0, chunks=2)[0], foo(v))
6770 
6771  def test_rnn_trace_override(self):
6772  from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
6773  num_layers = 3
6774  T, B, C = 11, 5, 7
6775 
6776  class RNNTraceWrapper(torch.nn.Module):
6777  def __init__(self, cell_type):
6778  super(RNNTraceWrapper, self).__init__()
6779  if cell_type == 'RNN':
6780  self.rnn = torch.nn.RNN(input_size=C, hidden_size=C, num_layers=num_layers)
6781  elif cell_type == 'LSTM':
6782  self.rnn = torch.nn.LSTM(input_size=C, hidden_size=C, num_layers=num_layers)
6783  elif cell_type == 'GRU':
6784  self.rnn = torch.nn.GRU(input_size=C, hidden_size=C, num_layers=num_layers)
6785 
6786  def forward(self, x, seq_lens):
6787  x = pack_padded_sequence(x, seq_lens)
6788  x, _ = self.rnn(x)
6789  x, _ = pad_packed_sequence(x)
6790  return x
6791 
6792  for cell_type in ['RNN', 'LSTM', 'GRU']:
6793  x = torch.ones(T, B, C, requires_grad=True)
6794  seq_lens = torch.from_numpy(np.array([11, 3, 2, 2, 1], dtype=np.int32))
6795 
6796  m = RNNTraceWrapper(cell_type)
6797  m_traced = torch.jit.trace(m, (x, seq_lens,))
6798 
6799  y = m(x, seq_lens)
6800  loss = torch.sum(y)
6801  loss.backward()
6802  grad = x.grad.clone()
6803  x.grad.zero_()
6804 
6805  y_traced = m_traced(x, seq_lens)
6806  loss_traced = torch.sum(y_traced)
6807  loss_traced.backward()
6808  grad_traced = x.grad.clone()
6809 
6810  self.assertEqual(y_traced, y)
6811  self.assertEqual(grad, grad_traced)
6812 
6813  f = io.BytesIO()
6814  torch.onnx._export(m, (x, seq_lens), f, verbose=False)
6815 
6816  def test_python_call_non_tensor(self):
6817  def foo(a, b, c):
6818  # type: (Tensor, int, Tuple[Tensor, int]) -> Tuple[int, Tensor]
6819  d, e = c
6820  return b + e, a + d
6821 
6822  @torch.jit.script
6823  def bar():
6824  x = torch.ones(3, 4)
6825  a, b = foo(x, 3, (x, 3))
6826  return a, b
6827 
6828  self.assertEqual((6, torch.ones(3, 4) + 1), bar())
6829 
6830  def test_python_call_non_tensor_wrong(self):
6831  with self.assertRaisesRegex(RuntimeError, r"but instead got value of type tuple"):
6832  def foo():
6833  # type: () -> Tensor
6834  return ((3, 4),) # noqa: T484
6835 
6836  @torch.jit.script
6837  def bar():
6838  return foo()
6839 
6840  bar()
6841 
6842  def test_tuples(self):
6843  def foo(i):
6844  a = (i + 4, i * 2)
6845  c = a
6846  # some nonsense with if-statements and loops to check
6847  # that tuple lowering doesn't fail
6848  if True:
6849  c = (i * 9, i + 1)
6850  t0, t1 = c
6851  while False:
6852  t0, t1 = c
6853  c = (t1, t0)
6854  x = (1,)
6855  y = 1,
6856  return t0, x, y
6857 
6858  v = torch.rand(10, 3)
6859  self.checkScript(foo, (v,))
6860 
6861  with self.assertRaisesRegex(RuntimeError, r"variable 'a' previously has type \(Tensor, Tensor\)"):
6862  @torch.jit.script
6863  def mixtypes(x):
6864  a = (x, x)
6865  if True:
6866  a = 4
6867 
6868  def test_if_tuple_sizes(self):
6869  with self.assertRaisesRegex(RuntimeError, "Type mismatch"):
6870  @torch.jit.script
6871  def diff_tuple_sizes(x):
6872  if False:
6873  c0 = ((x, x), (x, x, x))
6874  else:
6875  c0 = ((x, x, x), (x, x))
6876  return c0
6877 
6878  def test_if_different_type(self):
6879  with self.assertRaisesRegex(RuntimeError, "Type mismatch: c0 is set to type int "
6880  "in the true branch and type float in the false branch:"):
6881  @torch.jit.script
6882  def diff_type_used():
6883  if False:
6884  c0 = 1
6885  else:
6886  c0 = 1.0
6887  return c0
6888 
6889  with self.assertRaisesRegex(RuntimeError, "variable 'c0' previously has type float"):
6890  @torch.jit.script
6891  def diff_existing_type(x):
6892  c0 = 1.0
6893  if False:
6894  c0 = 1
6895  print(x)
6896  return x
6897 
6898  @torch.jit.script
6899  def diff_type_unused():
6900  if True:
6901  c0 = 1
6902  print(c0)
6903  else:
6904  c0 = 1.0
6905  print(c0)
6906  return 1
6907 
6908  def test_if_list_cat(self):
6909  # testing that different length lists don't throw error on cat in shape prop
6910  @torch.jit.script
6911  def test_list(x):
6912  if bool(x.sum() < 1):
6913  c = [x, x]
6914  else:
6915  c = [x, x, x]
6916  return torch.cat(c)
6917 
6918  b = torch.zeros(2, 4)
6919  test_list.graph.propagate_shapes((b,), False)
6920 
6921  def test_if_supertype(self):
6922  @torch.jit.script
6923  def tensor_unifying(x, y, z):
6924  # testing dynamic is appropriately set for y and z
6925  if True:
6926  x, y, z = x, y, z
6927  else:
6928  x, y, z = x, x, y
6929 
6930  return x, y, z
6931 
6932  a = torch.zeros(2, 2, dtype=torch.float)
6933  b = torch.zeros(2, 4, dtype=torch.long)
6934  c = torch.zeros(2, 4, dtype=torch.float)
6935 
6936  tensor_unifying.graph.propagate_shapes((a, b, c), False)
6937  if_outputs = list(tensor_unifying.graph.findNode("prim::If").outputs())
6938  self.assertTrue(if_outputs[0].type().str() == "Float(*, *)")
6939  self.assertTrue(if_outputs[1].type().str() == "Tensor")
6940  self.assertTrue(if_outputs[2].type().str() == "Tensor")
6941 
6942  def test_list_unify(self):
6943  # allowing a unififed int?[] would cause a runtime error b/c
6944  # the index operation expects int?[] to be a generic list,
6945  # but in the true branch the IValue will be a int list
6946  with self.assertRaisesRegex(RuntimeError, "int[] in the true branch and type None[]"):
6947  @torch.jit.script
6948  def list_optional_fails(x):
6949  # type: (bool) -> Optional[int]
6950  if x:
6951  y = [1]
6952  else:
6953  y = [None] # noqa: T484
6954  return y[0]
6955 
6956  @torch.jit.script
6957  def list_tensors(x):
6958  # type: (bool) -> Tuple[Tensor, List[Tensor]]
6959  if x:
6960  a = torch.zeros([1, 1])
6961  y = [a]
6962  else:
6963  a = torch.zeros([1, 2])
6964  y = [a]
6965  return a, y
6966 
6967  self.run_pass('constant_propagation', list_tensors.graph)
6969  m._create_method_from_graph("forward", list_tensors.graph)
6970  # testing that tensor type of lists is unified
6971  self.getExportImportCopy(m)
6972 
6973  def test_type_annotations_repeated_list(self):
6974  @torch.jit.script
6975  def float_fn(x, y):
6976  # type: (float, BroadcastingList3[float]) -> List[float]
6977  return y
6978  self.assertEqual(float_fn(2.0, 1.0), float_fn(2.0, [1.0, 1.0, 1.0]))
6979  self.assertEqual(float_fn(2.0, 1.0), float_fn(2.0, (1.0, 1.0, 1.0)))
6980 
6981  @torch.jit.script
6982  def float_fn_call():
6983  print(float_fn(1.0, 1.0))
6984  print(float_fn(1.0, (1.0, 1.0, 1.0)))
6985 
6986  @torch.jit.script
6987  def int_fn(x):
6988  # type: (BroadcastingList3[int]) -> List[int]
6989  return x
6990  self.assertEqual(int_fn(1), int_fn([1, 1, 1]))
6991  self.assertEqual(int_fn(1), int_fn((1, 1, 1)))
6992 
6993  @torch.jit.script
6994  def int_fn_call():
6995  print(int_fn(1))
6996  print(int_fn((1, 1, 1)))
6997 
6998  with self.assertRaisesRegex(RuntimeError, "must be a positive integer:"):
6999  @torch.jit.script # noqa: T484
7000  def fn(x):
7001  # type: (BroadcastingListx[int]) -> List[int] # noqa: T484
7002  return x
7003 
7004  # using CU so that flake8 error on int[2] is not raised (noqa not working)
7005  with self.assertRaisesRegex(RuntimeError, "Unknown type constructor"):
7006  cu = torch.jit.CompilationUnit('''
7007  def nested(x, y):
7008  # type: (int, Tuple[int, int[2]]) -> List[int]
7009  return x # noqa: T484
7010  ''')
7011 
7012  def test_ntuple_builtins(self):
7013  from torch.nn.modules.utils import _single, _pair, _triple, _quadruple
7014 
7015  def test_ints():
7016  return _single(1), _pair(2), _triple(3), _quadruple(4)
7017 
7018  def test_floats():
7019  return _single(1), _pair(2.1), _triple(3.1), _quadruple(4.1)
7020 
7021  self.checkScript(test_ints, ())
7022  self.checkScript(test_floats, ())
7023 
7024  def test_embedding_renorm_grad_error(self):
7025  # Testing that the builtin call to embedding_renorm_ correctly throws
7026  # Error when .backward() is called on its input
7027 
7028  def embedding_norm(input, embedding_matrix, max_norm):
7029  F.embedding(input, embedding_matrix, max_norm=0.01)
7030 
7031  @torch.jit.script
7032  def embedding_norm_script(input, embedding_matrix, max_norm):
7033  # type: (Tensor, Tensor, float) -> None
7034  F.embedding(input, embedding_matrix, max_norm=0.01)
7035 
7036  for _ in [embedding_norm, embedding_norm_script]:
7037  input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
7038  embedding_matrix = torch.randn(10, 3)
7039 
7040  var1 = torch.randn(10, 3, requires_grad=True)
7041  var2 = var1.detach().requires_grad_()
7042  output1 = var1 * embedding_matrix
7043  output2 = var2 * embedding_matrix
7044 
7045  output1.sum().backward()
7046 
7047  ignore = F.embedding(input, embedding_matrix, max_norm=0.01)
7048  with self.assertRaisesRegex(RuntimeError, "modified"):
7049  output2.sum().backward()
7050 
7051  def test_type_annotations(self):
7052  def fn(x, y):
7053  # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
7054  return x, x * 2, x * 3
7055 
7056  with self.assertRaisesRegex(RuntimeError, r"need 4 values .* found only 3"):
7057  @torch.jit.script
7058  def script_fn(x):
7059  x, y, z, w = fn(x, x)
7060 
7061  with self.assertRaisesRegex(RuntimeError, r"too many values .* need 2 but found 3"):
7062  @torch.jit.script
7063  def script_fn2(x):
7064  x, y = fn(x, x)
7065 
7066  def fn_unpack(x):
7067  y, z, w = fn(x, x)
7068  return y
7069 
7070  def fn_index(x):
7071  q = fn(x, x)
7072  return x
7073 
7074  def fn_string(str, strpair):
7075  # type: (str, Tuple[str, str]) -> Tuple[str, int, str, str]
7076  str1, str2 = strpair
7077  return str, 2, str1, str2
7078 
7079  x = torch.ones(2, 2)
7080  self.checkScript(fn_unpack, (x,), optimize=True)
7081  self.checkScript(fn_index, (x,), optimize=True)
7082  self.checkScript(fn_string, ("1", ("3", "4")), optimize=True)
7083 
7084  def test_type_annotations_varargs(self):
7085  def fn_varargs(x, *args):
7086  return args[0] if args else x
7087 
7088  def fn1(x, y, z):
7089  return fn_varargs(x)
7090 
7091  def fn2(x, y, z):
7092  return fn_varargs(x, y)
7093 
7094  def fn3(x, y, z):
7095  return fn_varargs(x, y, z)
7096 
7097  x, y, z = [torch.randn(2, 2) for _ in range(3)]
7098  self.checkScript(fn1, (x, y, z), optimize=True)
7099  self.checkScript(fn2, (x, y, z), optimize=True)
7100  self.checkScript(fn3, (x, y, z), optimize=True)
7101 
7102  @unittest.skipIf(not PY35, "Python 3.5 needed")
7103  def test_type_annotation_py3(self):
7104  import importlib.util
7105 
7106  code = dedent("""
7107  import torch
7108  from torch import Tensor
7109  from typing import Tuple
7110 
7111  def fn(x : torch.Tensor, y : Tensor, z) -> Tuple[Tensor, Tensor, Tensor]:
7112  return (x, y + z, z)
7113  """)
7114 
7115  with tempfile.TemporaryDirectory() as tmp_dir:
7116  script_path = os.path.join(tmp_dir, 'script.py')
7117  with open(script_path, 'w') as f:
7118  f.write(code)
7119  fn = get_fn('test_type_annotation_py3', script_path)
7120 
7121  with self.assertRaisesRegex(RuntimeError, r"expected a value of type Tensor for argument"
7122  r" '0' but found \(Tensor, Tensor\)"):
7123  @torch.jit.script
7124  def bad_fn(x):
7125  x, y = fn((x, x), x, x)
7126  return y
7127 
7128  with self.assertRaisesRegex(RuntimeError, r"too many values .* need 2 but found 3"):
7129  @torch.jit.script
7130  def bad_fn2(x):
7131  x, y = fn(x, x, x)
7132  return y
7133 
7134  with self.assertRaisesRegex(RuntimeError, r"need 4 values .* found only 3"):
7135  @torch.jit.script
7136  def bad_fn3(x):
7137  x, y, z, w = fn(x, x, x)
7138  return y
7139 
7140  def good_fn(x):
7141  y, z, w = fn(x, x, x)
7142  return y, z, w
7143 
7144  self.checkScript(good_fn, (torch.ones(2, 2),), optimize=True)
7145 
7146  def test_type_annotation_module(self):
7147  class BaseModule(torch.jit.ScriptModule):
7148  def foo(self, x):
7149  # type: (Tensor) -> Tensor
7150  return x + 1
7151 
7152  def bar(self, x, y):
7153  # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor]
7154  return x + y, y
7155 
7156  def baz(self, x, y):
7157  return x
7158 
7159  class ModuleTooMany(BaseModule):
7160  @torch.jit.script_method
7161  def method(self, x):
7162  return self.foo(x, x)
7163 
7164  class ModuleTooFew(BaseModule):
7165  @torch.jit.script_method
7166  def method(self, x):
7167  return self.bar(x)
7168 
7169  class ModuleTooManyAssign(BaseModule):
7170  @torch.jit.script_method
7171  def method(self, x):
7172  y, z, w = self.bar(x, x)
7173  return x
7174 
7175  class ModuleDefault(BaseModule):
7176  @torch.jit.script_method
7177  def method(self, x):
7178  y = self.baz(x)
7179  return x
7180 
7181  with self.assertRaisesRegex(RuntimeError, "expected at most 1 arguments but found 2"):
7182  ModuleTooMany()
7183  with self.assertRaisesRegex(RuntimeError, "argument 1 not provided"):
7184  ModuleTooFew()
7185  with self.assertRaisesRegex(RuntimeError, "need 3 values .* found only 2"):
7186  ModuleTooManyAssign()
7187  with self.assertRaisesRegex(RuntimeError, "argument 1 not provided."):
7188  ModuleDefault()
7189 
7190  def test_script_define_order(self):
7191  class M(torch.jit.ScriptModule):
7192  def __init__(self):
7193  pass
7194 
7195  @torch.jit.script_method
7196  def call_foo(self, input):
7197  return self.foo(input)
7198 
7199  @torch.jit.script_method
7200  def foo(self, input):
7201  return input + 1
7202  m = M()
7203  self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64)))
7204 
7205  def test_script_define_order_recursive_fail(self):
7206  class M(torch.jit.ScriptModule):
7207  def __init__(self):
7208  pass
7209 
7210  @torch.jit.script_method
7211  def call_foo(self, input):
7212  return self.foo(input)
7213 
7214  @torch.jit.script_method
7215  def foo(self, input):
7216  self.call_foo(input)
7217 
7218  with self.assertRaisesRegex(RuntimeError, 'called recursively involving'):
7219  M()
7220 
7221  def test_script_kwargs_fn_call(self):
7222  class M(torch.jit.ScriptModule):
7223  def __init__(self):
7224  pass
7225 
7226  @torch.jit.script_method
7227  def call_foo(self, input):
7228  return self.foo(input=input, bar=1)
7229 
7230  @torch.jit.script_method
7231  def foo(self, bar, input):
7232  # type: (int, Tensor) -> Tensor
7233  return input + bar
7234  m = M()
7235  self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64)))
7236 
7237  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
7238  def test_trace_of_script(self):
7239  @torch.jit.script
7240  def foo(a, c):
7241  b = 0.0
7242  if bool(a == 0.0):
7243  b = 1.0
7244  return b + c
7245 
7246  a = torch.ones(1, dtype=torch.float)
7247 
7248  @_trace(torch.zeros(1, dtype=torch.float))
7249  def use(b):
7250  return foo(b - 1.0, a) + 1.0
7251 
7252  # test we propagated shapes through the function
7253  self.assertTrue("Dynamic" not in str(use.graph))
7254 
7255  self.assertEqual(3, use(torch.ones(1, dtype=torch.float)))
7256  self.assertEqual(2, use(torch.zeros(1, dtype=torch.float)))
7257 
7258  def test_if_define(self):
7259  @torch.jit.script
7260  def foo(a):
7261  if bool(a == 0):
7262  b = 1
7263  else:
7264  b = 0
7265  return b + 1
7266 
7267  @torch.jit.script
7268  def foo2(a):
7269  b = 0
7270  if bool(a == 0):
7271  b = 1
7272  return b + 1
7273 
7274  @torch.jit.script
7275  def foo3(a):
7276  b = 1
7277  if bool(a == 0):
7278  c = 4
7279  else:
7280  b = 0
7281  return b + 1
7282 
7283  a = torch.ones(1, dtype=torch.long)
7284  b = torch.zeros(1, dtype=torch.long)
7285  self.assertEqual(1, foo(a))
7286  self.assertEqual(2, foo(b))
7287  self.assertEqual(1, foo2(a))
7288  self.assertEqual(2, foo2(b))
7289  self.assertEqual(1, foo3(a))
7290  self.assertEqual(2, foo3(b))
7291 
7292  def test_script_module_export_submodule(self):
7293  class M1(torch.jit.ScriptModule):
7294  def __init__(self):
7295  super(M1, self).__init__(False)
7296  self.weight = nn.Parameter(torch.randn(2))
7297 
7298  @torch.jit.script_method
7299  def forward(self, thing):
7300  return self.weight + thing
7301 
7302  class M2(torch.jit.ScriptModule):
7303  def __init__(self):
7304  super(M2, self).__init__(False)
7305  # test submodule
7306  self.sub = M1()
7307  self.weight = nn.Parameter(torch.randn(2, 3))
7308  self.bias = nn.Parameter(torch.randn(2))
7309  self.define("""
7310  def hi(self, a):
7311  return self.weight.mm(a)
7312  """)
7313 
7314  @torch.jit.script_method
7315  def doit(self, input):
7316  return self.weight.mm(input)
7317 
7318  @torch.jit.script_method
7319  def doit2(self, input):
7320  return self.weight.mm(input)
7321 
7322  @torch.jit.script_method
7323  def doit3(self, input):
7324  return input + torch.ones([1], dtype=torch.double)
7325 
7326  @torch.jit.script_method
7327  def forward(self, input):
7328  a = self.doit(input)
7329  b = self.doit2(input)
7330  c = self.hi(input)
7331  return a + b + self.bias + c
7332 
7333  m_orig = M2()
7334  m_import = self.getExportImportCopy(m_orig)
7335 
7336  input = torch.randn(3, 2)
7337  self.assertEqual(m_orig.doit(input), m_import.doit(input))
7338  self.assertEqual(m_orig.hi(input), m_import.hi(input))
7339  self.assertEqual(m_orig.doit3(input), m_import.doit3(input))
7340  self.assertEqual(m_orig.forward(input), m_import.forward(input))
7341 
7342  @skipIfNoTorchVision
7343  def test_script_module_trace_resnet18(self):
7344  x = torch.ones(1, 3, 224, 224)
7345  m_orig = torch.jit.trace(torchvision.models.resnet18(), torch.ones(1, 3, 224, 224))
7346  m_import = self.getExportImportCopy(m_orig)
7347 
7348  input = torch.randn(1, 3, 224, 224, requires_grad=True)
7349  output_orig = m_orig(input)
7350  output_orig.sum().backward()
7351  grad_orig = input.grad.clone()
7352  input.grad.zero_()
7353 
7354  output_import = m_import(input)
7355  output_import.sum().backward()
7356  grad_import = input.grad.clone()
7357 
7358  self.assertEqual(output_orig, output_import)
7359  self.assertEqual(grad_orig, grad_import)
7360 
7361  @skipIfNoTorchVision
7362  def test_script_module_script_resnet(self):
7363  def conv1x1(in_planes, out_planes, stride=1):
7364  """1x1 convolution"""
7365  return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
7366 
7367  def conv3x3(in_planes, out_planes, stride=1):
7368  """3x3 convolution with padding"""
7369  return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
7370  padding=1, bias=False)
7371 
7372  class BasicBlock(torch.jit.ScriptModule):
7373  expansion = 1
7374  __constants__ = ['downsample']
7375 
7376  def __init__(self, inplanes, planes, stride=1, downsample=None):
7377  super(BasicBlock, self).__init__()
7378  self.conv1 = conv3x3(inplanes, planes, stride)
7379  self.bn1 = nn.BatchNorm2d(planes)
7380  self.relu = nn.ReLU(inplace=True)
7381  self.conv2 = conv3x3(planes, planes)
7382  self.bn2 = nn.BatchNorm2d(planes)
7383  self.downsample = downsample
7384  self.stride = stride
7385 
7386  @torch.jit.script_method
7387  def forward(self, x):
7388  residual = x
7389 
7390  out = self.conv1(x)
7391  out = self.bn1(out)
7392  out = self.relu(out)
7393 
7394  out = self.conv2(out)
7395  out = self.bn2(out)
7396 
7397  if self.downsample is not None:
7398  residual = self.downsample(x)
7399 
7400  out += residual
7401  out = self.relu(out)
7402 
7403  return out
7404 
7405  class ResNet(torch.jit.ScriptModule):
7406  __constants__ = ['layer1', 'layer2', 'layer3', 'layer4']
7407 
7408  def __init__(self, block, layers, num_classes=1000):
7409  super(ResNet, self).__init__()
7410  self.inplanes = 64
7411  self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
7412  bias=False)
7413  self.bn1 = nn.BatchNorm2d(64)
7414  self.relu = nn.ReLU(inplace=True)
7415  self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
7416  self.layer1 = self._make_layer(block, 64, layers[0])
7417  self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
7418  self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
7419  self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
7420  self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
7421  self.fc = nn.Linear(512 * block.expansion, num_classes)
7422 
7423  for m in self.modules():
7424  if isinstance(m, nn.Conv2d):
7425  nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
7426  elif isinstance(m, nn.BatchNorm2d):
7427  nn.init.constant_(m.weight, 1)
7428  nn.init.constant_(m.bias, 0)
7429 
7430  def _make_layer(self, block, planes, blocks, stride=1):
7431  downsample = None
7432  if stride != 1 or self.inplanes != planes * block.expansion:
7433  downsample = nn.Sequential(
7434  conv1x1(self.inplanes, planes * block.expansion, stride),
7435  nn.BatchNorm2d(planes * block.expansion),
7436  )
7437 
7438  layers = []
7439  layers.append(block(self.inplanes, planes, stride, downsample))
7440  self.inplanes = planes * block.expansion
7441  for _ in range(1, blocks):
7442  layers.append(block(self.inplanes, planes))
7443 
7444  return nn.Sequential(*layers)
7445 
7446  @torch.jit.script_method
7447  def forward(self, x):
7448  x = self.conv1(x)
7449  x = self.bn1(x)
7450  x = self.relu(x)
7451  x = self.maxpool(x)
7452 
7453  x = self.layer1(x)
7454  x = self.layer2(x)
7455  x = self.layer3(x)
7456  x = self.layer4(x)
7457 
7458  x = self.avgpool(x)
7459  x = x.view(x.size(0), -1)
7460  x = self.fc(x)
7461 
7462  return x
7463 
7464  resnet18 = ResNet(BasicBlock, [2, 2, 2, 2])
7465 
7466  resnet18_imported = self.getExportImportCopy(resnet18)
7467 
7468  input = torch.randn(1, 3, 224, 224, requires_grad=True)
7469  output_orig = resnet18(input)
7470  output_orig.sum().backward()
7471  grad_orig = input.grad.clone()
7472  input.grad.zero_()
7473  output_import = resnet18_imported(input)
7474  output_import.sum().backward()
7475  grad_import = input.grad.clone()
7476 
7477  self.assertEqual(output_orig, output_import)
7478  self.assertEqual(grad_orig, grad_import)
7479 
7480  def test_script_module_export_tensor_type(self):
7481  class M(torch.jit.ScriptModule):
7482 
7483  def __init__(self, type):
7484  super(M, self).__init__(False)
7485  self.param = torch.nn.Parameter(torch.zeros((5, 5), dtype=type).random_())
7486 
7487  @torch.jit.script_method
7488  def foo(self):
7489  return self.param
7490 
7491  for type in [torch.float, torch.double]:
7492  m_orig = M(type)
7493  m_import = self.getExportImportCopy(m_orig)
7494  # check to make sure the storage wasn't resized
7495  self.assertTrue(m_orig.param.storage().size() == 25)
7496  self.assertEqual(m_orig.foo(), m_import.foo())
7497  self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype)
7498 
7499  @unittest.skipIf(not RUN_CUDA, "testing cuda tensors require CUDA")
7500  def test_script_module_export_tensor_cuda(self):
7501  class M(torch.jit.ScriptModule):
7502 
7503  def __init__(self):
7504  super(M, self).__init__(False)
7505  self.param = torch.nn.Parameter(torch.zeros((5, 5), device='cuda:0').random_())
7506 
7507  @torch.jit.script_method
7508  def foo(self):
7509  return self.param
7510 
7511  m_orig = M()
7512  m_import = self.getExportImportCopy(m_orig)
7513  # check to make sure the storage wasn't resized
7514  self.assertTrue(m_orig.param.storage().size() == 25)
7515  self.assertTrue(m_import.foo().device == torch.device('cuda:0'))
7516  self.assertEqual(m_orig.foo(), m_import.foo())
7517  self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype)
7518 
7519  def test_script_module_export_blocks(self):
7520  class M(torch.jit.ScriptModule):
7521  def __init__(self, n, m):
7522  super(M, self).__init__()
7523  self.weight = torch.nn.Parameter(torch.rand(n, m))
7524 
7525  @torch.jit.script_method
7526  def forward(self, input):
7527  if bool(input.sum() > 0):
7528  output = self.weight.mv(input)
7529  else:
7530  output = self.weight + input
7531  return output
7532 
7533  m_orig = M(200, 200)
7534  m_import = self.getExportImportCopy(m_orig)
7535 
7536  t = torch.rand(200)
7537  self.assertEqual(m_orig(t), m_import(t))
7538 
7539  def test_script_module_export_shared_storage(self):
7540  class M(torch.jit.ScriptModule):
7541 
7542  def __init__(self):
7543  super(M, self).__init__(False)
7544  self.param1 = torch.nn.Parameter(torch.rand(5, 5))
7545  self.param2 = torch.nn.Parameter(self.param1[3])
7546  self.param3 = torch.nn.Parameter(torch.rand(5, 5))
7547  self.param4 = torch.nn.Parameter(torch.rand(11, 5)[1:6])
7548 
7549  @torch.jit.script_method
7550  def foo(self):
7551  return self.param1 + self.param2 + self.param3 + self.param4
7552 
7553  m_orig = M()
7554  m_import = self.getExportImportCopy(m_orig)
7555 
7556  self.assertEqual(m_orig.foo(), m_import.foo())
7557  self.assertTrue(m_import.param1.storage().data_ptr() == m_import.param2.storage().data_ptr())
7558  self.assertTrue(m_import.param1.storage().data_ptr() != m_import.param3.storage().data_ptr())
7559 
7560  def test_onnx_export_script_module(self):
7561  class ModuleToExport(torch.jit.ScriptModule):
7562  def __init__(self):
7563  super(ModuleToExport, self).__init__()
7564 
7565  @torch.jit.script_method
7566  def forward(self, x):
7567  y = x - x
7568  return x + x
7569 
7570  mte = ModuleToExport()
7571  outputs = mte(torch.zeros(1, 2, 3))
7572  self.assertExpected(torch.onnx.export_to_pretty_string(
7573  mte, (torch.zeros(1, 2, 3),), None, verbose=False,
7574  example_outputs=outputs))
7575 
7576  def test_trace_nested_datatypes(self):
7577  @torch.jit.script
7578  def foo(x):
7579  return [[x + 1, x - 1], [x + 2, x - 2]]
7580 
7581  def bar(x):
7582  list_stuff = foo(x)
7583  return list_stuff[0][0], list_stuff[1][1]
7584 
7585  traced = torch.jit.trace(bar, torch.rand(3, 4))
7586  x = torch.rand(5, 6)
7587  self.assertEqual(bar(x), traced(x))
7588 
7589  @suppress_warnings
7590  def test_onnx_export_func_with_warnings(self):
7591  @torch.jit.script
7592  def func_with_warning(inp):
7593  return torch.nn.functional.sigmoid(inp) # triggers a deprecation warning
7594 
7595  class WarningTest(torch.nn.Module):
7596  def __init__(self):
7597  super(WarningTest, self).__init__()
7598 
7599  def forward(self, x):
7600  return func_with_warning(x)
7601 
7602  outputs = WarningTest()(torch.randn(42))
7603  # no exception
7605  WarningTest(), torch.randn(42), None, verbose=False,
7606  example_outputs=outputs)
7607 
7608  def test_onnx_export_script_python_fail(self):
7609  class ModuleToInline(torch.jit.ScriptModule):
7610  def __init__(self):
7611  super(ModuleToInline, self).__init__()
7612 
7613  def forward(self, x):
7614  return torch.neg(x)
7615 
7616  class ModuleToExport(torch.jit.ScriptModule):
7617  def __init__(self):
7618  super(ModuleToExport, self).__init__()
7619  self.mod = ModuleToInline()
7620 
7621  @torch.jit.script_method
7622  def forward(self, x):
7623  y = self.mod(x)
7624  return y + y
7625 
7626  mte = ModuleToExport()
7627  outputs = mte(torch.zeros(1, 2, 3))
7628  f = io.BytesIO()
7629  with self.assertRaisesRegex(RuntimeError, "Couldn't export Python operator"):
7630  torch.onnx._export(mte, (torch.zeros(1, 2, 3),), f, verbose=False,
7631  example_outputs=outputs)
7632 
7633  def test_onnx_export_script_inline_trace(self):
7634  class ModuleToInline(torch.jit.ScriptModule):
7635  def __init__(self):
7636  super(ModuleToInline, self).__init__()
7637 
7638  def forward(self, x):
7639  return torch.neg(x)
7640 
7641  class ModuleToExport(torch.jit.ScriptModule):
7642  def __init__(self):
7643  super(ModuleToExport, self).__init__()
7644  self.mod = torch.jit.trace(ModuleToInline(), torch.zeros(1, 2, 3))
7645 
7646  @torch.jit.script_method
7647  def forward(self, x):
7648  y = self.mod(x)
7649  return y + y
7650 
7651  mte = ModuleToExport()
7652  outputs = mte(torch.zeros(1, 2, 3))
7653  self.assertExpected(torch.onnx.export_to_pretty_string(
7654  mte, (torch.zeros(1, 2, 3),), None, verbose=False,
7655  example_outputs=outputs))
7656 
7657  def test_onnx_export_script_inline_script(self):
7658  class ModuleToInline(torch.jit.ScriptModule):
7659  def __init__(self):
7660  super(ModuleToInline, self).__init__()
7661 
7662  @torch.jit.script_method
7663  def forward(self, x):
7664  return torch.neg(x)
7665 
7666  class ModuleToExport(torch.jit.ScriptModule):
7667  def __init__(self):
7668  super(ModuleToExport, self).__init__()
7669  self.mod = ModuleToInline()
7670 
7671  @torch.jit.script_method
7672  def forward(self, x):
7673  y = self.mod(x)
7674  return y + y
7675 
7676  mte = ModuleToExport()
7677  outputs = mte(torch.zeros(1, 2, 3))
7678  self.assertExpected(torch.onnx.export_to_pretty_string(
7679  mte, (torch.zeros(1, 2, 3),), None, verbose=False,
7680  example_outputs=outputs))
7681 
7682  def test_onnx_export_script_module_loop(self):
7683  class ModuleToExport(torch.jit.ScriptModule):
7684  def __init__(self):
7685  super(ModuleToExport, self).__init__()
7686 
7687  @torch.jit.script_method
7688  def forward(self, x):
7689  # test if we support end to end onnx export on loop and
7690  # nested loops with and without loop index
7691  for _ in range(5):
7692  for i in range(3):
7693  x = x + i
7694  return x
7695 
7696  mte = ModuleToExport()
7697  outputs = mte(torch.zeros(1, 2, 3))
7698  self.assertExpected(torch.onnx.export_to_pretty_string(
7699  mte, (torch.zeros(1, 2, 3),), None, verbose=False,
7700  example_outputs=outputs))
7701 
7702  def test_onnx_export_script_truediv(self):
7703  class ModuleToExport(torch.jit.ScriptModule):
7704  def __init__(self):
7705  super(ModuleToExport, self).__init__()
7706 
7707  @torch.jit.script_method
7708  def forward(self, x):
7709  z = x.size(0) / 2
7710  return x + z
7711 
7712  mte = ModuleToExport()
7713  outputs = mte(torch.zeros(1, 2, 3))
7714  self.assertExpected(torch.onnx.export_to_pretty_string(
7715  mte, (torch.zeros(1, 2, 3),), None, verbose=False,
7716  example_outputs=outputs))
7717 
7718  def test_onnx_raw_export_script_truediv(self):
7719  class ModuleToExport(torch.jit.ScriptModule):
7720  def __init__(self):
7721  super(ModuleToExport, self).__init__()
7722 
7723  @torch.jit.script_method
7724  def forward(self, x):
7725  z = x.size(0) / 2
7726  return x + z
7727 
7728  mte = ModuleToExport()
7729  outputs = mte(torch.zeros(1, 2, 3))
7730  self.assertExpected(torch.onnx.export_to_pretty_string(
7731  mte, (torch.zeros(1, 2, 3),), None, verbose=False,
7732  example_outputs=outputs, export_raw_ir=True))
7733 
7734  def test_onnx_export_script_non_alpha_add_sub(self):
7735  class ModuleToExport(torch.jit.ScriptModule):
7736  def __init__(self):
7737  super(ModuleToExport, self).__init__()
7738 
7739  @torch.jit.script_method
7740  def forward(self, x):
7741  bs = x.size(0) + 1
7742  return bs - 1
7743 
7744  mte = ModuleToExport()
7745  outputs = torch.LongTensor([mte(torch.rand(3, 4))])
7746  self.assertExpected(torch.onnx.export_to_pretty_string(
7747  mte, (torch.rand(3, 4),), None, verbose=False,
7748  example_outputs=outputs))
7749 
7750  def test_onnx_export_script_module_if(self):
7751  class ModuleToExport(torch.jit.ScriptModule):
7752  def __init__(self):
7753  super(ModuleToExport, self).__init__()
7754 
7755  @torch.jit.script_method
7756  def forward(self, x):
7757  if bool(torch.sum(x) > 0):
7758  x = torch.neg(x)
7759  return x
7760 
7761  mte = ModuleToExport()
7762  outputs = mte(torch.zeros(1, 2, 3, dtype=torch.long))
7763  self.assertExpected(torch.onnx.export_to_pretty_string(
7764  mte, (torch.zeros(1, 2, 3),), None, verbose=False,
7765  example_outputs=outputs))
7766 
7767  def test_onnx_export_script_inline_params(self):
7768  class ModuleToInline(torch.jit.ScriptModule):
7769  def __init__(self):
7770  super(ModuleToInline, self).__init__()
7771  self.m = torch.nn.Parameter(torch.ones(3, 3))
7772  self.unused = torch.nn.Parameter(torch.ones(1, 2, 3))
7773 
7774  @torch.jit.script_method
7775  def forward(self, x):
7776  return torch.mm(x, self.m)
7777 
7778  class ModuleToExport(torch.jit.ScriptModule):
7779  def __init__(self):
7780  super(ModuleToExport, self).__init__()
7781  self.mod = ModuleToInline()
7782  self.param = torch.nn.Parameter(torch.ones(3, 4))
7783 
7784  @torch.jit.script_method
7785  def forward(self, x):
7786  y = self.mod(x)
7787  return torch.mm(y, self.param)
7788 
7789  mte = ModuleToExport()
7790  result = mte(torch.zeros(2, 3))
7791  reference = torch.mm(torch.mm(torch.zeros(2, 3), torch.ones(3, 3)), torch.ones(3, 4))
7792  self.assertEqual(result, reference)
7793  self.assertExpected(torch.onnx.export_to_pretty_string(
7794  mte, (torch.ones(2, 3),), None, verbose=False,
7795  example_outputs=result, propagate=True))
7796 
7797  def test_trace_with_size(self):
7798  @_trace(torch.zeros(1, 1))
7799  def foo(x):
7800  return x + 1
7801 
7802  @torch.jit.script
7803  def bar(x):
7804  y = int(foo(x))
7805  if True:
7806  y = 7
7807  return y + 1
7808 
7809  self.assertEqual(8, bar(torch.ones(1, 1)))
7810 
7811  def test_tracing_slicing(self):
7812  @_trace(torch.zeros(10))
7813  def foo_trace(x):
7814  return x[-5:-3]
7815 
7816  @torch.jit.script
7817  def foo_script(x):
7818  return x[-5:-3]
7819 
7820  def foo(x):
7821  return x[-5:-3]
7822 
7823  a = torch.arange(0, 8)
7824  b = torch.arange(0, 20)
7825  self.assertEqual(foo_trace(a), foo_script(a))
7826  self.assertEqual(foo_trace(a), foo(a))
7827  self.assertNotEqual(foo_trace(a), foo_trace(b))
7828 
7829  def test_tracing_indexing(self):
7830  @_trace(torch.zeros(10))
7831  def foo_trace(x):
7832  return x[-2]
7833 
7834  @torch.jit.script
7835  def foo_script(x):
7836  return x[-2]
7837 
7838  def foo(x):
7839  return x[-2]
7840 
7841  a = torch.arange(0, 8)
7842  b = torch.arange(0, 20)
7843  self.assertEqual(foo_script(a), foo_trace(a))
7844  self.assertEqual(foo_trace(a), foo(a))
7845  self.assertNotEqual(foo_trace(a), foo_trace(b))
7846 
7847  def test_index_select_shape_prop(self):
7848 
7849  @torch.jit.script
7850  def foo(x, y):
7851  return torch.index_select(x, index=y, dim=1)
7852 
7853  a = torch.zeros(2, 2)
7854  b = torch.zeros(4, dtype=torch.long)
7855  torch._C._jit_pass_complete_shape_analysis(foo.graph, (a, b), False)
7856  FileCheck().check("Double(2, 4)").run(str(foo.graph))
7857 
7858  def test_onnx_export_speculate(self):
7859 
7860  class Foo(torch.jit.ScriptModule):
7861  def __init__(self, m):
7862  super(Foo, self).__init__()
7863  self.m = m
7864 
7865  @torch.jit.script_method
7866  def forward(self, x):
7867  x += x
7868  # because we are testing if we emit `if` statement correctly
7869  # we cannot use `True` as the condition. Constant prop
7870  # would remove the `if` statements.
7871  c = torch.sum(x) > 4
7872  if bool(c):
7873  if bool(c):
7874  y = self.m(x)
7875  else:
7876  y = self.m(x)
7877  else:
7878  y = self.m(x)
7879  return y
7880 
7881  linear = torch.jit.trace(nn.Linear(10, 20).float(), torch.zeros(1, 10, dtype=torch.float))
7882 
7883  @torch.jit.script
7884  def transpose(x):
7885  return x.t()
7886 
7887  f1 = Foo(transpose)
7888  outputs_f1 = f1(torch.ones(1, 10, dtype=torch.float))
7889  f2 = Foo(linear)
7890  outputs_f2 = f2(torch.ones(1, 10, dtype=torch.float))
7891 
7893  f1,
7894  (torch.ones(1, 10, dtype=torch.float), ),
7895  None, verbose=False, example_outputs=outputs_f1)
7896  self.assertExpected(onnx_ish, subname='f1')
7898  f2,
7899  (torch.ones(1, 10, dtype=torch.float), ),
7900  None, verbose=False, example_outputs=outputs_f2)
7901  self.assertExpected(onnx_ish, subname='f2')
7902 
7903  def test_onnx_export_shape_reshape(self):
7904  class Foo(torch.nn.Module):
7905  def forward(self, x):
7906  import torch.onnx.operators
7907  x = x.repeat(5, 1, 1)
7910  return reshaped
7911 
7912  foo = torch.jit.trace(Foo(), torch.zeros(1, 2, 3))
7913  outputs = foo(torch.zeros(1, 2, 3))
7914  f = io.BytesIO()
7915  s = torch.onnx.export_to_pretty_string(foo, (torch.zeros(1, 2, 3)), f,
7916  example_outputs=outputs)
7917  self.assertExpected(s)
7918 
7919  def test_shape_analysis_loop(self):
7920  def foo(a, b, x):
7921  c = a
7922  # on the first iteration of the loop it appears that
7923  # c should have a expand to the size of b
7924  # but on the second+ iterations, there is no broadcast and the
7925  # sizes are different.
7926  # previously this would cause the compiler to (1) enter an infinite
7927  # loop trying to compute the shape, and (2) insert invalid
7928  # broadcasts.
7929  # this test ensure we don't regress on these issues
7930  for _ in range(2):
7931  a = c + b
7932  c = x
7933  b = x
7934  return a
7935 
7936  self.checkScript(foo, (torch.zeros(1), torch.zeros(4), torch.zeros(5)), optimize=False)
7937 
7938  def test_intlist_args(self):
7939  def func_1(x):
7941 
7942  def func_2(x):
7943  return torch.nn.functional.adaptive_avg_pool1d(x, output_size=1)
7944 
7945  def func_3(x):
7946  return torch.nn.functional.adaptive_avg_pool1d(x, output_size=[1])
7947 
7948  x = torch.randn(8, 8, 8)
7949  self.checkScript(func_1, [x], optimize=True)
7950  self.checkScript(func_2, [x], optimize=True)
7951  self.checkScript(func_3, [x], optimize=True)
7952 
7953  def test_wrong_implicit_expand(self):
7954 
7955  @_trace(torch.zeros(3), torch.zeros(1))
7956  def foo(a, b):
7957  return a + b
7958 
7959  a = torch.rand(4)
7960  b = torch.rand(4)
7961  self.assertEqual(a + b, foo(a, b))
7962 
7963  def test_builtin_args_fails(self):
7964 
7965  with self.assertRaisesRegex(RuntimeError, 'expected at most'):
7966  @torch.jit.script
7967  def f0(a):
7968  torch.sum(a, a, a, a)
7969 
7970  with self.assertRaisesRegex(RuntimeError, 'argument self not provided'):
7971  @torch.jit.script
7972  def f1(a):
7973  torch.sum(foo=4)
7974 
7975  with self.assertRaisesRegex(RuntimeError, 'specified twice'):
7976  @torch.jit.script
7977  def f2(a):
7978  torch.sum(a, self=a)
7979 
7980  with self.assertRaisesRegex(RuntimeError, 'not provided'):
7981  @torch.jit.script
7982  def f3(a):
7983  torch.sum(dim=4)
7984 
7985  with self.assertRaisesRegex(RuntimeError, 'for argument \'tensors\' but found Tensor'):
7986  @torch.jit.script
7987  def f4(a):
7988  torch.cat(a)
7989 
7990  with self.assertRaisesRegex(RuntimeError, r'argument \'tensors\' but found int\[\]'):
7991  @torch.jit.script
7992  def f5(a):
7993  torch.cat([3])
7994 
7995  with self.assertRaisesRegex(RuntimeError, 'Lists must contain only a single type'):
7996  @torch.jit.script
7997  def f6(a):
7998  a.expand(size=[3, [4]])
7999 
8000  with self.assertRaisesRegex(RuntimeError, 'xpected a value of type Tensor for argument \'self\''):
8001  @torch.jit.script
8002  def f7(a):
8003  torch.sum([4])
8004 
8005  def test_builtin_args(self):
8006 
8007  def t0(a):
8008  # default arg dim
8009  return torch.cat([a, a])
8010 
8011  self.checkScript(t0, (torch.zeros(1, 1),))
8012 
8013  def t1(a):
8014  # keywords out of order
8015  return torch.cat(dim=1, tensors=[a, a])
8016 
8017  self.checkScript(t1, (torch.zeros(1, 1, 2),))
8018 
8019  def t2(a):
8020  # mix const/non-const attributes
8021  if True:
8022  b = 1
8023  else:
8024  b = 0
8025  return torch.sum(a, dim=b, keepdim=False)
8026 
8027  self.checkScript(t2, (torch.zeros(1, 1, 2),))
8028 
8029  def test_parser_type_annotations(self):
8030  cu = torch.jit.CompilationUnit('''
8031  def foo(x : Tensor, y : Tuple[Tuple[Tensor, Tensor], Tensor]) -> Tuple[Tensor, Tensor]:
8032  return x, x
8033  ''')
8034 
8035  self.assertExpected(cu.__getattr__('foo').pretty_print_schema())
8036 
8037  def test_parser_type_annotations_comment(self):
8038  cu = torch.jit.CompilationUnit('''
8039  def foo(x, y):
8040  # type: (Tensor, Tuple[Tuple[Tensor, Tensor], Tensor]) -> Tuple[Tensor, Tensor]
8041  return x, x
8042  ''')
8043 
8044  self.assertExpected(cu.__getattr__('foo').pretty_print_schema())
8045 
8046  def test_parser_type_annotations_unknown_type(self):
8047  with self.assertRaisesRegex(RuntimeError, r'Unknown type name Foo'):
8048  cu = torch.jit.CompilationUnit('''
8049  def foo(x : Tensor, y : Tuple[Tuple[Foo, Tensor], Tensor]) -> Tuple[Tensor, Tensor]:
8050  return x, x
8051  ''')
8052 
8053  def test_parser_type_annotations_subscript_non_ident(self):
8054  with self.assertRaisesRegex(RuntimeError, r'Subscripted type must be a type identifier'):
8055  cu = torch.jit.CompilationUnit('''
8056  def foo(x : Tensor, y : Tuple[Tensor, Tensor][Tensor]) -> Tuple[Tensor, Tensor]:
8057  return x, x
8058  ''')
8059 
8060  def test_parser_type_annotations_subscript_tensor(self):
8061  with self.assertRaisesRegex(RuntimeError, r'Unknown type constructor Tensor'):
8062  cu = torch.jit.CompilationUnit('''
8063  def foo(x : Tensor, y : Tensor[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
8064  return x, x
8065  ''')
8066 
8067  def test_parser_type_annotations_incompatible_expression(self):
8068  with self.assertRaisesRegex(RuntimeError, r'Expression of type \+ cannot be used in a type expression'):
8069  cu = torch.jit.CompilationUnit('''
8070  def foo(x : Tensor, y : Tuple[3 + 4, Tensor]) -> Tuple[Tensor, Tensor]:
8071  return x, x
8072  ''')
8073 
8074  def test_gather_dynamic_index(self):
8075  def t(x):
8076  gather1 = x[0]
8077  idx = 0 + 1
8078  gather2 = x[idx]
8079  return gather1 + gather2
8080 
8081  self.checkScript(t, (torch.zeros(3, 2, 3),))
8082 
8083  def test_slice_dynamic_index(self):
8084  def t(x):
8085  slice1 = x[0:1]
8086  zero = 0
8087  one = zero + 1
8088  slice2 = x[zero:one]
8089  return slice1 + slice2
8090 
8091  self.checkScript(t, (torch.zeros(3, 2, 3),))
8092 
8093  def test_addmm_grad(self):
8094  """ This test checks several things:
8095  1. An expand node was inserted before the addmm operating on the
8096  bias term.
8097  2. The fused form of addmm appears in the ultimate graph that's
8098  executed.
8099  3. A sum op was emitted for accumulating gradients along the 0th
8100  (expanded) dimension of the bias term.
8101  4. The correct symbolic representation for the backward pass of the
8102  mm operator was emitted (x.t() -> mm)
8103 
8104  TODO: we should actually check these conditions once we have a way
8105  to dump the GraphExecutor state. Namely the processed forward graph
8106  and the backward graph.
8107  """
8108  @torch.jit.script
8109  def addmm_grad_test(b, x, w):
8110  return torch.addmm(b, x, w)
8111 
8112  # Initialize param and input values
8113  w_init = torch.rand(2, 5)
8114  b_init = torch.rand(5)
8115  x = torch.rand(3, 2)
8116 
8117  # Clone trainable params
8118  b = b_init.clone()
8119  b.requires_grad_()
8120  w = w_init.clone()
8121  w.requires_grad_()
8122 
8123  # Test symbolic differentiation
8124  y = addmm_grad_test(b, x, w)
8125  y.sum().backward()
8126 
8127  # clone params for autograd reference
8128  b_ref = b_init.clone()
8129  b_ref.requires_grad_()
8130  w_ref = w_init.clone()
8131  w_ref.requires_grad_()
8132  y_ref = torch.addmm(b_ref, x, w_ref)
8133  y_ref.sum().backward()
8134 
8135  self.assertEqual(w.grad, w_ref.grad)
8136  self.assertEqual(b.grad, b_ref.grad)
8137 
8138  def test_zeros(self):
8139  class M(torch.jit.ScriptModule):
8140  __constants__ = ['d']
8141 
8142  def __init__(self):
8143  self.d = torch.device('cpu')
8144 
8145  @torch.jit.script_method
8146  def create(self):
8147  return torch.zeros([1, 1, 2], dtype=torch.float, device=self.d, layout=torch.strided)
8148 
8149  r = M().create()
8150  self.assertEqual(r.dtype, torch.float)
8151  self.assertEqual(torch.zeros([1, 1, 2], dtype=torch.float), r)
8152 
8153  def test_vararg_zeros(self):
8154  def foo():
8155  return torch.zeros(3, 4, 5, dtype=torch.int)
8156 
8157  self.checkScript(foo, ())
8158 
8159  def test_rand(self):
8160  def test_rand():
8161  a = torch.rand([3, 4])
8162  return a + 1.0 - a
8163 
8164  self.checkScript(test_rand, ())
8165 
8166  def test_erase_number_types(self):
8167  def func(a):
8168  b = 7 + 1 + 3
8169  c = a + b
8170  c += b
8171  return c
8172 
8173  graph = torch.jit.script(func).graph
8174  FileCheck().check("int = prim::Constant").check("aten::add_").run(str(graph))
8175  self.run_pass('remove_inplace_ops', graph)
8176  self.run_pass('erase_number_types', graph)
8177  self.run_pass('dce', graph)
8178  FileCheck().check_not("int = prim::Constant").check_not("aten::add_").run(str(graph))
8179 
8180  def test_mm_batching(self):
8181  lstm_cell = torch.jit.script(LSTMCellS)
8182 
8183  def lstm(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
8184  for i in range(x.size(0)):
8185  hx, cx = lstm_cell(x[i], hx, cx, w_ih, w_hh, b_ih, b_hh)
8186  return hx
8187 
8188  slstm = torch.jit.script(lstm)
8189 
8190  inputs = get_lstm_inputs('cpu', training=True, seq_length=10)
8191  slstm(*inputs).sum().backward()
8192 
8193  fw_graph = slstm.graph_for(*inputs)
8194  bw_graph = backward_graph(slstm, diff_graph_idx=0)
8195  self.assertTrue('prim::MMBatchSide' in str(fw_graph))
8196  self.assertTrue('prim::MMTreeReduce' in str(bw_graph))
8197 
8198  sout = slstm(*inputs)
8199  out = lstm(*inputs)
8200  self.assertEqual(slstm(*inputs), lstm(*inputs))
8201  self.assertEqual(torch.autograd.grad(slstm(*inputs).sum(), inputs),
8202  torch.autograd.grad(lstm(*inputs).sum(), inputs))
8203 
8204  def test_loop_unrolling(self):
8205  def fn(x):
8206  y = 0
8207  for i in range(int(x)):
8208  y -= i
8209  return y
8210 
8211  graph = torch.jit.script(fn).graph
8212  self.run_pass('loop_unrolling', graph)
8213  unroll_factor = 8
8214  FileCheck().check("prim::Loop").check_count("aten::sub", unroll_factor) \
8215  .check("prim::Loop").check("aten::sub").run(str(graph))
8216  self.checkScript(fn, (torch.tensor(10),))
8217 
8218  def test_loop_unrolling_const(self):
8219  def fn():
8220  y = 0
8221  for _ in range(10):
8222  y -= 1
8223  return y
8224 
8225  def fn2():
8226  y = 0
8227  for i in range(10):
8228  y -= i
8229  return y
8230 
8231  def check(fn, name):
8232  graph = torch.jit.script(fn).graph
8233  self.run_pass('loop_unrolling', graph)
8234  # entirely unrolled
8235  FileCheck().check_not("prim::Loop'").run(str(graph))
8236  self.checkScript(fn, ())
8237 
8238  check(fn, 'add_const')
8239  check(fn2, 'add_iter')
8240 
8241  def test_loop_unrolling_nested(self):
8242  def fn(x):
8243  y = 0
8244  for _ in range(10):
8245  for j in range(int(x)):
8246  y -= j
8247  return y
8248 
8249  graph = torch.jit.script(fn).graph
8250  self.run_pass('loop_unrolling', graph)
8251  # inner loop with 8 subs followed by loop epilogue
8252  unroll_factor = 8
8253  FileCheck().check("prim::Loop").check("prim::Loop").check_count('aten::sub', unroll_factor) \
8254  .check("prim::Loop").check("aten::sub").run(str(graph))
8255  self.checkScript(fn, (torch.tensor(10),))
8256 
8257  def test_loop_unroll_unused_counter(self):
8258  def fn(x):
8259  y = 0
8260  for _ in range(int(x)):
8261  y -= 1
8262  return y
8263 
8264  graph = torch.jit.script(fn).graph
8265  self.run_pass('loop_unrolling', graph)
8266  FileCheck().check("prim::Loop").check_not("aten::add").check("return") \
8267  .run(str(graph))
8268 
8269  def test_loop_unroll_negative(self):
8270  def fn(x):
8271  y = 0
8272  for _ in range(int(x)):
8273  y += 1
8274  return y
8275 
8276  self.checkScript(fn, (torch.tensor(-20),))
8277  self.checkScript(fn, (torch.tensor(-2),))
8278  self.checkScript(fn, (torch.tensor(-1),))
8279  self.checkScript(fn, (torch.tensor(0),))
8280  self.checkScript(fn, (torch.tensor(1),))
8281  self.checkScript(fn, (torch.tensor(2),))
8282 
8283  def test_where(self):
8284  def fn(x, y):
8285  return torch.where(x > 0.0, x, y)
8286 
8287  self.checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float)))
8288 
8289  def test_where_method(self):
8290  def fn(x, y):
8291  return x.where(x > 0.0, y)
8292 
8293  self.checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float)))
8294 
8295  def test_reassign_module_lhs(self):
8296  with self.assertRaisesRegex(RuntimeError, 'Cannot re-assign \'self\' because it has type value and self is'
8297  ' not a first-class value. Only reassignments to first-class values are allowed'):
8298  class ReassignSelfLHS(torch.jit.ScriptModule):
8299  @torch.jit.script_method
8300  def forward(self, x):
8301  for _ in range(20):
8302  self = x
8303  return self
8304 
8305  ReassignSelfLHS()
8306 
8307  def test_reassign_module_rhs(self):
8308  with self.assertRaisesRegex(RuntimeError, 'Cannot re-assign \'x\' to a value of type module because x is not a'
8309  ' first-class value. Only reassignments to first-class values are allowed'):
8310  class ReassignSelfRHS(torch.jit.ScriptModule):
8311  @torch.jit.script_method
8312  def forward(self, x):
8313  for _ in range(20):
8314  x = self
8315  return self
8316 
8317  ReassignSelfRHS()
8318 
8319  def test_unknown_builtin(self):
8320  with self.assertRaisesRegex(RuntimeError, 'unknown builtin op'):
8321  @torch.jit.script
8322  def unknown_builtin(x):
8323  return x.splork(3)
8324 
8325  def test_return_tuple(self):
8326  def return_tuple(x):
8327  a = (x, x)
8328  return a, x
8329  self.checkScript(return_tuple, (torch.rand(4),))
8330 
8331  def test_method_no_self(self):
8332  with self.assertRaisesRegex(RuntimeError, 'methods must have a self argument'):
8333  class MethodNoSelf(torch.jit.ScriptModule):
8334  @torch.jit.script_method
8335  def forward():
8336  return torch.zeros(3, 4)
8337 
8338  MethodNoSelf()
8339 
8340  def test_return_stmt_not_at_end(self):
8341  def return_stmt(x):
8342  if bool(x > 3):
8343  return x + 3
8344  else:
8345  return x
8346  self.checkScript(return_stmt, (torch.rand(1),))
8347 
8348  def test_for_range_no_arg(self):
8349  with self.assertRaisesRegex(RuntimeError, r'range\(\) expects 1 argument but got 0'):
8350  @torch.jit.script
8351  def range_no_arg(x):
8352  for _ in range():
8353  x += 1
8354  return x
8355 
8356  def test_list_iterables(self):
8357  with self.assertRaisesRegex(RuntimeError, 'List of iterables is not supported currently'):
8358  cu = torch.jit.CompilationUnit('''
8359  def list_iterables(x):
8360  for i, j in [2, 3, 4], [5, 6, 7]:
8361  x += i
8362  x += j
8363  return x
8364  ''')
8365 
8366  def test_for_tuple_unpack(self):
8367  with self.assertRaisesRegex(RuntimeError, 'Iteration variable unpacking is not supported'):
8368  cu = torch.jit.CompilationUnit('''
8369  def for_tuple_unpack(x, y):
8370  for i, j in [[3, 4], [5, 6], [7, 8]]:
8371  x += i
8372  y += j
8373  return x, y
8374  ''')
8375 
8376  def test_single_starred_lhs(self):
8377  with self.assertRaisesRegex(RuntimeError, 'A Starred expression may only appear on the lhs within the presence'
8378  ' of another non-starred expression'):
8379  cu = torch.jit.CompilationUnit('''
8380  def single_starred_lhs(x):
8381  a = (x, x, x)
8382  *b, = a
8383  return b
8384  ''')
8385 
8386  def test_singleton_tuple_unpack(self):
8387  def foo(a):
8388  b, = (a,)
8389  return b + 1
8390  self.checkScript(foo, (torch.rand(3),))
8391 
8392  def test_multi_reduction(self):
8393  with self.assertRaisesRegex(
8394  RuntimeError,
8395  'augmented assignment can only have one LHS expression'):
8396  cu = torch.jit.CompilationUnit('''
8397  def multi_reduction(x):
8398  a, b += x
8399  return a, b
8400  ''')
8401 
8402  def test_invalid_call_arguments(self):
8403  with self.assertRaisesRegex(RuntimeError, 'arguments for call are not valid'):
8404  @torch.jit.script
8405  def invalid_call_arguments(x):
8406  return torch.unsqueeze(3, 4, 5, 6, 7, 8)
8407 
8408  def test_invalid_lhs_assignment(self):
8409  with self.assertRaisesRegex(RuntimeError, 'unexpected expression'):
8410  cu = torch.jit.CompilationUnit('''
8411  def invalid_lhs_assignment(x):
8412  x + 1 = x
8413  return x
8414  ''')
8415 
8416  def test_multi_starred_expr_lhs(self):
8417  with self.assertRaisesRegex(RuntimeError, 'Only one starred expression is allowed on the lhs'):
8418  cu = torch.jit.CompilationUnit('''
8419  def multi_starred_expr_lhs():
8420  a, *b, *c = [1, 2, 3, 4, 5, 6]
8421  return a
8422  ''')
8423 
8424  def test_pack_tuple_into_non_var(self):
8425  with self.assertRaisesRegex(RuntimeError, 'Cannot pack a tuple into a non-variable'):
8426  cu = torch.jit.CompilationUnit('''
8427  def pack_tuple_into_non_var(x):
8428  a, *1 = (3, 4, 5)
8429  return x
8430  ''')
8431 
8432  def test_print_kwargs(self):
8433  with self.assertRaisesRegex(RuntimeError, 'print doesn\'t accept any keyword arguments'):
8434  cu = torch.jit.CompilationUnit('''
8435  def print_kwargs(x):
8436  print(x, flush=True)
8437  return x
8438  ''')
8439 
8440  def test_builtin_use_as_value(self):
8441  with self.assertRaisesRegex(RuntimeError, 'builtin cannot be used as a value'):
8442  @torch.jit.script
8443  def builtin_use_as_value(x):
8444  return x.unsqueeze
8445 
8446  def test_wrong_use_as_tuple(self):
8447  with self.assertRaisesRegex(RuntimeError, 'cannot be used as a tuple'):
8448  def test_fn():
8449  return 3
8450 
8451  @torch.jit.script
8452  def wrong_use_as_tuple(self):
8453  a, b = test_fn
8454  return a
8455 
8456  def test_wrong_attr_lookup(self):
8457  with self.assertRaisesRegex(RuntimeError, 'attribute lookup is not defined on builtin'):
8458  @torch.jit.script
8459  def wrong_attr_lookup(self, x):
8460  a = x.unsqueeze.myattr
8461  return a
8462 
8463  def test_wrong_use_as_callable(self):
8464  with self.assertRaisesRegex(RuntimeError, 'cannot call a value'):
8465  @torch.jit.script
8466  def wrong_use_as_callable(x):
8467  return x(3, 4, 5)
8468 
8469  def test_python_val_doesnt_have_attr(self):
8470  with self.assertRaisesRegex(RuntimeError, 'object has no attribute abcd'):
8471 
8472  @torch.jit.script
8473  def python_val_doesnt_have_attr():
8474  # this has to be a module otherwise attr lookup would not be
8475  # allowed in the first place
8476  return shutil.abcd
8477 
8478  def test_wrong_module_attr_lookup(self):
8479  with self.assertRaisesRegex(RuntimeError, 'python value of type \'type\' cannot be used as a value:'):
8480  import io
8481 
8482  @torch.jit.script
8483  def wrong_module_attr_lookup():
8484  return io.BytesIO
8485 
8486  def test_wrong_method_call_inputs(self):
8487  with self.assertRaisesRegex(RuntimeError, 'argument y not provided'):
8488  class SomeModule(torch.jit.ScriptModule):
8489 
8490  @torch.jit.script_method
8491  def foo(self, x, y):
8492  return x
8493 
8494  @torch.jit.script_method
8495  def forward(self, x, y):
8496  return self.foo(x)
8497  SomeModule()
8498 
8499  def test_single_starred_expr_for_loop(self):
8500  with self.assertRaisesRegex(RuntimeError, 'unexpected expression'):
8501  cu = torch.jit.CompilationUnit('''
8502  def test():
8503  x = 0
8504  for *a in [1, 2, 3]:
8505  x = x + 1
8506  return x
8507  ''')
8508 
8509  def test_duplicate(self):
8510  with self.assertRaisesRegex(RuntimeError, 'Method \'test\' already defined'):
8511  cu = torch.jit.CompilationUnit('''
8512  def test():
8513  return 1
8514 
8515  def test():
8516  return 2
8517  ''')
8518 
8519  def test_call_ge(self):
8520  with self.assertRaisesRegex(RuntimeError, 'expected at most 1 arguments but found 3'):
8521  @_trace(torch.zeros(1, 2, 3))
8522  def foo(x):
8523  return x
8524 
8525  @torch.jit.script
8526  def test_fn():
8527  return foo(torch.full([1], 1), torch.full([1], 2), torch.full([1], 3))
8528 
8529  def test_wrong_return_type(self):
8530  with self.assertRaisesRegex(RuntimeError, 'but instead got value of type tuple'):
8531  def somefunc():
8532  # type: () -> Tuple[Tuple[Tensor, Tensor]]
8533  return torch.zeros(3, 4), torch.zeros(4, 5) # noqa: T484
8534 
8535  @torch.jit.script
8536  def wrong_return_type():
8537  return somefunc()
8538  wrong_return_type()
8539 
8540  # Tests for calling between different front-end modes
8541  def test_call_python_fn_from_tracing_fn(self):
8542  def python_fn(x):
8543  return torch.neg(x)
8544 
8545  @_trace(torch.rand(3, 4))
8546  def traced_fn(x):
8547  return python_fn(x) + 1
8548 
8549  # The neg op in the python function should be properly inlined to the
8550  # graph
8551  FileCheck().check("aten::neg").run(str(traced_fn.graph))
8552 
8553  def test_call_python_mod_from_tracing_fn(self):
8554  class PythonMod(torch.nn.Module):
8555  def __init__(self):
8556  super(PythonMod, self).__init__()
8557  self.param = torch.nn.Parameter(torch.rand(4, 3), requires_grad=False)
8558 
8559  def forward(self, x):
8560  return torch.mm(x, self.param)
8561 
8562  pm = PythonMod()
8563 
8564  @_trace(torch.rand(3, 4))
8565  def traced_fn(x):
8566  return pm(x) + 1.0
8567 
8568  # Note: the parameter self.param from the Python module is inlined
8569  # into the graph
8570  self.assertTrue(len(list(traced_fn.graph.inputs())) == 1)
8571  FileCheck().check("aten::mm").check("aten::add").run(str(traced_fn.graph))
8572 
8573  def test_call_traced_fn_from_tracing_fn(self):
8574  @_trace(torch.rand(3, 4))
8575  def traced_fn1(x):
8576  return torch.neg(x)
8577 
8578  @_trace(torch.rand(3, 4))
8579  def traced_fn(x):
8580  return traced_fn1(x) + 1
8581 
8582  FileCheck().check("aten::neg").check_same("scope: traced_fn1").check("aten::add") \
8583  .run(str(traced_fn.graph))
8584 
8585  def test_call_traced_mod_from_tracing_fn(self):
8586  class TracedModule(torch.nn.Module):
8587  def __init__(self):
8588  super(TracedModule, self).__init__()
8589  self.param = torch.nn.Parameter(torch.rand(4, 3), requires_grad=False)
8590 
8591  def forward(self, x):
8592  return torch.mm(x, self.param)
8593 
8594  tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
8595 
8596  @_trace(torch.rand(3, 4))
8597  def traced_fn(x):
8598  return tm(x) + 1.0
8599 
8600  # Note: the parameter self.param from the Python module is inlined
8601  # into the graph
8602  FileCheck().check("prim::Constant[value=<Tensor>]").check("aten::mm") \
8603  .check("aten::add").run(str(traced_fn.graph))
8604 
8605  def test_call_script_fn_from_tracing_fn(self):
8606  @torch.jit.script
8607  def script_fn(x):
8608  return torch.neg(x)
8609 
8610  @_trace(torch.rand(3, 4))
8611  def traced_fn(x):
8612  return script_fn(x) + 1
8613 
8614  FileCheck().check("aten::neg").check("aten::add").run(str(traced_fn.graph))
8615 
8616  def test_call_script_mod_from_tracing_fn(self):
8617  with self.disableModuleHook():
8618  class ScriptMod(torch.jit.ScriptModule):
8619  def __init__(self):
8620  super(ScriptMod, self).__init__()
8621  self.param = torch.nn.Parameter(torch.rand(3, 4), requires_grad=False)
8622 
8623  @torch.jit.script_method
8624  def forward(self, x):
8625  for _i in range(4):
8626  x += self.param
8627  return x
8628 
8629  sm = ScriptMod()
8630 
8631  @_trace(torch.rand(3, 4))
8632  def traced_fn(x):
8633  return sm(x) + 1.0
8634 
8635  # parameter turns into constant and loop is perserved
8636  FileCheck().check("prim::Constant[value=<Tensor>]").check("Loop") \
8637  .run(str(traced_fn.graph))
8638 
8639  def test_call_python_fn_from_traced_module(self):
8640  def python_fn(x):
8641  return torch.neg(x)
8642 
8643  class TracedModule(torch.nn.Module):
8644  def __init__(self):
8645  super(TracedModule, self).__init__()
8646  self.param = torch.nn.Parameter(torch.rand(4, 3))
8647 
8648  def forward(self, x):
8649  return torch.mm(python_fn(x), self.param)
8650 
8651  tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
8652 
8653  # Note: parameter self.param from the traced module should appear as
8654  # an input to the graph and the neg op from the Python function should
8655  # be properly inlined
8656  self.assertTrue(len(list(tm.graph.inputs())) == 2)
8657  FileCheck().check("aten::neg").check("aten::mm").run(str(tm.graph))
8658 
8659  def test_call_python_mod_from_traced_module(self):
8660  class PythonModule(torch.nn.Module):
8661  def __init__(self):
8662  super(PythonModule, self).__init__()
8663  self.param = torch.nn.Parameter(torch.rand(5, 7))
8664 
8665  def forward(self, x):
8666  return torch.mm(x, self.param)
8667 
8668  class TracedModule(torch.nn.Module):
8669  def __init__(self):
8670  super(TracedModule, self).__init__()
8671  self.param = torch.nn.Parameter(torch.rand(4, 5))
8672  self.mod = PythonModule()
8673 
8674  def forward(self, x):
8675  return self.mod(torch.mm(x, self.param)) + 1.0
8676 
8677  tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
8678 
8679  # Note: the parameters from both modules should appear in the flattened
8680  # inputs of the graph. All ops from both modules should be inlined.
8681  self.assertTrue(len(list(tm.graph.inputs())) == 3)
8682  FileCheck().check_not("value=<Tensor>").check_count("aten::mm", 2).check("aten::add") \
8683  .run(str(tm.graph))
8684 
8685  def test_call_traced_fn_from_traced_module(self):
8686  @_trace(torch.rand(3, 4))
8687  def traced_fn(x):
8688  return torch.neg(x)
8689 
8690  class TracedModule(torch.nn.Module):
8691  def __init__(self):
8692  super(TracedModule, self).__init__()
8693  self.param = torch.nn.Parameter(torch.rand(4, 5))
8694 
8695  def forward(self, x):
8696  return traced_fn(torch.mm(x, self.param))
8697 
8698  tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
8699  # Note: neg op from the traced function should be properly inlined
8700  FileCheck().check("aten::mm").check_same("scope: TracedModule") \
8701  .check_next("aten::neg").check("scope: TracedModule/traced_fn") \
8702  .run(str(tm.graph))
8703 
8704  def test_trace_hierarchy(self):
8705  # Test that we preserve the module hierarchy for a ScriptModule
8706  # submodule during tracing
8707 
8708  class AnotherScriptMod(torch.jit.ScriptModule):
8709  def __init__(self):
8710  super(AnotherScriptMod, self).__init__()
8711  self.param = torch.nn.Parameter(torch.rand(1, 2, 3))
8712 
8713  @torch.jit.script_method
8714  def bar(self):
8715  return torch.zeros(4, 5)
8716 
8717  class SomeScriptMod(torch.jit.ScriptModule):
8718  def __init__(self):
8719  super(SomeScriptMod, self).__init__()
8720  self.asm = AnotherScriptMod()
8721 
8722  @torch.jit.script_method
8723  def foo(self):
8724  return torch.zeros(3, 4)
8725 
8726  @torch.jit.script_method
8727  def bar(self):
8728  return torch.zeros(4, 3)
8729 
8730  class TraceMe(torch.nn.Module):
8731  def __init__(self):
8732  super(TraceMe, self).__init__()
8733  self.ssm = SomeScriptMod()
8734 
8735  def forward(self, x):
8736  return self.ssm.bar() + x
8737 
8738  orig = TraceMe()
8739  traced = torch.jit.trace(orig, (torch.rand(4, 3),))
8740  # for each of these checks, check that *BOTH* the underlying
8741  # _C.ScriptModule object has the expected method/param, as well as the
8742  # Python object that wraps it.
8743  self.assertTrue(traced.ssm._has_method('foo'))
8744  self.assertTrue(hasattr(traced.ssm, 'foo'))
8745 
8746  imported = self.getExportImportCopy(traced)
8747 
8748  self.assertTrue(imported.ssm._has_method('foo'))
8749  self.assertTrue(hasattr(imported.ssm, 'foo'))
8750 
8751  self.assertTrue(imported.ssm.asm._has_method('bar'))
8752  self.assertTrue(hasattr(imported.ssm.asm, 'bar'))
8753 
8754  self.assertTrue(imported.ssm.asm._has_parameter('param'))
8755  self.assertTrue(hasattr(imported.ssm.asm, 'param'))
8756 
8757  def test_trace_parameter(self):
8758  class Param(nn.Module):
8759  def __init__(self):
8760  super(Param, self).__init__()
8761  self.register_parameter("bias", nn.Parameter(torch.Tensor(4, 4)))
8762 
8763  def forward(self, x):
8764  return x
8765 
8766  class M3(torch.jit.ScriptModule):
8767  def __init__(self, model):
8768  super(M3, self).__init__(False)
8769  self.traced = torch.jit.trace(model, (torch.rand(3, 3)))
8770 
8771  @torch.jit.script_method
8772  def forward(self, x):
8773  return self.traced(x)
8774 
8775  class M2(nn.Module):
8776  def __init__(self, model):
8777  super(M2, self).__init__()
8778  self.module = M3(model)
8779 
8780  def forward(self, x):
8781  return self.module(x)
8782 
8783  class M1(torch.jit.ScriptModule):
8784  def __init__(self, model):
8785  super(M1, self).__init__(False)
8786  self.traced = torch.jit.trace(M2(model), (torch.rand(3, 3)))
8787 
8788  @torch.jit.script_method
8789  def forward(self, x):
8790  return self.traced(x)
8791 
8792  module = M1(Param())
8793  f = io.BytesIO()
8794  torch.jit.save(module, f)
8795 
8796  def test_call_traced_module_from_traced_module(self):
8797  class TracedModule1(torch.nn.Module):
8798  def __init__(self):
8799  super(TracedModule1, self).__init__()
8800  self.param = torch.nn.Parameter(torch.rand(5, 7))
8801 
8802  def forward(self, x):
8803  return torch.mm(x, self.param)
8804 
8805  class TracedModule(torch.nn.Module):
8806  def __init__(self):
8807  super(TracedModule, self).__init__()
8808  self.param = torch.nn.Parameter(torch.rand(4, 5))
8809  self.mod = torch.jit.trace(TracedModule1(), torch.rand(3, 5))
8810 
8811  def forward(self, x):
8812  return self.mod(torch.mm(x, self.param)) + 1.0
8813 
8814  tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
8815 
8816  # Note: the parameters from both modules should appear in the flattened
8817  # inputs of the graph. All ops from both modules should be inlined.
8818  self.assertTrue(len(list(tm.graph.inputs())) == 3)
8819  FileCheck().check_count("aten::mm", 2).check("aten::add").run(str(tm.graph))
8820 
8821  def test_call_script_fn_from_traced_module(self):
8822  @torch.jit.script
8823  def traced_fn(x):
8824  return torch.neg(x)
8825 
8826  class TracedModule(torch.nn.Module):
8827  def __init__(self):
8828  super(TracedModule, self).__init__()
8829  self.param = torch.nn.Parameter(torch.rand(4, 5))
8830 
8831  def forward(self, x):
8832  return traced_fn(torch.mm(x, self.param))
8833 
8834  tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
8835  # Note: neg op from the script function should be properly inlined
8836  FileCheck().check("aten::mm").check("aten::neg").run(str(tm.graph))
8837 
8838  def test_call_script_module_from_traced_module(self):
8839  class ScriptMod(torch.jit.ScriptModule):
8840  def __init__(self):
8841  super(ScriptMod, self).__init__()
8842  self.param_foo = torch.nn.Parameter(torch.rand(5, 7))
8843 
8844  @torch.jit.script_method
8845  def forward(self, x):
8846  return torch.mm(x, self.param_foo)
8847 
8848  class TracedModule(torch.nn.Module):
8849  def __init__(self):
8850  super(TracedModule, self).__init__()
8851  self.param = torch.nn.Parameter(torch.rand(4, 5))
8852  self.mod = ScriptMod()
8853 
8854  def forward(self, x):
8855  return self.mod(torch.mm(x, self.param)) + 1.0
8856 
8857  tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
8858 
8859  # Note: the parameters from both modules should appear in the flattened
8860  # inputs of the graph. All ops from both modules should be inlined.
8861  self.assertTrue(len(list(tm.graph.inputs())) == 3)
8862  FileCheck().check_count("aten::mm", 2).check("aten::add").run(str(tm.graph))
8863 
8864  def test_call_python_fn_from_script_fn(self):
8865  def python_fn(x):
8866  return torch.neg(x)
8867 
8868  @torch.jit.script
8869  def script_fn(x):
8870  return python_fn(x) + 1
8871 
8872  # Note: the call to python_fn appears as `^python_fn()` and is called
8873  # as a PythonOp in the interpreter
8874  a = torch.tensor(1)
8875  self.assertEqual(script_fn(a), torch.tensor(0))
8876  FileCheck().check("python_fn").run(str(script_fn.graph))
8877 
8878  def test_call_python_mod_from_script_fn(self):
8879  class PythonModule(torch.nn.Module):
8880  def __init__(self):
8881  super(PythonModule, self).__init__()
8882  self.param = torch.nn.Parameter(torch.rand(5, 7))
8883 
8884  def forward(self, x):
8885  return torch.mm(x, self.param)
8886 
8887  pm = PythonModule()
8888 
8889  @torch.jit.script
8890  def script_fn(x):
8891  return pm(x) + 1
8892 
8893  # Note: call to pm(x) appears as ^<python_value>() in the trace.
8894  # Parameters are NOT inlined.
8895  FileCheck().check("python_value").check("aten::add").run(str(script_fn.graph))
8896 
8897  def test_call_traced_fn_from_script_fn(self):
8898  @_trace(torch.rand(3, 4))
8899  def traced_fn(x):
8900  return torch.neg(x)
8901 
8902  @torch.jit.script
8903  def script_fn(x):
8904  return traced_fn(x) + 1
8905 
8906  # Note: the neg op from traced_fn should be properly inlined into the
8907  # script function's graph
8908  FileCheck().check("aten::neg").check("aten::add").run(str(script_fn.graph))
8909 
8910  def test_call_traced_mod_from_script_fn(self):
8911  class TracedModule(torch.nn.Module):
8912  def __init__(self):
8913  super(TracedModule, self).__init__()
8914 
8915  def forward(self, x):
8916  return torch.mm(x, torch.zeros(4, 3))
8917 
8918  tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
8919 
8920  @torch.jit.script
8921  def script_fn(x):
8922  return tm(x) + 1
8923 
8924  FileCheck().check("aten::zeros").check_same("scope: TracedModule").check("aten::mm") \
8925  .check("aten::add").run(str(script_fn.graph))
8926 
8927  def test_call_script_fn_from_script_fn(self):
8928  @torch.jit.script
8929  def script_fn1(x):
8930  return torch.neg(x)
8931 
8932  @torch.jit.script
8933  def script_fn(x):
8934  return script_fn1(x) + 1
8935 
8936  # Note: the neg op from script_fn1 should be properly inlined into the
8937  # graph of script_fn
8938  FileCheck().check("aten::neg").run(str(script_fn.graph))
8939 
8940  def test_call_script_mod_from_script_fn(self):
8941  class ScriptMod(torch.jit.ScriptModule):
8942  def __init__(self):
8943  super(ScriptMod, self).__init__()
8944 
8945  @torch.jit.script_method
8946  def forward(self, x):
8947  return torch.mm(x, torch.zeros([4, 3]))
8948 
8949  sm = ScriptMod()
8950 
8951  @torch.jit.script
8952  def script_fn(x):
8953  return sm(x) + 1
8954 
8955  FileCheck().check("zeros").check("aten::mm").check("add").run(str(script_fn.graph))
8956 
8957  def test_call_python_fn_from_script_module(self):
8958  def python_fn(x):
8959  return torch.neg(x)
8960 
8961  class ScriptMod(torch.jit.ScriptModule):
8962  def __init__(self):
8963  super(ScriptMod, self).__init__()
8964  self.param = torch.nn.Parameter(torch.rand(4, 3))
8965 
8966  @torch.jit.script_method
8967  def forward(self, x):
8968  return python_fn(torch.mm(x, self.param))
8969 
8970  sm = ScriptMod()
8971  FileCheck().check("aten::mm").check("python_fn") \
8972  .run(str(sm.__getattr__('forward').graph))
8973 
8974  def test_call_python_mod_from_script_module(self):
8975  class PythonMod(torch.nn.Module):
8976  def __init__(self):
8977  super(PythonMod, self).__init__()
8978  self.param = torch.nn.Parameter(torch.rand(3, 5))
8979 
8980  def forward(self, x):
8981  return torch.mm(x, self.param)
8982 
8983  class ScriptMod(torch.jit.ScriptModule):
8984  def __init__(self):
8985  super(ScriptMod, self).__init__()
8986  self.param = torch.nn.Parameter(torch.rand(4, 3))
8987  self.pm = PythonMod()
8988 
8989  @torch.jit.script_method
8990  def forward(self, x):
8991  return self.pm(torch.mm(x, self.param))
8992 
8993  sm = ScriptMod()
8994  # Note: the call into PythonMod appears as ^<python_value>(). Parameters
8995  # are NOT inlined
8996  FileCheck().check("aten::mm").check("python_value").run(str(sm.graph))
8997 
8998  def test_call_tracing_fn_from_script_module(self):
8999  @_trace(torch.rand(3, 3))
9000  def traced_fn(x):
9001  return torch.neg(x)
9002 
9003  class ScriptMod(torch.jit.ScriptModule):
9004  def __init__(self):
9005  super(ScriptMod, self).__init__()
9006  self.param = torch.nn.Parameter(torch.rand(4, 3))
9007 
9008  @torch.jit.script_method
9009  def forward(self, x):
9010  return traced_fn(torch.mm(x, self.param))
9011 
9012  sm = ScriptMod()
9013  FileCheck().check("aten::mm").check("aten::neg").run(str(sm.__getattr__('forward').graph))
9014 
9015  def test_call_tracing_mod_from_script_module(self):
9016  class TracedMod(torch.nn.Module):
9017  def __init__(self):
9018  super(TracedMod, self).__init__()
9019  self.param = torch.nn.Parameter(torch.rand(3, 5))
9020 
9021  def forward(self, x):
9022  return torch.mm(x, self.param)
9023 
9024  class ScriptMod(torch.jit.ScriptModule):
9025  def __init__(self):
9026  super(ScriptMod, self).__init__()
9027  self.param = torch.nn.Parameter(torch.rand(4, 3))
9028  self.tm = torch.jit.trace(TracedMod(), torch.rand(3, 3))
9029 
9030  @torch.jit.script_method
9031  def forward(self, x):
9032  return self.tm(torch.mm(x, self.param))
9033 
9034  sm = ScriptMod()
9035  # Note: the parameters from both modules should appear in the flattened
9036  # input list to the graph. The mm op from TracedMod should be properly
9037  # inlined
9038  self.assertTrue(len(list(sm.graph.inputs())) == 3)
9039  FileCheck().check("aten::mm").check("aten::mm").run(str(sm.graph))
9040 
9041  def test_call_script_fn_from_script_module(self):
9042  @torch.jit.script
9043  def script_fn(x):
9044  return torch.neg(x)
9045 
9046  class ScriptMod(torch.jit.ScriptModule):
9047  def __init__(self):
9048  super(ScriptMod, self).__init__()
9049  self.param = torch.nn.Parameter(torch.rand(4, 3))
9050 
9051  @torch.jit.script_method
9052  def forward(self, x):
9053  return script_fn(torch.mm(x, self.param))
9054 
9055  sm = ScriptMod()
9056  graph = (sm.__getattr__('forward').graph)
9057  FileCheck().check("aten::mm").check("aten::neg").run(str(graph))
9058 
9059  def test_call_script_mod_from_script_module(self):
9060  class ScriptMod1(torch.jit.ScriptModule):
9061  def __init__(self):
9062  super(ScriptMod1, self).__init__()
9063  self.param = torch.nn.Parameter(torch.rand(3, 5))
9064 
9065  @torch.jit.script_method
9066  def forward(self, x):
9067  return torch.mm(x, self.param)
9068 
9069  class ScriptMod(torch.jit.ScriptModule):
9070  def __init__(self):
9071  super(ScriptMod, self).__init__()
9072  self.param = torch.nn.Parameter(torch.rand(4, 3))
9073  self.tm = ScriptMod1()
9074 
9075  @torch.jit.script_method
9076  def forward(self, x):
9077  return self.tm(torch.mm(x, self.param))
9078 
9079  sm = ScriptMod()
9080  # Note: the parameters from both modules should appear in the flattened
9081  # input list to the graph. The mm op from ScriptMod1 should be properly
9082  # inlined
9083  # 3 % values in graph input lists, two mms in body
9084  FileCheck().check_count('%', 3).check(":").check_count("mm", 2).run(str(sm.graph))
9085 
9086  def test_module_with_params_called_fails(self):
9087  with self.assertRaisesRegex(RuntimeError, "Attempted to inline a Module with parameters. Stateful "
9088  "modules to be inlined must be submodules of the callee."):
9089  class ScriptMod(torch.jit.ScriptModule):
9090  def __init__(self):
9091  super(ScriptMod, self).__init__()
9092  self.param = torch.nn.Parameter(torch.rand(3, 3))
9093 
9094  @torch.jit.script_method
9095  def forward(self, x):
9096  return torch.mm(x, self.param)
9097 
9098  sm = ScriptMod()
9099 
9100  @torch.jit.script
9101  def some_func(x):
9102  return sm(x)
9103 
9104  def test_index_put_trace_with_view(self):
9105  @_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(1, 1, 1, 4))
9106  def test_index_put(target, indices, rhs):
9107  target[indices] = rhs
9108  return target
9109 
9110  FileCheck().check("aten::view").check("index_put_").run(str(test_index_put.graph))
9111 
9112  def test_index_put_trace_without_view(self):
9113  @_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(4))
9114  def test_index_put(target, indices, rhs):
9115  target[indices] = rhs
9116  return target
9117 
9118  FileCheck().check_not("aten::view").check("index_put_").run(str(test_index_put.graph))
9119 
9120  def test_tuple_indexing(self):
9121  def tuple_index(a):
9122  if bool(a):
9123  b = (1, 2)
9124  else:
9125  b = (0, 2)
9126  return b[-2], b[1]
9127 
9128  self.checkScript(tuple_index, (torch.tensor([0]),))
9129  self.checkScript(tuple_index, (torch.tensor([1]),))
9130  self.checkScript(tuple_index, (torch.tensor([1]),), optimize=True)
9131  tuple_comp = torch.jit.script(tuple_index)
9132  FileCheck().check_count("TupleIndex", 2, exactly=True).run(str(tuple_comp.graph))
9133 
9134  with self.assertRaisesRegex(RuntimeError, "tuple indices must be integer constants"):
9135  @torch.jit.script
9136  def test_non_constant_input(a):
9137  if bool(a):
9138  b = 1
9139  else:
9140  b = 0
9141  c = (0, 1)
9142  return c[b]
9143 
9144  def test_indexing_float():
9145  c = (1, 2)
9146  return c[0.1]
9147  self.checkScriptRaisesRegex(test_indexing_float, (), Exception,
9148  "tuple indices must")
9149 
9150  def test_indexing_out_of_bounds_pos():
9151  c = (1, 2)
9152  return c[2]
9153 
9154  self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception,
9155  "out of range")
9156 
9157  def test_indexing_out_of_bounds_neg():
9158  c = (1, 2)
9159  return c[-3]
9160 
9161  self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception,
9162  "out of range")
9163 
9164  def test_namedtuple_attr(self):
9165  def f(x):
9166  return x.max(dim=1).indices + torch.max(x, dim=1).indices
9167 
9168  self.checkScript(f, (torch.rand(20, 20, 20),), optimize=True)
9169 
9170  with self.assertRaisesRegex(RuntimeError, "Unknown attribute to named tuple"):
9171  @torch.jit.script
9172  def g1(x):
9173  return x.max(dim=1).unknown_symbol
9174 
9175  with self.assertRaisesRegex(RuntimeError, "Getting attributes of tuples is not supported"):
9176  @torch.jit.script
9177  def g2(x):
9178  print((x, x, x).__doc__)
9179  return x
9180 
9181  def test_tuple_slicing(self):
9182  def tuple_slice(a):
9183  if bool(a):
9184  b = (1, 2, 3, 4)
9185  else:
9186  b = (4, 3, 2, 1)
9187  c = b[-4:4]
9188  e = c[1:-1]
9189  return e
9190 
9191  self.checkScript(tuple_slice, (torch.tensor([1]),), optimize=True)
9192  tuple_graph = torch.jit.script(tuple_slice).graph
9193  slices = tuple_graph.findAllNodes("prim::TupleSlice")
9194  num_outputs = set(map(lambda x: len(x.output().type().elements()), slices))
9195  # one tuple slice should have an output with 2 elements, other 4
9196  self.assertTrue(num_outputs == {2, 4})
9197  self.run_pass('lower_all_tuples', tuple_graph)
9198  self.assertTrue('Tuple' not in str(tuple_graph))
9199  tuple_comp = torch.jit.script(tuple_slice)
9200  self.assertEqual(tuple_comp(torch.tensor(1)), (2, 3))
9201 
9202  @torch.jit.script
9203  def test_indexing_end_out_of_bounds():
9204  c = (1, 2)
9205  return c[2:10]
9206 
9207  self.assertEqual(test_indexing_end_out_of_bounds(), ())
9208 
9209  def test_unwrap_optional_builtin(self):
9210  def test(x):
9211  # type: (Optional[int]) -> int
9213  x = x + x # noqa: T484
9214  return x
9215 
9216  self.checkScript(test, (3,))
9217 
9218  with self.assertRaisesRegex(AssertionError, "Unwrapping null optional"):
9219  test(None)
9220 
9221  test_script = torch.jit.script(test)
9222  with self.assertRaisesRegex(RuntimeError, "Unwrapping null optional"):
9223  test_script(None)
9224 
9225  @torch.jit.script
9226  def test_test():
9227  return torch.jit._unwrap_optional(1)
9228 
9229  with self.assertRaisesRegex(RuntimeError, "cannot match an Optional\\[T\\] to None"):
9230  @torch.jit.script
9231  def test_no_type():
9232  # type: () -> int
9233  return torch.jit._unwrap_optional(None)
9234 
9235  def test_indexing_error(self):
9236  with self.assertRaisesRegex(RuntimeError, "only supported on lists, dictionaries, tensors, and tuples"):
9237  @torch.jit.script
9238  def test_wrong_type():
9239  a = 8
9240  return a[0]
9241 
9242  def test_annotated_script_fn(self):
9243  @torch.jit.script
9244  def foo(x, y, z):
9245  # type: (Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tuple[Tensor, Tensor]]) -> Tensor
9246  return x
9247 
9248  self.assertExpected(foo.__getattr__('forward').pretty_print_schema())
9249 
9250  def test_annotated_script_method(self):
9251  class SM(torch.jit.ScriptModule):
9252  @torch.jit.script_method
9253  def forward(self, x, y):
9254  # type: (Tuple[Tensor, Tensor], Tensor) -> Tuple[Tensor, Tensor, Tensor]
9255  return y, y, y
9256 
9257  sm = SM()
9258 
9259  self.assertExpected(sm.__getattr__('forward').pretty_print_schema())
9260 
9261  def test_annotated_script_fn_return_mismatch(self):
9262  with self.assertRaisesRegex(RuntimeError, "but is actually of type"):
9263  @torch.jit.script
9264  def return_tup(x):
9265  # type: (Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor]
9266  return x, x # noqa: T484
9267 
9268  def test_annotated_script_fn_arg_mismatch(self):
9269  with self.assertRaisesRegex(RuntimeError, r"arguments for call are not valid"):
9270  @torch.jit.script
9271  def tuple_arg(x):
9272  # type: (Tuple[Tensor, Tensor]) -> Tensor
9273  return x + 1 # noqa: T484
9274 
9275  def test_script_non_tensor_args_outputs(self):
9276  @torch.jit.script
9277  def fn(x, y):
9278  # type: (Tensor, float) -> float
9279  return float((x + y).sum())
9280 
9281  x = torch.ones(2, 2)
9282  z = fn(x, 1)
9283  self.assertIsInstance(z, float)
9284  self.assertEqual(z, 8.)
9285 
9286  @unittest.skip('https://github.com/pytorch/pytorch/issues/9595')
9287  def test_inline_and_run_annotated_script_fn(self):
9288  @torch.jit.script
9289  def to_inline(x, y):
9290  # type: (Tuple[Tensor, Tensor], Tensor) -> Tensor
9291  return y
9292 
9293  @torch.jit.script
9294  def some_func(x):
9295  return to_inline((x, x), x)
9296 
9297  x = torch.rand(3, 4)
9298  self.assertEqual(some_func(x), x)
9299 
9300  def test_file_format_serialization(self):
9301  import tempfile
9302  filename = tempfile.mktemp()
9303  writer = torch._C.PyTorchFileWriter(filename)
9304  import os
9305  import random
9306  buffers = [os.urandom(size) for size in [random.randint(1, 100) for i in range(20)]]
9307  offsets = []
9308  for i, buf in enumerate(buffers):
9309  writer.write_record(str(i), buf, len(buf))
9310  offsets.append(i)
9311  import pickle
9312  serialized_offsets = pickle.dumps(offsets)
9313  writer.write_record("meta", serialized_offsets, len(serialized_offsets))
9314  writer.write_end_of_file()
9315 
9316  reader = torch._C.PyTorchFileReader(filename)
9317  serialized_offsets_read = reader.get_record("meta")
9318  parsed_serialized_offsets = pickle.loads(serialized_offsets)
9319 
9320  for i, offset in enumerate(parsed_serialized_offsets):
9321  data = reader.get_record(str(offset))
9322  assert(data == buffers[i])
9323 
9324  # for each type, the input type annotation and corresponding return type annotation
9325  def type_input_return_pairs(self):
9326  return [
9327  ('Tensor', 'Tensor'),
9328  ('torch.Tensor', 'Tensor'),
9329  ('str', 'str'),
9330  ('int', 'int'),
9331  ('bool', 'bool'),
9332  ('BroadcastingList3[float]', 'List[float]'),
9333  ('BroadcastingList2[int]', 'List[int]'),
9334  ('List[int]', 'List[int]'),
9335  ('Optional[int]', 'Optional[int]'),
9336  ]
9337 
9338  # replacing code input & return type pair
9339  def format_code(self, code, pair):
9340  return code.format(input=pair[0], output=pair[1])
9341 
9342  # ***** Type annotation tests ****
9343  # Test combinations of:
9344  # {String frontend, Python AST Frontend}
9345  # {Python 3-style type annotations, MyPy-style type comments}
9346  # {Script method, Script function}
9347 
9348  # String frontend , Python 3-style type annotations , Script function
9349  def test_annot_string_py3_fn(self):
9350  code = '''
9351  def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
9352  return x, x
9353  '''
9354  test_str = []
9355  for pair in self.type_input_return_pairs():
9356  cu = torch.jit.CompilationUnit(self.format_code(code, pair))
9357  test_str.append(cu.__getattr__('foo').pretty_print_schema())
9358  self.assertExpected("\n".join(test_str))
9359 
9360  # String frontend , Python 3-style type annotations , Script method
9361  def test_annot_string_py3_method(self):
9362  class TestModule(torch.jit.ScriptModule):
9363  def __init__(self):
9364  super(TestModule, self).__init__()
9365 
9366  code = '''
9367  def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
9368  return x, x
9369  '''
9370  test_str = []
9371  for pair in self.type_input_return_pairs():
9372  tm = TestModule()
9373  tm.define(self.format_code(code, pair))
9374  test_str.append(tm.__getattr__('foo').pretty_print_schema())
9375  self.assertExpected("\n".join(test_str))
9376 
9377  # String frontend , MyPy-style type comments , Script function
9378  def test_annot_string_mypy_fn(self):
9379  code = '''
9380  def foo(x, y):
9381  # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
9382  return x, x
9383  '''
9384  test_str = []
9385  for pair in self.type_input_return_pairs():
9386  cu = torch.jit.CompilationUnit(self.format_code(code, pair))
9387  test_str.append(cu.__getattr__('foo').pretty_print_schema())
9388  self.assertExpected("\n".join(test_str))
9389 
9390  # String frontend , MyPy-style type comments , Script method
9391  def test_annot_string_mypy_method(self):
9392  class TestModule(torch.jit.ScriptModule):
9393  def __init__(self):
9394  super(TestModule, self).__init__()
9395 
9396  code = '''
9397  def foo(self, x, y):
9398  # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
9399  return x, x
9400  '''
9401 
9402  test_str = []
9403  for pair in self.type_input_return_pairs():
9404  tm = TestModule()
9405  tm.define(self.format_code(code, pair))
9406  test_str.append(tm.__getattr__('foo').pretty_print_schema())
9407  self.assertExpected("\n".join(test_str))
9408 
9409  # Helper function to eval Python3 code without causing a syntax error for
9410  # this file under py2
9411  def _get_py3_code(self, code, fn_name):
9412  with tempfile.TemporaryDirectory() as tmp_dir:
9413  script_path = os.path.join(tmp_dir, 'script.py')
9414  with open(script_path, 'w') as f:
9415  f.write(code)
9416  import importlib.util
9417  spec = importlib.util.spec_from_file_location(fn_name, script_path)
9418  module = importlib.util.module_from_spec(spec)
9419  spec.loader.exec_module(module)
9420  fn = getattr(module, fn_name)
9421  return fn
9422 
9423  # Python AST Frontend , Python 3-style type annotations , Script function
9424  @unittest.skipIf(not PY35, "Python 3.5 needed")
9425  def test_annot_ast_py3_fn(self):
9426  code = dedent('''
9427  from typing import Tuple, List, Optional
9428  from torch import Tensor
9429  from torch.jit.annotations import BroadcastingList2, BroadcastingList3
9430  import torch
9431  @torch.jit.script
9432  def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
9433  return x, x
9434  ''')
9435  test_str = []
9436  for pair in self.type_input_return_pairs():
9437  fn = self._get_py3_code(self.format_code(code, pair), 'foo')
9438  test_str.append(fn.__getattr__('forward').pretty_print_schema())
9439  self.assertExpected("\n".join(test_str))
9440 
9441  # Python AST Frontend , Python 3-style type annotations , Script method
9442  @unittest.skipIf(not PY35, "Python 3.5 needed")
9443  def test_annot_ast_py3_method(self):
9444  code = dedent('''
9445  from typing import Tuple, List, Optional
9446  from torch import Tensor
9447  from torch.jit.annotations import BroadcastingList2, \\
9448  BroadcastingList3
9449  import torch
9450  class FooModule(torch.jit.ScriptModule):
9451  @torch.jit.script_method
9452  def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
9453  return x, x
9454  instance = FooModule()
9455  ''')
9456 
9457  test_str = []
9458  for pair in self.type_input_return_pairs():
9459  fn = self._get_py3_code(self.format_code(code, pair), 'instance')
9460  test_str.append(fn.__getattr__('foo').pretty_print_schema())
9461  self.assertExpected("\n".join(test_str))
9462 
9463  # Python AST Frontend , MyPy-style type comments , Script function
9464  @unittest.skipIf(not PY35, "Python 3.5 needed")
9465  def test_annot_ast_mypy_fn(self):
9466  code = dedent('''
9467  import torch
9468  @torch.jit.script
9469  def foo(x, y):
9470  # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
9471  return x, x
9472  ''')
9473 
9474  test_str = []
9475  for pair in self.type_input_return_pairs():
9476  fn = self._get_py3_code(self.format_code(code, pair), 'foo')
9477  test_str.append(fn.__getattr__('forward').pretty_print_schema())
9478  self.assertExpected("\n".join(test_str))
9479 
9480  # Python AST Frontend , MyPy-style type comments , Script method
9481  @unittest.skipIf(not PY35, "Python 3.5 needed")
9482  def test_annot_ast_mypy_method(self):
9483  code = dedent('''
9484  import torch
9485  class FooModule(torch.jit.ScriptModule):
9486  @torch.jit.script_method
9487  def foo(self, x, y):
9488  # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
9489  return x, x
9490  instance = FooModule()
9491  ''')
9492 
9493  test_str = []
9494  for pair in self.type_input_return_pairs():
9495  fn = self._get_py3_code(self.format_code(code, pair), 'instance')
9496  test_str.append(fn.__getattr__('foo').pretty_print_schema())
9497  self.assertExpected("\n".join(test_str))
9498 
9499  def test_method_casts_script(self):
9500  cast_types = [
9501  'byte', 'char', 'double', 'float', 'int', 'long', 'short'
9502  ]
9503 
9504  for cast_type in cast_types:
9505  cu = torch.jit.CompilationUnit('''
9506  def cast_to(x):
9507  return x.{cast_type}()
9508  '''.format(cast_type=cast_type))
9509 
9510  x = torch.rand(3, 4, 5) * 128
9511  cu_result = cu.cast_to(x)
9512  reference = getattr(x, cast_type)()
9513  self.assertEqual(cu_result, reference)
9514 
9515  def test_listconstruct_erasure(self):
9516  class FooMod(torch.nn.Module):
9517  def forward(self, x):
9518  mask = x < 0.0
9519  return x[mask]
9520 
9521  import io
9522  f = io.BytesIO()
9524  FooMod(), (torch.rand(3, 4),), f,
9525  operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK))
9526 
9527  def test_trace_checker_arange_as_constant(self):
9528  with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'):
9529  @_trace(torch.rand(3, 4), check_inputs=[(torch.rand(4, 5),)])
9530  def foo(x):
9531  y = torch.arange(0, x.shape[0]).double()
9532  return x + y.unsqueeze(1)
9533 
9534  @suppress_warnings
9535  def test_trace_checker_dot_data(self):
9536  with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Tensor-valued Constant nodes differed in value '
9537  r'across invocations'):
9538  @_trace(torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
9539  def foo(x):
9540  y = x.data
9541  return x + y
9542 
9543  @suppress_warnings
9544  def test_trace_checker_control_flow(self):
9545  def foo(x):
9546  for _ in range(x.size(0)):
9547  x = torch.neg(x)
9548  return x
9549 
9550  with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'):
9551  torch.jit.trace(foo, torch.randn(3, 4), check_inputs=[torch.randn(4, 4)])
9552 
9553  @suppress_warnings
9554  def test_trace_checker_memoization(self):
9555  with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'):
9556  def foo(x):
9557  if not hasattr(foo, 'cache'):
9558  foo.cache = torch.neg(x)
9559  return x + foo.cache
9560 
9561  traced = torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
9562 
9563  # These tests don't work because UBSAN has a false positive about accessing
9564  # out of bounds on a dynamically sized struct internal to asmjit
9565  if not TEST_WITH_UBSAN and torch.fbgemm_is_cpu_supported():
9566  def test_int8_quantization_module(self):
9567  K1, N1 = 2, 2
9568 
9569  class FooBar(torch.nn.Module):
9570  def __init__(self):
9571  super(FooBar, self).__init__()
9572  self.linear1 = torch.nn.Linear(K1, N1).float()
9573 
9574  def forward(self, x):
9575  x = self.linear1(x)
9576  return x
9577 
9578  fb = FooBar()
9579  fb.linear1.weight = torch.nn.Parameter(
9580  torch.tensor([[-150, 100], [100, -150]], dtype=torch.float), requires_grad=False)
9581  fb.linear1.bias = torch.nn.Parameter(torch.zeros_like(fb.linear1.bias), requires_grad=False)
9582  fb_ref = FooBar()
9583  fb_ref.linear1.weight = torch.nn.Parameter(fb.linear1.weight.clone(), requires_grad=False)
9584  fb_ref.linear1.bias = torch.nn.Parameter(fb.linear1.bias.clone(), requires_grad=False)
9586 
9587  x = (torch.rand(1, K1).float() - 0.5) / 10.0
9588  traced = torch.jit.trace(fb, (x,))
9589  fb = self.getExportImportCopyWithPacking(traced)
9590 
9591  x = torch.tensor([[100, -150]], dtype=torch.float)
9592  y = fb(x)
9593  y_ref = fb_ref(x)
9594  torch.testing.assert_allclose(y, y_ref, rtol=0.0001, atol=1e-3)
9595 
9596  def checkTracerWarning(self, *args, **kwargs):
9597  with warnings.catch_warnings(record=True) as warns:
9598  torch.jit.trace(*args, **kwargs)
9599  self.assertGreater(len(warns), 0)
9600  for warn in warns:
9601  self.assertIn("cause the trace to be incorrect", str(warn.message))
9602 
9603  def test_trace_checker_slice_lhs(self):
9604  def foo(x):
9605  for i in range(3):
9606  x[i, :] = torch.zeros(4)
9607  return x
9608 
9609  self.checkTrace(foo, (torch.rand(3, 4),))
9610 
9611  def test_trace_checker_inplace_on_view(self):
9612  def foo(x):
9613  x.view(-1).add_(-x.view(-1))
9614  return x
9615 
9616  self.assertWarnsRegex(lambda: torch.jit.trace(foo,
9617  torch.rand(3, 4),
9618  check_inputs=[torch.rand(5, 6)],
9619  _force_outplace=True),
9620  'Output nr 1. of the traced function does not match the '
9621  'corresponding output of the Python function')
9622 
9623  def test_lhs_index_fails(self):
9624  def foo(x):
9625  x[0, 1] = 4
9626  return x
9627  self.checkTracerWarning(foo, torch.rand(3, 4), _force_outplace=True)
9628 
9629  def test_lhs_index_trivial(self):
9630  def foo(y, x):
9631  y[...] = x
9632  return y
9633  self.checkTrace(foo, (torch.rand(3, 4), torch.rand(4)), inputs_require_grads=False)
9634 
9635  def test_inplace_warn(self):
9636  def foo(x):
9637  x.view(-1).add_(-x.view(-1))
9638  return x
9639  self.checkTracerWarning(foo, torch.rand(3, 4), _force_outplace=True)
9640 
9641  @suppress_warnings
9642  def test_trace_checker_dropout_train(self):
9643  def foo(x):
9644  return torch.dropout(x, p=0.5, train=True)
9645 
9646  self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]),
9647  'Output nr 1. of the traced function does not match the '
9648  'corresponding output of the Python function')
9649  self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]),
9650  'Trace had nondeterministic nodes')
9651 
9652  def test_trace_checker_dropout_notrain(self):
9653  input = torch.rand(3, 4)
9654 
9655  @_trace(input)
9656  def foo(x):
9657  return torch.dropout(x, p=0.5, train=False)
9658 
9659  self.assertEqual(foo(input), input)
9660 
9661  def test_export_dynamic_slice(self):
9662  class DynamicSliceExportMod(torch.jit.ScriptModule):
9663  @torch.jit.script_method
9664  def forward(self, x):
9665  retval = x[0]
9666  for i in range(x.size(1)):
9667  retval += torch.sum(x[0:i], dim=0)
9668  return retval
9669 
9670  mod = DynamicSliceExportMod()
9671 
9672  input = torch.rand(3, 4, 5)
9673  example_outs = mod(input)
9674 
9675  f = io.BytesIO()
9677  DynamicSliceExportMod(), (input,), f, example_outputs=example_outs)
9678  self.assertExpected(exported)
9679 
9680  def test_string_frontend_elif(self):
9681  code = '''
9682  def elif_test(niter : int):
9683  rv = 0
9684  for i in range(niter):
9685  if i % 3 == 0 and i % 5 == 0:
9686  rv += 35
9687  elif i % 3 == 0:
9688  rv += 3
9689  elif i % 5 == 0:
9690  rv += 5
9691  else:
9692  rv += i
9693  return rv
9694  '''
9695 
9696  self.checkScript(code, (101,), name='elif_test', outputs=3028)
9697 
9698  def test_pyop_exception_message(self):
9699  class Foo(torch.jit.ScriptModule):
9700  def __init__(self):
9701  super(Foo, self).__init__()
9702  self.conv = nn.Conv2d(1, 10, kernel_size=5)
9703 
9704  @torch.jit.script_method
9705  def forward(self, x):
9706  return self.conv(x)
9707  foo = Foo()
9708  # testing that the correct error message propagates
9709  with self.assertRaisesRegex(RuntimeError, "Expected 4-dimensional input for 4-dimensional weight"):
9710  foo(torch.ones([123])) # wrong size
9711 
9712  def test_builtin_error_messsage(self):
9713  from torch.nn.modules.utils import _single, _pair, _triple, _quadruple
9714 
9715  with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
9716  @torch.jit.script
9717  def close_match(x):
9718  return x.masked_fill(True)
9719 
9720  with self.assertRaisesRegex(RuntimeError, "This op may not exist or may not be currently "
9721  "supported in TorchScript"):
9722  @torch.jit.script
9723  def unknown_op(x):
9724  torch.set_grad_enabled(True)
9725  return x
9726 
9727  def test_exceptions(self):
9728  cu = torch.jit.CompilationUnit('''
9729  def foo(cond):
9730  if bool(cond):
9731  raise ValueError(3)
9732  return 1
9733  ''')
9734 
9735  cu.foo(torch.tensor(0))
9736  with self.assertRaisesRegex(torch.jit.Error, "Exception"):
9737  cu.foo(torch.tensor(1))
9738 
9739  @torch.jit.script
9740  def foo(cond):
9741  a = 3
9742  if bool(cond):
9743  raise ArbitraryError(a, "hi")
9744  if False:
9745  raise ArbitraryError
9746  return a
9747 
9748  foo(torch.tensor(0))
9749  # we don't currently validate the name of the exception
9750  with self.assertRaisesRegex(torch.jit.Error, "Exception"):
9751  foo(torch.tensor(1))
9752 
9753  @torch.jit.script
9754  def foo_except_used():
9755  a = Exception()
9756  print(a)
9757  raise a
9758 
9759  # a not DCEd
9760  with self.assertRaisesRegex(RuntimeError, "expected value of type Tensor"):
9761  foo_except_used()
9762 
9763  # We don't validate the expr following raise
9764  @torch.jit.script
9765  def foo():
9766  raise 3 + 4
9767 
9768  # no control flow analysis yet
9769  with self.assertRaisesRegex(RuntimeError, "undefined value a"):
9770  @torch.jit.script
9771  def foo():
9772  if True:
9773  a = 1
9774  else:
9775  raise Exception("Hi")
9776  return a
9777 
9778  def test_assertions(self):
9779  cu = torch.jit.CompilationUnit('''
9780  def foo(cond):
9781  assert bool(cond), "hi"
9782  return 0
9783  ''')
9784 
9785  cu.foo(torch.tensor(1))
9786  with self.assertRaisesRegex(torch.jit.Error, "Exception"):
9787  cu.foo(torch.tensor(0))
9788 
9789  @torch.jit.script
9790  def foo(cond):
9791  assert bool(cond), "hi"
9792 
9793  foo(torch.tensor(1))
9794  # we don't currently validate the name of the exception
9795  with self.assertRaisesRegex(torch.jit.Error, "Exception"):
9796  foo(torch.tensor(0))
9797 
9798  def test_weak_script_function(self):
9799  outer_var = 10
9800  outer_var2 = 11
9801 
9802  def not_a_script_fn(x):
9803  return x + 2
9804 
9805  @torch.jit.script
9806  def even_more_inner(x):
9807  return x + 1
9808 
9809  @torch.jit.script
9810  def inner(x):
9811  return not_a_script_fn(x) + x + even_more_inner(x)
9812 
9813  @torch.jit.script
9814  def strong_script_fn(x):
9815  if bool(x.norm() > 2):
9816  x = x + 3
9817  return x + 4 + inner(x)
9818 
9819  @torch._jit_internal.weak_script
9820  def weak_script_fn_inner(x):
9821  return x + 6 + not_a_script_fn(x)
9822 
9823  @torch._jit_internal.weak_script
9824  def weak_script_fn(x):
9825  return x + 5 + weak_script_fn_inner(x) + weak_script_fn_inner(x)
9826 
9827  def fn(x):
9828  x = not_a_script_fn(x)
9829  x = strong_script_fn(x)
9830  return weak_script_fn(x)
9831 
9832  input = torch.randn(3, 4, 5)
9833  self.checkScript(fn, (input,))
9834 
9835  def test_python_op_exception(self):
9836  def python_op(x):
9837  raise Exception("bad!")
9838 
9839  @torch.jit.script
9840  def fn(x):
9841  return python_op(x)
9842 
9843  with self.assertRaisesRegex(RuntimeError, "operation failed in interpreter"):
9844  fn(torch.tensor(4))
9845 
9846  def test_trace_contiguous(self):
9847  def foo(x):
9848  return x[:, :, ::2].contiguous().view(12)
9849 
9850  x = torch.rand(2, 3, 4)
9851  traced = torch.jit.trace(foo, (x,))
9852  y = traced(x)
9853  self.assertNotEqual(x.storage().data_ptr(), y.storage().data_ptr())
9854 
9855  # This tests the logic in THPVariable_contiguous. There is short-circuiting
9856  # code that prevents us from even getting to VariableType::contiguous, since
9857  # it is an optimization that prevents us from acquiring the GIL for touching
9858  # the device. We needed to add the tracing logic directly into the
9859  # THPVariable_contiguous function only for the path where we are skipping
9860  # dispatch into contiguous. We should see an aten::contiguous in this trace!
9861  def test_trace_contiguous_short_circuit(self):
9862  def foo(x):
9863  return x.contiguous()
9864 
9865  x = torch.rand(2, 3, 4)
9866  traced = torch.jit.trace(foo, (x,))
9867  FileCheck().check("aten::contiguous").run(str(traced.graph))
9868 
9869  def test_weak_module(self):
9870 
9871  @torch._jit_internal.weak_module
9872  class Weak(torch.nn.Module):
9873  __constants__ = ['number']
9874 
9875  def __init__(self):
9876  super(Weak, self).__init__()
9877  self.number = 199
9878 
9879  def python_op_in_weak_module(self, x):
9880  return x + 123
9881 
9882  @torch._jit_internal.weak_script_method
9883  def forward(self, x):
9884  return 55 + self.number + self.python_op_in_weak_module(x)
9885 
9886  class OtherStrong(torch.jit.ScriptModule):
9887  __constants__ = ['number']
9888 
9889  def __init__(self):
9890  super(OtherStrong, self).__init__()
9891  self.number = 357
9892 
9893  def python_op_in_strong_module(self, x):
9894  return x + 456
9895 
9896  @torch.jit.script_method
9897  def forward(self, x):
9898  return x + self.number + self.python_op_in_strong_module(x)
9899 
9900  class Passthrough(torch.jit.ScriptModule):
9901  def __init__(self):
9902  super(Passthrough, self).__init__()
9903  self.weak = Weak()
9904 
9905  @torch.jit.script_method
9906  def forward(self, x):
9907  return self.weak(x)
9908 
9909  weak_mod = Weak()
9910  x = torch.ones(1)
9911  expected_result = 55 + 199 + (x + 123)
9912 
9913  # Ensure weak mod is running without the JIT by passing the wrong type
9914  # (i.e. not a tensor)
9915  weak_mod(2)
9916 
9917  python_result = weak_mod(x)
9918  strong_mod = Passthrough()
9919  script_result = strong_mod(x)
9920 
9921  self.assertEqual(python_result, expected_result)
9922  self.assertEqual(script_result, expected_result)
9923 
9924  class Strong(torch.jit.ScriptModule):
9925  def __init__(self):
9926  super(Strong, self).__init__()
9927  self.weak = Weak()
9928  self.strong = OtherStrong()
9929 
9930  @torch.jit.script_method
9931  def forward(self, x):
9932  y = 2 * x
9933  return y + 1 + self.weak(y) + self.strong(y)
9934 
9935  strong_mod = Strong()
9936  strong_mod2 = Strong()
9937  x = torch.ones(1)
9938  expected_result = (x * 2) + 1 + (55 + 199 + x * 2 + 123) + (x * 2 + 357 + x * 2 + 456)
9939  script_result = strong_mod(x)
9940  script_result2 = strong_mod2(x)
9941  self.assertEqual(script_result, expected_result)
9942  self.assertEqual(script_result, script_result2)
9943 
9944  def test_weak_module_parameters_and_buffers(self):
9945  weights = torch.randn(10, 10)
9946  bias = torch.randn(10)
9947  weights2 = torch.randn(10, 10)
9948  bias2 = torch.randn(10)
9949 
9950  @torch._jit_internal.weak_module
9951  class TestLinear(torch.nn.Module):
9952  def __init__(self, in_features, out_features):
9953  super(TestLinear, self).__init__()
9954  self.in_features = in_features
9955  self.out_features = out_features
9956  self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
9957  self.bias = torch.nn.Parameter(torch.Tensor(out_features))
9958  self.register_buffer('counter', torch.ones(out_features))
9959  self.reset_parameters()
9960 
9961  def reset_parameters(self):
9962  torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
9963  if self.bias is not None:
9965  bound = 1 / math.sqrt(fan_in)
9966  torch.nn.init.uniform_(self.bias, -bound, bound)
9967 
9968  @torch._jit_internal.weak_script_method
9969  def forward(self, input):
9970  return F.linear(input, self.weight, self.bias) + self.counter
9971 
9972  # Initialize a ScriptModule that uses the weak module above multiple times
9973  class Strong(torch.jit.ScriptModule):
9974  def __init__(self):
9975  super(Strong, self).__init__()
9976  self.fc1 = TestLinear(10, 10)
9977  self.fc1.weight = torch.nn.Parameter(weights)
9978  self.fc1.bias = torch.nn.Parameter(bias)
9979  self.fc2 = TestLinear(10, 10)
9980  self.fc2.weight = torch.nn.Parameter(weights2)
9981  self.fc2.bias = torch.nn.Parameter(bias2)
9982 
9983  @torch.jit.script_method
9984  def forward(self, x):
9985  return x + self.fc1(x) + self.fc1(x) + self.fc2(x)
9986 
9987  strong_mod = Strong()
9988 
9989  # Run same calculation as module
9990  inp = torch.ones(10)
9991  lin = torch.nn.Linear(10, 10)
9992  lin.weight = torch.nn.Parameter(weights)
9993  lin.bias = torch.nn.Parameter(bias)
9994  lin2 = torch.nn.Linear(10, 10)
9995  lin2.weight = torch.nn.Parameter(weights2)
9996  lin2.bias = torch.nn.Parameter(bias2)
9997  expected_result = inp + (lin(inp) + torch.ones(10)) * 2 + lin2(inp) + torch.ones(10)
9998 
9999  self.assertEqual(strong_mod(inp), expected_result)
10000  self.assertExportImportModule(strong_mod, (inp,))
10001 
10002  def test_weak_module_nested(self):
10003  @torch._jit_internal.weak_module
10004  class OtherWeak(torch.nn.Module):
10005  __constants__ = ['constant']
10006 
10007  def __init__(self, in_features, out_features):
10008  super(OtherWeak, self).__init__()
10009  self.in_features = in_features
10010  self.out_features = out_features
10011  self.weight = torch.nn.Parameter(torch.ones(out_features, in_features))
10012  self.bias = torch.nn.Parameter(torch.ones(out_features))
10013  self.constant = 3
10014 
10015  @torch._jit_internal.weak_script_method
10016  def forward(self, x):
10017  return x * x + self.constant + F.linear(x, self.weight, self.bias)
10018 
10019  class OtherStrong(torch.jit.ScriptModule):
10020 
10021  def __init__(self):
10022  super(OtherStrong, self).__init__()
10023 
10024  @torch.jit.script_method
10025  def forward(self, x):
10026  return x + 27
10027 
10028  @torch._jit_internal.weak_module
10029  class Weak(torch.nn.Module):
10030  def __init__(self, in_features, out_features):
10031  super(Weak, self).__init__()
10032  self.in_features = in_features
10033  self.out_features = out_features
10034  self.weight = torch.nn.Parameter(2 * torch.ones(out_features, in_features))
10035  self.bias = torch.nn.Parameter(2 * torch.ones(out_features))
10036  self.weak_submodule = OtherWeak(10, 10)
10037  self.strong_submodule = OtherStrong()
10038 
10039  @torch._jit_internal.weak_script_method
10040  def forward(self, x):
10041  return x + self.weak_submodule(x) + self.strong_submodule(x) \
10042  + F.linear(x, self.weight, self.bias)
10043 
10044  class Strong(torch.jit.ScriptModule):
10045  __constants__ = ['constant']
10046 
10047  def __init__(self):
10048  super(Strong, self).__init__()
10049  self.weak = Weak(10, 10)
10050 
10051  @torch.jit.script_method
10052  def forward(self, x):
10053  return x + self.weak(x)
10054 
10055  strong_mod = Strong()
10056  inp = torch.randn(10)
10057  result = strong_mod(inp)
10058  expected_result = inp + (inp + inp * inp + inp + 27) + 3 \
10059  + F.linear(inp, torch.ones(10, 10), torch.ones(10)) \
10060  + F.linear(inp, 2 * torch.ones(10, 10), 2 * torch.ones(10))
10061  self.assertEqual(result, expected_result)
10062 
10063  def test_weak_module_submodule(self):
10064  @torch._jit_internal.weak_module
10065  class Weak(torch.nn.Module):
10066  def __init__(self):
10067  super(Weak, self).__init__()
10068  self.param = torch.nn.Parameter(100 * torch.ones(5))
10069 
10070  @torch._jit_internal.weak_script_method
10071  def forward(self, x):
10072  return x + self.param
10073 
10074  weak = Weak()
10075 
10076  class OtherStrong(torch.jit.ScriptModule):
10077  def __init__(self):
10078  super(OtherStrong, self).__init__()
10079  self.weak = weak
10080  self.weak2 = weak
10081 
10082  @torch.jit.script_method
10083  def forward(self, x):
10084  return x + self.weak(x)
10085 
10086  class Strong(torch.jit.ScriptModule):
10087  def __init__(self):
10088  super(Strong, self).__init__()
10089  self.weak = Weak()
10090 
10091  @torch.jit.script_method
10092  def forward(self, x):
10093  return self.weak(x) + weak(x)
10094 
10095  other_strong_mod = OtherStrong()
10096 
10097  self.assertIs(other_strong_mod.weak, other_strong_mod.weak2)
10098 
10099  with self.assertRaisesRegex(RuntimeError, "Attempted to inline a Module with param"):
10100  strong_mod = Strong()
10101 
10102  def test_weak_module_copying(self):
10103  class Submodule(torch.nn.Module):
10104  def __init__(self):
10105  super(Submodule, self).__init__()
10106 
10107  def forward(self, x):
10108  return x + 100
10109 
10110  @torch._jit_internal.weak_module
10111  class Weak(torch.nn.Module):
10112  def __init__(self, in_features, out_features):
10113  super(Weak, self).__init__()
10114  self.weight = torch.nn.Parameter(torch.ones(out_features, in_features))
10115  self.bias = torch.nn.Parameter(torch.ones(out_features))
10116  self.register_buffer("buffer", torch.ones(out_features))
10117  self.submodule = Submodule()
10118 
10119  @torch._jit_internal.weak_script_method
10120  def forward(self, x):
10121  return F.linear(x, self.weight, self.bias) \
10122  + self.buffer + self.submodule(x)
10123 
10124  class Strong(torch.jit.ScriptModule):
10125  def __init__(self, weak):
10126  super(Strong, self).__init__()
10127  self.weak = weak
10128 
10129  @torch.jit.script_method
10130  def forward(self, x):
10131  return self.weak(x)
10132 
10133  inp = torch.ones(5, 5) * 5
10134  weak_mod = Weak(5, 5)
10135  strong_mod = Strong(weak_mod)
10136 
10137  self.assertTrue(isinstance(strong_mod.weak, torch.jit.ScriptModule))
10138  self.assertFalse(isinstance(weak_mod, torch.jit.ScriptModule))
10139 
10140  self.assertIs(strong_mod.weak.weight, weak_mod.weight)
10141  self.assertIs(strong_mod.weak.buffer, weak_mod.buffer)
10142  self.assertIs(strong_mod.weak.submodule, weak_mod.submodule)
10143 
10144  # Test lookup fallback
10145  weak_mod.new_attribute = 10
10146  self.assertIs(strong_mod.weak.new_attribute, weak_mod.new_attribute)
10147 
10148  weak_mod.weight.data += torch.ones(5, 5) * 100
10149  self.assertTrue(strong_mod(inp).allclose(weak_mod(inp)))
10150 
10151  # Re-assignment is not tracked
10152  weak_mod.weight = torch.nn.Parameter(torch.ones(5, 5) * 100)
10153  self.assertFalse(strong_mod(inp).allclose(weak_mod(inp)))
10154 
10155  def test_backend_cudnn_enabled(self):
10156  # Only test that this compiles
10157  @torch.jit.script
10158  def fn(x):
10159  if torch.backends.cudnn.enabled:
10160  x = x + 2
10161  else:
10162  x = x + 3
10163  return x
10164 
10165  def test_inplace_add(self):
10166 
10167  def foo(a, b):
10168  c = a + b
10169  c.add_(b)
10170  return c
10171  self.checkScript(foo, (torch.rand(3), torch.rand(3)))
10172 
10173  def test_add_out(self):
10174  def foo(a, b):
10175  c = a + b
10176  e = 2 * a
10177  torch.add(c, b, out=e)
10178  return e
10179  self.checkScript(foo, (torch.rand(3), torch.rand(3)))
10180 
10181  def test_augmented_assign(self):
10182  def foo(a, b):
10183  a += b
10184  a -= b
10185  a /= b
10186  a *= b
10187  return a, b
10188  self.checkScript(foo, (torch.rand(3), torch.rand(3)))
10189 
10190  def test_pass(self):
10191  def foo(x):
10192  # type: (bool) -> int
10193  for _i in range(3):
10194  pass
10195  if x:
10196  pass
10197  else:
10198  pass
10199  return 3
10200 
10201  self.checkScript(foo, (True,))
10202 
10203  def test_optional_conversion(self):
10204  @torch.jit.script
10205  def other_fn(x=None):
10206  # type: (Optional[int]) -> int
10207  return torch.jit._unwrap_optional(x)
10208 
10209  @torch.jit.script
10210  def fn(x):
10211  # type: (int) -> int
10212  return other_fn(x)
10213 
10214  self.assertEqual(fn(2), 2)
10215 
10216  @torch.jit.script
10217  def unify_to_optional(x):
10218  # type: (bool) -> Optional[int]
10219  if x:
10220  a = None
10221  else:
10222  a = 2
10223  return a
10224 
10225  self.assertEqual(unify_to_optional(True), None)
10226  self.assertEqual(unify_to_optional(False), 2)
10227 
10228  @torch.jit.script
10229  def opt_list(x):
10230  # type: (Optional[List[float]]) -> int
10231  return 2
10232 
10233  @torch.jit.script
10234  def broadcast_opt_list(x):
10235  # type: (Optional[BroadcastingList2[float]]) -> int
10236  return 2
10237 
10238  @torch.jit.script
10239  def opt_list_tuple_caller(x):
10240  # type: (Tuple[float, float]) -> int
10241  return opt_list(x) + broadcast_opt_list(x)
10242 
10243  self.assertEqual(opt_list_tuple_caller((2., 3.)), 4)
10244 
10245  def test_lhs_indexing(self):
10246  def foo(a, b):
10247  a = a.clone()
10248  a[0] = b
10249  return a
10250  self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
10251 
10252  def test_lhs_advanced_indexing_assignment(self):
10253  def foo(x, y):
10254  a = torch.exp(x)
10255  b = x == 1
10256  a[b] = y[b]
10257  return a
10258  self.checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3)))
10259 
10260  def test_lhs_advanced_indexing_augmented_assignment(self):
10261  def foo(x, y):
10262  a = torch.exp(x)
10263  b = x == 1
10264  a[b] += y[b]
10265  return a
10266  self.checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3)))
10267 
10268  def test_lhs_indexing_list(self):
10269  def foo(a, b):
10270  ls = [a]
10271  ls[0] = b
10272  return ls
10273  self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
10274 
10275  def test_inplace_copy_script(self):
10276  def foo(x):
10277  a = torch.rand(3, 4)
10278  a.copy_(x)
10279  return a
10280  self.checkScript(foo, (torch.rand(3, 4),))
10281 
10282  def test_lhs_indexing_increment(self):
10283  def foo(a, b):
10284  a[0] += b
10285  return a
10286  self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
10287 
10288  def test_lhs_indexing_increment_list(self):
10289  def foo(a, b):
10290  a = a.clone()
10291  ls = [a, b]
10292  ls[0] += b
10293  return ls
10294  self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
10295 
10296  def test_lhs_indexing_increment_list_prim(self):
10297  def foo():
10298  ls = [1, 2, 3]
10299  ls[0] += 5
10300  return ls
10301  self.checkScript(foo, ())
10302 
10303  def test_lhs_indexing_multi(self):
10304  def foo(a, b):
10305  a = a.clone()
10306  foo, a[0], bar = (1, b, 3)
10307  return foo, a, bar
10308  self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
10309 
10310  def test_bool_dispatch(self):
10311  with self.disableModuleHook(): # TODO: Python print broadcasting list
10312  def kwarg_false(x):
10313  # type: (Tensor) -> Tensor
10314  return F.max_pool1d(x, 1, 1, return_indices=False)
10315  self.checkScript(kwarg_false, (torch.randn(3, 3, 3),))
10316 
10317  def kwarg_true(x):
10318  # type: (Tensor) -> Tuple[Tensor, Tensor]
10319  return F.max_pool1d(x, 1, 1, return_indices=True)
10320  self.checkScript(kwarg_true, (torch.randn(3, 3, 3),))
10321 
10322  def full_kwarg_false(x):
10323  # type: (Tensor) -> Tensor
10324  return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=False)
10325  self.checkScript(full_kwarg_false, (torch.randn(3, 3, 3),))
10326 
10327  def full_kwarg_true(x):
10328  # type: (Tensor) -> Tuple[Tensor, Tensor]
10329  return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=True)
10330  self.checkScript(full_kwarg_true, (torch.randn(3, 3, 3),))
10331 
10332  def use_default(x):
10333  # type: (Tensor) -> Tensor
10334  return F.max_pool1d(x, 1, 1)
10335  self.checkScript(use_default, (torch.randn(3, 3, 3),))
10336 
10337  def arg_false(x):
10338  # type: (Tensor) -> Tensor
10339  return F.max_pool1d(x, 1, 1, 0, 1, False, False)
10340  self.checkScript(arg_false, (torch.randn(3, 3, 3),))
10341 
10342  def arg_true(x):
10343  # type: (Tensor) -> Tuple[Tensor, Tensor]
10344  return F.max_pool1d(x, 1, 1, 0, 1, False, True)
10345  self.checkScript(arg_true, (torch.randn(3, 3, 3),))
10346 
10347  def test_infer_size(self):
10348  from torch._C import _infer_size
10349 
10350  def fn(x, y):
10351  # type: (Tensor, Tensor) -> List[int]
10352  return _infer_size(x.size(), y.size())
10353 
10354  self.checkScript(fn, (torch.ones(2, 4, 2), torch.ones(2, 4, 2)))
10355 
10356  def test_mutable_dce(self):
10357  @torch.jit.script
10358  def foo():
10359  a = torch.rand(2, 3)
10360  a += torch.rand(2, 3)
10361  b = torch.rand(2, 3)
10362  b += torch.rand(2, 3)
10363  # b should be cleaned up but not a
10364  return a
10365 
10366  FileCheck().check_count("aten::rand", 2, exactly=True) \
10367  .check_count("aten::add", 1, exactly=True).run(str(foo.graph))
10368 
10369  def test_mutable_dce_block(self):
10370  @torch.jit.script
10371  def foo():
10372  a = torch.rand(2, 3)
10373  a += torch.rand(2, 3)
10374  b = torch.rand(2, 3)
10375  if bool(a > torch.zeros(2, 3)):
10376  b += torch.rand(2, 3)
10377  a += torch.rand(2, 3)
10378  # a should be cleaned up but not b
10379  return b
10380 
10381  FileCheck().check("prim::If").check_count("aten::rand", 1, exactly=True) \
10382  .run(str(foo.graph))
10383 
10384  def test_mutable_dce_graph_input(self):
10385  @torch.jit.script
10386  def foo(a):
10387  a += torch.rand(2, 3)
10388  # shouldn't clean up `a` even though it's not used in the output
10389 
10390  FileCheck().check("aten::rand").check("aten::add").run(str(foo.graph))
10391 
10392  def test_mutable_dce_list(self):
10393  @torch.jit.script
10394  def foo(a):
10395  l = []
10396  l.append(a)
10397  c = l[0]
10398  b = torch.rand(2, 3)
10399  c += torch.rand(2, 3)
10400  return b
10401 
10402  # c does not get cleaned up because there is a wildcard + mutation
10403  FileCheck().check_count("aten::rand", 2, exactly=True).run(str(foo.graph))
10404 
10405  def test_mutable_dce_loop(self):
10406  @torch.jit.script
10407  def foo(a):
10408  l = []
10409  l.append(a)
10410  i = 0
10411  b = torch.rand(2, 3)
10412  while i < 1:
10413  dead = torch.rand(2, 3)
10414  c = l[0]
10415  c += torch.rand(2, 3)
10416  i += 1
10417  return b
10418 
10419  FileCheck().check("prim::Loop").check_not("aten::rand").check("aten::select") \
10420  .check_count("aten::rand", 1, exactly=True).run(str(foo.graph))
10421 
10422  def test_mutable_dce_wildcards(self):
10423  def fn():
10424  x = torch.ones(2, 3)
10425  l = []
10426  l.append(x)
10427  x_view = l[0]
10428  x.add_(torch.ones(2, 3))
10429  return x_view
10430 
10431  self.checkScript(fn, ())
10432 
10433  def test_cpp_function_tensor_str(self):
10434  x = torch.randn(2, 2)
10435  scale = torch.randn(2, 2, requires_grad=True)
10436  shift = torch.randn(2, 2, requires_grad=True)
10437 
10438  @torch.jit.script
10439  def fn(x, scale, shift):
10440  return scale * x + shift
10441 
10442  with self.capture_stdout() as captured:
10443  print(fn(x, scale, shift))
10444 
10445  def test_non_final_return(self):
10446 
10447  def simple(x):
10448  if bool(x > 3):
10449  return x + 1
10450  else:
10451  return x + 2
10452  raise RuntimeError("nope")
10453 
10454  def nest(x):
10455  x = x + 1
10456  if bool(x > 3):
10457  if bool(x > 4):
10458  x += 1
10459  return x + 1
10460  else:
10461  return x + 2
10462 
10463  def early_ret(x):
10464  x = x + 1
10465  if bool(x > 3):
10466  return x + 1
10467  x = x + 1
10468  return x + 2
10469 
10470  def nest_early_ret(x):
10471  x = x + 1
10472  if bool(x > 3):
10473  if bool(x > 4):
10474  return x + 2
10475  return x + 1
10476  x = x + 1
10477  return x + 2
10478 
10479  self.checkScript(simple, torch.rand(1))
10480  self.checkScript(nest, torch.rand(1))
10481  self.checkScript(early_ret, torch.rand(1))
10482  self.checkScript(nest_early_ret, torch.rand(1))
10483 
10484  with self.assertRaisesRegex(RuntimeError, "early"):
10485  @torch.jit.script
10486  def not_early_ret(x):
10487  if bool(x > 3):
10488  if bool(x > 4):
10489  return 1
10490  print("foo")
10491  else:
10492  print("5")
10493  return 7
10494 
10495  with self.assertRaisesRegex(RuntimeError, "some paths"):
10496  @torch.jit.script
10497  def not_total_ret(x):
10498  if bool(x > 3):
10499  if bool(x > 4):
10500  return 1
10501  else:
10502  return 2
10503  else:
10504  print("5")
10505  return 7
10506 
10507  with self.assertRaisesRegex(RuntimeError, "from a loop"):
10508  @torch.jit.script
10509  def nest_while_ret(x):
10510  while bool(x > 4):
10511  if bool(x < 3):
10512  return 4
10513  return 5
10514 
10515  with self.assertRaisesRegex(RuntimeError, "from a loop"):
10516  @torch.jit.script
10517  def nest_for_ret(x):
10518  for _ in range(3):
10519  if bool(x < 3):
10520  return 4
10521  return 5
10522 
10523  def test_overloading(self):
10524  @torch._jit_internal.weak_module
10525  class W(torch.nn.Module):
10526  __overloads__ = {'forward': ['forward_tuple', 'forward_tensor']}
10527 
10528  def __init__(self):
10529  super(W, self).__init__()
10530 
10531  @torch._jit_internal.weak_script_method
10532  def forward_tuple(self, x):
10533  # type: (Tuple[Tensor, Tensor]) -> Tensor
10534  return x[0] + 5
10535 
10536  def forward(self, x):
10537  # manually do argument switching
10538  if isinstance(x, tuple):
10539  return self.forward_tuple(x)
10540  else:
10541  return self.forward_tensor(x)
10542 
10543  @torch._jit_internal.weak_script_method
10544  def forward_tensor(self, x):
10545  # type: (Tensor) -> Tensor
10546  return x + 20
10547 
10548  class S(torch.jit.ScriptModule):
10549  def __init__(self):
10550  super(S, self).__init__()
10551  self.weak = W()
10552 
10553  @torch.jit.script_method
10554  def forward(self, x):
10555  return self.weak(x) + self.weak((x, x))
10556 
10557  s = S()
10558  x = torch.ones(1)
10559  self.assertEqual(s(x), x + 20 + 5 + x)
10560 
10561  w = W()
10562  self.assertEqual(w((x, x)), x + 5)
10563  self.assertEqual(w((x)), x + 20)
10564 
10565  def test_select_after_chunk(self):
10566  def foo(x):
10567  chunked = torch.chunk(x, 1)
10568  foo = chunked[0]
10569  foo.add_(5)
10570  return x
10571 
10572  self.checkScript(foo, [torch.rand(2, 3)])
10573 
10574  def test_nn_LSTM(self):
10575  input = torch.nn.utils.rnn.pack_sequence([torch.randn(5, 5)])
10576 
10577  class S(torch.jit.ScriptModule):
10578  def __init__(self):
10579  super(S, self).__init__()
10580  self.x = torch.nn.LSTM(5, 5)
10581 
10582  @torch.jit.script_method
10583  def forward(self, input):
10584  # type: (Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]) -> Tuple[Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]], Tuple[Tensor, Tensor]] # noqa
10585  return self.x(input)
10586 
10587  eager_out = self.runAndSaveRNG(lambda x: torch.nn.LSTM(5, 5)(x), (input,))[0]
10588  script_out = self.runAndSaveRNG(lambda x: S()(x), (input,))[0]
10589 
10590  self.assertEqual(eager_out, script_out)
10591 
10592  def test_list_python_op(self):
10593  def python_list_op(lst):
10594  # type: (List[Tensor]) -> Tensor
10595  return lst[0]
10596 
10597  def fn(lst):
10598  # type: (List[Tensor]) -> Tensor
10599  return python_list_op(lst)
10600 
10601  self.checkScript(fn, ([torch.ones(2) + 2, torch.ones(2)],))
10602 
10603  def test_ignore_decorator(self):
10604  class M(torch.jit.ScriptModule):
10605  def __init__(self):
10606  super(M, self).__init__()
10607  tensor = torch.zeros(1, requires_grad=False)
10608  self.register_buffer('some_state', torch.nn.Parameter(tensor))
10609 
10610  @torch.jit.script_method
10611  def forward(self, x):
10612  self.ignored_code(x)
10613  return x
10614 
10615  @torch.jit.ignore
10616  def ignored_code(self, x):
10617  self.some_state = torch.tensor((100,))
10618 
10619  # Assert ignored code is run
10620  m = M()
10621  self.assertEqual(m.some_state, torch.zeros(1))
10622  m(torch.ones(1))
10623  self.assertEqual(m.some_state, torch.zeros(1) + 100)
10624 
10625  # Export and ensure ignored code not present
10626  pp, constants = m._python_print()
10627  printed = torch.jit.ScriptModule()
10628  ppv = "op_version_set = 0\n{}".format(pp)
10629  torch._C._jit_import_methods(printed, ppv, constants)
10630  self.assertIn('IgnoredPythonOp', ppv)
10631  self.assertNotIn('ignored_code', ppv)
10632 
10633  with self.assertRaisesRegex(torch.jit.Error, "This Python function is annotated to be ignored"):
10634  printed(torch.ones(1))
10635 
10636  def test_view_write(self):
10637  def fn(x, y):
10638  l = []
10639  l.append(x)
10640  x_view = l[0]
10641  a = x + x
10642  x_view.add_(y)
10643  b = x + x
10644  return a == b
10645  self.checkScript(fn, (torch.rand(2, 3), torch.rand(2, 3)))
10646 
10647  def test_dict_view(self):
10648  def fn(x, y):
10649  l = {"a": x}
10650  x_view = l["a"]
10651  a = x + x
10652  x_view.add_(y)
10653  b = x + x
10654  return a == b
10655  self.checkScript(fn, (torch.rand(2, 3), torch.rand(2, 3)))
10656 
10657  def test_dict_ops(self):
10658  d = {'a': torch.ones(1), 'b': torch.ones(1) + 1, 'c': torch.ones(1) + 2}
10659 
10660  @torch.jit.script
10661  def keys(x):
10662  # type: (Dict[str, Tensor]) -> List[str]
10663  return list(x.keys())
10664 
10665  self.assertEqual(set(keys(d)), set(d.keys()))
10666 
10667  @torch.jit.script
10668  def values(x):
10669  # type: (Dict[str, Tensor]) -> List[Tensor]
10670  return list(x.values())
10671 
10672  self.assertEqual(set(values(d)), set(d.values()))
10673 
10674  def length(x):
10675  # type: (Dict[str, Tensor]) -> int
10676  return len(x)
10677 
10678  self.checkScript(length, (d,))
10679 
10680  def test_dict(self):
10681  def simple(x):
10682  # type: (Dict[str, int]) -> Dict[str, int]
10683  return x
10684 
10685  self.checkScript(simple, ({'item': 20, 'other_item': 120},))
10686 
10687  def index(x):
10688  # type: (Dict[str, int]) -> int
10689  return x['item']
10690 
10691  self.checkScript(index, ({'item': 20, 'other_item': 120},))
10692 
10693  def type_default():
10694  # type: () -> Dict[str, Tensor]
10695  return {}
10696 
10697  self.checkScript(type_default, ())
10698 
10699  @torch.jit.script
10700  def missing_index(x):
10701  # type: (Dict[str, int]) -> int
10702  return x['dne']
10703 
10704  with self.assertRaisesRegex(RuntimeError, "KeyError"):
10705  missing_index({'item': 20, 'other_item': 120})
10706 
10707  code = dedent('''
10708  def literal1():
10709  return torch.jit.annotate(Dict[int, float], {})
10710  def literal2():
10711  return torch.jit.annotate(Dict[int, float], {10: 1.2})
10712  ''')
10713  cu = torch.jit.CompilationUnit(code)
10714  self.assertEqual({}, cu.literal1())
10715  self.assertEqual({10: 1.2}, cu.literal2())
10716 
10717  cu = torch.jit.CompilationUnit(dedent('''
10718  def literal3():
10719  return torch.jit.annotate(Dict[int, float], {10: 1.2, 11: 1.3})
10720  '''))
10721  self.assertEqual({10: 1.2, 11: 1.3}, cu.literal3())
10722 
10723  def list_of_dicts():
10724  # type: () -> List[Dict[str, Tensor]]
10725  return [{'word': torch.ones(2) + 3}, {'other word': torch.ones(1) + 2}]
10726 
10727  self.checkScript(list_of_dicts, ())
10728 
10729  def test_dict_mutability(self):
10730  @torch.jit.script
10731  def fn():
10732  # type: () -> Dict[str, int]
10733  a = torch.jit.annotate(Dict[str, int], {})
10734  a['ok'] = 10
10735  return a
10736 
10737  self.assertEqual(fn(), {'ok': 10})
10738 
10739  def test_dict_membership(self):
10740  def fn(x, y):
10741  # type: (Dict[int, int], int) -> int
10742  return x.get(y, 3)
10743 
10744  d = {1: 2, 3: 4}
10745  self.checkScript(fn, (d, 3))
10746  self.checkScript(fn, (d, 2))
10747 
10748  def optional(x, y):
10749  # type: (Dict[int, int], int) -> bool
10750  res = x.get(y)
10751  return res is None
10752 
10753  self.checkScript(fn, (d, 3))
10754  self.checkScript(fn, (d, 2))
10755 
10756  with self.assertRaisesRegex(RuntimeError, "is actually of type Optional"):
10757  @torch.jit.script
10758  def bad_types(x, y):
10759  # type: (Dict[int, int], int) -> int
10760  return x.get(y) # noqa: T484
10761 
10762  def dict_to_python(self):
10763  def python_lookup(my_dict, keys):
10764  # type: (Dict[str, int], List[str]) -> List[int]
10765  return [my_dict[k] for k in keys]
10766 
10767  def fn(my_dict, keys):
10768  # type: (Dict[str, int], List[str]) -> List[int]
10769  return python_lookup(my_dict, keys)
10770 
10771  a_dict = {'a': torch.ones(1), 'b': torch.ones(1) + 1, 'c': torch.ones(1) + 2}
10772  self.checkScript(fn, (a_dict, ('a', 'c')))
10773 
10774  def test_module_attrs(self):
10775  class M(torch.jit.ScriptModule):
10776  def __init__(self, table):
10777  super(M, self).__init__()
10778  self.table = torch.jit.Attribute(table, Dict[str, torch.Tensor])
10779  self.x = torch.nn.Parameter(torch.tensor([100.0]))
10780 
10781  @torch.jit.script_method
10782  def forward(self, key):
10783  # type: (str) -> Tensor
10784  return self.table[key] + self.x
10785 
10786  with self.disableModuleHook():
10787  # TODO: re-enable module hook when Python printing of attributes is
10788  # supported
10789  m = M({char : torch.ones(1) + ord(char) - ord("a") for char in "abcdefg"})
10790  self.assertEqual(m("c"), torch.tensor([103]))
10791 
10792  def test_tensor_import_export(self):
10793  @torch.jit.script
10794  def foo(x):
10795  a = torch.tensor(1)
10796  b = torch.tensor([1, 2])
10797  c = [a, b]
10798  return c
10799 
10800  self.run_pass('constant_propagation', foo.graph)
10802  m._create_method_from_graph("forward", foo.graph)
10803  self.getExportImportCopy(m)
10804 
10805  def test_attribute_serialization(self):
10806  class M(torch.jit.ScriptModule):
10807  def __init__(self):
10808  super(M, self).__init__()
10809  self.table = torch.jit.Attribute({"I": "am", "a test": "test"}, Dict[str, str])
10810  self.float = torch.jit.Attribute(2.3, float)
10811  self.int = torch.jit.Attribute(99, int)
10812  self.tuple = torch.jit.Attribute((1, 2, 3, 4), Tuple[int, int, int, int])
10813  self.list = torch.jit.Attribute([(1, 2), (3, 4)], List[Tuple[int, int]])
10814  self.tensor = torch.jit.Attribute(torch.randn(2, 2), torch.Tensor)
10815  self.int_list = torch.jit.Attribute([1, 2, 3, 4], List[int])
10816 
10817  @torch.jit.script_method
10818  def forward(self):
10819  return (self.table, self.float, self.int, self.tuple, self.list, self.int_list)
10820 
10821  m = M()
10822  imported_m = self.getExportImportCopy(m)
10823  self.assertEqual(m(), imported_m())
10824 
10825  @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle")
10826  def test_attribute_unpickling(self):
10827  import zipfile
10828 
10829  class M(torch.jit.ScriptModule):
10830  def __init__(self):
10831  super(M, self).__init__()
10832  self.table = torch.jit.Attribute({"I": "am", "a test": "test"}, Dict[str, str])
10833  self.float = torch.jit.Attribute(2.3, float)
10834  self.int = torch.jit.Attribute(99, int)
10835  self.tuple = torch.jit.Attribute((1, 2, 3, 4), Tuple[int, int, int, int])
10836  self.list = torch.jit.Attribute([(1, 2), (3, 4)], List[Tuple[int, int]])
10837  self.tensor = torch.jit.Attribute(torch.randn(2, 2), torch.Tensor)
10838  self.int_list = torch.jit.Attribute([1, 2, 3, 4], List[int])
10839 
10840  @torch.jit.script_method
10841  def forward(self):
10842  return (self.table, self.float, self.int, self.tuple, self.list, self.int_list)
10843 
10844  class TensorID(object):
10845  def __setstate__(self, id):
10846  self.id = id
10847 
10848  class IntList(object):
10849  def __setstate__(self, data):
10850  self.data = data
10851 
10852  class JitUnpickler(pickle.Unpickler):
10853  def find_class(self, module, name):
10854  if not module == '__main__':
10855  return None
10856 
10857  if name == 'TensorID':
10858  return TensorID
10859  elif name == 'IntList':
10860  return IntList
10861 
10862  with TemporaryFileName() as fname:
10863  M().save(fname)
10864  archive_name = os.path.basename(os.path.normpath(fname))
10865  archive = zipfile.ZipFile(fname, 'r')
10866  pickled_data = archive.read(os.path.join(archive_name, 'attributes.pkl'))
10867  JitUnpickler(io.BytesIO(pickled_data)).load()
10868 
10869  def test_submodule_attribute_serialization(self):
10870  class S(torch.jit.ScriptModule):
10871  def __init__(self, list_data):
10872  super(S, self).__init__()
10873  self.table = torch.jit.Attribute({"I": "am", "a test": "test"}, Dict[str, str])
10874  self.list = torch.jit.Attribute(list_data, List[Tuple[int, int]])
10875 
10876  @torch.jit.script_method
10877  def forward(self):
10878  return (self.table, self.list)
10879 
10880  class M(torch.jit.ScriptModule):
10881  def __init__(self):
10882  super(M, self).__init__()
10883  self.table = torch.jit.Attribute({"this": "is", "a different": "dict"}, Dict[str, str])
10884  self.tensor = torch.jit.Attribute(torch.randn(2, 2), torch.Tensor)
10885  self.s1 = S([(1, 2)])
10886  self.s2 = S([(4, 5)])
10887 
10888  @torch.jit.script_method
10889  def forward(self):
10890  return (self.table, self.tensor, self.s1.table, self.s2.list, self.s1.list)
10891 
10892  m = M()
10893  imported_m = self.getExportImportCopy(m)
10894  self.assertEqual(m(), imported_m())
10895 
10896  def test_optional_tuple(self):
10897  def fn(x=None):
10898  # type: (Optional[Tuple[int, int]]) -> Tuple[int, int]
10899  if x is None:
10900  new_x = (1, 2)
10901  else:
10902  new_x = x
10903  return new_x
10904 
10905  self.checkScript(fn, ((3, 4),))
10906  self.checkScript(fn, ())
10907 
10908  def test_split(self):
10909  def split_two(tensor):
10910  a, b, c = torch.split(tensor, 2, dim=1)
10911  return a, b, c
10912  x = torch.randn(3, 6)
10913  y = torch.randn(3, 6)
10914  self.checkScript(split_two, [(x + y)])
10915 
10916 
10917 class MnistNet(nn.Module):
10918  def __init__(self):
10919  super(MnistNet, self).__init__()
10920  self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
10921  self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
10922  self.conv2_drop = nn.Dropout2d()
10923  self.fc1 = nn.Linear(320, 50)
10924  self.fc2 = nn.Linear(50, 10)
10925 
10926  def forward(self, x):
10927  x = F.relu(F.max_pool2d(self.conv1(x), 2))
10928  x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
10929  x = x.view(-1, 320)
10930  x = F.relu(self.fc1(x))
10931  x = F.dropout(x, training=self.training)
10932  x = self.fc2(x)
10933  return F.log_softmax(x, dim=1)
10934 
10935 
10937  @staticmethod
10938  def _test_dcgan_models(self, device, check_export_import=True):
10939  class DCGANGenerator(nn.Module):
10940  def __init__(self, nz, ngf, nc):
10941  super(DCGANGenerator, self).__init__()
10942  self.main = nn.Sequential(
10943  # input is Z, going into a convolution
10944  nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
10945  nn.BatchNorm2d(ngf * 8),
10946  nn.ReLU(True),
10947  # state size. (ngf*8) x 4 x 4
10948  nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
10949  nn.BatchNorm2d(ngf * 4),
10950  nn.ReLU(True),
10951  # state size. (ngf*4) x 8 x 8
10952  nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
10953  nn.BatchNorm2d(ngf * 2),
10954  nn.ReLU(True),
10955  # state size. (ngf*2) x 16 x 16
10956  nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
10957  nn.BatchNorm2d(ngf),
10958  nn.ReLU(True),
10959  # state size. (ngf) x 32 x 32
10960  nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
10961  nn.Tanh()
10962  # state size. (nc) x 64 x 64
10963  )
10964 
10965  def forward(self, input):
10966  return self.main(input)
10967 
10968  class DCGANDiscriminator(nn.Module):
10969  def __init__(self, nc, ndf):
10970  super(DCGANDiscriminator, self).__init__()
10971  self.main = nn.Sequential(
10972  # input is (nc) x 64 x 64
10973  nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
10974  nn.LeakyReLU(0.2, inplace=True),
10975  # state size. (ndf) x 32 x 32
10976  nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
10977  nn.BatchNorm2d(ndf * 2),
10978  nn.LeakyReLU(0.2, inplace=True),
10979  # state size. (ndf*2) x 16 x 16
10980  nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
10981  nn.BatchNorm2d(ndf * 4),
10982  nn.LeakyReLU(0.2, inplace=True),
10983  # state size. (ndf*4) x 8 x 8
10984  nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
10985  nn.BatchNorm2d(ndf * 8),
10986  nn.LeakyReLU(0.2, inplace=True),
10987  # state size. (ndf*8) x 4 x 4
10988  nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
10989  nn.Sigmoid()
10990  )
10991 
10992  def forward(self, input):
10993  return self.main(input).view(-1, 1).squeeze(1)
10994 
10995  bs, nz, ngf, nc, ndf = 5, 6, 9, 3, 10
10996  self.checkTrace(DCGANGenerator(nz, ngf, nc).to(device),
10997  (torch.rand(bs, nz, 1, 1, device=device),),
10998  export_import=check_export_import)
10999  example_input = DCGANGenerator(nz, ngf, nc).to(device)(torch.rand(bs, nz, 1, 1, device=device))
11000  self.checkTrace(DCGANDiscriminator(nc, ndf).to(device), (example_input,),
11001  export_import=check_export_import)
11002 
11003  def test_dcgan_models(self):
11004  self._test_dcgan_models(self, device='cpu')
11005 
11006  @unittest.skipIf(not RUN_CUDA, "no CUDA")
11007  def test_dcgan_models_cuda(self):
11008  # XXX: export_import on CUDA modules doesn't work (#11480)
11009  self._test_dcgan_models(self, device='cuda', check_export_import=False)
11010 
11011  @staticmethod
11012  def _test_neural_style(self, device, check_export_import=True):
11013  class TransformerNet(torch.nn.Module):
11014  def __init__(self):
11015  super(TransformerNet, self).__init__()
11016  # Initial convolution layers
11017  self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
11018  self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
11019  self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
11020  self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
11021  self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
11022  self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
11023  # Residual layers
11024  self.res1 = ResidualBlock(128)
11025  self.res2 = ResidualBlock(128)
11026  self.res3 = ResidualBlock(128)
11027  self.res4 = ResidualBlock(128)
11028  self.res5 = ResidualBlock(128)
11029  # Upsampling Layers
11030  self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
11031  self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
11032  self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
11033  self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
11034  self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
11035  # Non-linearities
11036  self.relu = torch.nn.ReLU()
11037 
11038  def forward(self, X):
11039  y = self.relu(self.in1(self.conv1(X)))
11040  y = self.relu(self.in2(self.conv2(y)))
11041  y = self.relu(self.in3(self.conv3(y)))
11042  y = self.res1(y)
11043  y = self.res2(y)
11044  y = self.res3(y)
11045  y = self.res4(y)
11046  y = self.res5(y)
11047  y = self.relu(self.in4(self.deconv1(y)))
11048  y = self.relu(self.in5(self.deconv2(y)))
11049  y = self.deconv3(y)
11050  return y
11051 
11052  class ConvLayer(torch.nn.Module):
11053  def __init__(self, in_channels, out_channels, kernel_size, stride):
11054  super(ConvLayer, self).__init__()
11055  reflection_padding = kernel_size // 2
11056  self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
11057  self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
11058 
11059  def forward(self, x):
11060  out = self.reflection_pad(x)
11061  out = self.conv2d(out)
11062  return out
11063 
11064  class ResidualBlock(torch.nn.Module):
11065  """ResidualBlock
11066  introduced in: https://arxiv.org/abs/1512.03385
11067  recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
11068  """
11069 
11070  def __init__(self, channels):
11071  super(ResidualBlock, self).__init__()
11072  self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
11073  self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
11074  self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
11075  self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
11076  self.relu = torch.nn.ReLU()
11077 
11078  def forward(self, x):
11079  residual = x
11080  out = self.relu(self.in1(self.conv1(x)))
11081  out = self.in2(self.conv2(out))
11082  out = out + residual
11083  return out
11084 
11085  class UpsampleConvLayer(torch.nn.Module):
11086  """UpsampleConvLayer
11087  Upsamples the input and then does a convolution. This method gives better results
11088  compared to ConvTranspose2d.
11089  ref: http://distill.pub/2016/deconv-checkerboard/
11090  """
11091 
11092  def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
11093  super(UpsampleConvLayer, self).__init__()
11094  self.upsample = upsample
11095  if upsample:
11096  self.upsample_layer = torch.nn.Upsample(mode='nearest', scale_factor=upsample)
11097  reflection_padding = kernel_size // 2
11098  self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
11099  self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
11100 
11101  def forward(self, x):
11102  x_in = x
11103  if self.upsample:
11104  x_in = self.upsample_layer(x_in)
11105  out = self.reflection_pad(x_in)
11106  out = self.conv2d(out)
11107  return out
11108 
11109  self.checkTrace(TransformerNet(), (torch.rand(5, 3, 16, 16),), export_import=check_export_import)
11110 
11111  def test_neural_style(self):
11112  self._test_neural_style(self, device='cpu')
11113 
11114  @unittest.skipIf(not RUN_CUDA, "no CUDA")
11115  def test_neural_style_cuda(self):
11116  # XXX: export_import on CUDA modules doesn't work (#11480)
11117  self._test_neural_style(self, device='cuda', check_export_import=False)
11118 
11119  @staticmethod
11120  def _test_mnist(self, device, check_export_import=True):
11121  # eval() is present because dropout makes this nondeterministic
11122  self.checkTrace(MnistNet().to(device).eval(), (torch.rand(5, 1, 28, 28, device=device),),
11123  export_import=check_export_import)
11124 
11125  def test_mnist(self):
11126  self._test_mnist(self, device='cpu')
11127 
11128  @unittest.skipIf(not RUN_CUDA, "no CUDA")
11129  def test_mnist_cuda(self):
11130  # XXX: export_import on CUDA modules doesn't work (#11480)
11131  self._test_mnist(self, device='cuda', check_export_import=False)
11132 
11133  @unittest.skipIf(not RUN_CUDA, "no CUDA")
11134  def test_mnist_training_leaks_no_memory_cuda(self):
11135  net = MnistNet().cuda()
11136  # MnistNet uses dropout, don't check its trace
11137  traced_net = torch.jit.trace(net, [torch.randn(5, 1, 28, 28, device='cuda')],
11138  check_trace=False)
11139 
11140  def train(iters):
11141  for _ in range(iters):
11142  # Get some fake data
11143  inp = torch.randn(5, 1, 28, 28, device='cuda')
11144  out = traced_net(inp)
11145 
11146  # Here's some fake loss
11147  out.sum().backward()
11148 
11149  # Zero out grads
11150  traced_net.zero_grad()
11151 
11152  # Set it up so the params have .grad fields so they are not reported as leaks
11153  train(1)
11154 
11155  with self.assertLeaksNoCudaTensors():
11156  train(5)
11157 
11158  @staticmethod
11159  def _test_reinforcement_learning(self, device, test_export_import=True):
11160  class Policy(nn.Module):
11161  def __init__(self):
11162  super(Policy, self).__init__()
11163  self.affine1 = nn.Linear(4, 128)
11164  self.affine2 = nn.Linear(128, 2)
11165 
11166  def forward(self, x):
11167  x = F.relu(self.affine1(x))
11168  action_scores = self.affine2(x)
11169  return F.softmax(action_scores, dim=1)
11170 
11171  self.checkTrace(Policy().to(device), (torch.rand(1, 4, device=device),),
11172  export_import=test_export_import)
11173 
11174  def test_reinforcement_learning(self):
11175  self._test_reinforcement_learning(self, device='cpu')
11176 
11177  @unittest.skipIf(not RUN_CUDA, "no CUDA")
11178  def test_reinforcement_learning_cuda(self):
11179  # XXX: export_import on CUDA modules doesn't work (#11480)
11180  self._test_reinforcement_learning(self, device='cuda', test_export_import=False)
11181 
11182  @staticmethod
11183  def _test_snli(self, device, check_export_import=True, quantized=False):
11184  class Bottle(nn.Module):
11185 
11186  def forward(self, input):
11187  if len(input.size()) <= 2:
11188  return super(Bottle, self).forward(input)
11189  size = input.size()[:2]
11190  out = super(Bottle, self).forward(input.view(size[0] * size[1], -1))
11191  return out.view(size[0], size[1], -1)
11192 
11193  class Linear(Bottle, nn.Linear):
11194  pass
11195 
11196  class Encoder(nn.Module):
11197 
11198  def __init__(self, config):
11199  super(Encoder, self).__init__()
11200  self.config = config
11201  input_size = config.d_proj if config.projection else config.d_embed
11202  dropout = 0 if config.n_layers == 1 else config.dp_ratio
11203  self.rnn = nn.LSTM(input_size=input_size, hidden_size=config.d_hidden,
11204  num_layers=config.n_layers, dropout=dropout,
11205  bidirectional=config.birnn)
11206 
11207  def forward(self, inputs):
11208  batch_size = inputs.size()[1]
11209  state_shape = self.config.n_cells, batch_size, self.config.d_hidden
11210  h0 = c0 = inputs.new_zeros(state_shape)
11211  outputs, (ht, ct) = self.rnn(inputs, (h0, c0))
11212  return ht[-1] if not self.config.birnn else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)
11213 
11214  class SNLIClassifier(nn.Module):
11215 
11216  def __init__(self, config):
11217  super(SNLIClassifier, self).__init__()
11218  self.config = config
11219  self.embed = nn.Embedding(config.n_embed, config.d_embed)
11220  self.projection = Linear(config.d_embed, config.d_proj)
11221  self.encoder = Encoder(config)
11222  self.dropout = nn.Dropout(p=config.dp_ratio)
11223  self.relu = nn.ReLU()
11224  seq_in_size = 2 * config.d_hidden
11225  if self.config.birnn:
11226  seq_in_size *= 2
11227  lin_config = [seq_in_size] * 2
11228  self.out = nn.Sequential(
11229  Linear(*lin_config),
11230  self.relu,
11231  self.dropout,
11232  Linear(*lin_config),
11233  self.relu,
11234  self.dropout,
11235  Linear(*lin_config),
11236  self.relu,
11237  self.dropout,
11238  Linear(seq_in_size, config.d_out))
11239 
11240  def forward(self, premise, hypothesis):
11241  prem_embed = self.embed(premise)
11242  hypo_embed = self.embed(hypothesis)
11243  if self.config.fix_emb:
11244  prem_embed = prem_embed.detach()
11245  hypo_embed = hypo_embed.detach()
11246  if self.config.projection:
11247  prem_embed = self.relu(self.projection(prem_embed))
11248  hypo_embed = self.relu(self.projection(hypo_embed))
11249  premise = self.encoder(prem_embed)
11250  hypothesis = self.encoder(hypo_embed)
11251  scores = self.out(torch.cat([premise, hypothesis], 1))
11252  return scores
11253 
11254  class Config:
11255  n_embed = 100
11256  d_embed = 100
11257  d_proj = 300
11258  dp_ratio = 0.0 # For deterministic testing TODO: change by fixing seed in checkTrace?
11259  d_hidden = 30
11260  birnn = True
11261  d_out = 300
11262  fix_emb = True
11263  projection = True
11264  n_layers = 2
11265  n_cells = 4 # 2 * n_layers because birnn = True
11266 
11267  premise = torch.LongTensor(48, 64).random_(0, 100).to(device)
11268  hypothesis = torch.LongTensor(24, 64).random_(0, 100).to(device)
11269 
11270  if quantized:
11271  snli = SNLIClassifier(Config()).cpu()
11273  # we don't do export/import checks because we would need to call
11274  # _pack/_unpack
11275  self.checkTrace(snli, (premise, hypothesis), inputs_require_grads=False,
11276  export_import=False)
11277  else:
11278  self.checkTrace(SNLIClassifier(Config()).to(device), (premise, hypothesis),
11279  inputs_require_grads=False, export_import=check_export_import)
11280 
11281  def test_snli(self):
11282  self._test_snli(self, device='cpu')
11283 
11284  if not TEST_WITH_UBSAN and torch.fbgemm_is_cpu_supported():
11285  def test_snli_quantized(self):
11286  self._test_snli(self, device='cpu', quantized=True)
11287 
11288  @unittest.skipIf(not RUN_CUDA, "no CUDA")
11289  def test_snli_cuda(self):
11290  # XXX: export_import on CUDA modules doesn't work (#11480)
11291  self._test_snli(self, device='cuda', check_export_import=False)
11292 
11293  @staticmethod
11294  def _test_super_resolution(self, device, check_export_import=True):
11295  import torch.nn.init as init
11296 
11297  class Net(nn.Module):
11298 
11299  def __init__(self, upscale_factor):
11300  super(Net, self).__init__()
11301 
11302  self.relu = nn.ReLU()
11303  self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
11304  self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
11305  self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
11306  self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
11307  self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
11308 
11309  def forward(self, x):
11310  x = self.relu(self.conv1(x))
11311  x = self.relu(self.conv2(x))
11312  x = self.relu(self.conv3(x))
11313  x = self.pixel_shuffle(self.conv4(x))
11314  return x
11315 
11316  net = Net(upscale_factor=4).to(device)
11317  self.checkTrace(net, (torch.rand(5, 1, 32, 32, device=device),),
11318  export_import=check_export_import)
11319 
11320  def test_super_resolution(self):
11321  self._test_super_resolution(self, device='cpu')
11322 
11323  @unittest.skipIf(not RUN_CUDA, 'no CUDA')
11324  def test_super_resolution_cuda(self):
11325  # XXX: export_import on CUDA modules doesn't work (#11480)
11326  self._test_super_resolution(self, device='cuda', check_export_import=False)
11327 
11328  @suppress_warnings
11329  def test_time_sequence_prediction(self):
11330  class Sequence(torch.jit.ScriptModule):
11331  def __init__(self):
11332  super(Sequence, self).__init__()
11333  self.lstm1 = nn.LSTMCell(1, 51)
11334  self.lstm2 = nn.LSTMCell(51, 51)
11335  self.linear = nn.Linear(51, 1)
11336 
11337  # TODO: could not pass tuple to a python Op and type annotations
11338  # is not descending to python signature, hence the wrapper
11339  # see https://github.com/pytorch/pytorch/issues/8778
11340  # and https://github.com/pytorch/pytorch/issues/8777
11341  def test_lstm1(self, input, hx, cx):
11342  # type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor]
11343  return self.lstm1(input, (hx, cx))
11344 
11345  def test_lstm2(self, input, hx, cx):
11346  # type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor]
11347  return self.lstm2(input, (hx, cx))
11348 
11349  # TODO: could not support tensor constructors in script
11350  # see https://github.com/pytorch/pytorch/issues/8814
11351  def test_tensor(self):
11352  return torch.tensor([], dtype=torch.double)
11353 
11354  @torch.jit.script_method
11355  def forward(self, input):
11356  # TODO: add future as input with default val
11357  # see https://github.com/pytorch/pytorch/issues/8724
11358  outputs = self.test_tensor()
11359  h_t = torch.zeros((3, 51), dtype=torch.double)
11360  c_t = torch.zeros((3, 51), dtype=torch.double)
11361  h_t2 = torch.zeros((3, 51), dtype=torch.double)
11362  c_t2 = torch.zeros((3, 51), dtype=torch.double)
11363 
11364  output = torch.zeros([3, 51])
11365  future = 2
11366 
11367  # TODO: chunk call should appear as the for loop iterable
11368  # We hard-code it to 4 for now.
11369  a, b, c, d = input.chunk(input.size(1), dim=1)
11370  for input_t in (a, b, c, d):
11371  h_t, c_t = self.test_lstm1(input_t, h_t, c_t)
11372  h_t2, c_t2 = self.test_lstm2(h_t, h_t2, c_t2)
11373  output = self.linear(h_t2)
11374  outputs = torch.cat((outputs, output), 1)
11375  for _ in range(future): # if we should predict the future
11376  h_t, c_t = self.test_lstm1(output, h_t, c_t)
11377  h_t2, c_t2 = self.test_lstm2(h_t, h_t2, c_t2)
11378  output = self.linear(h_t2)
11379  outputs = torch.cat((outputs, output), 1)
11380  return outputs
11381 
11382  # TODO: toggle export_import once above issues are fixed
11383  self.checkTrace(Sequence(), (torch.rand(3, 4),),
11384  export_import=False)
11385 
11386  @staticmethod
11387  def _test_vae(self, device, check_export_import=True, quantized=False):
11388  class VAE(nn.Module):
11389  def __init__(self):
11390  super(VAE, self).__init__()
11391 
11392  self.fc1 = nn.Linear(784, 400)
11393  self.fc21 = nn.Linear(400, 20)
11394  self.fc22 = nn.Linear(400, 20)
11395  self.fc3 = nn.Linear(20, 400)
11396  self.fc4 = nn.Linear(400, 784)
11397 
11398  def encode(self, x):
11399  h1 = F.relu(self.fc1(x))
11400  return self.fc21(h1), self.fc22(h1)
11401 
11402  def reparameterize(self, mu, logvar):
11403  if self.training:
11404  std = torch.exp(0.5 * logvar)
11405  eps = torch.randn_like(std)
11406  return eps.mul(std).add_(mu)
11407  else:
11408  return mu
11409 
11410  def decode(self, z):
11411  h3 = F.relu(self.fc3(z))
11412  return torch.sigmoid(self.fc4(h3))
11413 
11414  def forward(self, x):
11415  mu, logvar = self.encode(x.view(-1, 784))
11416  z = self.reparameterize(mu, logvar)
11417  return self.decode(z), mu, logvar
11418 
11419  if quantized:
11420  vae = VAE().to(device).eval()
11422  # We don't do export/import checks because we would need to call
11423  # _unpack and _pack
11424  self.checkTrace(vae, (torch.rand(128, 1, 28, 28, device=device),),
11425  export_import=False, allow_unused=True,
11426  inputs_require_grads=False)
11427  else:
11428  # eval() is present because randn_like makes this nondeterministic
11429  self.checkTrace(VAE().to(device).eval(), (torch.rand(128, 1, 28, 28, device=device),),
11430  export_import=check_export_import)
11431 
11432  def test_vae(self):
11433  self._test_vae(self, device='cpu')
11434 
11435  if not TEST_WITH_UBSAN and torch.fbgemm_is_cpu_supported():
11436  def test_vae_quantized(self):
11437  self._test_vae(self, device='cpu', quantized=True)
11438 
11439  @unittest.skipIf(not RUN_CUDA, "no CUDA")
11440  def test_vae_cuda(self):
11441  # XXX: export_import on CUDA modules doesn't work (#11480)
11442  self._test_vae(self, device='cuda', check_export_import=False)
11443 
11444 
11445 # Smoke tests for export methods
11447  class MyModel(nn.Module):
11448  def __init__(self):
11449  super(TestPytorchExportModes.MyModel, self).__init__()
11450 
11451  def forward(self, x):
11452  return x.transpose(0, 1)
11453 
11454  def test_protobuf(self):
11455  torch_model = TestPytorchExportModes.MyModel()
11456  fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
11457  f = io.BytesIO()
11458  torch.onnx._export(torch_model, (fake_input), f, verbose=False,
11459  export_type=torch.onnx.ExportTypes.PROTOBUF_FILE)
11460 
11461  def test_zipfile(self):
11462  torch_model = TestPytorchExportModes.MyModel()
11463  fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
11464  f = io.BytesIO()
11465  torch.onnx._export(torch_model, (fake_input), f, verbose=False,
11466  export_type=torch.onnx.ExportTypes.ZIP_ARCHIVE)
11467 
11468  def test_compressed_zipfile(self):
11469  torch_model = TestPytorchExportModes.MyModel()
11470  fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
11471  f = io.BytesIO()
11472  torch.onnx._export(torch_model, (fake_input), f, verbose=False,
11473  export_type=torch.onnx.ExportTypes.COMPRESSED_ZIP_ARCHIVE)
11474 
11475  def test_directory(self):
11476  torch_model = TestPytorchExportModes.MyModel()
11477  fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
11478  d = tempfile.mkdtemp()
11479  torch.onnx._export(torch_model, (fake_input), d, verbose=False,
11480  export_type=torch.onnx.ExportTypes.DIRECTORY)
11481  shutil.rmtree(d)
11482 
11483  def test_onnx_multiple_return(self):
11484  @torch.jit.script
11485  def foo(a):
11486  return (a, a)
11487  f = io.BytesIO()
11488  x = torch.ones(3)
11489  torch.onnx._export(foo, (x,), f, example_outputs=(x, x))
11490 
11491  @skipIfNoLapack
11492  def test_aten_fallback(self):
11493  class ModelWithAtenNotONNXOp(nn.Module):
11494  def forward(self, x, y):
11495  abcd = x + y
11496  defg = torch.qr(abcd)
11497  return defg
11498 
11499  x = torch.rand(3, 4)
11500  y = torch.rand(3, 4)
11501  f = io.BytesIO()
11503  ModelWithAtenNotONNXOp(), (x, y), f,
11504  operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK)
11505  self.assertExpected(exported)
11506 
11507  # torch.fmod is using to test ONNX_ATEN.
11508  # If you plan to remove fmod from aten, or found this test failed.
11509  # please contact @Rui.
11510  def test_onnx_aten(self):
11511  class ModelWithAtenFmod(nn.Module):
11512  def forward(self, x, y):
11513  return torch.fmod(x, y)
11514 
11515  f = io.BytesIO()
11516  x = torch.randn(3, 4, dtype=torch.float32)
11517  y = torch.randn(3, 4, dtype=torch.float32)
11519  ModelWithAtenFmod(), (x, y), f,
11520  operator_export_type=OperatorExportTypes.ONNX_ATEN)
11521  self.assertExpected(exported)
11522 
11523 
11524 # known to be failing in tracer
11525 EXCLUDE_TRACED = {
11526  # The following fail due to #12024.
11527  # A prim::ListConstruct is involved and the indices get traced as TensorType,
11528  # which always require_grad. This causes a crash in autodiff.
11529  'test___getitem___adv_index',
11530  'test___getitem___adv_index_beg',
11531  'test___getitem___adv_index_comb',
11532  'test___getitem___adv_index_dup',
11533  'test___getitem___adv_index_sub',
11534  'test___getitem___adv_index_sub_2',
11535  'test___getitem___adv_index_sub_3',
11536  'test___getitem___adv_index_var',
11537 
11538 }
11539 
11540 EXCLUDE_TYPE_CHECK = {
11541  # slogdet tests use itemgetter to select its only differentiable output,
11542  # but this happens outside of the graph we handle, so there are fewer
11543  # reference outputs than graph outputs.
11544  'test_slogdet_1x1_neg_det',
11545  'test_slogdet_1x1_pos_det',
11546  'test_slogdet_distinct_singular_values',
11547  'test_slogdet_neg_det',
11548  'test_slogdet_pos_det',
11549  'test_slogdet_symmetric',
11550  'test_slogdet_symmetric_pd',
11551 }
11552 
11553 # known to be failing in script
11554 EXCLUDE_SCRIPT = {
11555  'test_norm_fro',
11556  'test_norm_fro_default',
11557  'test_norm_nuc',
11558 
11559  # aten op has additional cudnn argument
11560  'test_nn_unfold',
11561 
11562  # flaky test - TODO fix
11563  'test_nn_ctc_loss',
11564 
11565  # unknown builtin op
11566  'test_nn_fold',
11567 }
11568 
11569 EXCLUDE_PYTHON_PRINT = {
11570  # no support for BroadcastingList in python printer
11571  'test_nn_max_unpool1d',
11572  'test_nn_max_unpool2d',
11573  'test_nn_max_unpool3d',
11574  'test_nn_max_pool1d',
11575  'test_nn_max_pool2d',
11576  'test_nn_max_pool3d',
11577  'test_nn_max_pool1d_with_indices',
11578 }
11579 
11580 EXCLUDE_SCRIPT_MODULES = {
11581  'test_nn_AdaptiveAvgPool2d_tuple_none',
11582  'test_nn_AdaptiveAvgPool3d_tuple_none',
11583  'test_nn_AdaptiveMaxPool2d_tuple_none',
11584  'test_nn_AdaptiveMaxPool3d_tuple_none',
11585 }
11586 
11587 DISABLE_AUTODIFF_SUBGRAPH_INLINING = {
11588  'test_nn_avg_pool2d',
11589  'test_nn_adaptive_avg_pool1d',
11590  'test_nn_adaptive_avg_pool2d',
11591  'test_nn_adaptive_avg_pool3d',
11592  'test_nn_batch_norm',
11593  'test_nn_embedding',
11594  'test_nn_log_softmax',
11595  'test_nn_softmax',
11596  'test_nn_softmax_with_all_args',
11597  'test_nn_threshold',
11598  'test_nn_nll_loss',
11599  # Should have added all test_nn_interpolate_* here,
11600  # but it's using autodiff since its subgraph is over
11601  # 2 nodes.
11602 }
11603 
11604 
11605 # make a new function where all non-tensor arguments in 'args' have been partially
11606 # applied, and all tensor arguments remain.
11607 # used to trace functions when some arguments are not tensors
11608 def partial_apply_nontensors(fn, args, **kwargs):
11609  source = ['t' if isinstance(arg, torch.Tensor) else 's' for arg in args]
11610 
11611  def new_fn(*tensors_):
11612  tensors = iter(tensors_)
11613  return fn(*(args[i] if s == 's' else next(tensors) for i, s in enumerate(source)), **kwargs)
11614 
11615  return new_fn, [arg for arg in args if isinstance(arg, torch.Tensor)]
11616 
11617 
11618 # create a trace function from input fn
11619 #
11620 # disable_autodiff_subgraph_inlining:
11621 # Don't inline autodiff subgraphs so we can test autodiff
11622 def create_traced_fn(self, fn,
11623  disable_autodiff_subgraph_inlining=False):
11624  def traced_fn(*inputs, **kwargs):
11625  fn_tensors, inputs_tensors = partial_apply_nontensors(fn, inputs, **kwargs)
11626  traced = torch.jit.trace(fn_tensors, inputs_tensors)
11627  self.assertExportImport(traced.graph, inputs_tensors)
11628  if disable_autodiff_subgraph_inlining:
11629  traced.debug_disable_autodiff_subgraph_inlining()
11630  output = traced(*inputs_tensors)
11631  traced_fn.last_graph = traced.graph_for(*inputs_tensors)
11632  return output
11633  return traced_fn
11634 
11635 script_template = '''
11636 def the_method({}):
11637  return {}
11638 '''
11639 
11640 script_method_template = '''
11641 def forward({}):
11642  return {}
11643 '''
11644 
11645 
11646 def get_constant(x):
11647  if x == inf:
11648  return 'float(\'inf\')' if PY2 else 'math.inf'
11649  if x == -inf:
11650  return 'float(\'-inf\')' if PY2 else '-math.inf'
11651  return x
11652 
11653 
11654 def get_script_args(args):
11655  formals = []
11656  tensors = []
11657  actuals = []
11658  for arg in args:
11659  if isinstance(arg, torch.Tensor):
11660  name = 'i{}'.format(len(formals))
11661  formals.append(name)
11662  actuals.append(name)
11663  tensors.append(arg)
11664  elif isinstance(arg, str):
11665  actuals.append("'{}'".format(arg))
11666  else:
11667  actuals.append(str(get_constant(arg)))
11668  return (formals, tensors, actuals)
11669 
11670 
11671 # create a script function from (name, func_type, output_process_fn),
11672 # returns a function takes in (args, kwargs) and runs the compiled function and
11673 # then applies the post process fn to the outputs
11674 def create_script_fn(self, method_name, func_type, output_process_fn,
11675  disable_autodiff_subgraph_inlining=False):
11676  def script_fn(*args, **kwargs):
11677  formals, tensors, actuals = get_script_args(args)
11678  kwargs_str = ''
11679  for k, v in kwargs.items():
11680  kwargs_str += ', ' + k + '=' + str(v)
11681  if func_type == 'functional':
11682  call = 'torch.{}({}{})'.format(method_name, ', '.join(actuals), kwargs_str)
11683  elif func_type == 'method':
11684  call = '{}.{}({}{})'.format(actuals[0], method_name, ', '.join(actuals[1:]), kwargs_str)
11685  elif func_type == 'nn_functional':
11686  call = 'torch.nn.functional.{}({}{})'.format(method_name, ', '.join(actuals), kwargs_str)
11687  else:
11688  raise 'Unsupported function type'
11689 
11690  script = script_template.format(', '.join(formals), call)
11691 
11692  CU = torch.jit.CompilationUnit(script)
11693  if disable_autodiff_subgraph_inlining:
11694  CU.the_method.debug_disable_autodiff_subgraph_inlining()
11695  self.assertExportImport(CU.the_method.graph, tensors)
11696  output = output_process_fn(CU.the_method(*tensors))
11697  script_fn.last_graph = CU.the_method.graph_for(*tensors)
11698  return output
11699  return script_fn
11700 
11701 
11702 def check_alias_annotation(method_name, args, kwargs):
11703  formals, tensors, actuals = get_script_args(args)
11704  kwargs_str = ''
11705  for k, v in kwargs.items():
11706  kwargs_str += ', ' + k + '=' + str(v)
11707  call = '{}.{}({}{})'.format(actuals[0], method_name, ', '.join(actuals[1:]), kwargs_str)
11708  script = script_template.format(', '.join(formals), call)
11709  CU = torch.jit.CompilationUnit(script)
11710  torch._C._jit_check_alias_annotation(CU.the_method.graph, tuple(tensors), method_name)
11711 
11712 
11713 def check_output_types(self, func, ref_outputs, args, kwargs):
11714  graph = getattr(func, 'last_graph', None)
11715  types = [o.type() for o in graph.outputs()]
11716  self.assertTrue(len(types) == 1)
11717  t = types[0]
11718  torch._C._jit_assert_is_instance(ref_outputs, t)
11719 
11720 
11721 def check_against_reference(self, func, reference_func, args, kwargs=None,
11722  allow_unused=True, check_types=True, no_grad=False):
11723  kwargs = kwargs if kwargs else {}
11724 
11725  def allSum(vs):
11726  if isinstance(vs, torch.Tensor):
11727  vs = (vs,)
11728  return sum((i + 1) * v.sum()
11729  for i, v in enumerate(vs)
11730  if v is not None and v.dtype.is_floating_point)
11731 
11732  def clone_inputs(requires_grad):
11733  inputs = [
11734  arg.detach().clone().requires_grad_(requires_grad and arg.requires_grad)
11735  if isinstance(arg, torch.Tensor) else arg for arg in args
11736  ]
11737  return inputs, [input for input in inputs if isinstance(input, torch.Tensor) and input.requires_grad]
11738 
11739  nograd_inputs, nograd_tensors = clone_inputs(False)
11740  recording_inputs, recording_tensors = clone_inputs(True)
11741 
11742  # test no gradients case
11743  outputs = self.runAndSaveRNG(reference_func, nograd_inputs, kwargs)
11744  outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs)
11745  self.assertEqual(outputs, outputs_test)
11746 
11747  if check_types:
11748  check_output_types(self, func, outputs_test, nograd_inputs, kwargs)
11749 
11750  if no_grad:
11751  # skip grad tests
11752  return
11753 
11754  # test single grad case
11755  outputs = self.runAndSaveRNG(reference_func, recording_inputs, kwargs)
11756  grads = torch.autograd.grad(allSum(outputs), recording_tensors,
11757  allow_unused=allow_unused)
11758 
11759  outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs)
11760  grads_test = torch.autograd.grad(allSum(outputs_test), recording_tensors,
11761  allow_unused=allow_unused)
11762  self.assertEqual(outputs, outputs_test)
11763  self.assertEqual(grads, grads_test)
11764 
11765  # test the grad grad case
11766  if self._testMethodName in nn_functional_single_grad:
11767  return
11768 
11769  outputs = self.runAndSaveRNG(reference_func, recording_inputs, kwargs)
11770  l1 = allSum(outputs)
11771  grads = torch.autograd.grad(l1, recording_tensors, create_graph=True,
11772  allow_unused=allow_unused)
11773  l2 = (allSum(grads) * l1)
11774  grads2 = torch.autograd.grad(l2, recording_tensors, allow_unused=allow_unused)
11775 
11776  recording_inputs, recording_tensors = clone_inputs(True)
11777 
11778  outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs)
11779  l1_test = allSum(outputs_test)
11780  grads_test = torch.autograd.grad(
11781  l1_test, recording_tensors, create_graph=True, allow_unused=allow_unused)
11782  l2_test = (allSum(grads_test) * l1_test)
11783  grads2_test = torch.autograd.grad(l2_test, recording_tensors, allow_unused=allow_unused)
11784 
11785  self.assertEqual(outputs, outputs_test)
11786  self.assertEqual(grads, grads_test)
11787  for g2, g2_test in zip(grads2, grads2_test):
11788  if g2 is None and g2_test is None:
11789  continue
11790  self.assertTrue(torch.allclose(g2, g2_test, atol=5e-4, rtol=1e-4))
11791 
11792 
11793 class TestFuser(JitTestCase):
11794  def assertAllFused(self, graph, except_for=()):
11795  if [n.kind() for n in graph.nodes()] == ['prim::DifferentiableGraph']:
11796  graph = next(graph.nodes()).g('Subgraph')
11797  allowed_nodes = {'prim::Constant', 'prim::FusionGroup'} | set(except_for)
11798  self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()),
11799  'got {}'.format(graph))
11800  self.assertTrue([node.kind() for node in graph.nodes()].count('prim::FusionGroup') == 1)
11801 
11802  def _test_fused_abs(self, device='cpu'):
11803 
11804  @torch.jit.script
11805  def func(x):
11806  return x.abs() * 2
11807 
11808  a = torch.randn(5, device=device)
11809  self.assertEqual(func(a), a.abs() * 2)
11810  self.assertAllFused(func.graph_for(a))
11811 
11812  @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
11813  @enable_cpu_fuser
11814  def test_abs_cpu(self):
11815  self._test_fused_abs()
11816 
11817  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
11818  @unittest.skipIf(not RUN_CUDA, "requires CUDA")
11819  @skipIfRocm
11820  def test_abs_cuda(self):
11821  self._test_fused_abs(device="cuda")
11822 
11823  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
11824  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
11825  def test_arg_configurations_smoke_cuda(self):
11826  # A smoke test to make sure we won't use the same kernel for contiguous
11827  # and non-contiguous arguments.
11828  # TODO: add optionally enabled debug counters to the fuser to verify
11829  # that we really can tell the difference between configurations
11830  def f(x, y):
11831  z1, z2 = (x + y).chunk(2, dim=1)
11832  return z1 * z2
11833 
11834  x = torch.randn(4, 4, dtype=torch.float, device='cuda')
11835  y = torch.randn(4, 4, dtype=torch.float, device='cuda')
11836  traced_f = torch.jit.trace(f, (x, y,))
11837  self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y))
11838 
11839  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
11840  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
11841  @skipIfRocm
11842  def test_broadcast_cuda(self):
11843  def scaleshift(x, scale, shift):
11844  return x * scale + shift
11845 
11846  inputs = [
11847  torch.randn(4, 4, dtype=torch.float, device='cuda'),
11848  torch.randn(4, dtype=torch.float, device='cuda'),
11849  torch.randn(4, dtype=torch.float, device='cuda'),
11850  ]
11851  ge = self.checkTrace(scaleshift, inputs)
11852  self.assertAllFused(ge.graph_for(*inputs))
11853 
11854  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
11855  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
11856  @unittest.skipIf(not RUN_CUDA_HALF, "no half support")
11857  def test_cuda_half(self):
11858  x = torch.randn(4, 4, dtype=torch.half, device='cuda')
11859  y = torch.randn(4, 4, dtype=torch.half, device='cuda')
11860 
11861  funcs = [
11862  self.fn_test_comparison_gt_lt,
11863  self.fn_test_relu,
11864  self.fn_test_exp
11865  ]
11866 
11867  # Note: Non fused inputs must be float to prevent loss of precision
11868  inputs = (x.float(), y.float())
11869  fusion_inputs = (x, y)
11870  for fn in funcs:
11871  local_inputs = [t.clone().requires_grad_() for t in inputs]
11872  local_fusion_inputs = [t.clone().requires_grad_() for t in fusion_inputs]
11873 
11874  # Verifies outputs
11875  fusion = torch.jit.trace(fn, local_fusion_inputs, check_trace=False, optimize=True)
11876  outputs = fn(*local_inputs)
11877  fusion_outputs = fusion(*local_fusion_inputs)
11878  outputs_half = [t.half() for t in outputs]
11879  self.assertEqual(outputs_half, fusion_outputs)
11880 
11881  # Verifies gradients
11882  for output, fusion_output in zip(outputs_half, fusion_outputs):
11883  grads = torch.autograd.grad(
11884  output.float().sum(), local_inputs, allow_unused=True, retain_graph=True)
11885  fusion_grads = torch.autograd.grad(
11886  fusion_output.sum(), local_fusion_inputs, allow_unused=True, retain_graph=True)
11887  grads_half = [t.half() for t in grads]
11888  self.assertEqual(grads_half, fusion_grads)
11889 
11890  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
11891  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
11892  @skipIfRocm
11893  def test_checks_cat_inputs(self):
11894  # We shouldn't treat cat nodes as broadcasting. All their inputs
11895  # need to be checked for having the same map size, before we can
11896  # run the kernel.
11897  @torch.jit.script
11898  def f(x, y):
11899  return torch.cat([x + 2 * x + x ** 2, y + 4 * y + y ** 3], dim=0)
11900 
11901  # NOTE: y is broadcastable to x, but output of f(x, y) should have
11902  # shape 3x4, and not 4x4.
11903  x = torch.randn(2, 4, dtype=torch.float, device='cuda')
11904  y = torch.randn(1, 4, dtype=torch.float, device='cuda')
11905 
11906  self.assertEqual(f(x, y).shape, (3, 4))
11907  self.assertAllFused(f.graph_for(x, y))
11908 
11909  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
11910  @unittest.skipIf(not RUN_CUDA, "No CUDA")
11911  @skipIfRocm
11912  def test_chunk_cuda(self):
11913  def fn(x):
11914  a, b, c = x.chunk(3, 1)
11915  return a * b + c
11916 
11917  inputs = [torch.randn(10, 6, dtype=torch.float, device='cuda')]
11918 
11919  ge = self.checkScript(fn, inputs)
11920  graph = ge.graph_for(*inputs)
11921  self.assertAllFused(graph)
11922  FileCheck().check("prim::ConstantChunk[chunks=3, dim=1]").run(str(graph))
11923 
11924  @staticmethod
11925  def _test_chunk_correctness(self, device='cpu'):
11926  def chunk_4_0(x):
11927  x0, x1, x2, x3 = x.chunk(4, 0)
11928  return x0 + x1 + x2 + x3
11929 
11930  def chunk_4_1(x):
11931  x0, x1, x2, x3 = x.chunk(4, 1)
11932  return x0 + x1 + x2 + x3
11933 
11934  def chunk_4_last(x):
11935  x0, x1, x2, x3 = x.chunk(4, 2)
11936  return x0 + x1 + x2 + x3
11937 
11938  fns = [chunk_4_0, chunk_4_1, chunk_4_last]
11939  tensors = [
11940  # splitSize = 1
11941  torch.randn(4, 4, 4, dtype=torch.float, device=device),
11942 
11943  # contiguous case
11944  torch.randn(12, 8, 16, dtype=torch.float, device=device),
11945 
11946  # non-contiguous case
11947  torch.randn(12, 8, 16, dtype=torch.float, device=device).transpose(1, 2),
11948  ]
11949 
11950  for tensor in tensors:
11951  for fn in fns:
11952  self.checkScript(fn, [tensor])
11953 
11954  @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
11955  @enable_cpu_fuser
11956  def test_chunk_correctness(self):
11957  return self._test_chunk_correctness(self, 'cpu')
11958 
11959  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
11960  @unittest.skipIf(not RUN_CUDA, "No CUDA")
11961  def test_chunk_correctness_cuda(self):
11962  return self._test_chunk_correctness(self, 'cuda')
11963 
11964  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
11965  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
11966  @skipIfRocm
11967  def test_chunk_distributes_cuda(self):
11968  def f(x, y):
11969  z1, z2 = (x + y).chunk(2, dim=1)
11970  return z1 * z2
11971 
11972  x = torch.randn(4, 4, dtype=torch.float, device='cuda')
11973  y = torch.randn(4, 4, dtype=torch.float, device='cuda')
11974 
11975  ge = self.checkTrace(f, (x, y))
11976  graph = ge.graph_for(x, y)
11977  FileCheck().check("broadcast_tensors").check('with prim::FusionGroup_0') \
11978  .check_count('ConstantChunk', 2, exactly=True).run(str(graph))
11979 
11980  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
11981  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
11982  @skipIfRocm
11983  def test_chunk_motion_deduplicates_inputs(self):
11984  def func1(x):
11985  z = x * x
11986  z0, z1 = z.chunk(2)
11987  return z0 * z1
11988 
11989  def func2(x):
11990  z = x * x * x
11991  z0, z1 = z.chunk(2)
11992  return z0 * z1
11993 
11994  inputs = [
11995  torch.tensor([1.1, 1.2], device='cuda', dtype=torch.float),
11996  ]
11997  for func in [func1, func2]:
11998  module = self.checkScript(func, inputs)
11999  forward_graph = module.graph_for(*inputs)
12000  self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1)
12001  fusion_group = list(forward_graph.nodes())[-1]
12002  self.assertEqual(len(list(fusion_group.inputs())), 1)
12003 
12004  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
12005  @unittest.skipIf(not RUN_CUDA, "No CUDA")
12006  @skipIfRocm
12007  def test_chunk_multiple_cuda(self):
12008  # The arguments are intentionally used out of order as a test to see
12009  # if the fusion compiler adds extra args in the correct order
12010  def fn(s, x, y, z):
12011  z1, z2 = z.chunk(2, 2)
12012  x1, x2, x3 = x.chunk(3, 1)
12013  y1, y2 = y.chunk(2, 0)
12014  return s + x1 + x2 + x3 + y1 + y2 + z1 + z2
12015 
12016  inputs = [
12017  torch.randn(5, 2, 3, dtype=torch.float, device='cuda'),
12018  torch.randn(5, 6, 3, dtype=torch.float, device='cuda'),
12019  torch.randn(10, 2, 3, dtype=torch.float, device='cuda'),
12020  torch.randn(5, 2, 6, dtype=torch.float, device='cuda'),
12021  ]
12022 
12023  ge = self.checkScript(fn, inputs)
12024  self.assertAllFused(ge.graph_for(*inputs))
12025 
12026  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
12027  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
12028  @skipIfRocm
12029  def test_clamp(self):
12030  def func2(a, b):
12031  return torch.clamp(a + b, min=0, max=2)
12032 
12033  def funcInf(a, b):
12034  return torch.clamp(a + b, min=0, max=float('inf'))
12035 
12036  def funcOptMin(a, b):
12037  return torch.clamp(a + b, max=2)
12038 
12039  def funcOptMax(a, b):
12040  return torch.clamp(a + b, min=0)
12041 
12042  a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
12043  b = torch.randn(4, 4, dtype=torch.float, device='cuda')
12044  nan = torch.tensor(float('nan'))
12045 
12046  funcs = (func2, funcInf, funcOptMin, funcOptMax)
12047  for f, inputs in product(funcs, [[a, b], [a, nan]]):
12048  inp1, inp2 = inputs
12049  s = self.checkScript(f, (inp1, inp2))
12050  self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size'})
12051 
12052  c = s(inp1, inp2)
12053  c.sum().backward()
12054  graph = backward_graph(s)
12055  self.assertAllFused(graph)
12056 
12057  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
12058  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
12059  @skipIfRocm
12060  def test_comparison_eq_ne(self):
12061  def f(x, y):
12062  mask = (x == 0).type_as(x)
12063  z = x * mask + y
12064  mask = (x != 0).type_as(x)
12065  z = z * mask + y
12066  return z
12067 
12068  x = torch.randn(4, 4, dtype=torch.float, device='cuda')
12069  y = torch.randn(4, 4, dtype=torch.float, device='cuda')
12070 
12071  ge = self.checkTrace(f, (x, y))
12072  self.assertAllFused(ge.graph_for(x, y))
12073 
12074  @staticmethod
12075  def fn_test_comparison_gt_lt(x, y):
12076  mask = (x > 0).type_as(x)
12077  z = x * mask + y
12078  mask = (x < 0).type_as(x)
12079  z = z * mask + y
12080  return z
12081 
12082  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
12083  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
12084  @skipIfRocm
12085  def test_comparison_gt_lt_cuda(self):
12086  x = torch.randn(4, 4, dtype=torch.float, device='cuda')
12087  y = torch.randn(4, 4, dtype=torch.float, device='cuda')
12088 
12089  ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y))
12090  self.assertAllFused(ge.graph_for(x, y))
12091 
12092  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
12093  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
12094  @skipIfRocm
12095  def test_comparison_ge_le_cuda(self):
12096  def f(x, y):
12097  mask = (x >= 0).type_as(x)
12098  z = x * mask + y
12099  mask = (x <= 0).type_as(x)
12100  z = z * mask + y
12101  return z
12102 
12103  x = torch.randn(4, 4, dtype=torch.float, device='cuda')
12104  y = torch.randn(4, 4, dtype=torch.float, device='cuda')
12105 
12106  ge = self.checkTrace(f, (x, y))
12107  self.assertAllFused(ge.graph_for(x, y))
12108  x.requires_grad_(True)
12109  y.requires_grad_(True)
12110  self.assertAllFused(ge.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes"))
12111 
12112  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
12113  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
12114  @skipIfRocm
12115  def test_concat_cuda(self):
12116  hx = torch.randn(3, 20, dtype=torch.float, device='cuda')
12117  cx = torch.randn(3, 20, dtype=torch.float, device='cuda')
12118 
12119  def foo(hx, cx):
12120  return torch.cat((hx + cx, hx * cx))
12121 
12122  ge = self.checkTrace(foo, (hx, cx))
12123  graph = ge.graph_for(hx, cx)
12124  self.assertAllFused(graph)
12125  FileCheck().check("FusedConcat").check_next("return").run(str(graph))
12126 
12127  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
12128  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
12129  @skipIfRocm
12130  def test_concat_invariant_cuda(self):
12131  # Invariant: the output of prim::FusedConcat may
12132  # not be an input to any node inside the FusionGroup.
12133  def fn(x, y, z):
12134  x1 = x + y
12135  y1 = x - y
12136  w = torch.cat([x1, y1])
12137  return w + z
12138 
12139  x = torch.randn(2, 2, dtype=torch.float, device='cuda')
12140  y = torch.randn(2, 2, dtype=torch.float, device='cuda')
12141  z = torch.randn(4, 2, dtype=torch.float, device='cuda')
12142  ge = self.checkTrace(fn, (x, y, z))
12143  graph = ge.graph_for(x, y, z)
12144  self.assertAllFused(graph, except_for={'aten::add'})
12145  FileCheck().check("FusedConcat").check_next("return").run(str(graph))
12146 
12147  @staticmethod
12148  def fn_test_exp(x, y):
12149  return (x + .5 * y).exp()
12150 
12151  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
12152  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
12153  @skipIfRocm
12154  def test_exp_cuda(self):
12155  x = torch.randn(4, 4, dtype=torch.float, device='cuda')
12156  y = torch.randn(4, 4, dtype=torch.float, device='cuda')
12157 
12158  ge = self.checkTrace(self.fn_test_exp, (x, y))
12159  self.assertAllFused(ge.graph_for(x, y))
12160 
12161  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
12162  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
12163  @skipIfRocm
12164  def test_fuse_batch_norm(self):
12165 
12166  class ResLike(torch.jit.ScriptModule):
12167  def __init__(self, optimize=True):
12168  super(ResLike, self).__init__(optimize)
12169  self.bn = nn.BatchNorm2d(16)
12170 
12171  @torch.jit.script_method
12172  def forward(self, x, y):
12173  return y + torch.relu(self.bn(x))
12174 
12175  model = ResLike().cuda()
12176  model_noopt = ResLike(optimize=False).cuda()
12177  model_noopt.load_state_dict(model.state_dict())
12178  x = torch.randn(2, 16, 8, 8, device='cuda')
12179  y = torch.randn(2, 16, 8, 8, device='cuda')
12180  # FIXME: We need differentiation for CNNs for this optimization to trigger
12181  with torch.no_grad():
12182  out = model(x, y)
12183  graph = model.graph_for(x, y)
12184  rep = str(graph)
12185 
12186  out_noopt = model_noopt(x, y)
12187  rep_noopt = str(model_noopt.graph_for(x, y))
12188  self.assertEqual(out, out_noopt, prec=3e-5)
12189 
12190  # Check that batch_norm has really been decomposed
12191  self.assertIn('aten::batch_norm_update_stats', rep)
12192  self.assertNotIn('aten::batch_norm(', rep)
12193  self.assertIn('aten::batch_norm(', rep_noopt)
12194 
12195  # Make sure the fusion group is big, and contains aten::sqrt, which could
12196  # originate only from decomposing batch_norm in this case
12197  fusion_groups = [node for node in graph.nodes() if node.kind() == 'prim::FusionGroup']
12198  self.assertEqual(len(fusion_groups), 1)
12199  fused_graph = fusion_groups[0].g('Subgraph')
12200  self.assertTrue(any(node.kind() == 'aten::sqrt' for node in fused_graph.nodes()))
12201 
12202  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
12203  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
12204  @skipIfRocm
12205  def test_threshold(self):
12206  def f(x):
12207  return torch.threshold(x, 0, -10) + x + x + x
12208 
12209  x = torch.tensor([-1, -0.5, 0, 1, 2, 3], device='cuda')
12210  scripted = torch.jit.script(f)
12211 
12212  self.assertEqual(f(x), scripted(x))
12213  self.assertAllFused(scripted.graph_for(x))
12214 
12215  @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
12216  @enable_cpu_fuser
12217  def test_fuser_deduplication(self):
12218  # See that fusion kernel outputs are deduplicated when removing _grad_sum_to_size in the fuser's compilation
12219  # see the discussion in PR #14957.
12220  def f(x, y):
12221  return torch.sigmoid(x + y)
12222 
12223  b = torch.randn(5, 5, requires_grad=True)
12224  a = torch.randn(5, 5, requires_grad=True)
12225  s = self.checkScript(f, (a, b))
12226  self.assertAllFused(s.graph_for(a, b), except_for={'aten::size'})
12227 
12228  c = s(a, b)
12229  ga, gb = torch.autograd.grad(c.sum(), [a, b])
12230  graph = backward_graph(s)
12231  self.assertAllFused(graph)
12232  # check that a, b share storage, i.e. were generated as a single output in the fuser
12233  self.assertEqual(ga.data_ptr(), gb.data_ptr())
12234 
12235  @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
12236  @enable_cpu_fuser
12237  def test_fuser_iou(self):
12238  # This checks if most of Intersection over Union is fused.
12239  # In particular, the backward contains many _grad_sum_to_size.
12240  def iou(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2):
12241  ltx = torch.max(b1x1, b2x1) # [N,M]
12242  lty = torch.max(b1y1, b2y1)
12243  rbx = torch.min(b1x2, b2x2)
12244  rby = torch.min(b1y2, b2y2)
12245 
12246  w = (rbx - ltx).clamp(min=0, max=float('inf')) # [N,M]
12247  h = (rby - lty).clamp(min=0, max=float('inf')) # [N,M]
12248  inter = w * h # [N,M]
12249 
12250  area1 = (b1x2 - b1x1) * (b1y2 - b1y2) # [N,1]
12251  area2 = (b2x2 - b2x1) * (b2y2 - b2y2) # [1,M]
12252  iou = inter / (area1 + area2 - inter)
12253  return iou
12254 
12255  box1 = torch.randn(5, 4, requires_grad=True)
12256  box2 = torch.randn(5, 4, requires_grad=True)
12257  # unsqueezing can currently not be fused
12258  b1x1 = box1[:, 0].unsqueeze(1) # [N,1]
12259  b1y1 = box1[:, 1].unsqueeze(1)
12260  b1x2 = box1[:, 2].unsqueeze(1)
12261  b1y2 = box1[:, 3].unsqueeze(1)
12262  b2x1 = box2[:, 0].unsqueeze(0) # [1,N]
12263  b2y1 = box2[:, 1].unsqueeze(0)
12264  b2x2 = box2[:, 2].unsqueeze(0)
12265  b2y2 = box2[:, 3].unsqueeze(0)
12266 
12267  s = self.checkScript(iou, (b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2))
12268  self.assertAllFused(s.graph_for(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2),
12269  except_for={'aten::size', 'prim::BroadcastSizes'})
12270 
12271  c = s(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2)
12272  torch.autograd.grad(c.sum(), [b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2])
12273  graph = backward_graph(s)
12274  self.assertAllFused(graph, except_for={'aten::size', 'prim::BroadcastSizes'})
12275 
12276  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
12277  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
12278  @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
12279  @skipIfRocm
12280  @enable_cpu_fuser
12281  def test_fusion_reuse_multi_gpu(self):
12282  def fn(x, y):
12283  return x * y * x * y
12284 
12285  inputs_cpu = [
12286  torch.randn(4, 4, dtype=torch.float),
12287  torch.randn(4, 4, dtype=torch.float),
12288  ]
12289  inputs_cuda0 = [x.cuda(0) for x in inputs_cpu]
12290  inputs_cuda1 = [y.cuda(1) for y in inputs_cpu]
12291 
12292  # Should not crash; these should compile different kernels.
12293  ge = self.checkScript(fn, inputs_cpu)
12294  self.assertAllFused(ge.graph_for(*inputs_cpu))
12295  ge(*inputs_cuda0)
12296  ge(*inputs_cuda1)
12297 
12298  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
12299  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
12300  @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
12301  @skipIfRocm
12302  @enable_cpu_fuser
12303  def test_kernel_cache_multi_gpu(self):
12304  def not_fusible(x):
12305  return x
12306 
12307  def fn(x, y, z):
12308  x_out = x * x * x * x * x # fusion: lambda x. x * x * x * x * x
12309  y_out = y * y * y * y * y
12310  z_out = z * z * z * z * z
12311  return not_fusible(x_out), not_fusible(y_out), not_fusible(z_out)
12312 
12313  inputs = [
12314  torch.randn(4, 4, dtype=torch.float),
12315  torch.randn(4, 4, dtype=torch.float, device='cuda:0'),
12316  torch.randn(4, 4, dtype=torch.float, device='cuda:1'),
12317  ]
12318 
12319  prev_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
12320 
12321  # There are 3 FusionGroups. Because they have the same graph, they
12322  # should reuse the same KernelSpec in the KernelSpec cache.
12323  ge = self.checkScript(fn, inputs)
12324  self.assertGraphContainsExactly(
12325  ge.graph_for(*inputs), 'prim::FusionGroup', 3, True)
12326  new_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
12327  # XXX: This assumes that the same kernel isn't already used by another test
12328  self.assertEqual(new_cache_size - prev_cache_size, 1)
12329 
12330  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
12331  @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
12332  @skipIfRocm
12333  def test_nonzero_device_cuda(self):
12334  device = 'cuda:' + str(1)
12335  x = torch.tensor([0.4], dtype=torch.float, device=device)
12336  y = torch.tensor([0.7], dtype=torch.float, device=device)
12337 
12338  def doit(x, y):
12339  return torch.sigmoid(torch.tanh(x * (x + y) + x))
12340 
12341  ge = self.checkTrace(doit, (x, y))
12342  self.assertAllFused(ge.graph_for(x, y))
12343 
12344  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
12345  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
12346  @skipIfRocm
12347  def test_lstm_cuda(self):
12348  inputs = get_lstm_inputs('cuda', training=True)
12349  module = self.checkScript(LSTMCellS, inputs)
12350  forward_graph = module.graph_for(*inputs)
12351  self.assertGraphContainsExactly(
12352  forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
12353  self.assertTrue(len(list(forward_graph.nodes())) == 2)
12354  # Everything is differentiable but TupleConstruct return
12355  FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
12356  .check_next("return").run(str(forward_graph))
12357 
12358  hy, cy = module(*inputs)
12359  (hy + cy).sum().backward()
12360  backward = backward_graph(module)
12361  FileCheck().check("FusionGroup_0").check_next("FusionGroup_1") \
12362  .check_not("FusionGroup_2").run(str(backward))
12363 
12364  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
12365  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
12366  @skipIfRocm
12367  def test_lstm_concat_cuda(self):
12368  inputs = get_lstm_inputs('cuda')
12369  ge = self.checkTrace(LSTMCellC, inputs)
12370  graph = ge.graph_for(*inputs)
12371  FileCheck().check("FusedConcat").check_next("return").run(str(graph))
12372 
12373  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
12374  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
12375  @skipIfRocm
12376  def test_lstm_gates_permutations_cuda(self):
12377  # lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh.
12378  # Test that any permutation of this will still result in one FusionGroup.
12379  choices = ['x.mm(w_ih.t())', 'hx.mm(w_hh.t())', 'b_ih', 'b_hh']
12380  template = dedent('''
12381  def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
12382  gates = {} + {} + {} + {}
12383  ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
12384  return ingate * forgetgate * cellgate * outgate
12385  ''')
12386  for permutation in itertools.permutations(choices, len(choices)):
12387  code = template.format(*permutation)
12388  scope = {}
12389  exec(code, globals(), scope)
12390  cu = torch.jit.CompilationUnit(code)
12391 
12392  inputs = get_lstm_inputs('cuda', training=False)
12393  self.assertEqual(cu.cell(*inputs), scope['cell'](*inputs))
12394  forward_graph = cu.cell.graph_for(*inputs)
12395  self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1)
12396 
12397  # TODO: Fuser doesn't work at all when inputs require grad. Fix that
12398  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
12399  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
12400  @skipIfRocm
12401  def test_lstm_traced_cuda(self):
12402  inputs = get_lstm_inputs('cuda')
12403  ge = self.checkTrace(LSTMCellF, inputs)
12404  graph = ge.graph_for(*inputs)
12405  FileCheck().check_not("Chunk").check_not("aten::add").check_not("aten::sigmoid") \
12406  .check_not("aten::tanh").check("FusionGroup").check_next("TupleConstruct") \
12407  .check_next("return").check_not("FusionGroup_1").run(str(graph))
12408 
12409  @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
12410  @unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/8746")
12411  @enable_cpu_fuser
12412  def test_lstm_traced_cpu(self):
12413  inputs = get_lstm_inputs('cpu')
12414  try:
12415  ge = self.checkTrace(LSTMCellF, inputs)
12416  graph = ge.graph_for(*inputs)
12417  FileCheck.check("FusionGroup").run(str(graph))
12418  except RuntimeError as e:
12419  if 'Failed to compile' in e.args[0]:
12420  warnings.warn('CPU fuser test has failed! This is not a hard failure, '
12421  'because the kernels sometimes trigger bugs in compilers '
12422  '(most notably GCC 7.2).')
12423  raise unittest.SkipTest('Failed to compile')
12424  else:
12425  raise
12426 
12427  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
12428  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
12429  @skipIfRocm
12430  def test_milstm_cuda(self):
12431  inputs = get_milstm_inputs('cuda', training=True)
12432  module = self.checkScript(MiLSTMCell, inputs)
12433  forward_graph = module.graph_for(*inputs)
12434  self.assertGraphContainsExactly(
12435  forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
12436  FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
12437  .check_next("return").check("FusionGroup").run(str(forward_graph))
12438  hy, cy = module(*inputs)
12439  (hy + cy).sum().backward()
12440 
12441  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
12442  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
12443  @skipIfRocm
12444  def test_rand_cuda(self):
12445  class M(torch.jit.ScriptModule):
12446  __constants__ = ['d']
12447 
12448  def __init__(self):
12449  self.d = torch.device('cuda')
12450 
12451  @torch.jit.script_method
12452  def create(self, x):
12453  return x * x + x + torch.rand_like(x)
12454 
12455  x = torch.zeros([3, 4, 5], dtype=torch.float, device='cuda')
12456  m = M()
12457  out1 = m.create(x)
12458  out2 = m.create(x)
12459  self.assertNotEqual(out1, out2)
12460  self.assertTrue(torch.all(out1 >= 0))
12461  self.assertTrue(torch.all(out1 < 1))
12462  self.assertTrue(torch.all(out2 >= 0))
12463  self.assertTrue(torch.all(out2 < 1))
12464  self.assertAllFused(m.create.graph_for(x))
12465 
12466  @staticmethod
12467  def fn_test_relu(x, y):
12468  return F.relu(x + .5 * y)
12469 
12470  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
12471  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
12472  @skipIfRocm
12473  def test_relu_cuda(self):
12474  x = torch.randn(4, 4, dtype=torch.float, device='cuda')
12475  y = torch.randn(4, 4, dtype=torch.float, device='cuda')
12476 
12477  ge = self.checkTrace(self.fn_test_relu, (x, y))
12478  self.assertAllFused(ge.graph_for(x, y))
12479 
12480  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
12481  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
12482  @skipIfRocm
12483  def test_erf_cuda(self):
12484  def fn_test_erf(x):
12485  return F.relu(torch.erf(x) - torch.erfc(x))
12486 
12487  x = torch.randn(4, 4, dtype=torch.float, device='cuda')
12488  ge = self.checkTrace(fn_test_erf, (x,))
12489  self.assertAllFused(ge.graph_for(x))
12490  x.requires_grad_(True)
12491  self.assertAllFused(ge.graph_for(x), except_for=("aten::size", "prim::BroadcastSizes"))
12492 
12493  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
12494  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
12495  @skipIfRocm
12496  def test_rand_broadcast_cuda(self):
12497  def fn_test_rand(x, y):
12498  r = torch.rand_like(y)
12499  return r * x + x
12500 
12501  x = torch.randn(4, 4, dtype=torch.float, device='cuda')
12502  y = torch.randn(4, 4, dtype=torch.float, device='cuda')
12503  script_f = torch.jit.script(fn_test_rand, (x, y))
12504  out = script_f(x, y)
12505  self.assertAllFused(script_f.graph_for(x, y))
12506  x.requires_grad_(True)
12507  out = script_f(x, y)
12508  self.assertAllFused(script_f.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes"))
12509  # test that broadcasting random produces correct results
12510  x = torch.ones(4, 4, dtype=torch.float, device='cuda')
12511  y = torch.ones(4, dtype=torch.float, device='cuda')
12512  out = script_f(x, y)
12513  self.assertEqual(out[0], out[1])
12514 
12515  @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
12516  @enable_cpu_fuser
12517  def test_scalar(self):
12518  def fn(x, y):
12519  return 2 * x + y
12520 
12521  x = torch.tensor(0.1, dtype=torch.float, device='cpu')
12522  y = torch.tensor(1, dtype=torch.float, device='cpu')
12523  ge = self.checkScript(fn, (x, y))
12524  self.assertAllFused(ge.graph_for(x, y))
12525 
12526  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
12527  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
12528  @skipIfRocm
12529  def test_small_constant_cuda(self):
12530  def fn_test_small_constant(x, y):
12531  return (1e-8 * x + 5e-9 * y) * 1e8
12532  x = torch.randn(4, 4, dtype=torch.float, device='cuda')
12533  y = torch.randn(4, 4, dtype=torch.float, device='cuda')
12534 
12535  ge = self.checkTrace(fn_test_small_constant, (x, y))
12536  self.assertAllFused(ge.graph_for(x, y))
12537 
12538  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
12539  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
12540  @skipIfRocm
12541  def test_tensor_scalar_ops_cuda(self):
12542  def should_fuse(x):
12543  z = 3.
12544  y = x + z
12545  return x * y
12546 
12547  # XXX: right now we only support fusing scalars if
12548  # they're constant (#9940)
12549  def should_not_fuse(x, z):
12550  y = x + int(z)
12551  return x * y
12552 
12553  inputs = [torch.randn(2, 2, dtype=torch.float, device='cuda')]
12554  ge = self.checkScript(should_fuse, inputs)
12555  self.assertAllFused(ge.graph_for(*inputs))
12556 
12557  inputs = [
12558  torch.randn(2, 2, dtype=torch.float, device='cuda'),
12559  torch.tensor(3., dtype=torch.float, device='cuda'),
12560  ]
12561  ge = self.checkScript(should_not_fuse, inputs)
12562  self.assertGraphContainsExactly(
12563  ge.graph_for(*inputs), 'prim::FusionGroup', 0, consider_subgraphs=True)
12564 
12565  @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
12566  @enable_cpu_fuser
12567  def test_where_and_typing(self):
12568  def f(x, y):
12569  mask = x > y
12570  res = torch.where(mask, x, y)
12571  return mask, res
12572 
12573  script_f = torch.jit.script(f)
12574 
12575  x = torch.randn(4, 4, dtype=torch.double)
12576  y = torch.randn(4, 4, dtype=torch.double)
12577 
12578  result1, result2 = script_f(x, y)
12579  expected1, expected2 = f(x, y)
12580  self.assertEqual(result1, expected1)
12581  self.assertEqual(result2, expected2)
12582  self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'})
12583 
12584  @unittest.skipIf(not IS_WINDOWS, "Test that the fuser is disabled on Windows")
12585  @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
12586  def test_windows_cuda(self):
12587  def scaleshift(x, scale, shift):
12588  return x * scale + shift
12589 
12590  inputs = [
12591  torch.randn(4, 4, dtype=torch.float, device='cuda'),
12592  torch.randn(4, dtype=torch.float, device='cuda'),
12593  torch.randn(4, dtype=torch.float, device='cuda'),
12594  ]
12595 
12596  ge = self.checkScript(scaleshift, inputs)
12597  self.assertGraphContainsExactly(
12598  ge.graph_for(*inputs), 'prim::FusionGroup', 0, consider_subgraphs=True)
12599 
12600 
12601 # NB: torch.jit.script, when used as a function, uses the current scope
12602 # to resolve variable names. This function cannot be made local to
12603 # TestAutodiffSubgraphSlicing because those tests call torch.jit.script on functions
12604 # in a different scope than they are defined in.
12605 def pyfn(a, b):
12606  return a * b
12607 
12608 
12609 class TestAutodiffSubgraphSlicing(JitTestCase):
12610  # TODO: It is better if we can test directly on graphs instead of the current
12611  # end-to-end fashion.
12612  def _perform_ad_subgraph_slicing(self, fn, *input_sizes):
12613  ge = torch.jit.script(fn)
12614  ge.debug_disable_autodiff_subgraph_inlining()
12615  inputs = [torch.randn(size, requires_grad=True) for size in input_sizes]
12616  ge(*inputs)
12617  return ge.graph_for(*inputs)
12618 
12619  def assertGraphSize(self, graph, size):
12620  self.assertEqual(len(list(graph.nodes())), size)
12621 
12622  def test_simple_merge(self):
12623  # o --> o
12624  def fn(x, y, z):
12625  a = x * y
12626  b = a * z
12627  return b
12628 
12629  graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
12630 
12631  self.assertGraphSize(graph, 1)
12632  self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
12633 
12634  def test_simple_no_merge(self):
12635  # o: autodiff supported. x: not autodiff supported.
12636  # o --> x
12637  def fn(x, y, z):
12638  a = x * y
12639  b = pyfn(a, z)
12640  return b
12641 
12642  graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
12643 
12644  self.assertGraphSize(graph, 2)
12645  self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
12646 
12647  def test_does_not_merge_unrelated(self):
12648  # o o
12649  def fn(w, x, y, z):
12650  a = x * y
12651  b = w * z
12652  return a, b
12653 
12654  graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
12655 
12656  self.assertGraphSize(graph, 3)
12657  self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2)
12658 
12659  def test_merges_without_cycles(self):
12660  # o --> o --> o
12661  # | ^
12662  # \_________/
12663  def fn(w, x, y):
12664  a = w * x
12665  b = a * y
12666  c = a * b
12667  return c
12668 
12669  graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
12670 
12671  self.assertGraphSize(graph, 1)
12672  self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
12673 
12674  def test_merges_dense(self):
12675  # o o
12676  # |\ /|
12677  # | \ / |
12678  # | /\ |
12679  # vv vv
12680  # o o
12681  def fn(x, y):
12682  a, b = x.chunk(2)
12683  c, d = y.chunk(2)
12684  return a + c, b + d
12685 
12686  graph = self._perform_ad_subgraph_slicing(fn, 2, 2)
12687 
12688  self.assertGraphSize(graph, 2)
12689  self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
12690 
12691  def test_does_not_create_cycles(self):
12692  # o --> x --> o
12693  # | ^
12694  # \_________/
12695  def fn(w, x, y):
12696  a = w * x
12697  b = pyfn(a, y)
12698  c = a * b
12699  return c
12700 
12701  graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
12702 
12703  self.assertGraphSize(graph, 3)
12704  self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2)
12705 
12706  def test_merges_up(self):
12707  # o --> x o
12708  # | ^
12709  # \_________/
12710  def fn(w, x, y, z):
12711  a = w * x
12712  b = pyfn(a, y)
12713  c = a * z
12714  return b, c
12715 
12716  graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
12717 
12718  self.assertGraphSize(graph, 3)
12719  self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
12720 
12721  def test_merges_down(self):
12722  # o x --> o
12723  # | ^
12724  # \_________/
12725  def fn(v, w, x, y):
12726  a = v * w
12727  b = pyfn(x, y)
12728  c = b * a
12729  return a, c
12730 
12731  graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
12732 
12733  self.assertGraphSize(graph, 3)
12734  self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
12735 
12736  def test_respects_lexical_scoping(self):
12737  def fn(x, k):
12738  y = x * 1.1
12739  if bool(k):
12740  k = k + y
12741  z = y * k
12742  return z, k
12743 
12744  graph = self._perform_ad_subgraph_slicing(fn, 1, 1)
12745 
12746  # We should not have combined the two multiplications into
12747  # the same group; they should each be a separate DiffGraph
12748  self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2)
12749 
12750 
12751 class TestCustomOperators(JitTestCase):
12752 
12753  def test_dynamic_op_registry(self):
12754  from torch._ops import _OpNamespace
12755  self.assertTrue(hasattr(torch, 'ops'))
12756 
12757  if '_test' in torch.ops.__dict__:
12758  torch.ops.__dict__.pop('_test')
12759 
12760  # Don't use `hasattr()` because it will call `__getattr__`.
12761  self.assertNotIn('_test', torch.ops.__dict__)
12762  torch.ops._test
12763  self.assertIn('_test', torch.ops.__dict__)
12764  self.assertEqual(type(torch.ops._test), _OpNamespace)
12765 
12766  self.assertNotIn('leaky_relu', torch.ops._test.__dict__)
12767  op = torch.ops._test.leaky_relu
12768  self.assertTrue(callable(op))
12769  self.assertIn('leaky_relu', torch.ops._test.__dict__)
12770  op2 = torch.ops._test.leaky_relu
12771  self.assertEqual(op, op2)
12772 
12773  def test_simply_calling_an_operator(self):
12774  input = torch.randn(100)
12775  output = torch.ops.aten.relu(input)
12776  self.assertEqual(output, input.relu())
12777 
12778  def test_default_arguments_are_used(self):
12779  output = torch.ops._test.leaky_relu(torch.tensor([-1.0, 1.0]))
12780  self.assertEqual(output, torch.tensor([-0.01, 1]))
12781 
12782  def test_only_kwargs(self):
12783  output = torch.ops._test.leaky_relu(self=torch.tensor(-1.0))
12784  self.assertEqual(output, torch.tensor(-0.01))
12785 
12786  def test_passing_too_many_args(self):
12787  with self.assertRaisesRegex(
12788  RuntimeError,
12789  r"aten::relu\(\) expected at most 1 argument\(s\) but received 2 argument\(s\)"
12790  ):
12791  torch.ops.aten.relu(1, 2)
12792 
12793  def test_passing_too_few_args(self):
12794  with self.assertRaisesRegex(
12795  RuntimeError,
12796  r"aten::relu\(\) is missing value for argument 'self'."
12797  ):
12798  torch.ops.aten.relu()
12799 
12800  def test_passing_one_positional_but_not_the_second(self):
12801  with self.assertRaisesRegex(
12802  RuntimeError,
12803  r"aten::transpose\(\) is missing value for argument 'dim0'."
12804  ):
12805  torch.ops.aten.transpose(torch.ones(5, 5))
12806 
12807  def test_passing_an_argument_both_as_positional_and_kwarg(self):
12808  with self.assertRaisesRegex(
12809  RuntimeError,
12810  "Argument 'self' specified both as positional and keyword argument"
12811  ):
12812  torch.ops._test.leaky_relu(torch.ones(5), self=torch.ones(5))
12813 
12814  def test_passing_unknown_kwargs(self):
12815  with self.assertRaisesRegex(
12816  RuntimeError,
12817  "Unknown keyword argument 'foo' for operator '_test::leaky_relu'"
12818  ):
12819  torch.ops._test.leaky_relu(torch.ones(5), foo=torch.ones(5))
12820 
12821  def test_passing_and_returning_lists(self):
12822  # Replace with actual test once we support lists.
12823  a, b = torch.rand(5), torch.rand(5)
12824  output = torch.ops._test.cat([a, b])
12825  output_ref = torch.cat([a, b])
12826  self.assertEqual(output, output_ref)
12827 
12828  def test_calling_scripted_custom_op(self):
12829  @torch.jit.script
12830  def func(x):
12831  return torch.ops.aten.relu(x)
12832  input = torch.ones(5, 5)
12833  self.assertEqual(func(input), input.relu())
12834 
12835  def test_calling_traced_custom_op(self):
12836  input = torch.ones(5, 5)
12837  func = torch.jit.trace(torch.ops.aten.relu, [input])
12838  self.assertEqual(func(input), input.relu())
12839 
12840  def test_script_graph_for_custom_ops_matches_traced_graph(self):
12841  input = torch.ones(5, 5)
12842  trace = torch.jit.trace(torch.ops.aten.relu, [input])
12843  self.assertExpectedInline(canonical(trace.graph), '''\
12844 graph(%0 : Double(5, 5)):
12845  %1 : Double(5, 5) = aten::relu(%0)
12846  return (%1)
12847 ''')
12848 
12849  def test_script_graph_contains_custom_op(self):
12850  @torch.jit.script
12851  def func(x):
12852  return torch.ops.aten.relu(x)
12853  self.assertExpectedInline(canonical(func.graph), '''\
12854 graph(%x : Tensor):
12855  %1 : Tensor = aten::relu(%x)
12856  return (%1)
12857 ''')
12858 
12859  def test_generic_list(self):
12860  self.assertEqual(torch.ops._test.get_first([['hello']]), 'hello')
12861 
12862 
12863 class TestJitGeneratedAutograd(JitTestCase):
12864  pass
12865 
12866 
12867 class TestJitGeneratedModule(JitTestCase):
12868  pass
12869 
12870 
12871 class TestJitGeneratedFunctional(JitTestCase):
12872  pass
12873 
12874 
12875 # UBSAN per-function exclusions don't seem to work with OpenMP pragmas,
12876 # and we have to disable the failing tests here instead.
12877 UBSAN_BLACKLISTED_TESTS = [
12878  "test___rdiv___constant",
12879  "test___rdiv___scalar_constant",
12880  "test_addcdiv",
12881  "test_addcdiv_broadcast_all",
12882  "test_addcdiv_broadcast_rhs",
12883  "test_addcdiv_scalar",
12884  "test_addcdiv_scalar_broadcast_lhs",
12885  "test_addcdiv_scalar_broadcast_rhs",
12886  "test_addcdiv_scalar_scale",
12887  "test_addcdiv_scalar_scale_broadcast_lhs",
12888  "test_addcdiv_scalar_scale_broadcast_rhs",
12889  "test_addcdiv_scale",
12890  "test_addcdiv_scale_broadcast_all",
12891  "test_addcdiv_scale_broadcast_rhs",
12892  "test_add_broadcast_all",
12893  "test_add_broadcast_lhs",
12894  "test_add_broadcast_rhs",
12895  "test_add_constant",
12896  "test_add_scalar",
12897  "test_add_scalar_broadcast_lhs",
12898  "test_add_scalar_broadcast_rhs",
12899  "test_div",
12900  "test_div_broadcast_all",
12901  "test_div_broadcast_lhs",
12902  "test_div_broadcast_rhs",
12903  "test_div_scalar",
12904  "test_div_scalar_broadcast_lhs",
12905  "test_div_scalar_broadcast_rhs",
12906  "test_rsqrt",
12907  "test_rsqrt_scalar",
12908  "test_add",
12909  "test_reciprocal",
12910  "test_reciprocal_scalar",
12911 ]
12912 
12913 L = 20
12914 M = 10
12915 S = 5
12916 
12917 # module cannot be exported /imported currently
12918 EXCLUDE_MODULE_EXPORT_IMPORT = {
12919  'EmbeddingBag',
12920  'MaxPool1d',
12921  'MaxPool2d',
12922  'MaxPool3d',
12923  'AdaptiveAvgPool2d',
12924  'AdaptiveAvgPool3d',
12925  'Fold',
12926  'Unfold',
12927 }
12928 
12929 # NB: JIT script tests for all nn functional interfaces, script mode does
12930 # not support in_place operations yet, so no inplace operation tests added.
12931 # removed all the deprecated functions
12932 #
12933 # (
12934 # method name,
12935 # input size/constructing fn,
12936 # args (tuple represents shape of a tensor arg),
12937 # test variant name(will be used at test name suffix,
12938 # 'inplace' skips grad tests), // optional
12939 # fn to determine if test should be skipped, // optional
12940 # fn mapping output to part that should be gradcheck'ed, // optional
12941 # kwargs for function, // optional
12942 # )
12943 nn_functional_tests = [
12944  ('conv1d', (S, S, S), ((S, S, S),)),
12945  ('conv2d', (S, S, S, S), ((S, S, S, S),)),
12946  ('conv3d', (S, S, S, S, S), ((S, S, S, S, S),)),
12947  ('conv_transpose1d', (S, S, S), ((S, S, S),)),
12948  ('conv_transpose2d', (S, S, S, S), ((S, S, S, S),)),
12949  ('conv_transpose3d', (S, S, S, S, S), ((S, S, S, S, S),)),
12950  ('conv_tbc', (S, S, S), ((S, S, S), (S,), 2)),
12951  ('avg_pool1d', (S, S, S), (3,)),
12952  ('avg_pool2d', (S, S, S, S), (3,)),
12953  ('avg_pool3d', (S, S, S, S, S), (3,)),
12954  ('fractional_max_pool2d', (S, S, S, S), (3, [2, 3],)),
12955  ('max_pool1d', (S, S, S), (2, 1)),
12956  ('max_pool1d', (S, S, S), (2, 1, 1, 1, False, True), 'with_indices'),
12957  ('max_pool2d', (S, S, S, S), (2, 1)),
12958  ('max_pool3d', (S, S, S, S, S), (2, 1)),
12959  ('max_unpool1d', torch.tensor([[[2., 4]]]), (torch.tensor([[[1, 3]]]), 2, 2, 0)),
12960  ('max_unpool2d', torch.tensor([[[[2., 4]]]]), (torch.tensor([[[[1, 3]]]]), 2, 2, 0)),
12961  ('max_unpool3d', torch.tensor([[[[[2., 4]]]]]), (torch.tensor([[[[[1, 3]]]]]), 2, 2, 0)),
12962  ('lp_pool1d', (S, S, S), (2., 3, 2,)),
12963  ('lp_pool2d', (S, S, S, S), (2., 3, 2,)),
12964  ('adaptive_max_pool1d', (S, S, S), (5,)),
12965  ('adaptive_max_pool2d', (S, S, S, S), ([5, 7],)),
12966  ('adaptive_max_pool3d', (S, S, S, S, S), ([3, 2, 2],)),
12967  ('adaptive_avg_pool1d', (S, S, S), (5,)),
12968  ('adaptive_avg_pool2d', (S, S, S, S), ([5, 7],)),
12969  ('adaptive_avg_pool3d', (S, S, S, S, S), ([3, 2, 2],)),
12970  ('dropout', (S, S, S), (0.5,)),
12971  ('alpha_dropout', (S, S, S), (0.5,)),
12972  ('dropout2d', (S, S, S), (0.5,)),
12973  ('dropout3d', (S, S, S), (0.5,)),
12974  ('feature_alpha_dropout', (S, S, S), (0.5,)),
12975  ('threshold', (S, S, S), (0.1, 2.),),
12976  ('threshold', (S, S, S), (0.1, 2., True), 'inplace'),
12977  ('relu', (S, S, S), (),),
12978  ('relu', (S, S, S), (), 'inplace'),
12979  ('glu', (S - 1, S - 1, S - 1), (),),
12980  ('hardtanh', (S, S, S), (-0.5, 0.5),),
12981  ('hardtanh', (S, S, S), (-0.5, 0.5, True), 'inplace'),
12982  ('relu6', (S, S, S), (),),
12983  ('relu6', (S, S, S), (True), 'inplace'),
12984  ('elu', (S, S, S), (0.9,),),
12985  ('elu', (S, S, S), (0.9, True), 'inplace'),
12986  ('selu', (S, S, S), (),),
12987  ('selu', (S, S, S), (True), 'inplace'),
12988  ('celu', (S, S, S), (0.9,),),
12989  ('celu', (S, S, S), (0.9, True), 'inplace'),
12990  ('leaky_relu', (S, S, S), (0.02,),),
12991  ('leaky_relu', (S, S, S), (0.02,), 'inplace'),
12992  ('rrelu', (S, S), (0.1, 0.3, False),),
12993  ('rrelu', (S, S), (0.1, 0.3, False, True), 'inplace'),
12994  ('hardshrink', (S, S, S), (0.4,),),
12995  ('tanhshrink', (S, S, S), (),),
12996  ('softsign', (S, S, S), (),),
12997  ('softplus', (S, S, S), (),),
12998  ('softmin', (S, S, S), (0,),),
12999  ('softmax', (S, S, S), (0,),),
13000  ('softmax', (S, S, S), (0, 3, torch.double), 'with_all_args'),
13001  ('tanh', (S, S, S), (),),
13002  ('sigmoid', (S, S, S), (),),
13003  ('log_softmax', (S, S, S), (0,),),
13004  ('linear', (S, S), ((M, S),),),
13005  ('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),),
13006  ('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ),),
13007  ('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),),
13008  ('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), ),),
13009  ('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),),
13010  ('layer_norm', (S, S, S, S), ([5],),),
13011  ('layer_norm', (S, S, S, S), ([5], (S,)), 'with_only_weight'),
13012  ('layer_norm', (S, S, S, S), ([5], None, (S,)), 'with_only_bias'),
13013  ('layer_norm', (S, S, S, S), ([5], (S,), (S,)), 'with_weight_and_bias'),
13014  ('group_norm', (S, S, S), (1, torch.rand(5),),),
13015  ('local_response_norm', (S, S, S), (2, ),),
13016  ('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),),),
13017  ('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2),),),
13018  ('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2), True, True), 'full'),
13019  ('kl_div', F.log_softmax(torch.randn(S, 10), 1), (F.softmax(torch.randn(S, 10), 1),),),
13020  ('cross_entropy', (3, S), (torch.randint(S, (3,), dtype=torch.int64),),),
13021  ('binary_cross_entropy_with_logits', (3,), (torch.empty(3).random_(2), ),),
13022  ('smooth_l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
13023  ('l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
13024  ('mse_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
13025  ('smooth_l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
13026  ('l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
13027  ('mse_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
13028  ('margin_ranking_loss', (3, S), ((3, S), (S,)),),
13029  ('hinge_embedding_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
13030  ('soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
13031  ('multilabel_soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
13032  ('cosine_embedding_loss', (S, S), ((S, S), non_differentiable(torch.rand(S,))),),
13033  ('pixel_shuffle', (1, 9, 4, 4), (3,),),
13034  ('affine_grid', (S, 2, 3), (torch.Size([S, 1, 7, 7]),),),
13035  ('pad', (3, 3, 4, 2), ([1, 1],),),
13036  ('pairwise_distance', (S, S), ((S, S),),),
13037  ('pdist', (S, S), (),),
13038  ('cosine_similarity', (S, S), ((S, S),),),
13039  ('triplet_margin_loss', (S, S), ((S, S), (S, S)),),
13040  ('normalize', (S, S, S), (),),
13041  ('unfold', (S, S, S, S), ([2, 3]),),
13042  ('fold', (1, 3 * 2 * 2, 12), ([4, 5], [2, 2]),),
13043  ('grid_sample', (S, S, S, S), (non_differentiable(torch.rand(S, S, S, 2)),),),
13044  ('gumbel_softmax', (S, S), (2.,),),
13045  ('gumbel_softmax', (S, S), (2., True,), 'hard'),
13046  ('multilabel_margin_loss', torch.tensor([[0.2, -0.2, 0.07]]), (torch.tensor([[0, 0, 1]]),),),
13047  ('multi_margin_loss', (S, S), (non_differentiable(torch.randint(S, (S, ), dtype=torch.int64)),
13048  1, 1., non_differentiable(torch.randn(S))),),
13049  ('binary_cross_entropy', torch.randn(3, 2).sigmoid(), (non_differentiable(torch.rand(3, 2)),
13050  non_differentiable(torch.randn(3, 2))),),
13051  ('binary_cross_entropy', torch.randn(3, 2).sigmoid(),
13052  (non_differentiable(torch.rand(3, 2)),
13053  non_differentiable(torch.randn(3, 2)), None, None, 'mean'), 'size_average'),
13054  ('ctc_loss', torch.rand(S, S, S).log_softmax(2).detach().requires_grad_(),
13055  (torch.randint(1, S, (S, S), dtype=torch.long), torch.full((S,), S, dtype=torch.long),
13056  torch.randint(1, S, (S,), dtype=torch.long))),
13057  ('upsample', torch.randn(S, S, M, M), (None, 2), 'with_scale'),
13058  ('upsample', torch.randn(S, S, M, M), (4,), 'with_size'),
13059  ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d'),
13060  ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale'),
13061  ('interpolate', torch.randn(S, S, M, M), (4,), 'nearest_4d_with_size'),
13062  ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'area_4d'),
13063  ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'area_4d_with_scale'),
13064  ('interpolate', torch.randn(S, S, M, M), (4,), 'area_4d_with_size'),
13065  ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bilinear_4d'),
13066  ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bilinear_4d_with_scale'),
13067  ('interpolate', torch.randn(S, S, M, M), (4,), 'bilinear_4d_with_size'),
13068  ('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bicubic_4d'),
13069  ('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bicubic_4d_with_scale'),
13070  ('interpolate', torch.randn(S, S, M, M), (4,), 'bicubic_4d_with_size'),
13071  ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'nearest_3d'),
13072  ('interpolate', torch.randn(S, M, M), (None, 2.), 'nearest_3d_with_scale'),
13073  ('interpolate', torch.randn(S, M, M), (4,), 'nearest_3d_with_size'),
13074  ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'area_3d'),
13075  ('interpolate', torch.randn(S, M, M), (None, 2.), 'area_3d_with_scale'),
13076  ('interpolate', torch.randn(S, M, M), (4,), 'area_3d_with_size'),
13077  ('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'linear_3d'),
13078  ('interpolate', torch.randn(S, M, M), (None, 2.), 'linear_3d_with_scale'),
13079  ('interpolate', torch.randn(S, M, M), (4,), 'linear_3d_with_size'),
13080  ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'nearest_5d_with_scale'),
13081  ('interpolate', torch.randn(S, M, M, M, M), (4,), 'nearest_5d_with_size'),
13082  ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'area_5d'),
13083  ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'area_5d_with_scale'),
13084  ('interpolate', torch.randn(S, M, M, M, M), (4,), 'area_5d_with_size'),
13085  ('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d'),
13086  ('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale'),
13087  ('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size'),
13088 ]
13089 
13090 
13091 # Test names in this set are only checked for a single derivative
13092 nn_functional_single_grad = frozenset('test_nn_' + name for name in [
13093  'pdist',
13094  'multilabel_margin_loss',
13095  'max_unpool3d',
13096  'multi_margin_loss',
13097  'binary_cross_entropy',
13098  'binary_cross_entropy_size_average',
13099  'ctc_loss',
13100  'grid_sample',
13101 ])
13102 
13103 # additional modules test
13104 # TODO: delete this list once we make all nn_tests work
13105 additional_module_tests = [
13106  {
13107  'module_name': 'Bilinear',
13108  'constructor_args': (S, S, M),
13109  'input_size': (S, S),
13110  'extra_args': ((S, S),)
13111  },
13112  {
13113  'module_name': 'RNNCell',
13114  'constructor_args': (S, S),
13115  'input_size': (S, S),
13116  },
13117  {
13118  'module_name': 'LSTMCell',
13119  'constructor_args': (S, S),
13120  'input_size': (S, S),
13121  },
13122  {
13123  'module_name': 'GRUCell',
13124  'constructor_args': (S, S),
13125  'input_size': (S, S),
13126  },
13127 ]
13128 
13129 
13130 def add_autograd_test(
13131  name,
13132  self_size,
13133  args,
13134  variant_name='',
13135  dim_args_idx=(),
13136  skipTestIf=(),
13137  output_process_fn=lambda x: x,
13138  kwargs=None):
13139  basic_test_name = 'test_' + name
13140  if variant_name != '':
13141  basic_test_name += '_' + variant_name
13142 
13143  for dim_perm in product([-1, 1], repeat=len(dim_args_idx)):
13144  test_name = basic_test_name
13145  new_args = [arg * dim_perm[dim_args_idx.index(i)] if i in dim_args_idx else arg for i, arg in enumerate(args)]
13146  test_name = basic_test_name + ''.join('_neg' + str(i) for i, idx in enumerate(dim_perm) if idx < 0)
13147  new_args = tuple(new_args)
13148 
13149  # for-loop bodies don't define scopes, so we have to save the variables
13150  # we want to close over in some way
13151  def do_test(self, name=name, self_size=self_size, args=new_args, test_name=test_name,
13152  output_process_fn=output_process_fn):
13153  def check(name):
13154  set_rng_seed(2)
13155  is_magic_method = name[:2] == '__' and name[-2:] == '__'
13156  is_inplace = name[-1] == "_" and not is_magic_method
13157  self_variable = create_input((self_size,))[0][0]
13158  # FixMe: run grad checks on inplace self
13159  if is_inplace:
13160  self_variable.requires_grad = False
13161  # need to record this because methods can change the size (e.g. unsqueeze)
13162  args_variable, kwargs_variable = create_input(args, requires_grad=not is_inplace, call_kwargs=kwargs)
13163  self_tensor = deepcopy(self_variable.data)
13164  args_tensor = deepcopy(unpack_variables(args_variable))
13165 
13166  def fn(*inputs, **kwargs):
13167  output = getattr(inputs[0], name)(*inputs[1:], **kwargs)
13168  return output_process_fn(output)
13169 
13170  check_types = test_name not in EXCLUDE_TYPE_CHECK
13171 
13172  if not is_inplace and name not in EXCLUDE_GRADCHECK and not exclude_tensor_method(name, test_name):
13173  # Test with disable_autodiff_subgraph_inlining, which forces the graph
13174  # to contain DifferentiableGraph nodes whenever possible. This allows us
13175  # to test autodiff; we assume that autograd is correct and use autodiff for backprop
13176  if test_name not in EXCLUDE_TRACED:
13177  check_against_reference(self,
13178  create_traced_fn(self, fn,
13179  disable_autodiff_subgraph_inlining=True),
13180  fn, (self_variable,) + args_variable, kwargs_variable,
13181  check_types=check_types)
13182 
13183  if not is_magic_method and test_name not in EXCLUDE_SCRIPT:
13184  check_against_reference(self,
13185  create_script_fn(self, name, 'method', output_process_fn,
13186  disable_autodiff_subgraph_inlining=True),
13187  fn, (self_variable,) + args_variable, kwargs_variable,
13188  check_types=check_types)
13189 
13190  # functional interface tests
13191  if hasattr(torch, name) and name not in EXCLUDE_FUNCTIONAL:
13192  def fn(*inputs, **kwargs):
13193  output = getattr(torch, name)(*inputs, **kwargs)
13194  return output_process_fn(output)
13195 
13196  f_args_variable = (self_variable,) + args_variable
13197  f_args_tensor = (self_tensor,) + args_tensor
13198 
13199  if not is_inplace and test_name not in EXCLUDE_TRACED:
13200  check_against_reference(self,
13201  create_traced_fn(self, fn,
13202  disable_autodiff_subgraph_inlining=True),
13203  fn, f_args_variable, kwargs_variable, check_types=check_types)
13204 
13205  if not is_inplace and test_name not in EXCLUDE_SCRIPT:
13206  check_against_reference(self,
13207  create_script_fn(self, name, 'functional', output_process_fn,
13208  disable_autodiff_subgraph_inlining=True),
13209  fn, f_args_variable, kwargs_variable,
13210  check_types=check_types)
13211 
13212  # alias annotation testing
13213  if is_inplace and test_name not in EXCLUDE_SCRIPT:
13214  check_alias_annotation(name, (self_variable,) + args_variable, kwargs_variable)
13215 
13216  check(name)
13217  inplace_name = name + '_'
13218  # can't broadcast inplace to left hand side
13219  broadcast_skip_inplace = 'broadcast_lhs' in test_name or 'broadcast_all' in test_name
13220  if hasattr(torch.ones(1), inplace_name) and not broadcast_skip_inplace:
13221  check(inplace_name)
13222 
13223  post_add_test(test_name, skipTestIf, do_test, TestJitGeneratedAutograd)
13224 
13225 
13226 def suppress_warnings(fn):
13227  @wraps(fn)
13228  def wrapper(*args, **kwargs):
13229  with warnings.catch_warnings(record=True):
13230  return fn(*args, **kwargs)
13231  return wrapper
13232 
13233 
13234 def add_nn_functional_test(name, self_size, args, variant_name='', skipTestIf=(),
13235  output_process_fn=lambda x: x, kwargs=None):
13236  test_name = 'test_nn_' + name
13237 
13238  if variant_name != '':
13239  test_name = test_name + '_' + variant_name
13240 
13241  no_grad = variant_name == 'inplace'
13242 
13243  @suppress_warnings
13244  def do_test(self, name=name, args=args, test_name=test_name):
13245  torch.manual_seed(2)
13246 
13247  self_variable = create_input((self_size,))[0][0]
13248 
13249  # need to record this because methods can change the size (e.g. unsqueeze)
13250  args_variable, kwargs_variable = create_input(args, call_kwargs=kwargs)
13251 
13252  self_tensor = deepcopy(self_variable.data)
13253  args_tensor = deepcopy(unpack_variables(args_variable))
13254 
13255  if not no_grad:
13256  output_variable = getattr(F, name)(self_variable, *args_variable, **kwargs_variable)
13257 
13258  def fn(*inputs, **kwargs):
13259  output = getattr(F, name)(*inputs, **kwargs)
13260  return output_process_fn(output)
13261 
13262  f_args_variable = (self_variable,) + args_variable
13263  f_args_tensor = (self_tensor,) + args_tensor
13264 
13265  if test_name not in EXCLUDE_SCRIPT:
13266  disable_ad_subgraph_inlining = test_name in DISABLE_AUTODIFF_SUBGRAPH_INLINING
13267 
13268  def run_test():
13269  script_fn = create_script_fn(self, name, 'nn_functional', output_process_fn,
13270  disable_autodiff_subgraph_inlining=disable_ad_subgraph_inlining)
13271  check_against_reference(self, script_fn, fn, f_args_variable, kwargs_variable, no_grad=no_grad)
13272 
13273  if test_name in EXCLUDE_PYTHON_PRINT:
13274  with self.disableModuleHook():
13275  run_test()
13276  else:
13277  run_test()
13278 
13279  post_add_test(test_name, skipTestIf, do_test, TestJitGeneratedFunctional)
13280 
13281 
13282 def add_nn_module_test(*args, **kwargs):
13283  if 'module_name' in kwargs:
13284  name = kwargs['module_name']
13285  elif 'fullname' in kwargs:
13286  name = kwargs['fullname']
13287  elif 'constructor' in kwargs:
13288  name = kwargs['constructor'].__name__
13289 
13290  no_grad = False if 'no_grad' not in kwargs else kwargs['no_grad']
13291 
13292  module_name = name.split("_")[0]
13293 
13294  module = getattr(torch.nn, module_name, None)
13295  if module is None or torch._jit_internal.weak_types.get(module) is None:
13296  return
13297 
13298  if 'desc' in kwargs and 'eval' in kwargs['desc']:
13299  # eval() is not supported, so skip these tests
13300  return
13301 
13302  test_name = name
13303  if 'desc' in kwargs:
13304  test_name = "{}_{}".format(test_name, kwargs['desc'])
13305  test_name = 'test_nn_{}'.format(test_name)
13306 
13307  @suppress_warnings
13308  def do_test(self):
13309  if test_name in EXCLUDE_SCRIPT_MODULES:
13310  return
13311  if 'constructor' in kwargs:
13312  nn_module = kwargs['constructor']
13313  else:
13314  nn_module = getattr(torch.nn, name)
13315 
13316  if "FunctionalModule" in str(nn_module):
13317  return
13318 
13319  if 'constructor_args_fn' in kwargs:
13320  constructor_args = kwargs['constructor_args_fn']()
13321  else:
13322  constructor_args = kwargs.get('constructor_args', ())
13323 
13324  # Construct a script module that passes arguments through
13325  # to self.submodule
13326  def create_script_module(*args, **kwargs):
13327  formals, tensors, actuals = get_script_args(args)
13328 
13329  method_args = ', '.join(['self'] + actuals)
13330  call_args_str = ', '.join(actuals)
13331  call = "self.submodule({})".format(call_args_str)
13332  script = script_method_template.format(method_args, call)
13333 
13334  submodule_constants = []
13335  if kwargs.get('is_constant'):
13336  submodule_constants = ['submodule']
13337 
13338  # Create module to use the script method
13339  class TheModule(torch.jit.ScriptModule):
13340  __constants__ = submodule_constants
13341 
13342  def __init__(self):
13343  super(TheModule, self).__init__()
13344  self.submodule = nn_module(*constructor_args)
13345  # module cannot be imported / exported
13346  if module_name in EXCLUDE_MODULE_EXPORT_IMPORT:
13347  with self.disableModuleHook():
13348  module = TheModule()
13349  module.define(script)
13350  create_script_module.last_graph = module.graph
13351  mod = module(*args)
13352  else:
13353  module = TheModule()
13354  module.define(script)
13355  self.assertExportImportModule(module, tensors)
13356  create_script_module.last_graph = module.graph
13357  mod = module(*args)
13358  return mod
13359 
13360  # Construct a normal nn module to stay consistent with create_script_module
13361  # and make use of a single global rng_state in module initialization
13362  def create_nn_module(*args, **kwargs):
13363  module = nn_module(*constructor_args)
13364  return module(*args)
13365 
13366  # Set up inputs from tuple of sizes or constructor fn
13367  if 'input_fn' in kwargs:
13368  input = kwargs['input_fn']()
13369  else:
13370  input = (kwargs['input_size'],)
13371 
13372  # Extra parameters to forward()
13373  if 'extra_args' in kwargs:
13374  input = input + kwargs['extra_args']
13375 
13376  if 'target_size' in kwargs:
13377  input = input + (kwargs['target_size'],)
13378  elif 'target_fn' in kwargs:
13379  if torch.is_tensor(input):
13380  input = (input,)
13381  input = input + (kwargs['target_fn'](),)
13382 
13383  args_variable, kwargs_variable = create_input(input)
13384  f_args_variable = deepcopy(unpack_variables(args_variable))
13385 
13386  # Check against Python module as reference
13387  check_against_reference(self, create_script_module, create_nn_module, f_args_variable, no_grad=no_grad)
13388 
13389  post_add_test(test_name, (), do_test, TestJitGeneratedModule)
13390 
13391 
13392 def post_add_test(test_name, skipTestIf, do_test, test_class):
13393  assert not hasattr(test_class, test_name), 'Two tests have the same name: ' + test_name
13394 
13395  for skip in skipTestIf:
13396  do_test = skip(do_test)
13397 
13398  if not (TEST_WITH_UBSAN and test_name in UBSAN_BLACKLISTED_TESTS):
13399  setattr(test_class, test_name, do_test)
13400 
13401 
13402 class TestAsync(JitTestCase):
13403  def test_async_python(self):
13404  @torch.jit.script
13405  def foo(x):
13406  return torch.neg(x)
13407 
13408  x = torch.rand(3, 4)
13409  fut = torch.jit._fork(foo, x)
13410  y_hat = foo(x)
13411  y = torch.jit._wait(fut)
13412  # assert nothing; only to make sure the fake python path works
13413 
13414  def test_async_parsing(self):
13415  @torch.jit.script
13416  def foo(x):
13417  # type: (Tensor) -> List[Tensor]
13418  return [torch.neg(x), x.t()]
13419 
13420  @torch.jit.script
13421  def bar(x):
13422  futures = torch.jit.annotate(List[Future[List[Tensor]]], [])
13423  for _ in range(3):
13424  future = torch.jit.annotate(
13425  Future[List[Tensor]],
13426  torch.jit._fork(foo, x)
13427  )
13428  futures.append(future)
13429 
13430  output = torch.jit.annotate(List[List[Tensor]], [])
13431  for i in range(3):
13432  output.append(torch.jit._wait(futures[i]))
13433  return output
13434 
13435  x = torch.rand(3, 3)
13436  result = bar(x)
13437  self.assertEqual(len(result), 3)
13438 
13439  def test_async_script(self):
13440  @torch.jit.script
13441  def foo(x):
13442  return torch.neg(x), x
13443 
13444  x = torch.rand(3, 4)
13445 
13446  @torch.jit.script
13447  def wait_script(x):
13448  fut = torch.jit._fork(foo, x)
13449  y_hat = foo(x)
13450  y = torch.jit._wait(fut)
13451  return y, y_hat
13452 
13453  y, y_hat = wait_script(x)
13454 
13455  self.assertEqual(y, y_hat)
13456 
13457  def test_async_script_capture(self):
13458  class Mod(torch.jit.ScriptModule):
13459  __constants__ = ['const']
13460 
13461  def __init__(self):
13462  super(Mod, self).__init__(False)
13463  self.const = 42
13464  self.param = nn.Parameter(torch.randn(2, 2))
13465 
13466  @torch.jit.script_method
13467  def foo(self, x1, x2):
13468  return torch.neg(x1), self.param, self.const, torch.neg(x2), self.param
13469 
13470  @torch.jit.script_method
13471  def wait_script(self, x1, x2):
13472  fut = torch.jit._fork(self.foo, x1, x2)
13473  y_hat = self.foo(x1, x2)
13474  y = torch.jit._wait(fut)
13475  return y, y_hat
13476 
13477  x1 = torch.rand(3, 4)
13478  x2 = torch.rand(5, 6)
13479 
13480  m = Mod()
13481  y, y_hat = m.wait_script(x1, x2)
13482 
13483  self.assertEqual(y, y_hat)
13484 
13485  def test_async_script_nested(self):
13486  @torch.jit.script
13487  def foo(x):
13488  return torch.neg(x), x
13489 
13490  x = torch.rand(3, 4)
13491 
13492  @torch.jit.script
13493  def wait_script(x):
13494  fut = torch.jit._fork(foo, x)
13495  y_hat = foo(x)
13496  y = torch.jit._wait(fut)
13497  return y, y_hat
13498 
13499  @torch.jit.script
13500  def wait_script_nest(x):
13501  fut = torch.jit._fork(wait_script, x)
13502  return torch.jit._wait(fut)
13503 
13504  y, y_hat = wait_script_nest(x)
13505 
13506  self.assertEqual(y, y_hat)
13507 
13508  def test_async_script_no_script_mod(self):
13509  x = torch.rand(3, 4)
13510 
13511  with self.assertRaisesRegex(RuntimeError, 'cannot call a value'):
13512  @torch.jit.script
13513  def wait_script(x):
13514  fut = torch.jit._fork(x)
13515  return fut
13516 
13517  def test_async_script_multi_waits(self):
13518  @torch.jit.script
13519  def foo(x):
13520  return torch.neg(x).t() + x
13521 
13522  @torch.jit.script
13523  def wait_script(x):
13524  fut = torch.jit._fork(foo, x)
13525 
13526  # wait twice on the same future
13527  y1 = torch.jit._wait(fut)
13528  y2 = torch.jit._wait(fut)
13529  return y1, y2
13530 
13531  x = torch.rand(2, 2)
13532  y1, y2 = wait_script(x)
13533  self.assertEqual(y1, y2)
13534 
13535  def test_async_script_multi_forks(self):
13536  @torch.jit.script
13537  def foo1(x):
13538  return torch.neg(x).t() + x
13539 
13540  @torch.jit.script
13541  def foo2(x, y):
13542  return torch.neg(x).t() + x + torch.neg(y).t()
13543 
13544  @torch.jit.script
13545  def foo3(x, y, z):
13546  return torch.neg(z).t() + y.t() + x
13547 
13548  x1 = torch.rand(10, 10)
13549  x2 = torch.rand(10, 10)
13550  x3 = torch.rand(10, 10)
13551 
13552  @torch.jit.script
13553  def wait_script(x1, x2, x3):
13554  f1 = torch.jit._fork(foo1, x1)
13555  f2 = torch.jit._fork(foo2, x1, x2)
13556  f3 = torch.jit._fork(foo3, x1, x2, x3)
13557  f4 = torch.jit._fork(foo1, x2)
13558  f5 = torch.jit._fork(foo2, x2, x3)
13559 
13560  # ignore some forks
13561  y1 = torch.jit._wait(f1)
13562  y2 = torch.jit._wait(f2)
13563  y3 = torch.jit._wait(f3)
13564 
13565  return y1, y2, y3
13566 
13567  y1, y2, y3 = wait_script(x1, x2, x3)
13568  self.assertEqual(y1, foo1(x1))
13569  self.assertEqual(y2, foo2(x1, x2))
13570  self.assertEqual(y3, foo3(x1, x2, x3))
13571 
13572  def test_async_script_trace(self):
13573  class Traced(nn.Module):
13574  def __init__(self):
13575  super(Traced, self).__init__()
13576 
13577  def forward(self, x):
13578  return (torch.neg(x), x)
13579 
13580  class Mod(torch.jit.ScriptModule):
13581  def __init__(self):
13582  super(Mod, self).__init__(False)
13583  x = torch.rand(3, 3)
13584  self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True)
13585 
13586  @torch.jit.script_method
13587  def forward(self, x):
13588  # type: (Tensor) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]
13589  future1 = torch.jit._fork(self.traced, x)
13590  future2 = torch.jit._fork(torch.neg, x)
13591 
13592  tensor_tuple = torch.jit._wait(future1)
13593  tensor_single = torch.jit._wait(future2)
13594 
13595  tensor_list = []
13596  tensor_list.append(tensor_tuple[0])
13597  tensor_list.append(tensor_single)
13598 
13599  # return a nested structure of tensors
13600  return (tensor_list, tensor_tuple, tensor_tuple[1])
13601 
13602  class TupleCl(nn.Module):
13603  def __init__(self):
13604  super(TupleCl, self).__init__()
13605  self.module = Mod()
13606 
13607  def forward(self, x):
13608  z = torch.neg(x)
13609  y = self.module(x)
13610  list = [z, y[0][0], y[0][1], y[1][0], y[1][1], y[2]]
13611  return tuple(list)
13612 
13613  x = torch.rand(3, 3)
13614  module = torch.jit.trace(TupleCl(), (x), _force_outplace=True)
13615 
13616  # Make sure we have forks
13617  self.assertGraphContainsExactly(module.graph, kind='prim::fork', num_kind_nodes=2)
13618  # Make sure 1 ::neg is in the root graph and 2 ::negs are in the subgraphs
13619  self.assertGraphContainsExactly(module.graph, kind='aten::neg', num_kind_nodes=1)
13620  self.assertGraphContainsExactly(module.graph, kind='aten::neg', num_kind_nodes=3, consider_subgraphs=True)
13621 
13622  y = torch.neg(x)
13623  self.assertEqual(module(x), (y, y, y, y, x, x))
13624 
13625  def test_async_script_error(self):
13626  x = torch.rand(3, 4)
13627 
13628  @torch.jit.script
13629  def foo(x):
13630  # error here
13631  return x.t() + x
13632 
13633  @torch.jit.script
13634  def wait_script(x):
13635  fut = torch.jit._fork(foo, x)
13636  return torch.jit._wait(fut)
13637 
13638  @torch.jit.script
13639  def wait_script_nest(x):
13640  fut = torch.jit._fork(wait_script, x)
13641  return torch.jit._wait(fut)
13642 
13643  # no future
13644  error_msg = 'The size.*must match the size of tensor'
13645  with self.assertRaisesRegex(Exception, error_msg):
13646  foo(x)
13647 
13648  # one future
13649  with self.assertRaisesRegex(Exception, error_msg):
13650  wait_script(x)
13651 
13652  # two futures with a different error
13653  x = torch.rand(3, 4, 5)
13654  with self.assertRaisesRegex(Exception, 'expects a tensor with <= 2 dimensions'):
13655  wait_script_nest(x)
13656 
13657  def test_async_grad_guard_with_grad(self):
13658  @torch.jit.script
13659  def foo(x):
13660  y = x * 2
13661  return y.requires_grad
13662 
13663  @torch.jit.script
13664  def bar(x):
13665  fut = torch.jit._fork(foo, x)
13666  requires_grad_in_fork = torch.jit._wait(fut)
13667  z = x * 2
13668  return (requires_grad_in_fork, z.requires_grad)
13669 
13670  x = torch.randn(3, requires_grad=True)
13671 
13672  with torch.enable_grad():
13673  (inside_fork, after_wait) = bar(x)
13674 
13675  self.assertEqual(inside_fork, True)
13676  self.assertEqual(after_wait, True)
13677 
13678  def test_async_grad_guard_no_grad(self):
13679  @torch.jit.script
13680  def foo(x):
13681  y = x * 2
13682  return y.requires_grad
13683 
13684  @torch.jit.script
13685  def bar(x):
13686  fut = torch.jit._fork(foo, x)
13687  requires_grad_in_fork = torch.jit._wait(fut)
13688  z = x * 2
13689  return (requires_grad_in_fork, z.requires_grad)
13690 
13691  x = torch.randn(3, requires_grad=True)
13692 
13693  with torch.no_grad():
13694  (inside_fork, after_wait) = bar(x)
13695 
13696  self.assertEqual(inside_fork, False)
13697  self.assertEqual(after_wait, False)
13698 
13699  def test_trace_fork_wait(self):
13700  def fork_body(x):
13701  return x.neg(), x.neg() + 1
13702 
13703  def fn(x):
13704  fut = torch.jit._fork(fork_body, x)
13705  vals = torch.jit._wait(fut)
13706  return vals[0], vals[1], x - 1
13707 
13708  traced = torch.jit.trace(fn, (torch.rand(3, 4),))
13709  x = torch.rand(3, 4)
13710  self.assertEqual(fn(x), traced(x))
13711 
13712  self.assertGraphContainsExactly(traced.graph, kind='prim::fork', num_kind_nodes=1)
13713  self.assertGraphContainsExactly(traced.graph, kind='aten::wait', num_kind_nodes=1)
13714  self.assertGraphContainsExactly(traced.graph, kind='aten::neg', num_kind_nodes=2, consider_subgraphs=True)
13715 
13716  def test_trace_fork_wait_leaking(self):
13717  my_list = []
13718 
13719  def fork_body(x):
13720  my_list.append(x + 1)
13721  return x + 1
13722 
13723  def fn(x):
13724  fut = torch.jit._fork(fork_body, x)
13725  val = torch.jit._wait(fut)
13726  return my_list[0]
13727 
13728  with self.assertRaisesRegex(RuntimeError, 'did not have observable data dependence with trace inputs; '
13729  'this probably indicates your program cannot be understood '
13730  'by the tracer.'):
13731  traced = torch.jit.trace(fn, (torch.rand(3, 4),), check_trace=False)
13732 
13733  def test_trace_fork_wait_inline(self):
13734  def fork_body(x):
13735  return x + 1, x + 2
13736 
13737  def fn(x):
13738  fut = torch.jit._fork(fork_body, x)
13739  val = torch.jit._wait(fut)
13740  return val[1]
13741 
13742  traced = torch.jit.trace(fn, (torch.rand(3, 4),))
13743  torch._C._jit_pass_inline_fork_wait(traced.graph)
13744  torch._C._jit_pass_dce(traced.graph)
13745  self.assertGraphContainsExactly(traced.graph, kind='prim::fork', num_kind_nodes=0)
13746  self.assertGraphContainsExactly(traced.graph, kind='aten::wait', num_kind_nodes=0)
13747  self.assertGraphContainsExactly(traced.graph, kind='aten::add', num_kind_nodes=2)
13748 
13749  def test_trace_fork_wait_inline_onnx(self):
13750  def fork_body(x):
13751  return torch.neg(x), torch.neg(x)
13752 
13753  class MyMod(torch.nn.Module):
13754  def forward(self, x):
13755  fut = torch.jit._fork(fork_body, x)
13756  val = torch.jit._wait(fut)
13757  return val[1]
13758 
13759  # smoke test for ONNX export
13760  f = io.BytesIO()
13761  torch.onnx.export(MyMod(), (torch.rand(3, 4),), f)
13762 
13763  def test_save_load_with_extra_files(self):
13764  class MyMod(torch.jit.ScriptModule):
13765  @torch.jit.script_method
13766  def forward(self, a):
13767  return a
13768 
13769  expected_extra_files = torch._C.ExtraFilesMap()
13770  expected_extra_files['foo'] = 'bar'
13771  m = MyMod()
13772 
13773  # Save to file.
13774  with TemporaryFileName() as fname:
13775  m.save(fname, _extra_files=expected_extra_files)
13776  extra_files = torch._C.ExtraFilesMap()
13777  extra_files['foo'] = ''
13778  torch.jit.load(fname, _extra_files=extra_files)
13779  self.assertEqual('bar', extra_files['foo'])
13780 
13781  # Use torch.jit API
13782  torch.jit.save(m, fname, _extra_files=expected_extra_files)
13783  extra_files['foo'] = ''
13784  torch.jit.load(fname, _extra_files=extra_files)
13785  self.assertEqual('bar', extra_files['foo'])
13786 
13787  # Save to buffer.
13788  buffer = io.BytesIO(m.save_to_buffer(_extra_files=expected_extra_files))
13789  extra_files = torch._C.ExtraFilesMap()
13790  extra_files['foo'] = ''
13791  torch.jit.load(buffer, _extra_files=extra_files)
13792  self.assertEqual('bar', extra_files['foo'])
13793 
13794  # Use torch.jit API
13795  buffer = io.BytesIO()
13796  torch.jit.save(m, buffer, _extra_files=expected_extra_files)
13797  buffer.seek(0)
13798  extra_files = torch._C.ExtraFilesMap()
13799  extra_files['foo'] = ''
13800  torch.jit.load(buffer, _extra_files=extra_files)
13801  self.assertEqual('bar', extra_files['foo'])
13802 
13803  # Non-existent file 'bar'
13804  with self.assertRaises(RuntimeError):
13805  extra_files['bar'] = ''
13806  torch.jit.load(buffer, _extra_files=extra_files)
13807 
13808 
13809 class TestDataParallel(JitTestCase):
13810  class Mpy(torch.nn.Module):
13811  def __init__(self):
13812  super(TestDataParallel.Mpy, self).__init__()
13813  self.m = nn.Sequential(nn.Linear(2, 2), nn.BatchNorm1d(2),
13814  nn.ReLU(), nn.Linear(2, 2))
13815 
13816  def forward(self, input):
13817  return self.m(input)
13818 
13819  class Mpy1(torch.nn.Module):
13820  def __init__(self, block):
13821  super(TestDataParallel.Mpy1, self).__init__()
13822  self.m = block
13823 
13824  def forward(self, input):
13825  return self.m.forward(input)
13826 
13827  class Mpy2(torch.nn.Module):
13828  def __init__(self, block1, block2):
13829  super(TestDataParallel.Mpy2, self).__init__()
13830  self.m1 = block1
13831  self.m2 = block2
13832 
13833  def forward(self, input):
13834  x = self.m1.forward(input)
13835  return self.m2(x)
13836 
13837  class Msm(torch.jit.ScriptModule):
13838 
13839  __constants__ = ['m']
13840 
13841  def __init__(self):
13842  super(TestDataParallel.Msm, self).__init__(False)
13843  self.m = nn.Sequential(nn.Linear(2, 2), nn.BatchNorm1d(2),
13844  nn.ReLU(), nn.Linear(2, 2))
13845 
13846  @torch.jit.script_method
13847  def forward(self, input):
13848  return self.m(input)
13849 
13850  class Msm1(torch.jit.ScriptModule):
13851  def __init__(self, block):
13852  super(TestDataParallel.Msm1, self).__init__(False)
13853  self.block = block
13854 
13855  @torch.jit.script_method
13856  def forward(self, input):
13857  x = self.block(input)
13858  return x
13859 
13860  def check_replicas(self, module, replicas, input_shape=(2, 2)):
13861  input = torch.randn(input_shape).cuda()
13862  expected_output = module(input).data
13863  for i, replica in enumerate(replicas):
13864  for p in replica.parameters():
13865  self.assertEqual(p.get_device(), i)
13866  for b in replica.buffers():
13867  self.assertEqual(b.get_device(), i)
13868  replica_input = input.cuda(i)
13869  self.assertEqual(replica(replica_input).data, expected_output)
13870 
13871  @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
13872  @skipIfRocm
13873  def test_python_submodule_exception(self):
13874  module = self.Msm1(self.Mpy()).cuda()
13875  msg = "Cannot replicate.*"
13876  with self.assertRaisesRegex(Exception, msg):
13877  dp.replicate(module, {0, 1})
13878 
13879  @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
13880  @skipIfRocm
13881  def test_python_submodule_script(self):
13882  module = self.Mpy1(self.Msm()).cuda()
13883  replicas = dp.replicate(module, {0, 1})
13884  self.check_replicas(module, replicas)
13885 
13886  @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
13887  @skipIfRocm
13888  def test_shared_module(self):
13889  s = self.Msm()
13890  p1 = self.Mpy1(s)
13891  module = self.Mpy2(p1, s).cuda()
13892  replicas = dp.replicate(module, {0, 1})
13893  self.check_replicas(module, replicas)
13894 
13895  @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
13896  @skipIfRocm
13897  def test_traced_module(self):
13898  module = torch.jit.trace(self.Mpy1(self.Mpy()), torch.ones(2, 2)).cuda()
13899  replicas = dp.replicate(module, {0, 1})
13900  self.check_replicas(module, replicas)
13901 
13902  @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
13903  @skipIfRocm
13904  def test_tensor_sharing(self):
13905  module = self.Msm1(self.Msm()).cuda()
13906  replica = dp.replicate(module, {0, 1})
13907  optimizer = optim.SGD(module.parameters(), lr=1, momentum=1)
13908  x = torch.ones(2, 2, requires_grad=True).cuda()
13909  first_forward = module.forward(x)
13910  first_forward.sum().backward()
13911  optimizer.step()
13912  second_forward = module.forward(first_forward)
13913 
13914  # replica which is on the same GPU has a shallow copy of the original
13915  # params and buffers
13916  r0_forward = replica[0].forward(x)
13917  self.assertEqual(second_forward, r0_forward)
13918 
13919  # replca which is on a different GPU has a deep copy of the original
13920  # params and buffers
13921  x1 = torch.ones(2, 2, requires_grad=True).cuda(device=1)
13922  r1_forward = replica[1].forward(x1)
13923  self.assertEqual(first_forward, r1_forward)
13924 
13925 
13926 class TestClassType(JitTestCase):
13927  def test_get_with_method(self):
13928  @torch.jit.script
13929  class FooTest:
13930  def __init__(self, x):
13931  self.foo = x
13932 
13933  def getFooTest(self):
13934  return self.foo
13935 
13936  @torch.jit.script
13937  def fn(x):
13938  foo = FooTest(x)
13939  return foo.getFooTest()
13940 
13941  input = torch.ones(2, 3)
13942  self.assertEqual(fn(input), input)
13943 
13944  def test_get_attr(self):
13945  @torch.jit.script
13946  class FooTest:
13947  def __init__(self, x):
13948  self.foo = x
13949 
13950  @torch.jit.script
13951  def fn(x):
13952  foo = FooTest(x)
13953  return foo.foo
13954 
13955  input = torch.ones(2, 3)
13956  self.assertEqual(fn(input), input)
13957 
13958  def test_set_attr_in_method(self):
13959  @torch.jit.script
13960  class FooTest:
13961  def __init__(self, x):
13962  # type: (int) -> None
13963  self.foo = x
13964 
13965  def incFooTest(self, y):
13966  # type: (int) -> None
13967  self.foo = self.foo + y
13968 
13969  @torch.jit.script
13970  def fn(x):
13971  # type: (int) -> int
13972  foo = FooTest(x)
13973  foo.incFooTest(2)
13974  return foo.foo
13975 
13976  self.assertEqual(fn(1), 3)
13977 
13978  def test_set_attr_type_mismatch(self):
13979  with self.assertRaisesRegex(RuntimeError, "Wrong type for attribute assignment"):
13980  @torch.jit.script
13981  class FooTest:
13982  def __init__(self, x):
13983  self.foo = x
13984  self.foo = 10 # should error since int != Tensor
13985 
13986  def test_get_attr_not_initialized(self):
13987  with self.assertRaisesRegex(RuntimeError, "Tried to access to nonexistent attribute"):
13988  @torch.jit.script
13989  class FooTest:
13990  def __init__(self, x):
13991  self.foo = x
13992 
13993  def get_non_initialized(self):
13994  return self.asdf # asdf isn't an attr
13995 
13996  def test_set_attr_non_initialized(self):
13997  with self.assertRaisesRegex(RuntimeError, "Tried to set nonexistent attribute"):
13998  @torch.jit.script
13999  class FooTest:
14000  def __init__(self, x):
14001  self.foo = x
14002 
14003  def set_non_initialized(self, y):
14004  self.bar = y # can't assign to non-initialized attr
14005 
14006  def test_type_annotations(self):
14007  with self.assertRaisesRegex(RuntimeError, "expected a value of type bool"):
14008  @torch.jit.script
14009  class FooTest:
14010  def __init__(self, x):
14011  # type: (bool) -> None
14012  self.foo = x
14013 
14014  @torch.jit.script
14015  def fn(x):
14016  FooTest(x)
14017 
14018  fn(2)
14019 
14020  def test_conditional_set_attr(self):
14021  with self.assertRaisesRegex(RuntimeError, "assignment cannot be in a control-flow block"):
14022  @torch.jit.script
14023  class FooTest:
14024  def __init__(self, x):
14025  if True:
14026  self.attr = x
14027 
14028  def test_class_type_as_param(self):
14029  @torch.jit.script
14030  class FooTest:
14031  def __init__(self, x):
14032  self.attr = x
14033 
14034  @torch.jit.script
14035  def fn(foo):
14036  # type: (FooTest) -> Tensor
14037  return foo.attr
14038 
14039  @torch.jit.script
14040  def fn2(x):
14041  foo = FooTest(x)
14042  return fn(foo)
14043 
14044  input = torch.ones(1)
14045  self.assertEqual(fn2(input), input)
14046 
14047  def test_out_of_order_methods(self):
14048  @torch.jit.script
14049  class FooTest:
14050  def __init__(self, x):
14051  self.x = x
14052  self.x = self.get_stuff(x)
14053 
14054  def get_stuff(self, y):
14055  return self.x + y
14056 
14057  @torch.jit.script
14058  def fn(x):
14059  f = FooTest(x)
14060  return f.x
14061 
14062  input = torch.ones(1)
14063  self.assertEqual(fn(input), input + input)
14064 
14065  def test_save_load_with_classes(self):
14066  @torch.jit.script
14067  class FooTest:
14068  def __init__(self, x):
14069  self.x = x
14070 
14071  def get_x(self):
14072  return self.x
14073 
14074  class MyMod(torch.jit.ScriptModule):
14075  @torch.jit.script_method
14076  def forward(self, a):
14077  foo = FooTest(a)
14078  return foo.get_x()
14079 
14080  m = MyMod()
14081 
14082  buffer = io.BytesIO()
14083  torch.jit.save(m, buffer)
14084 
14085  # classes are globally registered for now, so we need to clear the JIT
14086  # registry to simulate loading a new model
14087  torch._C._jit_clear_class_registry()
14088 
14089  buffer.seek(0)
14090  m_loaded = torch.jit.load(buffer)
14091 
14092  input = torch.rand(2, 3)
14093  output = m_loaded(input)
14094  self.assertEqual(input, output)
14095 
14096  def test_save_load_with_classes_nested(self):
14097  @torch.jit.script
14098  class FooNestedTest:
14099  def __init__(self, y):
14100  self.y = y
14101 
14102  @torch.jit.script
14103  class FooNestedTest2:
14104  def __init__(self, y):
14105  self.y = y
14106  self.nested = FooNestedTest(y)
14107 
14108  @torch.jit.script
14109  class FooTest:
14110  def __init__(self, x):
14111  self.class_attr = FooNestedTest(x)
14112  self.class_attr2 = FooNestedTest2(x)
14113  self.x = self.class_attr.y + self.class_attr2.y
14114 
14115  class MyMod(torch.jit.ScriptModule):
14116  @torch.jit.script_method
14117  def forward(self, a):
14118  foo = FooTest(a)
14119  return foo.x
14120 
14121  m = MyMod()
14122 
14123  buffer = io.BytesIO()
14124  torch.jit.save(m, buffer)
14125 
14126  # classes are globally registered for now, so we need to clear the JIT
14127  # registry to simulate loading a new model
14128  torch._C._jit_clear_class_registry()
14129 
14130  buffer.seek(0)
14131  m_loaded = torch.jit.load(buffer)
14132 
14133  input = torch.rand(2, 3)
14134  output = m_loaded(input)
14135  self.assertEqual(2 * input, output)
14136 
14137 
14138 for test in autograd_method_tests():
14139  add_autograd_test(*test)
14140 
14141 for test in nn_functional_tests:
14142  add_nn_functional_test(*test)
14143 
14144 for test in module_tests + new_module_tests + additional_module_tests:
14145  add_nn_module_test(**test)
14146 
14147 for test in criterion_tests:
14148  test['no_grad'] = True
14149  add_nn_module_test(**test)
14150 
14151 if __name__ == '__main__':
14152  run_tests()
def _unwrap_optional(x)
Definition: __init__.py:1492
def export_to_pretty_string(args, kwargs)
Definition: __init__.py:30
def _calculate_fan_in_and_fan_out(tensor)
Definition: init.py:178
def assertEqual(self, x, y, prec=None, message='', allow_inf=False)
def _test_dcgan_models(self, device, check_export_import=True)
Definition: test_jit.py:10938
def createScriptModuleFromGraph(self, trace)
Definition: test_jit.py:540
def assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True)
Definition: __init__.py:18
def load(f, map_location=None, _extra_files=DEFAULT_EXTRA_FILES_MAP)
Definition: __init__.py:82
def reshape_from_tensor_shape(x, shape)
Definition: operators.py:19
def capture_stdout(self)
Definition: test_jit.py:2975
def run_ge_tests(self, optimize, use_cuda)
Definition: test_jit.py:1465
def checkTracerWarning(self, args, kwargs)
Definition: test_jit.py:9596
def getExportImportCopyWithPacking(self, m, also_test_file=True, map_location=None)
Definition: test_jit.py:321
def format_code(self, code, pair)
Definition: test_jit.py:9339
def _test_mnist(self, device, check_export_import=True)
Definition: test_jit.py:11120
def assertExpected(self, s, subname=None)
def verify(model, args, loss_fn=torch.sum, devices=None)
Definition: __init__.py:342
def rand_batch(self, dims)
Definition: test_jit.py:2428
def _test_reinforcement_learning(self, device, test_export_import=True)
Definition: test_jit.py:11159
def annotate(the_type, the_value)
Definition: __init__.py:1560
Definition: test.py:1
def pack_sequence(sequences, enforce_sorted=True)
Definition: rnn.py:381
def is_available()
Definition: __init__.py:45
Definition: model.py:1
def script(obj, optimize=True, _frames_up=0, _rcb=None)
Definition: __init__.py:727
def device_count()
Definition: __init__.py:341
def disableModuleHook(self)
Definition: test_jit.py:274
def trace(func, example_inputs, optimize=True, check_trace=True, check_inputs=None, check_tolerance=1e-5, _force_outplace=False, _module_class=None)
Definition: __init__.py:596
def test_arg_configurations(self)
Definition: test_jit.py:916
def test_addmm_grad(self)
Definition: test_jit.py:8093
def sigmoid(input)
Definition: functional.py:1375
def assertNotEqual(self, x, y, prec=None, message='')
def get_trace_graph(f, args=(), kwargs=None, _force_outplace=False, return_inputs=False)
Definition: __init__.py:192
def get_sum_list_fn(self)
Definition: test_jit.py:3167
def type_input_return_pairs(self)
Definition: test_jit.py:9325
def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None)
Definition: rnn.py:272
def get_jit_def(fn, self_name=None)
Definition: frontend.py:153
def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
Definition: init.py:261
def test_input_flatten(self)
Definition: test_jit.py:1252
def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices")
Definition: random.py:49
def quantize_linear_modules(module)
Definition: quantized.py:232
def save(m, f, _extra_files=DEFAULT_EXTRA_FILES_MAP)
Definition: __init__.py:154
Module caffe2.python.helpers.train.
def runAndSaveRNG(self, func, inputs, kwargs=None)
Definition: test_jit.py:555
def is_tensor(obj)
Definition: __init__.py:114
def run_pass(self, name, trace)
Definition: test_jit.py:388
def scope(scope_name)
Definition: __init__.py:68
def _export(args, kwargs)
Definition: __init__.py:20
def set_training(args, kwargs)
Definition: __init__.py:45
def assertExportImportModule(self, m, inputs)
Definition: test_jit.py:550
def batch(batch_size=1, optimize=True, _frames_up=0)
Definition: __init__.py:793
def pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)
Definition: rnn.py:221
def _check_code(self, code_str, fn_name, inputs)
Definition: test_jit.py:3477
def get_device_capability(device=None)
Definition: __init__.py:280
def checkTrace(self, func, reference_tensors, input_tensors=None, optimize=True, drop=None, allow_unused=False, verbose=False, inputs_require_grads=True, check_tolerance=1e-5, export_import=True, _force_outplace=False)
Definition: test_jit.py:455
def test_output_unflatten(self)
Definition: test_jit.py:1243
def do_trace_size(self, requires_grad)
Definition: test_jit.py:1112
def uniform_(tensor, a=0, b=1)
Definition: init.py:50
def prelu(input, weight)
Definition: functional.py:1107
def assertLeaksNoCudaTensors(self, name=None)
def checkScript(self, script, inputs, optimize=True, outputs=None, name='func', capture_output=False, frames_up=1, check_expected=False)
Definition: test_jit.py:414
def _get_py3_code(self, code, fn_name)
Definition: test_jit.py:9411
def _export_to_pretty_string(args, kwargs)
Definition: __init__.py:35
Module caffe2.python.helpers.dropout.
def dropout(input, p=0.5, training=True, inplace=False)
Definition: functional.py:807
def _test_neural_style(self, device, check_export_import=True)
Definition: test_jit.py:11012
def checkScriptRaisesRegex(self, script, inputs, exception, regex, optimize=True, outputs=None, capture_output=False)
Definition: test_jit.py:3015
def assertExportImport(self, trace, inputs)
Definition: test_jit.py:546
def getExportImportCopy(self, m, also_test_file=True, map_location=None)
Definition: test_jit.py:308
def current_device()
Definition: __init__.py:349
def _make_scalar_vars(self, arr, dtype)
Definition: test_jit.py:4662
def shape_as_tensor(x)
Definition: operators.py:15
def grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False, only_inputs=True, allow_unused=False)
Definition: __init__.py:97
def assertGraphContains(self, graph, kind)
Definition: test_jit.py:347
def export(args, kwargs)
Definition: __init__.py:25
def assertExpectedGraph(self, trace, args, kwargs)
Definition: test_jit.py:375
def emitModuleHook(self, module)
Definition: test_jit.py:279
def quantize_rnn_cell_modules(module)
Definition: quantized.py:212
def _optimize_trace(trace, operator_export_type)
Definition: __init__.py:40
def assertWarnsRegex(self, callable, regex, msg='')