Caffe2 - Python API
A deep learning, cross platform ML framework
test_jit.py
1 from __future__ import division
2 import torch
3 import torch.jit
4 import torch.nn as nn
5 import torch.nn.functional as F
6 import torch.nn.parallel as dp
7 import torch.optim as optim
8 import torch.cuda
10 from contextlib import contextmanager
11 from itertools import product, chain
12 import torch.jit.frontend
13 from torch.autograd import Variable, Function
14 from torch.nn import Module
15 from torch.autograd.function import traceable
16 from torch.testing import assert_allclose
17 from torch.onnx import OperatorExportTypes
18 from torch._six import inf, PY2, builtins
19 from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
20  skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
21  freeze_rng_state, set_rng_seed
22 from common_nn import module_tests, new_module_tests, criterion_tests
23 from textwrap import dedent
24 from functools import wraps
25 import os
26 import io
27 import itertools
28 import sys
29 import unittest
30 import inspect
31 import textwrap
32 import numpy as np
33 import tempfile
34 import shutil
35 import warnings
36 import math
37 import types
38 import pickle
39 import copy
40 
41 from common_methods_invocations import method_tests as autograd_method_tests
42 from common_methods_invocations import create_input, unpack_variables, \
43  exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL
44 from torch.testing import FileCheck
45 from torch._C import TensorType, TupleType, FloatType, IntType, \
46  ListType, StringType, DictType
47 from copy import deepcopy
48 import random
49 from typing import List, Dict, Optional, Tuple
50 from torch.jit.frontend import NotSupportedError
51 from torch.jit import BatchTensor
52 from torch import Tensor
53 from torch.jit.annotations import BroadcastingList2, BroadcastingList3
54 
55 # For testing truediv in python 2
56 from test_module.future_div import div_int_future, div_float_future
57 from test_module.no_future_div import div_int_nofuture, div_float_nofuture
58 
59 
60 # load_tests from common_utils is used to automatically filter tests for
61 # sharding on sandcastle. This line silences flake warnings
62 load_tests = load_tests
63 
64 try:
65  import torchvision
66  HAS_TORCHVISION = True
67 except ImportError:
68  HAS_TORCHVISION = False
69 
70 
71 skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
72 
73 RUN_CUDA = torch.cuda.is_available()
74 RUN_CUDA_HALF = RUN_CUDA
76  CUDA_VERSION = torch._C._cuda_getCompiledVersion()
77  for d in range(torch.cuda.device_count()):
79  if (CUDA_VERSION < 8000 and major >= 6) or (CUDA_VERSION < 9000 and major >= 7):
80  RUN_CUDA = False
81  if (CUDA_VERSION < 9000 or major < 6):
82  RUN_CUDA_HALF = False
83 
84 RUN_CUDA_MULTI_GPU = RUN_CUDA and torch.cuda.device_count() > 1
85 
86 PY35 = sys.version_info >= (3, 5)
87 WINDOWS = sys.platform == 'win32'
88 
89 
90 if WINDOWS:
91  @contextmanager
92  def TemporaryFileName():
93  # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
94  # opens the file, and it cannot be opened multiple times in Windows. To support Windows,
95  # close the file after creation and try to remove it manually
96  f = tempfile.NamedTemporaryFile(delete=False)
97  try:
98  f.close()
99  yield f.name
100  finally:
101  os.unlink(f.name)
102 else:
103  @contextmanager # noqa: T484
104  def TemporaryFileName():
105  with tempfile.NamedTemporaryFile() as f:
106  yield f.name
107 
108 
109 def LSTMCellF(input, hx, cx, *params):
110  return LSTMCell(input, (hx, cx), *params)
111 
112 
113 def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
114  hx, cx = hidden
115  gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
116 
117  ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
118  ingate = torch.sigmoid(ingate)
119  forgetgate = torch.sigmoid(forgetgate)
120  cellgate = torch.tanh(cellgate)
121  outgate = torch.sigmoid(outgate)
122 
123  cy = (forgetgate * cx) + (ingate * cellgate)
124  hy = outgate * torch.tanh(cy)
125  return hy, cy
126 
127 
128 def LSTMCellC(*args, **kwargs):
129  hy, cy = LSTMCellF(*args, **kwargs)
130  return torch.cat((hy, cy))
131 
132 
133 def LSTMCellS(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
134  gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh
135  ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
136  ingate = torch.sigmoid(ingate)
137  forgetgate = torch.sigmoid(forgetgate)
138  cellgate = torch.tanh(cellgate)
139  outgate = torch.sigmoid(outgate)
140  cy = (forgetgate * cx) + (ingate * cellgate)
141  hy = outgate * torch.tanh(cy)
142  return hy, cy
143 
144 
145 # Code reference: https://github.com/pytorch/translate/blob/master/pytorch_translate/rnn_cell.py#L27:44
146 def MiLSTMCell(x, hx, cx, w_ih, w_hh, alpha, beta_i, beta_h, bias):
147  Wx = x.mm(w_ih.t())
148  Uz = hx.mm(w_hh.t())
149  # Section 2.1 in https://arxiv.org/pdf/1606.06630.pdf
150  gates = alpha * Wx * Uz + beta_i * Wx + beta_h * Uz + bias
151  # Same as LSTMCell after this point
152  ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
153  ingate = ingate.sigmoid()
154  forgetgate = forgetgate.sigmoid()
155  cellgate = cellgate.tanh()
156  outgate = outgate.sigmoid()
157  cy = (forgetgate * cx) + (ingate * cellgate)
158  hy = outgate * cy.tanh()
159  return hy, cy
160 
161 
162 def canonical(graph):
163  return str(torch._C._jit_pass_canonicalize(graph))
164 
165 
166 def get_lstm_inputs(device, training=False, seq_length=None):
167  input_shape = (3, 10) if seq_length is None else (seq_length, 3, 10)
168  input = torch.randn(*input_shape, dtype=torch.float, device=device, requires_grad=training)
169  hx = torch.randn(3, 20, dtype=torch.float, device=device, requires_grad=training)
170  cx = torch.randn(3, 20, dtype=torch.float, device=device, requires_grad=training)
171  module = nn.LSTMCell(10, 20).to(device, torch.float) # Just to allocate weights with correct sizes
172  if training:
173  params = tuple(module.parameters())
174  else:
175  params = tuple(p.requires_grad_(False) for p in module.parameters())
176  return (input, hx, cx) + params
177 
178 
179 def get_milstm_inputs(device, training=False):
180  minibatch = 3
181  input_size = 10
182  hidden_size = 20
183  x = torch.randn(minibatch, input_size, device=device, dtype=torch.float)
184  hx = torch.randn(minibatch, hidden_size, device=device, dtype=torch.float)
185  cx = torch.randn(minibatch, hidden_size, device=device, dtype=torch.float)
186 
187  ih = torch.randn(4 * hidden_size, input_size, device=device, dtype=torch.float, requires_grad=training)
188  hh = torch.randn(4 * hidden_size, hidden_size, device=device, dtype=torch.float, requires_grad=training)
189  alpha = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
190  ibeta = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
191  hbeta = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
192  bias = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
193  return x, hx, cx, ih, hh, alpha, ibeta, hbeta, bias
194 
195 
196 def get_fn(file_name, script_path):
197  import importlib.util
198  spec = importlib.util.spec_from_file_location(file_name, script_path)
199  module = importlib.util.module_from_spec(spec)
200  spec.loader.exec_module(module)
201  fn = module.fn
202  return fn
203 
204 
205 def get_execution_plan(graph_executor_state):
206  execution_plans = list(graph_executor_state.execution_plans.values())
207  num_plans = len(execution_plans)
208  if num_plans != 1:
209  raise RuntimeError('This test assumes this GraphExecutor should '
210  'only have one execution plan, got: {}'.format(num_plans))
211  return execution_plans[0]
212 
213 
214 def get_grad_executor(plan_state, diff_graph_idx=None):
215  if diff_graph_idx is None:
216  nodes = list(plan_state.graph.nodes())
217  if len(nodes) == 1 or (len(nodes) == 2 and nodes[1].kind() == "prim::TupleConstruct"):
218  pass
219  else:
220  raise RuntimeError("Can't get a grad_executor for a non-differentiable graph")
221  grad_executors = list(plan_state.code.grad_executors())
222  return grad_executors[diff_graph_idx or 0]
223 
224 
225 def backward_graph(script_module, diff_graph_idx=None):
226  if not isinstance(script_module, torch.jit.ScriptModule):
227  raise RuntimeError('Expected ScriptModule')
228  ge_state = script_module.get_debug_state()
229  fwd_plan = get_execution_plan(ge_state)
230  grad_executor = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx)
231  bwd_plan = get_execution_plan(grad_executor.get_debug_state())
232  # Running JIT passes requires that we own the graph (with a shared_ptr).
233  # The debug state struct does not own its graph so we make a copy of it.
234  return bwd_plan.graph.copy()
235 
236 
237 # make it easy to quicky define/trace a function for these tests
238 def _trace(*args, **kwargs):
239  def wrapper(func):
240  return torch.jit.trace(func, args, **kwargs)
241  return wrapper
242 
243 
244 def enable_cpu_fuser(fn):
245  def wrapper(*args, **kwargs):
246  torch._C._jit_override_can_fuse_on_cpu(True)
247  try:
248  fn(*args, **kwargs)
249  finally:
250  torch._C._jit_override_can_fuse_on_cpu(False)
251  return wrapper
252 
253 
255  _do_cuda_memory_leak_check = True
256  _restored_warnings = False
257 
258  def setUp(self):
259  # unittest overrides all warning filters and forces all of them to show up
260  # after we install our own to silence those coming from inside PyTorch.
261  # This will ensure that our filter still takes precedence.
262  if not JitTestCase._restored_warnings:
264  JitTestCase._restored_warnings = True
265  torch._C._jit_set_emit_module_hook(self.emitModuleHook)
266 
267  def tearDown(self):
268  # needs to be cleared because python might be unloaded before
269  # the callback gets destucted
270  torch._C._jit_set_emit_module_hook(None)
271  torch._C._jit_clear_class_registry()
272 
273  @contextmanager
274  def disableModuleHook(self):
275  torch._C._jit_set_emit_module_hook(None)
276  yield None
277  torch._C._jit_set_emit_module_hook(self.emitModuleHook)
278 
279  def emitModuleHook(self, module):
280  def copy_structure_and_params(m):
282  for name, v in m._get_parameters():
283  c._register_parameter(name, v, False)
284  for name, the_type, v in m._get_attributes():
285  c._register_attribute(name, the_type, v)
286  for name, s in m._get_modules():
287  c._register_module(name, copy_structure_and_params(s))
288  return c
289 
290  # disable the hook while we parse code, otherwise we will re-enter the hook
291  with self.disableModuleHook():
292  try:
293  pp, constant_table = module._python_print()
294  except RuntimeError as e:
295  se = str(e)
296  if "could not export python function" not in se and \
297  "closures are not exportable" not in se:
298  raise
299  else:
300  return
301  ppv = "op_version_set = 0\n{}".format(pp)
302  sm = copy_structure_and_params(module)
303  torch._C._jit_import_methods(sm, ppv, constant_table)
304  pp2, _ = sm._python_print()
305  if pp != pp2:
306  self.assertMultiLineEqual(pp, pp2)
307 
308  def getExportImportCopy(self, m, also_test_file=True, map_location=None):
309  buffer = io.BytesIO()
310  torch.jit.save(m, buffer)
311  buffer.seek(0)
312  imported = torch.jit.load(buffer, map_location=map_location)
313 
314  if not also_test_file:
315  return imported
316 
317  with TemporaryFileName() as fname:
318  imported.save(fname)
319  return torch.jit.load(fname, map_location=map_location)
320 
321  def getExportImportCopyWithPacking(self, m, also_test_file=True, map_location=None):
322  buffer = io.BytesIO()
323  m.apply(lambda s: s._pack() if s._has_method('_pack') else None)
324  torch.jit.save(m, buffer)
325  m.apply(lambda s: s._unpack() if s._has_method('_unpack') else None)
326  buffer.seek(0)
327  imported = torch.jit.load(buffer, map_location=map_location)
328  imported.apply(lambda s: s._unpack() if s._has_method('_unpack') else None)
329 
330  if not also_test_file:
331  return imported
332 
333  # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
334  # opens the file, and it cannot be opened multiple times in Windows. To support Windows,
335  # close the file after creation and try to remove it manually
336  f = tempfile.NamedTemporaryFile(delete=False)
337  try:
338  f.close()
339  imported.save(f.name)
340  result = torch.jit.load(f.name, map_location=map_location)
341  finally:
342  os.unlink(f.name)
343 
344  result.apply(lambda s: s._unpack() if s._has_method('_unpack') else None)
345  return result
346 
347  def assertGraphContains(self, graph, kind):
348  self.assertTrue(any(n.kind() == kind for n in graph.nodes()))
349 
350  def assertGraphContainsExactly(self, graph, kind, num_kind_nodes, consider_subgraphs=False):
351  def perform_assert(graph, kind, actual, expected, consider_subgraphs):
352  if actual == expected:
353  return
354  subgraph = 'including' if consider_subgraphs else 'excluding'
355  raise AssertionError(
356  '{}\nError: graph contains {} {} nodes ({} subgraphs) but expected {}'.format(
357  graph, actual, kind, subgraph, expected))
358 
359  if consider_subgraphs:
360  strgraph = str(graph)
361  count = strgraph.count(kind) - strgraph.count('with {}'.format(kind))
362  perform_assert(graph, kind, count, num_kind_nodes,
363  consider_subgraphs)
364  return
365 
366  nodes = [node for node in graph.nodes()
367  if node.kind() == kind]
368  perform_assert(graph, kind, len(nodes), num_kind_nodes,
369  consider_subgraphs)
370 
371  def assertExpectedONNXGraph(self, trace, *args, **kwargs):
372  torch.onnx._optimize_trace(trace, operator_export_type=OperatorExportTypes.ONNX)
373  self.assertExpectedGraph(trace, *args, **kwargs)
374 
375  def assertExpectedGraph(self, trace, *args, **kwargs):
376  if isinstance(trace, torch._C.Graph):
377  graph = trace
378  else:
379  graph = trace.graph()
380 
381  torch._C._jit_pass_lint(graph)
382  torch._C._jit_pass_dce(graph)
383  torch._C._jit_pass_lint(graph)
384  graph = torch._C._jit_pass_canonicalize(graph)
385  torch._C._jit_pass_lint(graph)
386  self.assertExpected(str(graph), *args, **kwargs)
387 
388  def run_pass(self, name, trace):
389  if isinstance(trace, torch._C.Graph):
390  graph = trace
391  set_graph = False
392  else:
393  set_graph = True
394  graph = trace.graph()
395 
396  torch._C._jit_pass_lint(graph)
397  result = getattr(torch._C, '_jit_pass_' + name)(graph)
398  if result is not None:
399  graph = result
400  torch._C._jit_pass_lint(graph)
401 
402  if set_graph:
403  trace.set_graph(graph)
404  return graph
405 
406  def checkScript(self,
407  script,
408  inputs,
409  optimize=True,
410  outputs=None,
411  name='func',
412  capture_output=False,
413  frames_up=1,
414  check_expected=False):
415  if isinstance(script, str):
416  cu = torch.jit.CompilationUnit(script, optimize, _frames_up=frames_up)
417  ge = getattr(cu, name)
418  else:
419  if capture_output:
420  with self.capture_stdout() as captured:
421  outputs = script(*inputs)
422  else:
423  outputs = script(*inputs)
424  # Check the string frontend first
425  source = textwrap.dedent(inspect.getsource(script))
426  self.checkScript(
427  source,
428  inputs,
429  optimize,
430  outputs,
431  script.__name__,
432  capture_output,
433  frames_up=2,
434  check_expected=check_expected)
435  # Continue checking the Python frontend
436  ge = torch.jit.script(script, optimize, _frames_up=1)
437 
438  if capture_output:
439  with self.capture_stdout() as captured:
440  outputs_ge = ge(*inputs)
441  if not WINDOWS:
442  self.assertExpected(captured[0], subname='stdout')
443  else:
444  outputs_ge = ge(*inputs)
445  self.assertEqual(outputs, outputs_ge)
446 
447  if check_expected:
448  self.assertExpectedGraph(ge.graph)
449 
450  return ge
451 
452  def checkTrace(self, func, reference_tensors, input_tensors=None,
453  optimize=True, drop=None, allow_unused=False, verbose=False,
454  inputs_require_grads=True, check_tolerance=1e-5, export_import=True,
455  _force_outplace=False):
456  # TODO: check gradients for parameters, not just inputs
457  def allSum(vs):
458  # drop allows us to remove some values from ever being used
459  # to test unused outputs
460  if drop is not None:
461  vs = vs[:-drop]
462  # we don't want all the grad for all the outputs to be the same
463  # so we multiply each by a constant
464  return sum(math.log(i + 2) * v.sum() for i, v in enumerate(vs) if v is not None)
465  if input_tensors is None:
466  input_tensors = reference_tensors
467 
468  nograd_inputs = reference_tensors
469  if inputs_require_grads:
470  recording_inputs = [t.clone().requires_grad_() for t in reference_tensors]
471  else:
472  recording_inputs = reference_tensors
473 
474  if isinstance(func, torch._C.Graph):
475  ge = torch._C.GraphExecutor(func, optimize)
476  else:
477  ge = torch.jit.trace(func, input_tensors, optimize=optimize, check_tolerance=check_tolerance,
478  _force_outplace=_force_outplace)
479 
480  if export_import:
481  ge = self.getExportImportCopy(ge)
482 
483  if verbose:
484  print(ge.graph)
485 
486  # test no gradients case
487  outputs = func(*nograd_inputs)
488  outputs_ge = ge(*nograd_inputs)
489  self.assertEqual(outputs, outputs_ge)
490 
491  # test single grad case
492  outputs = func(*recording_inputs)
493  if inputs_require_grads:
494  grads = torch.autograd.grad(allSum(outputs), recording_inputs,
495  allow_unused=allow_unused)
496 
497  outputs_ge = ge(*recording_inputs)
498  if inputs_require_grads:
499  grads_ge = torch.autograd.grad(allSum(outputs_ge), recording_inputs,
500  allow_unused=allow_unused)
501  self.assertEqual(outputs, outputs_ge)
502  if inputs_require_grads:
503  self.assertEqual(grads, grads_ge)
504 
505  # test the grad grad case
506 
507  outputs = func(*recording_inputs)
508  l1 = allSum(outputs)
509  if inputs_require_grads:
510  grads = torch.autograd.grad(l1, recording_inputs, create_graph=True,
511  allow_unused=allow_unused)
512  if inputs_require_grads:
513  l2 = (allSum(grads) * l1)
514  grads2 = torch.autograd.grad(l2, recording_inputs, allow_unused=allow_unused)
515 
516  if inputs_require_grads:
517  recording_inputs = [Variable(t, requires_grad=True)
518  for t in reference_tensors]
519 
520  outputs_ge = ge(*recording_inputs)
521  l1_ge = allSum(outputs_ge)
522  if inputs_require_grads:
523  grads_ge = torch.autograd.grad(
524  l1_ge, recording_inputs, create_graph=True, allow_unused=allow_unused)
525 
526  if inputs_require_grads:
527  l2_ge = (allSum(grads_ge) * l1_ge)
528  grads2_ge = torch.autograd.grad(l2_ge, recording_inputs, allow_unused=allow_unused)
529 
530  self.assertEqual(outputs, outputs_ge)
531  if inputs_require_grads:
532  self.assertEqual(grads, grads_ge)
533  for g2, g2_ge in zip(grads2, grads2_ge):
534  if g2 is None and g2_ge is None:
535  continue
536  self.assertTrue(torch.allclose(g2, g2_ge, atol=8e-4, rtol=8e-4))
537 
538  return ge
539 
540  def createScriptModuleFromGraph(self, trace):
541  graph = trace if isinstance(trace, torch._C.Graph) else trace.graph()
543  m._create_method_from_graph("forward", graph)
544  return m
545 
546  def assertExportImport(self, trace, inputs):
547  m = self.createScriptModuleFromGraph(trace)
548  self.assertExportImportModule(m, inputs)
549 
550  def assertExportImportModule(self, m, inputs):
551  m_import = self.getExportImportCopy(m)
552  self.assertEqual(self.runAndSaveRNG(m.forward, inputs),
553  self.runAndSaveRNG(m_import.forward, inputs))
554 
555  def runAndSaveRNG(self, func, inputs, kwargs=None):
556  kwargs = kwargs if kwargs else {}
557  with freeze_rng_state():
558  results = func(*inputs, **kwargs)
559  return results
560 
561 
562 # has to be at top level or Pickle complains
563 class FooToPickle(torch.nn.Module):
564  def __init__(self):
565  super(FooToPickle, self).__init__()
566  self.bar = torch.jit.ScriptModule()
567 
568 
570 
571  @unittest.skip("Requires a lot of RAM")
572  def test_big(self):
574  gig = int(1024 * 1024 * 1024 / 4)
575  # a small tensor in the first 4GB
576  m.v0 = nn.Parameter(torch.full((2,), 1, dtype=torch.float))
577  # a large tensor in the first 4GB that ends outside of it
578  m.v1 = nn.Parameter(torch.full((5, gig), 2, dtype=torch.float))
579  # a small tensor in >4GB space
580  m.v2 = nn.Parameter(torch.full((2,), 3, dtype=torch.float))
581  # s large tensor in the > 4GB space
582  m.v3 = nn.Parameter(torch.full((5, gig), 4, dtype=torch.float))
583 
584  m2 = self.getExportImportCopy(m)
585 
586  self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
587 
588  def test_simple(self):
589  x = torch.tensor([0.4], requires_grad=True)
590  y = torch.tensor([0.7], requires_grad=True)
591 
592  def f(x, y):
593  return torch.sigmoid(torch.tanh(x * (x + y)))
594 
595  self.checkTrace(f, (x, y))
596 
597  def test_restore_device(self):
598  # main purpose is checking map_location works
600  cpu_device_str = 'cpu'
601  m.p0 = nn.Parameter(torch.tensor([0.3], dtype=torch.float,
602  device=cpu_device_str))
603  m.register_buffer('b0', torch.tensor([0.9], dtype=torch.float,
604  device=cpu_device_str))
605  m2 = self.getExportImportCopy(m)
606  self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
607  self.assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
608  self.assertFalse(m2.p0.is_cuda)
609  self.assertFalse(m2.b0.is_cuda)
610 
611  def test_model_save_error(self):
612  with TemporaryFileName() as fname:
613  with self.assertRaisesRegex(pickle.PickleError, "not supported"):
614  torch.save(FooToPickle(), fname)
615 
616  def test_single_tuple_trace(self):
617  x = torch.tensor(2.)
618 
619  def f2(x):
620  return (x,)
621  jit_f2 = torch.jit.trace(f2, x)
622  assert f2(x) == jit_f2(x) # fails
623 
624  @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA")
625  def test_restore_device_cuda(self):
626  class MyModule(torch.jit.ScriptModule):
627  def __init__(self):
628  super(MyModule, self).__init__(False)
629  self.register_buffer('b0', torch.randn(1, 3))
630  self.p0 = nn.Parameter(torch.randn(2, 3))
631 
632  @torch.jit.script_method
633  def forward(self, x):
634  return x + self.b0 + self.p0
635 
636  m = MyModule()
637  m.cuda(torch.cuda.device_count() - 1)
638  cuda_device_str = 'cuda:' + str(torch.cuda.device_count() - 1)
639 
640  self.assertTrue(m.p0.is_cuda)
641  self.assertTrue(m.b0.is_cuda)
642 
643  # restore to the saved devices
644  m2 = self.getExportImportCopy(m)
645  self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
646  self.assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
647  self.assertEqual(str(m2.p0.device), cuda_device_str)
648  self.assertEqual(str(m2.b0.device), cuda_device_str)
649 
650  # restore all to cpu using string
651  cpu_device_str = 'cpu'
652  m3 = self.getExportImportCopy(m, map_location=cpu_device_str)
653  self.assertEqual(str(m3.p0.device), cpu_device_str)
654  self.assertEqual(str(m3.b0.device), cpu_device_str)
655 
656  # restore all to first gpu using device
657  m4 = self.getExportImportCopy(
658  m3, map_location=torch.device('cuda:0'))
659  self.assertEqual(str(m4.p0.device), 'cuda:0')
660  self.assertEqual(str(m4.b0.device), 'cuda:0')
661 
662  # compute and compare the results
663  input = torch.rand(2, 3).cuda(torch.cuda.device_count() - 1)
664  origin_result = m(input)
665  self.assertEqual(origin_result, m2(input))
666  self.assertEqual(origin_result, m3(input.cpu()))
667  self.assertEqual(origin_result, m4(input.cuda(0)))
668 
669  @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA")
670  def test_restore_shared_storage_on_cuda(self):
671  whole_tensor = torch.randn(4, 5, dtype=torch.float, device='cpu')
673  m.p0 = nn.Parameter(whole_tensor.narrow(0, 0, 1))
674  m.register_buffer('b0', whole_tensor.narrow(0, 3, 1))
675  m2 = self.getExportImportCopy(m, map_location=torch.device('cuda:0'))
676  self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
677  self.assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
678  self.assertTrue(m2.p0.is_cuda)
679  self.assertTrue(m2.b0.is_cuda)
680  self.assertTrue(m2.p0.is_shared())
681  self.assertTrue(m2.b0.is_shared())
682  self.assertEqual(m2.b0.storage().data_ptr(), m2.p0.storage().data_ptr())
683 
684  def test_typeas_trace_check(self):
685  a = torch.tensor([0.4], requires_grad=True)
686  b = torch.tensor([0.7], requires_grad=True)
687 
688  def f(x, y):
689  return x.type_as(y)
690 
691  trace = torch.jit.trace(f, (a, b))
692 
693  def test_peephole(self):
694  a = torch.tensor([0.4])
695  b = torch.tensor([0.7])
696  c = torch.tensor([0], dtype=torch.int32)
697 
698  def f(x, y):
699  return x.type_as(y)
700 
701  tf = torch.jit.trace(f, (a, b))
702  FileCheck().check("type_as").run(str(tf.graph))
703  self.run_pass('peephole', tf.graph)
704  FileCheck().check_not("type_as").run(str(tf.graph))
705  tf2 = torch.jit.trace(f, (a, c))
706  s = str(tf2.graph)
707  self.run_pass('peephole', tf2.graph)
708  self.assertEqual(s, str(s))
709 
710  def test_peephole_dynamic(self):
711  def f(x, y):
712  return x.type_as(y)
713 
714  fn = torch.jit.script(f)
715  s = str(fn.graph)
716  torch._C._jit_pass_peephole(fn.graph)
717  self.assertEqual(s, str(fn.graph))
718 
719  @unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA")
720  def test_peephole_cuda(self):
721  a = torch.tensor([0.4], device='cpu')
722  b = torch.tensor([0.7], device='cuda')
723  c = torch.tensor([0.7], device='cuda')
724 
725  def f(x, y):
726  return x.type_as(y)
727 
728  trace = torch.jit.trace(f, (a, c))
729  s = str(trace.graph)
730  self.run_pass('peephole', trace.graph)
731  self.assertEqual(s, str(trace.graph))
732  trace = torch.jit.trace(f, (b, c))
733  self.run_pass('peephole', trace.graph)
734  self.assertTrue(len(list(trace.graph.nodes())) == 0)
735 
736  def test_index(self):
737  x = torch.tensor([0.4], requires_grad=True)
738  y = torch.tensor([0], dtype=torch.int64)
739 
740  def fn(x, y):
741  return x[y]
742 
743  fn_traced = torch.jit.trace(fn, (x, y,))
744 
745  self.assertEqual(fn(x, y), fn_traced(x, y))
746 
747  def test_disabled(self):
748  torch.jit._enabled = False
749  try:
750  def f(x, y):
751  return x + y
752 
753  self.assertIs(torch.jit.trace(f, (torch.randn(2, 2), torch.randn(2, 2))), f)
754  self.assertIs(torch.jit.script(f), f)
755 
756  class MyModule(torch.jit.ScriptModule):
757  @torch.jit.script_method
758  def method(self, x):
759  return x
760 
761  # XXX: Unfortunately ScriptModule won't simply become Module now,
762  # because that requires disabling the JIT at startup time, which
763  # we can't do in here.
764  # We need to or those two conditions to make it work with all versions of Python
765  self.assertTrue(inspect.ismethod(MyModule.method) or inspect.isfunction(MyModule.method))
766  finally:
767  torch.jit._enabled = True
768 
769  def test_train_eval(self):
770  class Sub(nn.Module):
771  def forward(self, input):
772  if self.training:
773  return input
774  else:
775  return -input
776 
777  class MyModule(torch.jit.ScriptModule):
778  def __init__(self, module):
779  super(MyModule, self).__init__()
780  self.module = module
781 
782  @torch.jit.script_method
783  def forward(self, input):
784  return self.module(input) + 1
785 
786  m = MyModule(Sub())
787  input = torch.rand(3, 4)
788  self.assertEqual(input + 1, m(input))
789  m.eval()
790  self.assertEqual(-input + 1, m(input))
791 
792  # test batchnorm and dropout train/eval
793  input = torch.randn(6, 10)
794  batchnorm = nn.BatchNorm1d(10)
795  dropout = nn.Dropout(p=0.2)
796 
797  m_batchnorm = MyModule(batchnorm)
798  self.assertEqual(batchnorm(input) + 1, m_batchnorm(input))
799  batchnorm.eval()
800  m_batchnorm.eval()
801  self.assertEqual(batchnorm(input) + 1, m_batchnorm(input))
802 
803  m_dropout = MyModule(dropout)
804  dropout.eval()
805  m_dropout.eval()
806  self.assertEqual(dropout(input) + 1, m_dropout(input))
807 
808  def test_diff_subgraph_clones_constants(self):
809  @torch.jit.script
810  def f(x, y):
811  return x + x + y + x + y + x + y + x + y + x
812 
813  def count_constants(graph):
814  return sum(node.kind() == 'prim::Constant' for node in graph.nodes())
815 
816  graph = f.graph.copy()
817  self.run_pass('cse', graph)
818  self.run_pass('create_autodiff_subgraphs', graph)
819  nodes = list(graph.nodes())
820  self.assertEqual(count_constants(graph), 1)
821  self.assertEqual(count_constants(nodes[1].g('Subgraph')), 1)
822 
823  # Backwards tracing was broken for indexing by a constant,
824  # because it's internally implemented using as_strided,
825  # and we attempted to trace its derivative (which is not
826  # currently supported.) It currently works because
827  # slice() is now not marked as traceable.
828  def test_index_constant(self):
829  x = torch.tensor([0.4], requires_grad=True)
830 
831  def fn(x):
832  return x[0]
833 
834  def run(f):
835  y = f(x)
836  grad = torch.autograd.grad(y, x)[0].clone()
837  return y, grad
838 
839  traced_fn = torch.jit.trace(fn, torch.ones(1))
840  self.assertEqual(run(fn), run(traced_fn))
841 
842  def test_scopes(self):
843  x = torch.tensor([0.4], requires_grad=True)
844  y = torch.tensor([0.7], requires_grad=True)
845 
846  def f(x, y):
847  out = x + y
848  with torch.jit.scope('Foo'):
849  out = x * out
850  with torch.jit.scope('Bar'):
851  out = torch.tanh(out)
852  out = torch.sigmoid(out)
853  return out
854 
855  self.checkTrace(f, (x, y))
856 
857  def test_scopes_intermediate_node(self):
858  class Net(nn.Module):
859  def forward(self, x):
860  return F.log_softmax(x, dim=0)
861 
862  net = Net()
863  t = torch.ones(2, requires_grad=True)
864  trace, outputs, inputs = torch.jit.get_trace_graph(net, (t,), return_inputs=True)
865  self.assertEqual(outputs, self.createScriptModuleFromGraph(trace)(*inputs))
866  self.assertExportImport(trace, (t,))
867  torch.onnx._optimize_trace(trace, operator_export_type=OperatorExportTypes.ONNX)
868  FileCheck().check("onnx::LogSoftmax").check("scope: Net").run(str(trace))
869 
870  def test_scopes_identity_node(self):
871 
872  class Net(nn.Module):
873 
874  def __init__(self):
875  super(Net, self).__init__()
876  self.features = nn.Sequential(
877  nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
878  nn.ReLU(inplace=True),
879  nn.MaxPool2d(kernel_size=3, stride=2),
880  )
881 
882  def forward(self, x):
883  x = self.features(x)
884  return x
885 
886  model = Net()
887 
888  t = torch.ones(1, 3, 227, 227, requires_grad=True)
889 
890  with torch.onnx.set_training(model, False):
891  trace, _ = torch.jit.get_trace_graph(model, (t,))
892 
893  self.assertExportImport(trace, (t,) + tuple(model.parameters()))
894  torch.onnx._optimize_trace(trace, operator_export_type=OperatorExportTypes.ONNX)
895  FileCheck().check("Net/Sequential[features]/Conv2d[0]").check("ReLU").check("MaxPool").run(str(trace))
896 
897  def test_canonicalize_tensor_iterator(self):
898  x = torch.randn(4, 4)
899 
900  def f(x):
901  x = x + 2
902  x = x - 4
903  x = x * 6
904  x = x / 8
905  return x
906 
907  traced = torch.jit.trace(f, (x,))
908  f(x)
909  graph = traced.graph_for(x)
910  # There should be 4 int constants for the right sides of operators, plus one
911  # for the alpha argument for add and sub
912  self.assertTrue(str(traced.graph_for(x)).count(': int = prim::Constant') == 5)
913 
914  # TODO: adapt this test to check that GraphExecutor treats them differently
915  @unittest.skip("Need to be adjusted to Graph Executor")
917  """Different arg configurations should trigger different traces"""
918  x = Variable(torch.FloatTensor(4, 4).uniform_())
919  x_double = Variable(x.data.double())
920  x_grad = Variable(x.data.clone(), requires_grad=True)
921  y = Variable(torch.randn(4))
922 
923  configurations = [
924  (x,),
925  (x_double,),
926  (x_grad,),
927  (y,),
928  ([x, x],),
929  ([x, y],),
930  ]
932  x_cuda = Variable(x.data.cuda())
933  configurations += [
934  (x_cuda,),
935  ([x, x_cuda],),
936  ([x_cuda, x],),
937  ([[x_cuda, x]],),
938  ]
939  if torch.cuda.device_count() > 1:
940  x_cuda_1 = Variable(x.data.cuda(1))
941  configurations += [
942  (x_cuda_1,),
943  ([x_cuda, x_cuda_1],),
944  ]
945 
946  @torch.jit.compile(nderivs=0)
947  def fn(*args):
948  in_vars, _ = torch._C._jit_flatten(args)
949  return in_vars[0] + 1
950 
951  for i, config in enumerate(configurations):
952  self.assertFalse(fn.has_trace_for(*config))
953  fn(*config)
954  self.assertTrue(fn.has_trace_for(*config))
955  for unk_config in configurations[i + 1:]:
956  self.assertFalse(fn.has_trace_for(*unk_config))
957  self.assertEqual(fn.hits, 0)
958 
959  def test_cse(self):
960  x = torch.tensor([0.4, 0.3], requires_grad=True)
961  y = torch.tensor([0.7, 0.5], requires_grad=True)
962 
963  def fn(x, y):
964  w = (x + y) * (x + y) * (x + y)
965  t = torch.tanh(w) + torch.tanh(w)
966  z = (x + y) * (x + y) * (x + y) + t
967  return z
968 
969  trace, _ = torch.jit.get_trace_graph(fn, (x, y))
970  self.run_pass('cse', trace)
971  do_exactly = True
972  FileCheck().check_count("add", 1).check_count("mul", 2, do_exactly) \
973  .check_count("tanh", 1, do_exactly).check_count("add", 2, do_exactly).check_next("return") \
974  .run(str(trace))
975 
976  self.assertExportImport(trace, (x, y))
977 
978  def test_recursive_cse(self):
979  x = torch.tensor([0.1])
980  y = torch.tensor([0.2])
981 
982  def fn(x, y):
983  z = x
984  if bool(x + y > x):
985  z = x + y
986  return z
987 
988  graph = torch.jit.script(fn).graph
989  self.run_pass('cse', graph)
990  FileCheck().check("block").check_not("aten::add").check_not("aten::gt").run(str(graph))
991 
992  def test_shape_analysis_broadcast(self):
993  def broadcast(a, b):
994  return a + b
995 
996  x = torch.randn(3, 1, 5, requires_grad=True)
997  y = torch.randn(4, 1, 8, 5, requires_grad=True)
998 
999  graph = torch.jit.script(broadcast).graph
1000  torch._C._jit_pass_complete_shape_analysis(graph, (x, y), False)
1001  FileCheck().check("Double(4, 3, 8, 5)").run(str(graph))
1002 
1003  # TODO: update verify to work with GraphExecutors
1004  @unittest.skip("verify needs to be updated to work with GraphExecutors")
1005  def test_verify(self):
1006  x = torch.tensor([0.4], requires_grad=True)
1007  y = torch.tensor([0.7], requires_grad=True)
1008 
1009  @torch.jit.compile
1010  def f(x, y):
1011  z = torch.sigmoid(x * (x + y))
1012  w = torch.abs(x * x * x + y) + Variable(torch.ones(1))
1013  return z, w
1014 
1015  torch.jit.verify(f, (x, y), loss_fn=lambda z, w: z * w, devices=[])
1016 
1017  @suppress_warnings
1018  def test_constant(self):
1019  x = torch.randn(2, 2, requires_grad=True)
1020 
1021  def f(x):
1022  return x.matmul(torch.diag(torch.tensor([2., 2.])))
1023 
1024  self.checkTrace(f, (x,), (torch.ones(2, 2, requires_grad=True),))
1025 
1026  def test_legacy_fail(self):
1027  class MyLegacyFn(Function):
1028  def forward(self, x):
1029  return x
1030 
1031  def backward(self, grad_output):
1032  return grad_output
1033 
1034  x = torch.tensor([0.], requires_grad=True)
1035  with self.assertRaisesRegex(RuntimeError, "MyLegacyFn"):
1036  torch.jit.get_trace_graph(lambda x: MyLegacyFn()(x), (x,))
1037 
1038  def test_inplace_transplant(self):
1039  x = torch.tensor([0.], requires_grad=True)
1040 
1041  def fn(x):
1042  y = x.clone()
1043  y.add_(2)
1044  y.add_(3)
1045  return y
1046 
1047  trace, _ = torch.jit.get_trace_graph(fn, (x,))
1048  self.run_pass('dce', trace)
1049  FileCheck().check_count("aten::clone", 1, exactly=True) \
1050  .check_count("aten::add_", 2, exactly=True) \
1051  .check_next("return").run(str(trace))
1052  self.assertExportImport(trace, (x,))
1053 
1054  def test_inplace_flags(self):
1055  class InplaceFn(Function):
1056  @staticmethod
1057  def forward(ctx, x):
1058  ctx.mark_dirty(x)
1059  return x.add_(1)
1060 
1061  @staticmethod
1062  def backward(ctx, go):
1063  return go
1064 
1065  class RegularFn(Function):
1066  @staticmethod
1067  def forward(ctx, x):
1068  return x.add(1)
1069 
1070  @staticmethod
1071  def backward(ctx, go):
1072  return go
1073 
1074  x = torch.tensor([0.], requires_grad=True)
1075 
1076  def fn(x):
1077  y = RegularFn.apply(x)
1078  y = InplaceFn.apply(y)
1079  y = InplaceFn.apply(y)
1080  y = RegularFn.apply(y)
1081  return y
1082 
1083  trace, _ = torch.jit.get_trace_graph(fn, (x,), _force_outplace=True)
1084  self.run_pass('dce', trace)
1085  ops = [n for n in trace.graph().nodes()]
1086  for op in ops:
1087  self.assertTrue(op.hasAttribute('inplace'))
1088  inplace_flags = [False, True, True, False]
1089  for op, is_inplace in zip(ops, inplace_flags):
1090  self.assertEqual(op.i('inplace'), is_inplace)
1091 
1092  def test_inplace_check(self):
1093  class MyInplaceFn(Function):
1094  @staticmethod
1095  def forward(self, x):
1096  x.add_(1)
1097  self.mark_dirty(x)
1098  return x
1099 
1100  @staticmethod
1101  def backward(self, grad):
1102  return grad
1103 
1104  def fn(x):
1105  return MyInplaceFn.apply(x)
1106 
1107  x = torch.randn(5, 5)
1108  ge = torch._C.GraphExecutor(fn, (x,), lambda var: '', _force_outplace=True)
1109  with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'):
1110  ge(x)
1111 
1112  def do_trace_size(self, requires_grad):
1113  def fn(x):
1114  return x.view(x.shape[1] * 2, x.size(0), 2)
1115 
1116  x = torch.randn(5, 2, 4, requires_grad=requires_grad)
1117  y = torch.randn(4, 8, 4, requires_grad=requires_grad)
1118 
1119  # Check that it behaves as expected
1120  traced_fn = torch.jit.trace(fn, x)
1121  self.assertEqual(traced_fn(y), fn(y))
1122  self.assertEqual(traced_fn(x), fn(x))
1123 
1124  def test_trace_size(self):
1125  self.do_trace_size(False)
1126 
1127  # test the different graph_executor path that happens when
1128  # gradients are required and sizes are involved
1129  def test_trace_size_with_grad(self):
1130  self.do_trace_size(True)
1131 
1132  def test_trace_casts(self):
1133  casts = [
1134  lambda x: x.byte(),
1135  lambda x: x.float(),
1136  lambda x: x.cpu(),
1137  lambda x: x.to(device='cpu'),
1138  lambda x: x.to(dtype=torch.int64),
1139  lambda x: x.to(device='cpu', dtype=torch.float),
1140  lambda x: x.to(x)
1141  ]
1142 
1143  def assertContainsCast(trace):
1144  self.assertEqual(sum(n.kind() == 'aten::to' for n in trace.graph.nodes()), 1)
1145 
1146  for cast in casts:
1147  trace = torch.jit.trace(cast, torch.randn(2, 2))
1148  assertContainsCast(trace)
1149  x = torch.randn(2, 2)
1150  self.assertEqual(trace(x), cast(x))
1151 
1152  def to_tensor(x, y):
1153  return x.to(y)
1154 
1155  to_tensor_trace = torch.jit.trace(to_tensor, (torch.randn(2, 2), torch.randn(1, 8)))
1156  assertContainsCast(to_tensor_trace)
1157  x, y = torch.randn(2, 2), torch.randn(1, 10)
1158  self.assertEqual(to_tensor_trace(x, y), to_tensor(x, y))
1159 
1160  def test_trace_warn(self):
1161  def fn(x):
1162  int(x) # Warning 1.
1163  y = x * 1
1164  if y: # Warning 2.
1165  pass
1166  q = [x, x * 4]
1167  z = q[y] # Warning 3.
1168  float(z) # Warning 4.
1169  z.tolist() # Warning 5.
1170  z.numpy() # Warning 6.
1171  for _ in torch.ones(4, 4): # Warning 7.
1172  pass
1173  return z + 4
1174 
1175  with warnings.catch_warnings(record=True) as warns:
1176  traced_fn = torch.jit.trace(fn, torch.tensor([1]))
1177  warns = [str(w.message) for w in warns]
1178  self.assertEqual(len(warns), 7)
1179  self.assertIn('a Python integer', warns[0])
1180  self.assertIn('a Python boolean', warns[1])
1181  self.assertIn('a Python index', warns[2])
1182  self.assertIn('a Python float', warns[3])
1183  self.assertIn('a Python list', warns[4])
1184  self.assertIn('a NumPy array', warns[5])
1185  self.assertIn('Iterating over', warns[6])
1186 
1187  def test_trace_tuple(self):
1188  def fn(x, y):
1189  return x, (x * y[1], x * y[0])
1190 
1191  x, y = torch.randn(2, 2), (torch.ones(2, 2), torch.randn(2, 2))
1192  traced_fn = torch.jit.trace(fn, (x, y))
1193  self.assertEqual(traced_fn(x, y), fn(x, y))
1194  # should be a tuple nested within another tuple
1195  FileCheck().check_count("prim::TupleConstruct", 2, exactly=True).check_next("return") \
1196  .run(str(traced_fn.graph))
1197  self.assertExportImport(traced_fn.graph, (x, y))
1198 
1199  def test_trace_random(self):
1200  def f(mean, std):
1201  return torch.normal(mean, std)
1202 
1203  traced = torch.jit.trace(f, (torch.zeros(2, 3), torch.ones(2, 3)), check_trace=False)
1204  mean, std = torch.zeros(5, 5), torch.ones(5, 5)
1205  with torch.random.fork_rng(devices=[]):
1206  output = f(mean, std)
1207  traced_output = traced(mean, std)
1208  self.assertEqual(output, traced_output)
1209 
1210  def test_trace_tensor_factory(self):
1211  def run(**kwargs):
1212  inputs_require_grads = kwargs.pop('inputs_require_grads', True)
1213 
1214  def fn(x):
1215  return x + torch.ones(2, 3, **kwargs)
1216 
1217  input_kwargs = kwargs.copy()
1218  if 'out' in input_kwargs:
1219  del input_kwargs['out']
1220  input = torch.ones(2, 3, **input_kwargs)
1221  self.checkTrace(fn, (input,), inputs_require_grads=inputs_require_grads)
1222  # check we recorded 'ones' and did not just record a constant
1223  tfn = torch.jit.trace(fn, input)
1224  self.assertTrue("ones" in str(tfn.graph))
1225  run()
1226  run(dtype=torch.int, inputs_require_grads=False)
1227  run(out=torch.tensor([]))
1228  if RUN_CUDA:
1229  run(device="cuda:0")
1230  if RUN_CUDA_MULTI_GPU:
1231  run(device="cuda:1")
1232 
1233  def test_trace_indexed_assignment(self):
1234  def stuff(x, y):
1235  x = x.clone()
1236  x[0] = y
1237  return x
1238  example = torch.rand(3, 4)
1239  self.checkTrace(stuff, (example, example[0] + 1))
1240 
1241  # TODO: implement
1242  @unittest.expectedFailure
1244  """Check that outputs of traced functions retain the original structure and nesting"""
1245  def fn(x):
1246  return (x * 2, (x ** 2, x + 4, (x + 2,), ), x * 4)
1247 
1248  self.checkTrace(fn, (torch.randn(2, 2),))
1249 
1250  # TODO: implement
1251  @unittest.expectedFailure
1253  """Check that inputs to traced functions are flattened"""
1254 
1255  def fn(x, t):
1256  y, z = t
1257  return x * y * z
1258 
1259  inputs = (torch.randn(1), (torch.randn(1), torch.randn(1)))
1260  self.checkTrace(fn, inputs)
1261 
1262  # TODO: adapt to a GraphExecutor test
1263  @unittest.skip("Need to instrument GraphExecutors a bit more")
1264  def test_flags(self):
1265  x, y = torch.randn(2, 2)
1266  y = Variable(torch.randn(2, 2))
1267 
1268  @torch.jit.compile
1269  def fn(x, y):
1270  return (x * x + y * y + x * y).sum()
1271 
1272  grads = {}
1273  for rx, ry in product((True, False), repeat=2):
1274  x.requires_grad = rx
1275  y.requires_grad = ry
1276 
1277  self.assertFalse(fn.has_trace_for(x, y))
1278  out = fn(x, y)
1279 
1280  self.assertFalse(fn.has_trace_for(x, y))
1281  for v, name, compute in [(x, 'x', rx), (y, 'y', ry)]:
1282  if not compute:
1283  continue
1284  grad_v, = torch.autograd.grad(out, v, retain_graph=True)
1285  expected_grad = grads.setdefault(name, grad_v)
1286  self.assertEqual(grad_v, expected_grad)
1287  self.assertEqual(fn.has_trace_for(x, y), rx or ry)
1288 
1289  def test_python_ir(self):
1290  x = torch.tensor([0.4], requires_grad=True)
1291  y = torch.tensor([0.7], requires_grad=True)
1292 
1293  def doit(x, y):
1294  return torch.sigmoid(torch.tanh(x * (x + y)))
1295 
1296  trace, _ = torch.jit.get_trace_graph(doit, (x, y))
1297  self.run_pass('dce', trace)
1298  self.run_pass('canonicalize', trace)
1299  g = trace.graph()
1300  g2 = torch._C.Graph()
1301  g_to_g2 = {}
1302  for node in g.inputs():
1303  g_to_g2[node] = g2.addInput()
1304  for node in g.nodes():
1305  n_ = g2.createClone(node, lambda x: g_to_g2[x])
1306  g2.appendNode(n_)
1307  for o, no in zip(node.outputs(), n_.outputs()):
1308  g_to_g2[o] = no
1309 
1310  for node in g.outputs():
1311  g2.registerOutput(g_to_g2[node])
1312 
1313  t_node = g2.create("prim::TensorTest").t_("a", torch.ones([2, 2]))
1314  self.assertEqual(t_node.attributeNames(), ["a"])
1315  g2.appendNode(t_node)
1316  self.assertTrue(torch.equal(torch.ones(2, 2), t_node.t("a")))
1317  for node in g.nodes():
1318  self.assertTrue(g2.findNode(node.kind()) is not None)
1319 
1320  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
1321  @unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA")
1322  @skipIfRocm
1323  def test_cpp_cuda(self):
1324  from cpp.jit import tests_setup
1325  tests_setup.setup()
1326  torch._C._jit_run_cpp_tests()
1327  tests_setup.shutdown()
1328 
1329  def test_batchnorm(self):
1330  x = torch.ones(2, 2, 2, 2)
1331  trace, outputs, inputs = torch.jit.get_trace_graph(nn.BatchNorm2d(2), x,
1332  _force_outplace=True, return_inputs=True)
1333  m = self.createScriptModuleFromGraph(trace)
1334  self.assertEqual(outputs, m(*inputs))
1335 
1336  def test_dropout(self):
1337  x = torch.ones(2, 2)
1338  with torch.random.fork_rng(devices=[]):
1339  trace, outputs, inputs = torch.jit.get_trace_graph(nn.Dropout(0.6), x, return_inputs=True)
1340  with torch.random.fork_rng(devices=[]):
1341  m = self.createScriptModuleFromGraph(trace)
1342  self.assertEqual(outputs, m(*inputs))
1343 
1344  @unittest.skipIf(not RUN_CUDA, "test_dropout_cuda require CUDA")
1345  def test_dropout_cuda(self):
1346  # Dropout AD is dispatched to _fused_dropout in CUDA case,
1347  # which is not included in TestJitGeneratedFunctional
1348  x = torch.ones(4, 4).cuda().requires_grad_()
1349 
1350  @torch.jit.script
1351  def func(x):
1352  return torch.nn.functional.dropout(x)
1353 
1354  with freeze_rng_state():
1355  out_ref = torch.nn.functional.dropout(x)
1356  grad_ref = torch.autograd.grad(out_ref.sum(), x)
1357 
1358  with freeze_rng_state():
1359  out = func(x)
1360  grad = torch.autograd.grad(out.sum(), x)
1361 
1362  self.assertEqual(out, out_ref)
1363  self.assertEqual(grad, grad_ref)
1364 
1365  def test_conv(self):
1366  x = torch.ones(20, 16, 50, 40)
1367  trace, outputs, inputs = torch.jit.get_trace_graph(nn.Conv2d(16, 13, 3, bias=False), x, return_inputs=True)
1368  m = self.createScriptModuleFromGraph(trace)
1369  self.assertEqual(outputs, m(*inputs))
1370 
1371  def test_repeated_input(self):
1372  def fn(a, b):
1373  return a + b
1374 
1375  ge = self.checkTrace(fn, [torch.randn(2, 2)] * 2)
1376  inputs = set(ge.graph.inputs())
1377  self.assertTrue(len(inputs) == 2)
1378 
1379  def test_repeated_output(self):
1380  def fn(a, b):
1381  z = a + b
1382  return z, z
1383 
1384  ge = self.checkTrace(fn, [torch.randn(2, 2) for _ in range(2)])
1385  tuple_output = list(ge.graph.outputs())[0]
1386  tuple_inputs = list(tuple_output.node().inputs())
1387  self.assertTrue(tuple_inputs[0] == tuple_inputs[1])
1388 
1389  @skipIfNoTorchVision
1390  def test_alexnet(self):
1391  x = torch.ones(1, 3, 224, 224)
1392  model = torchvision.models.AlexNet()
1393  with torch.random.fork_rng(devices=[]):
1394  trace, outputs, inputs = torch.jit.get_trace_graph(model, x, return_inputs=True)
1395  self.run_pass('cse', trace)
1396  m = self.createScriptModuleFromGraph(trace)
1397  with torch.random.fork_rng(devices=[]):
1398  self.assertEqual(outputs, m(*inputs))
1399 
1400  def test_inplace_copy(self):
1401  x = torch.randn(4, 4, requires_grad=True)
1402 
1403  def f(x):
1404  out = Variable(torch.zeros(x.size()))
1405  out.copy_(x)
1406  return out
1407 
1408  trace, outputs, inputs = torch.jit.get_trace_graph(f, (x, ), return_inputs=True)
1409  self.run_pass('dce', trace)
1410  m = self.createScriptModuleFromGraph(trace)
1411  self.assertEqual(outputs, m(*inputs))
1412  self.assertExportImport(trace, (x,))
1413 
1414  def test_shared_param(self):
1415  class MyModule(torch.nn.Module):
1416  def __init__(self):
1417  super(MyModule, self).__init__()
1418  self.b = self.a = nn.Parameter(torch.randn(2, 2))
1419 
1420  def forward(self, x):
1421  return x * self.a + self.b
1422 
1423  m = MyModule()
1424  trace, _ = torch.jit.get_trace_graph(m, (torch.randn(2, 2),))
1425  self.run_pass('dce', trace)
1426  self.assertEqual(len(list(trace.graph().inputs())), 2)
1427  FileCheck().check("mul").check("add").run(str(trace))
1428 
1429  def test_trace_c10_ops(self):
1430  class MyModel(torch.nn.Module):
1431  def __init__(self):
1432  super(MyModel, self).__init__()
1433 
1434  def forward(self, scores, bbox_deltas, im_info, anchors):
1435  a, b = torch.ops._caffe2.GenerateProposals(
1436  (scores), (bbox_deltas), (im_info), (anchors),
1437  2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0,
1438  )
1439  return a, b
1440  model = MyModel()
1441  A = 4
1442  H = 10
1443  W = 8
1444  img_count = 3
1445  scores = torch.ones(img_count, A, H, W, dtype=torch.float32)
1446  bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W,
1447  dtype=torch.float32)
1448  bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W)
1449  im_info = torch.ones(img_count, 3, dtype=torch.float32)
1450  anchors = torch.ones(A, 4, dtype=torch.float32)
1451  inputs = (scores, bbox_deltas, im_info, anchors)
1452  traced_model = torch.jit.trace(model, inputs)
1453  self.assertEqual(traced_model(*inputs), model(*inputs))
1454  self.assertExportImport(traced_model.graph, (scores, bbox_deltas, im_info, anchors))
1455 
1456  def test_nested_inplace(self):
1457  x = torch.randn(2, 2)
1458  trace, outputs, inputs = torch.jit.get_trace_graph(
1459  lambda x: F.threshold(x, 0, 0, inplace=True), (x, ), return_inputs=True)
1460  m = self.createScriptModuleFromGraph(trace)
1461  self.assertEqual(outputs, m(*inputs))
1462  FileCheck().check("threshold_").run(str(trace))
1463  self.assertExportImport(trace, (x,))
1464 
1465  def run_ge_tests(self, optimize, use_cuda):
1466  def rand(*args):
1467  t = torch.rand(*args).float()
1468  if use_cuda:
1469  t = t.cuda()
1470  return t
1471  self.checkTrace(lambda a, b: a * b + b,
1472  [rand(1), rand(1)], [rand(2, 3), rand(2, 3)],
1473  optimize=optimize)
1474  # trivial identity
1475  self.checkTrace(lambda a, b: (
1476  b, a), [rand(1), rand(1)], optimize=optimize)
1477 
1478  def foo(a):
1479  t = a * a
1480  return t * t, 4 * t
1481  self.checkTrace(foo, [rand(1)], optimize=optimize)
1482  # unused input
1483  self.checkTrace(
1484  lambda a, b: a * a, [rand(1), rand(1)], optimize=optimize,
1485  allow_unused=True)
1486  # test outputs that do not get used in grad
1487  self.checkTrace(foo, [rand(1)], drop=1, optimize=optimize)
1488  # test autograd fallback
1489  self.checkTrace(lambda a, b: a * b /
1490  (a - 2 * b) + b, [rand(1), rand(1)],
1491  optimize=optimize)
1492 
1493  def test_ge_unoptimized(self):
1494  self.run_ge_tests(False, False)
1495 
1496  @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
1497  @enable_cpu_fuser
1498  def test_ge_optimized(self):
1499  self.run_ge_tests(True, False)
1500 
1501  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
1502  @unittest.skipIf(not RUN_CUDA, "requires CUDA")
1503  def test_ge_cuda(self):
1504  self.run_ge_tests(True, True)
1505 
1506  # more manual test of graph executor that can be used as a scratchpad
1507  def test_ge(self):
1508  def foo(a, b):
1509  return a * b / (a - b) + b
1510  V = Variable
1511  a, b = V(torch.rand(1)), V(torch.rand(1))
1512  ge = torch._C.GraphExecutor(foo, (a, b), lambda var: '')
1513  a, b = V(torch.rand(1), requires_grad=True), V(
1514  torch.rand(1), requires_grad=True)
1515  r, = ge(a, b)
1516  da, db = torch.autograd.grad(r + 3, [a, b], create_graph=True)
1517 
1518  l2 = (da * db + db * db)
1519  g2result = torch.autograd.grad(l2, [da, db])
1520 
1521  r = foo(a, b)
1522  da2, db2 = torch.autograd.grad(r + 3, [a, b], create_graph=True)
1523  self.assertEqual(da, da2)
1524  self.assertEqual(db, db2)
1525  l3 = (da2 * db2 + db2 * db2)
1526  g2result2 = torch.autograd.grad(l3, [da2, db2])
1527  self.assertEqual(g2result, g2result2)
1528 
1529  def test_trace_annotation(self):
1530  @_trace(torch.rand(1))
1531  def foo(a):
1532  return a + a + a
1533 
1534  x = torch.randn(5, 5)
1535  self.assertEqual(foo(x), x + x + x)
1536 
1537  def test_trace_script(self):
1538  @torch.jit.script
1539  def func1(x):
1540  # type: (Tuple[Tensor, Tensor]) -> Tensor
1541  return x[0] + x[1]
1542 
1543  @torch.jit.script
1544  def func2(x):
1545  # type: (List[Tensor]) -> Tensor
1546  return x[0] + x[1]
1547 
1548  a = torch.randn(5)
1549  b = torch.randn(5)
1550 
1551  expected = func1((a, b))
1552  traced = torch.jit.trace(func1, ((a, b),))
1553  result = traced((a, b))
1554  self.assertEqual(expected, result)
1555 
1556  expected = func2((a, b))
1557  traced = torch.jit.trace(func2, ((a, b),))
1558  result = traced((a, b))
1559  self.assertEqual(expected, result)
1560 
1561  def test_einsum(self):
1562  def outer(x, y):
1563  return torch.einsum('i,j->ij', (x, y))
1564 
1565  traced = torch.jit.trace(outer, (torch.randn(4), torch.randn(5)))
1566  script = torch.jit.script(outer)
1567  fns = [traced, script]
1568  x, y = torch.randn(10), torch.randn(2)
1569  for fn in [traced, script]:
1570  self.assertGraphContains(fn.graph, kind='aten::einsum')
1571  self.assertEqual(fn(x, y), outer(x, y))
1572 
1573  @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
1574  @unittest.skipIf(not RUN_CUDA, "calls .cuda()")
1575  def test_traced_module_cuda(self):
1576  class Model(nn.Module):
1577  def __init__(self, num_features, num_layers):
1578  super(Model, self).__init__()
1579  self.num_layers = num_layers
1580  layers = [[nn.Linear(num_features, num_features), nn.Sigmoid()]
1581  for _ in range(num_layers)]
1582  self.submodule = nn.Sequential(*chain(*layers))
1583 
1584  def forward(self, x):
1585  for i in range(self.num_layers):
1586  x = self.submodule[i](x) + x
1587  return x
1588 
1589  model = Model(5, 3)
1590  x = torch.randn(2, 5)
1591  traced_model = torch.jit.trace(model, x)
1592 
1593  # We're missing some attributes these modules had initially. Make sure we can
1594  # still get the __repr__()
1595  model.__repr__()
1596 
1597  # XXX: indexing sequentials is broken
1598  linear_submodule = next(iter(traced_model.submodule._modules.values()))
1599 
1600  # All attributes that aren't parameters should raise
1601  with self.assertRaises(AttributeError):
1602  linear_submodule.in_features
1603  linear_submodule.weight
1604  with self.assertRaises(RuntimeError):
1605  traced_model.asdf = 4
1606  linear_submodule.weight = nn.Parameter(torch.randn(linear_submodule.weight.shape))
1607  with self.assertRaises(RuntimeError):
1608  del linear_submodule.weight
1609 
1610  # Submodules can't be called
1611  with self.assertRaises(RuntimeError):
1612  linear_submodule(x)
1613 
1614  # Type casts
1615  linear_submodule.cuda()
1616  traced_model.float().cuda()
1617  cuda_out = traced_model(x.float().cuda())
1618  traced_model.cpu()
1619  cpu_out = traced_model(x.float())
1620  self.assertEqual(cpu_out, cuda_out)
1621  traced_model.to('cuda')
1622  cuda_out = traced_model(x.float().cuda())
1623  traced_model.to('cpu')
1624  cpu_out = traced_model(x.float())
1625  self.assertEqual(cpu_out, cuda_out)
1626  traced_model.double()
1627 
1628  # state_dict + load_state_dict
1629  state = {k: v.clone() for k, v in traced_model.state_dict().items()}
1630  new_state = {k: v.clone().fill_(1) for k, v in state.items()}
1631  out = traced_model(x)
1632  traced_model.load_state_dict(new_state)
1633  out_ones = traced_model(x)
1634  traced_model.load_state_dict(state)
1635  out_state = traced_model(x)
1636  self.assertEqual(out, out_state)
1637  self.assertNotEqual(out, out_ones)
1638 
1639  def test_export_no_reorder(self):
1640  def func(a, b):
1641  return a * b / (a - 2 * b) + b
1642 
1643  recording_inputs = [torch.tensor([0.55619788169860839844], dtype=torch.float32, requires_grad=True),
1644  torch.tensor([0.25947844982147216797], dtype=torch.float32, requires_grad=True)]
1645 
1646  ge1 = torch.jit.trace(func, recording_inputs, optimize=True)
1647  ge2 = self.getExportImportCopy(ge1)
1648 
1649  outputs_ge1 = ge1(*recording_inputs)
1650  outputs_ge2 = ge2(*recording_inputs)
1651 
1652  grad_ge1 = torch.autograd.grad(outputs_ge1, recording_inputs)
1653  grad_ge2 = torch.autograd.grad(outputs_ge2, recording_inputs)
1654  self.assertTrue(outputs_ge1 == outputs_ge2)
1655  self.assertTrue(grad_ge1 == grad_ge2)
1656 
1657  def test_python_function(self):
1658  class MyFn(Function):
1659  @staticmethod
1660  def forward(ctx, x):
1661  return x + 1
1662 
1663  @staticmethod
1664  def backward(ctx, grad_output):
1665  return grad_output
1666 
1667  @_trace(torch.zeros(2))
1668  def fn(x):
1669  return MyFn.apply(x + 2) + 3
1670 
1671  x = torch.tensor([1., 2., 3.])
1672  y = torch.randn(2, 2, requires_grad=True)
1673  fn(x)
1674  fn(y)
1675 
1676  def test_python_function_tup(self):
1677  class MyFn(Function):
1678  @staticmethod
1679  def forward(ctx, x):
1680  return x + 1, x - 1
1681 
1682  @staticmethod
1683  def backward(ctx, grad_output):
1684  return grad_output, grad_output
1685 
1686  @_trace(torch.zeros(2))
1687  def fn(x):
1688  a, b = MyFn.apply(x + 2)
1689  return a + b + 3
1690  x = torch.tensor([1., 2., 3.])
1691  y = torch.randn(2, 2, requires_grad=True)
1692  fn(x)
1693  fn(y)
1694 
1695  def test_decompose_addmm(self):
1696  def does_decompose():
1697  @torch.jit.script
1698  def addmm(mat, mat1, mat2, alpha, beta):
1699  a = mat.addmm(mat1, mat2, alpha=4.20, beta=2.0)
1700  b = mat.addmm(mat1, mat2, alpha=int(alpha), beta=int(beta))
1701 
1702  return a + b
1703 
1704  mat = torch.randn(2, 2)
1705  mat1 = torch.randn(2, 4)
1706  mat2 = torch.randn(4, 2)
1707  alpha = torch.FloatTensor([123.0])
1708  beta = torch.FloatTensor([321.0])
1709 
1710  out_ref = addmm(mat, mat1, mat2, alpha, beta)
1711  self.run_pass('canonicalize_ops', addmm.graph)
1712  out_test = addmm(mat, mat1, mat2, alpha, beta)
1713  self.assertEqual(out_ref, out_test)
1714  FileCheck().check_not("addmm").run(str(addmm.graph))
1715 
1716  def doesnt_decompose():
1717  @torch.jit.script
1718  def addmm(mat, mat1, mat2, alpha, beta):
1719  a = mat.addmm(mat1, mat2)
1720  b = mat.addmm(mat1, mat2, alpha=1.0, beta=1.0)
1721 
1722  orig = str(addm.graph)
1723  self.run_pass('canonicalize_ops', addmm.graph)
1724  self.assertTrue(orig == str(addmm.graph))
1725 
1726  def test_index_put(self):
1727  ten = torch.zeros(3, 3)
1728  mask = torch.Tensor([[True, True, True],
1729  [True, False, False],
1730  [True, True, False]]).byte()
1731 
1732  def test_fn(ten, mask):
1733  ten[mask] = torch.ones(6)
1734  return ten
1735 
1736  traced_test_fn = torch.jit.trace(test_fn, (ten, mask))
1737 
1738  ten = torch.rand(3, 3)
1739  self.assertEqual(test_fn(ten, mask), traced_test_fn(ten, mask))
1740 
1741  def test_sparse_tensors_error(self):
1742  def get_sparse():
1743  return torch.sparse.FloatTensor(2, 3)
1744 
1745  @torch.jit.script
1746  def sparse(input):
1747  output = get_sparse()
1748  return output, input
1749 
1750  with self.assertRaisesRegex(RuntimeError, "sparse tensors not supported"):
1751  sparse(get_sparse())
1752 
1753  with self.assertRaisesRegex(RuntimeError, "sparse tensors not supported"):
1754  sparse(torch.tensor([1]))
1755 
1756  def test_tuple_specialization(self):
1757  @torch.jit.script
1758  def f(t):
1759  # type: (Tuple[Tensor, Tensor]) -> Tensor
1760  x, y = t
1761  return x + y
1762 
1763  t = torch.randn(2, 2), torch.randn(2, 2)
1764  f(t)
1765  graph = f.graph_for(t)
1766  input_types = list(next(graph.inputs()).type().elements())
1767  for t in input_types:
1768  self.assertEqual(t.kind(), 'DimensionedTensorType')
1769 
1770  def test_constant_prop_simple(self):
1771  @torch.jit.script
1772  def constant_prop(input_int):
1773  # type: (int) -> int
1774  a = 2 * 3
1775  b = a + 2
1776  return b - input_int
1777 
1778  out_ref = constant_prop(2)
1779  self.run_pass('constant_propagation', constant_prop.graph)
1780  out_test = constant_prop(2)
1781  self.assertEqual(out_ref, out_test)
1782  graph_str = str(constant_prop.graph)
1783  self.assertTrue("aten::add" not in graph_str and "aten::mul" not in graph_str)
1784  const = constant_prop.graph.findNode("prim::Constant").output().toIValue()
1785  self.assertEqual(const, 8)
1786 
1787  def test_constant_prop_nested(self):
1788  @torch.jit.script
1789  def constant_prop(a):
1790  b = 2 + 1
1791  if bool(a < 2):
1792  c = b + 2
1793  else:
1794  c = b - 2
1795  return c
1796  out_ref = constant_prop(torch.tensor(2))
1797  self.run_pass('constant_propagation', constant_prop.graph)
1798  out_test = constant_prop(torch.tensor(2))
1799  self.assertEqual(out_ref, out_test)
1800  if_node = constant_prop.graph.findNode("prim::If")
1801  for block in if_node.blocks():
1802  for node in block.nodes():
1803  self.assertTrue(node.kind() == "prim::Constant")
1804 
1805  def test_constant_prop_print(self):
1806  @torch.jit.script
1807  def constant_prop(input_tensor):
1808  a = 2 * 3
1809  print(a)
1810  b = a + 2
1811  return b + input_tensor
1812 
1813  self.run_pass('constant_propagation', constant_prop.graph)
1814  graph = constant_prop.graph
1815  print_node = graph.findNode("prim::Print")
1816  self.assertTrue(print_node.input().toIValue() == 6)
1817 
1818  def test_constant_prop_rand(self):
1819  @torch.jit.script
1820  def constant_prop():
1821  a = torch.randn([3])
1822  b = a + 2
1823  return b
1824 
1825  self.run_pass('constant_propagation', constant_prop.graph)
1826  self.assertTrue("aten::randn" in str(constant_prop.graph))
1827 
1828  def test_constant_prop_none(self):
1829  @torch.jit.script
1830  def typed_none():
1831  # type: () -> Optional[int]
1832  return None
1833 
1834  @torch.jit.script
1835  def constant_prop():
1836  a = typed_none()
1837  b = typed_none()
1838  if (a is None and b is None):
1839  a = 2
1840  else:
1841  a = 1
1842  return a
1843 
1844  self.run_pass('constant_propagation', constant_prop.graph)
1845  graph_str = str(constant_prop.graph)
1846  self.assertTrue(graph_str.count("prim::Constant") == 1)
1847 
1848  def test_constant_prop_if_inline(self):
1849  @torch.jit.script
1850  def constant_prop():
1851  cond = True
1852  a = 1
1853  if cond:
1854  a = 1 * 2
1855  else:
1856  a = 1 // 0
1857  return a
1858 
1859  # testing that 1 // 0 error is not thrownn
1860  self.run_pass('constant_propagation', constant_prop.graph)
1861 
1862  def test_trace_records_names(self):
1863  def foo(bar, baz):
1864  baz = bar + 3
1865  quick_brown_fox = torch.neg(baz)
1866  for _ in range(20):
1867  yeet = quick_brown_fox - 3.14
1868  return yeet
1869 
1870  traced = torch.jit.trace(foo, (torch.rand(3, 3), torch.rand(3, 3)))
1871  graph_str = str(traced.graph)
1872  assert 'bar' in graph_str
1873  assert 'baz' in graph_str
1874  assert 'quick_brown_fox' in graph_str
1875 
1876  def test_constant_prop_if_constant(self):
1877  @torch.jit.script
1878  def constant_prop(a, b):
1879  c0 = 1
1880  c1 = 1
1881  c2 = 1
1882  if bool(a): # -> c0, c1
1883  if bool(b): # -> c0
1884  if True: # -> c0
1885  c0 = c0 + 1
1886  if False:
1887  c1 = c1 + 1
1888  c2 = c2 + 1
1889  else: # -> c0, c1
1890  c1 = c1 + 1
1891 
1892  if True: # inlined
1893  c0 = c0 + 1 # dynamic
1894  c2 = c2 + 4 # set to 5
1895  return a + c0 + c1 + c2
1896 
1897  graph = constant_prop.graph
1898  self.run_pass('constant_propagation', graph)
1899  ifs = graph.findAllNodes("prim::If", recurse=False)
1900  snd_if_inlined = len(ifs) == 1
1901  self.assertTrue(snd_if_inlined)
1902  first_if = ifs[0]
1903  self.assertTrue(first_if.outputsSize() == 2)
1904  second_if = first_if.findNode("prim::If", recurse=False)
1905  self.assertTrue(second_if.outputsSize() == 1)
1906  self.assertTrue(second_if.findNode("prim::If") is None)
1907 
1908  def test_constant_prop_loop_constant(self):
1909  @torch.jit.script
1910  def constant_prop(cond, iter):
1911  # type: (bool, int) -> int
1912  b = 0
1913  while True:
1914  print("stays")
1915  for _ in range(2):
1916  print("stays")
1917  for _ in range(iter):
1918  print("stays")
1919  while cond:
1920  print("stays")
1921  while False:
1922  print("removed")
1923  for _i in range(0):
1924  print("removed")
1925  for _i in range(-4):
1926  print("removed")
1927  return b
1928 
1929  self.run_pass('constant_propagation', constant_prop.graph)
1930  graph = canonical(constant_prop.graph)
1931  self.assertTrue(graph.count("removed") == 0)
1932  self.assertTrue(graph.count("stays") == 1) # constant gets pooled
1933  self.assertTrue(graph.count("prim::Print") == 4)
1934 
1935  def test_constant_prop_remove_output(self):
1936  @torch.jit.script
1937  def constant_prop(iter):
1938  # type: (int) -> None
1939  a = 1
1940  b = 1
1941  c = 1
1942  for i in range(iter):
1943  if False:
1944  a = 10
1945  if i == 5:
1946  b = 2
1947  c = 3
1948  print(a, b, c)
1949 
1950  graph = constant_prop.graph
1951  self.run_pass('constant_propagation', graph)
1952  self.assertTrue(graph.findNode("prim::Loop").outputsSize() == 2)
1953 
1954  def test_trace_detach(self):
1955  def foo(x, w):
1956  return torch.matmul(x, w).detach()
1957 
1958  traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5)))
1959 
1960  FileCheck().check("matmul").check("detach").run(str(traced.graph))
1961  x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
1962  traced_result = traced(x, w)
1963  self.assertEqual(foo(x, w), traced_result)
1964  self.assertFalse(traced_result.requires_grad)
1965  self.assertIsNone(traced_result.grad_fn)
1966 
1967  def test_trace_detach_inplace(self):
1968  def foo(x, w):
1969  y = torch.matmul(x, w)
1970  y.detach_()
1971  return y
1972 
1973  traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5)))
1974 
1975  FileCheck().check("matmul").check("detach(").run(str(traced.graph))
1976  x, w = torch.rand(3, 4), torch.rand(4, 5)
1977  traced_result = traced(x, w)
1978  self.assertEqual(foo(x, w), traced_result)
1979  self.assertFalse(traced_result.requires_grad)
1980  self.assertIsNone(traced_result.grad_fn)
1981 
1982  def test_trace_detach_onnx_erase(self):
1983  class Mod(torch.nn.Module):
1984  def forward(self, x, w):
1985  return torch.matmul(x, w).detach()
1986 
1987  f = io.BytesIO()
1989  Mod(), (torch.rand(3, 4), torch.rand(4, 5)), f))
1990 
1991  def test_trace_slice_full_dim(self):
1992  def foo(x):
1993  return x[0:5, 0] + 1.0
1994 
1995  traced = torch.jit.trace(foo, (torch.rand(5, 4),))
1996  test_x = torch.rand(6, 3)
1997  self.assertEqual(foo(test_x), traced(test_x))
1998 
1999  def test_export_dropout(self):
2000  test = torch.nn.Dropout()
2001  test.eval()
2002 
2003  traced = torch.jit.trace(test, (torch.rand(3, 4),), check_trace=False)
2004  imported = self.getExportImportCopy(traced)
2005  x = torch.randn(3, 4)
2006  self.assertEqual(traced(x), imported(x))
2007 
2008  def test_onnx_transpose_incomplete_tensor_type(self):
2009  # Smoke test to get us into the state where we are attempting to export
2010  # a transpose op, where the input is a TensorType rather than a
2011  # CompleteTensorType. This would previously not work, since we would
2012  # take the size of the input and use the length of its sizes as the
2013  # number of dimensions in the permutation.
2014  class Foo(torch.jit.ScriptModule):
2015  @torch.jit.script_method
2016  def forward(self, x):
2017  return x.contiguous().transpose(0, 1).sum()
2018 
2019  class TraceMe(torch.nn.Module):
2020  def __init__(self):
2021  super(TraceMe, self).__init__()
2022  self.foo = Foo()
2023 
2024  def forward(self, x):
2025  return self.foo(x)
2026 
2027  tm = TraceMe()
2028  tm = torch.jit.trace(tm, torch.rand(3, 4))
2029  example_outputs = (tm(torch.rand(3, 4)),)
2030  f = io.BytesIO()
2031  torch.onnx._export(tm, (torch.rand(3, 4),), f, example_outputs=example_outputs)
2032 
2033  @unittest.skipIf(not RUN_CUDA, "requires CUDA")
2034  def test_cuda_export_restore(self):
2035  class Sub(torch.jit.ScriptModule):
2036  def __init__(self):
2037  super(Sub, self).__init__()
2038  self.weight = nn.Parameter(torch.randn(3, 4))
2039 
2040  @torch.jit.script_method
2041  def forward(self, thing):
2042  return self.weight + thing
2043 
2044  class M(torch.jit.ScriptModule):
2045  def __init__(self):
2046  super(M, self).__init__()
2047  self.mod = Sub()
2048 
2049  @torch.jit.script_method
2050  def forward(self, v):
2051  return self.mod(v)
2052  m = M()
2053  m.cuda()
2054  m2 = self.getExportImportCopy(m)
2055  m2.cuda()
2056  input = torch.rand(3, 4).cuda()
2057  self.assertEqual(m(input), m2(input))
2058 
2059  def test_export_batchnorm(self):
2060  for mode in ['eval', 'train']:
2061  for clazz in [
2062  torch.nn.BatchNorm1d(100),
2063  torch.nn.BatchNorm1d(100, affine=False),
2064  torch.nn.BatchNorm2d(100),
2065  torch.nn.BatchNorm2d(100, affine=False)]:
2066  getattr(clazz, mode)()
2067  input = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \
2068  torch.randn(20, 100, 35, 45)
2069  traced = torch.jit.trace(clazz, (input,))
2070  imported = self.getExportImportCopy(traced)
2071  x = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \
2072  torch.randn(20, 100, 35, 45)
2073  self.assertEqual(traced(x), imported(x))
2074 
2075  def test_export_rnn(self):
2076  for clazz in [nn.RNN(10, 20, 2), nn.GRU(10, 20, 2)]:
2077  class RNNTest(torch.nn.Module):
2078  def __init__(self):
2079  super(RNNTest, self).__init__()
2080  self.rnn = clazz
2081 
2082  def forward(self, x, lengths, h0):
2083  packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths)
2084  out, h = self.rnn(packed, h0)
2085  padded_outs, _ = torch.nn.utils.rnn.pad_packed_sequence(out)
2086  return padded_outs
2087 
2088  test = RNNTest()
2089 
2090  traced = torch.jit.trace(test, (torch.randn(5, 3, 10), torch.LongTensor([3, 2, 1]), torch.randn(2, 3, 20)))
2091  imported = self.getExportImportCopy(traced)
2092  # NB: We make sure to pass in a batch with a different max sequence
2093  # length to ensure that the argument stashing for pad_packed works
2094  # properly.
2095  x, lengths, h0 = torch.randn(7, 4, 10), torch.LongTensor([7, 3, 2, 1]), torch.randn(2, 4, 20)
2096  self.assertEqual(traced(x, lengths, h0), imported(x, lengths, h0))
2097 
2098  def test_export_lstm(self):
2099  class LSTMTest(torch.nn.Module):
2100  def __init__(self):
2101  super(LSTMTest, self).__init__()
2102  self.rnn = nn.LSTM(10, 20, 2)
2103 
2104  def forward(self, x, lengths, hiddens):
2105  h0, c0 = hiddens
2106  packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths)
2107  out, (h, c) = self.rnn(packed, (h0, c0))
2108  padded_outs, _ = torch.nn.utils.rnn.pad_packed_sequence(out)
2109  return padded_outs
2110 
2111  test = LSTMTest()
2112 
2113  traced = torch.jit.trace(test, (torch.randn(5, 3, 10),
2114  torch.LongTensor([3, 2, 1]),
2115  (torch.randn(2, 3, 20), torch.randn(2, 3, 20))))
2116  imported = self.getExportImportCopy(traced)
2117  x, lengths, h0, c0 = \
2118  torch.randn(7, 3, 10), torch.LongTensor([7, 5, 2]), torch.randn(2, 3, 20), torch.randn(2, 3, 20)
2119  self.assertEqual(traced(x, lengths, (h0, c0)), imported(x, lengths, (h0, c0)))
2120 
2121  def test_trace_dict_input(self):
2122  class Bar(torch.nn.Module):
2123  def __init__(self):
2124  super(Bar, self).__init__()
2125  self.foo = Foo()
2126 
2127  def forward(self, a, b):
2128  return self.foo({'a': a, 'b': b})['a']
2129 
2130  class Foo(torch.nn.Module):
2131  def forward(self, x):
2132  return {'a': x['a'] * x['b']}
2133 
2134  x = (torch.rand(3), torch.rand(3))
2135  model = Bar()
2136  self.checkTrace(model, x)
2137 
2138  def test_trace_variable_instantiation(self):
2139  def random_foo(x):
2140  return Variable(Variable(x) + 1.0)
2141 
2142  random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),))
2143 
2144  x = torch.rand(5, 6)
2145  self.assertEqual(random_foo(x), random_foo_traced(x))
2146 
2147  def test_trace_slice_expr_complete_type(self):
2148  def random_foo(x):
2149  return x + 1.0
2150 
2151  random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),))
2152 
2153  @torch.jit.script
2154  def random_bar(x):
2155  return random_foo_traced(x)[0:1]
2156 
2157  x = torch.rand(3, 4)
2158  self.assertEqual(random_bar(x), (x + 1)[0:1])
2159 
2160  def test_export_tensoroption_to(self):
2161  def foo(x):
2162  return x.new_tensor(x[0]).cpu() + x
2163 
2164  traced = torch.jit.trace(foo, (torch.rand([2])))
2165  example_outputs = traced(torch.rand([2]))
2166 
2167  f = io.BytesIO()
2168  self.assertExpected(torch.onnx._export_to_pretty_string(traced, (torch.rand([2]),), f,
2169  example_outputs=example_outputs))
2170 
2171  def test_pretty_printer(self):
2172  @torch.jit.script
2173  def if_test(a, b):
2174  # FIXME: use 0 instead of a.
2175  # c = 0
2176  c = a
2177  if bool(a < b):
2178  c = b
2179  else:
2180  c = a
2181  return c
2182 
2183  @torch.jit.script
2184  def if_one(a, b):
2185  c = b
2186  if bool(a < b):
2187  c = a
2188  return c
2189 
2190  @torch.jit.script
2191  def while_test(a, i):
2192  while bool(i < 3):
2193  a *= a
2194  i += 1
2195  return a
2196 
2197  @torch.jit.script
2198  def while_if_test(a, b):
2199  c = 0
2200  while bool(a < 10):
2201  a = a + 1
2202  b = b + 1
2203  if bool(a > b):
2204  c = 2
2205  else:
2206  c = 3
2207  return a + 1 + c
2208 
2209  @torch.jit.script
2210  def loop_use_test(y):
2211  x = y + 1
2212  z = x + 5
2213  while bool(y < 8):
2214  y += 1
2215  z = x
2216  return x, z
2217 
2218  def python_fn(x):
2219  return x + 10
2220 
2221  @torch.jit.script
2222  def python_op_name_test(y):
2223  return python_fn(y)
2224 
2225  @torch.jit.script
2226  def empty_int_list_test(y):
2227  x = torch.jit.annotate(List[int], [])
2228  return x[0]
2229 
2230  @torch.jit.script
2231  def empty_float_list_test(y):
2232  return [1.0, 2.0, 3.0]
2233 
2234  @torch.jit.script
2235  def print_weird_test(y):
2236  print("hi\016")
2237 
2238  self.assertExpected(if_test.graph.pretty_print(), "if_test")
2239  self.assertExpected(if_one.graph.pretty_print(), "if_one")
2240  self.assertExpected(while_test.graph.pretty_print(), "while_test")
2241  self.assertExpected(while_if_test.graph.pretty_print(), "while_if_test")
2242  self.assertExpected(loop_use_test.graph.pretty_print(), "loop_use_test")
2243  self.assertExpected(python_op_name_test.graph.pretty_print(), "python_op_name_test")
2244  self.assertExpected(empty_int_list_test.graph.pretty_print(), "empty_int_list_test")
2245  self.assertExpected(empty_float_list_test.graph.pretty_print(), "empty_float_list_test")
2246  self.assertExpected(print_weird_test.graph.pretty_print(), "print_weird_test")
2247 
2248  def test_cu_escaped_number(self):
2249  cu = torch.jit.CompilationUnit('''
2250  def foo(a):
2251  print("hi\016")
2252  ''')
2253  self.assertExpected(cu.foo.graph.pretty_print())
2254 
2255  def test_import_method(self):
2256  @torch.jit.script
2257  def foo(x, y):
2258  return 2 * x + y
2259 
2260  r, _ = foo._python_print()
2261  mod = torch.jit.ScriptModule()
2262  torch._C._jit_import_methods(mod, "op_version_set = 0\n{}".format(r), [])
2263  self.assertExpected(mod.graph.pretty_print())
2264 
2265  def test_function_default_values(self):
2266  outer_var = torch.tensor(20)
2267  outer_var2 = torch.tensor(30)
2268  a = torch.tensor(0.5)
2269  b = torch.tensor(10)
2270 
2271  @torch.jit.script
2272  def simple_fn(x, a=a, b=b, c=outer_var + outer_var2):
2273  return x + a + b + c
2274 
2275  self.assertEqual(
2276  simple_fn(torch.ones(1)),
2277  torch.ones(1) + 0.5 + 10 + (20 + 30))
2278  self.assertEqual(
2279  simple_fn(torch.ones(1), torch.tensor(1), torch.tensor(3), torch.tensor(4)),
2280  torch.ones(1) + 1 + 3 + 4)
2281 
2282  outer_c = torch.tensor(9)
2283  outer_flag = torch.tensor(False)
2284 
2285  @torch.jit.script
2286  def bool_fn(x, a=outer_c, flag=outer_flag):
2287  if bool(flag):
2288  result = x
2289  else:
2290  result = x + a
2291  return result
2292 
2293  self.assertEqual(bool_fn(torch.ones(1)), torch.ones(1) + 9)
2294  self.assertEqual(
2295  bool_fn(torch.ones(1), torch.tensor(1), torch.tensor(True)),
2296  torch.ones(1))
2297 
2298  @torch.jit.script
2299  def none_fn(x=None):
2300  # type: (Optional[int]) -> Optional[int]
2301  return x
2302 
2303  self.assertEqual(none_fn(), None)
2304  self.assertEqual(none_fn(1), 1)
2305 
2306  @torch.jit.script
2307  def hints(x, a=0.5, b=10):
2308  # type: (Tensor, float, int) -> Tensor
2309  return x + a + b
2310 
2311  self.assertEqual(hints(torch.ones(1)), torch.ones(1) + 0.5 + 10)
2312 
2313  with self.assertRaisesRegex(RuntimeError, "Expected a default value"):
2314 
2315  @torch.jit.script
2316  def hints_bad_types(x, a=10, b=0.5): # noqa: T484
2317  # type: (Tensor, float, int) -> Tensor
2318  return x + a + b
2319 
2320  def test_module_default_values(self):
2321  four = torch.tensor(4)
2322 
2323  class Test(torch.jit.ScriptModule):
2324  def __init__(self):
2325  super(Test, self).__init__()
2326 
2327  @torch.jit.script_method
2328  def forward(self, input, other=four):
2329  return input + other
2330 
2331  t = Test()
2332  self.assertEqual(t(torch.ones(1)), torch.ones(1) + 4)
2333 
2334  def test_warnings(self):
2335  import warnings
2336 
2337  @torch.jit.script
2338  def fn(x):
2339  if bool(x < 2):
2340  warnings.warn("x is less than 2")
2341  return x
2342 
2343  FileCheck().check("aten::warn").run(str(fn.graph))
2344 
2345  def test_no_erroneous_warnings(self):
2346  import warnings
2347 
2348  def fn(x):
2349  if bool(x > 0):
2350  warnings.warn('This should NOT be printed')
2351  x += 1
2352  return x
2353 
2354  with warnings.catch_warnings(record=True) as warns:
2355  fn_script = torch.jit.script(fn)
2356  fn_script(torch.tensor(0))
2357  warns = [str(w.message) for w in warns]
2358  self.assertEqual(len(warns), 0)
2359 
2360  @unittest.skipIf(sys.platform == "win32", "TODO: need to fix this test case for Windows")
2361  def test_torch_load_error(self):
2362  class J(torch.jit.ScriptModule):
2363  def __init__(self):
2364  super(J, self).__init__()
2365 
2366  @torch.jit.script_method
2367  def forward(self, input):
2368  return input + 100
2369 
2370  j = J()
2371  with tempfile.NamedTemporaryFile() as f:
2372  j.save(f.name)
2373  with self.assertRaisesRegex(RuntimeError, "is a zip"):
2374  torch.load(f.name)
2375 
2376  def test_legacy_constructors(self):
2377  def fn(x):
2378  return x.new_zeros(5, 5, requires_grad=False)
2379 
2380  with warnings.catch_warnings(record=True) as warns:
2381  torch.jit.trace(fn, (torch.ones(2, 2)))
2382  warns = [str(w.message) for w in warns]
2383  self.assertEqual(len(warns), 1)
2384  self.assertEqual(warns[0], "new_zeros is a legacy constructor and is not supported in the JIT.")
2385 
2386  def test_python_bindings(self):
2387  lstm_cell = torch.jit.script(LSTMCellS)
2388 
2389  def lstm(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
2390  for i in range(x.size(0)):
2391  hx, cx = lstm_cell(x[i], hx, cx, w_ih, w_hh, b_ih, b_hh)
2392  return hx
2393 
2394  slstm = torch.jit.script(lstm)
2395 
2396  inputs = get_lstm_inputs('cpu', training=True, seq_length=10)
2397  slstm(*inputs).sum().backward()
2398  global fw_graph
2399  fw_graph = slstm.graph_for(*inputs)
2400  nodes = [n for n in fw_graph.nodes()]
2401  tested_blocks = False
2402  for node in nodes:
2403  for output in [o for o in node.outputs()]:
2404  self.assertTrue(hasattr(output, 'type'))
2405  self.assertTrue(output.type() is not None)
2406  for input in [i for i in node.inputs()]:
2407  self.assertTrue(hasattr(input, 'type'))
2408  self.assertTrue(input.type() is not None)
2409  for block in [b for b in node.blocks()]:
2410  tested_blocks = True
2411  self.assertTrue(hasattr(block, 'inputs'))
2412  self.assertTrue(hasattr(block, 'outputs'))
2413  for output in [o for o in block.outputs()]:
2414  self.assertTrue(hasattr(output, 'type'))
2415  self.assertTrue(output.type() is not None)
2416  for input in [i for i in block.inputs()]:
2417  self.assertTrue(hasattr(input, 'type'))
2418  self.assertTrue(input.type() is not None)
2419  self.assertTrue(hasattr(block, 'returnNode'))
2420  self.assertTrue(type(block.returnNode()) == torch._C.Node)
2421  self.assertTrue(hasattr(block, 'paramNode'))
2422  self.assertTrue(type(block.paramNode()) == torch._C.Node)
2423  self.assertTrue(tested_blocks)
2424 
2425 
2427  # generate random examples and create an batchtensor with them
2428  def rand_batch(self, *dims):
2429  dims = [dim for dim in dims if dim != ()]
2430  xs = [torch.rand(1, *(random.randint(1, size) if b else size for b, size in dims[1:]),
2431  requires_grad=True) for i in range(dims[0])]
2432  xb = BatchTensor(xs, torch.tensor([b for b, d in dims[1:]]).byte())
2433  return xs, xb
2434 
2435  def test_create_batchtensor(self):
2436  # create from tensorlist
2437  xs, batch = self.rand_batch(4, (True, 3), (False, 2), (True, 5))
2438  self.assertEqual(xs, batch.examples())
2439  # create from data, mask, dims
2440  batch2 = BatchTensor(batch.get_data(), batch.get_mask(), batch.get_dims())
2441  self.assertEqual(xs, batch2.examples())
2442  # expand a tensor to a batchtensor given batch_size
2443  xs = torch.rand(3, 4, 5)
2444  batch3 = BatchTensor(xs, 2)
2445  xs = xs.unsqueeze(0)
2446  self.assertEqual([xs, xs], batch3.examples())
2447 
2448  def test_batch_elementwise_unary(self):
2449  @torch.jit.batch(batch_size=4)
2450  def tanh(a):
2451  return torch.tanh(a)
2452 
2453  xs, batch = self.rand_batch(4, (True, 3), (False, 2))
2454  res_batch = tanh(batch)
2455  res = [torch.tanh(xs[j]) for j in range(4)]
2456  self.assertEqual(res, res_batch.examples())
2457 
2458  def test_batch_elementwise_binary(self):
2459  @torch.jit.batch(batch_size=4)
2460  def add(a, b):
2461  return a + b
2462 
2463  xs, batch = self.rand_batch(4, (True, 3), (False, 2))
2464  xs2, batch2 = xs, batch
2465  res_batch = add(batch, batch2)
2466  res = [torch.add(xs[j], xs2[j]) for j in range(4)]
2467  self.assertEqual(res, res_batch.examples())
2468 
2469  # test broadcast
2470  xs, batch = self.rand_batch(4, (False, 3), (False, 2))
2471  b = torch.rand(3, 2)
2472  res_batch = add(batch, b)
2473  res = [torch.add(xs[j], b) for j in range(4)]
2474  self.assertEqual(res, res_batch.examples())
2475 
2476  def test_batch_mm(self):
2477  @torch.jit.batch(batch_size=4)
2478  def mm(a, b):
2479  return torch.mm(a, b)
2480 
2481  xs, batch = self.rand_batch(4, (True, 3), (False, 2))
2482  xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
2483  res_batch = mm(batch, batch2)
2484  res = [torch.mm(xs[j].squeeze(0), xs2[j].squeeze(0)).unsqueeze(0) for j in range(4)]
2485  self.assertEqual(res, res_batch.examples())
2486 
2487  # test broadcast
2488  b = torch.rand(2, 4)
2489  res_batch = mm(batch, b)
2490  res = [torch.mm(xs[j].squeeze(0), b).unsqueeze(0) for j in range(4)]
2491  self.assertEqual(res, res_batch.examples())
2492 
2493  def test_batch_matmul(self):
2494  @torch.jit.batch(batch_size=4)
2495  def matmul(a, b):
2496  return torch.matmul(a, b)
2497 
2498  def matmul_test(xs, batch, xs2, batch2):
2499  ys = [torch.matmul(xs[j].squeeze(0), xs2[j].squeeze(0)).unsqueeze(0) for j in range(4)]
2500  ybs = matmul(batch, batch2)
2501  self.assertEqual(ys, ybs.examples())
2502 
2503  # 1 dimension * 1 dimension
2504  xs, batch = self.rand_batch(4, (False, 2))
2505  xs2, batch2 = self.rand_batch(4, (False, 2))
2506  matmul_test(xs, batch, xs2, batch2)
2507  # 1 dimension * 2 dimension
2508  xs, batch = self.rand_batch(4, (False, 2))
2509  xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
2510  matmul_test(xs, batch, xs2, batch2)
2511  # 2 dimension * 1 dimensions
2512  xs, batch = self.rand_batch(4, (True, 3), (False, 2))
2513  xs2, batch2 = self.rand_batch(4, (False, 2))
2514  matmul_test(xs, batch, xs2, batch2)
2515  # 2 dimension * 2 dimension
2516  xs, batch = self.rand_batch(4, (True, 3), (False, 2))
2517  xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
2518  matmul_test(xs, batch, xs2, batch2)
2519 
2520  def test_batch_select(self):
2521  @torch.jit.batch(batch_size=4)
2522  def select(x):
2523  return torch.select(x, 1, 0)
2524 
2525  xs, batch = self.rand_batch(4, (True, 3), (True, 2))
2526  res_batch = select(batch)
2527  res = [torch.select(xs[j], 1, 0) for j in range(4)]
2528  self.assertEqual(res, res_batch.examples())
2529 
2530  xs, batch = self.rand_batch(4, (False, 3), (True, 2))
2531  res_batch = select(batch)
2532  res = [torch.select(xs[j], 1, 0) for j in range(4)]
2533  self.assertEqual(res, res_batch.examples())
2534 
2535  def test_batch_index_select(self):
2536  @torch.jit.batch(batch_size=4)
2537  def index_select(x, ind):
2538  return x.index_select(1, ind)
2539 
2540  xs, batch = self.rand_batch(4, (False, 5), (True, 2))
2541  ind = [torch.randint(0, 4, (1,), dtype=torch.long) for i in range(4)]
2542  ind_batch = BatchTensor(ind, torch.tensor([]).byte())
2543  res_batch = index_select(batch, ind_batch)
2544  res = [torch.index_select(xs[j], 1, ind[j]) for j in range(4)]
2545  self.assertEqual(res, res_batch.examples())
2546 
2547  def test_batch_where(self):
2548  @torch.jit.batch(batch_size=4)
2549  def where(c, a, b):
2550  return torch.where(c, a, b)
2551 
2552  xs, batch = self.rand_batch(4, (False, 3), (False, 2))
2553  xs2, batch2 = self.rand_batch(4, (False, 3), (False, 2))
2554 
2555  dims = [4, (False, 3), (False, 2)]
2556  xs_cond = [torch.rand(1, 3, 2).byte() for i in range(dims[0])]
2557  batch_cond = BatchTensor(xs_cond, torch.tensor([b for b, d in dims[1:]]))
2558 
2559  res_batch = where(batch_cond, batch, batch2)
2560  res = [torch.where(xs_cond[j], xs[j], xs2[j]) for j in range(4)]
2561  self.assertEqual(res, res_batch.examples())
2562 
2563  def test_batch_argmax(self):
2564  @torch.jit.batch(batch_size=4)
2565  def argmax(a):
2566  return torch.argmax(a, 1)
2567 
2568  xs, batch = self.rand_batch(4, (True, 5), (True, 6))
2569  res_batch = argmax(batch)
2570  res = [torch.argmax(xs[j], 1) for j in range(4)]
2571  self.assertEqual(res, res_batch.examples())
2572 
2573  @torch.jit.batch(batch_size=4)
2574  def argmax(a):
2575  return torch.argmax(a, 1, False)
2576 
2577  res_batch = argmax(batch)
2578  res = [torch.argmax(xs[j], 1, False) for j in range(4)]
2579  self.assertEqual(res, res_batch.examples())
2580 
2581  def test_batch_topk(self):
2582  @torch.jit.batch(batch_size=4)
2583  def topk(a):
2584  return torch.topk(a, 3, 1)
2585 
2586  xs, batch = self.rand_batch(4, (False, 5), (True, 6))
2587 
2588  # along static dim
2589  res_batch = topk(batch)
2590  res = [torch.topk(xs[j], 3, 1)[0] for j in range(4)]
2591  res_idx = [torch.topk(xs[j], 3, 1)[1] for j in range(4)]
2592  self.assertEqual(res, res_batch[0].examples())
2593  self.assertEqual(res_idx, res_batch[1].examples())
2594 
2595  @torch.jit.batch(batch_size=4)
2596  def topk(a):
2597  return torch.topk(a, 1, 2)
2598 
2599  # along dynamic dim
2600  res_batch = topk(batch)
2601  res = [torch.topk(xs[j], 1, 2)[0] for j in range(4)]
2602  res_idx = [torch.topk(xs[j], 1, 2)[1] for j in range(4)]
2603  self.assertEqual(res, res_batch[0].examples())
2604  self.assertEqual(res_idx, res_batch[1].examples())
2605 
2606  def test_batch_softmax(self):
2607  @torch.jit.batch(batch_size=4)
2608  def softmax(a):
2609  return torch.softmax(a, 1)
2610 
2611  xs, batch = self.rand_batch(4, (False, 5), (True, 6))
2612 
2613  # along static dim
2614  res_batch = softmax(batch)
2615  res = [torch.softmax(xs[j], 1) for j in range(4)]
2616  self.assertEqual(res, res_batch.examples())
2617 
2618  @torch.jit.batch(batch_size=4)
2619  def softmax(a):
2620  return torch.softmax(a, 2)
2621 
2622  # along dynamic dim
2623  res_batch = softmax(batch)
2624  res = [torch.softmax(xs[j], 2) for j in range(4)]
2625  self.assertEqual(res, res_batch.examples())
2626 
2627  def test_batch_view(self):
2628  @torch.jit.batch(batch_size=4)
2629  def view(a):
2630  return a.view([4, -1, 3])
2631 
2632  xs, batch = self.rand_batch(4, (True, 5), (False, 3))
2633  res_batch = view(batch)
2634  res = [xs[j].view([1, -1, 3]) for j in range(4)]
2635  self.assertEqual(res, res_batch.examples())
2636 
2637  def test_batch_cat(self):
2638  @torch.jit.batch(batch_size=4)
2639  def cat2(a, b):
2640  return torch.cat([a, b], 2)
2641 
2642  xs, batch = self.rand_batch(4, (True, 5), (False, 3))
2643  xs2, batch2 = xs, batch
2644  res_batch = cat2(batch, batch2)
2645  res = [torch.cat([xs[j], xs2[j]], 2) for j in range(4)]
2646  self.assertEqual(res, res_batch.examples())
2647 
2648  def test_batch_sum(self):
2649  @torch.jit.batch(batch_size=4)
2650  def batch_sum(a):
2651  return a.sum()
2652 
2653  xs, batch = self.rand_batch(4, (True, 5), (False, 3))
2654  res_batch = batch_sum(batch)
2655  res = [xs[j].sum().unsqueeze(0) for j in range(4)]
2656  self.assertEqual(res, res_batch.examples())
2657 
2658  def test_if_else(self):
2659  def single_if(a, b):
2660  if bool(a > b):
2661  a = a + b
2662  else:
2663  a = a - b
2664  return a
2665 
2666  batch_if = torch.jit.batch(batch_size=4)(single_if)
2667 
2668  a, batch_a = self.rand_batch(4, ())
2669  b, batch_b = self.rand_batch(4, ())
2670  res_batch = batch_if(batch_a, batch_b)
2671  res = [single_if(a[j], b[j]) for j in range(4)]
2672  self.assertEqual(res, res_batch.examples())
2673 
2674  script_if = torch.jit.script(single_if)
2675  torch.to_batch_graph(script_if.graph)
2676 
2677  def test_if_else_with_scalar(self):
2678  def single_if(a, b):
2679  if bool(a > 0.1):
2680  a = a + b
2681  else:
2682  a = a - b
2683  return a
2684 
2685  batch_if = torch.jit.batch(batch_size=4)(single_if)
2686 
2687  a, batch_a = self.rand_batch(4, ())
2688  b, batch_b = self.rand_batch(4, ())
2689  res_batch = batch_if(batch_a, batch_b)
2690  res = [single_if(a[j], b[j]) for j in range(4)]
2691  self.assertEqual(res, res_batch.examples())
2692 
2693  script_if = torch.jit.script(single_if)
2694  torch.to_batch_graph(script_if.graph)
2695 
2696  def test_if_noelse(self):
2697  def single_if(a, b):
2698  if bool(a > b):
2699  a = a + b
2700  return a
2701 
2702  batch_if = torch.jit.batch(batch_size=4)(single_if)
2703 
2704  a, batch_a = self.rand_batch(4, ())
2705  b, batch_b = self.rand_batch(4, ())
2706  res_batch = batch_if(batch_a, batch_b)
2707  res = [single_if(a[j], b[j]) for j in range(4)]
2708  self.assertEqual(res, res_batch.examples())
2709 
2710  script_if = torch.jit.script(single_if)
2711  torch.to_batch_graph(script_if.graph)
2712 
2713  def test_if_noelse_with_scalar(self):
2714  def single_if(a, b):
2715  if bool(a > 0.1):
2716  a = a + b
2717  return a
2718 
2719  batch_if = torch.jit.batch(batch_size=4)(single_if)
2720 
2721  a, batch_a = self.rand_batch(4, ())
2722  b, batch_b = self.rand_batch(4, ())
2723  res_batch = batch_if(batch_a, batch_b)
2724  res = [single_if(a[j], b[j]) for j in range(4)]
2725  self.assertEqual(res, res_batch.examples())
2726 
2727  script_if = torch.jit.script(single_if)
2728  torch.to_batch_graph(script_if.graph)
2729 
2730  def test_while(self):
2731  def single_while(a, b):
2732  while bool(a > b):
2733  a = a - b
2734  return a
2735 
2736  batch_while = torch.jit.batch(batch_size=4)(single_while)
2737 
2738  a, batch_a = self.rand_batch(4, ())
2739  b = [torch.abs(torch.rand(1)) for i in range(4)]
2740  batch_b = BatchTensor(b, torch.tensor([]).byte())
2741  res_batch = batch_while(batch_a, batch_b)
2742  res = [single_while(a[j], b[j]) for j in range(4)]
2743  self.assertEqual(res, res_batch.examples())
2744 
2745  script_while = torch.jit.script(single_while)
2746  torch.to_batch_graph(script_while.graph)
2747 
2748  def test_for(self):
2749  def single_for(x, y):
2750  for _ in range(10):
2751  x = x + y
2752  return x
2753 
2754  batch_for = torch.jit.batch(batch_size=4)(single_for)
2755 
2756  a, batch_a = self.rand_batch(4, ())
2757  b, batch_b = self.rand_batch(4, ())
2758  res_batch = batch_for(batch_a, batch_b)
2759  res = [single_for(a[j], b[j]) for j in range(4)]
2760  self.assertEqual(res, res_batch.examples())
2761 
2762  script_for = torch.jit.script(single_for)
2763  torch.to_batch_graph(script_for.graph)
2764 
2765  def test_lstm(self):
2766  def LSTM(x_all, h, c, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c):
2767  for i in range(x_all.size(1)):
2768  x = x_all.select(1, i)
2769  i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
2770  f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
2771  o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
2772  # activations
2773  i_t = torch.sigmoid(i_t)
2774  f_t = torch.sigmoid(f_t)
2775  o_t = torch.sigmoid(o_t)
2776  # cell computations
2777  c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
2778  c_t = torch.tanh(c_t)
2779  c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
2780  h_t = torch.mul(o_t, torch.tanh(c_t))
2781  h = h_t
2782  c = c_t
2783  return h
2784 
2785  LSTM_batch = torch.jit.batch(batch_size=4)(LSTM)
2786 
2787  batch_size, input_size, hidden_size = 4, 3, 2
2788  xs, batch = self.rand_batch(batch_size, (True, 4), (False, input_size))
2789  hx, h_batch = self.rand_batch(batch_size, (False, hidden_size))
2790  cx, c_batch = self.rand_batch(batch_size, (False, hidden_size))
2791 
2792  # input to hidden weights
2793  w_xi = torch.rand(input_size, hidden_size)
2794  w_xf = torch.rand(input_size, hidden_size)
2795  w_xo = torch.rand(input_size, hidden_size)
2796  w_xc = torch.rand(input_size, hidden_size)
2797  # hidden to hidden weights
2798  w_hi = torch.rand(hidden_size, hidden_size)
2799  w_hf = torch.rand(hidden_size, hidden_size)
2800  w_ho = torch.rand(hidden_size, hidden_size)
2801  w_hc = torch.rand(hidden_size, hidden_size)
2802  # bias terms
2803  b_i = torch.rand(hidden_size)
2804  b_f = torch.rand(hidden_size)
2805  b_o = torch.rand(hidden_size)
2806  b_c = torch.rand(hidden_size)
2807 
2808  ys = [LSTM(xs[j], hx[j], cx[j], w_xi, w_xf, w_xo, w_xc,
2809  w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c) for j in range(batch_size)]
2810  ybs = LSTM_batch(batch, h_batch, c_batch, w_xi, w_xf, w_xo, w_xc,
2811  w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c)
2812  self.assertEqual(ys, ybs.examples())
2813 
2814  def test_greedy_search(self):
2815  def greedy(x, h, c, embed, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
2816  b_i, b_f, b_o, b_c, w_hs, b_s, iter_num):
2817  iter_count = torch.zeros_like(iter_num)
2818  while bool(iter_count < iter_num):
2819  iter_count = iter_count + 1
2820  # LSTM Cell
2821  i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
2822  f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
2823  o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
2824  # activations
2825  i_t = torch.sigmoid(i_t)
2826  f_t = torch.sigmoid(f_t)
2827  o_t = torch.sigmoid(o_t)
2828  # cell computations
2829  c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
2830  c_t = torch.tanh(c_t)
2831  c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
2832  h_t = torch.mul(o_t, torch.tanh(c_t))
2833  h = h_t
2834  c = c_t
2835  # calculate feature with max probability
2836  s_t = torch.matmul(h_t, w_hs) + b_s
2837  p_t = torch.softmax(s_t, 1)
2838  i_t = torch.argmax(p_t, 1)
2839  x = embed.index_select(1, i_t).squeeze(1)
2840  return h
2841 
2842  greedy_batch = torch.jit.batch(batch_size=4)(greedy)
2843 
2844  batch_size, input_size, hidden_size, vocab_size = 4, 6, 8, 7
2845  xs, batch = self.rand_batch(batch_size, (False, input_size))
2846  hx, h_batch = self.rand_batch(batch_size, (False, hidden_size))
2847  cx, c_batch = self.rand_batch(batch_size, (False, hidden_size))
2848  embed, embed_batch = self.rand_batch(batch_size, (False, vocab_size), (False, input_size))
2849  iter_num = [torch.randint(2, 5, (1,)) for i in range(batch_size)]
2850  iter_num_batch = BatchTensor(iter_num, torch.tensor([]).byte())
2851 
2852  # input to hidden weights
2853  w_xi = torch.rand(input_size, hidden_size)
2854  w_xf = torch.rand(input_size, hidden_size)
2855  w_xo = torch.rand(input_size, hidden_size)
2856  w_xc = torch.rand(input_size, hidden_size)
2857  # hidden to hidden weights
2858  w_hi = torch.rand(hidden_size, hidden_size)
2859  w_hf = torch.rand(hidden_size, hidden_size)
2860  w_ho = torch.rand(hidden_size, hidden_size)
2861  w_hc = torch.rand(hidden_size, hidden_size)
2862  # bias terms
2863  b_i = torch.rand(hidden_size)
2864  b_f = torch.rand(hidden_size)
2865  b_o = torch.rand(hidden_size)
2866  b_c = torch.rand(hidden_size)
2867  # hidden to vocab weights, bias
2868  w_hs = torch.rand(hidden_size, vocab_size)
2869  b_s = torch.rand(vocab_size)
2870 
2871  ys = [greedy(xs[j], hx[j], cx[j], embed[j], w_xi, w_xf, w_xo, w_xc,
2872  w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num[j]) for j in range(batch_size)]
2873  ybs = greedy_batch(batch, h_batch, c_batch, embed_batch, w_xi, w_xf, w_xo, w_xc,
2874  w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num_batch)
2875  self.assertEqual(ys, ybs.examples())
2876 
2877  def test_beam_search(self):
2878  def beam(x, h, c, embed, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
2879  b_i, b_f, b_o, b_c, w_hs, b_s, iter_num, idx):
2880  k = 5
2881  vocab_size = embed.size(1)
2882  iter_count = torch.zeros_like(iter_num)
2883  max_len = idx.size(2)
2884  while bool(iter_count < iter_num):
2885  iter_count = iter_count + 1
2886  # LSTM Cell
2887  i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
2888  f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
2889  o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
2890  # activations
2891  i_t = torch.sigmoid(i_t)
2892  f_t = torch.sigmoid(f_t)
2893  o_t = torch.sigmoid(o_t)
2894  # cell computations
2895  c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
2896  c_t = torch.tanh(c_t)
2897  c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
2898  h_t = torch.mul(o_t, torch.tanh(c_t))
2899  h = h_t
2900  c = c_t
2901  # calculate features with max probability
2902  s_t = torch.matmul(h_t, w_hs) + b_s
2903  s_t = s_t.view([1, s_t.size(1) * s_t.size(2)])
2904  p_t = torch.softmax(s_t, 1)
2905  prob_t, idx_t = torch.topk(p_t, k, 1)
2906  if(int(idx_t.dim()) > 1):
2907  idx_t_tmp = idx_t.squeeze(0)
2908  else:
2909  idx_t_tmp = idx_t
2910  new_y = torch.fmod(idx_t_tmp, vocab_size)
2911  pre_y = idx_t_tmp / vocab_size
2912  x = embed.index_select(1, new_y)
2913  h = h_t.index_select(1, pre_y)
2914  c = c_t.index_select(1, pre_y)
2915  iter = int(iter_count[0])
2916  idx = torch.cat([idx.narrow(2, 0, iter).index_select(1, pre_y),
2917  torch.fmod(idx_t, vocab_size).unsqueeze(-1),
2918  idx.narrow(2, iter, max_len - iter)], 2)
2919  idx = idx.narrow(2, 0, max_len)
2920  return idx
2921 
2922  beam_batch = torch.jit.batch(batch_size=4)(beam)
2923 
2924  k = 5
2925  batch_size, input_size, hidden_size, vocab_size = 4, 6, 8, 7
2926  max_len = 5
2927  xs, batch = self.rand_batch(batch_size, (False, 1), (False, input_size))
2928  hx, h_batch = self.rand_batch(batch_size, (False, 1), (False, hidden_size))
2929  cx, c_batch = self.rand_batch(batch_size, (False, 1), (False, hidden_size))
2930  embed, embed_batch = self.rand_batch(batch_size, (False, vocab_size), (False, input_size))
2931  iter_num = [torch.randint(2, max_len + 1, (1,)) for i in range(batch_size)]
2932  iter_num_batch = BatchTensor(iter_num, torch.tensor([]).byte())
2933 
2934  # input to hidden weights
2935  w_xi = torch.rand(input_size, hidden_size)
2936  w_xf = torch.rand(input_size, hidden_size)
2937  w_xo = torch.rand(input_size, hidden_size)
2938  w_xc = torch.rand(input_size, hidden_size)
2939  # hidden to hidden weights
2940  w_hi = torch.rand(hidden_size, hidden_size)
2941  w_hf = torch.rand(hidden_size, hidden_size)
2942  w_ho = torch.rand(hidden_size, hidden_size)
2943  w_hc = torch.rand(hidden_size, hidden_size)
2944  # bias terms
2945  b_i = torch.rand(1, hidden_size)
2946  b_f = torch.rand(1, hidden_size)
2947  b_o = torch.rand(1, hidden_size)
2948  b_c = torch.rand(1, hidden_size)
2949  # hidden to vocab weights, bias
2950  w_hs = torch.rand(hidden_size, vocab_size)
2951  b_s = torch.rand(1, vocab_size)
2952 
2953  idx_batch = torch.jit.BatchTensor(torch.zeros([batch_size, k, max_len], dtype=torch.long),
2954  torch.zeros([batch_size, 1, max_len]).byte(),
2955  torch.tensor([0, 1]).byte())
2956  idx = [torch.zeros([1, k, max_len], dtype=torch.long) for _ in range(batch_size)]
2957 
2958  ys = [beam(xs[j], hx[j], cx[j], embed[j], w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
2959  b_i, b_f, b_o, b_c, w_hs, b_s, iter_num[j], idx[j]).narrow(2, 0, int(iter_num[j]))
2960  for j in range(batch_size)]
2961  ybs = beam_batch(batch, h_batch, c_batch, embed_batch, w_xi, w_xf, w_xo, w_xc,
2962  w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num_batch, idx_batch)
2963  self.assertEqual(ys, ybs.examples())
2964 
2965 
2966 def execWrapper(code, glob, loc):
2967  if PY2:
2968  exec(code) in glob, loc
2969  else:
2970  exec(code, glob, loc)
2971 
2972 
2974  @contextmanager
2975  def capture_stdout(self):
2976  # No idea how to capture stdout from C++ on Windows
2977  if WINDOWS:
2978  yield ['']
2979  return
2980  import os
2981  import fcntl
2982  import errno
2983  sys.stdout.flush()
2984  stdout_fd = os.dup(1)
2985  r, w = os.pipe()
2986  try:
2987  # Override stdout with r - dup is guaranteed to return the lowest free fd
2988  os.close(1)
2989  os.dup(w)
2990 
2991  captured_stdout = ['']
2992  yield captured_stdout
2993  sys.stdout.flush() # Make sure that Python hasn't buffered anything
2994 
2995  # Do the ugly dance to read all the data that was written into the pipe
2996  fcntl.fcntl(r, fcntl.F_SETFL, os.O_NONBLOCK)
2997  total_stdout = ''
2998  while True:
2999  try:
3000  total_stdout += os.read(r, 1000).decode('ascii')
3001  except OSError as e:
3002  if e.errno != errno.EAGAIN:
3003  raise
3004  break
3005  captured_stdout[0] = total_stdout
3006  finally:
3007  # Revert the change, and clean up all fds
3008  os.close(1)
3009  os.dup(stdout_fd)
3010  os.close(stdout_fd)
3011  os.close(r)
3012  os.close(w)
3013 
3014  def checkScriptRaisesRegex(self, script, inputs, exception, regex,
3015  optimize=True, outputs=None, capture_output=False):
3016  """
3017  Checks that a given function will throw the correct exception,
3018  when executed with normal python, the string frontend, and the AST frontend
3019  """
3020  # normal python
3021  with self.assertRaisesRegex(exception, regex):
3022  script(*inputs)
3023  # string frontend
3024  with self.assertRaisesRegex(exception, regex):
3025  source = textwrap.dedent(inspect.getsource(script))
3026  cu = torch.jit.CompilationUnit(source, optimize)
3027  ge = getattr(cu, script.__name__)
3028  ge(*inputs)
3029  # python AST frontend
3030  with self.assertRaisesRegex(exception, regex):
3031  ge = torch.jit.script(script, optimize)
3032  ge(*inputs)
3033 
3034  def test_training_param(self):
3035  class What(torch.jit.ScriptModule):
3036  @torch.jit.script_method
3037  def forward(self, x):
3038  # type: (int) -> int
3039  if self.training:
3040  r = x
3041  else:
3042  r = x + 4
3043  # check double use of training
3044  if self.training:
3045  r = r + 1
3046  return r
3047 
3048  w = What()
3049  self.assertEqual(4, w(3))
3050  w.train(False)
3051  self.assertEqual(7, w(3))
3052 
3053  def test_jitter_bug(self):
3054  @torch.jit.script
3055  def fn2(input, kernel_size):
3056  # type: (Tensor, List[int]) -> Tensor
3057  if kernel_size[0] > 1:
3058  _stride = [2]
3059  else:
3060  _stride = kernel_size
3061  print(_stride, kernel_size)
3062  return input
3063 
3064  @torch.jit.script
3065  def fn(input):
3066  # type: (Tensor) -> Tensor
3067  return fn2(input, [1])
3068 
3069  def test_parser_kwargonly(self):
3070  cu = torch.jit.CompilationUnit('''
3071  def foo(x, *, y) -> Tuple[Tensor, Tensor]:
3072  return x, x
3073  def bar(x):
3074  return foo(x, y=x)
3075  ''')
3076  self.assertTrue('*' in cu.module._get_method('foo').pretty_print_schema())
3077  with self.assertRaisesRegex(RuntimeError, "not provided"):
3079  def foo(x, *, y) -> Tuple[Tensor, Tensor]:
3080  return x, x
3081  def bar(x):
3082  return foo(x, x)
3083  ''')
3084 
3085  def test_annoying_doubles(self):
3086  mod = types.ModuleType("temp")
3087  mod.inf = float("inf")
3088  mod.ninf = float("-inf")
3089  mod.nan = float("nan")
3090 
3091  with self.disableModuleHook():
3092  @torch.jit.script
3093  def foo():
3094  return math.pi, 0.1, mod.inf, mod.ninf, 2.225073858507201e-308, mod.nan
3095 
3096  pp, table = foo._get_method('forward').python_print()
3097  ppv = "op_version_set = 0\n{}".format(pp)
3098  sm = torch.jit.ScriptModule()
3099  torch._C._jit_import_methods(sm, ppv, table)
3100  r = foo()
3101  r2 = sm()
3102  # use precise assert, we are checking floating point details
3103  self.assertTrue(r[:-1] == r2[:-1])
3104  self.assertTrue(math.isnan(r[-1]) and math.isnan(r2[-1]))
3105 
3106  def test_type_annotate(self):
3107 
3108  def foo(a):
3109  return torch.jit.annotate(torch.Tensor, a)
3110 
3111  self.checkScript(foo, (torch.rand(3),))
3112 
3113  def bar():
3114  a = torch.jit.annotate(List[int], [])
3115  for _ in range(10):
3116  a.append(4)
3117  return a
3118 
3119  self.checkScript(bar, ())
3120 
3121  def baz(a):
3122  return torch.jit.annotate(float, a)
3123  self.checkScript(baz, (torch.rand(()),))
3124 
3125  # test annotate none types
3126  def annotate_none():
3127  return torch.jit.annotate(Optional[torch.Tensor], None)
3128 
3129  def annotate_none_no_optional():
3130  return torch.jit.annotate(torch.Tensor, None)
3131 
3132  self.checkScript(annotate_none, ())
3133  self.checkScript(annotate_none_no_optional, ())
3134 
3135  def test_robust_op_resolution(self):
3136  neg = torch.add # misleading name to make sure we resolve by function
3137 
3138  def stuff(x):
3139  return neg(x, x)
3140 
3141  a = (torch.rand(3),)
3142  self.checkScript(stuff, a)
3143 
3144  def test_tuple_io(self):
3145  def stuff(x):
3146  # type: (Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]
3147  a, b = x
3148  return b, a
3149 
3150  a = (torch.rand(3), torch.rand(3))
3151  self.checkScript(stuff, (a,))
3152 
3153  def test_tuple_create_return(self):
3154  def stuff2(x):
3155  # type: (int) -> Tuple[Tensor, Tensor]
3156  a = (torch.ones(x), torch.zeros(x))
3157  return a
3158  self.checkScript(stuff2, (3,))
3159 
3160  def test_list_io(self):
3161  def stuff3(x):
3162  # type: (List[int]) -> Tuple[Tensor, List[int]]
3163  return torch.ones(x), x
3164  self.checkScript(stuff3, ([3, 2],))
3165 
3166  # to avoid defining sum_list in multiple tests
3167  def get_sum_list_fn(self):
3168  def sum_list(a):
3169  # type: (List[int]) -> int
3170  sum = 0
3171  for i in a:
3172  sum += i
3173 
3174  return sum
3175 
3176  return sum_list
3177 
3178  def test_sum_list_diff_elms(self):
3179  self.checkScript(self.get_sum_list_fn(), ([1, 2, 3, 4, 5],))
3180 
3181  def test_sum_list_empty(self):
3182  self.checkScript(self.get_sum_list_fn(), ([],))
3183 
3184  def test_sum_list_one(self):
3185  self.checkScript(self.get_sum_list_fn(), ([1],))
3186 
3187  def test_sum_list_literal(self):
3188 
3189  def sum_list():
3190  # type: () -> int
3191  sum = 0
3192  for i in [1, 2, 3, 4, 5]:
3193  sum += i
3194 
3195  return sum
3196 
3197  self.checkScript(sum_list, ())
3198 
3199  def test_sum_list_wrong_type(self):
3200 
3201  with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
3202  @torch.jit.script
3203  def sum_list(a):
3204  # type: (int) -> int
3205  sum = 0
3206  for i in a: # noqa: T484
3207  sum += i
3208 
3209  return sum
3210 
3211  sum_list(1)
3212 
3213  def test_bool_list_io(self):
3214  @torch.jit.script
3215  def stuff4(x):
3216  # type: (List[bool]) -> Tuple[List[bool], List[bool], List[List[bool]]]
3217  return x, [True, False], [[True]]
3218 
3219  li_1, li_2, li_3 = stuff4([True])
3220  li_3 = li_3[0]
3221  for li in [li_1, li_2, li_3]:
3222  self.assertTrue(type(li[0]) == type(True))
3223 
3224  def test_nested_list(self):
3225  def foo(z):
3226  # type: (Tuple[int, List[List[int]]]) -> int
3227  x, y = z
3228  return y[0][1]
3229  self.checkScript(foo, ((1, [[1, 2], [3, 4]]),))
3230 
3231  def test_nested_list_construct(self):
3232  def foo():
3233  return [[4]] + [[4, 5]]
3234  self.checkScript(foo, ())
3235 
3236  def test_tensor_shape(self):
3237  x = torch.empty(34, 56, 78)
3238 
3239  def f(x):
3240  return x.shape
3241 
3242  self.checkScript(f, (x,))
3243 
3244  def test_tensor_grad(self):
3245  x = torch.tensor(1.0, requires_grad=True)
3246  y = torch.tensor(1.0, requires_grad=False)
3247 
3248  def f(x):
3249  return x.requires_grad
3250 
3251  self.checkScript(f, (x,))
3252  self.checkScript(f, (y,))
3253 
3254  def test_tensor_dtype(self):
3255  x_byte = torch.empty(34, 56, 78, dtype=torch.uint8)
3256  x_long = torch.empty(34, 56, 78, dtype=torch.long)
3257  x_float32 = torch.empty(34, 56, 78, dtype=torch.float32)
3258 
3259  @torch.jit.script
3260  def byte(x):
3261  return x.dtype == torch.uint8
3262 
3263  @torch.jit.script
3264  def long(x):
3265  return x.dtype == torch.long
3266 
3267  @torch.jit.script
3268  def float32(x):
3269  return x.dtype == torch.float32
3270 
3271  self.assertTrue(byte(x_byte))
3272  self.assertFalse(byte(x_long))
3273  self.assertFalse(byte(x_float32))
3274  self.assertFalse(long(x_byte))
3275  self.assertTrue(long(x_long))
3276  self.assertFalse(long(x_float32))
3277  self.assertFalse(float32(x_byte))
3278  self.assertFalse(float32(x_long))
3279  self.assertTrue(float32(x_float32))
3280 
3281  @unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
3282  def test_tensor_device(self):
3283  cpu = torch.empty(34, 56, 78, device='cpu')
3284  gpu = torch.empty(34, 56, 78, device='cuda')
3285 
3286  @torch.jit.script
3287  def same_device(x, y):
3288  return x.device == y.device
3289 
3290  self.assertTrue(same_device(cpu, cpu))
3291  self.assertTrue(same_device(gpu, gpu))
3292  self.assertFalse(same_device(cpu, gpu))
3293 
3294  @unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
3295  def test_tensor_to_device(self):
3296  def to_device(x):
3297  return x.to(device="cuda").to(device=torch.device("cpu"))
3298 
3299  self.checkScript(to_device, (torch.ones(3, 4),))
3300 
3301  def test_tensor_to_cpu(self):
3302  def to_cpu(x):
3303  return x.cpu()
3304 
3305  x = torch.ones(3, 4)
3306  script_fn = torch.jit.script(to_cpu)
3307  self.assertEqual(to_cpu(x).device, script_fn(x).device)
3308  self.checkScript(to_cpu, (x,))
3309 
3310  @unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
3311  def test_tensor_to_cuda(self):
3312  def to_cuda(x):
3313  return x.cuda()
3314 
3315  x = torch.ones(3, 4)
3316  script_fn = torch.jit.script(to_cuda)
3317  self.assertEqual(to_cuda(x).device, script_fn(x).device)
3318  self.checkScript(to_cuda, (x,))
3319 
3320  def test_generic_list_errors(self):
3321  with self.assertRaisesRegex(RuntimeError, "previously matched to type"):
3322  @torch.jit.script
3323  def foo(x):
3324  return [[x]] + [[1]]
3325 
3326  def test_script_cu(self):
3327  cu = torch.jit.CompilationUnit('''
3328  def foo(a):
3329  b = a
3330  return b
3331  ''')
3332  a = Variable(torch.rand(1))
3333  self.assertEqual(a, cu.foo(a))
3334 
3335  # because the compilation unit ingests python strings
3336  # to use an escape sequence escape the backslash (\\n = \n)
3337  def test_string_cu(self):
3338  cu = torch.jit.CompilationUnit('''
3339  def foo(a):
3340  print(a, """a\\n\tb\\n""", 2, "a\
3341 a")
3342  return a
3343  ''')
3344  FileCheck().check("aa").check("a\\n\\tb\\n").run(str(cu.foo.graph))
3345 
3346  def test_string_ops(self):
3347  def foo():
3348  a = "a" + "b"
3349  return a + a, "ab" == "b", "ab" != "b", "ab" == "ab", "ab" != "ab"
3350 
3351  self.checkScript(foo, ())
3352 
3353  def test_string_new_line(self):
3354  with self.assertRaisesRegex(RuntimeError, "expected a valid token*"):
3356  def test_while(a):
3357  print("
3358  a")
3359  return a
3360  ''')
3361 
3362  def test_string_single_escape(self):
3363  with self.assertRaisesRegex(RuntimeError, "expected a valid token*"):
3365  def test_while(a):
3366  print("\\")
3367  return a
3368  ''')
3369 
3370  def test_script_annotation(self):
3371  @torch.jit.script
3372  def foo(a):
3373  return a + a + a
3374  s = Variable(torch.rand(2))
3375  self.assertEqual(s + s + s, foo(s))
3376 
3377  def test_inf(self):
3378  @torch.jit.script
3379  def foo(a):
3380  return a < float('inf')
3381  s = torch.rand(1)
3382  self.assertTrue(foo(s))
3383 
3384  @torch.jit.script
3385  def bar(a):
3386  return a > float('-inf')
3387  s = torch.rand(1)
3388  self.assertTrue(foo(s))
3389 
3390  def test_add(self):
3391  def func(a, b):
3392  c = a + b
3393  c += a
3394  return c
3395 
3396  a = torch.rand(1, requires_grad=True)
3397  b = torch.rand(1, requires_grad=True)
3398  self.checkScript(func, (a, b), optimize=True)
3399 
3400  def test_mul(self):
3401  def func(a, b):
3402  return a * b
3403 
3404  a = torch.rand(1, requires_grad=True)
3405  b = torch.rand(1, requires_grad=True)
3406  self.checkScript(func, (a, b), optimize=True)
3407 
3408  @unittest.skipIf(not PY35, "Python 3.5 needed")
3409  def test_matmul_py3(self):
3410  code = dedent("""
3411  def fn(a, b):
3412  return a @ b
3413  """)
3414 
3415  with tempfile.TemporaryDirectory() as tmp_dir:
3416  script_path = os.path.join(tmp_dir, 'script.py')
3417  with open(script_path, 'w') as f:
3418  f.write(code)
3419  fn = get_fn('test_matmul_py3', script_path)
3420 
3421  a = torch.rand(4, 3, requires_grad=True)
3422  b = torch.rand(3, 2, requires_grad=True)
3423  self.checkScript(fn, (a, b), optimize=True)
3424 
3425  def test_pow(self):
3426  def func(a, b):
3427  return a ** b
3428 
3429  def func2(a, b, c, d):
3430  return c + a ** b ** d
3431 
3432  a = torch.rand(1, requires_grad=True)
3433  b = torch.rand(1, requires_grad=True)
3434  c = torch.rand(1, requires_grad=True)
3435  d = torch.rand(1, requires_grad=True)
3436  self.checkScript(func, (a, b), optimize=True)
3437  self.checkScript(func2, (a, b, c, d), optimize=True)
3438 
3439  def test_triple(self):
3440  def func(x):
3441  return 3. * x
3442 
3443  x = torch.rand(1, dtype=torch.float, requires_grad=True)
3444  self.checkScript(func, [x], optimize=True)
3445 
3446  def test_slice(self):
3447  def func(x):
3448  return x[:5]
3449 
3450  x = torch.rand(10, dtype=torch.float, requires_grad=True)
3451  self.checkScript(func, [x], optimize=True)
3452 
3453  def func2(x):
3454  return x[5:]
3455 
3456  self.checkScript(func2, [x], optimize=True)
3457 
3458  def test_gather(self):
3459  def func(x):
3460  return x[0]
3461 
3462  x = torch.rand(10, dtype=torch.float, requires_grad=True)
3463  self.checkScript(func, [x], optimize=True)
3464 
3465  def test_random(self):
3466  @torch.jit.script
3467  def f(mean, std):
3468  return torch.normal(mean, std)
3469 
3470  mean, std = torch.zeros(5, 5), torch.ones(5, 5)
3471  with torch.random.fork_rng(devices=[]):
3472  output = torch.normal(mean, std)
3473  with torch.random.fork_rng(devices=[]):
3474  script_output = f(mean, std)
3475  self.assertEqual(output, script_output)
3476 
3477  def _check_code(self, code_str, fn_name, inputs):
3478  scope = {}
3479  exec(code_str, globals(), scope)
3480  cu = torch.jit.CompilationUnit(code_str)
3481  self.assertEqual(cu.func(*inputs), scope[fn_name](*inputs))
3482 
3483  @unittest.skipIf(not RUN_CUDA, 'no CUDA')
3484  def test_scriptmodule_releases_tensors_cuda(self):
3485  @torch.jit.script
3486  def fn(x, y):
3487  return x.sigmoid() * y.tanh()
3488 
3489  def test(backward=False):
3490  x = torch.randn(3, 3, dtype=torch.double, device='cuda', requires_grad=True)
3491  y = torch.randn(3, 3, dtype=torch.double, device='cuda', requires_grad=True)
3492  out = fn(x, y)
3493  if backward:
3494  out.sum().backward()
3495 
3496  with self.assertLeaksNoCudaTensors():
3497  test()
3498  test()
3499  test()
3500 
3501  with self.assertLeaksNoCudaTensors():
3502  test(backward=True)
3503  test(backward=True)
3504  test(backward=True)
3505 
3506  def test_index(self):
3507  def consec(size, start=0):
3508  numel = torch.tensor(size).prod().item()
3509  return torch.arange(numel).view(size)
3510 
3511  def check_indexing(indexing, tensor):
3512  template = dedent("""
3513  def func(x):
3514  return x{}
3515  """)
3516 
3517  self._check_code(template.format(indexing), "func", [tensor])
3518 
3519  def check_dynamic_indexing(indexing, tensor, value1, value2):
3520  value1 = torch.tensor(value1)
3521  value2 = torch.tensor(value2)
3522 
3523  template = dedent("""
3524  def func(x, value1, value2):
3525  i = int(value1)
3526  j = int(value2)
3527  return x{}
3528  """)
3529 
3530  self._check_code(template.format(indexing), "func", [tensor, value1, value2])
3531 
3532  # basic slices
3533  check_indexing('[0]', consec((3, 3)))
3534  check_indexing('[1]', consec((3, 3), 10))
3535  check_indexing('[2]', consec((3, 3), 19))
3536  check_indexing('[2]', consec((3,)))
3537  check_indexing('[-1]', consec((3, 3), 19))
3538  check_indexing('[0:2]', consec((3, 3, 3)))
3539  check_indexing('[1:-1]', consec((3, 3, 3)))
3540  check_indexing('[-3:-1]', consec((6, 3)))
3541  check_indexing('[1:]', consec((3, 3)))
3542  check_indexing('[:1]', consec((3, 3)))
3543  check_indexing('[:]', consec((3, 2)))
3544 
3545  # multi-dim: indexes
3546  check_indexing('[0, 1]', consec((3, 3)))
3547  check_indexing('[0, 1]', consec((3, 3, 2)))
3548  check_indexing('[1, 0, 2]', consec((3, 3, 3)))
3549  check_indexing('[2, -1]', consec((3, 3)))
3550 
3551  # multi-dim: mixed slicing and indexing
3552  check_indexing('[0, 1:2]', consec((3, 3)))
3553  check_indexing('[0, :1]', consec((3, 3, 2)))
3554  check_indexing('[1, 2:]', consec((3, 3, 3)))
3555  check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3)))
3556  check_indexing('[1:, -1, 0]', consec((3, 3, 3, 3)))
3557  check_indexing('[-1, 2:, 1:2]', consec((3, 3, 3, 3)))
3558  check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3)))
3559  check_indexing('[-1, :, 0, 2]', consec((3, 3, 3, 3)))
3560 
3561  # zero-sized slices
3562  check_indexing('[0:0]', consec((2, 2)))
3563  check_indexing('[0:0, 1]', consec((3, 3)))
3564 
3565  # trivial expression usage
3566  check_indexing('[1+1]', consec((3, 3)))
3567  check_indexing('[1:(0 + 2)]', consec((3, 3, 3)))
3568 
3569  # dynamic expression usage
3570  check_dynamic_indexing("[i + j]", consec((3, 3)), 0, 1)
3571  check_dynamic_indexing("[i:j, i]", consec((3, 3, 2)), 0, 2)
3572 
3573  def test_tensor_item(self):
3574  def test_scalar_to_float_coercion(x):
3575  return x.item() == 1
3576 
3577  self.checkScript(test_scalar_to_float_coercion, (torch.tensor(1.0),))
3578  self.checkScript(test_scalar_to_float_coercion, (torch.tensor(1),))
3579 
3580  def test_scalar_cast(x):
3581  scalar = x.item()
3582  return int(scalar), float(scalar)
3583 
3584  self.checkScript(test_scalar_to_float_coercion, (torch.tensor(1.0),))
3585  self.checkScript(test_scalar_to_float_coercion, (torch.tensor(1),))
3586 
3587  expected_str = r"Use int\(tensor\) or float\(tensor\) to retrieve"
3588  with self.assertRaisesRegex(RuntimeError, expected_str):
3589  @torch.jit.script
3590  def int_fn(a):
3591  # type: (int) -> int
3592  return a
3593 
3594  @torch.jit.script
3595  def test_error_msg(x):
3596  return int_fn(x.item())
3597 
3598  def test_method_on_number(self):
3599  def func():
3600  c = 1
3601  return c.add(1)
3602  with self.assertRaisesRegex(RuntimeError, 'Cannot call methods on numbers'):
3603  torch.jit.script(func)
3604 
3605  # testing implicit conversion of tensors to scalars to match function arguments
3606  def test_scalar_to_num_conversions(self):
3607  @torch.jit.script
3608  def multiple_defs(x):
3609  c = 1
3610  x = x + c
3611  return x
3612 
3613  self.assertTrue("ImplicitTensorToNum" not in str(multiple_defs.graph))
3614 
3615  @torch.jit.script
3616  def tensor_to_int_script(x, tensor):
3617  return x.unsqueeze(tensor)
3618 
3619  def tensor_to_int(x, tensor):
3620  return x.unsqueeze(tensor)
3621 
3622  @torch.jit.script
3623  def tensor_to_float_script(x, tensor):
3624  return x.addcmul(tensor, tensor, value=tensor)
3625 
3626  def tensor_to_float(x, tensor):
3627  return x.addcmul(tensor, tensor, value=tensor)
3628 
3629  x = torch.zeros(10)
3630  # float tensor, float tensor with grad, int tensor (can't set grad on int tensor)
3631  tensors = [torch.tensor(1.1),
3632  torch.tensor(1.1, requires_grad=True),
3633  torch.tensor(0),
3634  torch.tensor([2])]
3635 
3636  script_funs = [tensor_to_int_script, tensor_to_float_script]
3637  funs = [tensor_to_int, tensor_to_float]
3638 
3639  # return the result, or whether exception was thrown
3640  def test_func(func, x, tensor):
3641  try:
3642  result = func(x, tensor)
3643  except RuntimeError as e:
3644  result = True
3645  except TypeError as e:
3646  result = True
3647  return result
3648 
3649  # assert result or exception equal for each (function, inputs)
3650  for tensor in tensors:
3651  for i in range(len(script_funs)):
3652  self.assertEqual(test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor))
3653 
3654  def test_tuple_to_opt_list(self):
3655  @torch.jit.script
3656  def foo(x):
3657  # type: (Optional[List[int]]) -> int
3658  return 1
3659 
3660  @torch.jit.script
3661  def tuple_call():
3662  return foo((1, 2))
3663 
3664  def test_advancedindex(self):
3665  def consec(size, start=0):
3666  numel = torch.tensor(size).prod().item()
3667  return torch.arange(numel).view(size)
3668 
3669  def check_indexing(indexing, tensor, **kwargs):
3670  indices_dict = kwargs
3671 
3672  template = dedent("""
3673  def func(x{formals}):
3674  return x{expr}
3675  """)
3676 
3677  formals = []
3678  values = []
3679  for formal, value in indices_dict.items():
3680  formals.append(formal)
3681  values.append(value)
3682 
3683  formals = ''.join(map(', {}'.format, formals))
3684  inputs = [tensor] + values
3685  self._check_code(template.format(formals=formals, expr=indexing),
3686  "func", inputs)
3687 
3688  # Indexing with tensor (basic)
3689  check_indexing('[i]', consec((3, 3)), i=torch.tensor([0]))
3690  check_indexing('[i]', consec((3, 3)), i=torch.tensor(1))
3691  check_indexing('[i]', consec((3, 3)), i=torch.tensor([-2]))
3692  check_indexing('[i]', consec((3, 3), 2), i=torch.tensor([0, 0]))
3693  check_indexing('[i]', consec((3, 3, 2, 2)), i=torch.tensor([0, -2, 1]))
3694 
3695  # NB: indexing with tensors and indexing with sequences can be implemented
3696  # in a very similar way (sequences are converted to tensors), so only one
3697  # case needs to be tested extensively.
3698  # XXX: When we can index with sequences, replace these cases with
3699  # sequence indexing expressions; those are much easier to read.
3700 
3701  # Misc sequence advanced indexing
3702  inp = consec((4, 8, 5))
3703  to_check = [
3704  # [[0, 2], [1, 3]]
3705  ['[i, j]', {'i': [0, 2], 'j': [1, 3]}],
3706  # [[0, 2], [1, 3], [1, 1]]
3707  ['[i, j, k]', {'i': [0, 2], 'j': [1, 3], 'k': [1, 1]}],
3708  # [[0, 2], 1, [1, 1]]
3709  ['[i, j, k]', {'i': [0, 2], 'j': 1, 'k': [1, 1]}],
3710  # [:, :, [0, 3, 4]]
3711  ['[:, :, i]', {'i': [0, 3, 4]}],
3712  # [:, [2, 4, 5, 7], 2:4]
3713  ['[:, i, 2:4]', {'i': [0, 2, 3]}],
3714  # [[2, 3], :, :]
3715  ['[i, :, :]', {'i': [2, 3]}],
3716  # [:, [0, 2, 3], [1, 3, 4]]
3717  ['[:, i, j]', {'i': [0, 2, 3], 'j': [1, 3, 4]}],
3718  # [:, [0], [1, 2, 4]]
3719  ['[:, i, j]', {'i': [0], 'j': [1, 2, 4]}],
3720  # [:, [0, 1, 3], [4]]
3721  ['[:, i, j]', {'i': [0, 1, 3], 'j': [4]}],
3722  # [:, [[0, 1], [1, 0]], [[2, 3]]]
3723  ['[:, i, j]', {'i': [[0, 1], [1, 0]], 'j': [[2, 3]]}],
3724  # [:, [[0, 1], [2, 3]], [[0]]]
3725  ['[:, i, j]', {'i': [[0, 1], [2, 3]], 'j': [[0]]}],
3726  # [:, [[5, 6]], [[0, 3], [4, 4]]]
3727  ['[:, i, j]', {'i': [[5, 6]], 'j': [[0, 3], [4, 4]]}],
3728  # [[0, 2, 3], [1, 3, 4], :]
3729  ['[i, j, :]', {'i': [0, 2, 3], 'j': [1, 3, 4]}],
3730  # [0, [1, 2, 4], :]
3731  ['[i, j, :]', {'i': 0, 'j': [1, 2, 4]}],
3732  # [[0, 1, 3], 4, :]
3733  ['[i, j, :]', {'i': [0, 1, 3], 'j': 4}],
3734  # [[[0, 1], [1, 0]], [[2, 1], [3, 5]], :]
3735  ['[i, j, :]', {'i': [[0, 1], [1, 0]], 'j': [[2, 1], [3, 5]]}],
3736  # [[[0, 1], [1, 0]], [[2, 3]], :]
3737  ['[i, j, :]', {'i': [[0, 1], [1, 0]], 'j': [[2, 3]]}],
3738  # [[[0, 1], [2, 3]], [[0]], :]
3739  ['[i, j, :]', {'i': [[0, 1], [2, 3]], 'j': [[0]]}],
3740  # [[[2, 1]], [[0, 3], [4, 4]], :]
3741  ['[i, j, :]', {'i': [[2, 1]], 'j': [[0, 3], [4, 4]]}],
3742  # [[[2]], [[0, 3], [4, 1]], 0:2]
3743  ['[i, j, 0:2]', {'i': [[2]], 'j': [[0, 3], [4, 1]]}],
3744  ]
3745 
3746  for expr, argdict in to_check:
3747  tensordict = {k: torch.tensor(v) for (k, v) in argdict.items()}
3748  check_indexing(expr, inp, **tensordict)
3749 
3750  def test_keyword(self):
3751  @torch.jit.script
3752  def func(x):
3753  return torch.sum(x, dim=0)
3754 
3755  x = torch.rand(10, dtype=torch.float, requires_grad=True)
3756  y = func(x)
3757  y2 = torch.sum(x, dim=0)
3758  self.assertEqual(y, y2)
3759 
3760  def test_constant_pooling_none(self):
3761  @torch.jit.script
3762  def typed_nones(a=None, b=None, c=None):
3763  # type: (Optional[int], Optional[bool], Optional[Tensor]) -> Tuple[Optional[int], Optional[bool], Optional[Tensor]] # noqa
3764  return a, b, c
3765 
3766  @torch.jit.script
3767  def test(a):
3768  # type: (bool) -> None
3769  if a:
3770  print(typed_nones())
3771  else:
3772  print(typed_nones())
3773 
3774  graph_str = str(test.graph)
3775  self.assertTrue(graph_str.count("bool? = prim::Constant") == 1)
3776  self.assertTrue(graph_str.count("int? = prim::Constant") == 1)
3777  self.assertTrue(graph_str.count("None = prim::Constant") == 1)
3778 
3779  def test_literal(self):
3780  def func1(a, b):
3781  c = a, b
3782  d, e = c
3783  return d + e
3784 
3785  def func2(a, b):
3786  c = a, (a, b)
3787  d, e = c
3788  f, g = e
3789  return d + f + g
3790 
3791  def func3(a, b):
3792  # type: (float, float) -> float
3793  c = 0., (0., 0.)
3794  x = True
3795  while x:
3796  x = False
3797  c = a, (a, b)
3798  d, e = c
3799  f, g = e
3800  return d + f + g
3801 
3802  a = torch.rand(1, requires_grad=True)
3803  b = torch.rand(1, requires_grad=True)
3804  self.checkScript(func1, (a, b), optimize=True)
3805  self.checkScript(func2, (a, b), optimize=True)
3806  self.checkScript(func3, (a.item(), b.item()), optimize=True)
3807 
3808  def test_expand(self):
3809  @torch.jit.script
3810  def func(x, y):
3811  return x + y
3812 
3813  x = torch.rand(2, 3, dtype=torch.float, requires_grad=True)
3814  y = torch.rand(3, dtype=torch.float, requires_grad=True)
3815  out = func(x, y)
3816  self.assertEqual(func(x, y), x + y)
3817 
3818  grad = torch.randn(2, 3, dtype=torch.float)
3819  out.backward(grad)
3820  self.assertEqual(x.grad, grad)
3821  self.assertEqual(y.grad, grad.sum(dim=0))
3822 
3823  def test_sum(self):
3824  @torch.jit.script
3825  def func(x):
3826  return x.sum(dim=[4])
3827 
3828  @torch.jit.script
3829  def func2(x):
3830  return x.sum(dim=4)
3831 
3832  # test that shape analysis is written correctly for sum with IntArrayRef[1] dim argument
3833  self.run_pass('constant_propagation', func.graph)
3834  self.run_pass('constant_propagation', func2.graph)
3835  torch._C._jit_pass_shape_analysis(
3836  func.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
3837  torch._C._jit_pass_shape_analysis(
3838  func2.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
3839  self.assertTrue(func.graph.findNode("aten::sum").output().type().kind()
3840  == "DimensionedTensorType")
3841  self.assertTrue(func2.graph.findNode("aten::sum").output().type().kind()
3842  == "DimensionedTensorType")
3843 
3844  def test_cat(self):
3845  @torch.jit.script
3846  def func(x):
3847  return torch.cat((x, x), dim=0)
3848 
3849  x = torch.rand(10, dtype=torch.float, requires_grad=True)
3850  self.assertEqual(func(x), torch.cat((x, x), dim=0))
3851 
3852  @torch.jit.script
3853  def func2(x, y):
3854  return torch.cat((x, x), y)
3855 
3856  x = torch.rand([2, 2])
3857  y = torch.tensor(1)
3858  self.assertEqual(func2(x, y), torch.cat((x, x), y))
3859 
3860  def test_cat_lifts(self):
3861  @torch.jit.script
3862  def foo(x):
3863  return torch.cat([x, x], dim=1)
3864 
3865  @torch.jit.script
3866  def foo2(x):
3867  return torch.cat([], dim=1)
3868 
3869  @torch.jit.script
3870  def foo3(x):
3871  return torch.cat([x], dim=1)
3872 
3873  for g in [foo.graph, foo2.graph, foo3.graph]:
3874  FileCheck().check("int =").check("ListConstruct").check("aten::cat").run(str(g))
3875 
3876  def test_list_literal(self):
3877  def reassign():
3878  x = [1]
3879  if True:
3880  x = [2, 3]
3881  return
3882  self.checkScript(reassign, (), optimize=False)
3883 
3884  def reassign_arity_change():
3885  x = [1]
3886  if True:
3887  x = [1, 2, 3]
3888  return
3889  self.checkScript(reassign_arity_change, (), optimize=False)
3890 
3891  def reassign_from_empty_literal():
3892  x = []
3893  if True:
3894  x = [1, 2, 3]
3895  return
3896  with self.assertRaisesRegex(RuntimeError, r"previously has type Tensor\[\]"):
3897  self.checkScript(reassign_from_empty_literal, (), optimize=False)
3898 
3899  def reassign_from_empty_builtin():
3900  x = torch.jit.annotate(List[int], [])
3901  if True:
3902  x = [1, 2, 3]
3903  y = torch.jit.annotate(List[float], [])
3904  if True:
3905  y = [1.0, 2.0, 3.0]
3906  z = []
3907  if True:
3908  z = [torch.randn([1])]
3909  return
3910  self.checkScript(reassign_from_empty_builtin, (), optimize=False)
3911 
3912  def reassign_bad_type():
3913  x = [1]
3914  if True:
3915  x = [1.0]
3916  return
3917  with self.assertRaisesRegex(RuntimeError, "previously has type"):
3918  self.checkScript(reassign_bad_type, (), optimize=False)
3919 
3920  def reassign_nested():
3921  x = torch.jit.annotate(List[int], [])
3922  if True:
3923  x = [1, 2, 3]
3924  if True:
3925  x = [1.0]
3926  return
3927  with self.assertRaisesRegex(RuntimeError, "previously has type"):
3928  self.checkScript(reassign_nested, (), optimize=False)
3929 
3930  def test_list_gather(self):
3931  def index():
3932  a = [1, 2, 3]
3933  return a[1]
3934 
3935  self.checkScript(index, ())
3936 
3937  def negative_index():
3938  a = [1, 2, 3]
3939  return a[-1]
3940 
3941  self.checkScript(negative_index, ())
3942 
3943  def bad_index():
3944  a = [1, 2, 3]
3945  return a[4]
3946 
3947  self.checkScriptRaisesRegex(bad_index, (), IndexError,
3948  "list index out of range")
3949 
3950  def bad_negative_index():
3951  a = [1, 2, 3]
3952  return a[-5]
3953 
3954  self.checkScriptRaisesRegex(bad_negative_index, (), IndexError,
3955  "list index out of range")
3956 
3957  def test_tensor_len(self):
3958  def func(x):
3959  return len(x)
3960 
3961  self.checkScript(func, [torch.ones(4, 5, 6)])
3962 
3963  def test_list_len(self):
3964  def func():
3965  a = [1, 2, 3]
3966  return len(a) == 3
3967 
3968  self.checkScript(func, ())
3969 
3970  def func2():
3971  a = []
3972  return len(a) == 0
3973 
3974  self.checkScript(func2, ())
3975 
3976  def test_list_ops(self):
3977  def test_equality():
3978  a = [1, 2, 3]
3979  b = [1, 2, 3]
3980  return a == b
3981 
3982  self.checkScript(test_equality, (), optimize=True)
3983 
3984  def test_inequality():
3985  a = [1, 2, 3]
3986  b = [1, 2, 3]
3987  return a != b
3988 
3989  self.checkScript(test_equality, (), optimize=True)
3990 
3991  def test_non_equality():
3992  a = [1, 2, 3]
3993  b = [3]
3994  return a == b
3995 
3996  self.checkScript(test_non_equality, (), optimize=True)
3997 
3998  def test_non_inequality():
3999  a = [1, 2, 3]
4000  b = [3]
4001  return a != b
4002 
4003  self.checkScript(test_non_equality, (), optimize=True)
4004 
4005  def test_list_equality_as_cond():
4006  a = [1, 2, 3]
4007  b = [3]
4008  if a == b:
4009  c = 1
4010  else:
4011  c = 2
4012  return c
4013 
4014  self.checkScript(test_list_equality_as_cond, (), optimize=True)
4015 
4016  def test_list_add():
4017  a = [1, 2, 3]
4018  b = [2]
4019  c = a + b
4020  return c == [1, 2, 3, 2]
4021 
4022  self.checkScript(test_list_add, (), optimize=True)
4023 
4024  def test_list_add_empty():
4025  a = [1, 2, 3]
4026  b = torch.jit.annotate(List[int], [])
4027  c = a + b
4028  return c == [1, 2, 3]
4029 
4030  self.checkScript(test_list_add_empty, (), optimize=True)
4031 
4032  def test_tensor_list_equality():
4033  t1 = torch.ones([1, 1])
4034  t2 = torch.ones([1, 1])
4035  x = [t1, t2]
4036  y = [t2, t1]
4037  return x == y
4038 
4039  self.checkScript(test_tensor_list_equality, (), optimize=True)
4040 
4041  def test_invalid_list_equality():
4042  t1 = torch.ones([2, 2])
4043  t2 = torch.ones([2, 2])
4044  x = [t1, t2]
4045  y = [t2, t1]
4046  # will throw since the tensors have more than one element
4047  return x == y
4048 
4050  test_invalid_list_equality,
4051  (),
4052  RuntimeError,
4053  "bool value of Tensor")
4054 
4055  def test_list_slice(self):
4056  def test_regular_slice():
4057  a = [0, 1, 2, 3, 4]
4058  return a[2:3] == [2]
4059  self.checkScript(test_regular_slice, ())
4060 
4061  def test_open_ended_slice():
4062  a = [0, 1, 2, 3, 4]
4063  return a[2:] == [2, 3, 4]
4064  self.checkScript(test_open_ended_slice, ())
4065 
4066  def test_open_ended_slice2():
4067  a = [0, 1, 2, 3, 4]
4068  return a[:2] == [0, 1]
4069  self.checkScript(test_open_ended_slice2, ())
4070 
4071  def test_negative_slice():
4072  a = [0, 1, 2, 3, 4]
4073  return a[:-1] == [0, 1, 2, 3]
4074  self.checkScript(test_negative_slice, ())
4075 
4076  def test_negative_slice2():
4077  a = [0, 1, 2, 3, 4]
4078  return a[-3:-1] == [2, 3]
4079  self.checkScript(test_negative_slice2, ())
4080 
4081  def test_backward_slice():
4082  a = [0, 1, 2, 3, 4]
4083  return a[3:2] == torch.jit.annotate(List[int], [])
4084  self.checkScript(test_backward_slice, ())
4085 
4086  def test_over_slice():
4087  a = [0, 1, 2, 3, 4]
4088  return a[3:10] == [3, 4]
4089  self.checkScript(test_backward_slice, ())
4090 
4091  def test_mutable_list_append(self):
4092  def test_append():
4093  a = [0, 1]
4094  a.append(2)
4095  a.append(3)
4096  return a == [0, 1, 2, 3]
4097  self.checkScript(test_append, ())
4098 
4099  def test_mutable_list_append_2(self):
4100  def test_append_2():
4101  a = [0, 1]
4102  a.append(2)
4103  a = [1]
4104  a.append(4)
4105  return a == [1, 4]
4106  self.checkScript(test_append_2, ())
4107 
4108  def test_mutable_list_append_if(self):
4109  def test_append_if():
4110  a = [1]
4111  if True:
4112  a.append(4)
4113  return a == [1, 4]
4114  self.checkScript(test_append_if, ())
4115 
4116  def test_mutable_list_append_if_else(self):
4117  def test_append_if_else():
4118  a = [1]
4119  if False:
4120  a.append(4)
4121  else:
4122  a.append(10)
4123  return a == [1, 10]
4124  self.checkScript(test_append_if_else, ())
4125 
4126  def test_mutable_list_append_loop(self):
4127  def test_append_loop():
4128  a = torch.jit.annotate(List[int], [])
4129  for i in range(5):
4130  a.append(i)
4131 
4132  return a == [0, 1, 2, 3, 4]
4133  self.checkScript(test_append_loop, ())
4134 
4135  def test_mutable_list_append_loop_if(self):
4136  def test_append_loop_if():
4137  a = torch.jit.annotate(List[int], [])
4138  for i in range(5):
4139  if i > 3:
4140  a.append(i)
4141  else:
4142  a.append(0)
4143 
4144  return a == [0, 0, 0, 0, 4]
4145  self.checkScript(test_append_loop_if, ())
4146 
4147  def test_mutable_list_nested_loop(self):
4148  def test_nested_loop():
4149  a = torch.jit.annotate(List[int], [])
4150  for i in range(2):
4151  for j in range(2):
4152  a.append(i + j)
4153 
4154  return a == [0, 1, 1, 2]
4155  self.checkScript(test_nested_loop, ())
4156 
4157  def test_mutable_list_function_inline(self):
4158  @torch.jit.script
4159  def bar(y):
4160  # type: (List[int]) -> None
4161  y.append(4)
4162 
4163  @torch.jit.script
4164  def foo():
4165  x = [1, 2, 3]
4166  bar(x)
4167  return x
4168 
4169  self.assertEqual(foo(), [1, 2, 3, 4])
4170 
4171  def test_mutable_list_reverse_empty(self):
4172  def test_reverse_empty():
4173  a = []
4174  a.reverse()
4175 
4176  return a == []
4177  self.checkScript(test_reverse_empty, ())
4178 
4179  def test_mutable_list_reverse(self):
4180  def test_reverse():
4181  a = [1, 2, 3, 4]
4182  a.reverse()
4183 
4184  return a == [4, 3, 2, 1]
4185  self.checkScript(test_reverse, ())
4186 
4187  def test_mutable_tensor_list_reverse(self):
4188  def test_tensor_reverse():
4189  a = [torch.tensor(1), torch.tensor(2)]
4190  a.reverse()
4191 
4192  return a == [torch.tensor(2), torch.tensor(1)]
4193  self.checkScript(test_tensor_reverse, ())
4194 
4195  def test_mutable_list_pop_empty(self):
4196  @torch.jit.script
4197  def test_pop_empty():
4198  a = torch.jit.annotate(List[int], [])
4199  return a.pop()
4200 
4201  with self.assertRaisesRegex(RuntimeError, "pop from empty list"):
4202  test_pop_empty()
4203 
4204  def test_mutable_list_pop(self):
4205  def test_pop():
4206  a = [1, 2, 3, 4]
4207  b = a.pop()
4208 
4209  return b == 4
4210 
4211  self.checkScript(test_pop, ())
4212 
4213  def test_mutable_list_pop2(self):
4214  def test_pop2():
4215  a = [1, 2, 3, 4]
4216  b = a.pop()
4217 
4218  return len(a) == 3
4219 
4220  self.checkScript(test_pop2, ())
4221 
4222  def test_mutable_list_pop_at(self):
4223  def test_pop_at():
4224  a = [1, 2, 3, 4]
4225  b = a.pop(1)
4226 
4227  return b == 2
4228 
4229  self.checkScript(test_pop_at, ())
4230 
4231  def test_mutable_list_pop_at2(self):
4232  def test_pop_at2():
4233  a = [1, 2, 3, 4]
4234  b = a.pop(1)
4235 
4236  return len(a) == 3
4237 
4238  self.checkScript(test_pop_at2, ())
4239 
4240  def test_mutable_list_pop_at_negative(self):
4241  def test_pop_at_negative():
4242  a = [1, 2, 3, 4]
4243  b = a.pop(-2)
4244 
4245  return b == 3
4246 
4247  self.checkScript(test_pop_at_negative, ())
4248 
4249  def test_mutable_list_pop_at_negative2(self):
4250  def test_pop_at_negative2():
4251  a = [1, 2, 3, 4]
4252  b = a.pop(-2)
4253 
4254  return len(a) == 3
4255 
4256  self.checkScript(test_pop_at_negative2, ())
4257 
4258  def test_mutable_list_pop_slice(self):
4259  def test_pop_slice():
4260  a = [1, 2, 3, 4]
4261  b = [1, 2, 3, 4]
4262 
4263  a.pop()
4264  b = b[:-1]
4265 
4266  return a == b
4267 
4268  self.checkScript(test_pop_slice, ())
4269 
4270  @unittest.skipIf(sys.version_info < (3, 3), "clear not supported in version < 3.3")
4271  def test_mutable_list_clear_empty(self):
4272  def test_clear_empty():
4273  a = torch.jit.annotate(List[int], [])
4274  a.clear()
4275 
4276  return len(a) == 0
4277  self.checkScript(test_clear_empty, ())
4278 
4279  @unittest.skipIf(sys.version_info < (3, 3), "clear not supported in version < 3.3")
4280  def test_mutable_list_clear(self):
4281  def test_clear():
4282  a = [1, 2, 3, 4]
4283  a.clear()
4284 
4285  return len(a) == 0
4286  self.checkScript(test_clear, ())
4287 
4288  def test_mutable_list_insert(self):
4289  def test_list_insert():
4290  a = [1, 2, 3, 4]
4291  a.insert(2, 5)
4292 
4293  return a == [1, 2, 5, 3, 4]
4294  self.checkScript(test_list_insert, ())
4295 
4296  def test_mutable_list_insert_negative(self):
4297  def test_list_insert_negative():
4298  a = [1, 2, 3, 4]
4299  a.insert(-1, 5)
4300 
4301  return a == [1, 2, 3, 5, 4]
4302  self.checkScript(test_list_insert_negative, ())
4303 
4304  def test_mutable_list_insert_neg_out_of_bounds(self):
4305  def test_list_insert_neg_out_of_bounds():
4306  a = [1, 2, 3, 4]
4307  a.insert(-10, 5)
4308 
4309  return a == [5, 1, 2, 3, 4]
4310  self.checkScript(test_list_insert_neg_out_of_bounds, ())
4311 
4312  def test_mutable_list_insert_out_of_bounds(self):
4313  def test_list_insert_out_of_bounds():
4314  a = [1, 2, 3, 4]
4315  a.insert(10, 5)
4316 
4317  return a == [1, 2, 3, 4, 5]
4318  self.checkScript(test_list_insert_out_of_bounds, ())
4319 
4320  def test_mutable_list_remove_not_existing(self):
4321  @torch.jit.script
4322  def test_list_remove_not_existing():
4323  a = [1, 2, 3, 4]
4324  a.remove(5)
4325 
4326  return a
4327 
4328  with self.assertRaisesRegex(RuntimeError, "x not in list"):
4329  test_list_remove_not_existing()
4330 
4331  def test_mutable_list_remove(self):
4332  def test_list_remove():
4333  a = [1, 2, 3, 4]
4334  a.remove(3)
4335 
4336  return a == [1, 2, 4]
4337  self.checkScript(test_list_remove, ())
4338 
4339  def test_list_index_not_existing(self):
4340  @torch.jit.script
4341  def list_index_not_existing():
4342  a = [4, 1, 3, 2]
4343  i = a.index(5)
4344 
4345  return i
4346 
4347  with self.assertRaisesRegex(RuntimeError, "'5' is not in list"):
4348  list_index_not_existing()
4349 
4350  def test_list_index(self):
4351  def list_index():
4352  a = [4, 1, 3, 2]
4353  i = a.index(3)
4354 
4355  return i == 2
4356  self.checkScript(list_index, ())
4357 
4358  def test_tensor_list_index(self):
4359  def tensor_list_index():
4360  a = [torch.tensor(4), torch.tensor(1), torch.tensor(3), torch.tensor(2)]
4361  i = a.index(torch.tensor(3))
4362 
4363  return i == 2
4364  self.checkScript(tensor_list_index, ())
4365 
4366  def test_tensor_list_index_not_existing(self):
4367  @torch.jit.script
4368  def tensor_list_index_not_existing():
4369  a = [torch.tensor(4), torch.tensor(1), torch.tensor(3), torch.tensor(2)]
4370  i = a.index(torch.tensor(5))
4371 
4372  return i
4373 
4374  with self.assertRaisesRegex(RuntimeError, "is not in list"):
4375  tensor_list_index_not_existing()
4376 
4377  def test_list_count(self):
4378  def list_count():
4379  a = [4, 1, 4, 2, 4]
4380  i = a.count(4)
4381 
4382  return i == 3
4383  self.checkScript(list_count, ())
4384 
4385  def test_list_count_not_existing(self):
4386  def list_count_not_existing():
4387  a = [4, 1, 4, 2, 4]
4388  i = a.count(5)
4389 
4390  return i == 0
4391  self.checkScript(list_count_not_existing, ())
4392 
4393  def test_tensor_list_count(self):
4394  def tensor_list_count():
4395  a = [torch.tensor(4), torch.tensor(1), torch.tensor(4), torch.tensor(4)]
4396  i = a.count(torch.tensor(4))
4397 
4398  return i == 3
4399  self.checkScript(tensor_list_count, ())
4400 
4401  def test_tensor_list_count_not_existing(self):
4402  def tensor_list_count_not_existing():
4403  a = [torch.tensor(4), torch.tensor(1), torch.tensor(4), torch.tensor(4)]
4404  i = a.count(torch.tensor(5))
4405 
4406  return i == 0
4407  self.checkScript(tensor_list_count_not_existing, ())
4408 
4409  def test_mutable_list_remove_tensor(self):
4410  def test_list_remove_tensor():
4411  a = [torch.ones(1), torch.zeros(1), torch.ones(2)]
4412  a.remove(torch.zeros(1))
4413 
4414  return len(a) == 2
4415  self.checkScript(test_list_remove_tensor, ())
4416 
4417  def test_mutable_list_remove2(self):
4418  def test_list_remove2():
4419  a = [1]
4420  a.remove(1)
4421 
4422  return len(a) == 0
4423  self.checkScript(test_list_remove2, ())
4424 
4425  def test_extend_list_mutable(self):
4426  @torch.jit.script
4427  def extend_list(a, b):
4428  # type: (List[Tensor], List[Tensor]) -> List[Tensor]
4429 
4430  a.extend(b)
4431  return a
4432 
4433  for l in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]:
4434  for r in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]:
4435  self.assertEqual(extend_list(l, r), l + r)
4436 
4437  def test_extend_list_immutable(self):
4438  @torch.jit.script
4439  def extend_list(a, b):
4440  # type: (List[int], List[int]) -> List[int]
4441 
4442  a.extend(b)
4443  return a
4444 
4445  for l in [[], [1], [1, 2, 3]]:
4446  for r in [[], [1], [1, 2, 3]]:
4447  self.assertEqual(extend_list(l, r), l + r)
4448 
4449  def test_copy_list_mutable(self):
4450  @torch.jit.script
4451  def copy_list(a):
4452  # type: (List[Tensor]) -> List[Tensor]
4453  return a.copy()
4454 
4455  for l in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]:
4456  self.assertEqual(copy_list(l), l)
4457 
4458  def test_copy_list_immutable(self):
4459  @torch.jit.script
4460  def copy_list(a):
4461  # type: (List[int]) -> List[int]
4462  return a.copy()
4463 
4464  for l in [[], [1], [1, 2, 3]]:
4465  self.assertEqual(copy_list(l), l)
4466 
4467  def test_func_call(self):
4468  script = '''
4469  def add(a, b):
4470  return a + b
4471 
4472  def mul(a, x):
4473  return a * x
4474 
4475  def func(alpha, beta, x, y):
4476  return add(mul(alpha, x), mul(beta, y))
4477  '''
4478  alpha = torch.rand(1, dtype=torch.float, requires_grad=True)
4479  beta = torch.rand(1, dtype=torch.float, requires_grad=True)
4480  x = torch.rand(3, dtype=torch.float, requires_grad=True)
4481  y = torch.rand(3, dtype=torch.float, requires_grad=True)
4482  outputs = alpha * x + beta * y
4483  # NOTE: cannot optimize yet because broadcasts are not inserted before the fuser runs
4484  self.checkScript(script, [alpha, beta, x, y], optimize=False, outputs=outputs)
4485 
4486  def test_resize_input_ops(self):
4487  # resize_ and resize_as resize the input tensor. because our shape analysis
4488  # is flow invariant, we set any Tensor that can alias a resized Tensor
4489  # to the base Tensor Type, without size information.
4490 
4491  # testing that value which is an input of a graph gets handled
4492  def out_op_graph_input():
4493  @torch.jit.script
4494  def test(x, y, z):
4495  torch.mul(x, y, out=z)
4496  return z
4497 
4498  torch._C._jit_pass_shape_analysis(
4499  test.graph, (torch.zeros(2, 1), torch.zeros(1, 2), torch.zeros(1, 1, 1)), False)
4500  self.assertTrue(next(test.graph.outputs()).type() == TensorType.get())
4501  out_op_graph_input()
4502 
4503  def test_resize():
4504  @torch.jit.script
4505  def test(x):
4506  after_resize_alias = torch.zeros([2])
4507  for _i in range(5):
4508  b = x + 1
4509  f = [1]
4510  before_resize_alias = b.sub_(1)
4511  # for i in range(10):
4512  f.append(1)
4513  b.resize_(f)
4514  after_resize_alias = b.add_(1)
4515  return after_resize_alias
4516 
4517  g = test.graph
4518  self.run_pass('constant_propagation', g)
4519  torch._C._jit_pass_shape_analysis(
4520  g, (torch.zeros(1, 1),), False)
4521  resize_node = g.findNode("aten::resize_")
4522  # first input and output of b.resize_ is b
4523  self.assertTrue(next(resize_node.inputs()).type() == TensorType.get())
4524  self.assertTrue(next(resize_node.outputs()).type() == TensorType.get())
4525 
4526  # correctly propagates to b alias set
4527  before_resize = g.findNode("aten::sub_")
4528  self.assertTrue(next(before_resize.outputs()).type() == TensorType.get())
4529 
4530  after_resize = g.findNode("aten::add_")
4531  self.assertTrue(next(after_resize.outputs()).type() == TensorType.get())
4532 
4533  test_resize()
4534 
4535  def test_resize_as():
4536  @torch.jit.script
4537  def test(x):
4538  b = torch.zeros([2, 2])
4539  b.resize_as_(x)
4540  return b
4541 
4542  g = test.graph
4543  self.run_pass('constant_propagation', g)
4544  torch._C._jit_pass_shape_analysis(
4545  g, (torch.zeros(1, 1),), False)
4546 
4547  # x doesn't alias a resized op so it shouldn't be set to base Tensor type
4548  self.assertTrue(next(g.inputs()).type() != TensorType.get())
4549  # return is resized
4550  self.assertTrue(next(g.outputs()).type() == TensorType.get())
4551 
4552  test_resize_as()
4553 
4554  def test_view_shape_prop(self):
4555  cu = torch.jit.CompilationUnit('''
4556  def test_view_shape_prop(a):
4557  return a.view(size=[-1])
4558  ''')
4559  inputs = [torch.zeros(10, 10)]
4560  outputs = torch.zeros(100)
4561 
4562  real_outs = cu.test_view_shape_prop(*inputs)
4563  self.assertEqual(real_outs, outputs)
4564 
4565  def test_view_listconstruct_shape_prop(self):
4566  def fn(x):
4567  B = x.size(0)
4568  C = x.size(1)
4569  T = x.size(2)
4570  return x.view(T, B, C)
4571 
4572  x = torch.randn(3, 1, 5, requires_grad=True)
4573  graph = torch.jit.script(fn).graph
4574  torch._C._jit_pass_shape_analysis(graph, (x,), False)
4575  a = next(graph.outputs()).type().kind()
4576  self.assertTrue(next(graph.outputs()).type().kind() != 'TensorType')
4577 
4578  def test_integral_shape_inference(self):
4579  cu = torch.jit.CompilationUnit('''
4580  def test_integral_shape_inference(a):
4581  return a / a
4582  ''')
4583  inputs = [torch.ones(10, 10).type(torch.LongTensor)]
4584  outputs = torch.ones(10, 10)
4585 
4586  self.assertEqual(cu.test_integral_shape_inference(*inputs), outputs)
4587 
4588  def test_fuser_multiple_blocks(self):
4589  cu = torch.jit.CompilationUnit('''
4590  def test_fuser_multiple_blocks(this, that, theother, meme):
4591  i = 0
4592  while i < 20:
4593  this = torch.cat([this, meme], dim=0)
4594  that = torch.cat([that, meme], dim=0)
4595  theother = torch.cat([theother, meme], dim=0)
4596  i = i + 1
4597  return this, that, theother
4598  ''')
4599 
4600  inputs = [torch.ones(0, 10, 10)] * 3
4601  inputs += [torch.ones(1, 10, 10)]
4602  outputs = [torch.ones(20, 10, 10)] * 3
4603 
4604  self.assertEqual(cu.test_fuser_multiple_blocks(*inputs), outputs)
4605 
4606  def test_dropout_script(self):
4607 
4608  eg = torch.zeros(1, 2, 3, requires_grad=True)
4609 
4610  @_trace(eg)
4611  def foo(x):
4612  x = torch.neg(x)
4613  return F.dropout(x)
4614 
4615  class MyDrop(nn.Module):
4616  def forward(self, x):
4617  return foo(x)
4618 
4619  f = io.BytesIO()
4620  torch.onnx.export(MyDrop(), (eg,), f, verbose=False)
4621 
4622  @unittest.skip("RuntimeError: VariableType::ID() not implemented")
4623  def test_cast(self):
4624  script = '''
4625  def to_int(x):
4626  return int(x)
4627  '''
4628  x = Variable(torch.FloatTensor([1.1, 2.3]), requires_grad=True)
4629  out = Variable(torch.IntTensor([1, 2]), requires_grad=True)
4630  self.checkScript(script, [x], optimize=True, outputs=[out], func='to_int')
4631 
4632  def test_python_frontend(self):
4633  def fn(x, y, z):
4634  q = None
4635  q = x + y - z.sigmoid()
4636  print(q)
4637  w = -z
4638  if not x and not y and z:
4639  m = x if not z else y
4640  while x < y > z:
4641  q = x
4642  assert 1 == 1, "hello"
4643  return x
4644 
4646  self.assertExpected(str(ast))
4647 
4648  @unittest.skipIf(not PY2, "Requires python 2")
4649  def test_python_frontend_py2(self):
4650  def fn():
4651  raise Exception("hello")
4653  self.assertExpected(str(ast))
4654 
4655  @unittest.skipIf(PY2, "Requires python 3")
4656  def test_python_frontend_py3(self):
4657  def fn():
4658  raise Exception("hello")
4660  self.assertExpected(str(ast))
4661 
4662  def _make_scalar_vars(self, arr, dtype):
4663  return [torch.tensor(val, dtype=dtype) for val in arr]
4664 
4665  def test_string_print(self):
4666  def func(a):
4667  print(a, "a" 'b' '''c''' """d""", 2, 1.5)
4668  return a
4669 
4670  inputs = self._make_scalar_vars([1], torch.int64)
4671  self.checkScript(func, inputs, capture_output=True)
4672 
4673  def test_while(self):
4674  def func(a, b, max):
4675  while bool(a < max):
4676  a = a + 1
4677  b = b + 1
4678  c = a + b
4679  return c
4680 
4681  inputs = self._make_scalar_vars([1, 1, 10], torch.int64)
4682  self.checkScript(func, inputs, optimize=True)
4683 
4684  def test_fibb(self):
4685  def func(lim):
4686  first = 1
4687  second = 1
4688  i = 1
4689  somenum = 5
4690  dontmutateme = 3
4691  third = 0
4692  while bool(i < lim):
4693  third = first + second
4694  first = second
4695  second = third
4696  j = 0
4697  while j < 10:
4698  somenum = somenum * 2
4699  j = j + 1
4700  i = i + j
4701  i = i + dontmutateme
4702 
4703  st = second + third
4704  fs = first + second
4705  return third, st, fs
4706 
4707  inputs = self._make_scalar_vars([10], torch.int64)
4708  self.checkScript(func, inputs, optimize=True)
4709 
4710  def test_if(self):
4711  def func(a, b):
4712  # type: (int, int) -> int
4713  d = 3
4714  if bool(a > 10):
4715  a = 3 + d
4716  else:
4717  b = 3 + d
4718  d = 4
4719  c = a + b
4720  return c
4721 
4722  inputs = self._make_scalar_vars([1, -1], torch.int64)
4723  self.checkScript(func, inputs, optimize=True)
4724 
4725  def test_if_for_in_range(self):
4726  def func(a, b):
4727  # type: (int, int) -> int
4728  d = 3
4729  for _ in range(20):
4730  if bool(a > 10):
4731  a = 3 + d
4732  else:
4733  b = 3 + d
4734  d = 4
4735  c = a + b
4736  return d
4737  inputs = self._make_scalar_vars([1, -1], torch.int64)
4738  self.checkScript(func, inputs, optimize=True)
4739 
4740  def test_if_noelse(self):
4741  def func(a, b):
4742  if bool(a > 10):
4743  a = 3 + b
4744  c = a + b
4745  return c
4746 
4747  inputs = self._make_scalar_vars([-1, 1], torch.int64)
4748  self.checkScript(func, inputs, optimize=True)
4749 
4750  def test_if_is_none_dispatch(self):
4751 
4752  @torch.jit.script
4753  def test_lhs_none_rhs_none():
4754  # LHS, RHS both alwaysNone, dispatch always_none_branch
4755  # only emit one prim::Constant
4756  if None is None:
4757  return 1
4758  elif None is not None:
4759  return 2
4760  else:
4761  return 3
4762 
4763  self.assertTrue(str(test_lhs_none_rhs_none.graph).count(': int = prim::Constant') == 1)
4764 
4765  @torch.jit.script
4766  def test_lhs_opt_rhs_none(lhs=None):
4767  # type: (Optional[Tensor]) -> int
4768  # LHS maybeNone: emit normal if stmt that contains 3 constants
4769  if lhs is not None:
4770  return 2
4771  elif lhs is None:
4772  return 1
4773  else:
4774  return 3
4775 
4776  self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(': int = prim::Constant') == 3)
4777 
4778  @torch.jit.script
4779  def test_lhs_none_rhs_opt(rhs=None):
4780  # type: (Optional[Tensor]) -> int
4781  # RHS maybeNone, emit normal if stmt that contains 3 constants
4782  if None is rhs:
4783  return 1
4784  elif None is not rhs:
4785  return 2
4786  else:
4787  return 3
4788 
4789  self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(': int = prim::Constant') == 3)
4790 
4791  @torch.jit.script
4792  def test_lhs_never_rhs_none(lhs):
4793  # LHS neverNone, RHS alwaysNone dispatch never_none_branch
4794  # only emit one prim::Constant
4795  if lhs is None:
4796  return 1
4797  elif lhs is not None:
4798  return 2
4799  else:
4800  return 3
4801 
4802  self.assertTrue(str(test_lhs_never_rhs_none.graph).count(': int = prim::Constant') == 1)
4803 
4804  @torch.jit.script
4805  def test_lhs_none_rhs_never(rhs):
4806  # LHS alwaysNone, RHS neverNone dispatch never_none_branch
4807  # only emit one prim::Constant
4808  if None is rhs:
4809  return 1
4810  elif None is not rhs:
4811  return 2
4812  else:
4813  return 3
4814 
4815  self.assertTrue(str(test_lhs_none_rhs_never.graph).count(': int = prim::Constant') == 1)
4816 
4817  def test_explicit_bool_cast(self):
4818  with self.assertRaisesRegex(RuntimeError, "expected a boolean"):
4819  @torch.jit.script
4820  def test_bool_cast(a):
4821  if a:
4822  return a + 2
4823  return a + 1
4824 
4825  def test_while_nonexistent_value(self):
4826  with self.assertRaisesRegex(RuntimeError, "undefined value x"):
4828  def test_while(a, b):
4829  while bool(a < 10):
4830  a = a + x
4831  b = b + 1
4832  return a + b
4833  ''')
4834 
4835  def test_while_nonexistent_cond_value(self):
4836  with self.assertRaisesRegex(RuntimeError, "undefined value x"):
4838  def test_while(a, b):
4839  while a < x:
4840  a = a + 1
4841  b = b + 1
4842  return a + b
4843  ''')
4844 
4845  def test_optional_refinement(self):
4846  @torch.jit.script
4847  def test_if_none_assignment(x):
4848  # type: (Optional[int]) -> int
4849  if x is None:
4850  x = 1
4851  return x + 1
4852 
4853  self.assertEqual(test_if_none_assignment(1), 2)
4854 
4855  @torch.jit.script
4856  def test_ternary(x):
4857  # type: (Optional[int]) -> int
4858  x = x if x is not None else 2
4859  return x
4860 
4861  @torch.jit.script
4862  def test_not_none(x):
4863  # type: (Optional[int]) -> None
4864  if x is not None:
4865  print(x + 1)
4866 
4867  @torch.jit.script
4868  def test_and(x, y):
4869  # type: (Optional[int], Optional[int]) -> None
4870  if x is not None and y is not None:
4871  print(x + y)
4872 
4873  @torch.jit.script
4874  def test_not(x, y):
4875  # type: (Optional[int], Optional[int]) -> None
4876  if not (x is not None and y is not None):
4877  pass
4878  else:
4879  print(x + y)
4880 
4881  @torch.jit.script
4882  def test_bool_expression(x):
4883  # type: (Optional[int]) -> None
4884  if x is not None and x < 2:
4885  print(x + 1)
4886 
4887  @torch.jit.script
4888  def test_nested_bool_expression(x, y):
4889  # type: (Optional[int], Optional[int]) -> int
4890  if x is not None and x < 2 and y is not None:
4891  x = x + y
4892  else:
4893  x = 5
4894  return x + 2
4895 
4896  @torch.jit.script
4897  def test_or(x, y):
4898  # type: (Optional[int], Optional[int]) -> None
4899  if y is None or x is None:
4900  pass
4901  else:
4902  print(x + y)
4903 
4904  # backwards compatibility
4905  @torch.jit.script
4906  def test_manual_unwrap_opt(x):
4907  # type: (Optional[int]) -> int
4908  if x is None:
4909  x = 1
4910  else:
4912  return x # noqa: T484
4913 
4914  with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
4915  @torch.jit.script
4916  def or_error(x, y):
4917  # type: (Optional[int], Optional[int]) -> None
4918  if x is None or y is None:
4919  print(x + y) # noqa: T484
4920 
4921  with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
4922  @torch.jit.script
4923  def and_error(x, y):
4924  # type: (Optional[int], Optional[int]) -> None
4925  if x is None and y is None:
4926  pass
4927  else:
4928  print(x + y) # noqa: T484
4929 
4930  with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
4931  @torch.jit.script
4932  def named_var(x):
4933  # type: (Optional[int]) -> None
4934  x_none = x is not None
4935  if x_none:
4936  print(x + 1) # noqa: T484
4937 
4938  with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
4939  @torch.jit.script
4940  def named_var_and(x, y):
4941  # type: (Optional[int], Optional[int]) -> None
4942  x_none = x is not None
4943  if y is not None and x_none:
4944  print(x + y) # noqa: T484
4945 
4946  def test_while_write_outer_then_read(self):
4947  def func(a, b):
4948  while bool(a < 10):
4949  a = a + 1
4950  b = a + 1
4951  return a + b
4952 
4953  inputs = self._make_scalar_vars([42, 1337], torch.int64)
4954  self.checkScript(func, inputs, optimize=True)
4955 
4956  def test_while_nest_if(self):
4957  def func(a, b):
4958  # type: (int, int) -> int
4959  c = 0
4960  while a < 10:
4961  a = a + 1
4962  b = b + 1
4963  if a > b:
4964  c = -a
4965  else:
4966  c = -b
4967  return c + 1
4968 
4969  inputs = self._make_scalar_vars([-1234, 4321], torch.int64)
4970  self.checkScript(func, inputs, optimize=True)
4971 
4972  def test_math_ops(self):
4973 
4974  def test_floor():
4975  return math.floor(1.5)
4976 
4977  self.checkScript(test_floor, ())
4978 
4979  def test_if_nest_while(self):
4980  def func(a, b):
4981  # type: (int, int) -> int
4982  c = 0
4983  if a > b:
4984  while a > b:
4985  b = b + 1
4986  c = -b
4987  return c
4988 
4989  inputs = self._make_scalar_vars([4321, 1234], torch.int64)
4990  self.checkScript(func, inputs, optimize=True)
4991 
4992  def test_script_for_in_range(self):
4993  def fn():
4994  c = 0
4995  for i in range(100):
4996  c += i
4997  return c
4998  self.checkScript(fn, (), outputs=4950, optimize=True)
4999 
5000  def test_script_for_in_range_dynamic(self):
5001  def fn():
5002  c = 0
5003  for i in range(100):
5004  acc = 0
5005  for j in range(i):
5006  acc += j
5007  c += acc
5008  return c
5009  self.checkScript(fn, (), optimize=False)
5010 
5011  def test_script_for_in_range_ast(self):
5012  @torch.jit.script
5013  def test_script_for_in_range_ast():
5014  c = 0
5015  for i in range(100):
5016  acc = 0
5017  for j in range(i):
5018  acc += j
5019  c += acc
5020  return c
5021 
5022  self.assertEqual(test_script_for_in_range_ast(), 161700)
5023 
5024  def test_script_for_in_range_if_ast(self):
5025  @torch.jit.script
5026  def test_script_for_in_range_if_ast(x):
5027  output = x
5028  for i in range(20):
5029  if i == 0:
5030  output = x.unsqueeze(0)
5031  else:
5032  output = torch.cat((output, x.unsqueeze(0)), dim=0)
5033  return output
5034  inputs = self._make_scalar_vars([0], torch.int64)
5035 
5036  self.assertEqual(test_script_for_in_range_if_ast(*inputs).shape[0], 20)
5037 
5038  def test_script_optional_none(self):
5039  def none_stmt(x):
5040  output = None
5041  output = x
5042  return output
5043 
5044  def none_args(x):
5045  # type: (Optional[Tensor]) -> Optional[Tensor]
5046  return None
5047 
5048  self.checkScript(none_stmt, [torch.arange(0, 2)], optimize=True)
5049  self.checkScript(none_args, [None], optimize=True)
5050 
5051  # test undefined tensor None as default param
5052  def test_script_optional_tensor_none(x=None):
5053  # type: (Optional[Tensor]) -> Tensor
5054  res = torch.zeros(1, dtype=torch.int8)
5055  if x is None:
5056  res = res + 1
5057  else:
5058  res = x
5059  return res
5060 
5061  fn = test_script_optional_tensor_none
5062  scripted_fn = torch.jit.script(fn)
5063  self.assertEqual(fn(), scripted_fn())
5064  self.assertEqual(fn(torch.zeros(1)), scripted_fn(torch.zeros(1)))
5065 
5066  # test typical None as default param
5067  def test_script_optional_other_none(x=None):
5068  # type: (Optional[float]) -> float
5069  res = 2.0
5070  if x is None:
5071  res = res + 1.0
5072  else:
5073  res = x
5074  return res
5075 
5076  fn = test_script_optional_other_none
5077  scripted_fn = torch.jit.script(fn)
5078  self.assertEqual(fn(), scripted_fn())
5079  self.assertEqual(fn(1.0), scripted_fn(1.0))
5080 
5081  def test_script_clamp_none(self):
5082  def test_script_clamp_max_none(x):
5083  return torch.clamp(x, min=2, max=None)
5084 
5085  def test_script_clamp_max(x):
5086  return torch.clamp(x, max=2)
5087 
5088  def test_script_clamp_min_none(x):
5089  return torch.clamp(x, min=None, max=2)
5090 
5091  def test_script_clamp_min(x):
5092  return torch.clamp(x, min=2)
5093 
5094  input = [torch.arange(0, 3)]
5095  self.checkScript(test_script_clamp_max_none, input, optimize=True)
5096  self.checkScript(test_script_clamp_max, input, optimize=True)
5097  self.checkScript(test_script_clamp_min_none, input, optimize=True)
5098  self.checkScript(test_script_clamp_min, input, optimize=True)
5099 
5100  def test_script_bool_constant(self):
5101  script = '''
5102  def test_script_bool_constant():
5103  a = True
5104  return a
5105  '''
5106  outputs = [1]
5107  self.checkScript(script, [], outputs[0], True, 'test_script_bool_constant')
5108 
5109  def test_ternary(self):
5110  def func(a, b):
5111  c = 3
5112  c = a + b if bool(a > 3) else b
5113  return c
5114 
5115  inputs_true = self._make_scalar_vars([5, 2], torch.int64)
5116  inputs_false = self._make_scalar_vars([1, 0], torch.int64)
5117  self.checkScript(func, inputs_true, optimize=True)
5118  self.checkScript(func, inputs_false, optimize=True)
5119 
5120  def test_print(self):
5121  def func(x, y):
5122  q = (x + y).sigmoid()
5123  print(q, 1, 2, [1, 2], [1.0, 2.0])
5124  w = -q
5125  return w * w
5126 
5127  x = torch.arange(4., requires_grad=True)
5128  y = torch.arange(0., 8, 2, requires_grad=True)
5129  self.checkScript(func, [x, y], optimize=True, capture_output=True)
5130 
5131  def test_format(self):
5132  def func(x):
5133  print("{}, I'm a {}".format("Hello", "test"))
5134  print("format blank".format())
5135  print("stuff before {}".format("hi"))
5136  print("{} stuff after".format("hi"))
5137  return x + 1
5138 
5139  x = torch.arange(4., requires_grad=True)
5140  self.checkScript(func, [x], optimize=True, capture_output=True)
5141 
5142  def test_logical_short_circuit(self):
5143  @torch.jit.script
5144  def testNoThrows(t):
5145  c1 = 1
5146  if (False and bool(t[1])) or (True or bool(t[1])):
5147  c1 = 0
5148  return c1
5149 
5150  self.assertEqual(0, testNoThrows(torch.randn(0)))
5151  ifs = testNoThrows.graph.findAllNodes("prim::If", recurse=False)
5152 
5153  # three ifs at the top level, and the second one has a nested if for
5154  # the or (True or bool(t[1])) expression
5155  self.assertTrue(len(ifs) == 3)
5156  self.assertTrue(ifs[0].findNode("prim::If") is None)
5157  self.assertTrue(ifs[1].findNode("prim::If").findNode("prim::If") is None)
5158  self.assertTrue(ifs[2].findNode("prim::If") is None)
5159 
5160  @torch.jit.script
5161  def throwsOr(t):
5162  c0 = False or bool(t[1])
5163  print(c0)
5164 
5165  @torch.jit.script
5166  def throwsAnd(t):
5167  c0 = True and bool(t[1])
5168  print(c0)
5169 
5170  t = torch.randn(0)
5171  with self.assertRaisesRegex(RuntimeError, "index 1 out of range for tensor of size"):
5172  throwsOr(t)
5173  with self.assertRaisesRegex(RuntimeError, "index 1 out of range for tensor of size"):
5174  throwsAnd(t)
5175 
5176  def test_type_cast(self):
5177  template = dedent('''
5178  def cast(v):
5179  # type: ({from_type}) -> {to_type}
5180  return {to_type}(v)
5181  ''')
5182 
5183  def check_cast(from_type, to_type, value, raises=False):
5184  code = template.format(from_type=from_type, to_type=to_type)
5185  expected = getattr(builtins, to_type)(value)
5186  if raises:
5187  with self.assertRaisesRegex(RuntimeError, "Cannot cast"):
5188  cu = torch.jit.CompilationUnit(code)
5189  else:
5190  self.checkScript(code, (value,), name='cast', outputs=expected)
5191 
5192  check_cast('int', 'float', 1)
5193  check_cast('int', 'bool', 1)
5194  check_cast('int', 'bool', 0)
5195 
5196  check_cast('float', 'int', 1.)
5197  check_cast('float', 'bool', 1.)
5198  check_cast('float', 'bool', 0.)
5199 
5200  check_cast('bool', 'int', True)
5201  check_cast('bool', 'float', True)
5202 
5203  def test_multiple_assignment(self):
5204  def outer_func(x):
5205  return x * 2, x + 2
5206 
5207  @torch.jit.script
5208  def func(x):
5209  y, z = outer_func(x)
5210  return y + z
5211 
5212  x = torch.arange(4)
5213  self.assertEqual(func(x), x * 2 + x + 2)
5214 
5215  def test_literals(self):
5216  def func(a):
5217  return a.view(size=[1, 2, 3])
5218 
5219  a = torch.randn(6)
5220  self.checkScript(func, [a], optimize=True)
5221 
5222  def test_return(self):
5223  def no_return(a):
5224  a + 1
5225 
5226  def void_return(a):
5227  return
5228 
5229  def one_return(a):
5230  return a + 1.
5231 
5232  def multiple_returns(a):
5233  return a * 1., a * 2., a * 3.
5234 
5235  a = torch.randn(1, dtype=torch.float)
5236  self.checkScript(no_return, [a], optimize=True)
5237  self.checkScript(void_return, [a], optimize=True)
5238  self.checkScript(one_return, [a], optimize=True)
5239  self.checkScript(multiple_returns, [a], optimize=True)
5240 
5241  with self.assertRaisesRegex(RuntimeError, "but is actually of type None"):
5243  def no_return_bad_annotation(a):
5244  # type: (Tensor) -> Tensor
5245  a + 1
5246  ''')
5247 
5248  def test_error(self):
5249  @torch.jit.script
5250  def foo(a):
5251  return a.t()
5252  s = Variable(torch.rand(5, 5, 5))
5253  # XXX: this should stay quiet in stay propagation and only fail in the interpreter
5254  with self.assertRaisesRegex(RuntimeError, "failed in interpreter"):
5255  foo(s)
5256 
5257  @torch.jit.script
5258  def bar(c, b):
5259  return c + b
5260 
5261  with self.assertRaisesRegex(RuntimeError, "failed in interpreter"):
5262  bar(Variable(torch.rand(10), requires_grad=True), Variable(torch.rand(9), requires_grad=True))
5263 
5264  def test_binop_unsupported_error(self):
5265  with self.assertRaisesRegex(NotSupportedError, "unsupported binary operator:"):
5266  @torch.jit.script
5267  def binop(x, y):
5268  # Replace this with another unsupported op when/if it gets supported
5269  return x << y
5270 
5271  def test_bitwise_ops(self):
5272 
5273  def int_test():
5274  return 2 & 3, 2 ^ 3, 2 | 3
5275 
5276  self.checkScript(int_test, ())
5277 
5278  def bool_test(x, y):
5279  # type: (bool, bool) -> Tuple[bool, bool, bool]
5280  return x & y, x ^ y, x | y
5281 
5282  self.checkScript(bool_test, (True, False))
5283  self.checkScript(bool_test, (True, True))
5284 
5285  def tensor_test(x, y):
5286  return x & y, x ^ y, x | y
5287 
5288  x = torch.tensor(2)
5289  y = torch.tensor(3)
5290 
5291  self.checkScript(tensor_test, (x, y))
5292 
5293  def test_number_math(self):
5294  ops_template = dedent('''
5295  def func():
5296  return {scalar1} {op} {scalar2}
5297  ''')
5298  ops = ['+', '-', '*', '%', '<', '<=', '>', '>=', '==', '!=', '//']
5299  funcs_template = dedent('''
5300  def func():
5301  return {func}({scalar1}, {scalar2})
5302  ''')
5303  funcs = ['min', 'max']
5304  scalars = ['7', '2', '3', '-3', '3.14', '0.125', '-0.5', '2.0', '-2.0']
5305  scalar_pairs = [(scalar1, scalar2) for scalar1 in scalars for scalar2 in scalars]
5306 
5307  def run_test(code):
5308  scope = {}
5309  execWrapper(code, globals(), scope)
5310  cu = torch.jit.CompilationUnit(code)
5311 
5312  self.assertEqual(cu.func(), scope['func']())
5313 
5314  for scalar1, scalar2 in scalar_pairs:
5315  for op in ops:
5316  code = ops_template.format(op=op, scalar1=scalar1, scalar2=scalar2)
5317  run_test(code)
5318  for func in funcs:
5319  code = funcs_template.format(func=func, scalar1=scalar1, scalar2=scalar2)
5320  run_test(code)
5321 
5322  def test_number_div(self):
5323  self.checkScript(div_int_future, (), optimize=True)
5324  self.checkScript(div_float_future, (), optimize=True)
5325 
5326  if PY2:
5327  with self.assertRaisesRegex(RuntimeError, 'from __future__ import division'):
5328  torch.jit.script(div_int_nofuture)
5329  with self.assertRaisesRegex(RuntimeError, 'from __future__ import division'):
5330  torch.jit.script(div_float_nofuture)
5331  else:
5332  self.checkScript(div_int_nofuture, (), optimize=True)
5333  self.checkScript(div_float_nofuture, (), optimize=True)
5334 
5335  def test_floor_div(self):
5336  @torch.jit.script
5337  def foo(a, b):
5338  # type: (int, int) -> int
5339  return a // b
5340  for i in range(-8, 8):
5341  for j in range(-8, 8):
5342  if j != 0:
5343  self.assertEqual(foo(i, j), i // j)
5344  else:
5345  with self.assertRaisesRegex(RuntimeError, 'division by 0'):
5346  foo(i, j)
5347 
5348  def test_number_augassign(self):
5349  def func():
5350  z = 1
5351  z += 2
5352  return z
5353 
5354  self.checkScript(func, (), optimize=True)
5355 
5356  def test_number_neg(self):
5357  # int -> int
5358  def func1():
5359  return -8
5360 
5361  # float -> float
5362  def func2():
5363  return -3.14
5364 
5365  self.checkScript(func1, (), optimize=True)
5366  self.checkScript(func2, (), optimize=True)
5367 
5368  def _test_tensor_number_math(self, device='cpu'):
5369  template = dedent('''
5370  def func(t):
5371  return {lhs} {op} {rhs}
5372  ''')
5373 
5374  def test(op, const, swap_args):
5375  args = ('t', const)
5376  if swap_args:
5377  args = (const, 't')
5378 
5379  code = template.format(lhs=args[0], rhs=args[1], op=op)
5380  scope = {}
5381  execWrapper(code, globals(), scope)
5382  cu = torch.jit.CompilationUnit(code)
5383  self.assertEqual(cu.func(tensor), scope['func'](tensor))
5384 
5385  var_int = [2, -2]
5386  var_float = [1.4321, -1.2]
5387 
5388  ops = ['+', '-', '*', '%', '<', '<=', '>', '>=', '==', '!=', '/']
5389 
5390  float_tensor = torch.randn(5, 5, device=device)
5391  double_tensor = torch.randn(5, 5, dtype=torch.double, device=device)
5392  long_tensor = torch.randint(-5, 5, (5, 5), dtype=torch.long, device=device)
5393  long_tensor[long_tensor == 0] = 2
5394 
5395  tensors = [float_tensor, double_tensor, long_tensor]
5396  consts = var_int + var_float
5397 
5398  for op, tensor, const, swap_args in product(ops, tensors, consts, [True, False]):
5399  # FIXME: things like 2 / long_tensor are not implemented correctly
5400  # Look in torch/tensor.py to see how pytorch implements it.
5401  if op == '/' and tensor.data_ptr() == long_tensor.data_ptr():
5402  continue
5403 
5404  # % operator does not take: const % tensor
5405  if op == '%' and swap_args is True:
5406  continue
5407 
5408  test(op, const, swap_args)
5409 
5410  def test_tensor_number_math(self):
5411  self._test_tensor_number_math()
5412 
5413  def test_torch_tensor_bad_input(self):
5414  with self.assertRaisesRegex(RuntimeError, "Input list to torch.tensor must be of ints, floats, "
5415  "or bools, got None"):
5416  @torch.jit.script
5417  def test():
5418  return torch.tensor([None])
5419 
5420  with self.assertRaisesRegex(RuntimeError, "Note: empty lists are constructed as Tensor"):
5421  @torch.jit.script
5422  def tmp():
5423  return torch.tensor([])
5424 
5425  @torch.jit.script
5426  def foo():
5427  return torch.tensor([[2, 2], [1]])
5428  with self.assertRaisesRegex(RuntimeError, "Expected sequence of length"):
5429  foo()
5430 
5431  @suppress_warnings
5432  def test_torch_tensor_empty_list(self):
5433  def func():
5434  return torch.tensor(torch.jit.annotate(List[int], []))
5435  cu = torch.jit.script(func)
5436  t1 = cu()
5437  t2 = func()
5438 
5439  # torchscript returns int tensor, python returns float tensor
5440  self.assertNotEqual(t1.dtype, t2.dtype)
5441 
5442  def func():
5443  li = torch.jit.annotate(List[int], [])
5444  return torch.tensor([li, li])
5445 
5446  self.checkScript(func, ())
5447 
5448  def func():
5449  li = torch.jit.annotate(List[int], [])
5450  return torch.tensor([[[li]]])
5451 
5452  self.checkScript(func, ())
5453 
5454  def test_torch_tensor(self):
5455  template = dedent('''
5456  def func():
5457  li = {list_create}
5458  return torch.tensor(li {options})
5459  ''')
5460 
5461  lists = ["2.5", "4", "True", "False", "[2]", "[-.5]", "[False, True, False]", "[2, 2]",
5462  "torch.jit.annotate(List[int], [])", "[2.5, 2.5]", "[[2], [2]]", "[[-.5], [2.2]]", "[[False], [True]]"]
5463 
5464  dtypes = ["", ", dtype=torch.float", ", dtype=torch.double", ", dtype=torch.half",
5465  ", dtype=torch.uint8", ", dtype=torch.int8", ", dtype=torch.short",
5466  ", dtype=torch.int", ", dtype=torch.long"]
5467 
5468  devices = ['', ", device='cpu'"]
5469  if RUN_CUDA:
5470  devices.append(", device='cuda'")
5471 
5472  option_pairs = [dtype + device for dtype in dtypes for device in devices]
5473  for li in lists:
5474  for option in option_pairs:
5475  # tensor from empty list is type float in python and annotated type in torchscript
5476  if "annotate" in li and "dtype" not in option:
5477  continue
5478  code = template.format(list_create=li, options=option)
5479  scope = {}
5480  exec(code, globals(), scope)
5481  cu = torch.jit.CompilationUnit(code)
5482  t1 = cu.func()
5483  t2 = scope['func']()
5484  if t1.dtype == torch.float16: # equality NYI for half tensor
5485  self.assertTrue(str(t1) == str(t2))
5486  else:
5487  self.assertEqual(t1, t2)
5488  self.assertEqual(t1.dtype, t2.dtype)
5489  self.assertEqual(t1.device, t2.device)
5490 
5491  # adapted from test in test_torch
5492  def test_tensor_to(self):
5493  template = dedent('''
5494  def func(t):
5495  cuda = "{cuda}"
5496  device = "{device}"
5497  non_blocking = {non_blocking}
5498  return {to_str}
5499  ''')
5500 
5501  def s(t, to_str, non_blocking=None, device=None, cuda=None):
5502  device = device if device is not None else str(t.device)
5503  non_blocking = non_blocking if non_blocking is not None else False
5504  cuda = "cuda" if cuda is None else cuda
5505  code = template.format(to_str=to_str, device=device, non_blocking=non_blocking, cuda=cuda)
5506  scope = {}
5507  cu = torch.jit.CompilationUnit(code)
5508  return cu.func(t)
5509 
5510  def test_copy_behavior(t, non_blocking=False):
5511  self.assertIs(t, s(t, 't.to(t, non_blocking=non_blocking)', non_blocking))
5512  self.assertIs(t, s(t, 't.to(t.dtype, non_blocking=non_blocking)', non_blocking))
5513  self.assertIs(t, s(t, 't.to(torch.empty_like(t), non_blocking=non_blocking)', non_blocking))
5514  self.assertIsNot(t, s(t, 't.to(t, non_blocking=non_blocking, copy=True)', non_blocking))
5515  self.assertIsNot(t, s(t, 't.to(t.dtype, non_blocking=non_blocking, copy=True)', non_blocking))
5516  self.assertIsNot(t, s(t, 't.to(torch.empty_like(t), non_blocking=non_blocking, copy=True)', non_blocking))
5517 
5518  devices = [t.device]
5519  if t.device.type == 'cuda':
5520  if t.device.index == -1:
5521  devices.append('cuda:{}'.format(torch.cuda.current_device()))
5522  elif t.device.index == torch.cuda.current_device():
5523  devices.append('cuda')
5524  for device in devices:
5525  self.assertIs(t, s(t, 't.to(device, non_blocking=non_blocking)', non_blocking, device))
5526  self.assertIs(t, s(t, 't.to(device, t.dtype, non_blocking=non_blocking)', non_blocking, device))
5527  self.assertIsNot(t, s(t, 't.to(device, non_blocking=non_blocking, copy=True)', non_blocking, device))
5528  self.assertIsNot(t, s(t, 't.to(device, t.dtype, non_blocking=non_blocking, copy=True)',
5529  non_blocking, device))
5530 
5531  t = torch.tensor(5)
5532  test_copy_behavior(t)
5533 
5534  self.assertEqual(t.device, s(t, "t.to('cpu')").device)
5535  self.assertEqual(t.device, s(t, "t.to('cpu', dtype=torch.float32)").device)
5536  self.assertIs(torch.float32, s(t, "t.to('cpu', dtype=torch.float32)").dtype)
5537  self.assertEqual(t.device, s(t, "t.to(torch.float32)").device)
5538  self.assertIs(torch.float32, s(t, "t.to(dtype=torch.float32)").dtype)
5539  self.assertEqual(t.data_ptr(), s(t, "t.to('cpu')").data_ptr())
5540  self.assertEqual(t.data_ptr(), s(t, "t.to(dtype=t.dtype, device=t.device, copy=False)").data_ptr())
5541  self.assertEqual(t.data_ptr(), s(t, "t.to('cpu', copy=False)").data_ptr())
5542  self.assertNotEqual(t.data_ptr(), s(t, "t.to('cpu', copy=True)").data_ptr())
5543 
5544  a = torch.tensor(5)
5546  for non_blocking in [True, False]:
5547  for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
5548  b = torch.tensor(5., device=cuda)
5549  test_copy_behavior(b, non_blocking)
5550  self.assertEqual(b.device, s(b, "t.to(cuda, non_blocking=non_blocking).device", cuda=cuda))
5551  self.assertEqual(a.device, s(b, "t.to('cpu', non_blocking=non_blocking).device"))
5552  self.assertEqual(b.device, s(b, "t.to(cuda, non_blocking=non_blocking).device", cuda=cuda))
5553  self.assertIs(torch.int32, s(b, "t.to('cpu', dtype=torch.int32, non_blocking=non_blocking)").dtype)
5554  self.assertEqual(a.device, s(b, "t.to('cpu', dtype=torch.int32, non_blocking=non_blocking)").device)
5555  self.assertIs(torch.int32, s(b, "t.to(dtype=torch.int32)").dtype)
5556  self.assertEqual(b.device, s(b, "t.to(dtype=torch.int32)").device)
5557 
5558  # Test AD: aten::to(Tensor self, int dtype, bool non_blocking, bool copy) -> Tensor
5559  t = torch.tensor(5).float().requires_grad_()
5560  out_ref = t.to(torch.float32)
5561  out = s(t, "t.to(torch.float32)")
5562  self.assertEqual(out_ref, out)
5563 
5564  grad_ref = torch.autograd.grad(out_ref.sum(), t)
5565  grad = torch.autograd.grad(out.sum(), t)
5566  self.assertEqual(grad_ref, grad)
5567 
5568  # Test AD: aten::to(Tensor self, Device? device, int? dtype, bool non_blocking, bool copy) -> Tensor
5569  out_ref = t.to('cpu')
5570  out = s(t, "t.to('cpu')")
5571  self.assertEqual(out_ref, out)
5572 
5573  grad_ref = torch.autograd.grad(out_ref.sum(), t)
5574  grad = torch.autograd.grad(out.sum(), t)
5575  self.assertEqual(grad_ref, grad)
5576 
5577  # Test AD: aten::to(Tensor self, Tensor other, bool non_blocking, bool copy) -> Tensor
5578  @torch.jit.script
5579  def func2(t, t_ref):
5580  return t.to(t_ref)
5581 
5582  func2.debug_disable_autodiff_subgraph_inlining()
5583 
5584  t_ref = torch.tensor(4).double()
5585  out_ref = t.to(t_ref)
5586  out = func2(t, t_ref)
5587  grad_ref = torch.autograd.grad(out_ref.sum(), t)
5588  grad = torch.autograd.grad(out.sum(), t)
5589  self.assertEqual(grad_ref, grad)
5590 
5591  @unittest.skipIf(not RUN_CUDA, "No CUDA")
5592  def test_tensor_number_math_cuda(self):
5593  self._test_tensor_number_math(device='cuda')
5594 
5595  def test_not(self):
5596  # test not operator in python
5597  # TODO: add more tests when bool conversions ready
5598  def test_not_op(a):
5599  return not bool(a > 1)
5600 
5601  self.checkScript(test_not_op, (torch.tensor(2), ), optimize=True)
5602 
5603  def test_is_isnot(self):
5604  # test is and is not operator in python
5605  template = dedent('''
5606  def func():
5607  # type: () -> bool
5608  return {lhs} {op} {rhs}
5609  ''')
5610 
5611  def test(op, args):
5612  code = template.format(lhs=args[0], rhs=args[1], op=op)
5613  scope = {}
5614  execWrapper(code, globals(), scope)
5615  cu = torch.jit.CompilationUnit(code)
5616  self.assertEqual(
5617  cu.func(),
5618  scope['func'](),
5619  "Failed with op: {}, lhs: {}, rhs: {}"
5620  .format(op, args[0], args[1])
5621  )
5622 
5623  ops = ['is', 'is not']
5624  type_literals = [True, False, None, [1, 1]]
5625 
5626  # do literals product to try any types combinations
5627  for op, lhs, rhs in product(ops, type_literals, type_literals):
5628  test(op, [lhs, rhs])
5629 
5630  def test_isinstance(self):
5631  # test isinstance operator for static type checking
5632  template = dedent('''
5633  def func(x):
5634  # type: ({type_hint}) -> bool
5635  return isinstance(x, {typ})
5636  ''')
5637 
5638  def test(inp, typ, type_hint):
5639  code = template.format(typ=typ, type_hint=type_hint)
5640  scope = {}
5641  execWrapper(code, globals(), scope)
5642  cu = torch.jit.CompilationUnit(code)
5643  self.assertEqual(
5644  cu.func(inp),
5645  scope['func'](inp),
5646  "Failed with typ: {}"
5647  .format(typ)
5648  )
5649 
5650  inputs = [True, 1, 1.0, torch.tensor(1), [1, 2], (1.0,), [1, 2], 1]
5651  type_literals = ['bool', 'int', 'float', 'torch.Tensor', 'list', 'tuple',
5652  '(list, tuple)', '(int, float, bool)']
5653  type_annotations = ['bool', 'int', 'float', 'Tensor', 'List[int]', 'Tuple[float]',
5654  'List[int]', 'int']
5655 
5656  # do zipping to try different types
5657  for inp, typ, type_hint in zip(inputs, type_literals, type_annotations):
5658  test(inp, typ, type_hint)
5659 
5660  # test optional isintance check
5661  with self.assertRaisesRegex(RuntimeError, "Optional isinstance check is not supported"):
5662  @torch.jit.script
5663  def opt_func(x):
5664  # type: (Optional[int]) -> bool
5665  return isinstance(x, int)
5666 
5667  def test_python_call(self):
5668  def pyfunc(a):
5669  return a * 3.0
5670 
5671  cu = torch.jit.CompilationUnit('''
5672  def other_func(a):
5673  return a + a
5674 
5675  def test_call_python(a):
5676  b = pyfunc(a)
5677  b = other_func(b)
5678  i = 0
5679  step = 1
5680  while i < 10:
5681  b = pyfunc(b)
5682  if bool(b > 3.0):
5683  b = pyfunc(b)
5684  i = 11
5685  return b
5686  ''')
5687  inputs = self._make_scalar_vars([1], torch.float)
5688  outputs = self._make_scalar_vars([54], torch.float)
5689 
5690  self.assertEqual(cu.test_call_python(*inputs), outputs[0])
5691 
5692  def test_python_call_failure(self):
5693  with self.assertRaisesRegex(RuntimeError, "undefined value pyfunc2"):
5694  def pyfunc(a):
5695  return a * 3.0
5696 
5697  cu = torch.jit.CompilationUnit('''
5698  def other_func(a):
5699  return a + a
5700 
5701  def test_call_python(a):
5702  b = pyfunc(a)
5703  b = other_func(b)
5704  i = 0
5705  step = 1
5706  while i < 10:
5707  b = pyfunc2(b)
5708  if b > 3.0:
5709  b = pyfunc(b)
5710  i = 11
5711  return b
5712  ''')
5713  inputs = self._make_scalar_vars([1], torch.float)
5714  outputs = self._make_scalar_vars([54], torch.float)
5715 
5716  self.assertEqual(cu.test_call_python(*inputs), outputs)
5717 
5718  def test_python_call_annotation(self):
5719  def pyfunc(a):
5720  return a * 3.0
5721 
5722  @torch.jit.script
5723  def foo(a):
5724  return pyfunc(a) + pyfunc(a)
5725 
5726  inputs = self._make_scalar_vars([1], torch.float)
5727  outputs = self._make_scalar_vars([6], torch.float)
5728  self.assertEqual(foo(*inputs), outputs[0])
5729 
5730  def test_python_call_annoytation_failure(self):
5731  with self.assertRaisesRegex(RuntimeError, "undefined value pyfunc2"):
5732  def pyfunc(a):
5733  return a * 3.0
5734 
5735  @torch.jit.script
5736  def foo(a):
5737  return pyfunc2(a) + pyfunc(a)
5738 
5739  inputs = self._make_scalar_vars([1], torch.float)
5740  outputs = self._make_scalar_vars([6], torch.float)
5741 
5742  self.assertEqual(foo(*inputs), outputs[0])
5743 
5744  def test_desugar_module(self):
5745  import torch.nn.functional as F
5746 
5747  def fn(x, slope):
5748  a = torch.abs(x)
5749  b = torch.nn.functional.prelu(x, slope)
5750  c = F.prelu(x, slope)
5751  return a, b, c
5752 
5753  x = torch.arange(-3., 4)
5754  slope = torch.tensor([0.5])
5755  self.checkScript(fn, [x, slope], optimize=True)
5756 
5757  def test_script_docstring(self):
5758  @torch.jit.script
5759  def with_docstring(x):
5760  """test str"""
5761  y = x
5762  """y is the same as x"""
5763  return y
5764  self.assertEqual(with_docstring.__doc__, 'test str')
5765 
5766  def test_script_method_docstring(self):
5767  class A(torch.jit.ScriptModule):
5768  @torch.jit.script_method
5769  def with_docstring(self, x):
5770  """test str"""
5771  y = x
5772  """y is the same as x"""
5773  return y
5774  a = A()
5775  self.assertEqual(a.with_docstring.__doc__, 'test str')
5776 
5777  @unittest.skipIf(TEST_WITH_UBSAN or not torch.fbgemm_is_cpu_supported(),
5778  'Quantized RNN requires FBGEMM. FBGEMM does not play'
5779  ' well with UBSAN at the moment, so we skip the test if'
5780  ' we are in a UBSAN environment.')
5781  def test_rnn_cell_quantized(self):
5782  d_in, d_hid = 2, 2
5783 
5784  for cell in [
5785  torch.nn.LSTMCell(d_in, d_hid).float(),
5786  torch.nn.GRUCell(d_in, d_hid).float(),
5787  torch.nn.RNNCell(d_in, d_hid).float(),
5788  ]:
5789  if isinstance(cell, torch.nn.LSTMCell):
5790  num_chunks = 4
5791  elif isinstance(cell, torch.nn.GRUCell):
5792  num_chunks = 3
5793  elif isinstance(cell, torch.nn.RNNCell):
5794  num_chunks = 1
5795 
5796  # Replace parameter values s.t. the range of values is exactly
5797  # 255, thus we will have 0 quantization error in the quantized
5798  # GEMM call. This i s for testing purposes.
5799  #
5800  # Note that the current implementation does not support
5801  # accumulation values outside of the range representable by a
5802  # 16 bit integer, instead resulting in a saturated value. We
5803  # must take care that in our test we do not end up with a dot
5804  # product that overflows the int16 range, e.g.
5805  # (255*127+255*127) = 64770. So, we hardcode the test values
5806  # here and ensure a mix of signedness.
5807  vals = [[100, -155],
5808  [100, -155],
5809  [-155, 100],
5810  [-155, 100],
5811  [100, -155],
5812  [-155, 100],
5813  [-155, 100],
5814  [100, -155]]
5815  vals = vals[:d_hid * num_chunks]
5816  cell.weight_ih = torch.nn.Parameter(
5817  torch.tensor(vals, dtype=torch.float),
5818  requires_grad=False)
5819  cell.weight_hh = torch.nn.Parameter(
5820  torch.tensor(vals, dtype=torch.float),
5821  requires_grad=False)
5822 
5823  ref = copy.deepcopy(cell)
5824 
5826  x = torch.tensor([[100, -155],
5827  [-155, 100],
5828  [100, -155]], dtype=torch.float)
5829  h0_vals = [[-155, 100],
5830  [-155, 155],
5831  [100, -155]]
5832  hx = torch.tensor(h0_vals, dtype=torch.float)
5833  if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
5834  cx = torch.tensor(h0_vals, dtype=torch.float)
5835  hiddens = (hx, cx)
5836  else:
5837  hiddens = hx
5838 
5839  if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
5840  class ScriptWrapper(torch.jit.ScriptModule):
5841  def __init__(self, cell):
5842  super(ScriptWrapper, self).__init__()
5843  self.cell = cell
5844 
5845  @torch.jit.script_method
5846  def forward(self, x, hiddens):
5847  # type: (torch.Tensor, Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
5848  return self.cell(x, hiddens)
5849  else:
5850 
5851  class ScriptWrapper(torch.jit.ScriptModule):
5852  def __init__(self, cell):
5853  super(ScriptWrapper, self).__init__()
5854  self.cell = cell
5855 
5856  @torch.jit.script_method
5857  def forward(self, x, hiddens):
5858  # type: (torch.Tensor, torch.Tensor) -> torch.Tensor
5859  return self.cell(x, hiddens)
5860 
5861  cell = ScriptWrapper(cell)
5862  outs = cell(x, hiddens)
5863  cell = self.getExportImportCopyWithPacking(cell)
5864 
5865  outs = cell(x, hiddens)
5866  ref_outs = ref(x, hiddens)
5867 
5868  self.assertEqual(len(outs), len(ref_outs))
5869  for out, ref_out in zip(outs, ref_outs):
5870  torch.testing.assert_allclose(out, ref_out)
5871 
5872  def test_script_module(self):
5873  class M1(torch.jit.ScriptModule):
5874  def __init__(self):
5875  super(M1, self).__init__(False)
5876  self.weight = nn.Parameter(torch.randn(2))
5877 
5878  @torch.jit.script_method
5879  def forward(self, thing):
5880  return self.weight + thing
5881 
5882  class PModule(nn.Module):
5883  def __init__(self):
5884  super(PModule, self).__init__()
5885  self.a = nn.Parameter(torch.randn(2, 3))
5886 
5887  def forward(self, a):
5888  return self.a.mm(a)
5889 
5890  class M2(torch.jit.ScriptModule):
5891  def __init__(self):
5892  super(M2, self).__init__(False)
5893  # test submodule
5894  self.sub = M1()
5895  self.sub2 = PModule()
5896  # test parameters
5897  self.weight = nn.Parameter(torch.randn(2, 3))
5898  self.bias = nn.Parameter(torch.randn(2))
5899  # test defining a method from a string
5900  self.define("""
5901  def hi(self, a):
5902  return self.weight.mm(a)
5903  """)
5904  # test script methods
5905 
5906  @torch.jit.script_method
5907  def doit(self, input):
5908  # test use of parameter
5909  return self.weight.mm(input)
5910 
5911  @torch.jit.script_method
5912  def doit2(self, input):
5913  return self.weight.mm(input)
5914 
5915  @torch.jit.script_method
5916  def forward(self, input):
5917  a = self.doit(input)
5918  b = self.doit2(input)
5919  c = self.hi(input)
5920  d = self.sub2(input)
5921  return a + b + self.bias + self.sub(a) + c + d
5922  m2 = M2()
5923  input = torch.randn(3, 2)
5924  a = m2.weight.mm(input)
5925  b = m2.weight.mm(input)
5926  c = m2.weight.mm(input)
5927  d = m2.sub2.a.mm(input)
5928  ref = a + b + m2.bias + m2.sub.weight + a + c + d
5929  self.assertEqual(ref, m2.forward(input))
5930  m2.weight = nn.Parameter(torch.zeros_like(m2.weight))
5931  m2.bias = nn.Parameter(torch.zeros_like(m2.bias))
5932  m2.sub.weight = nn.Parameter(torch.zeros_like(m2.sub.weight))
5933  m2.sub2.a.data.zero_()
5934  self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2)))
5935 
5936  def test_filecheck(self):
5937  def test_check():
5938  file = "232"
5939  FileCheck().check("2").check("3").check("2").run(file)
5940  FileCheck().check("232").run(file)
5941 
5942  with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
5943  FileCheck().check("22").run(file)
5944  with self.assertRaisesRegex(RuntimeError, "CHECK: 3"):
5945  FileCheck().check("3").check("3").run(file)
5946 
5947  test_check()
5948 
5949  def test_check_count():
5950  file = "22222"
5951  FileCheck().check_count("2", 5).run(file)
5952  FileCheck().check_count("22", 2).run(file)
5953  FileCheck().check_count("222", 1).run(file)
5954 
5955  with self.assertRaisesRegex(RuntimeError, 'Expected to not find'):
5956  FileCheck().check_count("2", 4, exactly=True).run(file)
5957 
5958  with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
5959  FileCheck().check_count("22", 3).run(file)
5960 
5961  with self.assertRaisesRegex(RuntimeError, "CHECK-COUNT-6: 2"):
5962  FileCheck().check_count("2", 6).run(file)
5963 
5964  test_check_count()
5965 
5966  def test_check_same():
5967  file = "22\n33"
5968  FileCheck().check_same("22").run(file)
5969 
5970  with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
5971  FileCheck().check_same("33").run(file)
5972 
5973  file = "22 1 3"
5974 
5975  FileCheck().check("2").check_same("3").run(file)
5976  FileCheck().check_count("2", 2).check_same("3").run(file)
5977 
5978  test_check_same()
5979 
5980  def test_check_next():
5981  file = "\n1\n2\n3"
5982  FileCheck().check("1").check_next("2").check_next("3").run(file)
5983  FileCheck().check_next("1").check_next("2").check_next("3").run(file)
5984 
5985  with self.assertRaisesRegex(RuntimeError, "Expected to find"):
5986  FileCheck().check("1").check_next("2").run("12")
5987 
5988  with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
5989  FileCheck().check("1").check_next("2").run("1\n\n2")
5990 
5991  test_check_next()
5992 
5993  def test_check_dag():
5994  fc = FileCheck().check_dag("1").check_dag("2").check_not("2")
5995  fc.run("12")
5996  fc.run("21")
5997 
5998  fc = FileCheck()
5999  fc.check_not("3").check_dag("1").check_dag("2").check_not("3")
6000  fc.run("1 3 2")
6001  fc.run("2 3 1")
6002 
6003  fc = FileCheck().check_dag("1").check_dag("2").check("3")
6004  with self.assertRaisesRegex(RuntimeError, 'Expected to find "3" but did not find it'):
6005  fc.run("1 3 2")
6006 
6007  test_check_dag()
6008 
6009  def test_check_not():
6010  FileCheck().check_not("2").check("1").run("12")
6011  FileCheck().check("2").check_not("2").run("12")
6012 
6013  with self.assertRaisesRegex(RuntimeError, 'Expected to not find "2"'):
6014  FileCheck().check_not("2").check("1").run("21")
6015 
6016  with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'):
6017  FileCheck().check("2").check_not("1").run("21")
6018 
6019  # checks with distinct range matchings
6020  fb = FileCheck().check_count("2", 2).check_count("2", 2).check_not("2")
6021  with self.assertRaisesRegex(RuntimeError, 'Expected to not find "2"'):
6022  fb.run("22 2 22")
6023 
6024  fb = FileCheck().check_count("2", 2).check_not("1").check_count("2", 2)
6025  with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'):
6026  fb.run("22 1 22")
6027 
6028  def test_script_module_call_noscript(self):
6029  class M(torch.jit.ScriptModule):
6030  def __init__(self):
6031  super(M, self).__init__(False)
6032  self.value = 1
6033 
6034  def foo(self):
6035  return torch.ones(2, 2) + self.value
6036 
6037  @torch.jit.script_method
6038  def forward(self, input):
6039  return input + self.foo()
6040 
6041  m = M()
6042  input = torch.randn(2, 2)
6043  o = m(input)
6044  self.assertEqual(o, input + torch.ones(2, 2) + 1)
6045  # check that we can change python attributes
6046  # and that those changes are picked up in script methods
6047  m.value = 2
6048  o = m(input)
6049  self.assertEqual(o, input + torch.ones(2, 2) + 2)
6050 
6051  def test_script_module_nochange_submodule(self):
6052  class M(torch.jit.ScriptModule):
6053  def __init__(self):
6054  super(M, self).__init__(False)
6055  self.sub = nn.Linear(5, 5)
6056 
6057  @torch.jit.script_method
6058  def forward(self, input):
6059  return self.sub(input)
6060 
6061  m = M()
6062  input = torch.randn(1, 5, 5)
6063  o = m(input)
6064  self.assertEqual(o, m.sub(input))
6065  with self.assertRaisesRegex(RuntimeError, "cannot re-assign"):
6066  m.sub = nn.Linear(5, 5)
6067 
6068  def test_script_inline_trace_multiple_args(self):
6069  class M(torch.jit.ScriptModule):
6070  def __init__(self):
6071  super(M, self).__init__(False)
6072 
6073  def forward(self, input, input2):
6074  return input + input2
6075 
6076  class M2(torch.jit.ScriptModule):
6077  def __init__(self):
6078  super(M2, self).__init__(False)
6079  self.m = torch.jit.trace(M(), (torch.zeros(4, 3), torch.zeros(4, 3)))
6080 
6081  @torch.jit.script_method
6082  def forward(self, inp):
6083  return self.m(inp, inp)
6084 
6085  m2 = M2()
6086  m2(torch.zeros(4, 3))
6087 
6088  def test_script_module_const(self):
6089  class M(torch.jit.ScriptModule):
6090 
6091  __constants__ = ['b', 'i', 'c']
6092 
6093  def __init__(self):
6094  super(M, self).__init__(False)
6095  self.b = False
6096  self.i = 1
6097  self.c = 3.5
6098 
6099  @torch.jit.script_method
6100  def forward(self):
6101  return self.b, self.i, self.c
6102 
6103  m = M()
6104  o0, o1, o2 = m()
6105  self.assertEqual(o0, 0)
6106  self.assertEqual(o1, 1)
6107  self.assertEqual(o2, 3.5)
6108 
6109  def test_script_module_fail_const(self):
6110  class M(torch.jit.ScriptModule):
6111  def __init__(self):
6112  super(M, self).__init__(False)
6113  self.b = False
6114 
6115  @torch.jit.script_method
6116  def forward(self):
6117  return self.b
6118  with self.assertRaisesRegex(RuntimeError, "is not usable in a script method"):
6119  M()
6120 
6121  def test_script_module_valid_consts(self):
6122  tester = self
6123 
6124  class Foo(torch.jit.ScriptModule):
6125  __constants__ = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i']
6126 
6127  def __init__(self):
6128  super(Foo, self).__init__(False)
6129  self.a = 1
6130  self.b = 1.2
6131  self.c = False
6132  with tester.assertRaisesRegex(
6133  TypeError,
6134  "'Linear' object for attribute 'd' is not a valid constant"):
6135  self.d = [nn.Linear(3, 4)]
6136  self.e = lambda x: x
6137  self.f = [3, 4, 5]
6138  tester.assertTrue(type(self.f) is tuple)
6139  self.g = [3, (3, 4), 5]
6140  with tester.assertRaisesRegex(TypeError, "not a valid constant"):
6141  self.h = type(1)
6142  with tester.assertRaisesRegex(TypeError, "not a valid constant"):
6143  self.i = (3, 4, {})
6144 
6145  f = Foo()
6146 
6147  def test_script_module_param_buffer_mutation(self):
6148  # TODO: add param mutation test case after JIT support it
6149  class ModuleBufferMutate(torch.jit.ScriptModule):
6150  def __init__(self):
6151  super(ModuleBufferMutate, self).__init__(False)
6152  self.register_buffer('running_var', torch.tensor(0, dtype=torch.long))
6153 
6154  @torch.jit.script_method
6155  def forward(self):
6156  if self.training:
6157  self.running_var += 1
6158  return self.running_var
6159 
6160  m = ModuleBufferMutate()
6161  self.assertEqual(m(), 1)
6162  m.eval()
6163  self.assertEqual(m(), 1)
6164 
6165  def test_script_module_for(self):
6166  class M(torch.jit.ScriptModule):
6167  __constants__ = ['b']
6168 
6169  def __init__(self):
6170  super(M, self).__init__(False)
6171  self.b = [1, 2, 3, 4]
6172 
6173  @torch.jit.script_method
6174  def forward(self):
6175  sum = 0
6176  for i in self.b:
6177  sum += i
6178  return sum
6179 
6180  m = M()
6181  self.assertEqual(m(), 10)
6182 
6183  def test_script_module_for2(self):
6184  class Sub(torch.jit.ScriptModule):
6185  def __init__(self):
6186  super(Sub, self).__init__(False)
6187  self.weight = nn.Parameter(torch.randn(2))
6188 
6189  @torch.jit.script_method
6190  def forward(self, thing):
6191  return self.weight + thing
6192 
6193  class M(torch.jit.ScriptModule):
6194  __constants__ = ['mods']
6195 
6196  def __init__(self):
6197  super(M, self).__init__(False)
6198  self.mods = nn.ModuleList([Sub() for i in range(10)])
6199 
6200  @torch.jit.script_method
6201  def forward(self, v):
6202  for m in self.mods:
6203  v = m(v)
6204  return v
6205 
6206  i = torch.Tensor(2)
6207  m = M()
6208  o = m(i)
6209  v = i
6210  for sub in m.mods:
6211  v = sub(v)
6212  self.assertEqual(o, v)
6213 
6214  def test_script_module_const_submodule_fail(self):
6215  class Sub(torch.jit.ScriptModule):
6216  def __init__(self):
6217  super(Sub, self).__init__(False)
6218  self.weight = nn.Parameter(torch.randn(2))
6219 
6220  @torch.jit.script_method
6221  def forward(self, thing):
6222  return self.weight + thing
6223 
6224  class M(torch.jit.ScriptModule):
6225  def __init__(self):
6226  super(M, self).__init__(False)
6227  self.mods = [Sub() for _ in range(10)]
6228 
6229  @torch.jit.script_method
6230  def forward(self):
6231  for _ in self.mods:
6232  print(1)
6233  return 4
6234 
6235  with self.assertRaisesRegex(RuntimeError, "did you forget to add it __constants__"):
6236  M()
6237 
6238  # Specialized error for Tensors
6239  class S(torch.jit.ScriptModule):
6240  def __init__(self):
6241  self.tensor_constant = torch.ones(2)
6242 
6243  @torch.jit.script_method
6244  def forward(self):
6245  return self.tensor_constant + 2
6246 
6247  with self.assertRaisesRegex(RuntimeError, "Tensors must be added to a module as a buffer or parameter"):
6248  S()
6249 
6251  def __init__(self):
6252  super(TestScript.DerivedStateModule, self).__init__()
6253  self.param = torch.nn.Parameter(torch.ones(3, 4, dtype=torch.float))
6254  self.register_buffer('derived', torch.neg(self.param).detach().clone())
6255 
6256  # This is a flag so we can test that the pack method was called
6257  self.register_buffer('pack_called', torch.zeros(1, dtype=torch.long))
6258  # This is a flag so we can test that the unpack method was called
6259  self.register_buffer('unpack_called', torch.zeros(1, dtype=torch.long))
6260 
6261  @torch.jit.script_method
6262  def _pack(self):
6263  self.pack_called.set_(torch.ones(1, dtype=torch.long))
6264  self.derived.set_(torch.rand(1, dtype=torch.float).detach())
6265 
6266  @torch.jit.script_method
6267  def _unpack(self):
6268  self.unpack_called.set_(torch.ones(1, dtype=torch.long))
6269  self.derived.set_(torch.neg(self.param).detach())
6270 
6271  @torch.jit.script_method
6272  def forward(self, x):
6273  return x + self.derived
6274 
6275  def test_pack_unpack_state(self):
6277  x = torch.rand(3, 4, dtype=torch.float)
6278  torch.testing.assert_allclose(sm(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
6279 
6280  # Test save path
6281  self.assertFalse(sm.pack_called.item())
6282  self.assertFalse(sm.unpack_called.item())
6283  imported = self.getExportImportCopyWithPacking(sm)
6284  # ensure pack was called before serialization
6285  self.assertTrue(sm.pack_called.item())
6286  # ensure unpack was called after serialization so as to leave the module in an initialized state
6287  self.assertTrue(sm.unpack_called.item())
6288 
6289  torch.testing.assert_allclose(sm.derived, torch.neg(sm.param))
6290 
6291  # Test load paths
6292  self.assertTrue(imported.unpack_called.item())
6293  torch.testing.assert_allclose(imported(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
6294 
6295  def test_pack_unpack_nested(self):
6296  class SubSubMod(torch.jit.ScriptModule):
6297  def __init__(self):
6298  super(SubSubMod, self).__init__()
6299  self.register_buffer('buf', torch.ones(3, 4) * 3)
6300 
6301  @torch.jit.script_method
6302  def _pack(self):
6303  self.buf.set_(torch.zeros(1, dtype=torch.double))
6304 
6305  @torch.jit.script_method
6306  def _unpack(self):
6307  self.buf.set_(torch.ones(3, 4, dtype=torch.double) * 3)
6308 
6309  @torch.jit.script_method
6310  def forward(self, x):
6311  return x + self.buf
6312 
6313  class SubMod(torch.jit.ScriptModule):
6314  def __init__(self):
6315  super(SubMod, self).__init__()
6316  self.register_buffer('buf', torch.ones(3, 4) * 2)
6317  self.ssm = SubSubMod()
6318 
6319  @torch.jit.script_method
6320  def _pack(self):
6321  self.buf.set_(torch.zeros(1, dtype=torch.double))
6322 
6323  @torch.jit.script_method
6324  def _unpack(self):
6325  self.buf.set_(torch.ones(3, 4, dtype=torch.double) * 2)
6326 
6327  @torch.jit.script_method
6328  def forward(self, x):
6329  return self.ssm(x + self.buf)
6330 
6331  class Mod(torch.jit.ScriptModule):
6332  def __init__(self):
6333  super(Mod, self).__init__()
6334  self.submod = SubMod()
6335  self.register_buffer('buf', torch.ones(3, 4) * 1)
6336 
6337  @torch.jit.script_method
6338  def _pack(self):
6339  self.buf.set_(torch.zeros(1, dtype=torch.double))
6340 
6341  @torch.jit.script_method
6342  def _unpack(self):
6343  self.buf.set_(torch.ones(3, 4, dtype=torch.double))
6344 
6345  @torch.jit.script_method
6346  def forward(self, x):
6347  return self.submod(x + self.buf)
6348 
6349  m = Mod()
6350  torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6)
6351  m.apply(lambda s: s._pack())
6352  torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.zeros(3, 4))
6353  m.apply(lambda s: s._unpack())
6354  torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6)
6355 
6356  def test_script_module_not_tuple(self):
6357  class M(torch.jit.ScriptModule):
6358  __constants__ = ['mods']
6359 
6360  def __init__(self):
6361  super(M, self).__init__(False)
6362  self.mods = 1
6363 
6364  @torch.jit.script_method
6365  def forward(self, v):
6366  for m in self.mods:
6367  print(m)
6368  return v
6369  with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
6370  M()
6371 
6372  def test_script_module_list_sequential_error(self):
6373  class M(torch.jit.ScriptModule):
6374  def __init__(self, mod_list):
6375  super(M, self).__init__(False)
6376  self.mods = mod_list
6377 
6378  @torch.jit.script_method
6379  def forward(self, v):
6380  for m in self.mods:
6381  v = m(v)
6382  return v
6383 
6384  with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
6385  a = M(nn.Sequential(nn.ReLU()))
6386  with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
6387  a = M(nn.ModuleList([nn.ReLU()]))
6388 
6389  def test_script_sequential_for(self):
6390  class Sub(torch.jit.ScriptModule):
6391  def __init__(self):
6392  super(Sub, self).__init__(False)
6393  self.weight = nn.Parameter(torch.randn(2))
6394 
6395  @torch.jit.script_method
6396  def forward(self, thing):
6397  return self.weight + thing
6398 
6399  class M(torch.jit.ScriptModule):
6400  __constants__ = ['mods']
6401 
6402  def __init__(self):
6403  super(M, self).__init__(False)
6404  self.mods = nn.Sequential(Sub(), Sub(), Sub())
6405 
6406  @torch.jit.script_method
6407  def forward(self, v):
6408  for m in self.mods:
6409  v = m(v)
6410  return v
6411 
6412  @torch.jit.script_method
6413  def forward2(self, v):
6414  return self.mods(v)
6415 
6416  i = torch.Tensor(2)
6417  m = M()
6418  o = m(i)
6419  v = i
6420  for sub in m.mods:
6421  v = sub(v)
6422  self.assertEqual(o, v)
6423 
6424  o2 = m.forward2(i)
6425  self.assertEqual(o2, v)
6426 
6427  def test_script_sequential_multi_output_fail(self):
6428  class Sub(torch.jit.ScriptModule):
6429  def __init__(self):
6430  super(Sub, self).__init__(False)
6431  self.weight = nn.Parameter(torch.randn(2))
6432 
6433  @torch.jit.script_method
6434  def forward(self, thing):
6435  return self.weight + thing
6436 
6437  class ReturnMulti(torch.jit.ScriptModule):
6438  def __init__(self):
6439  super(ReturnMulti, self).__init__(False)
6440 
6441  @torch.jit.script_method
6442  def forward(self, x):
6443  return x, x, x
6444 
6445  class HaveSequential(torch.jit.ScriptModule):
6446  __constants__ = ['someseq']
6447 
6448  def __init__(self):
6449  super(HaveSequential, self).__init__(False)
6450  self.someseq = nn.Sequential(
6451  Sub(),
6452  ReturnMulti(),
6453  Sub()
6454  )
6455 
6456  @torch.jit.script_method
6457  def forward(self, x):
6458  return self.someseq(x)
6459 
6460  with self.assertRaisesRegex(RuntimeError, "(Tensor, Tensor, Tensor)"):
6461  hs = HaveSequential()
6462  i = torch.Tensor(2)
6463  hs(i)
6464 
6465  def test_constant_insert_fail_lint(self):
6466  @torch.jit.script
6467  def foo(x):
6468  y = x + 1
6469  z = torch.tensor([[1.0, 2.5]])
6470  print(x, z)
6471 
6472  # check that it doesnt error
6473  self.run_pass('constant_propagation', foo.graph)
6474  self.assertTrue("aten::tensor" in str(foo.graph)) # not constant propped
6475 
6476  def test_script_sequential_in_mod_list(self):
6477  class Sub(torch.jit.ScriptModule):
6478  def __init__(self):
6479  super(Sub, self).__init__(False)
6480  self.weight = nn.Parameter(torch.randn(2))
6481 
6482  @torch.jit.script_method
6483  def forward(self, thing):
6484  return self.weight + thing
6485 
6486  class M(torch.jit.ScriptModule):
6487  __constants__ = ['mods']
6488 
6489  def __init__(self):
6490  super(M, self).__init__(False)
6491  self.mods = nn.ModuleList([Sub(), nn.Sequential(Sub(), nn.Sequential(Sub(), Sub()), Sub())])
6492 
6493  @torch.jit.script_method
6494  def forward(self, v):
6495  for mod in self.mods:
6496  v = mod(v)
6497  return v
6498 
6499  m = M()
6500  graph = str(m.graph)
6501  self.assertTrue(graph.count("aten::add") == 5)
6502  self.assertTrue("python" not in graph)
6503 
6504  def test_script_nested_mod_list(self):
6505  class Sub(torch.jit.ScriptModule):
6506  def __init__(self):
6507  super(Sub, self).__init__(False)
6508  self.weight = nn.Parameter(torch.randn(2))
6509 
6510  @torch.jit.script_method
6511  def forward(self, thing):
6512  return self.weight + thing
6513 
6514  class M(torch.jit.ScriptModule):
6515  __constants__ = ['mods']
6516 
6517  def __init__(self):
6518  super(M, self).__init__(False)
6519  self.mods = nn.ModuleList([nn.ModuleList([Sub()]), nn.Sequential(Sub()), nn.ModuleList([Sub(), Sub()])])
6520 
6521  @torch.jit.script_method
6522  def forward(self, v):
6523  for mod in self.mods:
6524  for m in mod:
6525  v = m(v)
6526  return v
6527 
6528  m = M()
6529  graph = str(m.graph)
6530  self.assertTrue(graph.count("aten::add") == 4)
6531  self.assertTrue("python" not in graph)
6532 
6533  def test_constant_as_attr(self):
6534  class M(torch.jit.ScriptModule):
6535  __constants__ = ['dim']
6536 
6537  def __init__(self):
6538  super(M, self).__init__(False)
6539  self.dim = 1
6540 
6541  @torch.jit.script_method
6542  def forward(self, v):
6543  return torch.cat([v, v, v], dim=self.dim)
6544  v = torch.zeros(1, 1)
6545  self.assertEqual(torch.cat([v, v, v], dim=1), M()(v))
6546 
6547  class StarTestSumStarred(torch.nn.Module):
6548  def __init__(self):
6549  super(TestScript.StarTestSumStarred, self).__init__()
6550 
6551  def forward(self, *inputs):
6552  output = inputs[0]
6553  for i in range(1, len(inputs)):
6554  output += inputs[i]
6555  return output
6556 
6557  class StarTestReturnThree(torch.nn.Module):
6558  def __init__(self):
6559  super(TestScript.StarTestReturnThree, self).__init__()
6560 
6561  def forward(self, rep):
6562  return rep, rep, rep
6563 
6564  def test_script_star_expr(self):
6565 
6566  class M2(torch.jit.ScriptModule):
6567  def __init__(self):
6568  super(M2, self).__init__(True)
6570  (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)))
6571  self.g = torch.jit.trace(TestScript.StarTestReturnThree(), torch.ones(4, 3))
6572 
6573  @torch.jit.script_method
6574  def forward(self, rep):
6575  tup = self.g(rep)
6576  return self.m(*tup)
6577 
6578  m = M2()
6579  self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
6580 
6581  def test_script_star_expr_string(self):
6582  class M2(torch.jit.ScriptModule):
6583  def __init__(self):
6584  super(M2, self).__init__(True)
6586  (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)))
6587  self.g = torch.jit.trace(TestScript.StarTestReturnThree(), torch.ones(4, 3))
6588 
6589  self.define('''
6590  def forward(self, rep):
6591  tup = self.g(rep)
6592  return self.m(*tup)
6593  ''')
6594 
6595  m = M2()
6596  self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
6597 
6598  class StarTestSumAndReturnThree(torch.nn.Module):
6599  def __init__(self):
6600  super(TestScript.StarTestSumAndReturnThree, self).__init__()
6601 
6602  def forward(self, *inputs):
6603  output = inputs[0]
6604  for i in range(1, len(inputs)):
6605  output += inputs[i]
6606  return output, output, output
6607 
6608  def test_script_star_assign(self):
6609  class M2(torch.jit.ScriptModule):
6610  def __init__(self):
6611  super(M2, self).__init__(True)
6612  self.g = torch.jit.trace(TestScript.StarTestSumAndReturnThree(), torch.ones(4, 3))
6613  self.define('''
6614  def forward(self, rep):
6615  head, *tail = self.g(rep)
6616  return head
6617  ''')
6618 
6619  m = M2()
6620  self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
6621 
6622  def test_script_module_star_assign2(self):
6623  class M2(torch.jit.ScriptModule):
6624  def __init__(self):
6625  super(M2, self).__init__(True)
6626  self.g = torch.jit.trace(
6628  (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)),
6629  _force_outplace=True)
6630  self.define('''
6631  def forward(self, rep):
6632  *head, tail = self.g(rep, rep, rep)
6633  return tail
6634  ''')
6635 
6636  m = M2()
6637  self.assertEqual(m(torch.ones(4, 3)), 3 * torch.ones(4, 3))
6638 
6639  def test_script_module_star_assign2_inplace(self):
6640  class M2(torch.jit.ScriptModule):
6641  def __init__(self):
6642  super(M2, self).__init__(True)
6643  self.g = torch.jit.trace(
6645  (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)),
6646  _force_outplace=False)
6647  self.define('''
6648  def forward(self, rep):
6649  *head, tail = self.g(rep, rep, rep)
6650  return tail
6651  ''')
6652 
6653  m = M2()
6654  # since forward() makes three aliases to the input `rep` before passing
6655  # it to StarTestSumAndReturnThree(), in-place behavior will be different
6656  # than the above out of place.
6657  self.assertEqual(m(torch.ones(4, 3)), 4 * torch.ones(4, 3))
6658 
6659  def test_script_module_star_assign_fail_pythonop(self):
6660 
6661  with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
6662  class M2(torch.jit.ScriptModule):
6663  def __init__(self):
6664  super(M2, self).__init__(True)
6665 
6666  def myfunc():
6667  return torch.zeros(1, 2, 3), torch.zeros(1, 2, 3)
6668 
6669  self.define('''
6670  def forward(self, rep):
6671  a, *b = myfunc()
6672  return a
6673  ''')
6674 
6675  m = M2()
6676  m(torch.zeros(4, 3))
6677 
6678  def test_script_module_star_assign_fail_builtin(self):
6679  with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
6680  class M2(torch.jit.ScriptModule):
6681  def __init__(self):
6682  super(M2, self).__init__(True)
6683 
6684  self.define('''
6685  def forward(self, rep):
6686  a, *b = torch.neg(rep)
6687  return a
6688  ''')
6689 
6690  m = M2()
6691  m(torch.zeros(4, 3))
6692 
6693  def test_pack_padded_pad_packed_trace(self):
6694  from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
6695  T, B, C = 3, 5, 7
6696 
6697  class PadPackedWrapper(torch.nn.Module):
6698  def __init__(self):
6699  super(PadPackedWrapper, self).__init__()
6700 
6701  def forward(self, x, seq_lens):
6702  x = pack_padded_sequence(x, seq_lens)
6703  x, _ = pad_packed_sequence(x)
6704  return x
6705 
6706  x = np.ones((T, B, C))
6707  seq_lens = np.array([3, 3, 2, 2, 1], dtype=np.int32)
6708  # set padding value so we can test equivalence
6709  for b in range(B):
6710  if seq_lens[b] < T:
6711  x[seq_lens[b]:, b, :] = 0
6712  seq_lens = torch.from_numpy(seq_lens)
6713  x = torch.autograd.Variable(torch.from_numpy(x), requires_grad=True)
6714 
6715  m = PadPackedWrapper()
6716  m_traced = torch.jit.trace(m, (x, seq_lens,))
6717 
6718  y = m(x, seq_lens)
6719  loss = torch.sum(y)
6720  loss.backward()
6721  grad = x.grad.clone()
6722  x.grad.zero_()
6723 
6724  y_traced = m_traced(x, seq_lens)
6725  loss_traced = torch.sum(y_traced)
6726  loss_traced.backward()
6727  grad_traced = x.grad.clone()
6728 
6729  self.assertEqual(y_traced, x)
6730  self.assertEqual(y_traced, y)
6731  self.assertEqual(grad, grad_traced)
6732 
6733  f = io.BytesIO()
6734  torch.onnx._export(m, (x, seq_lens), f, verbose=False)
6735 
6736  def test_script_outputs(self):
6737  with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
6738  @torch.jit.script
6739  def foo(a):
6740  c, d = a + a
6741  return c + d
6742 
6743  @torch.jit.script
6744  def return3():
6745  return 1, 2, 3
6746 
6747  with self.assertRaisesRegex(RuntimeError, "too many values to unpack"):
6748  @torch.jit.script
6749  def bind2():
6750  a, b = return3()
6751  print(a)
6752  print(b)
6753 
6754  @unittest.skipIf(not RUN_CUDA, "requires CUDA")
6755  def test_script_get_device_cuda(self):
6756  @torch.jit.script
6757  def foo(a):
6758  return a.get_device()
6759 
6760  v = torch.randn(1, device='cuda')
6761  self.assertEqual(foo(v), 0)
6762 
6763  def test_script_chunk(self):
6764  @torch.jit.script
6765  def foo(a):
6766  b, c = torch.chunk(a, dim=0, chunks=2)
6767  return b
6768  v = torch.rand(10, 3)
6769  self.assertEqual(torch.chunk(v, dim=0, chunks=2)[0], foo(v))
6770 
6771  def test_rnn_trace_override(self):
6772  from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
6773  num_layers = 3
6774  T, B, C = 11, 5, 7
6775 
6776  class RNNTraceWrapper(torch.nn.Module):
6777  def __init__(self, cell_type):
6778  super(RNNTraceWrapper, self).__init__()
6779  if cell_type == 'RNN':
6780  self.rnn = torch.nn.RNN(input_size=C, hidden_size=C, num_layers=num_layers)
6781  elif cell_type == 'LSTM':
6782  self.rnn = torch.nn.LSTM(input_size=C, hidden_size=C, num_layers=num_layers)
6783  elif cell_type == 'GRU':
6784  self.rnn = torch.nn.GRU(input_size=C, hidden_size=C, num_layers=num_layers)
6785 
6786  def forward(self, x, seq_lens):
6787  x = pack_padded_sequence(x, seq_lens)
6788  x, _ = self.rnn(x)
6789  x, _ = pad_packed_sequence(x)
6790  return x
6791 
6792  for cell_type in ['RNN', 'LSTM', 'GRU']:
6793  x = torch.ones(T, B, C, requires_grad=True)
6794  seq_lens = torch.from_numpy(np.array([11, 3, 2, 2, 1], dtype=np.int32))
6795 
6796  m = RNNTraceWrapper(cell_type)
6797  m_traced = torch.jit.trace(m, (x, seq_lens,))
6798 
6799  y = m(x, seq_lens)
6800  loss = torch.sum(y)
6801  loss.backward()
6802  grad = x.grad.clone()
6803  x.grad.zero_()
6804 
6805  y_traced = m_traced(x, seq_lens)
6806  loss_traced = torch.sum(y_traced)
6807  loss_traced.backward()
6808  grad_traced = x.grad.clone()
6809 
6810  self.assertEqual(y_traced, y)
6811  self.assertEqual(grad, grad_traced)
6812 
6813  f = io.BytesIO()
6814  torch.onnx._export(m, (x, seq_lens), f, verbose=False)
6815 
6816  def test_python_call_non_tensor(self):
6817  def foo(a, b, c):
6818  # type: (Tensor, int, Tuple[Tensor, int]) -> Tuple[int, Tensor]
6819  d, e = c
6820  return b + e, a + d
6821 
6822  @torch.jit.script
6823  def bar():
6824  x = torch.ones(3, 4)
6825  a, b = foo(x, 3, (x, 3))
6826  return a, b
6827 
6828  self.assertEqual((6, torch.ones(3, 4) + 1), bar())
6829 
6830  def test_python_call_non_tensor_wrong(self):
6831  with self.assertRaisesRegex(RuntimeError, r"but instead got value of type tuple"):
6832  def foo():
6833  # type: () -> Tensor
6834  return ((3, 4),) # noqa: T484
6835 
6836  @torch.jit.script
6837  def bar():
6838  return foo()
6839 
6840  bar()
6841 
6842  def test_tuples(self):
6843  def foo(i):
6844  a = (i + 4, i * 2)
6845  c = a
6846  # some nonsense with if-statements and loops to check
6847  # that tuple lowering doesn't fail
6848  if True:
6849  c = (i * 9, i + 1)
6850  t0, t1 = c
6851  while False:
6852  t0, t1 = c
6853  c = (t1, t0)
6854  x = (1,)
6855  y = 1,
6856  return t0, x, y
6857 
6858  v = torch.rand(10, 3)
6859  self.checkScript(foo, (v,))
6860 
6861  with self.assertRaisesRegex(RuntimeError, r"variable 'a' previously has type \(Tensor, Tensor\)"):
6862  @torch.jit.script
6863  def mixtypes(x):
6864  a = (x, x)
6865  if True:
6866  a = 4
6867 
6868  def test_if_tuple_sizes(self):
6869  with self.assertRaisesRegex(RuntimeError, "Type mismatch"):
6870  @torch.jit.script
6871  def diff_tuple_sizes(x):
6872  if False:
6873  c0 = ((x, x), (x, x, x))
6874  else<