1 from __future__
import division
10 from contextlib
import contextmanager
11 from itertools
import product, chain
19 from common_utils
import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
20 skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
21 freeze_rng_state, set_rng_seed
22 from common_nn
import module_tests, new_module_tests, criterion_tests
23 from textwrap
import dedent
24 from functools
import wraps
41 from common_methods_invocations
import method_tests
as autograd_method_tests
42 from common_methods_invocations
import create_input, unpack_variables, \
43 exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL
45 from torch._C import TensorType, TupleType, FloatType, IntType, \
46 ListType, StringType, DictType
47 from copy
import deepcopy
49 from typing
import List, Dict, Optional, Tuple
52 from torch
import Tensor
62 load_tests = load_tests
66 HAS_TORCHVISION =
True 68 HAS_TORCHVISION =
False 71 skipIfNoTorchVision = unittest.skipIf(
not HAS_TORCHVISION,
"no torchvision")
74 RUN_CUDA_HALF = RUN_CUDA
76 CUDA_VERSION = torch._C._cuda_getCompiledVersion()
79 if (CUDA_VERSION < 8000
and major >= 6)
or (CUDA_VERSION < 9000
and major >= 7):
81 if (CUDA_VERSION < 9000
or major < 6):
86 PY35 = sys.version_info >= (3, 5)
87 WINDOWS = sys.platform ==
'win32' 92 def TemporaryFileName():
96 f = tempfile.NamedTemporaryFile(delete=
False)
104 def TemporaryFileName():
105 with tempfile.NamedTemporaryFile()
as f:
109 def LSTMCellF(input, hx, cx, *params):
110 return LSTMCell(input, (hx, cx), *params)
113 def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
115 gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
117 ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
118 ingate = torch.sigmoid(ingate)
119 forgetgate = torch.sigmoid(forgetgate)
120 cellgate = torch.tanh(cellgate)
121 outgate = torch.sigmoid(outgate)
123 cy = (forgetgate * cx) + (ingate * cellgate)
124 hy = outgate * torch.tanh(cy)
128 def LSTMCellC(*args, **kwargs):
129 hy, cy = LSTMCellF(*args, **kwargs)
130 return torch.cat((hy, cy))
133 def LSTMCellS(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
134 gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh
135 ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
136 ingate = torch.sigmoid(ingate)
137 forgetgate = torch.sigmoid(forgetgate)
138 cellgate = torch.tanh(cellgate)
139 outgate = torch.sigmoid(outgate)
140 cy = (forgetgate * cx) + (ingate * cellgate)
141 hy = outgate * torch.tanh(cy)
146 def MiLSTMCell(x, hx, cx, w_ih, w_hh, alpha, beta_i, beta_h, bias):
150 gates = alpha * Wx * Uz + beta_i * Wx + beta_h * Uz + bias
152 ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
153 ingate = ingate.sigmoid()
154 forgetgate = forgetgate.sigmoid()
155 cellgate = cellgate.tanh()
156 outgate = outgate.sigmoid()
157 cy = (forgetgate * cx) + (ingate * cellgate)
158 hy = outgate * cy.tanh()
162 def canonical(graph):
163 return str(torch._C._jit_pass_canonicalize(graph))
166 def get_lstm_inputs(device, training=False, seq_length=None):
167 input_shape = (3, 10)
if seq_length
is None else (seq_length, 3, 10)
168 input = torch.randn(*input_shape, dtype=torch.float, device=device, requires_grad=training)
169 hx = torch.randn(3, 20, dtype=torch.float, device=device, requires_grad=training)
170 cx = torch.randn(3, 20, dtype=torch.float, device=device, requires_grad=training)
171 module = nn.LSTMCell(10, 20).to(device, torch.float)
173 params = tuple(module.parameters())
175 params = tuple(p.requires_grad_(
False)
for p
in module.parameters())
176 return (input, hx, cx) + params
179 def get_milstm_inputs(device, training=False):
183 x = torch.randn(minibatch, input_size, device=device, dtype=torch.float)
184 hx = torch.randn(minibatch, hidden_size, device=device, dtype=torch.float)
185 cx = torch.randn(minibatch, hidden_size, device=device, dtype=torch.float)
187 ih = torch.randn(4 * hidden_size, input_size, device=device, dtype=torch.float, requires_grad=training)
188 hh = torch.randn(4 * hidden_size, hidden_size, device=device, dtype=torch.float, requires_grad=training)
189 alpha = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
190 ibeta = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
191 hbeta = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
192 bias = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
193 return x, hx, cx, ih, hh, alpha, ibeta, hbeta, bias
196 def get_fn(file_name, script_path):
197 import importlib.util
198 spec = importlib.util.spec_from_file_location(file_name, script_path)
199 module = importlib.util.module_from_spec(spec)
200 spec.loader.exec_module(module)
205 def get_execution_plan(graph_executor_state):
206 execution_plans = list(graph_executor_state.execution_plans.values())
207 num_plans = len(execution_plans)
209 raise RuntimeError(
'This test assumes this GraphExecutor should ' 210 'only have one execution plan, got: {}'.format(num_plans))
211 return execution_plans[0]
214 def get_grad_executor(plan_state, diff_graph_idx=None):
215 if diff_graph_idx
is None:
216 nodes = list(plan_state.graph.nodes())
217 if len(nodes) == 1
or (len(nodes) == 2
and nodes[1].kind() ==
"prim::TupleConstruct"):
220 raise RuntimeError(
"Can't get a grad_executor for a non-differentiable graph")
221 grad_executors = list(plan_state.code.grad_executors())
222 return grad_executors[diff_graph_idx
or 0]
225 def backward_graph(script_module, diff_graph_idx=None):
227 raise RuntimeError(
'Expected ScriptModule')
228 ge_state = script_module.get_debug_state()
229 fwd_plan = get_execution_plan(ge_state)
230 grad_executor = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx)
231 bwd_plan = get_execution_plan(grad_executor.get_debug_state())
234 return bwd_plan.graph.copy()
238 def _trace(*args, **kwargs):
244 def enable_cpu_fuser(fn):
245 def wrapper(*args, **kwargs):
246 torch._C._jit_override_can_fuse_on_cpu(
True)
250 torch._C._jit_override_can_fuse_on_cpu(
False)
255 _do_cuda_memory_leak_check =
True 256 _restored_warnings =
False 262 if not JitTestCase._restored_warnings:
264 JitTestCase._restored_warnings =
True 270 torch._C._jit_set_emit_module_hook(
None)
271 torch._C._jit_clear_class_registry()
274 def disableModuleHook(self):
275 torch._C._jit_set_emit_module_hook(
None)
279 def emitModuleHook(self, module):
280 def copy_structure_and_params(m):
282 for name, v
in m._get_parameters():
283 c._register_parameter(name, v,
False)
284 for name, the_type, v
in m._get_attributes():
285 c._register_attribute(name, the_type, v)
286 for name, s
in m._get_modules():
287 c._register_module(name, copy_structure_and_params(s))
293 pp, constant_table = module._python_print()
294 except RuntimeError
as e:
296 if "could not export python function" not in se
and \
297 "closures are not exportable" not in se:
301 ppv =
"op_version_set = 0\n{}".format(pp)
302 sm = copy_structure_and_params(module)
303 torch._C._jit_import_methods(sm, ppv, constant_table)
304 pp2, _ = sm._python_print()
306 self.assertMultiLineEqual(pp, pp2)
308 def getExportImportCopy(self, m, also_test_file=True, map_location=None):
309 buffer = io.BytesIO()
314 if not also_test_file:
317 with TemporaryFileName()
as fname:
321 def getExportImportCopyWithPacking(self, m, also_test_file=True, map_location=None):
322 buffer = io.BytesIO()
323 m.apply(
lambda s: s._pack()
if s._has_method(
'_pack')
else None)
325 m.apply(
lambda s: s._unpack()
if s._has_method(
'_unpack')
else None)
328 imported.apply(
lambda s: s._unpack()
if s._has_method(
'_unpack')
else None)
330 if not also_test_file:
336 f = tempfile.NamedTemporaryFile(delete=
False)
339 imported.save(f.name)
344 result.apply(
lambda s: s._unpack()
if s._has_method(
'_unpack')
else None)
347 def assertGraphContains(self, graph, kind):
348 self.assertTrue(any(n.kind() == kind
for n
in graph.nodes()))
350 def assertGraphContainsExactly(self, graph, kind, num_kind_nodes, consider_subgraphs=False):
351 def perform_assert(graph, kind, actual, expected, consider_subgraphs):
352 if actual == expected:
354 subgraph =
'including' if consider_subgraphs
else 'excluding' 355 raise AssertionError(
356 '{}\nError: graph contains {} {} nodes ({} subgraphs) but expected {}'.format(
357 graph, actual, kind, subgraph, expected))
359 if consider_subgraphs:
360 strgraph = str(graph)
361 count = strgraph.count(kind) - strgraph.count(
'with {}'.format(kind))
362 perform_assert(graph, kind, count, num_kind_nodes,
366 nodes = [node
for node
in graph.nodes()
367 if node.kind() == kind]
368 perform_assert(graph, kind, len(nodes), num_kind_nodes,
371 def assertExpectedONNXGraph(self, trace, *args, **kwargs):
375 def assertExpectedGraph(self, trace, *args, **kwargs):
376 if isinstance(trace, torch._C.Graph):
379 graph = trace.graph()
381 torch._C._jit_pass_lint(graph)
382 torch._C._jit_pass_dce(graph)
383 torch._C._jit_pass_lint(graph)
384 graph = torch._C._jit_pass_canonicalize(graph)
385 torch._C._jit_pass_lint(graph)
388 def run_pass(self, name, trace):
389 if isinstance(trace, torch._C.Graph):
394 graph = trace.graph()
396 torch._C._jit_pass_lint(graph)
397 result = getattr(
torch._C,
'_jit_pass_' + name)(graph)
398 if result
is not None:
400 torch._C._jit_pass_lint(graph)
403 trace.set_graph(graph)
406 def checkScript(self,
412 capture_output=
False,
414 check_expected=
False):
415 if isinstance(script, str):
417 ge = getattr(cu, name)
420 with self.capture_stdout()
as captured:
421 outputs = script(*inputs)
423 outputs = script(*inputs)
425 source = textwrap.dedent(inspect.getsource(script))
434 check_expected=check_expected)
439 with self.capture_stdout()
as captured:
440 outputs_ge = ge(*inputs)
444 outputs_ge = ge(*inputs)
452 def checkTrace(self, func, reference_tensors, input_tensors=None,
453 optimize=
True, drop=
None, allow_unused=
False, verbose=
False,
454 inputs_require_grads=
True, check_tolerance=1e-5, export_import=
True,
455 _force_outplace=
False):
464 return sum(math.log(i + 2) * v.sum()
for i, v
in enumerate(vs)
if v
is not None)
465 if input_tensors
is None:
466 input_tensors = reference_tensors
468 nograd_inputs = reference_tensors
469 if inputs_require_grads:
470 recording_inputs = [t.clone().requires_grad_()
for t
in reference_tensors]
472 recording_inputs = reference_tensors
474 if isinstance(func, torch._C.Graph):
475 ge = torch._C.GraphExecutor(func, optimize)
477 ge =
torch.jit.trace(func, input_tensors, optimize=optimize, check_tolerance=check_tolerance,
478 _force_outplace=_force_outplace)
487 outputs = func(*nograd_inputs)
488 outputs_ge = ge(*nograd_inputs)
492 outputs = func(*recording_inputs)
493 if inputs_require_grads:
495 allow_unused=allow_unused)
497 outputs_ge = ge(*recording_inputs)
498 if inputs_require_grads:
500 allow_unused=allow_unused)
502 if inputs_require_grads:
507 outputs = func(*recording_inputs)
509 if inputs_require_grads:
511 allow_unused=allow_unused)
512 if inputs_require_grads:
513 l2 = (allSum(grads) * l1)
516 if inputs_require_grads:
517 recording_inputs = [Variable(t, requires_grad=
True)
518 for t
in reference_tensors]
520 outputs_ge = ge(*recording_inputs)
521 l1_ge = allSum(outputs_ge)
522 if inputs_require_grads:
524 l1_ge, recording_inputs, create_graph=
True, allow_unused=allow_unused)
526 if inputs_require_grads:
527 l2_ge = (allSum(grads_ge) * l1_ge)
531 if inputs_require_grads:
533 for g2, g2_ge
in zip(grads2, grads2_ge):
534 if g2
is None and g2_ge
is None:
536 self.assertTrue(torch.allclose(g2, g2_ge, atol=8e-4, rtol=8e-4))
540 def createScriptModuleFromGraph(self, trace):
541 graph = trace
if isinstance(trace, torch._C.Graph)
else trace.graph()
543 m._create_method_from_graph(
"forward", graph)
546 def assertExportImport(self, trace, inputs):
550 def assertExportImportModule(self, m, inputs):
555 def runAndSaveRNG(self, func, inputs, kwargs=None):
556 kwargs = kwargs
if kwargs
else {}
557 with freeze_rng_state():
558 results = func(*inputs, **kwargs)
565 super(FooToPickle, self).__init__()
571 @unittest.skip(
"Requires a lot of RAM")
574 gig = int(1024 * 1024 * 1024 / 4)
576 m.v0 = nn.Parameter(torch.full((2,), 1, dtype=torch.float))
578 m.v1 = nn.Parameter(torch.full((5, gig), 2, dtype=torch.float))
580 m.v2 = nn.Parameter(torch.full((2,), 3, dtype=torch.float))
582 m.v3 = nn.Parameter(torch.full((5, gig), 4, dtype=torch.float))
586 self.
assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
588 def test_simple(self):
593 return torch.sigmoid(torch.tanh(x * (x + y)))
597 def test_restore_device(self):
600 cpu_device_str =
'cpu' 601 m.p0 = nn.Parameter(
torch.tensor([0.3], dtype=torch.float,
602 device=cpu_device_str))
603 m.register_buffer(
'b0',
torch.tensor([0.9], dtype=torch.float,
604 device=cpu_device_str))
606 self.
assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
607 self.
assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
608 self.assertFalse(m2.p0.is_cuda)
609 self.assertFalse(m2.b0.is_cuda)
611 def test_model_save_error(self):
612 with TemporaryFileName()
as fname:
616 def test_single_tuple_trace(self):
622 assert f2(x) == jit_f2(x)
624 @unittest.skipIf(
not RUN_CUDA,
"restore device requires CUDA")
625 def test_restore_device_cuda(self):
628 super(MyModule, self).__init__(
False)
629 self.register_buffer(
'b0', torch.randn(1, 3))
630 self.
p0 = nn.Parameter(torch.randn(2, 3))
632 @torch.jit.script_method
633 def forward(self, x):
634 return x + self.b0 + self.
p0 640 self.assertTrue(m.p0.is_cuda)
641 self.assertTrue(m.b0.is_cuda)
645 self.
assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
646 self.
assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
647 self.
assertEqual(str(m2.p0.device), cuda_device_str)
648 self.
assertEqual(str(m2.b0.device), cuda_device_str)
651 cpu_device_str =
'cpu' 653 self.
assertEqual(str(m3.p0.device), cpu_device_str)
654 self.
assertEqual(str(m3.b0.device), cpu_device_str)
658 m3, map_location=torch.device(
'cuda:0'))
664 origin_result = m(input)
669 @unittest.skipIf(
not RUN_CUDA,
"restore device requires CUDA")
670 def test_restore_shared_storage_on_cuda(self):
671 whole_tensor = torch.randn(4, 5, dtype=torch.float, device=
'cpu')
673 m.p0 = nn.Parameter(whole_tensor.narrow(0, 0, 1))
674 m.register_buffer(
'b0', whole_tensor.narrow(0, 3, 1))
676 self.
assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
677 self.
assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
678 self.assertTrue(m2.p0.is_cuda)
679 self.assertTrue(m2.b0.is_cuda)
680 self.assertTrue(m2.p0.is_shared())
681 self.assertTrue(m2.b0.is_shared())
682 self.
assertEqual(m2.b0.storage().data_ptr(), m2.p0.storage().data_ptr())
684 def test_typeas_trace_check(self):
693 def test_peephole(self):
702 FileCheck().check(
"type_as").run(str(tf.graph))
704 FileCheck().check_not(
"type_as").run(str(tf.graph))
707 self.
run_pass(
'peephole', tf2.graph)
710 def test_peephole_dynamic(self):
716 torch._C._jit_pass_peephole(fn.graph)
719 @unittest.skipIf(
not RUN_CUDA,
"cpp tests require CUDA")
720 def test_peephole_cuda(self):
730 self.
run_pass(
'peephole', trace.graph)
733 self.
run_pass(
'peephole', trace.graph)
734 self.assertTrue(len(list(trace.graph.nodes())) == 0)
736 def test_index(self):
747 def test_disabled(self):
748 torch.jit._enabled =
False 753 self.assertIs(
torch.jit.trace(f, (torch.randn(2, 2), torch.randn(2, 2))), f)
757 @torch.jit.script_method
765 self.assertTrue(inspect.ismethod(MyModule.method)
or inspect.isfunction(MyModule.method))
767 torch.jit._enabled =
True 769 def test_train_eval(self):
770 class Sub(nn.Module):
771 def forward(self, input):
778 def __init__(self, module):
779 super(MyModule, self).__init__()
782 @torch.jit.script_method
783 def forward(self, input):
784 return self.
module(input) + 1
787 input = torch.rand(3, 4)
793 input = torch.randn(6, 10)
794 batchnorm = nn.BatchNorm1d(10)
795 dropout = nn.Dropout(p=0.2)
797 m_batchnorm = MyModule(batchnorm)
798 self.
assertEqual(batchnorm(input) + 1, m_batchnorm(input))
801 self.
assertEqual(batchnorm(input) + 1, m_batchnorm(input))
803 m_dropout = MyModule(dropout)
808 def test_diff_subgraph_clones_constants(self):
811 return x + x + y + x + y + x + y + x + y + x
813 def count_constants(graph):
814 return sum(node.kind() ==
'prim::Constant' for node
in graph.nodes())
816 graph = f.graph.copy()
818 self.
run_pass(
'create_autodiff_subgraphs', graph)
819 nodes = list(graph.nodes())
821 self.
assertEqual(count_constants(nodes[1].g(
'Subgraph')), 1)
828 def test_index_constant(self):
842 def test_scopes(self):
851 out = torch.tanh(out)
852 out = torch.sigmoid(out)
857 def test_scopes_intermediate_node(self):
858 class Net(nn.Module):
859 def forward(self, x):
860 return F.log_softmax(x, dim=0)
863 t = torch.ones(2, requires_grad=
True)
868 FileCheck().check(
"onnx::LogSoftmax").check(
"scope: Net").run(str(trace))
870 def test_scopes_identity_node(self):
872 class Net(nn.Module):
875 super(Net, self).__init__()
877 nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
878 nn.ReLU(inplace=
True),
879 nn.MaxPool2d(kernel_size=3, stride=2),
882 def forward(self, x):
888 t = torch.ones(1, 3, 227, 227, requires_grad=
True)
895 FileCheck().check(
"Net/Sequential[features]/Conv2d[0]").check(
"ReLU").check(
"MaxPool").run(str(trace))
897 def test_canonicalize_tensor_iterator(self):
898 x = torch.randn(4, 4)
909 graph = traced.graph_for(x)
912 self.assertTrue(str(traced.graph_for(x)).count(
': int = prim::Constant') == 5)
915 @unittest.skip(
"Need to be adjusted to Graph Executor")
917 """Different arg configurations should trigger different traces""" 918 x = Variable(torch.FloatTensor(4, 4).uniform_())
919 x_double = Variable(x.data.double())
920 x_grad = Variable(x.data.clone(), requires_grad=
True)
921 y = Variable(torch.randn(4))
932 x_cuda = Variable(x.data.cuda())
940 x_cuda_1 = Variable(x.data.cuda(1))
943 ([x_cuda, x_cuda_1],),
946 @torch.jit.compile(nderivs=0)
948 in_vars, _ = torch._C._jit_flatten(args)
949 return in_vars[0] + 1
951 for i, config
in enumerate(configurations):
952 self.assertFalse(fn.has_trace_for(*config))
954 self.assertTrue(fn.has_trace_for(*config))
955 for unk_config
in configurations[i + 1:]:
956 self.assertFalse(fn.has_trace_for(*unk_config))
964 w = (x + y) * (x + y) * (x + y)
965 t = torch.tanh(w) + torch.tanh(w)
966 z = (x + y) * (x + y) * (x + y) + t
972 FileCheck().check_count(
"add", 1).check_count(
"mul", 2, do_exactly) \
973 .check_count(
"tanh", 1, do_exactly).check_count(
"add", 2, do_exactly).check_next(
"return") \
978 def test_recursive_cse(self):
990 FileCheck().check(
"block").check_not(
"aten::add").check_not(
"aten::gt").run(str(graph))
992 def test_shape_analysis_broadcast(self):
996 x = torch.randn(3, 1, 5, requires_grad=
True)
997 y = torch.randn(4, 1, 8, 5, requires_grad=
True)
1000 torch._C._jit_pass_complete_shape_analysis(graph, (x, y),
False)
1001 FileCheck().check(
"Double(4, 3, 8, 5)").run(str(graph))
1004 @unittest.skip(
"verify needs to be updated to work with GraphExecutors")
1011 z = torch.sigmoid(x * (x + y))
1012 w = torch.abs(x * x * x + y) + Variable(torch.ones(1))
1018 def test_constant(self):
1019 x = torch.randn(2, 2, requires_grad=
True)
1024 self.
checkTrace(f, (x,), (torch.ones(2, 2, requires_grad=
True),))
1026 def test_legacy_fail(self):
1027 class MyLegacyFn(Function):
1028 def forward(self, x):
1031 def backward(self, grad_output):
1038 def test_inplace_transplant(self):
1049 FileCheck().check_count(
"aten::clone", 1, exactly=
True) \
1050 .check_count(
"aten::add_", 2, exactly=
True) \
1051 .check_next(
"return").run(str(trace))
1054 def test_inplace_flags(self):
1055 class InplaceFn(Function):
1057 def forward(ctx, x):
1062 def backward(ctx, go):
1065 class RegularFn(Function):
1067 def forward(ctx, x):
1071 def backward(ctx, go):
1077 y = RegularFn.apply(x)
1078 y = InplaceFn.apply(y)
1079 y = InplaceFn.apply(y)
1080 y = RegularFn.apply(y)
1085 ops = [n
for n
in trace.graph().nodes()]
1087 self.assertTrue(op.hasAttribute(
'inplace'))
1088 inplace_flags = [
False,
True,
True,
False]
1089 for op, is_inplace
in zip(ops, inplace_flags):
1092 def test_inplace_check(self):
1093 class MyInplaceFn(Function):
1095 def forward(self, x):
1101 def backward(self, grad):
1105 return MyInplaceFn.apply(x)
1107 x = torch.randn(5, 5)
1108 ge = torch._C.GraphExecutor(fn, (x,),
lambda var:
'', _force_outplace=
True)
1112 def do_trace_size(self, requires_grad):
1114 return x.view(x.shape[1] * 2, x.size(0), 2)
1116 x = torch.randn(5, 2, 4, requires_grad=requires_grad)
1117 y = torch.randn(4, 8, 4, requires_grad=requires_grad)
1124 def test_trace_size(self):
1129 def test_trace_size_with_grad(self):
1132 def test_trace_casts(self):
1135 lambda x: x.float(),
1137 lambda x: x.to(device=
'cpu'),
1138 lambda x: x.to(dtype=torch.int64),
1139 lambda x: x.to(device=
'cpu', dtype=torch.float),
1143 def assertContainsCast(trace):
1144 self.
assertEqual(sum(n.kind() ==
'aten::to' for n
in trace.graph.nodes()), 1)
1148 assertContainsCast(trace)
1149 x = torch.randn(2, 2)
1152 def to_tensor(x, y):
1155 to_tensor_trace =
torch.jit.trace(to_tensor, (torch.randn(2, 2), torch.randn(1, 8)))
1156 assertContainsCast(to_tensor_trace)
1157 x, y = torch.randn(2, 2), torch.randn(1, 10)
1158 self.
assertEqual(to_tensor_trace(x, y), to_tensor(x, y))
1160 def test_trace_warn(self):
1171 for _
in torch.ones(4, 4):
1175 with warnings.catch_warnings(record=
True)
as warns:
1177 warns = [str(w.message)
for w
in warns]
1179 self.assertIn(
'a Python integer', warns[0])
1180 self.assertIn(
'a Python boolean', warns[1])
1181 self.assertIn(
'a Python index', warns[2])
1182 self.assertIn(
'a Python float', warns[3])
1183 self.assertIn(
'a Python list', warns[4])
1184 self.assertIn(
'a NumPy array', warns[5])
1185 self.assertIn(
'Iterating over', warns[6])
1187 def test_trace_tuple(self):
1189 return x, (x * y[1], x * y[0])
1191 x, y = torch.randn(2, 2), (torch.ones(2, 2), torch.randn(2, 2))
1195 FileCheck().check_count(
"prim::TupleConstruct", 2, exactly=
True).check_next(
"return") \
1196 .run(str(traced_fn.graph))
1199 def test_trace_random(self):
1201 return torch.normal(mean, std)
1203 traced =
torch.jit.trace(f, (torch.zeros(2, 3), torch.ones(2, 3)), check_trace=
False)
1204 mean, std = torch.zeros(5, 5), torch.ones(5, 5)
1206 output = f(mean, std)
1207 traced_output = traced(mean, std)
1210 def test_trace_tensor_factory(self):
1212 inputs_require_grads = kwargs.pop(
'inputs_require_grads',
True)
1215 return x + torch.ones(2, 3, **kwargs)
1217 input_kwargs = kwargs.copy()
1218 if 'out' in input_kwargs:
1219 del input_kwargs[
'out']
1220 input = torch.ones(2, 3, **input_kwargs)
1221 self.
checkTrace(fn, (input,), inputs_require_grads=inputs_require_grads)
1224 self.assertTrue(
"ones" in str(tfn.graph))
1226 run(dtype=torch.int, inputs_require_grads=
False)
1229 run(device=
"cuda:0")
1230 if RUN_CUDA_MULTI_GPU:
1231 run(device=
"cuda:1")
1233 def test_trace_indexed_assignment(self):
1238 example = torch.rand(3, 4)
1239 self.
checkTrace(stuff, (example, example[0] + 1))
1242 @unittest.expectedFailure
1244 """Check that outputs of traced functions retain the original structure and nesting""" 1246 return (x * 2, (x ** 2, x + 4, (x + 2,), ), x * 4)
1251 @unittest.expectedFailure
1253 """Check that inputs to traced functions are flattened""" 1259 inputs = (torch.randn(1), (torch.randn(1), torch.randn(1)))
1263 @unittest.skip(
"Need to instrument GraphExecutors a bit more")
1264 def test_flags(self):
1265 x, y = torch.randn(2, 2)
1266 y = Variable(torch.randn(2, 2))
1270 return (x * x + y * y + x * y).sum()
1273 for rx, ry
in product((
True,
False), repeat=2):
1274 x.requires_grad = rx
1275 y.requires_grad = ry
1277 self.assertFalse(fn.has_trace_for(x, y))
1280 self.assertFalse(fn.has_trace_for(x, y))
1281 for v, name, compute
in [(x,
'x', rx), (y,
'y', ry)]:
1285 expected_grad = grads.setdefault(name, grad_v)
1287 self.
assertEqual(fn.has_trace_for(x, y), rx
or ry)
1289 def test_python_ir(self):
1294 return torch.sigmoid(torch.tanh(x * (x + y)))
1298 self.
run_pass(
'canonicalize', trace)
1300 g2 = torch._C.Graph()
1302 for node
in g.inputs():
1303 g_to_g2[node] = g2.addInput()
1304 for node
in g.nodes():
1305 n_ = g2.createClone(node,
lambda x: g_to_g2[x])
1307 for o, no
in zip(node.outputs(), n_.outputs()):
1310 for node
in g.outputs():
1311 g2.registerOutput(g_to_g2[node])
1313 t_node = g2.create(
"prim::TensorTest").t_(
"a", torch.ones([2, 2]))
1315 g2.appendNode(t_node)
1316 self.assertTrue(torch.equal(torch.ones(2, 2), t_node.t(
"a")))
1317 for node
in g.nodes():
1318 self.assertTrue(g2.findNode(node.kind())
is not None)
1320 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
1321 @unittest.skipIf(
not RUN_CUDA,
"cpp tests require CUDA")
1323 def test_cpp_cuda(self):
1324 from cpp.jit import tests_setup
1326 torch._C._jit_run_cpp_tests()
1327 tests_setup.shutdown()
1329 def test_batchnorm(self):
1330 x = torch.ones(2, 2, 2, 2)
1332 _force_outplace=
True, return_inputs=
True)
1336 def test_dropout(self):
1337 x = torch.ones(2, 2)
1344 @unittest.skipIf(
not RUN_CUDA,
"test_dropout_cuda require CUDA")
1345 def test_dropout_cuda(self):
1348 x = torch.ones(4, 4).cuda().requires_grad_()
1354 with freeze_rng_state():
1358 with freeze_rng_state():
1365 def test_conv(self):
1366 x = torch.ones(20, 16, 50, 40)
1371 def test_repeated_input(self):
1375 ge = self.
checkTrace(fn, [torch.randn(2, 2)] * 2)
1376 inputs = set(ge.graph.inputs())
1377 self.assertTrue(len(inputs) == 2)
1379 def test_repeated_output(self):
1384 ge = self.
checkTrace(fn, [torch.randn(2, 2)
for _
in range(2)])
1385 tuple_output = list(ge.graph.outputs())[0]
1386 tuple_inputs = list(tuple_output.node().inputs())
1387 self.assertTrue(tuple_inputs[0] == tuple_inputs[1])
1389 @skipIfNoTorchVision
1390 def test_alexnet(self):
1391 x = torch.ones(1, 3, 224, 224)
1392 model = torchvision.models.AlexNet()
1400 def test_inplace_copy(self):
1401 x = torch.randn(4, 4, requires_grad=
True)
1404 out = Variable(torch.zeros(x.size()))
1414 def test_shared_param(self):
1415 class MyModule(torch.nn.Module):
1417 super(MyModule, self).__init__()
1418 self.
b = self.
a = nn.Parameter(torch.randn(2, 2))
1420 def forward(self, x):
1421 return x * self.
a + self.
b 1426 self.
assertEqual(len(list(trace.graph().inputs())), 2)
1427 FileCheck().check(
"mul").check(
"add").run(str(trace))
1429 def test_trace_c10_ops(self):
1430 class MyModel(torch.nn.Module):
1432 super(MyModel, self).__init__()
1434 def forward(self, scores, bbox_deltas, im_info, anchors):
1435 a, b = torch.ops._caffe2.GenerateProposals(
1436 (scores), (bbox_deltas), (im_info), (anchors),
1437 2.0, 6000, 300, 0.7, 16,
True, -90, 90, 1.0,
1445 scores = torch.ones(img_count, A, H, W, dtype=torch.float32)
1446 bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W,
1447 dtype=torch.float32)
1448 bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W)
1449 im_info = torch.ones(img_count, 3, dtype=torch.float32)
1450 anchors = torch.ones(A, 4, dtype=torch.float32)
1451 inputs = (scores, bbox_deltas, im_info, anchors)
1456 def test_nested_inplace(self):
1457 x = torch.randn(2, 2)
1459 lambda x: F.threshold(x, 0, 0, inplace=
True), (x, ), return_inputs=
True)
1462 FileCheck().check(
"threshold_").run(str(trace))
1465 def run_ge_tests(self, optimize, use_cuda):
1467 t = torch.rand(*args).float()
1472 [rand(1), rand(1)], [rand(2, 3), rand(2, 3)],
1476 b, a), [rand(1), rand(1)], optimize=optimize)
1481 self.
checkTrace(foo, [rand(1)], optimize=optimize)
1484 lambda a, b: a * a, [rand(1), rand(1)], optimize=optimize,
1487 self.
checkTrace(foo, [rand(1)], drop=1, optimize=optimize)
1490 (a - 2 * b) + b, [rand(1), rand(1)],
1493 def test_ge_unoptimized(self):
1496 @unittest.skipIf(IS_WINDOWS
or IS_SANDCASTLE,
"NYI: fuser support for Windows or Sandcastle")
1498 def test_ge_optimized(self):
1501 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
1502 @unittest.skipIf(
not RUN_CUDA,
"requires CUDA")
1503 def test_ge_cuda(self):
1509 return a * b / (a - b) + b
1511 a, b = V(torch.rand(1)), V(torch.rand(1))
1512 ge = torch._C.GraphExecutor(foo, (a, b),
lambda var:
'')
1513 a, b = V(torch.rand(1), requires_grad=
True), V(
1514 torch.rand(1), requires_grad=
True)
1518 l2 = (da * db + db * db)
1525 l3 = (da2 * db2 + db2 * db2)
1529 def test_trace_annotation(self):
1530 @_trace(torch.rand(1))
1534 x = torch.randn(5, 5)
1537 def test_trace_script(self):
1551 expected = func1((a, b))
1553 result = traced((a, b))
1556 expected = func2((a, b))
1558 result = traced((a, b))
1561 def test_einsum(self):
1563 return torch.einsum(
'i,j->ij', (x, y))
1567 fns = [traced, script]
1568 x, y = torch.randn(10), torch.randn(2)
1569 for fn
in [traced, script]:
1573 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
1574 @unittest.skipIf(
not RUN_CUDA,
"calls .cuda()")
1575 def test_traced_module_cuda(self):
1576 class Model(nn.Module):
1577 def __init__(self, num_features, num_layers):
1578 super(Model, self).__init__()
1580 layers = [[nn.Linear(num_features, num_features), nn.Sigmoid()]
1581 for _
in range(num_layers)]
1582 self.
submodule = nn.Sequential(*chain(*layers))
1584 def forward(self, x):
1590 x = torch.randn(2, 5)
1598 linear_submodule = next(iter(traced_model.submodule._modules.values()))
1601 with self.assertRaises(AttributeError):
1602 linear_submodule.in_features
1603 linear_submodule.weight
1604 with self.assertRaises(RuntimeError):
1605 traced_model.asdf = 4
1606 linear_submodule.weight = nn.Parameter(torch.randn(linear_submodule.weight.shape))
1607 with self.assertRaises(RuntimeError):
1608 del linear_submodule.weight
1611 with self.assertRaises(RuntimeError):
1615 linear_submodule.cuda()
1616 traced_model.float().cuda()
1617 cuda_out = traced_model(x.float().cuda())
1619 cpu_out = traced_model(x.float())
1621 traced_model.to(
'cuda')
1622 cuda_out = traced_model(x.float().cuda())
1623 traced_model.to(
'cpu')
1624 cpu_out = traced_model(x.float())
1626 traced_model.double()
1629 state = {k: v.clone()
for k, v
in traced_model.state_dict().items()}
1630 new_state = {k: v.clone().fill_(1)
for k, v
in state.items()}
1631 out = traced_model(x)
1632 traced_model.load_state_dict(new_state)
1633 out_ones = traced_model(x)
1634 traced_model.load_state_dict(state)
1635 out_state = traced_model(x)
1639 def test_export_no_reorder(self):
1641 return a * b / (a - 2 * b) + b
1643 recording_inputs = [
torch.tensor([0.55619788169860839844], dtype=torch.float32, requires_grad=
True),
1644 torch.tensor([0.25947844982147216797], dtype=torch.float32, requires_grad=
True)]
1649 outputs_ge1 = ge1(*recording_inputs)
1650 outputs_ge2 = ge2(*recording_inputs)
1654 self.assertTrue(outputs_ge1 == outputs_ge2)
1655 self.assertTrue(grad_ge1 == grad_ge2)
1657 def test_python_function(self):
1658 class MyFn(Function):
1660 def forward(ctx, x):
1664 def backward(ctx, grad_output):
1667 @_trace(torch.zeros(2))
1669 return MyFn.apply(x + 2) + 3
1672 y = torch.randn(2, 2, requires_grad=
True)
1676 def test_python_function_tup(self):
1677 class MyFn(Function):
1679 def forward(ctx, x):
1683 def backward(ctx, grad_output):
1684 return grad_output, grad_output
1686 @_trace(torch.zeros(2))
1688 a, b = MyFn.apply(x + 2)
1691 y = torch.randn(2, 2, requires_grad=
True)
1695 def test_decompose_addmm(self):
1696 def does_decompose():
1698 def addmm(mat, mat1, mat2, alpha, beta):
1699 a = mat.addmm(mat1, mat2, alpha=4.20, beta=2.0)
1700 b = mat.addmm(mat1, mat2, alpha=int(alpha), beta=int(beta))
1704 mat = torch.randn(2, 2)
1705 mat1 = torch.randn(2, 4)
1706 mat2 = torch.randn(4, 2)
1707 alpha = torch.FloatTensor([123.0])
1708 beta = torch.FloatTensor([321.0])
1710 out_ref = addmm(mat, mat1, mat2, alpha, beta)
1711 self.
run_pass(
'canonicalize_ops', addmm.graph)
1712 out_test = addmm(mat, mat1, mat2, alpha, beta)
1714 FileCheck().check_not(
"addmm").run(str(addmm.graph))
1716 def doesnt_decompose():
1718 def addmm(mat, mat1, mat2, alpha, beta):
1719 a = mat.addmm(mat1, mat2)
1720 b = mat.addmm(mat1, mat2, alpha=1.0, beta=1.0)
1722 orig = str(addm.graph)
1723 self.
run_pass(
'canonicalize_ops', addmm.graph)
1724 self.assertTrue(orig == str(addmm.graph))
1726 def test_index_put(self):
1727 ten = torch.zeros(3, 3)
1728 mask = torch.Tensor([[
True,
True,
True],
1729 [
True,
False,
False],
1730 [
True,
True,
False]]).byte()
1732 def test_fn(ten, mask):
1733 ten[mask] = torch.ones(6)
1738 ten = torch.rand(3, 3)
1739 self.
assertEqual(test_fn(ten, mask), traced_test_fn(ten, mask))
1741 def test_sparse_tensors_error(self):
1743 return torch.sparse.FloatTensor(2, 3)
1747 output = get_sparse()
1748 return output, input
1751 sparse(get_sparse())
1756 def test_tuple_specialization(self):
1763 t = torch.randn(2, 2), torch.randn(2, 2)
1765 graph = f.graph_for(t)
1766 input_types = list(next(graph.inputs()).type().elements())
1767 for t
in input_types:
1768 self.
assertEqual(t.kind(),
'DimensionedTensorType')
1770 def test_constant_prop_simple(self):
1772 def constant_prop(input_int):
1776 return b - input_int
1778 out_ref = constant_prop(2)
1779 self.
run_pass(
'constant_propagation', constant_prop.graph)
1780 out_test = constant_prop(2)
1782 graph_str = str(constant_prop.graph)
1783 self.assertTrue(
"aten::add" not in graph_str
and "aten::mul" not in graph_str)
1784 const = constant_prop.graph.findNode(
"prim::Constant").output().toIValue()
1787 def test_constant_prop_nested(self):
1789 def constant_prop(a):
1797 self.
run_pass(
'constant_propagation', constant_prop.graph)
1800 if_node = constant_prop.graph.findNode(
"prim::If")
1801 for block
in if_node.blocks():
1802 for node
in block.nodes():
1803 self.assertTrue(node.kind() ==
"prim::Constant")
1805 def test_constant_prop_print(self):
1807 def constant_prop(input_tensor):
1811 return b + input_tensor
1813 self.
run_pass(
'constant_propagation', constant_prop.graph)
1814 graph = constant_prop.graph
1815 print_node = graph.findNode(
"prim::Print")
1816 self.assertTrue(print_node.input().toIValue() == 6)
1818 def test_constant_prop_rand(self):
1820 def constant_prop():
1821 a = torch.randn([3])
1825 self.
run_pass(
'constant_propagation', constant_prop.graph)
1826 self.assertTrue(
"aten::randn" in str(constant_prop.graph))
1828 def test_constant_prop_none(self):
1835 def constant_prop():
1838 if (a
is None and b
is None):
1844 self.
run_pass(
'constant_propagation', constant_prop.graph)
1845 graph_str = str(constant_prop.graph)
1846 self.assertTrue(graph_str.count(
"prim::Constant") == 1)
1848 def test_constant_prop_if_inline(self):
1850 def constant_prop():
1860 self.
run_pass(
'constant_propagation', constant_prop.graph)
1862 def test_trace_records_names(self):
1865 quick_brown_fox = torch.neg(baz)
1867 yeet = quick_brown_fox - 3.14
1871 graph_str = str(traced.graph)
1872 assert 'bar' in graph_str
1873 assert 'baz' in graph_str
1874 assert 'quick_brown_fox' in graph_str
1876 def test_constant_prop_if_constant(self):
1878 def constant_prop(a, b):
1895 return a + c0 + c1 + c2
1897 graph = constant_prop.graph
1898 self.
run_pass(
'constant_propagation', graph)
1899 ifs = graph.findAllNodes(
"prim::If", recurse=
False)
1900 snd_if_inlined = len(ifs) == 1
1901 self.assertTrue(snd_if_inlined)
1903 self.assertTrue(first_if.outputsSize() == 2)
1904 second_if = first_if.findNode(
"prim::If", recurse=
False)
1905 self.assertTrue(second_if.outputsSize() == 1)
1906 self.assertTrue(second_if.findNode(
"prim::If")
is None)
1908 def test_constant_prop_loop_constant(self):
1910 def constant_prop(cond, iter):
1917 for _
in range(iter):
1925 for _i
in range(-4):
1929 self.
run_pass(
'constant_propagation', constant_prop.graph)
1930 graph = canonical(constant_prop.graph)
1931 self.assertTrue(graph.count(
"removed") == 0)
1932 self.assertTrue(graph.count(
"stays") == 1)
1933 self.assertTrue(graph.count(
"prim::Print") == 4)
1935 def test_constant_prop_remove_output(self):
1937 def constant_prop(iter):
1942 for i
in range(iter):
1950 graph = constant_prop.graph
1951 self.
run_pass(
'constant_propagation', graph)
1952 self.assertTrue(graph.findNode(
"prim::Loop").outputsSize() == 2)
1954 def test_trace_detach(self):
1956 return torch.matmul(x, w).detach()
1960 FileCheck().check(
"matmul").check(
"detach").run(str(traced.graph))
1961 x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=
True)
1962 traced_result = traced(x, w)
1964 self.assertFalse(traced_result.requires_grad)
1965 self.assertIsNone(traced_result.grad_fn)
1967 def test_trace_detach_inplace(self):
1969 y = torch.matmul(x, w)
1975 FileCheck().check(
"matmul").check(
"detach(").run(str(traced.graph))
1976 x, w = torch.rand(3, 4), torch.rand(4, 5)
1977 traced_result = traced(x, w)
1979 self.assertFalse(traced_result.requires_grad)
1980 self.assertIsNone(traced_result.grad_fn)
1982 def test_trace_detach_onnx_erase(self):
1983 class Mod(torch.nn.Module):
1984 def forward(self, x, w):
1985 return torch.matmul(x, w).detach()
1989 Mod(), (torch.rand(3, 4), torch.rand(4, 5)), f))
1991 def test_trace_slice_full_dim(self):
1993 return x[0:5, 0] + 1.0
1996 test_x = torch.rand(6, 3)
1999 def test_export_dropout(self):
2000 test = torch.nn.Dropout()
2003 traced =
torch.jit.trace(test, (torch.rand(3, 4),), check_trace=
False)
2005 x = torch.randn(3, 4)
2008 def test_onnx_transpose_incomplete_tensor_type(self):
2015 @torch.jit.script_method
2016 def forward(self, x):
2017 return x.contiguous().transpose(0, 1).sum()
2019 class TraceMe(torch.nn.Module):
2021 super(TraceMe, self).__init__()
2024 def forward(self, x):
2029 example_outputs = (tm(torch.rand(3, 4)),)
2033 @unittest.skipIf(
not RUN_CUDA,
"requires CUDA")
2034 def test_cuda_export_restore(self):
2037 super(Sub, self).__init__()
2038 self.
weight = nn.Parameter(torch.randn(3, 4))
2040 @torch.jit.script_method
2041 def forward(self, thing):
2042 return self.
weight + thing
2046 super(M, self).__init__()
2049 @torch.jit.script_method
2050 def forward(self, v):
2056 input = torch.rand(3, 4).cuda()
2059 def test_export_batchnorm(self):
2060 for mode
in [
'eval',
'train']:
2062 torch.nn.BatchNorm1d(100),
2063 torch.nn.BatchNorm1d(100, affine=
False),
2064 torch.nn.BatchNorm2d(100),
2065 torch.nn.BatchNorm2d(100, affine=
False)]:
2066 getattr(clazz, mode)()
2067 input = torch.randn(20, 100)
if isinstance(clazz, torch.nn.BatchNorm1d)
else \
2068 torch.randn(20, 100, 35, 45)
2071 x = torch.randn(20, 100)
if isinstance(clazz, torch.nn.BatchNorm1d)
else \
2072 torch.randn(20, 100, 35, 45)
2075 def test_export_rnn(self):
2076 for clazz
in [nn.RNN(10, 20, 2), nn.GRU(10, 20, 2)]:
2077 class RNNTest(torch.nn.Module):
2079 super(RNNTest, self).__init__()
2082 def forward(self, x, lengths, h0):
2084 out, h = self.
rnn(packed, h0)
2090 traced =
torch.jit.trace(test, (torch.randn(5, 3, 10), torch.LongTensor([3, 2, 1]), torch.randn(2, 3, 20)))
2095 x, lengths, h0 = torch.randn(7, 4, 10), torch.LongTensor([7, 3, 2, 1]), torch.randn(2, 4, 20)
2096 self.
assertEqual(traced(x, lengths, h0), imported(x, lengths, h0))
2098 def test_export_lstm(self):
2099 class LSTMTest(torch.nn.Module):
2101 super(LSTMTest, self).__init__()
2102 self.
rnn = nn.LSTM(10, 20, 2)
2104 def forward(self, x, lengths, hiddens):
2107 out, (h, c) = self.
rnn(packed, (h0, c0))
2114 torch.LongTensor([3, 2, 1]),
2115 (torch.randn(2, 3, 20), torch.randn(2, 3, 20))))
2117 x, lengths, h0, c0 = \
2118 torch.randn(7, 3, 10), torch.LongTensor([7, 5, 2]), torch.randn(2, 3, 20), torch.randn(2, 3, 20)
2119 self.
assertEqual(traced(x, lengths, (h0, c0)), imported(x, lengths, (h0, c0)))
2121 def test_trace_dict_input(self):
2122 class Bar(torch.nn.Module):
2124 super(Bar, self).__init__()
2127 def forward(self, a, b):
2128 return self.
foo({
'a': a,
'b': b})[
'a']
2130 class Foo(torch.nn.Module):
2131 def forward(self, x):
2132 return {
'a': x[
'a'] * x[
'b']}
2134 x = (torch.rand(3), torch.rand(3))
2138 def test_trace_variable_instantiation(self):
2140 return Variable(Variable(x) + 1.0)
2144 x = torch.rand(5, 6)
2145 self.
assertEqual(random_foo(x), random_foo_traced(x))
2147 def test_trace_slice_expr_complete_type(self):
2155 return random_foo_traced(x)[0:1]
2157 x = torch.rand(3, 4)
2160 def test_export_tensoroption_to(self):
2162 return x.new_tensor(x[0]).cpu() + x
2165 example_outputs = traced(torch.rand([2]))
2169 example_outputs=example_outputs))
2171 def test_pretty_printer(self):
2191 def while_test(a, i):
2198 def while_if_test(a, b):
2210 def loop_use_test(y):
2222 def python_op_name_test(y):
2226 def empty_int_list_test(y):
2231 def empty_float_list_test(y):
2232 return [1.0, 2.0, 3.0]
2235 def print_weird_test(y):
2240 self.
assertExpected(while_test.graph.pretty_print(),
"while_test")
2241 self.
assertExpected(while_if_test.graph.pretty_print(),
"while_if_test")
2242 self.
assertExpected(loop_use_test.graph.pretty_print(),
"loop_use_test")
2243 self.
assertExpected(python_op_name_test.graph.pretty_print(),
"python_op_name_test")
2244 self.
assertExpected(empty_int_list_test.graph.pretty_print(),
"empty_int_list_test")
2245 self.
assertExpected(empty_float_list_test.graph.pretty_print(),
"empty_float_list_test")
2246 self.
assertExpected(print_weird_test.graph.pretty_print(),
"print_weird_test")
2248 def test_cu_escaped_number(self):
2255 def test_import_method(self):
2260 r, _ = foo._python_print()
2262 torch._C._jit_import_methods(mod,
"op_version_set = 0\n{}".format(r), [])
2265 def test_function_default_values(self):
2272 def simple_fn(x, a=a, b=b, c=outer_var + outer_var2):
2273 return x + a + b + c
2276 simple_fn(torch.ones(1)),
2277 torch.ones(1) + 0.5 + 10 + (20 + 30))
2280 torch.ones(1) + 1 + 3 + 4)
2286 def bool_fn(x, a=outer_c, flag=outer_flag):
2293 self.
assertEqual(bool_fn(torch.ones(1)), torch.ones(1) + 9)
2299 def none_fn(x=None):
2307 def hints(x, a=0.5, b=10):
2311 self.
assertEqual(hints(torch.ones(1)), torch.ones(1) + 0.5 + 10)
2316 def hints_bad_types(x, a=10, b=0.5):
2320 def test_module_default_values(self):
2325 super(Test, self).__init__()
2327 @torch.jit.script_method
2328 def forward(self, input, other=four):
2329 return input + other
2332 self.
assertEqual(t(torch.ones(1)), torch.ones(1) + 4)
2334 def test_warnings(self):
2340 warnings.warn(
"x is less than 2")
2343 FileCheck().check(
"aten::warn").run(str(fn.graph))
2345 def test_no_erroneous_warnings(self):
2350 warnings.warn(
'This should NOT be printed')
2354 with warnings.catch_warnings(record=
True)
as warns:
2357 warns = [str(w.message)
for w
in warns]
2360 @unittest.skipIf(sys.platform ==
"win32",
"TODO: need to fix this test case for Windows")
2361 def test_torch_load_error(self):
2364 super(J, self).__init__()
2366 @torch.jit.script_method
2367 def forward(self, input):
2371 with tempfile.NamedTemporaryFile()
as f:
2376 def test_legacy_constructors(self):
2378 return x.new_zeros(5, 5, requires_grad=
False)
2380 with warnings.catch_warnings(record=
True)
as warns:
2382 warns = [str(w.message)
for w
in warns]
2384 self.
assertEqual(warns[0],
"new_zeros is a legacy constructor and is not supported in the JIT.")
2386 def test_python_bindings(self):
2389 def lstm(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
2390 for i
in range(x.size(0)):
2391 hx, cx = lstm_cell(x[i], hx, cx, w_ih, w_hh, b_ih, b_hh)
2396 inputs = get_lstm_inputs(
'cpu', training=
True, seq_length=10)
2397 slstm(*inputs).sum().backward()
2399 fw_graph = slstm.graph_for(*inputs)
2400 nodes = [n
for n
in fw_graph.nodes()]
2401 tested_blocks =
False 2403 for output
in [o
for o
in node.outputs()]:
2404 self.assertTrue(hasattr(output,
'type'))
2405 self.assertTrue(output.type()
is not None)
2406 for input
in [i
for i
in node.inputs()]:
2407 self.assertTrue(hasattr(input,
'type'))
2408 self.assertTrue(input.type()
is not None)
2409 for block
in [b
for b
in node.blocks()]:
2410 tested_blocks =
True 2411 self.assertTrue(hasattr(block,
'inputs'))
2412 self.assertTrue(hasattr(block,
'outputs'))
2413 for output
in [o
for o
in block.outputs()]:
2414 self.assertTrue(hasattr(output,
'type'))
2415 self.assertTrue(output.type()
is not None)
2416 for input
in [i
for i
in block.inputs()]:
2417 self.assertTrue(hasattr(input,
'type'))
2418 self.assertTrue(input.type()
is not None)
2419 self.assertTrue(hasattr(block,
'returnNode'))
2420 self.assertTrue(type(block.returnNode()) == torch._C.Node)
2421 self.assertTrue(hasattr(block,
'paramNode'))
2422 self.assertTrue(type(block.paramNode()) == torch._C.Node)
2423 self.assertTrue(tested_blocks)
2428 def rand_batch(self, *dims):
2429 dims = [dim
for dim
in dims
if dim != ()]
2430 xs = [torch.rand(1, *(random.randint(1, size)
if b
else size
for b, size
in dims[1:]),
2431 requires_grad=
True)
for i
in range(dims[0])]
2432 xb = BatchTensor(xs,
torch.tensor([b
for b, d
in dims[1:]]).byte())
2435 def test_create_batchtensor(self):
2437 xs, batch = self.
rand_batch(4, (
True, 3), (
False, 2), (
True, 5))
2440 batch2 = BatchTensor(batch.get_data(), batch.get_mask(), batch.get_dims())
2443 xs = torch.rand(3, 4, 5)
2444 batch3 = BatchTensor(xs, 2)
2445 xs = xs.unsqueeze(0)
2448 def test_batch_elementwise_unary(self):
2451 return torch.tanh(a)
2453 xs, batch = self.
rand_batch(4, (
True, 3), (
False, 2))
2454 res_batch = tanh(batch)
2455 res = [torch.tanh(xs[j])
for j
in range(4)]
2458 def test_batch_elementwise_binary(self):
2463 xs, batch = self.
rand_batch(4, (
True, 3), (
False, 2))
2464 xs2, batch2 = xs, batch
2465 res_batch = add(batch, batch2)
2466 res = [torch.add(xs[j], xs2[j])
for j
in range(4)]
2470 xs, batch = self.
rand_batch(4, (
False, 3), (
False, 2))
2471 b = torch.rand(3, 2)
2472 res_batch = add(batch, b)
2473 res = [torch.add(xs[j], b)
for j
in range(4)]
2476 def test_batch_mm(self):
2479 return torch.mm(a, b)
2481 xs, batch = self.
rand_batch(4, (
True, 3), (
False, 2))
2482 xs2, batch2 = self.
rand_batch(4, (
False, 2), (
True, 3))
2483 res_batch = mm(batch, batch2)
2484 res = [torch.mm(xs[j].squeeze(0), xs2[j].squeeze(0)).unsqueeze(0)
for j
in range(4)]
2488 b = torch.rand(2, 4)
2489 res_batch = mm(batch, b)
2490 res = [torch.mm(xs[j].squeeze(0), b).unsqueeze(0)
for j
in range(4)]
2493 def test_batch_matmul(self):
2496 return torch.matmul(a, b)
2498 def matmul_test(xs, batch, xs2, batch2):
2499 ys = [torch.matmul(xs[j].squeeze(0), xs2[j].squeeze(0)).unsqueeze(0)
for j
in range(4)]
2500 ybs = matmul(batch, batch2)
2506 matmul_test(xs, batch, xs2, batch2)
2509 xs2, batch2 = self.
rand_batch(4, (
False, 2), (
True, 3))
2510 matmul_test(xs, batch, xs2, batch2)
2512 xs, batch = self.
rand_batch(4, (
True, 3), (
False, 2))
2514 matmul_test(xs, batch, xs2, batch2)
2516 xs, batch = self.
rand_batch(4, (
True, 3), (
False, 2))
2517 xs2, batch2 = self.
rand_batch(4, (
False, 2), (
True, 3))
2518 matmul_test(xs, batch, xs2, batch2)
2520 def test_batch_select(self):
2523 return torch.select(x, 1, 0)
2525 xs, batch = self.
rand_batch(4, (
True, 3), (
True, 2))
2526 res_batch = select(batch)
2527 res = [torch.select(xs[j], 1, 0)
for j
in range(4)]
2530 xs, batch = self.
rand_batch(4, (
False, 3), (
True, 2))
2531 res_batch = select(batch)
2532 res = [torch.select(xs[j], 1, 0)
for j
in range(4)]
2535 def test_batch_index_select(self):
2537 def index_select(x, ind):
2538 return x.index_select(1, ind)
2540 xs, batch = self.
rand_batch(4, (
False, 5), (
True, 2))
2541 ind = [torch.randint(0, 4, (1,), dtype=torch.long)
for i
in range(4)]
2543 res_batch = index_select(batch, ind_batch)
2544 res = [torch.index_select(xs[j], 1, ind[j])
for j
in range(4)]
2547 def test_batch_where(self):
2550 return torch.where(c, a, b)
2552 xs, batch = self.
rand_batch(4, (
False, 3), (
False, 2))
2553 xs2, batch2 = self.
rand_batch(4, (
False, 3), (
False, 2))
2555 dims = [4, (
False, 3), (
False, 2)]
2556 xs_cond = [torch.rand(1, 3, 2).byte()
for i
in range(dims[0])]
2557 batch_cond = BatchTensor(xs_cond,
torch.tensor([b
for b, d
in dims[1:]]))
2559 res_batch = where(batch_cond, batch, batch2)
2560 res = [torch.where(xs_cond[j], xs[j], xs2[j])
for j
in range(4)]
2563 def test_batch_argmax(self):
2566 return torch.argmax(a, 1)
2568 xs, batch = self.
rand_batch(4, (
True, 5), (
True, 6))
2569 res_batch = argmax(batch)
2570 res = [torch.argmax(xs[j], 1)
for j
in range(4)]
2575 return torch.argmax(a, 1,
False)
2577 res_batch = argmax(batch)
2578 res = [torch.argmax(xs[j], 1,
False)
for j
in range(4)]
2581 def test_batch_topk(self):
2584 return torch.topk(a, 3, 1)
2586 xs, batch = self.
rand_batch(4, (
False, 5), (
True, 6))
2589 res_batch = topk(batch)
2590 res = [torch.topk(xs[j], 3, 1)[0]
for j
in range(4)]
2591 res_idx = [torch.topk(xs[j], 3, 1)[1]
for j
in range(4)]
2593 self.
assertEqual(res_idx, res_batch[1].examples())
2597 return torch.topk(a, 1, 2)
2600 res_batch = topk(batch)
2601 res = [torch.topk(xs[j], 1, 2)[0]
for j
in range(4)]
2602 res_idx = [torch.topk(xs[j], 1, 2)[1]
for j
in range(4)]
2604 self.
assertEqual(res_idx, res_batch[1].examples())
2606 def test_batch_softmax(self):
2609 return torch.softmax(a, 1)
2611 xs, batch = self.
rand_batch(4, (
False, 5), (
True, 6))
2614 res_batch = softmax(batch)
2615 res = [torch.softmax(xs[j], 1)
for j
in range(4)]
2620 return torch.softmax(a, 2)
2623 res_batch = softmax(batch)
2624 res = [torch.softmax(xs[j], 2)
for j
in range(4)]
2627 def test_batch_view(self):
2630 return a.view([4, -1, 3])
2632 xs, batch = self.
rand_batch(4, (
True, 5), (
False, 3))
2633 res_batch = view(batch)
2634 res = [xs[j].view([1, -1, 3])
for j
in range(4)]
2637 def test_batch_cat(self):
2640 return torch.cat([a, b], 2)
2642 xs, batch = self.
rand_batch(4, (
True, 5), (
False, 3))
2643 xs2, batch2 = xs, batch
2644 res_batch = cat2(batch, batch2)
2645 res = [torch.cat([xs[j], xs2[j]], 2)
for j
in range(4)]
2648 def test_batch_sum(self):
2653 xs, batch = self.
rand_batch(4, (
True, 5), (
False, 3))
2654 res_batch = batch_sum(batch)
2655 res = [xs[j].sum().unsqueeze(0)
for j
in range(4)]
2658 def test_if_else(self):
2659 def single_if(a, b):
2670 res_batch = batch_if(batch_a, batch_b)
2671 res = [single_if(a[j], b[j])
for j
in range(4)]
2675 torch.to_batch_graph(script_if.graph)
2677 def test_if_else_with_scalar(self):
2678 def single_if(a, b):
2689 res_batch = batch_if(batch_a, batch_b)
2690 res = [single_if(a[j], b[j])
for j
in range(4)]
2694 torch.to_batch_graph(script_if.graph)
2696 def test_if_noelse(self):
2697 def single_if(a, b):
2706 res_batch = batch_if(batch_a, batch_b)
2707 res = [single_if(a[j], b[j])
for j
in range(4)]
2711 torch.to_batch_graph(script_if.graph)
2713 def test_if_noelse_with_scalar(self):
2714 def single_if(a, b):
2723 res_batch = batch_if(batch_a, batch_b)
2724 res = [single_if(a[j], b[j])
for j
in range(4)]
2728 torch.to_batch_graph(script_if.graph)
2730 def test_while(self):
2731 def single_while(a, b):
2739 b = [torch.abs(torch.rand(1))
for i
in range(4)]
2741 res_batch = batch_while(batch_a, batch_b)
2742 res = [single_while(a[j], b[j])
for j
in range(4)]
2746 torch.to_batch_graph(script_while.graph)
2749 def single_for(x, y):
2758 res_batch = batch_for(batch_a, batch_b)
2759 res = [single_for(a[j], b[j])
for j
in range(4)]
2763 torch.to_batch_graph(script_for.graph)
2765 def test_lstm(self):
2766 def LSTM(x_all, h, c, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c):
2767 for i
in range(x_all.size(1)):
2768 x = x_all.select(1, i)
2769 i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
2770 f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
2771 o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
2773 i_t = torch.sigmoid(i_t)
2774 f_t = torch.sigmoid(f_t)
2775 o_t = torch.sigmoid(o_t)
2777 c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
2778 c_t = torch.tanh(c_t)
2779 c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
2780 h_t = torch.mul(o_t, torch.tanh(c_t))
2787 batch_size, input_size, hidden_size = 4, 3, 2
2788 xs, batch = self.
rand_batch(batch_size, (
True, 4), (
False, input_size))
2789 hx, h_batch = self.
rand_batch(batch_size, (
False, hidden_size))
2790 cx, c_batch = self.
rand_batch(batch_size, (
False, hidden_size))
2793 w_xi = torch.rand(input_size, hidden_size)
2794 w_xf = torch.rand(input_size, hidden_size)
2795 w_xo = torch.rand(input_size, hidden_size)
2796 w_xc = torch.rand(input_size, hidden_size)
2798 w_hi = torch.rand(hidden_size, hidden_size)
2799 w_hf = torch.rand(hidden_size, hidden_size)
2800 w_ho = torch.rand(hidden_size, hidden_size)
2801 w_hc = torch.rand(hidden_size, hidden_size)
2803 b_i = torch.rand(hidden_size)
2804 b_f = torch.rand(hidden_size)
2805 b_o = torch.rand(hidden_size)
2806 b_c = torch.rand(hidden_size)
2808 ys = [LSTM(xs[j], hx[j], cx[j], w_xi, w_xf, w_xo, w_xc,
2809 w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c)
for j
in range(batch_size)]
2810 ybs = LSTM_batch(batch, h_batch, c_batch, w_xi, w_xf, w_xo, w_xc,
2811 w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c)
2814 def test_greedy_search(self):
2815 def greedy(x, h, c, embed, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
2816 b_i, b_f, b_o, b_c, w_hs, b_s, iter_num):
2817 iter_count = torch.zeros_like(iter_num)
2818 while bool(iter_count < iter_num):
2819 iter_count = iter_count + 1
2821 i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
2822 f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
2823 o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
2825 i_t = torch.sigmoid(i_t)
2826 f_t = torch.sigmoid(f_t)
2827 o_t = torch.sigmoid(o_t)
2829 c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
2830 c_t = torch.tanh(c_t)
2831 c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
2832 h_t = torch.mul(o_t, torch.tanh(c_t))
2836 s_t = torch.matmul(h_t, w_hs) + b_s
2837 p_t = torch.softmax(s_t, 1)
2838 i_t = torch.argmax(p_t, 1)
2839 x = embed.index_select(1, i_t).squeeze(1)
2844 batch_size, input_size, hidden_size, vocab_size = 4, 6, 8, 7
2845 xs, batch = self.
rand_batch(batch_size, (
False, input_size))
2846 hx, h_batch = self.
rand_batch(batch_size, (
False, hidden_size))
2847 cx, c_batch = self.
rand_batch(batch_size, (
False, hidden_size))
2848 embed, embed_batch = self.
rand_batch(batch_size, (
False, vocab_size), (
False, input_size))
2849 iter_num = [torch.randint(2, 5, (1,))
for i
in range(batch_size)]
2850 iter_num_batch = BatchTensor(iter_num,
torch.tensor([]).byte())
2853 w_xi = torch.rand(input_size, hidden_size)
2854 w_xf = torch.rand(input_size, hidden_size)
2855 w_xo = torch.rand(input_size, hidden_size)
2856 w_xc = torch.rand(input_size, hidden_size)
2858 w_hi = torch.rand(hidden_size, hidden_size)
2859 w_hf = torch.rand(hidden_size, hidden_size)
2860 w_ho = torch.rand(hidden_size, hidden_size)
2861 w_hc = torch.rand(hidden_size, hidden_size)
2863 b_i = torch.rand(hidden_size)
2864 b_f = torch.rand(hidden_size)
2865 b_o = torch.rand(hidden_size)
2866 b_c = torch.rand(hidden_size)
2868 w_hs = torch.rand(hidden_size, vocab_size)
2869 b_s = torch.rand(vocab_size)
2871 ys = [greedy(xs[j], hx[j], cx[j], embed[j], w_xi, w_xf, w_xo, w_xc,
2872 w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num[j])
for j
in range(batch_size)]
2873 ybs = greedy_batch(batch, h_batch, c_batch, embed_batch, w_xi, w_xf, w_xo, w_xc,
2874 w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num_batch)
2877 def test_beam_search(self):
2878 def beam(x, h, c, embed, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
2879 b_i, b_f, b_o, b_c, w_hs, b_s, iter_num, idx):
2881 vocab_size = embed.size(1)
2882 iter_count = torch.zeros_like(iter_num)
2883 max_len = idx.size(2)
2884 while bool(iter_count < iter_num):
2885 iter_count = iter_count + 1
2887 i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
2888 f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
2889 o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
2891 i_t = torch.sigmoid(i_t)
2892 f_t = torch.sigmoid(f_t)
2893 o_t = torch.sigmoid(o_t)
2895 c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
2896 c_t = torch.tanh(c_t)
2897 c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
2898 h_t = torch.mul(o_t, torch.tanh(c_t))
2902 s_t = torch.matmul(h_t, w_hs) + b_s
2903 s_t = s_t.view([1, s_t.size(1) * s_t.size(2)])
2904 p_t = torch.softmax(s_t, 1)
2905 prob_t, idx_t = torch.topk(p_t, k, 1)
2906 if(int(idx_t.dim()) > 1):
2907 idx_t_tmp = idx_t.squeeze(0)
2910 new_y = torch.fmod(idx_t_tmp, vocab_size)
2911 pre_y = idx_t_tmp / vocab_size
2912 x = embed.index_select(1, new_y)
2913 h = h_t.index_select(1, pre_y)
2914 c = c_t.index_select(1, pre_y)
2915 iter = int(iter_count[0])
2916 idx = torch.cat([idx.narrow(2, 0, iter).index_select(1, pre_y),
2917 torch.fmod(idx_t, vocab_size).unsqueeze(-1),
2918 idx.narrow(2, iter, max_len - iter)], 2)
2919 idx = idx.narrow(2, 0, max_len)
2925 batch_size, input_size, hidden_size, vocab_size = 4, 6, 8, 7
2927 xs, batch = self.
rand_batch(batch_size, (
False, 1), (
False, input_size))
2928 hx, h_batch = self.
rand_batch(batch_size, (
False, 1), (
False, hidden_size))
2929 cx, c_batch = self.
rand_batch(batch_size, (
False, 1), (
False, hidden_size))
2930 embed, embed_batch = self.
rand_batch(batch_size, (
False, vocab_size), (
False, input_size))
2931 iter_num = [torch.randint(2, max_len + 1, (1,))
for i
in range(batch_size)]
2932 iter_num_batch = BatchTensor(iter_num,
torch.tensor([]).byte())
2935 w_xi = torch.rand(input_size, hidden_size)
2936 w_xf = torch.rand(input_size, hidden_size)
2937 w_xo = torch.rand(input_size, hidden_size)
2938 w_xc = torch.rand(input_size, hidden_size)
2940 w_hi = torch.rand(hidden_size, hidden_size)
2941 w_hf = torch.rand(hidden_size, hidden_size)
2942 w_ho = torch.rand(hidden_size, hidden_size)
2943 w_hc = torch.rand(hidden_size, hidden_size)
2945 b_i = torch.rand(1, hidden_size)
2946 b_f = torch.rand(1, hidden_size)
2947 b_o = torch.rand(1, hidden_size)
2948 b_c = torch.rand(1, hidden_size)
2950 w_hs = torch.rand(hidden_size, vocab_size)
2951 b_s = torch.rand(1, vocab_size)
2954 torch.zeros([batch_size, 1, max_len]).byte(),
2956 idx = [torch.zeros([1, k, max_len], dtype=torch.long)
for _
in range(batch_size)]
2958 ys = [beam(xs[j], hx[j], cx[j], embed[j], w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
2959 b_i, b_f, b_o, b_c, w_hs, b_s, iter_num[j], idx[j]).narrow(2, 0, int(iter_num[j]))
2960 for j
in range(batch_size)]
2961 ybs = beam_batch(batch, h_batch, c_batch, embed_batch, w_xi, w_xf, w_xo, w_xc,
2962 w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num_batch, idx_batch)
2966 def execWrapper(code, glob, loc):
2968 exec(code)
in glob, loc
2970 exec(code, glob, loc)
2975 def capture_stdout(self):
2984 stdout_fd = os.dup(1)
2991 captured_stdout = [
'']
2992 yield captured_stdout
2996 fcntl.fcntl(r, fcntl.F_SETFL, os.O_NONBLOCK)
3000 total_stdout += os.read(r, 1000).decode(
'ascii')
3001 except OSError
as e:
3002 if e.errno != errno.EAGAIN:
3005 captured_stdout[0] = total_stdout
3015 optimize=
True, outputs=
None, capture_output=
False):
3017 Checks that a given function will throw the correct exception, 3018 when executed with normal python, the string frontend, and the AST frontend 3025 source = textwrap.dedent(inspect.getsource(script))
3027 ge = getattr(cu, script.__name__)
3034 def test_training_param(self):
3036 @torch.jit.script_method
3037 def forward(self, x):
3053 def test_jitter_bug(self):
3055 def fn2(input, kernel_size):
3057 if kernel_size[0] > 1:
3060 _stride = kernel_size
3061 print(_stride, kernel_size)
3067 return fn2(input, [1])
3069 def test_parser_kwargonly(self):
3071 def foo(x, *, y) -> Tuple[Tensor, Tensor]: 3076 self.assertTrue(
'*' in cu.module._get_method(
'foo').pretty_print_schema())
3079 def foo(x, *, y) -> Tuple[Tensor, Tensor]: 3085 def test_annoying_doubles(self):
3086 mod = types.ModuleType(
"temp")
3087 mod.inf = float(
"inf")
3088 mod.ninf = float(
"-inf")
3089 mod.nan = float(
"nan")
3094 return math.pi, 0.1, mod.inf, mod.ninf, 2.225073858507201e-308, mod.nan
3096 pp, table = foo._get_method(
'forward').python_print()
3097 ppv =
"op_version_set = 0\n{}".format(pp)
3099 torch._C._jit_import_methods(sm, ppv, table)
3103 self.assertTrue(r[:-1] == r2[:-1])
3104 self.assertTrue(math.isnan(r[-1])
and math.isnan(r2[-1]))
3106 def test_type_annotate(self):
3126 def annotate_none():
3129 def annotate_none_no_optional():
3135 def test_robust_op_resolution(self):
3141 a = (torch.rand(3),)
3144 def test_tuple_io(self):
3150 a = (torch.rand(3), torch.rand(3))
3153 def test_tuple_create_return(self):
3156 a = (torch.ones(x), torch.zeros(x))
3160 def test_list_io(self):
3163 return torch.ones(x), x
3167 def get_sum_list_fn(self):
3178 def test_sum_list_diff_elms(self):
3181 def test_sum_list_empty(self):
3184 def test_sum_list_one(self):
3187 def test_sum_list_literal(self):
3192 for i
in [1, 2, 3, 4, 5]:
3199 def test_sum_list_wrong_type(self):
3213 def test_bool_list_io(self):
3217 return x, [
True,
False], [[
True]]
3219 li_1, li_2, li_3 = stuff4([
True])
3221 for li
in [li_1, li_2, li_3]:
3222 self.assertTrue(type(li[0]) == type(
True))
3224 def test_nested_list(self):
3231 def test_nested_list_construct(self):
3233 return [[4]] + [[4, 5]]
3236 def test_tensor_shape(self):
3237 x = torch.empty(34, 56, 78)
3244 def test_tensor_grad(self):
3249 return x.requires_grad
3254 def test_tensor_dtype(self):
3255 x_byte = torch.empty(34, 56, 78, dtype=torch.uint8)
3256 x_long = torch.empty(34, 56, 78, dtype=torch.long)
3257 x_float32 = torch.empty(34, 56, 78, dtype=torch.float32)
3261 return x.dtype == torch.uint8
3265 return x.dtype == torch.long
3269 return x.dtype == torch.float32
3271 self.assertTrue(byte(x_byte))
3272 self.assertFalse(byte(x_long))
3273 self.assertFalse(byte(x_float32))
3274 self.assertFalse(long(x_byte))
3275 self.assertTrue(long(x_long))
3276 self.assertFalse(long(x_float32))
3277 self.assertFalse(float32(x_byte))
3278 self.assertFalse(float32(x_long))
3279 self.assertTrue(float32(x_float32))
3281 @unittest.skipIf(
not RUN_CUDA,
"device tests require CUDA")
3282 def test_tensor_device(self):
3283 cpu = torch.empty(34, 56, 78, device=
'cpu')
3284 gpu = torch.empty(34, 56, 78, device=
'cuda')
3287 def same_device(x, y):
3288 return x.device == y.device
3290 self.assertTrue(same_device(cpu, cpu))
3291 self.assertTrue(same_device(gpu, gpu))
3292 self.assertFalse(same_device(cpu, gpu))
3294 @unittest.skipIf(
not RUN_CUDA,
"device tests require CUDA")
3295 def test_tensor_to_device(self):
3297 return x.to(device=
"cuda").to(device=torch.device(
"cpu"))
3301 def test_tensor_to_cpu(self):
3305 x = torch.ones(3, 4)
3307 self.
assertEqual(to_cpu(x).device, script_fn(x).device)
3310 @unittest.skipIf(
not RUN_CUDA,
"device tests require CUDA")
3311 def test_tensor_to_cuda(self):
3315 x = torch.ones(3, 4)
3317 self.
assertEqual(to_cuda(x).device, script_fn(x).device)
3320 def test_generic_list_errors(self):
3324 return [[x]] + [[1]]
3326 def test_script_cu(self):
3332 a = Variable(torch.rand(1))
3337 def test_string_cu(self):
3340 print(a, """a\\n\tb\\n""", 2, "a\ 3344 FileCheck().check(
"aa").check(
"a\\n\\tb\\n").run(str(cu.foo.graph))
3346 def test_string_ops(self):
3349 return a + a,
"ab" ==
"b",
"ab" !=
"b",
"ab" ==
"ab",
"ab" !=
"ab" 3353 def test_string_new_line(self):
3362 def test_string_single_escape(self):
3370 def test_script_annotation(self):
3374 s = Variable(torch.rand(2))
3380 return a < float(
'inf')
3382 self.assertTrue(foo(s))
3386 return a > float(
'-inf')
3388 self.assertTrue(foo(s))
3396 a = torch.rand(1, requires_grad=
True)
3397 b = torch.rand(1, requires_grad=
True)
3404 a = torch.rand(1, requires_grad=
True)
3405 b = torch.rand(1, requires_grad=
True)
3408 @unittest.skipIf(
not PY35,
"Python 3.5 needed")
3409 def test_matmul_py3(self):
3415 with tempfile.TemporaryDirectory()
as tmp_dir:
3416 script_path = os.path.join(tmp_dir,
'script.py')
3417 with open(script_path,
'w')
as f:
3419 fn = get_fn(
'test_matmul_py3', script_path)
3421 a = torch.rand(4, 3, requires_grad=
True)
3422 b = torch.rand(3, 2, requires_grad=
True)
3429 def func2(a, b, c, d):
3430 return c + a ** b ** d
3432 a = torch.rand(1, requires_grad=
True)
3433 b = torch.rand(1, requires_grad=
True)
3434 c = torch.rand(1, requires_grad=
True)
3435 d = torch.rand(1, requires_grad=
True)
3437 self.
checkScript(func2, (a, b, c, d), optimize=
True)
3439 def test_triple(self):
3443 x = torch.rand(1, dtype=torch.float, requires_grad=
True)
3446 def test_slice(self):
3450 x = torch.rand(10, dtype=torch.float, requires_grad=
True)
3458 def test_gather(self):
3462 x = torch.rand(10, dtype=torch.float, requires_grad=
True)
3465 def test_random(self):
3468 return torch.normal(mean, std)
3470 mean, std = torch.zeros(5, 5), torch.ones(5, 5)
3472 output = torch.normal(mean, std)
3474 script_output = f(mean, std)
3477 def _check_code(self, code_str, fn_name, inputs):
3479 exec(code_str, globals(), scope)
3481 self.
assertEqual(cu.func(*inputs), scope[fn_name](*inputs))
3483 @unittest.skipIf(
not RUN_CUDA,
'no CUDA')
3484 def test_scriptmodule_releases_tensors_cuda(self):
3487 return x.sigmoid() * y.tanh()
3489 def test(backward=False):
3490 x = torch.randn(3, 3, dtype=torch.double, device=
'cuda', requires_grad=
True)
3491 y = torch.randn(3, 3, dtype=torch.double, device=
'cuda', requires_grad=
True)
3494 out.sum().backward()
3506 def test_index(self):
3507 def consec(size, start=0):
3509 return torch.arange(numel).view(size)
3511 def check_indexing(indexing, tensor):
3512 template = dedent(
""" 3517 self.
_check_code(template.format(indexing),
"func", [tensor])
3519 def check_dynamic_indexing(indexing, tensor, value1, value2):
3523 template = dedent(
""" 3524 def func(x, value1, value2): 3530 self.
_check_code(template.format(indexing),
"func", [tensor, value1, value2])
3533 check_indexing(
'[0]', consec((3, 3)))
3534 check_indexing(
'[1]', consec((3, 3), 10))
3535 check_indexing(
'[2]', consec((3, 3), 19))
3536 check_indexing(
'[2]', consec((3,)))
3537 check_indexing(
'[-1]', consec((3, 3), 19))
3538 check_indexing(
'[0:2]', consec((3, 3, 3)))
3539 check_indexing(
'[1:-1]', consec((3, 3, 3)))
3540 check_indexing(
'[-3:-1]', consec((6, 3)))
3541 check_indexing(
'[1:]', consec((3, 3)))
3542 check_indexing(
'[:1]', consec((3, 3)))
3543 check_indexing(
'[:]', consec((3, 2)))
3546 check_indexing(
'[0, 1]', consec((3, 3)))
3547 check_indexing(
'[0, 1]', consec((3, 3, 2)))
3548 check_indexing(
'[1, 0, 2]', consec((3, 3, 3)))
3549 check_indexing(
'[2, -1]', consec((3, 3)))
3552 check_indexing(
'[0, 1:2]', consec((3, 3)))
3553 check_indexing(
'[0, :1]', consec((3, 3, 2)))
3554 check_indexing(
'[1, 2:]', consec((3, 3, 3)))
3555 check_indexing(
'[-1, 1:, 0]', consec((3, 3, 3, 3)))
3556 check_indexing(
'[1:, -1, 0]', consec((3, 3, 3, 3)))
3557 check_indexing(
'[-1, 2:, 1:2]', consec((3, 3, 3, 3)))
3558 check_indexing(
'[-1, 1:, 0]', consec((3, 3, 3, 3)))
3559 check_indexing(
'[-1, :, 0, 2]', consec((3, 3, 3, 3)))
3562 check_indexing(
'[0:0]', consec((2, 2)))
3563 check_indexing(
'[0:0, 1]', consec((3, 3)))
3566 check_indexing(
'[1+1]', consec((3, 3)))
3567 check_indexing(
'[1:(0 + 2)]', consec((3, 3, 3)))
3570 check_dynamic_indexing(
"[i + j]", consec((3, 3)), 0, 1)
3571 check_dynamic_indexing(
"[i:j, i]", consec((3, 3, 2)), 0, 2)
3573 def test_tensor_item(self):
3574 def test_scalar_to_float_coercion(x):
3575 return x.item() == 1
3580 def test_scalar_cast(x):
3582 return int(scalar), float(scalar)
3587 expected_str =
r"Use int\(tensor\) or float\(tensor\) to retrieve" 3595 def test_error_msg(x):
3596 return int_fn(x.item())
3598 def test_method_on_number(self):
3606 def test_scalar_to_num_conversions(self):
3608 def multiple_defs(x):
3613 self.assertTrue(
"ImplicitTensorToNum" not in str(multiple_defs.graph))
3616 def tensor_to_int_script(x, tensor):
3617 return x.unsqueeze(tensor)
3619 def tensor_to_int(x, tensor):
3620 return x.unsqueeze(tensor)
3623 def tensor_to_float_script(x, tensor):
3624 return x.addcmul(tensor, tensor, value=tensor)
3626 def tensor_to_float(x, tensor):
3627 return x.addcmul(tensor, tensor, value=tensor)
3636 script_funs = [tensor_to_int_script, tensor_to_float_script]
3637 funs = [tensor_to_int, tensor_to_float]
3640 def test_func(func, x, tensor):
3642 result = func(x, tensor)
3643 except RuntimeError
as e:
3645 except TypeError
as e:
3650 for tensor
in tensors:
3651 for i
in range(len(script_funs)):
3652 self.
assertEqual(test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor))
3654 def test_tuple_to_opt_list(self):
3664 def test_advancedindex(self):
3665 def consec(size, start=0):
3667 return torch.arange(numel).view(size)
3669 def check_indexing(indexing, tensor, **kwargs):
3670 indices_dict = kwargs
3672 template = dedent(
""" 3673 def func(x{formals}): 3679 for formal, value
in indices_dict.items():
3680 formals.append(formal)
3681 values.append(value)
3683 formals =
''.join(map(
', {}'.format, formals))
3684 inputs = [tensor] + values
3685 self.
_check_code(template.format(formals=formals, expr=indexing),
3689 check_indexing(
'[i]', consec((3, 3)), i=
torch.tensor([0]))
3690 check_indexing(
'[i]', consec((3, 3)), i=
torch.tensor(1))
3691 check_indexing(
'[i]', consec((3, 3)), i=
torch.tensor([-2]))
3692 check_indexing(
'[i]', consec((3, 3), 2), i=
torch.tensor([0, 0]))
3693 check_indexing(
'[i]', consec((3, 3, 2, 2)), i=
torch.tensor([0, -2, 1]))
3702 inp = consec((4, 8, 5))
3705 [
'[i, j]', {
'i': [0, 2],
'j': [1, 3]}],
3707 [
'[i, j, k]', {
'i': [0, 2],
'j': [1, 3],
'k': [1, 1]}],
3709 [
'[i, j, k]', {
'i': [0, 2],
'j': 1,
'k': [1, 1]}],
3711 [
'[:, :, i]', {
'i': [0, 3, 4]}],
3713 [
'[:, i, 2:4]', {
'i': [0, 2, 3]}],
3715 [
'[i, :, :]', {
'i': [2, 3]}],
3717 [
'[:, i, j]', {
'i': [0, 2, 3],
'j': [1, 3, 4]}],
3719 [
'[:, i, j]', {
'i': [0],
'j': [1, 2, 4]}],
3721 [
'[:, i, j]', {
'i': [0, 1, 3],
'j': [4]}],
3723 [
'[:, i, j]', {
'i': [[0, 1], [1, 0]],
'j': [[2, 3]]}],
3725 [
'[:, i, j]', {
'i': [[0, 1], [2, 3]],
'j': [[0]]}],
3727 [
'[:, i, j]', {
'i': [[5, 6]],
'j': [[0, 3], [4, 4]]}],
3729 [
'[i, j, :]', {
'i': [0, 2, 3],
'j': [1, 3, 4]}],
3731 [
'[i, j, :]', {
'i': 0,
'j': [1, 2, 4]}],
3733 [
'[i, j, :]', {
'i': [0, 1, 3],
'j': 4}],
3735 [
'[i, j, :]', {
'i': [[0, 1], [1, 0]],
'j': [[2, 1], [3, 5]]}],
3737 [
'[i, j, :]', {
'i': [[0, 1], [1, 0]],
'j': [[2, 3]]}],
3739 [
'[i, j, :]', {
'i': [[0, 1], [2, 3]],
'j': [[0]]}],
3741 [
'[i, j, :]', {
'i': [[2, 1]],
'j': [[0, 3], [4, 4]]}],
3743 [
'[i, j, 0:2]', {
'i': [[2]],
'j': [[0, 3], [4, 1]]}],
3746 for expr, argdict
in to_check:
3747 tensordict = {k:
torch.tensor(v)
for (k, v)
in argdict.items()}
3748 check_indexing(expr, inp, **tensordict)
3750 def test_keyword(self):
3753 return torch.sum(x, dim=0)
3755 x = torch.rand(10, dtype=torch.float, requires_grad=
True)
3757 y2 = torch.sum(x, dim=0)
3760 def test_constant_pooling_none(self):
3762 def typed_nones(a=None, b=None, c=None):
3770 print(typed_nones())
3772 print(typed_nones())
3774 graph_str = str(test.graph)
3775 self.assertTrue(graph_str.count(
"bool? = prim::Constant") == 1)
3776 self.assertTrue(graph_str.count(
"int? = prim::Constant") == 1)
3777 self.assertTrue(graph_str.count(
"None = prim::Constant") == 1)
3779 def test_literal(self):
3802 a = torch.rand(1, requires_grad=
True)
3803 b = torch.rand(1, requires_grad=
True)
3806 self.
checkScript(func3, (a.item(), b.item()), optimize=
True)
3808 def test_expand(self):
3813 x = torch.rand(2, 3, dtype=torch.float, requires_grad=
True)
3814 y = torch.rand(3, dtype=torch.float, requires_grad=
True)
3818 grad = torch.randn(2, 3, dtype=torch.float)
3826 return x.sum(dim=[4])
3833 self.
run_pass(
'constant_propagation', func.graph)
3834 self.
run_pass(
'constant_propagation', func2.graph)
3835 torch._C._jit_pass_shape_analysis(
3836 func.graph, (torch.zeros(1, 1, 1, 1, 4),),
False)
3837 torch._C._jit_pass_shape_analysis(
3838 func2.graph, (torch.zeros(1, 1, 1, 1, 4),),
False)
3839 self.assertTrue(func.graph.findNode(
"aten::sum").output().type().kind()
3840 ==
"DimensionedTensorType")
3841 self.assertTrue(func2.graph.findNode(
"aten::sum").output().type().kind()
3842 ==
"DimensionedTensorType")
3847 return torch.cat((x, x), dim=0)
3849 x = torch.rand(10, dtype=torch.float, requires_grad=
True)
3850 self.
assertEqual(func(x), torch.cat((x, x), dim=0))
3854 return torch.cat((x, x), y)
3856 x = torch.rand([2, 2])
3858 self.
assertEqual(func2(x, y), torch.cat((x, x), y))
3860 def test_cat_lifts(self):
3863 return torch.cat([x, x], dim=1)
3867 return torch.cat([], dim=1)
3871 return torch.cat([x], dim=1)
3873 for g
in [foo.graph, foo2.graph, foo3.graph]:
3874 FileCheck().check(
"int =").check(
"ListConstruct").check(
"aten::cat").run(str(g))
3876 def test_list_literal(self):
3884 def reassign_arity_change():
3889 self.
checkScript(reassign_arity_change, (), optimize=
False)
3891 def reassign_from_empty_literal():
3897 self.
checkScript(reassign_from_empty_literal, (), optimize=
False)
3899 def reassign_from_empty_builtin():
3908 z = [torch.randn([1])]
3910 self.
checkScript(reassign_from_empty_builtin, (), optimize=
False)
3912 def reassign_bad_type():
3918 self.
checkScript(reassign_bad_type, (), optimize=
False)
3920 def reassign_nested():
3928 self.
checkScript(reassign_nested, (), optimize=
False)
3930 def test_list_gather(self):
3937 def negative_index():
3948 "list index out of range")
3950 def bad_negative_index():
3955 "list index out of range")
3957 def test_tensor_len(self):
3963 def test_list_len(self):
3976 def test_list_ops(self):
3977 def test_equality():
3982 self.
checkScript(test_equality, (), optimize=
True)
3984 def test_inequality():
3989 self.
checkScript(test_equality, (), optimize=
True)
3991 def test_non_equality():
3996 self.
checkScript(test_non_equality, (), optimize=
True)
3998 def test_non_inequality():
4003 self.
checkScript(test_non_equality, (), optimize=
True)
4005 def test_list_equality_as_cond():
4014 self.
checkScript(test_list_equality_as_cond, (), optimize=
True)
4016 def test_list_add():
4020 return c == [1, 2, 3, 2]
4022 self.
checkScript(test_list_add, (), optimize=
True)
4024 def test_list_add_empty():
4028 return c == [1, 2, 3]
4030 self.
checkScript(test_list_add_empty, (), optimize=
True)
4032 def test_tensor_list_equality():
4033 t1 = torch.ones([1, 1])
4034 t2 = torch.ones([1, 1])
4039 self.
checkScript(test_tensor_list_equality, (), optimize=
True)
4041 def test_invalid_list_equality():
4042 t1 = torch.ones([2, 2])
4043 t2 = torch.ones([2, 2])
4050 test_invalid_list_equality,
4053 "bool value of Tensor")
4055 def test_list_slice(self):
4056 def test_regular_slice():
4058 return a[2:3] == [2]
4061 def test_open_ended_slice():
4063 return a[2:] == [2, 3, 4]
4066 def test_open_ended_slice2():
4068 return a[:2] == [0, 1]
4071 def test_negative_slice():
4073 return a[:-1] == [0, 1, 2, 3]
4076 def test_negative_slice2():
4078 return a[-3:-1] == [2, 3]
4081 def test_backward_slice():
4086 def test_over_slice():
4088 return a[3:10] == [3, 4]
4091 def test_mutable_list_append(self):
4096 return a == [0, 1, 2, 3]
4099 def test_mutable_list_append_2(self):
4100 def test_append_2():
4108 def test_mutable_list_append_if(self):
4109 def test_append_if():
4116 def test_mutable_list_append_if_else(self):
4117 def test_append_if_else():
4126 def test_mutable_list_append_loop(self):
4127 def test_append_loop():
4132 return a == [0, 1, 2, 3, 4]
4135 def test_mutable_list_append_loop_if(self):
4136 def test_append_loop_if():
4144 return a == [0, 0, 0, 0, 4]
4147 def test_mutable_list_nested_loop(self):
4148 def test_nested_loop():
4154 return a == [0, 1, 1, 2]
4157 def test_mutable_list_function_inline(self):
4171 def test_mutable_list_reverse_empty(self):
4172 def test_reverse_empty():
4179 def test_mutable_list_reverse(self):
4184 return a == [4, 3, 2, 1]
4187 def test_mutable_tensor_list_reverse(self):
4188 def test_tensor_reverse():
4195 def test_mutable_list_pop_empty(self):
4197 def test_pop_empty():
4204 def test_mutable_list_pop(self):
4213 def test_mutable_list_pop2(self):
4222 def test_mutable_list_pop_at(self):
4231 def test_mutable_list_pop_at2(self):
4240 def test_mutable_list_pop_at_negative(self):
4241 def test_pop_at_negative():
4249 def test_mutable_list_pop_at_negative2(self):
4250 def test_pop_at_negative2():
4258 def test_mutable_list_pop_slice(self):
4259 def test_pop_slice():
4270 @unittest.skipIf(sys.version_info < (3, 3),
"clear not supported in version < 3.3")
4271 def test_mutable_list_clear_empty(self):
4272 def test_clear_empty():
4279 @unittest.skipIf(sys.version_info < (3, 3),
"clear not supported in version < 3.3")
4280 def test_mutable_list_clear(self):
4288 def test_mutable_list_insert(self):
4289 def test_list_insert():
4293 return a == [1, 2, 5, 3, 4]
4296 def test_mutable_list_insert_negative(self):
4297 def test_list_insert_negative():
4301 return a == [1, 2, 3, 5, 4]
4304 def test_mutable_list_insert_neg_out_of_bounds(self):
4305 def test_list_insert_neg_out_of_bounds():
4309 return a == [5, 1, 2, 3, 4]
4310 self.
checkScript(test_list_insert_neg_out_of_bounds, ())
4312 def test_mutable_list_insert_out_of_bounds(self):
4313 def test_list_insert_out_of_bounds():
4317 return a == [1, 2, 3, 4, 5]
4318 self.
checkScript(test_list_insert_out_of_bounds, ())
4320 def test_mutable_list_remove_not_existing(self):
4322 def test_list_remove_not_existing():
4329 test_list_remove_not_existing()
4331 def test_mutable_list_remove(self):
4332 def test_list_remove():
4336 return a == [1, 2, 4]
4339 def test_list_index_not_existing(self):
4341 def list_index_not_existing():
4348 list_index_not_existing()
4350 def test_list_index(self):
4358 def test_tensor_list_index(self):
4359 def tensor_list_index():
4366 def test_tensor_list_index_not_existing(self):
4368 def tensor_list_index_not_existing():
4375 tensor_list_index_not_existing()
4377 def test_list_count(self):
4385 def test_list_count_not_existing(self):
4386 def list_count_not_existing():
4393 def test_tensor_list_count(self):
4394 def tensor_list_count():
4401 def test_tensor_list_count_not_existing(self):
4402 def tensor_list_count_not_existing():
4407 self.
checkScript(tensor_list_count_not_existing, ())
4409 def test_mutable_list_remove_tensor(self):
4410 def test_list_remove_tensor():
4411 a = [torch.ones(1), torch.zeros(1), torch.ones(2)]
4412 a.remove(torch.zeros(1))
4417 def test_mutable_list_remove2(self):
4418 def test_list_remove2():
4425 def test_extend_list_mutable(self):
4427 def extend_list(a, b):
4433 for l
in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]:
4434 for r
in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]:
4437 def test_extend_list_immutable(self):
4439 def extend_list(a, b):
4445 for l
in [[], [1], [1, 2, 3]]:
4446 for r
in [[], [1], [1, 2, 3]]:
4449 def test_copy_list_mutable(self):
4455 for l
in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]:
4458 def test_copy_list_immutable(self):
4464 for l
in [[], [1], [1, 2, 3]]:
4467 def test_func_call(self):
4475 def func(alpha, beta, x, y): 4476 return add(mul(alpha, x), mul(beta, y)) 4478 alpha = torch.rand(1, dtype=torch.float, requires_grad=
True)
4479 beta = torch.rand(1, dtype=torch.float, requires_grad=
True)
4480 x = torch.rand(3, dtype=torch.float, requires_grad=
True)
4481 y = torch.rand(3, dtype=torch.float, requires_grad=
True)
4482 outputs = alpha * x + beta * y
4484 self.
checkScript(script, [alpha, beta, x, y], optimize=
False, outputs=outputs)
4486 def test_resize_input_ops(self):
4492 def out_op_graph_input():
4495 torch.mul(x, y, out=z)
4498 torch._C._jit_pass_shape_analysis(
4499 test.graph, (torch.zeros(2, 1), torch.zeros(1, 2), torch.zeros(1, 1, 1)),
False)
4500 self.assertTrue(next(test.graph.outputs()).type() == TensorType.get())
4501 out_op_graph_input()
4506 after_resize_alias = torch.zeros([2])
4510 before_resize_alias = b.sub_(1)
4514 after_resize_alias = b.add_(1)
4515 return after_resize_alias
4518 self.
run_pass(
'constant_propagation', g)
4519 torch._C._jit_pass_shape_analysis(
4520 g, (torch.zeros(1, 1),),
False)
4521 resize_node = g.findNode(
"aten::resize_")
4523 self.assertTrue(next(resize_node.inputs()).type() == TensorType.get())
4524 self.assertTrue(next(resize_node.outputs()).type() == TensorType.get())
4527 before_resize = g.findNode(
"aten::sub_")
4528 self.assertTrue(next(before_resize.outputs()).type() == TensorType.get())
4530 after_resize = g.findNode(
"aten::add_")
4531 self.assertTrue(next(after_resize.outputs()).type() == TensorType.get())
4535 def test_resize_as():
4538 b = torch.zeros([2, 2])
4543 self.
run_pass(
'constant_propagation', g)
4544 torch._C._jit_pass_shape_analysis(
4545 g, (torch.zeros(1, 1),),
False)
4548 self.assertTrue(next(g.inputs()).type() != TensorType.get())
4550 self.assertTrue(next(g.outputs()).type() == TensorType.get())
4554 def test_view_shape_prop(self):
4556 def test_view_shape_prop(a): 4557 return a.view(size=[-1]) 4559 inputs = [torch.zeros(10, 10)]
4560 outputs = torch.zeros(100)
4562 real_outs = cu.test_view_shape_prop(*inputs)
4565 def test_view_listconstruct_shape_prop(self):
4570 return x.view(T, B, C)
4572 x = torch.randn(3, 1, 5, requires_grad=
True)
4574 torch._C._jit_pass_shape_analysis(graph, (x,),
False)
4575 a = next(graph.outputs()).type().kind()
4576 self.assertTrue(next(graph.outputs()).type().kind() !=
'TensorType')
4578 def test_integral_shape_inference(self):
4580 def test_integral_shape_inference(a): 4583 inputs = [torch.ones(10, 10).type(torch.LongTensor)]
4584 outputs = torch.ones(10, 10)
4586 self.
assertEqual(cu.test_integral_shape_inference(*inputs), outputs)
4588 def test_fuser_multiple_blocks(self):
4590 def test_fuser_multiple_blocks(this, that, theother, meme): 4593 this = torch.cat([this, meme], dim=0) 4594 that = torch.cat([that, meme], dim=0) 4595 theother = torch.cat([theother, meme], dim=0) 4597 return this, that, theother 4600 inputs = [torch.ones(0, 10, 10)] * 3
4601 inputs += [torch.ones(1, 10, 10)]
4602 outputs = [torch.ones(20, 10, 10)] * 3
4604 self.
assertEqual(cu.test_fuser_multiple_blocks(*inputs), outputs)
4606 def test_dropout_script(self):
4608 eg = torch.zeros(1, 2, 3, requires_grad=
True)
4615 class MyDrop(nn.Module):
4616 def forward(self, x):
4622 @unittest.skip(
"RuntimeError: VariableType::ID() not implemented")
4623 def test_cast(self):
4628 x = Variable(torch.FloatTensor([1.1, 2.3]), requires_grad=
True)
4629 out = Variable(torch.IntTensor([1, 2]), requires_grad=
True)
4630 self.
checkScript(script, [x], optimize=
True, outputs=[out], func=
'to_int')
4632 def test_python_frontend(self):
4635 q = x + y - z.sigmoid()
4638 if not x
and not y
and z:
4639 m = x
if not z
else y
4642 assert 1 == 1,
"hello" 4648 @unittest.skipIf(
not PY2,
"Requires python 2")
4649 def test_python_frontend_py2(self):
4651 raise Exception(
"hello")
4655 @unittest.skipIf(PY2,
"Requires python 3")
4656 def test_python_frontend_py3(self):
4658 raise Exception(
"hello")
4662 def _make_scalar_vars(self, arr, dtype):
4665 def test_string_print(self):
4667 print(a,
"a" 'b' '''c''' """d""", 2, 1.5)
4671 self.
checkScript(func, inputs, capture_output=
True)
4673 def test_while(self):
4674 def func(a, b, max):
4675 while bool(a < max):
4684 def test_fibb(self):
4692 while bool(i < lim):
4693 third = first + second
4698 somenum = somenum * 2
4701 i = i + dontmutateme
4705 return third, st, fs
4725 def test_if_for_in_range(self):
4740 def test_if_noelse(self):
4750 def test_if_is_none_dispatch(self):
4753 def test_lhs_none_rhs_none():
4758 elif None is not None:
4763 self.assertTrue(str(test_lhs_none_rhs_none.graph).count(
': int = prim::Constant') == 1)
4766 def test_lhs_opt_rhs_none(lhs=None):
4776 self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(
': int = prim::Constant') == 3)
4779 def test_lhs_none_rhs_opt(rhs=None):
4784 elif None is not rhs:
4789 self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(
': int = prim::Constant') == 3)
4792 def test_lhs_never_rhs_none(lhs):
4797 elif lhs
is not None:
4802 self.assertTrue(str(test_lhs_never_rhs_none.graph).count(
': int = prim::Constant') == 1)
4805 def test_lhs_none_rhs_never(rhs):
4810 elif None is not rhs:
4815 self.assertTrue(str(test_lhs_none_rhs_never.graph).count(
': int = prim::Constant') == 1)
4817 def test_explicit_bool_cast(self):
4820 def test_bool_cast(a):
4825 def test_while_nonexistent_value(self):
4828 def test_while(a, b): 4835 def test_while_nonexistent_cond_value(self):
4838 def test_while(a, b): 4845 def test_optional_refinement(self):
4847 def test_if_none_assignment(x):
4856 def test_ternary(x):
4858 x = x
if x
is not None else 2
4862 def test_not_none(x):
4870 if x
is not None and y
is not None:
4876 if not (x
is not None and y
is not None):
4882 def test_bool_expression(x):
4884 if x
is not None and x < 2:
4888 def test_nested_bool_expression(x, y):
4890 if x
is not None and x < 2
and y
is not None:
4899 if y
is None or x
is None:
4906 def test_manual_unwrap_opt(x):
4918 if x
is None or y
is None:
4923 def and_error(x, y):
4925 if x
is None and y
is None:
4934 x_none = x
is not None 4940 def named_var_and(x, y):
4942 x_none = x
is not None 4943 if y
is not None and x_none:
4946 def test_while_write_outer_then_read(self):
4956 def test_while_nest_if(self):
4972 def test_math_ops(self):
4975 return math.floor(1.5)
4979 def test_if_nest_while(self):
4992 def test_script_for_in_range(self):
4995 for i
in range(100):
4998 self.
checkScript(fn, (), outputs=4950, optimize=
True)
5000 def test_script_for_in_range_dynamic(self):
5003 for i
in range(100):
5011 def test_script_for_in_range_ast(self):
5013 def test_script_for_in_range_ast():
5015 for i
in range(100):
5022 self.
assertEqual(test_script_for_in_range_ast(), 161700)
5024 def test_script_for_in_range_if_ast(self):
5026 def test_script_for_in_range_if_ast(x):
5030 output = x.unsqueeze(0)
5032 output = torch.cat((output, x.unsqueeze(0)), dim=0)
5036 self.
assertEqual(test_script_for_in_range_if_ast(*inputs).shape[0], 20)
5038 def test_script_optional_none(self):
5048 self.
checkScript(none_stmt, [torch.arange(0, 2)], optimize=
True)
5049 self.
checkScript(none_args, [
None], optimize=
True)
5052 def test_script_optional_tensor_none(x=None):
5054 res = torch.zeros(1, dtype=torch.int8)
5061 fn = test_script_optional_tensor_none
5064 self.
assertEqual(fn(torch.zeros(1)), scripted_fn(torch.zeros(1)))
5067 def test_script_optional_other_none(x=None):
5076 fn = test_script_optional_other_none
5081 def test_script_clamp_none(self):
5082 def test_script_clamp_max_none(x):
5083 return torch.clamp(x, min=2, max=
None)
5085 def test_script_clamp_max(x):
5086 return torch.clamp(x, max=2)
5088 def test_script_clamp_min_none(x):
5089 return torch.clamp(x, min=
None, max=2)
5091 def test_script_clamp_min(x):
5092 return torch.clamp(x, min=2)
5094 input = [torch.arange(0, 3)]
5095 self.
checkScript(test_script_clamp_max_none, input, optimize=
True)
5096 self.
checkScript(test_script_clamp_max, input, optimize=
True)
5097 self.
checkScript(test_script_clamp_min_none, input, optimize=
True)
5098 self.
checkScript(test_script_clamp_min, input, optimize=
True)
5100 def test_script_bool_constant(self):
5102 def test_script_bool_constant(): 5107 self.
checkScript(script, [], outputs[0],
True,
'test_script_bool_constant')
5109 def test_ternary(self):
5112 c = a + b
if bool(a > 3)
else b
5117 self.
checkScript(func, inputs_true, optimize=
True)
5118 self.
checkScript(func, inputs_false, optimize=
True)
5120 def test_print(self):
5122 q = (x + y).sigmoid()
5123 print(q, 1, 2, [1, 2], [1.0, 2.0])
5127 x = torch.arange(4., requires_grad=
True)
5128 y = torch.arange(0., 8, 2, requires_grad=
True)
5129 self.
checkScript(func, [x, y], optimize=
True, capture_output=
True)
5131 def test_format(self):
5133 print(
"{}, I'm a {}".format(
"Hello",
"test"))
5134 print(
"format blank".format())
5135 print(
"stuff before {}".format(
"hi"))
5136 print(
"{} stuff after".format(
"hi"))
5139 x = torch.arange(4., requires_grad=
True)
5140 self.
checkScript(func, [x], optimize=
True, capture_output=
True)
5142 def test_logical_short_circuit(self):
5144 def testNoThrows(t):
5146 if (
False and bool(t[1]))
or (
True or bool(t[1])):
5151 ifs = testNoThrows.graph.findAllNodes(
"prim::If", recurse=
False)
5155 self.assertTrue(len(ifs) == 3)
5156 self.assertTrue(ifs[0].findNode(
"prim::If")
is None)
5157 self.assertTrue(ifs[1].findNode(
"prim::If").findNode(
"prim::If")
is None)
5158 self.assertTrue(ifs[2].findNode(
"prim::If")
is None)
5162 c0 =
False or bool(t[1])
5167 c0 =
True and bool(t[1])
5171 with self.
assertRaisesRegex(RuntimeError,
"index 1 out of range for tensor of size"):
5173 with self.
assertRaisesRegex(RuntimeError,
"index 1 out of range for tensor of size"):
5176 def test_type_cast(self):
5177 template = dedent(
''' 5179 # type: ({from_type}) -> {to_type} 5183 def check_cast(from_type, to_type, value, raises=False):
5184 code = template.format(from_type=from_type, to_type=to_type)
5185 expected = getattr(builtins, to_type)(value)
5190 self.
checkScript(code, (value,), name=
'cast', outputs=expected)
5192 check_cast(
'int',
'float', 1)
5193 check_cast(
'int',
'bool', 1)
5194 check_cast(
'int',
'bool', 0)
5196 check_cast(
'float',
'int', 1.)
5197 check_cast(
'float',
'bool', 1.)
5198 check_cast(
'float',
'bool', 0.)
5200 check_cast(
'bool',
'int',
True)
5201 check_cast(
'bool',
'float',
True)
5203 def test_multiple_assignment(self):
5209 y, z = outer_func(x)
5215 def test_literals(self):
5217 return a.view(size=[1, 2, 3])
5222 def test_return(self):
5232 def multiple_returns(a):
5233 return a * 1., a * 2., a * 3.
5235 a = torch.randn(1, dtype=torch.float)
5236 self.checkScript(no_return, [a], optimize=
True)
5237 self.checkScript(void_return, [a], optimize=
True)
5238 self.checkScript(one_return, [a], optimize=
True)
5239 self.checkScript(multiple_returns, [a], optimize=
True)
5241 with self.assertRaisesRegex(RuntimeError,
"but is actually of type None"):
5243 def no_return_bad_annotation(a): 5244 # type: (Tensor) -> Tensor 5248 def test_error(self):
5252 s = Variable(torch.rand(5, 5, 5))
5254 with self.assertRaisesRegex(RuntimeError,
"failed in interpreter"):
5261 with self.assertRaisesRegex(RuntimeError,
"failed in interpreter"):
5262 bar(Variable(torch.rand(10), requires_grad=
True), Variable(torch.rand(9), requires_grad=
True))
5264 def test_binop_unsupported_error(self):
5265 with self.assertRaisesRegex(NotSupportedError,
"unsupported binary operator:"):
5271 def test_bitwise_ops(self):
5274 return 2 & 3, 2 ^ 3, 2 | 3
5276 self.checkScript(int_test, ())
5278 def bool_test(x, y):
5280 return x & y, x ^ y, x | y
5282 self.checkScript(bool_test, (
True,
False))
5283 self.checkScript(bool_test, (
True,
True))
5285 def tensor_test(x, y):
5286 return x & y, x ^ y, x | y
5291 self.checkScript(tensor_test, (x, y))
5293 def test_number_math(self):
5294 ops_template = dedent(
''' 5296 return {scalar1} {op} {scalar2} 5298 ops = [
'+',
'-',
'*',
'%',
'<',
'<=',
'>',
'>=',
'==',
'!=',
'//']
5299 funcs_template = dedent(
''' 5301 return {func}({scalar1}, {scalar2}) 5303 funcs = [
'min',
'max']
5304 scalars = [
'7',
'2',
'3',
'-3',
'3.14',
'0.125',
'-0.5',
'2.0',
'-2.0']
5305 scalar_pairs = [(scalar1, scalar2)
for scalar1
in scalars
for scalar2
in scalars]
5309 execWrapper(code, globals(), scope)
5312 self.assertEqual(cu.func(), scope[
'func']())
5314 for scalar1, scalar2
in scalar_pairs:
5316 code = ops_template.format(op=op, scalar1=scalar1, scalar2=scalar2)
5319 code = funcs_template.format(func=func, scalar1=scalar1, scalar2=scalar2)
5322 def test_number_div(self):
5323 self.checkScript(div_int_future, (), optimize=
True)
5324 self.checkScript(div_float_future, (), optimize=
True)
5327 with self.assertRaisesRegex(RuntimeError,
'from __future__ import division'):
5329 with self.assertRaisesRegex(RuntimeError,
'from __future__ import division'):
5332 self.checkScript(div_int_nofuture, (), optimize=
True)
5333 self.checkScript(div_float_nofuture, (), optimize=
True)
5335 def test_floor_div(self):
5340 for i
in range(-8, 8):
5341 for j
in range(-8, 8):
5343 self.assertEqual(foo(i, j), i // j)
5345 with self.assertRaisesRegex(RuntimeError,
'division by 0'):
5348 def test_number_augassign(self):
5354 self.checkScript(func, (), optimize=
True)
5356 def test_number_neg(self):
5365 self.checkScript(func1, (), optimize=
True)
5366 self.checkScript(func2, (), optimize=
True)
5368 def _test_tensor_number_math(self, device='cpu'):
5369 template = dedent(
''' 5371 return {lhs} {op} {rhs} 5374 def test(op, const, swap_args):
5379 code = template.format(lhs=args[0], rhs=args[1], op=op)
5381 execWrapper(code, globals(), scope)
5383 self.assertEqual(cu.func(tensor), scope[
'func'](tensor))
5386 var_float = [1.4321, -1.2]
5388 ops = [
'+',
'-',
'*',
'%',
'<',
'<=',
'>',
'>=',
'==',
'!=',
'/']
5390 float_tensor = torch.randn(5, 5, device=device)
5391 double_tensor = torch.randn(5, 5, dtype=torch.double, device=device)
5392 long_tensor = torch.randint(-5, 5, (5, 5), dtype=torch.long, device=device)
5393 long_tensor[long_tensor == 0] = 2
5395 tensors = [float_tensor, double_tensor, long_tensor]
5396 consts = var_int + var_float
5398 for op, tensor, const, swap_args
in product(ops, tensors, consts, [
True,
False]):
5401 if op ==
'/' and tensor.data_ptr() == long_tensor.data_ptr():
5405 if op ==
'%' and swap_args
is True:
5408 test(op, const, swap_args)
5410 def test_tensor_number_math(self):
5411 self._test_tensor_number_math()
5413 def test_torch_tensor_bad_input(self):
5414 with self.assertRaisesRegex(RuntimeError,
"Input list to torch.tensor must be of ints, floats, " 5415 "or bools, got None"):
5420 with self.assertRaisesRegex(RuntimeError,
"Note: empty lists are constructed as Tensor"):
5428 with self.assertRaisesRegex(RuntimeError,
"Expected sequence of length"):
5432 def test_torch_tensor_empty_list(self):
5440 self.assertNotEqual(t1.dtype, t2.dtype)
5446 self.checkScript(func, ())
5452 self.checkScript(func, ())
5454 def test_torch_tensor(self):
5455 template = dedent(
''' 5458 return torch.tensor(li {options}) 5461 lists = [
"2.5",
"4",
"True",
"False",
"[2]",
"[-.5]",
"[False, True, False]",
"[2, 2]",
5462 "torch.jit.annotate(List[int], [])",
"[2.5, 2.5]",
"[[2], [2]]",
"[[-.5], [2.2]]",
"[[False], [True]]"]
5464 dtypes = [
"",
", dtype=torch.float",
", dtype=torch.double",
", dtype=torch.half",
5465 ", dtype=torch.uint8",
", dtype=torch.int8",
", dtype=torch.short",
5466 ", dtype=torch.int",
", dtype=torch.long"]
5468 devices = [
'',
", device='cpu'"]
5470 devices.append(
", device='cuda'")
5472 option_pairs = [dtype + device
for dtype
in dtypes
for device
in devices]
5474 for option
in option_pairs:
5476 if "annotate" in li
and "dtype" not in option:
5478 code = template.format(list_create=li, options=option)
5480 exec(code, globals(), scope)
5483 t2 = scope[
'func']()
5484 if t1.dtype == torch.float16:
5485 self.assertTrue(str(t1) == str(t2))
5487 self.assertEqual(t1, t2)
5488 self.assertEqual(t1.dtype, t2.dtype)
5489 self.assertEqual(t1.device, t2.device)
5492 def test_tensor_to(self):
5493 template = dedent(
''' 5497 non_blocking = {non_blocking} 5501 def s(t, to_str, non_blocking=None, device=None, cuda=None):
5502 device = device
if device
is not None else str(t.device)
5503 non_blocking = non_blocking
if non_blocking
is not None else False 5504 cuda =
"cuda" if cuda
is None else cuda
5505 code = template.format(to_str=to_str, device=device, non_blocking=non_blocking, cuda=cuda)
5510 def test_copy_behavior(t, non_blocking=False):
5511 self.assertIs(t, s(t,
't.to(t, non_blocking=non_blocking)', non_blocking))
5512 self.assertIs(t, s(t,
't.to(t.dtype, non_blocking=non_blocking)', non_blocking))
5513 self.assertIs(t, s(t,
't.to(torch.empty_like(t), non_blocking=non_blocking)', non_blocking))
5514 self.assertIsNot(t, s(t,
't.to(t, non_blocking=non_blocking, copy=True)', non_blocking))
5515 self.assertIsNot(t, s(t,
't.to(t.dtype, non_blocking=non_blocking, copy=True)', non_blocking))
5516 self.assertIsNot(t, s(t,
't.to(torch.empty_like(t), non_blocking=non_blocking, copy=True)', non_blocking))
5518 devices = [t.device]
5519 if t.device.type ==
'cuda':
5520 if t.device.index == -1:
5523 devices.append(
'cuda')
5524 for device
in devices:
5525 self.assertIs(t, s(t,
't.to(device, non_blocking=non_blocking)', non_blocking, device))
5526 self.assertIs(t, s(t,
't.to(device, t.dtype, non_blocking=non_blocking)', non_blocking, device))
5527 self.assertIsNot(t, s(t,
't.to(device, non_blocking=non_blocking, copy=True)', non_blocking, device))
5528 self.assertIsNot(t, s(t,
't.to(device, t.dtype, non_blocking=non_blocking, copy=True)',
5529 non_blocking, device))
5532 test_copy_behavior(t)
5534 self.assertEqual(t.device, s(t,
"t.to('cpu')").device)
5535 self.assertEqual(t.device, s(t,
"t.to('cpu', dtype=torch.float32)").device)
5536 self.assertIs(torch.float32, s(t,
"t.to('cpu', dtype=torch.float32)").dtype)
5537 self.assertEqual(t.device, s(t,
"t.to(torch.float32)").device)
5538 self.assertIs(torch.float32, s(t,
"t.to(dtype=torch.float32)").dtype)
5539 self.assertEqual(t.data_ptr(), s(t,
"t.to('cpu')").data_ptr())
5540 self.assertEqual(t.data_ptr(), s(t,
"t.to(dtype=t.dtype, device=t.device, copy=False)").data_ptr())
5541 self.assertEqual(t.data_ptr(), s(t,
"t.to('cpu', copy=False)").data_ptr())
5542 self.assertNotEqual(t.data_ptr(), s(t,
"t.to('cpu', copy=True)").data_ptr())
5546 for non_blocking
in [
True,
False]:
5549 test_copy_behavior(b, non_blocking)
5550 self.assertEqual(b.device, s(b,
"t.to(cuda, non_blocking=non_blocking).device", cuda=cuda))
5551 self.assertEqual(a.device, s(b,
"t.to('cpu', non_blocking=non_blocking).device"))
5552 self.assertEqual(b.device, s(b,
"t.to(cuda, non_blocking=non_blocking).device", cuda=cuda))
5553 self.assertIs(torch.int32, s(b,
"t.to('cpu', dtype=torch.int32, non_blocking=non_blocking)").dtype)
5554 self.assertEqual(a.device, s(b,
"t.to('cpu', dtype=torch.int32, non_blocking=non_blocking)").device)
5555 self.assertIs(torch.int32, s(b,
"t.to(dtype=torch.int32)").dtype)
5556 self.assertEqual(b.device, s(b,
"t.to(dtype=torch.int32)").device)
5560 out_ref = t.to(torch.float32)
5561 out = s(t,
"t.to(torch.float32)")
5562 self.assertEqual(out_ref, out)
5566 self.assertEqual(grad_ref, grad)
5569 out_ref = t.to(
'cpu')
5570 out = s(t,
"t.to('cpu')")
5571 self.assertEqual(out_ref, out)
5575 self.assertEqual(grad_ref, grad)
5579 def func2(t, t_ref):
5582 func2.debug_disable_autodiff_subgraph_inlining()
5585 out_ref = t.to(t_ref)
5586 out = func2(t, t_ref)
5589 self.assertEqual(grad_ref, grad)
5591 @unittest.skipIf(
not RUN_CUDA,
"No CUDA")
5592 def test_tensor_number_math_cuda(self):
5593 self._test_tensor_number_math(device=
'cuda')
5599 return not bool(a > 1)
5601 self.checkScript(test_not_op, (
torch.tensor(2), ), optimize=
True)
5603 def test_is_isnot(self):
5605 template = dedent(
''' 5608 return {lhs} {op} {rhs} 5612 code = template.format(lhs=args[0], rhs=args[1], op=op)
5614 execWrapper(code, globals(), scope)
5619 "Failed with op: {}, lhs: {}, rhs: {}" 5620 .format(op, args[0], args[1])
5623 ops = [
'is',
'is not']
5624 type_literals = [
True,
False,
None, [1, 1]]
5627 for op, lhs, rhs
in product(ops, type_literals, type_literals):
5628 test(op, [lhs, rhs])
5630 def test_isinstance(self):
5632 template = dedent(
''' 5634 # type: ({type_hint}) -> bool 5635 return isinstance(x, {typ}) 5638 def test(inp, typ, type_hint):
5639 code = template.format(typ=typ, type_hint=type_hint)
5641 execWrapper(code, globals(), scope)
5646 "Failed with typ: {}" 5650 inputs = [
True, 1, 1.0,
torch.tensor(1), [1, 2], (1.0,), [1, 2], 1]
5651 type_literals = [
'bool',
'int',
'float',
'torch.Tensor',
'list',
'tuple',
5652 '(list, tuple)',
'(int, float, bool)']
5653 type_annotations = [
'bool',
'int',
'float',
'Tensor',
'List[int]',
'Tuple[float]',
5657 for inp, typ, type_hint
in zip(inputs, type_literals, type_annotations):
5658 test(inp, typ, type_hint)
5661 with self.assertRaisesRegex(RuntimeError,
"Optional isinstance check is not supported"):
5665 return isinstance(x, int)
5667 def test_python_call(self):
5675 def test_call_python(a): 5687 inputs = self._make_scalar_vars([1], torch.float)
5688 outputs = self._make_scalar_vars([54], torch.float)
5690 self.assertEqual(cu.test_call_python(*inputs), outputs[0])
5692 def test_python_call_failure(self):
5693 with self.assertRaisesRegex(RuntimeError,
"undefined value pyfunc2"):
5701 def test_call_python(a): 5713 inputs = self._make_scalar_vars([1], torch.float)
5714 outputs = self._make_scalar_vars([54], torch.float)
5716 self.assertEqual(cu.test_call_python(*inputs), outputs)
5718 def test_python_call_annotation(self):
5724 return pyfunc(a) + pyfunc(a)
5726 inputs = self._make_scalar_vars([1], torch.float)
5727 outputs = self._make_scalar_vars([6], torch.float)
5728 self.assertEqual(foo(*inputs), outputs[0])
5730 def test_python_call_annoytation_failure(self):
5731 with self.assertRaisesRegex(RuntimeError,
"undefined value pyfunc2"):
5737 return pyfunc2(a) + pyfunc(a)
5739 inputs = self._make_scalar_vars([1], torch.float)
5740 outputs = self._make_scalar_vars([6], torch.float)
5742 self.assertEqual(foo(*inputs), outputs[0])
5744 def test_desugar_module(self):
5750 c = F.prelu(x, slope)
5753 x = torch.arange(-3., 4)
5755 self.checkScript(fn, [x, slope], optimize=
True)
5757 def test_script_docstring(self):
5759 def with_docstring(x):
5762 """y is the same as x""" 5764 self.assertEqual(with_docstring.__doc__,
'test str')
5766 def test_script_method_docstring(self):
5768 @torch.jit.script_method
5769 def with_docstring(self, x):
5772 """y is the same as x""" 5775 self.assertEqual(a.with_docstring.__doc__,
'test str')
5777 @unittest.skipIf(TEST_WITH_UBSAN
or not torch.fbgemm_is_cpu_supported(),
5778 'Quantized RNN requires FBGEMM. FBGEMM does not play' 5779 ' well with UBSAN at the moment, so we skip the test if' 5780 ' we are in a UBSAN environment.')
5781 def test_rnn_cell_quantized(self):
5785 torch.nn.LSTMCell(d_in, d_hid).float(),
5786 torch.nn.GRUCell(d_in, d_hid).float(),
5787 torch.nn.RNNCell(d_in, d_hid).float(),
5789 if isinstance(cell, torch.nn.LSTMCell):
5791 elif isinstance(cell, torch.nn.GRUCell):
5793 elif isinstance(cell, torch.nn.RNNCell):
5807 vals = [[100, -155],
5815 vals = vals[:d_hid * num_chunks]
5816 cell.weight_ih = torch.nn.Parameter(
5818 requires_grad=
False)
5819 cell.weight_hh = torch.nn.Parameter(
5821 requires_grad=
False)
5823 ref = copy.deepcopy(cell)
5828 [100, -155]], dtype=torch.float)
5829 h0_vals = [[-155, 100],
5841 def __init__(self, cell):
5842 super(ScriptWrapper, self).__init__()
5845 @torch.jit.script_method
5846 def forward(self, x, hiddens):
5848 return self.cell(x, hiddens)
5852 def __init__(self, cell):
5853 super(ScriptWrapper, self).__init__()
5856 @torch.jit.script_method
5857 def forward(self, x, hiddens):
5859 return self.cell(x, hiddens)
5861 cell = ScriptWrapper(cell)
5862 outs = cell(x, hiddens)
5863 cell = self.getExportImportCopyWithPacking(cell)
5865 outs = cell(x, hiddens)
5866 ref_outs = ref(x, hiddens)
5868 self.assertEqual(len(outs), len(ref_outs))
5869 for out, ref_out
in zip(outs, ref_outs):
5872 def test_script_module(self):
5875 super(M1, self).__init__(
False)
5876 self.weight = nn.Parameter(torch.randn(2))
5878 @torch.jit.script_method
5879 def forward(self, thing):
5880 return self.weight + thing
5882 class PModule(nn.Module):
5884 super(PModule, self).__init__()
5885 self.a = nn.Parameter(torch.randn(2, 3))
5887 def forward(self, a):
5892 super(M2, self).__init__(
False)
5895 self.sub2 = PModule()
5897 self.weight = nn.Parameter(torch.randn(2, 3))
5898 self.bias = nn.Parameter(torch.randn(2))
5902 return self.weight.mm(a) 5906 @torch.jit.script_method
5907 def doit(self, input):
5909 return self.weight.mm(input)
5911 @torch.jit.script_method
5912 def doit2(self, input):
5913 return self.weight.mm(input)
5915 @torch.jit.script_method
5916 def forward(self, input):
5917 a = self.doit(input)
5918 b = self.doit2(input)
5920 d = self.sub2(input)
5921 return a + b + self.bias + self.sub(a) + c + d
5923 input = torch.randn(3, 2)
5924 a = m2.weight.mm(input)
5925 b = m2.weight.mm(input)
5926 c = m2.weight.mm(input)
5927 d = m2.sub2.a.mm(input)
5928 ref = a + b + m2.bias + m2.sub.weight + a + c + d
5929 self.assertEqual(ref, m2.forward(input))
5930 m2.weight = nn.Parameter(torch.zeros_like(m2.weight))
5931 m2.bias = nn.Parameter(torch.zeros_like(m2.bias))
5932 m2.sub.weight = nn.Parameter(torch.zeros_like(m2.sub.weight))
5933 m2.sub2.a.data.zero_()
5934 self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2)))
5936 def test_filecheck(self):
5939 FileCheck().check(
"2").check(
"3").check(
"2").run(file)
5940 FileCheck().check(
"232").run(file)
5942 with self.assertRaisesRegex(RuntimeError,
'Expected to find "22"'):
5943 FileCheck().check(
"22").run(file)
5944 with self.assertRaisesRegex(RuntimeError,
"CHECK: 3"):
5945 FileCheck().check(
"3").check(
"3").run(file)
5949 def test_check_count():
5951 FileCheck().check_count(
"2", 5).run(file)
5952 FileCheck().check_count(
"22", 2).run(file)
5953 FileCheck().check_count(
"222", 1).run(file)
5955 with self.assertRaisesRegex(RuntimeError,
'Expected to not find'):
5956 FileCheck().check_count(
"2", 4, exactly=
True).run(file)
5958 with self.assertRaisesRegex(RuntimeError,
'Expected to find "22"'):
5959 FileCheck().check_count(
"22", 3).run(file)
5961 with self.assertRaisesRegex(RuntimeError,
"CHECK-COUNT-6: 2"):
5962 FileCheck().check_count(
"2", 6).run(file)
5966 def test_check_same():
5968 FileCheck().check_same(
"22").run(file)
5970 with self.assertRaisesRegex(RuntimeError,
"Expected to not find"):
5971 FileCheck().check_same(
"33").run(file)
5975 FileCheck().check(
"2").check_same(
"3").run(file)
5976 FileCheck().check_count(
"2", 2).check_same(
"3").run(file)
5980 def test_check_next():
5982 FileCheck().check(
"1").check_next(
"2").check_next(
"3").run(file)
5983 FileCheck().check_next(
"1").check_next(
"2").check_next(
"3").run(file)
5985 with self.assertRaisesRegex(RuntimeError,
"Expected to find"):
5986 FileCheck().check(
"1").check_next(
"2").run(
"12")
5988 with self.assertRaisesRegex(RuntimeError,
"Expected to not find"):
5989 FileCheck().check(
"1").check_next(
"2").run(
"1\n\n2")
5993 def test_check_dag():
5994 fc = FileCheck().check_dag(
"1").check_dag(
"2").check_not(
"2")
5999 fc.check_not(
"3").check_dag(
"1").check_dag(
"2").check_not(
"3")
6003 fc = FileCheck().check_dag(
"1").check_dag(
"2").check(
"3")
6004 with self.assertRaisesRegex(RuntimeError,
'Expected to find "3" but did not find it'):
6009 def test_check_not():
6010 FileCheck().check_not(
"2").check(
"1").run(
"12")
6011 FileCheck().check(
"2").check_not(
"2").run(
"12")
6013 with self.assertRaisesRegex(RuntimeError,
'Expected to not find "2"'):
6014 FileCheck().check_not(
"2").check(
"1").run(
"21")
6016 with self.assertRaisesRegex(RuntimeError,
'Expected to not find "1"'):
6017 FileCheck().check(
"2").check_not(
"1").run(
"21")
6020 fb = FileCheck().check_count(
"2", 2).check_count(
"2", 2).check_not(
"2")
6021 with self.assertRaisesRegex(RuntimeError,
'Expected to not find "2"'):
6024 fb = FileCheck().check_count(
"2", 2).check_not(
"1").check_count(
"2", 2)
6025 with self.assertRaisesRegex(RuntimeError,
'Expected to not find "1"'):
6028 def test_script_module_call_noscript(self):
6031 super(M, self).__init__(
False)
6035 return torch.ones(2, 2) + self.value
6037 @torch.jit.script_method
6038 def forward(self, input):
6039 return input + self.foo()
6042 input = torch.randn(2, 2)
6044 self.assertEqual(o, input + torch.ones(2, 2) + 1)
6049 self.assertEqual(o, input + torch.ones(2, 2) + 2)
6051 def test_script_module_nochange_submodule(self):
6054 super(M, self).__init__(
False)
6055 self.sub = nn.Linear(5, 5)
6057 @torch.jit.script_method
6058 def forward(self, input):
6059 return self.sub(input)
6062 input = torch.randn(1, 5, 5)
6064 self.assertEqual(o, m.sub(input))
6065 with self.assertRaisesRegex(RuntimeError,
"cannot re-assign"):
6066 m.sub = nn.Linear(5, 5)
6068 def test_script_inline_trace_multiple_args(self):
6071 super(M, self).__init__(
False)
6073 def forward(self, input, input2):
6074 return input + input2
6078 super(M2, self).__init__(
False)
6081 @torch.jit.script_method
6082 def forward(self, inp):
6083 return self.m(inp, inp)
6086 m2(torch.zeros(4, 3))
6088 def test_script_module_const(self):
6091 __constants__ = [
'b',
'i',
'c']
6094 super(M, self).__init__(
False)
6099 @torch.jit.script_method
6101 return self.b, self.i, self.c
6105 self.assertEqual(o0, 0)
6106 self.assertEqual(o1, 1)
6107 self.assertEqual(o2, 3.5)
6109 def test_script_module_fail_const(self):
6112 super(M, self).__init__(
False)
6115 @torch.jit.script_method
6118 with self.assertRaisesRegex(RuntimeError,
"is not usable in a script method"):
6121 def test_script_module_valid_consts(self):
6125 __constants__ = [
'a',
'b',
'c',
'd',
'e',
'f',
'g',
'h',
'i']
6128 super(Foo, self).__init__(
False)
6132 with tester.assertRaisesRegex(
6134 "'Linear' object for attribute 'd' is not a valid constant"):
6135 self.d = [nn.Linear(3, 4)]
6136 self.e =
lambda x: x
6138 tester.assertTrue(type(self.f)
is tuple)
6139 self.g = [3, (3, 4), 5]
6140 with tester.assertRaisesRegex(TypeError,
"not a valid constant"):
6142 with tester.assertRaisesRegex(TypeError,
"not a valid constant"):
6147 def test_script_module_param_buffer_mutation(self):
6151 super(ModuleBufferMutate, self).__init__(
False)
6152 self.register_buffer(
'running_var',
torch.tensor(0, dtype=torch.long))
6154 @torch.jit.script_method
6157 self.running_var += 1
6158 return self.running_var
6160 m = ModuleBufferMutate()
6161 self.assertEqual(m(), 1)
6163 self.assertEqual(m(), 1)
6165 def test_script_module_for(self):
6167 __constants__ = [
'b']
6170 super(M, self).__init__(
False)
6171 self.b = [1, 2, 3, 4]
6173 @torch.jit.script_method
6181 self.assertEqual(m(), 10)
6183 def test_script_module_for2(self):
6186 super(Sub, self).__init__(
False)
6187 self.weight = nn.Parameter(torch.randn(2))
6189 @torch.jit.script_method
6190 def forward(self, thing):
6191 return self.weight + thing
6194 __constants__ = [
'mods']
6197 super(M, self).__init__(
False)
6198 self.mods = nn.ModuleList([Sub()
for i
in range(10)])
6200 @torch.jit.script_method
6201 def forward(self, v):
6212 self.assertEqual(o, v)
6214 def test_script_module_const_submodule_fail(self):
6217 super(Sub, self).__init__(
False)
6218 self.weight = nn.Parameter(torch.randn(2))
6220 @torch.jit.script_method
6221 def forward(self, thing):
6222 return self.weight + thing
6226 super(M, self).__init__(
False)
6227 self.mods = [Sub()
for _
in range(10)]
6229 @torch.jit.script_method
6235 with self.assertRaisesRegex(RuntimeError,
"did you forget to add it __constants__"):
6241 self.tensor_constant = torch.ones(2)
6243 @torch.jit.script_method
6245 return self.tensor_constant + 2
6247 with self.assertRaisesRegex(RuntimeError,
"Tensors must be added to a module as a buffer or parameter"):
6253 self.
param = torch.nn.Parameter(torch.ones(3, 4, dtype=torch.float))
6254 self.register_buffer(
'derived', torch.neg(self.
param).detach().clone())
6257 self.register_buffer(
'pack_called', torch.zeros(1, dtype=torch.long))
6259 self.register_buffer(
'unpack_called', torch.zeros(1, dtype=torch.long))
6261 @torch.jit.script_method
6263 self.pack_called.set_(torch.ones(1, dtype=torch.long))
6264 self.derived.set_(torch.rand(1, dtype=torch.float).detach())
6266 @torch.jit.script_method
6268 self.unpack_called.set_(torch.ones(1, dtype=torch.long))
6269 self.derived.set_(torch.neg(self.
param).detach())
6271 @torch.jit.script_method
6272 def forward(self, x):
6273 return x + self.derived
6275 def test_pack_unpack_state(self):
6277 x = torch.rand(3, 4, dtype=torch.float)
6281 self.assertFalse(sm.pack_called.item())
6282 self.assertFalse(sm.unpack_called.item())
6283 imported = self.getExportImportCopyWithPacking(sm)
6285 self.assertTrue(sm.pack_called.item())
6287 self.assertTrue(sm.unpack_called.item())
6292 self.assertTrue(imported.unpack_called.item())
6295 def test_pack_unpack_nested(self):
6298 super(SubSubMod, self).__init__()
6299 self.register_buffer(
'buf', torch.ones(3, 4) * 3)
6301 @torch.jit.script_method
6303 self.buf.set_(torch.zeros(1, dtype=torch.double))
6305 @torch.jit.script_method
6307 self.buf.set_(torch.ones(3, 4, dtype=torch.double) * 3)
6309 @torch.jit.script_method
6310 def forward(self, x):
6315 super(SubMod, self).__init__()
6316 self.register_buffer(
'buf', torch.ones(3, 4) * 2)
6317 self.ssm = SubSubMod()
6319 @torch.jit.script_method
6321 self.buf.set_(torch.zeros(1, dtype=torch.double))
6323 @torch.jit.script_method
6325 self.buf.set_(torch.ones(3, 4, dtype=torch.double) * 2)
6327 @torch.jit.script_method
6328 def forward(self, x):
6329 return self.ssm(x + self.buf)
6333 super(Mod, self).__init__()
6334 self.submod = SubMod()
6335 self.register_buffer(
'buf', torch.ones(3, 4) * 1)
6337 @torch.jit.script_method
6339 self.buf.set_(torch.zeros(1, dtype=torch.double))
6341 @torch.jit.script_method
6343 self.buf.set_(torch.ones(3, 4, dtype=torch.double))
6345 @torch.jit.script_method
6346 def forward(self, x):
6347 return self.submod(x + self.buf)
6351 m.apply(
lambda s: s._pack())
6353 m.apply(
lambda s: s._unpack())
6356 def test_script_module_not_tuple(self):
6358 __constants__ = [
'mods']
6361 super(M, self).__init__(
False)
6364 @torch.jit.script_method
6365 def forward(self, v):
6369 with self.assertRaisesRegex(RuntimeError,
"cannot be used as a tuple"):
6372 def test_script_module_list_sequential_error(self):
6374 def __init__(self, mod_list):
6375 super(M, self).__init__(
False)
6376 self.mods = mod_list
6378 @torch.jit.script_method
6379 def forward(self, v):
6384 with self.assertRaisesRegex(RuntimeError,
"Did you forget to add it to __constants"):
6385 a = M(nn.Sequential(nn.ReLU()))
6386 with self.assertRaisesRegex(RuntimeError,
"Did you forget to add it to __constants"):
6387 a = M(nn.ModuleList([nn.ReLU()]))
6389 def test_script_sequential_for(self):
6392 super(Sub, self).__init__(
False)
6393 self.weight = nn.Parameter(torch.randn(2))
6395 @torch.jit.script_method
6396 def forward(self, thing):
6397 return self.weight + thing
6400 __constants__ = [
'mods']
6403 super(M, self).__init__(
False)
6404 self.mods = nn.Sequential(Sub(), Sub(), Sub())
6406 @torch.jit.script_method
6407 def forward(self, v):
6412 @torch.jit.script_method
6413 def forward2(self, v):
6422 self.assertEqual(o, v)
6425 self.assertEqual(o2, v)
6427 def test_script_sequential_multi_output_fail(self):
6430 super(Sub, self).__init__(
False)
6431 self.weight = nn.Parameter(torch.randn(2))
6433 @torch.jit.script_method
6434 def forward(self, thing):
6435 return self.weight + thing
6439 super(ReturnMulti, self).__init__(
False)
6441 @torch.jit.script_method
6442 def forward(self, x):
6446 __constants__ = [
'someseq']
6449 super(HaveSequential, self).__init__(
False)
6450 self.someseq = nn.Sequential(
6456 @torch.jit.script_method
6457 def forward(self, x):
6458 return self.someseq(x)
6460 with self.assertRaisesRegex(RuntimeError,
"(Tensor, Tensor, Tensor)"):
6461 hs = HaveSequential()
6465 def test_constant_insert_fail_lint(self):
6473 self.run_pass(
'constant_propagation', foo.graph)
6474 self.assertTrue(
"aten::tensor" in str(foo.graph))
6476 def test_script_sequential_in_mod_list(self):
6479 super(Sub, self).__init__(
False)
6480 self.weight = nn.Parameter(torch.randn(2))
6482 @torch.jit.script_method
6483 def forward(self, thing):
6484 return self.weight + thing
6487 __constants__ = [
'mods']
6490 super(M, self).__init__(
False)
6491 self.mods = nn.ModuleList([Sub(), nn.Sequential(Sub(), nn.Sequential(Sub(), Sub()), Sub())])
6493 @torch.jit.script_method
6494 def forward(self, v):
6495 for mod
in self.mods:
6500 graph = str(m.graph)
6501 self.assertTrue(graph.count(
"aten::add") == 5)
6502 self.assertTrue(
"python" not in graph)
6504 def test_script_nested_mod_list(self):
6507 super(Sub, self).__init__(
False)
6508 self.weight = nn.Parameter(torch.randn(2))
6510 @torch.jit.script_method
6511 def forward(self, thing):
6512 return self.weight + thing
6515 __constants__ = [
'mods']
6518 super(M, self).__init__(
False)
6519 self.mods = nn.ModuleList([nn.ModuleList([Sub()]), nn.Sequential(Sub()), nn.ModuleList([Sub(), Sub()])])
6521 @torch.jit.script_method
6522 def forward(self, v):
6523 for mod
in self.mods:
6529 graph = str(m.graph)
6530 self.assertTrue(graph.count(
"aten::add") == 4)
6531 self.assertTrue(
"python" not in graph)
6533 def test_constant_as_attr(self):
6535 __constants__ = [
'dim']
6538 super(M, self).__init__(
False)
6541 @torch.jit.script_method
6542 def forward(self, v):
6543 return torch.cat([v, v, v], dim=self.dim)
6544 v = torch.zeros(1, 1)
6545 self.assertEqual(torch.cat([v, v, v], dim=1), M()(v))
6551 def forward(self, *inputs):
6553 for i
in range(1, len(inputs)):
6561 def forward(self, rep):
6562 return rep, rep, rep
6564 def test_script_star_expr(self):
6568 super(M2, self).__init__(
True)
6570 (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)))
6573 @torch.jit.script_method
6574 def forward(self, rep):
6579 self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
6581 def test_script_star_expr_string(self):
6584 super(M2, self).__init__(
True)
6586 (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)))
6590 def forward(self, rep): 6596 self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
6602 def forward(self, *inputs):
6604 for i
in range(1, len(inputs)):
6606 return output, output, output
6608 def test_script_star_assign(self):
6611 super(M2, self).__init__(
True)
6614 def forward(self, rep): 6615 head, *tail = self.g(rep) 6620 self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
6622 def test_script_module_star_assign2(self):
6625 super(M2, self).__init__(
True)
6628 (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)),
6629 _force_outplace=
True)
6631 def forward(self, rep): 6632 *head, tail = self.g(rep, rep, rep) 6637 self.assertEqual(m(torch.ones(4, 3)), 3 * torch.ones(4, 3))
6639 def test_script_module_star_assign2_inplace(self):
6642 super(M2, self).__init__(
True)
6645 (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)),
6646 _force_outplace=
False)
6648 def forward(self, rep): 6649 *head, tail = self.g(rep, rep, rep) 6657 self.assertEqual(m(torch.ones(4, 3)), 4 * torch.ones(4, 3))
6659 def test_script_module_star_assign_fail_pythonop(self):
6661 with self.assertRaisesRegex(RuntimeError,
"cannot be used as a tuple"):
6664 super(M2, self).__init__(
True)
6667 return torch.zeros(1, 2, 3), torch.zeros(1, 2, 3)
6670 def forward(self, rep): 6676 m(torch.zeros(4, 3))
6678 def test_script_module_star_assign_fail_builtin(self):
6679 with self.assertRaisesRegex(RuntimeError,
"cannot be used as a tuple"):
6682 super(M2, self).__init__(
True)
6685 def forward(self, rep): 6686 a, *b = torch.neg(rep) 6691 m(torch.zeros(4, 3))
6693 def test_pack_padded_pad_packed_trace(self):
6697 class PadPackedWrapper(torch.nn.Module):
6699 super(PadPackedWrapper, self).__init__()
6701 def forward(self, x, seq_lens):
6702 x = pack_padded_sequence(x, seq_lens)
6703 x, _ = pad_packed_sequence(x)
6706 x = np.ones((T, B, C))
6707 seq_lens = np.array([3, 3, 2, 2, 1], dtype=np.int32)
6711 x[seq_lens[b]:, b, :] = 0
6712 seq_lens = torch.from_numpy(seq_lens)
6713 x = torch.autograd.Variable(torch.from_numpy(x), requires_grad=
True)
6715 m = PadPackedWrapper()
6721 grad = x.grad.clone()
6724 y_traced = m_traced(x, seq_lens)
6725 loss_traced = torch.sum(y_traced)
6726 loss_traced.backward()
6727 grad_traced = x.grad.clone()
6729 self.assertEqual(y_traced, x)
6730 self.assertEqual(y_traced, y)
6731 self.assertEqual(grad, grad_traced)
6736 def test_script_outputs(self):
6737 with self.assertRaisesRegex(RuntimeError,
"cannot be used as a tuple"):
6747 with self.assertRaisesRegex(RuntimeError,
"too many values to unpack"):
6754 @unittest.skipIf(
not RUN_CUDA,
"requires CUDA")
6755 def test_script_get_device_cuda(self):
6758 return a.get_device()
6760 v = torch.randn(1, device=
'cuda')
6761 self.assertEqual(foo(v), 0)
6763 def test_script_chunk(self):
6766 b, c = torch.chunk(a, dim=0, chunks=2)
6768 v = torch.rand(10, 3)
6769 self.assertEqual(torch.chunk(v, dim=0, chunks=2)[0], foo(v))
6771 def test_rnn_trace_override(self):
6776 class RNNTraceWrapper(torch.nn.Module):
6777 def __init__(self, cell_type):
6778 super(RNNTraceWrapper, self).__init__()
6779 if cell_type ==
'RNN':
6780 self.rnn = torch.nn.RNN(input_size=C, hidden_size=C, num_layers=num_layers)
6781 elif cell_type ==
'LSTM':
6782 self.rnn = torch.nn.LSTM(input_size=C, hidden_size=C, num_layers=num_layers)
6783 elif cell_type ==
'GRU':
6784 self.rnn = torch.nn.GRU(input_size=C, hidden_size=C, num_layers=num_layers)
6786 def forward(self, x, seq_lens):
6787 x = pack_padded_sequence(x, seq_lens)
6789 x, _ = pad_packed_sequence(x)
6792 for cell_type
in [
'RNN',
'LSTM',
'GRU']:
6793 x = torch.ones(T, B, C, requires_grad=
True)
6794 seq_lens = torch.from_numpy(np.array([11, 3, 2, 2, 1], dtype=np.int32))
6796 m = RNNTraceWrapper(cell_type)
6802 grad = x.grad.clone()
6805 y_traced = m_traced(x, seq_lens)
6806 loss_traced = torch.sum(y_traced)
6807 loss_traced.backward()
6808 grad_traced = x.grad.clone()
6810 self.assertEqual(y_traced, y)
6811 self.assertEqual(grad, grad_traced)
6816 def test_python_call_non_tensor(self):
6824 x = torch.ones(3, 4)
6825 a, b = foo(x, 3, (x, 3))
6828 self.assertEqual((6, torch.ones(3, 4) + 1), bar())
6830 def test_python_call_non_tensor_wrong(self):
6831 with self.assertRaisesRegex(RuntimeError,
r"but instead got value of type tuple"):
6842 def test_tuples(self):
6858 v = torch.rand(10, 3)
6859 self.checkScript(foo, (v,))
6861 with self.assertRaisesRegex(RuntimeError,
r"variable 'a' previously has type \(Tensor, Tensor\)"):
6868 def test_if_tuple_sizes(self):
6869 with self.assertRaisesRegex(RuntimeError,
"Type mismatch"):
6871 def diff_tuple_sizes(x):
6873 c0 = ((x, x), (x, x, x))