1 from __future__
import division
10 from contextlib
import contextmanager
11 from itertools
import product, chain
19 from common_utils
import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
20 skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
21 freeze_rng_state, set_rng_seed
22 from common_nn
import module_tests, new_module_tests, criterion_tests
23 from textwrap
import dedent
24 from functools
import wraps
41 from common_methods_invocations
import method_tests
as autograd_method_tests
42 from common_methods_invocations
import create_input, unpack_variables, \
43 exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL
45 from torch._C import TensorType, TupleType, FloatType, IntType, \
46 ListType, StringType, DictType
47 from copy
import deepcopy
49 from typing
import List, Dict, Optional, Tuple
52 from torch
import Tensor
62 load_tests = load_tests
66 HAS_TORCHVISION =
True 68 HAS_TORCHVISION =
False 71 skipIfNoTorchVision = unittest.skipIf(
not HAS_TORCHVISION,
"no torchvision")
74 RUN_CUDA_HALF = RUN_CUDA
76 CUDA_VERSION = torch._C._cuda_getCompiledVersion()
79 if (CUDA_VERSION < 8000
and major >= 6)
or (CUDA_VERSION < 9000
and major >= 7):
81 if (CUDA_VERSION < 9000
or major < 6):
86 PY35 = sys.version_info >= (3, 5)
87 WINDOWS = sys.platform ==
'win32' 92 def TemporaryFileName():
96 f = tempfile.NamedTemporaryFile(delete=
False)
104 def TemporaryFileName():
105 with tempfile.NamedTemporaryFile()
as f:
109 def LSTMCellF(input, hx, cx, *params):
110 return LSTMCell(input, (hx, cx), *params)
113 def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
115 gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
117 ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
118 ingate = torch.sigmoid(ingate)
119 forgetgate = torch.sigmoid(forgetgate)
120 cellgate = torch.tanh(cellgate)
121 outgate = torch.sigmoid(outgate)
123 cy = (forgetgate * cx) + (ingate * cellgate)
124 hy = outgate * torch.tanh(cy)
128 def LSTMCellC(*args, **kwargs):
129 hy, cy = LSTMCellF(*args, **kwargs)
130 return torch.cat((hy, cy))
133 def LSTMCellS(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
134 gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh
135 ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
136 ingate = torch.sigmoid(ingate)
137 forgetgate = torch.sigmoid(forgetgate)
138 cellgate = torch.tanh(cellgate)
139 outgate = torch.sigmoid(outgate)
140 cy = (forgetgate * cx) + (ingate * cellgate)
141 hy = outgate * torch.tanh(cy)
146 def MiLSTMCell(x, hx, cx, w_ih, w_hh, alpha, beta_i, beta_h, bias):
150 gates = alpha * Wx * Uz + beta_i * Wx + beta_h * Uz + bias
152 ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
153 ingate = ingate.sigmoid()
154 forgetgate = forgetgate.sigmoid()
155 cellgate = cellgate.tanh()
156 outgate = outgate.sigmoid()
157 cy = (forgetgate * cx) + (ingate * cellgate)
158 hy = outgate * cy.tanh()
162 def canonical(graph):
163 return str(torch._C._jit_pass_canonicalize(graph))
166 def get_lstm_inputs(device, training=False, seq_length=None):
167 input_shape = (3, 10)
if seq_length
is None else (seq_length, 3, 10)
168 input = torch.randn(*input_shape, dtype=torch.float, device=device, requires_grad=training)
169 hx = torch.randn(3, 20, dtype=torch.float, device=device, requires_grad=training)
170 cx = torch.randn(3, 20, dtype=torch.float, device=device, requires_grad=training)
171 module = nn.LSTMCell(10, 20).to(device, torch.float)
173 params = tuple(module.parameters())
175 params = tuple(p.requires_grad_(
False)
for p
in module.parameters())
176 return (input, hx, cx) + params
179 def get_milstm_inputs(device, training=False):
183 x = torch.randn(minibatch, input_size, device=device, dtype=torch.float)
184 hx = torch.randn(minibatch, hidden_size, device=device, dtype=torch.float)
185 cx = torch.randn(minibatch, hidden_size, device=device, dtype=torch.float)
187 ih = torch.randn(4 * hidden_size, input_size, device=device, dtype=torch.float, requires_grad=training)
188 hh = torch.randn(4 * hidden_size, hidden_size, device=device, dtype=torch.float, requires_grad=training)
189 alpha = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
190 ibeta = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
191 hbeta = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
192 bias = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
193 return x, hx, cx, ih, hh, alpha, ibeta, hbeta, bias
196 def get_fn(file_name, script_path):
197 import importlib.util
198 spec = importlib.util.spec_from_file_location(file_name, script_path)
199 module = importlib.util.module_from_spec(spec)
200 spec.loader.exec_module(module)
205 def get_execution_plan(graph_executor_state):
206 execution_plans = list(graph_executor_state.execution_plans.values())
207 num_plans = len(execution_plans)
209 raise RuntimeError(
'This test assumes this GraphExecutor should ' 210 'only have one execution plan, got: {}'.format(num_plans))
211 return execution_plans[0]
214 def get_grad_executor(plan_state, diff_graph_idx=None):
215 if diff_graph_idx
is None:
216 nodes = list(plan_state.graph.nodes())
217 if len(nodes) == 1
or (len(nodes) == 2
and nodes[1].kind() ==
"prim::TupleConstruct"):
220 raise RuntimeError(
"Can't get a grad_executor for a non-differentiable graph")
221 grad_executors = list(plan_state.code.grad_executors())
222 return grad_executors[diff_graph_idx
or 0]
225 def backward_graph(script_module, diff_graph_idx=None):
227 raise RuntimeError(
'Expected ScriptModule')
228 ge_state = script_module.get_debug_state()
229 fwd_plan = get_execution_plan(ge_state)
230 grad_executor = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx)
231 bwd_plan = get_execution_plan(grad_executor.get_debug_state())
234 return bwd_plan.graph.copy()
238 def _trace(*args, **kwargs):
244 def enable_cpu_fuser(fn):
245 def wrapper(*args, **kwargs):
246 torch._C._jit_override_can_fuse_on_cpu(
True)
250 torch._C._jit_override_can_fuse_on_cpu(
False)
255 _do_cuda_memory_leak_check =
True 256 _restored_warnings =
False 262 if not JitTestCase._restored_warnings:
264 JitTestCase._restored_warnings =
True 270 torch._C._jit_set_emit_module_hook(
None)
271 torch._C._jit_clear_class_registry()
274 def disableModuleHook(self):
275 torch._C._jit_set_emit_module_hook(
None)
279 def emitModuleHook(self, module):
280 def copy_structure_and_params(m):
282 for name, v
in m._get_parameters():
283 c._register_parameter(name, v,
False)
284 for name, the_type, v
in m._get_attributes():
285 c._register_attribute(name, the_type, v)
286 for name, s
in m._get_modules():
287 c._register_module(name, copy_structure_and_params(s))
293 pp, constant_table = module._python_print()
294 except RuntimeError
as e:
296 if "could not export python function" not in se
and \
297 "closures are not exportable" not in se:
301 ppv =
"op_version_set = 0\n{}".format(pp)
302 sm = copy_structure_and_params(module)
303 torch._C._jit_import_methods(sm, ppv, constant_table)
304 pp2, _ = sm._python_print()
306 self.assertMultiLineEqual(pp, pp2)
308 def getExportImportCopy(self, m, also_test_file=True, map_location=None):
309 buffer = io.BytesIO()
314 if not also_test_file:
317 with TemporaryFileName()
as fname:
321 def getExportImportCopyWithPacking(self, m, also_test_file=True, map_location=None):
322 buffer = io.BytesIO()
323 m.apply(
lambda s: s._pack()
if s._has_method(
'_pack')
else None)
325 m.apply(
lambda s: s._unpack()
if s._has_method(
'_unpack')
else None)
328 imported.apply(
lambda s: s._unpack()
if s._has_method(
'_unpack')
else None)
330 if not also_test_file:
336 f = tempfile.NamedTemporaryFile(delete=
False)
339 imported.save(f.name)
344 result.apply(
lambda s: s._unpack()
if s._has_method(
'_unpack')
else None)
347 def assertGraphContains(self, graph, kind):
348 self.assertTrue(any(n.kind() == kind
for n
in graph.nodes()))
350 def assertGraphContainsExactly(self, graph, kind, num_kind_nodes, consider_subgraphs=False):
351 def perform_assert(graph, kind, actual, expected, consider_subgraphs):
352 if actual == expected:
354 subgraph =
'including' if consider_subgraphs
else 'excluding' 355 raise AssertionError(
356 '{}\nError: graph contains {} {} nodes ({} subgraphs) but expected {}'.format(
357 graph, actual, kind, subgraph, expected))
359 if consider_subgraphs:
360 strgraph = str(graph)
361 count = strgraph.count(kind) - strgraph.count(
'with {}'.format(kind))
362 perform_assert(graph, kind, count, num_kind_nodes,
366 nodes = [node
for node
in graph.nodes()
367 if node.kind() == kind]
368 perform_assert(graph, kind, len(nodes), num_kind_nodes,
371 def assertExpectedONNXGraph(self, trace, *args, **kwargs):
375 def assertExpectedGraph(self, trace, *args, **kwargs):
376 if isinstance(trace, torch._C.Graph):
379 graph = trace.graph()
381 torch._C._jit_pass_lint(graph)
382 torch._C._jit_pass_dce(graph)
383 torch._C._jit_pass_lint(graph)
384 graph = torch._C._jit_pass_canonicalize(graph)
385 torch._C._jit_pass_lint(graph)
388 def run_pass(self, name, trace):
389 if isinstance(trace, torch._C.Graph):
394 graph = trace.graph()
396 torch._C._jit_pass_lint(graph)
397 result = getattr(
torch._C,
'_jit_pass_' + name)(graph)
398 if result
is not None:
400 torch._C._jit_pass_lint(graph)
403 trace.set_graph(graph)
406 def checkScript(self,
412 capture_output=
False,
414 check_expected=
False):
415 if isinstance(script, str):
417 ge = getattr(cu, name)
420 with self.capture_stdout()
as captured:
421 outputs = script(*inputs)
423 outputs = script(*inputs)
425 source = textwrap.dedent(inspect.getsource(script))
434 check_expected=check_expected)
439 with self.capture_stdout()
as captured:
440 outputs_ge = ge(*inputs)
444 outputs_ge = ge(*inputs)
452 def checkTrace(self, func, reference_tensors, input_tensors=None,
453 optimize=
True, drop=
None, allow_unused=
False, verbose=
False,
454 inputs_require_grads=
True, check_tolerance=1e-5, export_import=
True,
455 _force_outplace=
False):
464 return sum(math.log(i + 2) * v.sum()
for i, v
in enumerate(vs)
if v
is not None)
465 if input_tensors
is None:
466 input_tensors = reference_tensors
468 nograd_inputs = reference_tensors
469 if inputs_require_grads:
470 recording_inputs = [t.clone().requires_grad_()
for t
in reference_tensors]
472 recording_inputs = reference_tensors
474 if isinstance(func, torch._C.Graph):
475 ge = torch._C.GraphExecutor(func, optimize)
477 ge =
torch.jit.trace(func, input_tensors, optimize=optimize, check_tolerance=check_tolerance,
478 _force_outplace=_force_outplace)
487 outputs = func(*nograd_inputs)
488 outputs_ge = ge(*nograd_inputs)
492 outputs = func(*recording_inputs)
493 if inputs_require_grads:
495 allow_unused=allow_unused)
497 outputs_ge = ge(*recording_inputs)
498 if inputs_require_grads:
500 allow_unused=allow_unused)
502 if inputs_require_grads:
507 outputs = func(*recording_inputs)
509 if inputs_require_grads:
511 allow_unused=allow_unused)
512 if inputs_require_grads:
513 l2 = (allSum(grads) * l1)
516 if inputs_require_grads:
517 recording_inputs = [Variable(t, requires_grad=
True)
518 for t
in reference_tensors]
520 outputs_ge = ge(*recording_inputs)
521 l1_ge = allSum(outputs_ge)
522 if inputs_require_grads:
524 l1_ge, recording_inputs, create_graph=
True, allow_unused=allow_unused)
526 if inputs_require_grads:
527 l2_ge = (allSum(grads_ge) * l1_ge)
531 if inputs_require_grads:
533 for g2, g2_ge
in zip(grads2, grads2_ge):
534 if g2
is None and g2_ge
is None:
536 self.assertTrue(torch.allclose(g2, g2_ge, atol=8e-4, rtol=8e-4))
540 def createScriptModuleFromGraph(self, trace):
541 graph = trace
if isinstance(trace, torch._C.Graph)
else trace.graph()
543 m._create_method_from_graph(
"forward", graph)
546 def assertExportImport(self, trace, inputs):
550 def assertExportImportModule(self, m, inputs):
555 def runAndSaveRNG(self, func, inputs, kwargs=None):
556 kwargs = kwargs
if kwargs
else {}
557 with freeze_rng_state():
558 results = func(*inputs, **kwargs)
565 super(FooToPickle, self).__init__()
571 @unittest.skip(
"Requires a lot of RAM")
574 gig = int(1024 * 1024 * 1024 / 4)
576 m.v0 = nn.Parameter(torch.full((2,), 1, dtype=torch.float))
578 m.v1 = nn.Parameter(torch.full((5, gig), 2, dtype=torch.float))
580 m.v2 = nn.Parameter(torch.full((2,), 3, dtype=torch.float))
582 m.v3 = nn.Parameter(torch.full((5, gig), 4, dtype=torch.float))
586 self.
assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
588 def test_simple(self):
593 return torch.sigmoid(torch.tanh(x * (x + y)))
597 def test_restore_device(self):
600 cpu_device_str =
'cpu' 601 m.p0 = nn.Parameter(
torch.tensor([0.3], dtype=torch.float,
602 device=cpu_device_str))
603 m.register_buffer(
'b0',
torch.tensor([0.9], dtype=torch.float,
604 device=cpu_device_str))
606 self.
assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
607 self.
assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
608 self.assertFalse(m2.p0.is_cuda)
609 self.assertFalse(m2.b0.is_cuda)
611 def test_model_save_error(self):
612 with TemporaryFileName()
as fname:
616 def test_single_tuple_trace(self):
622 assert f2(x) == jit_f2(x)
624 @unittest.skipIf(
not RUN_CUDA,
"restore device requires CUDA")
625 def test_restore_device_cuda(self):
628 super(MyModule, self).__init__(
False)
629 self.register_buffer(
'b0', torch.randn(1, 3))
630 self.
p0 = nn.Parameter(torch.randn(2, 3))
632 @torch.jit.script_method
633 def forward(self, x):
634 return x + self.b0 + self.
p0 640 self.assertTrue(m.p0.is_cuda)
641 self.assertTrue(m.b0.is_cuda)
645 self.
assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
646 self.
assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
647 self.
assertEqual(str(m2.p0.device), cuda_device_str)
648 self.
assertEqual(str(m2.b0.device), cuda_device_str)
651 cpu_device_str =
'cpu' 653 self.
assertEqual(str(m3.p0.device), cpu_device_str)
654 self.
assertEqual(str(m3.b0.device), cpu_device_str)
658 m3, map_location=torch.device(
'cuda:0'))
664 origin_result = m(input)
669 @unittest.skipIf(
not RUN_CUDA,
"restore device requires CUDA")
670 def test_restore_shared_storage_on_cuda(self):
671 whole_tensor = torch.randn(4, 5, dtype=torch.float, device=
'cpu')
673 m.p0 = nn.Parameter(whole_tensor.narrow(0, 0, 1))
674 m.register_buffer(
'b0', whole_tensor.narrow(0, 3, 1))
676 self.
assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
677 self.
assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
678 self.assertTrue(m2.p0.is_cuda)
679 self.assertTrue(m2.b0.is_cuda)
680 self.assertTrue(m2.p0.is_shared())
681 self.assertTrue(m2.b0.is_shared())
682 self.
assertEqual(m2.b0.storage().data_ptr(), m2.p0.storage().data_ptr())
684 def test_typeas_trace_check(self):
693 def test_peephole(self):
702 FileCheck().check(
"type_as").run(str(tf.graph))
704 FileCheck().check_not(
"type_as").run(str(tf.graph))
707 self.
run_pass(
'peephole', tf2.graph)
710 def test_peephole_dynamic(self):
716 torch._C._jit_pass_peephole(fn.graph)
719 @unittest.skipIf(
not RUN_CUDA,
"cpp tests require CUDA")
720 def test_peephole_cuda(self):
730 self.
run_pass(
'peephole', trace.graph)
733 self.
run_pass(
'peephole', trace.graph)
734 self.assertTrue(len(list(trace.graph.nodes())) == 0)
736 def test_index(self):
747 def test_disabled(self):
748 torch.jit._enabled =
False 753 self.assertIs(
torch.jit.trace(f, (torch.randn(2, 2), torch.randn(2, 2))), f)
757 @torch.jit.script_method
765 self.assertTrue(inspect.ismethod(MyModule.method)
or inspect.isfunction(MyModule.method))
767 torch.jit._enabled =
True 769 def test_train_eval(self):
770 class Sub(nn.Module):
771 def forward(self, input):
778 def __init__(self, module):
779 super(MyModule, self).__init__()
782 @torch.jit.script_method
783 def forward(self, input):
784 return self.
module(input) + 1
787 input = torch.rand(3, 4)
793 input = torch.randn(6, 10)
794 batchnorm = nn.BatchNorm1d(10)
795 dropout = nn.Dropout(p=0.2)
797 m_batchnorm = MyModule(batchnorm)
798 self.
assertEqual(batchnorm(input) + 1, m_batchnorm(input))
801 self.
assertEqual(batchnorm(input) + 1, m_batchnorm(input))
803 m_dropout = MyModule(dropout)
808 def test_diff_subgraph_clones_constants(self):
811 return x + x + y + x + y + x + y + x + y + x
813 def count_constants(graph):
814 return sum(node.kind() ==
'prim::Constant' for node
in graph.nodes())
816 graph = f.graph.copy()
818 self.
run_pass(
'create_autodiff_subgraphs', graph)
819 nodes = list(graph.nodes())
821 self.
assertEqual(count_constants(nodes[1].g(
'Subgraph')), 1)
828 def test_index_constant(self):
842 def test_scopes(self):
851 out = torch.tanh(out)
852 out = torch.sigmoid(out)
857 def test_scopes_intermediate_node(self):
858 class Net(nn.Module):
859 def forward(self, x):
860 return F.log_softmax(x, dim=0)
863 t = torch.ones(2, requires_grad=
True)
868 FileCheck().check(
"onnx::LogSoftmax").check(
"scope: Net").run(str(trace))
870 def test_scopes_identity_node(self):
872 class Net(nn.Module):
875 super(Net, self).__init__()
877 nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
878 nn.ReLU(inplace=
True),
879 nn.MaxPool2d(kernel_size=3, stride=2),
882 def forward(self, x):
888 t = torch.ones(1, 3, 227, 227, requires_grad=
True)
895 FileCheck().check(
"Net/Sequential[features]/Conv2d[0]").check(
"ReLU").check(
"MaxPool").run(str(trace))
897 def test_canonicalize_tensor_iterator(self):
898 x = torch.randn(4, 4)
909 graph = traced.graph_for(x)
912 self.assertTrue(str(traced.graph_for(x)).count(
': int = prim::Constant') == 5)
915 @unittest.skip(
"Need to be adjusted to Graph Executor")
917 """Different arg configurations should trigger different traces""" 918 x = Variable(torch.FloatTensor(4, 4).uniform_())
919 x_double = Variable(x.data.double())
920 x_grad = Variable(x.data.clone(), requires_grad=
True)
921 y = Variable(torch.randn(4))
932 x_cuda = Variable(x.data.cuda())
940 x_cuda_1 = Variable(x.data.cuda(1))
943 ([x_cuda, x_cuda_1],),
946 @torch.jit.compile(nderivs=0)
948 in_vars, _ = torch._C._jit_flatten(args)
949 return in_vars[0] + 1
951 for i, config
in enumerate(configurations):
952 self.assertFalse(fn.has_trace_for(*config))
954 self.assertTrue(fn.has_trace_for(*config))
955 for unk_config
in configurations[i + 1:]:
956 self.assertFalse(fn.has_trace_for(*unk_config))
964 w = (x + y) * (x + y) * (x + y)
965 t = torch.tanh(w) + torch.tanh(w)
966 z = (x + y) * (x + y) * (x + y) + t
972 FileCheck().check_count(
"add", 1).check_count(
"mul", 2, do_exactly) \
973 .check_count(
"tanh", 1, do_exactly).check_count(
"add", 2, do_exactly).check_next(
"return") \
978 def test_recursive_cse(self):
990 FileCheck().check(
"block").check_not(
"aten::add").check_not(
"aten::gt").run(str(graph))
992 def test_shape_analysis_broadcast(self):
996 x = torch.randn(3, 1, 5, requires_grad=
True)
997 y = torch.randn(4, 1, 8, 5, requires_grad=
True)
1000 torch._C._jit_pass_complete_shape_analysis(graph, (x, y),
False)
1001 FileCheck().check(
"Double(4, 3, 8, 5)").run(str(graph))
1004 @unittest.skip(
"verify needs to be updated to work with GraphExecutors")
1011 z = torch.sigmoid(x * (x + y))
1012 w = torch.abs(x * x * x + y) + Variable(torch.ones(1))
1018 def test_constant(self):
1019 x = torch.randn(2, 2, requires_grad=
True)
1024 self.
checkTrace(f, (x,), (torch.ones(2, 2, requires_grad=
True),))
1026 def test_legacy_fail(self):
1027 class MyLegacyFn(Function):
1028 def forward(self, x):
1031 def backward(self, grad_output):
1038 def test_inplace_transplant(self):
1049 FileCheck().check_count(
"aten::clone", 1, exactly=
True) \
1050 .check_count(
"aten::add_", 2, exactly=
True) \
1051 .check_next(
"return").run(str(trace))
1054 def test_inplace_flags(self):
1055 class InplaceFn(Function):
1057 def forward(ctx, x):
1062 def backward(ctx, go):
1065 class RegularFn(Function):
1067 def forward(ctx, x):
1071 def backward(ctx, go):
1077 y = RegularFn.apply(x)
1078 y = InplaceFn.apply(y)
1079 y = InplaceFn.apply(y)
1080 y = RegularFn.apply(y)
1085 ops = [n
for n
in trace.graph().nodes()]
1087 self.assertTrue(op.hasAttribute(
'inplace'))
1088 inplace_flags = [
False,
True,
True,
False]
1089 for op, is_inplace
in zip(ops, inplace_flags):
1092 def test_inplace_check(self):
1093 class MyInplaceFn(Function):
1095 def forward(self, x):
1101 def backward(self, grad):
1105 return MyInplaceFn.apply(x)
1107 x = torch.randn(5, 5)
1108 ge = torch._C.GraphExecutor(fn, (x,),
lambda var:
'', _force_outplace=
True)
1112 def do_trace_size(self, requires_grad):
1114 return x.view(x.shape[1] * 2, x.size(0), 2)
1116 x = torch.randn(5, 2, 4, requires_grad=requires_grad)
1117 y = torch.randn(4, 8, 4, requires_grad=requires_grad)
1124 def test_trace_size(self):
1129 def test_trace_size_with_grad(self):
1132 def test_trace_casts(self):
1135 lambda x: x.float(),
1137 lambda x: x.to(device=
'cpu'),
1138 lambda x: x.to(dtype=torch.int64),
1139 lambda x: x.to(device=
'cpu', dtype=torch.float),
1143 def assertContainsCast(trace):
1144 self.
assertEqual(sum(n.kind() ==
'aten::to' for n
in trace.graph.nodes()), 1)
1148 assertContainsCast(trace)
1149 x = torch.randn(2, 2)
1152 def to_tensor(x, y):
1155 to_tensor_trace =
torch.jit.trace(to_tensor, (torch.randn(2, 2), torch.randn(1, 8)))
1156 assertContainsCast(to_tensor_trace)
1157 x, y = torch.randn(2, 2), torch.randn(1, 10)
1158 self.
assertEqual(to_tensor_trace(x, y), to_tensor(x, y))
1160 def test_trace_warn(self):
1171 for _
in torch.ones(4, 4):
1175 with warnings.catch_warnings(record=
True)
as warns:
1177 warns = [str(w.message)
for w
in warns]
1179 self.assertIn(
'a Python integer', warns[0])
1180 self.assertIn(
'a Python boolean', warns[1])
1181 self.assertIn(
'a Python index', warns[2])
1182 self.assertIn(
'a Python float', warns[3])
1183 self.assertIn(
'a Python list', warns[4])
1184 self.assertIn(
'a NumPy array', warns[5])
1185 self.assertIn(
'Iterating over', warns[6])
1187 def test_trace_tuple(self):
1189 return x, (x * y[1], x * y[0])
1191 x, y = torch.randn(2, 2), (torch.ones(2, 2), torch.randn(2, 2))
1195 FileCheck().check_count(
"prim::TupleConstruct", 2, exactly=
True).check_next(
"return") \
1196 .run(str(traced_fn.graph))
1199 def test_trace_random(self):
1201 return torch.normal(mean, std)
1203 traced =
torch.jit.trace(f, (torch.zeros(2, 3), torch.ones(2, 3)), check_trace=
False)
1204 mean, std = torch.zeros(5, 5), torch.ones(5, 5)
1206 output = f(mean, std)
1207 traced_output = traced(mean, std)
1210 def test_trace_tensor_factory(self):
1212 inputs_require_grads = kwargs.pop(
'inputs_require_grads',
True)
1215 return x + torch.ones(2, 3, **kwargs)
1217 input_kwargs = kwargs.copy()
1218 if 'out' in input_kwargs:
1219 del input_kwargs[
'out']
1220 input = torch.ones(2, 3, **input_kwargs)
1221 self.
checkTrace(fn, (input,), inputs_require_grads=inputs_require_grads)
1224 self.assertTrue(
"ones" in str(tfn.graph))
1226 run(dtype=torch.int, inputs_require_grads=
False)
1229 run(device=
"cuda:0")
1230 if RUN_CUDA_MULTI_GPU:
1231 run(device=
"cuda:1")
1233 def test_trace_indexed_assignment(self):
1238 example = torch.rand(3, 4)
1239 self.
checkTrace(stuff, (example, example[0] + 1))
1242 @unittest.expectedFailure
1244 """Check that outputs of traced functions retain the original structure and nesting""" 1246 return (x * 2, (x ** 2, x + 4, (x + 2,), ), x * 4)
1251 @unittest.expectedFailure
1253 """Check that inputs to traced functions are flattened""" 1259 inputs = (torch.randn(1), (torch.randn(1), torch.randn(1)))
1263 @unittest.skip(
"Need to instrument GraphExecutors a bit more")
1264 def test_flags(self):
1265 x, y = torch.randn(2, 2)
1266 y = Variable(torch.randn(2, 2))
1270 return (x * x + y * y + x * y).sum()
1273 for rx, ry
in product((
True,
False), repeat=2):
1274 x.requires_grad = rx
1275 y.requires_grad = ry
1277 self.assertFalse(fn.has_trace_for(x, y))
1280 self.assertFalse(fn.has_trace_for(x, y))
1281 for v, name, compute
in [(x,
'x', rx), (y,
'y', ry)]:
1285 expected_grad = grads.setdefault(name, grad_v)
1287 self.
assertEqual(fn.has_trace_for(x, y), rx
or ry)
1289 def test_python_ir(self):
1294 return torch.sigmoid(torch.tanh(x * (x + y)))
1298 self.
run_pass(
'canonicalize', trace)
1300 g2 = torch._C.Graph()
1302 for node
in g.inputs():
1303 g_to_g2[node] = g2.addInput()
1304 for node
in g.nodes():
1305 n_ = g2.createClone(node,
lambda x: g_to_g2[x])
1307 for o, no
in zip(node.outputs(), n_.outputs()):
1310 for node
in g.outputs():
1311 g2.registerOutput(g_to_g2[node])
1313 t_node = g2.create(
"prim::TensorTest").t_(
"a", torch.ones([2, 2]))
1315 g2.appendNode(t_node)
1316 self.assertTrue(torch.equal(torch.ones(2, 2), t_node.t(
"a")))
1317 for node
in g.nodes():
1318 self.assertTrue(g2.findNode(node.kind())
is not None)
1320 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
1321 @unittest.skipIf(
not RUN_CUDA,
"cpp tests require CUDA")
1323 def test_cpp_cuda(self):
1324 from cpp.jit import tests_setup
1326 torch._C._jit_run_cpp_tests()
1327 tests_setup.shutdown()
1329 def test_batchnorm(self):
1330 x = torch.ones(2, 2, 2, 2)
1332 _force_outplace=
True, return_inputs=
True)
1336 def test_dropout(self):
1337 x = torch.ones(2, 2)
1344 @unittest.skipIf(
not RUN_CUDA,
"test_dropout_cuda require CUDA")
1345 def test_dropout_cuda(self):
1348 x = torch.ones(4, 4).cuda().requires_grad_()
1354 with freeze_rng_state():
1358 with freeze_rng_state():
1365 def test_conv(self):
1366 x = torch.ones(20, 16, 50, 40)
1371 def test_repeated_input(self):
1375 ge = self.
checkTrace(fn, [torch.randn(2, 2)] * 2)
1376 inputs = set(ge.graph.inputs())
1377 self.assertTrue(len(inputs) == 2)
1379 def test_repeated_output(self):
1384 ge = self.
checkTrace(fn, [torch.randn(2, 2)
for _
in range(2)])
1385 tuple_output = list(ge.graph.outputs())[0]
1386 tuple_inputs = list(tuple_output.node().inputs())
1387 self.assertTrue(tuple_inputs[0] == tuple_inputs[1])
1389 @skipIfNoTorchVision
1390 def test_alexnet(self):
1391 x = torch.ones(1, 3, 224, 224)
1392 model = torchvision.models.AlexNet()
1400 def test_inplace_copy(self):
1401 x = torch.randn(4, 4, requires_grad=
True)
1404 out = Variable(torch.zeros(x.size()))
1414 def test_shared_param(self):
1415 class MyModule(torch.nn.Module):
1417 super(MyModule, self).__init__()
1418 self.
b = self.
a = nn.Parameter(torch.randn(2, 2))
1420 def forward(self, x):
1421 return x * self.
a + self.
b 1426 self.
assertEqual(len(list(trace.graph().inputs())), 2)
1427 FileCheck().check(
"mul").check(
"add").run(str(trace))
1429 def test_trace_c10_ops(self):
1430 class MyModel(torch.nn.Module):
1432 super(MyModel, self).__init__()
1434 def forward(self, scores, bbox_deltas, im_info, anchors):
1435 a, b = torch.ops._caffe2.GenerateProposals(
1436 (scores), (bbox_deltas), (im_info), (anchors),
1437 2.0, 6000, 300, 0.7, 16,
True, -90, 90, 1.0,
1445 scores = torch.ones(img_count, A, H, W, dtype=torch.float32)
1446 bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W,
1447 dtype=torch.float32)
1448 bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W)
1449 im_info = torch.ones(img_count, 3, dtype=torch.float32)
1450 anchors = torch.ones(A, 4, dtype=torch.float32)
1451 inputs = (scores, bbox_deltas, im_info, anchors)
1456 def test_nested_inplace(self):
1457 x = torch.randn(2, 2)
1459 lambda x: F.threshold(x, 0, 0, inplace=
True), (x, ), return_inputs=
True)
1462 FileCheck().check(
"threshold_").run(str(trace))
1465 def run_ge_tests(self, optimize, use_cuda):
1467 t = torch.rand(*args).float()
1472 [rand(1), rand(1)], [rand(2, 3), rand(2, 3)],
1476 b, a), [rand(1), rand(1)], optimize=optimize)
1481 self.
checkTrace(foo, [rand(1)], optimize=optimize)
1484 lambda a, b: a * a, [rand(1), rand(1)], optimize=optimize,
1487 self.
checkTrace(foo, [rand(1)], drop=1, optimize=optimize)
1490 (a - 2 * b) + b, [rand(1), rand(1)],
1493 def test_ge_unoptimized(self):
1496 @unittest.skipIf(IS_WINDOWS
or IS_SANDCASTLE,
"NYI: fuser support for Windows or Sandcastle")
1498 def test_ge_optimized(self):
1501 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
1502 @unittest.skipIf(
not RUN_CUDA,
"requires CUDA")
1503 def test_ge_cuda(self):
1509 return a * b / (a - b) + b
1511 a, b = V(torch.rand(1)), V(torch.rand(1))
1512 ge = torch._C.GraphExecutor(foo, (a, b),
lambda var:
'')
1513 a, b = V(torch.rand(1), requires_grad=
True), V(
1514 torch.rand(1), requires_grad=
True)
1518 l2 = (da * db + db * db)
1525 l3 = (da2 * db2 + db2 * db2)
1529 def test_trace_annotation(self):
1530 @_trace(torch.rand(1))
1534 x = torch.randn(5, 5)
1537 def test_trace_script(self):
1551 expected = func1((a, b))
1553 result = traced((a, b))
1556 expected = func2((a, b))
1558 result = traced((a, b))
1561 def test_einsum(self):
1563 return torch.einsum(
'i,j->ij', (x, y))
1567 fns = [traced, script]
1568 x, y = torch.randn(10), torch.randn(2)
1569 for fn
in [traced, script]:
1573 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
1574 @unittest.skipIf(
not RUN_CUDA,
"calls .cuda()")
1575 def test_traced_module_cuda(self):
1576 class Model(nn.Module):
1577 def __init__(self, num_features, num_layers):
1578 super(Model, self).__init__()
1580 layers = [[nn.Linear(num_features, num_features), nn.Sigmoid()]
1581 for _
in range(num_layers)]
1582 self.
submodule = nn.Sequential(*chain(*layers))
1584 def forward(self, x):
1590 x = torch.randn(2, 5)
1598 linear_submodule = next(iter(traced_model.submodule._modules.values()))
1601 with self.assertRaises(AttributeError):
1602 linear_submodule.in_features
1603 linear_submodule.weight
1604 with self.assertRaises(RuntimeError):
1605 traced_model.asdf = 4
1606 linear_submodule.weight = nn.Parameter(torch.randn(linear_submodule.weight.shape))
1607 with self.assertRaises(RuntimeError):
1608 del linear_submodule.weight
1611 with self.assertRaises(RuntimeError):
1615 linear_submodule.cuda()
1616 traced_model.float().cuda()
1617 cuda_out = traced_model(x.float().cuda())
1619 cpu_out = traced_model(x.float())
1621 traced_model.to(
'cuda')
1622 cuda_out = traced_model(x.float().cuda())
1623 traced_model.to(
'cpu')
1624 cpu_out = traced_model(x.float())
1626 traced_model.double()
1629 state = {k: v.clone()
for k, v
in traced_model.state_dict().items()}
1630 new_state = {k: v.clone().fill_(1)
for k, v
in state.items()}
1631 out = traced_model(x)
1632 traced_model.load_state_dict(new_state)
1633 out_ones = traced_model(x)
1634 traced_model.load_state_dict(state)
1635 out_state = traced_model(x)
1639 def test_export_no_reorder(self):
1641 return a * b / (a - 2 * b) + b
1643 recording_inputs = [
torch.tensor([0.55619788169860839844], dtype=torch.float32, requires_grad=
True),
1644 torch.tensor([0.25947844982147216797], dtype=torch.float32, requires_grad=
True)]
1649 outputs_ge1 = ge1(*recording_inputs)
1650 outputs_ge2 = ge2(*recording_inputs)
1654 self.assertTrue(outputs_ge1 == outputs_ge2)
1655 self.assertTrue(grad_ge1 == grad_ge2)
1657 def test_python_function(self):
1658 class MyFn(Function):
1660 def forward(ctx, x):
1664 def backward(ctx, grad_output):
1667 @_trace(torch.zeros(2))
1669 return MyFn.apply(x + 2) + 3
1672 y = torch.randn(2, 2, requires_grad=
True)
1676 def test_python_function_tup(self):
1677 class MyFn(Function):
1679 def forward(ctx, x):
1683 def backward(ctx, grad_output):
1684 return grad_output, grad_output
1686 @_trace(torch.zeros(2))
1688 a, b = MyFn.apply(x + 2)
1691 y = torch.randn(2, 2, requires_grad=
True)
1695 def test_decompose_addmm(self):
1696 def does_decompose():
1698 def addmm(mat, mat1, mat2, alpha, beta):
1699 a = mat.addmm(mat1, mat2, alpha=4.20, beta=2.0)
1700 b = mat.addmm(mat1, mat2, alpha=int(alpha), beta=int(beta))
1704 mat = torch.randn(2, 2)
1705 mat1 = torch.randn(2, 4)
1706 mat2 = torch.randn(4, 2)
1707 alpha = torch.FloatTensor([123.0])
1708 beta = torch.FloatTensor([321.0])
1710 out_ref = addmm(mat, mat1, mat2, alpha, beta)
1711 self.
run_pass(
'canonicalize_ops', addmm.graph)
1712 out_test = addmm(mat, mat1, mat2, alpha, beta)
1714 FileCheck().check_not(
"addmm").run(str(addmm.graph))
1716 def doesnt_decompose():
1718 def addmm(mat, mat1, mat2, alpha, beta):
1719 a = mat.addmm(mat1, mat2)
1720 b = mat.addmm(mat1, mat2, alpha=1.0, beta=1.0)
1722 orig = str(addm.graph)
1723 self.
run_pass(
'canonicalize_ops', addmm.graph)
1724 self.assertTrue(orig == str(addmm.graph))
1726 def test_index_put(self):
1727 ten = torch.zeros(3, 3)
1728 mask = torch.Tensor([[
True,
True,
True],
1729 [
True,
False,
False],
1730 [
True,
True,
False]]).byte()
1732 def test_fn(ten, mask):
1733 ten[mask] = torch.ones(6)
1738 ten = torch.rand(3, 3)
1739 self.
assertEqual(test_fn(ten, mask), traced_test_fn(ten, mask))
1741 def test_sparse_tensors_error(self):
1743 return torch.sparse.FloatTensor(2, 3)
1747 output = get_sparse()
1748 return output, input
1751 sparse(get_sparse())
1756 def test_tuple_specialization(self):
1763 t = torch.randn(2, 2), torch.randn(2, 2)
1765 graph = f.graph_for(t)
1766 input_types = list(next(graph.inputs()).type().elements())
1767 for t
in input_types:
1768 self.
assertEqual(t.kind(),
'DimensionedTensorType')
1770 def test_constant_prop_simple(self):
1772 def constant_prop(input_int):
1776 return b - input_int
1778 out_ref = constant_prop(2)
1779 self.
run_pass(
'constant_propagation', constant_prop.graph)
1780 out_test = constant_prop(2)
1782 graph_str = str(constant_prop.graph)
1783 self.assertTrue(
"aten::add" not in graph_str
and "aten::mul" not in graph_str)
1784 const = constant_prop.graph.findNode(
"prim::Constant").output().toIValue()
1787 def test_constant_prop_nested(self):
1789 def constant_prop(a):
1797 self.
run_pass(
'constant_propagation', constant_prop.graph)
1800 if_node = constant_prop.graph.findNode(
"prim::If")
1801 for block
in if_node.blocks():
1802 for node
in block.nodes():
1803 self.assertTrue(node.kind() ==
"prim::Constant")
1805 def test_constant_prop_print(self):
1807 def constant_prop(input_tensor):
1811 return b + input_tensor
1813 self.
run_pass(
'constant_propagation', constant_prop.graph)
1814 graph = constant_prop.graph
1815 print_node = graph.findNode(
"prim::Print")
1816 self.assertTrue(print_node.input().toIValue() == 6)
1818 def test_constant_prop_rand(self):
1820 def constant_prop():
1821 a = torch.randn([3])
1825 self.
run_pass(
'constant_propagation', constant_prop.graph)
1826 self.assertTrue(
"aten::randn" in str(constant_prop.graph))
1828 def test_constant_prop_none(self):
1835 def constant_prop():
1838 if (a
is None and b
is None):
1844 self.
run_pass(
'constant_propagation', constant_prop.graph)
1845 graph_str = str(constant_prop.graph)
1846 self.assertTrue(graph_str.count(
"prim::Constant") == 1)
1848 def test_constant_prop_if_inline(self):
1850 def constant_prop():
1860 self.
run_pass(
'constant_propagation', constant_prop.graph)
1862 def test_trace_records_names(self):
1865 quick_brown_fox = torch.neg(baz)
1867 yeet = quick_brown_fox - 3.14
1871 graph_str = str(traced.graph)
1872 assert 'bar' in graph_str
1873 assert 'baz' in graph_str
1874 assert 'quick_brown_fox' in graph_str
1876 def test_constant_prop_if_constant(self):
1878 def constant_prop(a, b):
1895 return a + c0 + c1 + c2
1897 graph = constant_prop.graph
1898 self.
run_pass(
'constant_propagation', graph)
1899 ifs = graph.findAllNodes(
"prim::If", recurse=
False)
1900 snd_if_inlined = len(ifs) == 1
1901 self.assertTrue(snd_if_inlined)
1903 self.assertTrue(first_if.outputsSize() == 2)
1904 second_if = first_if.findNode(
"prim::If", recurse=
False)
1905 self.assertTrue(second_if.outputsSize() == 1)
1906 self.assertTrue(second_if.findNode(
"prim::If")
is None)
1908 def test_constant_prop_loop_constant(self):
1910 def constant_prop(cond, iter):
1917 for _
in range(iter):
1925 for _i
in range(-4):
1929 self.
run_pass(
'constant_propagation', constant_prop.graph)
1930 graph = canonical(constant_prop.graph)
1931 self.assertTrue(graph.count(
"removed") == 0)
1932 self.assertTrue(graph.count(
"stays") == 1)
1933 self.assertTrue(graph.count(
"prim::Print") == 4)
1935 def test_constant_prop_remove_output(self):
1937 def constant_prop(iter):
1942 for i
in range(iter):
1950 graph = constant_prop.graph
1951 self.
run_pass(
'constant_propagation', graph)
1952 self.assertTrue(graph.findNode(
"prim::Loop").outputsSize() == 2)
1954 def test_trace_detach(self):
1956 return torch.matmul(x, w).detach()
1960 FileCheck().check(
"matmul").check(
"detach").run(str(traced.graph))
1961 x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=
True)
1962 traced_result = traced(x, w)
1964 self.assertFalse(traced_result.requires_grad)
1965 self.assertIsNone(traced_result.grad_fn)
1967 def test_trace_detach_inplace(self):
1969 y = torch.matmul(x, w)
1975 FileCheck().check(
"matmul").check(
"detach(").run(str(traced.graph))
1976 x, w = torch.rand(3, 4), torch.rand(4, 5)
1977 traced_result = traced(x, w)
1979 self.assertFalse(traced_result.requires_grad)
1980 self.assertIsNone(traced_result.grad_fn)
1982 def test_trace_detach_onnx_erase(self):
1983 class Mod(torch.nn.Module):
1984 def forward(self, x, w):
1985 return torch.matmul(x, w).detach()
1989 Mod(), (torch.rand(3, 4), torch.rand(4, 5)), f))
1991 def test_trace_slice_full_dim(self):
1993 return x[0:5, 0] + 1.0
1996 test_x = torch.rand(6, 3)
1999 def test_export_dropout(self):
2000 test = torch.nn.Dropout()
2003 traced =
torch.jit.trace(test, (torch.rand(3, 4),), check_trace=
False)
2005 x = torch.randn(3, 4)
2008 def test_onnx_transpose_incomplete_tensor_type(self):
2015 @torch.jit.script_method
2016 def forward(self, x):
2017 return x.contiguous().transpose(0, 1).sum()
2019 class TraceMe(torch.nn.Module):
2021 super(TraceMe, self).__init__()
2024 def forward(self, x):
2029 example_outputs = (tm(torch.rand(3, 4)),)
2033 @unittest.skipIf(
not RUN_CUDA,
"requires CUDA")
2034 def test_cuda_export_restore(self):
2037 super(Sub, self).__init__()
2038 self.
weight = nn.Parameter(torch.randn(3, 4))
2040 @torch.jit.script_method
2041 def forward(self, thing):
2042 return self.
weight + thing
2046 super(M, self).__init__()
2049 @torch.jit.script_method
2050 def forward(self, v):
2056 input = torch.rand(3, 4).cuda()
2059 def test_export_batchnorm(self):
2060 for mode
in [
'eval',
'train']:
2062 torch.nn.BatchNorm1d(100),
2063 torch.nn.BatchNorm1d(100, affine=
False),
2064 torch.nn.BatchNorm2d(100),
2065 torch.nn.BatchNorm2d(100, affine=
False)]:
2066 getattr(clazz, mode)()
2067 input = torch.randn(20, 100)
if isinstance(clazz, torch.nn.BatchNorm1d)
else \
2068 torch.randn(20, 100, 35, 45)
2071 x = torch.randn(20, 100)
if isinstance(clazz, torch.nn.BatchNorm1d)
else \
2072 torch.randn(20, 100, 35, 45)
2075 def test_export_rnn(self):
2076 for clazz
in [nn.RNN(10, 20, 2), nn.GRU(10, 20, 2)]:
2077 class RNNTest(torch.nn.Module):
2079 super(RNNTest, self).__init__()
2082 def forward(self, x, lengths, h0):
2084 out, h = self.
rnn(packed, h0)
2090 traced =
torch.jit.trace(test, (torch.randn(5, 3, 10), torch.LongTensor([3, 2, 1]), torch.randn(2, 3, 20)))
2095 x, lengths, h0 = torch.randn(7, 4, 10), torch.LongTensor([7, 3, 2, 1]), torch.randn(2, 4, 20)
2096 self.
assertEqual(traced(x, lengths, h0), imported(x, lengths, h0))
2098 def test_export_lstm(self):
2099 class LSTMTest(torch.nn.Module):
2101 super(LSTMTest, self).__init__()
2102 self.
rnn = nn.LSTM(10, 20, 2)
2104 def forward(self, x, lengths, hiddens):
2107 out, (h, c) = self.
rnn(packed, (h0, c0))
2114 torch.LongTensor([3, 2, 1]),
2115 (torch.randn(2, 3, 20), torch.randn(2, 3, 20))))
2117 x, lengths, h0, c0 = \
2118 torch.randn(7, 3, 10), torch.LongTensor([7, 5, 2]), torch.randn(2, 3, 20), torch.randn(2, 3, 20)
2119 self.
assertEqual(traced(x, lengths, (h0, c0)), imported(x, lengths, (h0, c0)))
2121 def test_trace_dict_input(self):
2122 class Bar(torch.nn.Module):
2124 super(Bar, self).__init__()
2127 def forward(self, a, b):
2128 return self.
foo({
'a': a,
'b': b})[
'a']
2130 class Foo(torch.nn.Module):
2131 def forward(self, x):
2132 return {
'a': x[
'a'] * x[
'b']}
2134 x = (torch.rand(3), torch.rand(3))
2138 def test_trace_variable_instantiation(self):
2140 return Variable(Variable(x) + 1.0)
2144 x = torch.rand(5, 6)
2145 self.
assertEqual(random_foo(x), random_foo_traced(x))
2147 def test_trace_slice_expr_complete_type(self):
2155 return random_foo_traced(x)[0:1]
2157 x = torch.rand(3, 4)
2160 def test_export_tensoroption_to(self):
2162 return x.new_tensor(x[0]).cpu() + x
2165 example_outputs = traced(torch.rand([2]))
2169 example_outputs=example_outputs))
2171 def test_pretty_printer(self):
2191 def while_test(a, i):
2198 def while_if_test(a, b):
2210 def loop_use_test(y):
2222 def python_op_name_test(y):
2226 def empty_int_list_test(y):
2231 def empty_float_list_test(y):
2232 return [1.0, 2.0, 3.0]
2235 def print_weird_test(y):
2240 self.
assertExpected(while_test.graph.pretty_print(),
"while_test")
2241 self.
assertExpected(while_if_test.graph.pretty_print(),
"while_if_test")
2242 self.
assertExpected(loop_use_test.graph.pretty_print(),
"loop_use_test")
2243 self.
assertExpected(python_op_name_test.graph.pretty_print(),
"python_op_name_test")
2244 self.
assertExpected(empty_int_list_test.graph.pretty_print(),
"empty_int_list_test")
2245 self.
assertExpected(empty_float_list_test.graph.pretty_print(),
"empty_float_list_test")
2246 self.
assertExpected(print_weird_test.graph.pretty_print(),
"print_weird_test")
2248 def test_cu_escaped_number(self):
2255 def test_import_method(self):
2260 r, _ = foo._python_print()
2262 torch._C._jit_import_methods(mod,
"op_version_set = 0\n{}".format(r), [])
2265 def test_function_default_values(self):
2272 def simple_fn(x, a=a, b=b, c=outer_var + outer_var2):
2273 return x + a + b + c
2276 simple_fn(torch.ones(1)),
2277 torch.ones(1) + 0.5 + 10 + (20 + 30))
2280 torch.ones(1) + 1 + 3 + 4)
2286 def bool_fn(x, a=outer_c, flag=outer_flag):
2293 self.
assertEqual(bool_fn(torch.ones(1)), torch.ones(1) + 9)
2299 def none_fn(x=None):
2307 def hints(x, a=0.5, b=10):
2311 self.
assertEqual(hints(torch.ones(1)), torch.ones(1) + 0.5 + 10)
2316 def hints_bad_types(x, a=10, b=0.5):
2320 def test_module_default_values(self):
2325 super(Test, self).__init__()
2327 @torch.jit.script_method
2328 def forward(self, input, other=four):
2329 return input + other
2332 self.
assertEqual(t(torch.ones(1)), torch.ones(1) + 4)
2334 def test_warnings(self):
2340 warnings.warn(
"x is less than 2")
2343 FileCheck().check(
"aten::warn").run(str(fn.graph))
2345 def test_no_erroneous_warnings(self):
2350 warnings.warn(
'This should NOT be printed')
2354 with warnings.catch_warnings(record=
True)
as warns:
2357 warns = [str(w.message)
for w
in warns]
2360 @unittest.skipIf(sys.platform ==
"win32",
"TODO: need to fix this test case for Windows")
2361 def test_torch_load_error(self):
2364 super(J, self).__init__()
2366 @torch.jit.script_method
2367 def forward(self, input):
2371 with tempfile.NamedTemporaryFile()
as f:
2376 def test_legacy_constructors(self):
2378 return x.new_zeros(5, 5, requires_grad=
False)
2380 with warnings.catch_warnings(record=
True)
as warns:
2382 warns = [str(w.message)
for w
in warns]
2384 self.
assertEqual(warns[0],
"new_zeros is a legacy constructor and is not supported in the JIT.")
2386 def test_python_bindings(self):
2389 def lstm(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
2390 for i
in range(x.size(0)):
2391 hx, cx = lstm_cell(x[i], hx, cx, w_ih, w_hh, b_ih, b_hh)
2396 inputs = get_lstm_inputs(
'cpu', training=
True, seq_length=10)
2397 slstm(*inputs).sum().backward()
2399 fw_graph = slstm.graph_for(*inputs)
2400 nodes = [n
for n
in fw_graph.nodes()]
2401 tested_blocks =
False 2403 for output
in [o
for o
in node.outputs()]:
2404 self.assertTrue(hasattr(output,
'type'))
2405 self.assertTrue(output.type()
is not None)
2406 for input
in [i
for i
in node.inputs()]:
2407 self.assertTrue(hasattr(input,
'type'))
2408 self.assertTrue(input.type()
is not None)
2409 for block
in [b
for b
in node.blocks()]:
2410 tested_blocks =
True 2411 self.assertTrue(hasattr(block,
'inputs'))
2412 self.assertTrue(hasattr(block,
'outputs'))
2413 for output
in [o
for o
in block.outputs()]:
2414 self.assertTrue(hasattr(output,
'type'))
2415 self.assertTrue(output.type()
is not None)
2416 for input
in [i
for i
in block.inputs()]:
2417 self.assertTrue(hasattr(input,
'type'))
2418 self.assertTrue(input.type()
is not None)
2419 self.assertTrue(hasattr(block,
'returnNode'))
2420 self.assertTrue(type(block.returnNode()) == torch._C.Node)
2421 self.assertTrue(hasattr(block,
'paramNode'))
2422 self.assertTrue(type(block.paramNode()) == torch._C.Node)
2423 self.assertTrue(tested_blocks)
2428 def rand_batch(self, *dims):
2429 dims = [dim
for dim
in dims
if dim != ()]
2430 xs = [torch.rand(1, *(random.randint(1, size)
if b
else size
for b, size
in dims[1:]),
2431 requires_grad=
True)
for i
in range(dims[0])]
2432 xb = BatchTensor(xs,
torch.tensor([b
for b, d
in dims[1:]]).byte())
2435 def test_create_batchtensor(self):
2437 xs, batch = self.
rand_batch(4, (
True, 3), (
False, 2), (
True, 5))
2440 batch2 = BatchTensor(batch.get_data(), batch.get_mask(), batch.get_dims())
2443 xs = torch.rand(3, 4, 5)
2444 batch3 = BatchTensor(xs, 2)
2445 xs = xs.unsqueeze(0)
2448 def test_batch_elementwise_unary(self):
2451 return torch.tanh(a)
2453 xs, batch = self.
rand_batch(4, (
True, 3), (
False, 2))
2454 res_batch = tanh(batch)
2455 res = [torch.tanh(xs[j])
for j
in range(4)]
2458 def test_batch_elementwise_binary(self):
2463 xs, batch = self.
rand_batch(4, (
True, 3), (
False, 2))
2464 xs2, batch2 = xs, batch
2465 res_batch = add(batch, batch2)
2466 res = [torch.add(xs[j], xs2[j])
for j
in range(4)]
2470 xs, batch = self.
rand_batch(4, (
False, 3), (
False, 2))
2471 b = torch.rand(3, 2)
2472 res_batch = add(batch, b)
2473 res = [torch.add(xs[j], b)
for j
in range(4)]
2476 def test_batch_mm(self):
2479 return torch.mm(a, b)
2481 xs, batch = self.
rand_batch(4, (
True, 3), (
False, 2))
2482 xs2, batch2 = self.
rand_batch(4, (
False, 2), (
True, 3))
2483 res_batch = mm(batch, batch2)
2484 res = [torch.mm(xs[j].squeeze(0), xs2[j].squeeze(0)).unsqueeze(0)
for j
in range(4)]
2488 b = torch.rand(2, 4)
2489 res_batch = mm(batch, b)
2490 res = [torch.mm(xs[j].squeeze(0), b).unsqueeze(0)
for j
in range(4)]
2493 def test_batch_matmul(self):
2496 return torch.matmul(a, b)
2498 def matmul_test(xs, batch, xs2, batch2):
2499 ys = [torch.matmul(xs[j].squeeze(0), xs2[j].squeeze(0)).unsqueeze(0)
for j
in range(4)]
2500 ybs = matmul(batch, batch2)
2506 matmul_test(xs, batch, xs2, batch2)
2509 xs2, batch2 = self.
rand_batch(4, (
False, 2), (
True, 3))
2510 matmul_test(xs, batch, xs2, batch2)
2512 xs, batch = self.
rand_batch(4, (
True, 3), (
False, 2))
2514 matmul_test(xs, batch, xs2, batch2)
2516 xs, batch = self.
rand_batch(4, (
True, 3), (
False, 2))
2517 xs2, batch2 = self.
rand_batch(4, (
False, 2), (
True, 3))
2518 matmul_test(xs, batch, xs2, batch2)
2520 def test_batch_select(self):
2523 return torch.select(x, 1, 0)
2525 xs, batch = self.
rand_batch(4, (
True, 3), (
True, 2))
2526 res_batch = select(batch)
2527 res = [torch.select(xs[j], 1, 0)
for j
in range(4)]
2530 xs, batch = self.
rand_batch(4, (
False, 3), (
True, 2))
2531 res_batch = select(batch)
2532 res = [torch.select(xs[j], 1, 0)
for j
in range(4)]
2535 def test_batch_index_select(self):
2537 def index_select(x, ind):
2538 return x.index_select(1, ind)
2540 xs, batch = self.
rand_batch(4, (
False, 5), (
True, 2))
2541 ind = [torch.randint(0, 4, (1,), dtype=torch.long)
for i
in range(4)]
2543 res_batch = index_select(batch, ind_batch)
2544 res = [torch.index_select(xs[j], 1, ind[j])
for j
in range(4)]
2547 def test_batch_where(self):
2550 return torch.where(c, a, b)
2552 xs, batch = self.
rand_batch(4, (
False, 3), (
False, 2))
2553 xs2, batch2 = self.
rand_batch(4, (
False, 3), (
False, 2))
2555 dims = [4, (
False, 3), (
False, 2)]
2556 xs_cond = [torch.rand(1, 3, 2).byte()
for i
in range(dims[0])]
2557 batch_cond = BatchTensor(xs_cond,
torch.tensor([b
for b, d
in dims[1:]]))
2559 res_batch = where(batch_cond, batch, batch2)
2560 res = [torch.where(xs_cond[j], xs[j], xs2[j])
for j
in range(4)]
2563 def test_batch_argmax(self):
2566 return torch.argmax(a, 1)
2568 xs, batch = self.
rand_batch(4, (
True, 5), (
True, 6))
2569 res_batch = argmax(batch)
2570 res = [torch.argmax(xs[j], 1)
for j
in range(4)]
2575 return torch.argmax(a, 1,
False)
2577 res_batch = argmax(batch)
2578 res = [torch.argmax(xs[j], 1,
False)
for j
in range(4)]
2581 def test_batch_topk(self):
2584 return torch.topk(a, 3, 1)
2586 xs, batch = self.
rand_batch(4, (
False, 5), (
True, 6))
2589 res_batch = topk(batch)
2590 res = [torch.topk(xs[j], 3, 1)[0]
for j
in range(4)]
2591 res_idx = [torch.topk(xs[j], 3, 1)[1]
for j
in range(4)]
2593 self.
assertEqual(res_idx, res_batch[1].examples())
2597 return torch.topk(a, 1, 2)
2600 res_batch = topk(batch)
2601 res = [torch.topk(xs[j], 1, 2)[0]
for j
in range(4)]
2602 res_idx = [torch.topk(xs[j], 1, 2)[1]
for j
in range(4)]
2604 self.
assertEqual(res_idx, res_batch[1].examples())
2606 def test_batch_softmax(self):
2609 return torch.softmax(a, 1)
2611 xs, batch = self.
rand_batch(4, (
False, 5), (
True, 6))
2614 res_batch = softmax(batch)
2615 res = [torch.softmax(xs[j], 1)
for j
in range(4)]
2620 return torch.softmax(a, 2)
2623 res_batch = softmax(batch)
2624 res = [torch.softmax(xs[j], 2)
for j
in range(4)]
2627 def test_batch_view(self):
2630 return a.view([4, -1, 3])
2632 xs, batch = self.
rand_batch(4, (
True, 5), (
False, 3))
2633 res_batch = view(batch)
2634 res = [xs[j].view([1, -1, 3])
for j
in range(4)]
2637 def test_batch_cat(self):
2640 return torch.cat([a, b], 2)
2642 xs, batch = self.
rand_batch(4, (
True, 5), (
False, 3))
2643 xs2, batch2 = xs, batch
2644 res_batch = cat2(batch, batch2)
2645 res = [torch.cat([xs[j], xs2[j]], 2)
for j
in range(4)]
2648 def test_batch_sum(self):
2653 xs, batch = self.
rand_batch(4, (
True, 5), (
False, 3))
2654 res_batch = batch_sum(batch)
2655 res = [xs[j].sum().unsqueeze(0)
for j
in range(4)]
2658 def test_if_else(self):
2659 def single_if(a, b):
2670 res_batch = batch_if(batch_a, batch_b)
2671 res = [single_if(a[j], b[j])
for j
in range(4)]
2675 torch.to_batch_graph(script_if.graph)
2677 def test_if_else_with_scalar(self):
2678 def single_if(a, b):
2689 res_batch = batch_if(batch_a, batch_b)
2690 res = [single_if(a[j], b[j])
for j
in range(4)]
2694 torch.to_batch_graph(script_if.graph)
2696 def test_if_noelse(self):
2697 def single_if(a, b):
2706 res_batch = batch_if(batch_a, batch_b)
2707 res = [single_if(a[j], b[j])
for j
in range(4)]
2711 torch.to_batch_graph(script_if.graph)
2713 def test_if_noelse_with_scalar(self):
2714 def single_if(a, b):
2723 res_batch = batch_if(batch_a, batch_b)
2724 res = [single_if(a[j], b[j])
for j
in range(4)]
2728 torch.to_batch_graph(script_if.graph)
2730 def test_while(self):
2731 def single_while(a, b):
2739 b = [torch.abs(torch.rand(1))
for i
in range(4)]
2741 res_batch = batch_while(batch_a, batch_b)
2742 res = [single_while(a[j], b[j])
for j
in range(4)]
2746 torch.to_batch_graph(script_while.graph)
2749 def single_for(x, y):
2758 res_batch = batch_for(batch_a, batch_b)
2759 res = [single_for(a[j], b[j])
for j
in range(4)]
2763 torch.to_batch_graph(script_for.graph)
2765 def test_lstm(self):
2766 def LSTM(x_all, h, c, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c):
2767 for i
in range(x_all.size(1)):
2768 x = x_all.select(1, i)
2769 i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
2770 f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
2771 o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
2773 i_t = torch.sigmoid(i_t)
2774 f_t = torch.sigmoid(f_t)
2775 o_t = torch.sigmoid(o_t)
2777 c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
2778 c_t = torch.tanh(c_t)
2779 c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
2780 h_t = torch.mul(o_t, torch.tanh(c_t))
2787 batch_size, input_size, hidden_size = 4, 3, 2
2788 xs, batch = self.
rand_batch(batch_size, (
True, 4), (
False, input_size))
2789 hx, h_batch = self.
rand_batch(batch_size, (
False, hidden_size))
2790 cx, c_batch = self.
rand_batch(batch_size, (
False, hidden_size))
2793 w_xi = torch.rand(input_size, hidden_size)
2794 w_xf = torch.rand(input_size, hidden_size)
2795 w_xo = torch.rand(input_size, hidden_size)
2796 w_xc = torch.rand(input_size, hidden_size)
2798 w_hi = torch.rand(hidden_size, hidden_size)
2799 w_hf = torch.rand(hidden_size, hidden_size)
2800 w_ho = torch.rand(hidden_size, hidden_size)
2801 w_hc = torch.rand(hidden_size, hidden_size)
2803 b_i = torch.rand(hidden_size)
2804 b_f = torch.rand(hidden_size)
2805 b_o = torch.rand(hidden_size)
2806 b_c = torch.rand(hidden_size)
2808 ys = [LSTM(xs[j], hx[j], cx[j], w_xi, w_xf, w_xo, w_xc,
2809 w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c)
for j
in range(batch_size)]
2810 ybs = LSTM_batch(batch, h_batch, c_batch, w_xi, w_xf, w_xo, w_xc,
2811 w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c)
2814 def test_greedy_search(self):
2815 def greedy(x, h, c, embed, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
2816 b_i, b_f, b_o, b_c, w_hs, b_s, iter_num):
2817 iter_count = torch.zeros_like(iter_num)
2818 while bool(iter_count < iter_num):
2819 iter_count = iter_count + 1
2821 i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
2822 f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
2823 o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
2825 i_t = torch.sigmoid(i_t)
2826 f_t = torch.sigmoid(f_t)
2827 o_t = torch.sigmoid(o_t)
2829 c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
2830 c_t = torch.tanh(c_t)
2831 c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
2832 h_t = torch.mul(o_t, torch.tanh(c_t))
2836 s_t = torch.matmul(h_t, w_hs) + b_s
2837 p_t = torch.softmax(s_t, 1)
2838 i_t = torch.argmax(p_t, 1)
2839 x = embed.index_select(1, i_t).squeeze(1)
2844 batch_size, input_size, hidden_size, vocab_size = 4, 6, 8, 7
2845 xs, batch = self.
rand_batch(batch_size, (
False, input_size))
2846 hx, h_batch = self.
rand_batch(batch_size, (
False, hidden_size))
2847 cx, c_batch = self.
rand_batch(batch_size, (
False, hidden_size))
2848 embed, embed_batch = self.
rand_batch(batch_size, (
False, vocab_size), (
False, input_size))
2849 iter_num = [torch.randint(2, 5, (1,))
for i
in range(batch_size)]
2850 iter_num_batch = BatchTensor(iter_num,
torch.tensor([]).byte())
2853 w_xi = torch.rand(input_size, hidden_size)
2854 w_xf = torch.rand(input_size, hidden_size)
2855 w_xo = torch.rand(input_size, hidden_size)
2856 w_xc = torch.rand(input_size, hidden_size)
2858 w_hi = torch.rand(hidden_size, hidden_size)
2859 w_hf = torch.rand(hidden_size, hidden_size)
2860 w_ho = torch.rand(hidden_size, hidden_size)
2861 w_hc = torch.rand(hidden_size, hidden_size)
2863 b_i = torch.rand(hidden_size)
2864 b_f = torch.rand(hidden_size)
2865 b_o = torch.rand(hidden_size)
2866 b_c = torch.rand(hidden_size)
2868 w_hs = torch.rand(hidden_size, vocab_size)
2869 b_s = torch.rand(vocab_size)
2871 ys = [greedy(xs[j], hx[j], cx[j], embed[j], w_xi, w_xf, w_xo, w_xc,
2872 w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num[j])
for j
in range(batch_size)]
2873 ybs = greedy_batch(batch, h_batch, c_batch, embed_batch, w_xi, w_xf, w_xo, w_xc,
2874 w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num_batch)
2877 def test_beam_search(self):
2878 def beam(x, h, c, embed, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
2879 b_i, b_f, b_o, b_c, w_hs, b_s, iter_num, idx):
2881 vocab_size = embed.size(1)
2882 iter_count = torch.zeros_like(iter_num)
2883 max_len = idx.size(2)
2884 while bool(iter_count < iter_num):
2885 iter_count = iter_count + 1
2887 i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
2888 f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
2889 o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
2891 i_t = torch.sigmoid(i_t)
2892 f_t = torch.sigmoid(f_t)
2893 o_t = torch.sigmoid(o_t)
2895 c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
2896 c_t = torch.tanh(c_t)
2897 c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
2898 h_t = torch.mul(o_t, torch.tanh(c_t))
2902 s_t = torch.matmul(h_t, w_hs) + b_s
2903 s_t = s_t.view([1, s_t.size(1) * s_t.size(2)])
2904 p_t = torch.softmax(s_t, 1)
2905 prob_t, idx_t = torch.topk(p_t, k, 1)
2906 if(int(idx_t.dim()) > 1):
2907 idx_t_tmp = idx_t.squeeze(0)
2910 new_y = torch.fmod(idx_t_tmp, vocab_size)
2911 pre_y = idx_t_tmp / vocab_size
2912 x = embed.index_select(1, new_y)
2913 h = h_t.index_select(1, pre_y)
2914 c = c_t.index_select(1, pre_y)
2915 iter = int(iter_count[0])
2916 idx = torch.cat([idx.narrow(2, 0, iter).index_select(1, pre_y),
2917 torch.fmod(idx_t, vocab_size).unsqueeze(-1),
2918 idx.narrow(2, iter, max_len - iter)], 2)
2919 idx = idx.narrow(2, 0, max_len)
2925 batch_size, input_size, hidden_size, vocab_size = 4, 6, 8, 7
2927 xs, batch = self.
rand_batch(batch_size, (
False, 1), (
False, input_size))
2928 hx, h_batch = self.
rand_batch(batch_size, (
False, 1), (
False, hidden_size))
2929 cx, c_batch = self.
rand_batch(batch_size, (
False, 1), (
False, hidden_size))
2930 embed, embed_batch = self.
rand_batch(batch_size, (
False, vocab_size), (
False, input_size))
2931 iter_num = [torch.randint(2, max_len + 1, (1,))
for i
in range(batch_size)]
2932 iter_num_batch = BatchTensor(iter_num,
torch.tensor([]).byte())
2935 w_xi = torch.rand(input_size, hidden_size)
2936 w_xf = torch.rand(input_size, hidden_size)
2937 w_xo = torch.rand(input_size, hidden_size)
2938 w_xc = torch.rand(input_size, hidden_size)
2940 w_hi = torch.rand(hidden_size, hidden_size)
2941 w_hf = torch.rand(hidden_size, hidden_size)
2942 w_ho = torch.rand(hidden_size, hidden_size)
2943 w_hc = torch.rand(hidden_size, hidden_size)
2945 b_i = torch.rand(1, hidden_size)
2946 b_f = torch.rand(1, hidden_size)
2947 b_o = torch.rand(1, hidden_size)
2948 b_c = torch.rand(1, hidden_size)
2950 w_hs = torch.rand(hidden_size, vocab_size)
2951 b_s = torch.rand(1, vocab_size)
2954 torch.zeros([batch_size, 1, max_len]).byte(),
2956 idx = [torch.zeros([1, k, max_len], dtype=torch.long)
for _
in range(batch_size)]
2958 ys = [beam(xs[j], hx[j], cx[j], embed[j], w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
2959 b_i, b_f, b_o, b_c, w_hs, b_s, iter_num[j], idx[j]).narrow(2, 0, int(iter_num[j]))
2960 for j
in range(batch_size)]
2961 ybs = beam_batch(batch, h_batch, c_batch, embed_batch, w_xi, w_xf, w_xo, w_xc,
2962 w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num_batch, idx_batch)
2966 def execWrapper(code, glob, loc):
2968 exec(code)
in glob, loc
2970 exec(code, glob, loc)
2975 def capture_stdout(self):
2984 stdout_fd = os.dup(1)
2991 captured_stdout = [
'']
2992 yield captured_stdout
2996 fcntl.fcntl(r, fcntl.F_SETFL, os.O_NONBLOCK)
3000 total_stdout += os.read(r, 1000).decode(
'ascii')
3001 except OSError
as e:
3002 if e.errno != errno.EAGAIN:
3005 captured_stdout[0] = total_stdout
3015 optimize=
True, outputs=
None, capture_output=
False):
3017 Checks that a given function will throw the correct exception, 3018 when executed with normal python, the string frontend, and the AST frontend 3025 source = textwrap.dedent(inspect.getsource(script))
3027 ge = getattr(cu, script.__name__)
3034 def test_training_param(self):
3036 @torch.jit.script_method
3037 def forward(self, x):
3053 def test_jitter_bug(self):
3055 def fn2(input, kernel_size):
3057 if kernel_size[0] > 1:
3060 _stride = kernel_size
3061 print(_stride, kernel_size)
3067 return fn2(input, [1])
3069 def test_parser_kwargonly(self):
3071 def foo(x, *, y) -> Tuple[Tensor, Tensor]: 3076 self.assertTrue(
'*' in cu.module._get_method(
'foo').pretty_print_schema())
3079 def foo(x, *, y) -> Tuple[Tensor, Tensor]: 3085 def test_annoying_doubles(self):
3086 mod = types.ModuleType(
"temp")
3087 mod.inf = float(
"inf")
3088 mod.ninf = float(
"-inf")
3089 mod.nan = float(
"nan")
3094 return math.pi, 0.1, mod.inf, mod.ninf, 2.225073858507201e-308, mod.nan
3096 pp, table = foo._get_method(
'forward').python_print()
3097 ppv =
"op_version_set = 0\n{}".format(pp)
3099 torch._C._jit_import_methods(sm, ppv, table)
3103 self.assertTrue(r[:-1] == r2[:-1])
3104 self.assertTrue(math.isnan(r[-1])
and math.isnan(r2[-1]))
3106 def test_type_annotate(self):
3126 def annotate_none():
3129 def annotate_none_no_optional():
3135 def test_robust_op_resolution(self):
3141 a = (torch.rand(3),)
3144 def test_tuple_io(self):
3150 a = (torch.rand(3), torch.rand(3))
3153 def test_tuple_create_return(self):
3156 a = (torch.ones(x), torch.zeros(x))
3160 def test_list_io(self):
3163 return torch.ones(x), x
3167 def get_sum_list_fn(self):
3178 def test_sum_list_diff_elms(self):
3181 def test_sum_list_empty(self):
3184 def test_sum_list_one(self):
3187 def test_sum_list_literal(self):
3192 for i
in [1, 2, 3, 4, 5]:
3199 def test_sum_list_wrong_type(self):
3213 def test_bool_list_io(self):
3217 return x, [
True,
False], [[
True]]
3219 li_1, li_2, li_3 = stuff4([
True])
3221 for li
in [li_1, li_2, li_3]:
3222 self.assertTrue(type(li[0]) == type(
True))
3224 def test_nested_list(self):
3231 def test_nested_list_construct(self):
3233 return [[4]] + [[4, 5]]
3236 def test_tensor_shape(self):
3237 x = torch.empty(34, 56, 78)
3244 def test_tensor_grad(self):
3249 return x.requires_grad
3254 def test_tensor_dtype(self):
3255 x_byte = torch.empty(34, 56, 78, dtype=torch.uint8)
3256 x_long = torch.empty(34, 56, 78, dtype=torch.long)
3257 x_float32 = torch.empty(34, 56, 78, dtype=torch.float32)
3261 return x.dtype == torch.uint8
3265 return x.dtype == torch.long
3269 return x.dtype == torch.float32
3271 self.assertTrue(byte(x_byte))
3272 self.assertFalse(byte(x_long))
3273 self.assertFalse(byte(x_float32))
3274 self.assertFalse(long(x_byte))
3275 self.assertTrue(long(x_long))
3276 self.assertFalse(long(x_float32))
3277 self.assertFalse(float32(x_byte))
3278 self.assertFalse(float32(x_long))
3279 self.assertTrue(float32(x_float32))
3281 @unittest.skipIf(
not RUN_CUDA,
"device tests require CUDA")
3282 def test_tensor_device(self):
3283 cpu = torch.empty(34, 56, 78, device=
'cpu')
3284 gpu = torch.empty(34, 56, 78, device=
'cuda')
3287 def same_device(x, y):
3288 return x.device == y.device
3290 self.assertTrue(same_device(cpu, cpu))
3291 self.assertTrue(same_device(gpu, gpu))
3292 self.assertFalse(same_device(cpu, gpu))
3294 @unittest.skipIf(
not RUN_CUDA,
"device tests require CUDA")
3295 def test_tensor_to_device(self):
3297 return x.to(device=
"cuda").to(device=torch.device(
"cpu"))
3301 def test_tensor_to_cpu(self):
3305 x = torch.ones(3, 4)
3307 self.
assertEqual(to_cpu(x).device, script_fn(x).device)
3310 @unittest.skipIf(
not RUN_CUDA,
"device tests require CUDA")
3311 def test_tensor_to_cuda(self):
3315 x = torch.ones(3, 4)
3317 self.
assertEqual(to_cuda(x).device, script_fn(x).device)
3320 def test_generic_list_errors(self):
3324 return [[x]] + [[1]]
3326 def test_script_cu(self):
3332 a = Variable(torch.rand(1))
3337 def test_string_cu(self):
3340 print(a, """a\\n\tb\\n""", 2, "a\ 3344 FileCheck().check(
"aa").check(
"a\\n\\tb\\n").run(str(cu.foo.graph))
3346 def test_string_ops(self):
3349 return a + a,
"ab" ==
"b",
"ab" !=
"b",
"ab" ==
"ab",
"ab" !=
"ab" 3353 def test_string_new_line(self):
3362 def test_string_single_escape(self):
3370 def test_script_annotation(self):
3374 s = Variable(torch.rand(2))
3380 return a < float(
'inf')
3382 self.assertTrue(foo(s))
3386 return a > float(
'-inf')
3388 self.assertTrue(foo(s))
3396 a = torch.rand(1, requires_grad=
True)
3397 b = torch.rand(1, requires_grad=
True)
3404 a = torch.rand(1, requires_grad=
True)
3405 b = torch.rand(1, requires_grad=
True)
3408 @unittest.skipIf(
not PY35,
"Python 3.5 needed")
3409 def test_matmul_py3(self):
3415 with tempfile.TemporaryDirectory()
as tmp_dir:
3416 script_path = os.path.join(tmp_dir,
'script.py')
3417 with open(script_path,
'w')
as f:
3419 fn = get_fn(
'test_matmul_py3', script_path)
3421 a = torch.rand(4, 3, requires_grad=
True)
3422 b = torch.rand(3, 2, requires_grad=
True)
3429 def func2(a, b, c, d):
3430 return c + a ** b ** d
3432 a = torch.rand(1, requires_grad=
True)
3433 b = torch.rand(1, requires_grad=
True)
3434 c = torch.rand(1, requires_grad=
True)
3435 d = torch.rand(1, requires_grad=
True)
3437 self.
checkScript(func2, (a, b, c, d), optimize=
True)
3439 def test_triple(self):
3443 x = torch.rand(1, dtype=torch.float, requires_grad=
True)
3446 def test_slice(self):
3450 x = torch.rand(10, dtype=torch.float, requires_grad=
True)
3458 def test_gather(self):
3462 x = torch.rand(10, dtype=torch.float, requires_grad=
True)
3465 def test_random(self):
3468 return torch.normal(mean, std)
3470 mean, std = torch.zeros(5, 5), torch.ones(5, 5)
3472 output = torch.normal(mean, std)
3474 script_output = f(mean, std)
3477 def _check_code(self, code_str, fn_name, inputs):
3479 exec(code_str, globals(), scope)
3481 self.
assertEqual(cu.func(*inputs), scope[fn_name](*inputs))
3483 @unittest.skipIf(
not RUN_CUDA,
'no CUDA')
3484 def test_scriptmodule_releases_tensors_cuda(self):
3487 return x.sigmoid() * y.tanh()
3489 def test(backward=False):
3490 x = torch.randn(3, 3, dtype=torch.double, device=
'cuda', requires_grad=
True)
3491 y = torch.randn(3, 3, dtype=torch.double, device=
'cuda', requires_grad=
True)
3494 out.sum().backward()
3506 def test_index(self):
3507 def consec(size, start=0):
3509 return torch.arange(numel).view(size)
3511 def check_indexing(indexing, tensor):
3512 template = dedent(
""" 3517 self.
_check_code(template.format(indexing),
"func", [tensor])
3519 def check_dynamic_indexing(indexing, tensor, value1, value2):
3523 template = dedent(
""" 3524 def func(x, value1, value2): 3530 self.
_check_code(template.format(indexing),
"func", [tensor, value1, value2])
3533 check_indexing(
'[0]', consec((3, 3)))
3534 check_indexing(
'[1]', consec((3, 3), 10))
3535 check_indexing(
'[2]', consec((3, 3), 19))
3536 check_indexing(
'[2]', consec((3,)))
3537 check_indexing(
'[-1]', consec((3, 3), 19))
3538 check_indexing(
'[0:2]', consec((3, 3, 3)))
3539 check_indexing(
'[1:-1]', consec((3, 3, 3)))
3540 check_indexing(
'[-3:-1]', consec((6, 3)))
3541 check_indexing(
'[1:]', consec((3, 3)))
3542 check_indexing(
'[:1]', consec((3, 3)))
3543 check_indexing(
'[:]', consec((3, 2)))
3546 check_indexing(
'[0, 1]', consec((3, 3)))
3547 check_indexing(
'[0, 1]', consec((3, 3, 2)))
3548 check_indexing(
'[1, 0, 2]', consec((3, 3, 3)))
3549 check_indexing(
'[2, -1]', consec((3, 3)))
3552 check_indexing(
'[0, 1:2]', consec((3, 3)))
3553 check_indexing(
'[0, :1]', consec((3, 3, 2)))
3554 check_indexing(
'[1, 2:]', consec((3, 3, 3)))
3555 check_indexing(
'[-1, 1:, 0]', consec((3, 3, 3, 3)))
3556 check_indexing(
'[1:, -1, 0]', consec((3, 3, 3, 3)))
3557 check_indexing(
'[-1, 2:, 1:2]', consec((3, 3, 3, 3)))
3558 check_indexing(
'[-1, 1:, 0]', consec((3, 3, 3, 3)))
3559 check_indexing(
'[-1, :, 0, 2]', consec((3, 3, 3, 3)))
3562 check_indexing(
'[0:0]', consec((2, 2)))
3563 check_indexing(
'[0:0, 1]', consec((3, 3)))
3566 check_indexing(
'[1+1]', consec((3, 3)))
3567 check_indexing(
'[1:(0 + 2)]', consec((3, 3, 3)))
3570 check_dynamic_indexing(
"[i + j]", consec((3, 3)), 0, 1)
3571 check_dynamic_indexing(
"[i:j, i]", consec((3, 3, 2)), 0, 2)
3573 def test_tensor_item(self):
3574 def test_scalar_to_float_coercion(x):
3575 return x.item() == 1
3580 def test_scalar_cast(x):
3582 return int(scalar), float(scalar)
3587 expected_str =
r"Use int\(tensor\) or float\(tensor\) to retrieve" 3595 def test_error_msg(x):
3596 return int_fn(x.item())
3598 def test_method_on_number(self):
3606 def test_scalar_to_num_conversions(self):
3608 def multiple_defs(x):
3613 self.assertTrue(
"ImplicitTensorToNum" not in str(multiple_defs.graph))
3616 def tensor_to_int_script(x, tensor):
3617 return x.unsqueeze(tensor)
3619 def tensor_to_int(x, tensor):
3620 return x.unsqueeze(tensor)
3623 def tensor_to_float_script(x, tensor):
3624 return x.addcmul(tensor, tensor, value=tensor)
3626 def tensor_to_float(x, tensor):
3627 return x.addcmul(tensor, tensor, value=tensor)
3636 script_funs = [tensor_to_int_script, tensor_to_float_script]
3637 funs = [tensor_to_int, tensor_to_float]
3640 def test_func(func, x, tensor):
3642 result = func(x, tensor)
3643 except RuntimeError
as e:
3645 except TypeError
as e:
3650 for tensor
in tensors:
3651 for i
in range(len(script_funs)):
3652 self.
assertEqual(test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor))
3654 def test_tuple_to_opt_list(self):
3664 def test_advancedindex(self):
3665 def consec(size, start=0):
3667 return torch.arange(numel).view(size)
3669 def check_indexing(indexing, tensor, **kwargs):
3670 indices_dict = kwargs
3672 template = dedent(
""" 3673 def func(x{formals}): 3679 for formal, value
in indices_dict.items():
3680 formals.append(formal)
3681 values.append(value)
3683 formals =
''.join(map(
', {}'.format, formals))
3684 inputs = [tensor] + values
3685 self.
_check_code(template.format(formals=formals, expr=indexing),
3689 check_indexing(
'[i]', consec((3, 3)), i=
torch.tensor([0]))
3690 check_indexing(
'[i]', consec((3, 3)), i=
torch.tensor(1))
3691 check_indexing(
'[i]', consec((3, 3)), i=
torch.tensor([-2]))
3692 check_indexing(
'[i]', consec((3, 3), 2), i=
torch.tensor([0, 0]))
3693 check_indexing(
'[i]', consec((3, 3, 2, 2)), i=
torch.tensor([0, -2, 1]))
3702 inp = consec((4, 8, 5))
3705 [
'[i, j]', {
'i': [0, 2],
'j': [1, 3]}],
3707 [
'[i, j, k]', {
'i': [0, 2],
'j': [1, 3],
'k': [1, 1]}],
3709 [
'[i, j, k]', {
'i': [0, 2],
'j': 1,
'k': [1, 1]}],
3711 [
'[:, :, i]', {
'i': [0, 3, 4]}],
3713 [
'[:, i, 2:4]', {
'i': [0, 2, 3]}],
3715 [
'[i, :, :]', {
'i': [2, 3]}],
3717 [
'[:, i, j]', {
'i': [0, 2, 3],
'j': [1, 3, 4]}],
3719 [
'[:, i, j]', {
'i': [0],
'j': [1, 2, 4]}],
3721 [
'[:, i, j]', {
'i': [0, 1, 3],
'j': [4]}],
3723 [
'[:, i, j]', {
'i': [[0, 1], [1, 0]],
'j': [[2, 3]]}],
3725 [
'[:, i, j]', {
'i': [[0, 1], [2, 3]],
'j': [[0]]}],
3727 [
'[:, i, j]', {
'i': [[5, 6]],
'j': [[0, 3], [4, 4]]}],
3729 [
'[i, j, :]', {
'i': [0, 2, 3],
'j': [1, 3, 4]}],
3731 [
'[i, j, :]', {
'i': 0,
'j': [1, 2, 4]}],
3733 [
'[i, j, :]', {
'i': [0, 1, 3],
'j': 4}],
3735 [
'[i, j, :]', {
'i': [[0, 1], [1, 0]],
'j': [[2, 1], [3, 5]]}],
3737 [
'[i, j, :]', {
'i': [[0, 1], [1, 0]],
'j': [[2, 3]]}],
3739 [
'[i, j, :]', {
'i': [[0, 1], [2, 3]],
'j': [[0]]}],
3741 [
'[i, j, :]', {
'i': [[2, 1]],
'j': [[0, 3], [4, 4]]}],
3743 [
'[i, j, 0:2]', {
'i': [[2]],
'j': [[0, 3], [4, 1]]}],
3746 for expr, argdict
in to_check:
3747 tensordict = {k:
torch.tensor(v)
for (k, v)
in argdict.items()}
3748 check_indexing(expr, inp, **tensordict)
3750 def test_keyword(self):
3753 return torch.sum(x, dim=0)
3755 x = torch.rand(10, dtype=torch.float, requires_grad=
True)
3757 y2 = torch.sum(x, dim=0)
3760 def test_constant_pooling_none(self):
3762 def typed_nones(a=None, b=None, c=None):
3770 print(typed_nones())
3772 print(typed_nones())
3774 graph_str = str(test.graph)
3775 self.assertTrue(graph_str.count(
"bool? = prim::Constant") == 1)
3776 self.assertTrue(graph_str.count(
"int? = prim::Constant") == 1)
3777 self.assertTrue(graph_str.count(
"None = prim::Constant") == 1)
3779 def test_literal(self):
3802 a = torch.rand(1, requires_grad=
True)
3803 b = torch.rand(1, requires_grad=
True)
3806 self.
checkScript(func3, (a.item(), b.item()), optimize=
True)
3808 def test_expand(self):
3813 x = torch.rand(2, 3, dtype=torch.float, requires_grad=
True)
3814 y = torch.rand(3, dtype=torch.float, requires_grad=
True)
3818 grad = torch.randn(2, 3, dtype=torch.float)
3826 return x.sum(dim=[4])
3833 self.
run_pass(
'constant_propagation', func.graph)
3834 self.
run_pass(
'constant_propagation', func2.graph)
3835 torch._C._jit_pass_shape_analysis(
3836 func.graph, (torch.zeros(1, 1, 1, 1, 4),),
False)
3837 torch._C._jit_pass_shape_analysis(
3838 func2.graph, (torch.zeros(1, 1, 1, 1, 4),),
False)
3839 self.assertTrue(func.graph.findNode(
"aten::sum").output().type().kind()
3840 ==
"DimensionedTensorType")
3841 self.assertTrue(func2.graph.findNode(
"aten::sum").output().type().kind()
3842 ==
"DimensionedTensorType")
3847 return torch.cat((x, x), dim=0)
3849 x = torch.rand(10, dtype=torch.float, requires_grad=
True)
3850 self.
assertEqual(func(x), torch.cat((x, x), dim=0))
3854 return torch.cat((x, x), y)
3856 x = torch.rand([2, 2])
3858 self.
assertEqual(func2(x, y), torch.cat((x, x), y))
3860 def test_cat_lifts(self):
3863 return torch.cat([x, x], dim=1)
3867 return torch.cat([], dim=1)
3871 return torch.cat([x], dim=1)
3873 for g
in [foo.graph, foo2.graph, foo3.graph]:
3874 FileCheck().check(
"int =").check(
"ListConstruct").check(
"aten::cat").run(str(g))
3876 def test_list_literal(self):
3884 def reassign_arity_change():
3889 self.
checkScript(reassign_arity_change, (), optimize=
False)
3891 def reassign_from_empty_literal():
3897 self.
checkScript(reassign_from_empty_literal, (), optimize=
False)
3899 def reassign_from_empty_builtin():
3908 z = [torch.randn([1])]
3910 self.
checkScript(reassign_from_empty_builtin, (), optimize=
False)
3912 def reassign_bad_type():
3918 self.
checkScript(reassign_bad_type, (), optimize=
False)
3920 def reassign_nested():
3928 self.
checkScript(reassign_nested, (), optimize=
False)
3930 def test_list_gather(self):
3937 def negative_index():
3948 "list index out of range")
3950 def bad_negative_index():
3955 "list index out of range")
3957 def test_tensor_len(self):
3963 def test_list_len(self):
3976 def test_list_ops(self):
3977 def test_equality():
3982 self.
checkScript(test_equality, (), optimize=
True)
3984 def test_inequality():
3989 self.
checkScript(test_equality, (), optimize=
True)
3991 def test_non_equality():
3996 self.
checkScript(test_non_equality, (), optimize=
True)
3998 def test_non_inequality():
4003 self.
checkScript(test_non_equality, (), optimize=
True)
4005 def test_list_equality_as_cond():
4014 self.
checkScript(test_list_equality_as_cond, (), optimize=
True)
4016 def test_list_add():
4020 return c == [1, 2, 3, 2]
4022 self.
checkScript(test_list_add, (), optimize=
True)
4024 def test_list_add_empty():
4028 return c == [1, 2, 3]
4030 self.
checkScript(test_list_add_empty, (), optimize=
True)
4032 def test_tensor_list_equality():
4033 t1 = torch.ones([1, 1])
4034 t2 = torch.ones([1, 1])
4039 self.
checkScript(test_tensor_list_equality, (), optimize=
True)
4041 def test_invalid_list_equality():
4042 t1 = torch.ones([2, 2])
4043 t2 = torch.ones([2, 2])
4050 test_invalid_list_equality,
4053 "bool value of Tensor")
4055 def test_list_slice(self):
4056 def test_regular_slice():
4058 return a[2:3] == [2]
4061 def test_open_ended_slice():
4063 return a[2:] == [2, 3, 4]
4066 def test_open_ended_slice2():
4068 return a[:2] == [0, 1]
4071 def test_negative_slice():
4073 return a[:-1] == [0, 1, 2, 3]
4076 def test_negative_slice2():
4078 return a[-3:-1] == [2, 3]
4081 def test_backward_slice():
4086 def test_over_slice():
4088 return a[3:10] == [3, 4]
4091 def test_mutable_list_append(self):
4096 return a == [0, 1, 2, 3]
4099 def test_mutable_list_append_2(self):
4100 def test_append_2():
4108 def test_mutable_list_append_if(self):
4109 def test_append_if():
4116 def test_mutable_list_append_if_else(self):
4117 def test_append_if_else():
4126 def test_mutable_list_append_loop(self):
4127 def test_append_loop():
4132 return a == [0, 1, 2, 3, 4]
4135 def test_mutable_list_append_loop_if(self):
4136 def test_append_loop_if():
4144 return a == [0, 0, 0, 0, 4]
4147 def test_mutable_list_nested_loop(self):
4148 def test_nested_loop():
4154 return a == [0, 1, 1, 2]
4157 def test_mutable_list_function_inline(self):
4171 def test_mutable_list_reverse_empty(self):
4172 def test_reverse_empty():
4179 def test_mutable_list_reverse(self):
4184 return a == [4, 3, 2, 1]
4187 def test_mutable_tensor_list_reverse(self):
4188 def test_tensor_reverse():
4195 def test_mutable_list_pop_empty(self):
4197 def test_pop_empty():
4204 def test_mutable_list_pop(self):
4213 def test_mutable_list_pop2(self):
4222 def test_mutable_list_pop_at(self):
4231 def test_mutable_list_pop_at2(self):
4240 def test_mutable_list_pop_at_negative(self):
4241 def test_pop_at_negative():
4249 def test_mutable_list_pop_at_negative2(self):
4250 def test_pop_at_negative2():
4258 def test_mutable_list_pop_slice(self):
4259 def test_pop_slice():
4270 @unittest.skipIf(sys.version_info < (3, 3),
"clear not supported in version < 3.3")
4271 def test_mutable_list_clear_empty(self):
4272 def test_clear_empty():
4279 @unittest.skipIf(sys.version_info < (3, 3),
"clear not supported in version < 3.3")
4280 def test_mutable_list_clear(self):
4288 def test_mutable_list_insert(self):
4289 def test_list_insert():
4293 return a == [1, 2, 5, 3, 4]
4296 def test_mutable_list_insert_negative(self):
4297 def test_list_insert_negative():
4301 return a == [1, 2, 3, 5, 4]
4304 def test_mutable_list_insert_neg_out_of_bounds(self):
4305 def test_list_insert_neg_out_of_bounds():
4309 return a == [5, 1, 2, 3, 4]
4310 self.
checkScript(test_list_insert_neg_out_of_bounds, ())
4312 def test_mutable_list_insert_out_of_bounds(self):
4313 def test_list_insert_out_of_bounds():
4317 return a == [1, 2, 3, 4, 5]
4318 self.
checkScript(test_list_insert_out_of_bounds, ())
4320 def test_mutable_list_remove_not_existing(self):
4322 def test_list_remove_not_existing():
4329 test_list_remove_not_existing()
4331 def test_mutable_list_remove(self):
4332 def test_list_remove():
4336 return a == [1, 2, 4]
4339 def test_list_index_not_existing(self):
4341 def list_index_not_existing():
4348 list_index_not_existing()
4350 def test_list_index(self):
4358 def test_tensor_list_index(self):
4359 def tensor_list_index():
4366 def test_tensor_list_index_not_existing(self):
4368 def tensor_list_index_not_existing():
4375 tensor_list_index_not_existing()
4377 def test_list_count(self):
4385 def test_list_count_not_existing(self):
4386 def list_count_not_existing():
4393 def test_tensor_list_count(self):
4394 def tensor_list_count():
4401 def test_tensor_list_count_not_existing(self):
4402 def tensor_list_count_not_existing():
4407 self.
checkScript(tensor_list_count_not_existing, ())
4409 def test_mutable_list_remove_tensor(self):
4410 def test_list_remove_tensor():
4411 a = [torch.ones(1), torch.zeros(1), torch.ones(2)]
4412 a.remove(torch.zeros(1))
4417 def test_mutable_list_remove2(self):
4418 def test_list_remove2():
4425 def test_extend_list_mutable(self):
4427 def extend_list(a, b):
4433 for l
in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]:
4434 for r
in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]:
4437 def test_extend_list_immutable(self):
4439 def extend_list(a, b):
4445 for l
in [[], [1], [1, 2, 3]]:
4446 for r
in [[], [1], [1, 2, 3]]:
4449 def test_copy_list_mutable(self):
4455 for l
in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]:
4458 def test_copy_list_immutable(self):
4464 for l
in [[], [1], [1, 2, 3]]:
4467 def test_func_call(self):
4475 def func(alpha, beta, x, y): 4476 return add(mul(alpha, x), mul(beta, y)) 4478 alpha = torch.rand(1, dtype=torch.float, requires_grad=
True)
4479 beta = torch.rand(1, dtype=torch.float, requires_grad=
True)
4480 x = torch.rand(3, dtype=torch.float, requires_grad=
True)
4481 y = torch.rand(3, dtype=torch.float, requires_grad=
True)
4482 outputs = alpha * x + beta * y
4484 self.
checkScript(script, [alpha, beta, x, y], optimize=
False, outputs=outputs)
4486 def test_resize_input_ops(self):
4492 def out_op_graph_input():
4495 torch.mul(x, y, out=z)
4498 torch._C._jit_pass_shape_analysis(
4499 test.graph, (torch.zeros(2, 1), torch.zeros(1, 2), torch.zeros(1, 1, 1)),
False)
4500 self.assertTrue(next(test.graph.outputs()).type() == TensorType.get())
4501 out_op_graph_input()
4506 after_resize_alias = torch.zeros([2])
4510 before_resize_alias = b.sub_(1)
4514 after_resize_alias = b.add_(1)
4515 return after_resize_alias
4518 self.
run_pass(
'constant_propagation', g)
4519 torch._C._jit_pass_shape_analysis(
4520 g, (torch.zeros(1, 1),),
False)
4521 resize_node = g.findNode(
"aten::resize_")
4523 self.assertTrue(next(resize_node.inputs()).type() == TensorType.get())
4524 self.assertTrue(next(resize_node.outputs()).type() == TensorType.get())
4527 before_resize = g.findNode(
"aten::sub_")
4528 self.assertTrue(next(before_resize.outputs()).type() == TensorType.get())
4530 after_resize = g.findNode(
"aten::add_")
4531 self.assertTrue(next(after_resize.outputs()).type() == TensorType.get())
4535 def test_resize_as():
4538 b = torch.zeros([2, 2])
4543 self.
run_pass(
'constant_propagation', g)
4544 torch._C._jit_pass_shape_analysis(
4545 g, (torch.zeros(1, 1),),
False)
4548 self.assertTrue(next(g.inputs()).type() != TensorType.get())
4550 self.assertTrue(next(g.outputs()).type() == TensorType.get())
4554 def test_view_shape_prop(self):
4556 def test_view_shape_prop(a): 4557 return a.view(size=[-1]) 4559 inputs = [torch.zeros(10, 10)]
4560 outputs = torch.zeros(100)
4562 real_outs = cu.test_view_shape_prop(*inputs)
4565 def test_view_listconstruct_shape_prop(self):
4570 return x.view(T, B, C)
4572 x = torch.randn(3, 1, 5, requires_grad=
True)
4574 torch._C._jit_pass_shape_analysis(graph, (x,),
False)
4575 a = next(graph.outputs()).type().kind()
4576 self.assertTrue(next(graph.outputs()).type().kind() !=
'TensorType')
4578 def test_integral_shape_inference(self):
4580 def test_integral_shape_inference(a): 4583 inputs = [torch.ones(10, 10).type(torch.LongTensor)]
4584 outputs = torch.ones(10, 10)
4586 self.
assertEqual(cu.test_integral_shape_inference(*inputs), outputs)
4588 def test_fuser_multiple_blocks(self):
4590 def test_fuser_multiple_blocks(this, that, theother, meme): 4593 this = torch.cat([this, meme], dim=0) 4594 that = torch.cat([that, meme], dim=0) 4595 theother = torch.cat([theother, meme], dim=0) 4597 return this, that, theother 4600 inputs = [torch.ones(0, 10, 10)] * 3
4601 inputs += [torch.ones(1, 10, 10)]
4602 outputs = [torch.ones(20, 10, 10)] * 3
4604 self.
assertEqual(cu.test_fuser_multiple_blocks(*inputs), outputs)
4606 def test_dropout_script(self):
4608 eg = torch.zeros(1, 2, 3, requires_grad=
True)
4615 class MyDrop(nn.Module):
4616 def forward(self, x):
4622 @unittest.skip(
"RuntimeError: VariableType::ID() not implemented")
4623 def test_cast(self):
4628 x = Variable(torch.FloatTensor([1.1, 2.3]), requires_grad=
True)
4629 out = Variable(torch.IntTensor([1, 2]), requires_grad=
True)
4630 self.
checkScript(script, [x], optimize=
True, outputs=[out], func=
'to_int')
4632 def test_python_frontend(self):
4635 q = x + y - z.sigmoid()
4638 if not x
and not y
and z:
4639 m = x
if not z
else y
4642 assert 1 == 1,
"hello" 4648 @unittest.skipIf(
not PY2,
"Requires python 2")
4649 def test_python_frontend_py2(self):
4651 raise Exception(
"hello")
4655 @unittest.skipIf(PY2,
"Requires python 3")
4656 def test_python_frontend_py3(self):
4658 raise Exception(
"hello")
4662 def _make_scalar_vars(self, arr, dtype):
4665 def test_string_print(self):
4667 print(a,
"a" 'b' '''c''' """d""", 2, 1.5)
4671 self.
checkScript(func, inputs, capture_output=
True)
4673 def test_while(self):
4674 def func(a, b, max):
4675 while bool(a < max):
4684 def test_fibb(self):
4692 while bool(i < lim):
4693 third = first + second
4698 somenum = somenum * 2
4701 i = i + dontmutateme
4705 return third, st, fs
4725 def test_if_for_in_range(self):
4740 def test_if_noelse(self):
4750 def test_if_is_none_dispatch(self):
4753 def test_lhs_none_rhs_none():
4758 elif None is not None:
4763 self.assertTrue(str(test_lhs_none_rhs_none.graph).count(
': int = prim::Constant') == 1)
4766 def test_lhs_opt_rhs_none(lhs=None):
4776 self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(
': int = prim::Constant') == 3)
4779 def test_lhs_none_rhs_opt(rhs=None):
4784 elif None is not rhs:
4789 self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(
': int = prim::Constant') == 3)
4792 def test_lhs_never_rhs_none(lhs):
4797 elif lhs
is not None:
4802 self.assertTrue(str(test_lhs_never_rhs_none.graph).count(
': int = prim::Constant') == 1)
4805 def test_lhs_none_rhs_never(rhs):
4810 elif None is not rhs:
4815 self.assertTrue(str(test_lhs_none_rhs_never.graph).count(
': int = prim::Constant') == 1)
4817 def test_explicit_bool_cast(self):
4820 def test_bool_cast(a):
4825 def test_while_nonexistent_value(self):
4828 def test_while(a, b): 4835 def test_while_nonexistent_cond_value(self):
4838 def test_while(a, b): 4845 def test_optional_refinement(self):
4847 def test_if_none_assignment(x):
4856 def test_ternary(x):
4858 x = x
if x
is not None else 2
4862 def test_not_none(x):
4870 if x
is not None and y
is not None:
4876 if not (x
is not None and y
is not None):
4882 def test_bool_expression(x):
4884 if x
is not None and x < 2:
4888 def test_nested_bool_expression(x, y):
4890 if x
is not None and x < 2
and y
is not None:
4899 if y
is None or x
is None:
4906 def test_manual_unwrap_opt(x):
4918 if x
is None or y
is None:
4923 def and_error(x, y):
4925 if x
is None and y
is None:
4934 x_none = x
is not None 4940 def named_var_and(x, y):
4942 x_none = x
is not None 4943 if y
is not None and x_none:
4946 def test_while_write_outer_then_read(self):
4956 def test_while_nest_if(self):
4972 def test_math_ops(self):
4975 return math.floor(1.5)
4979 def test_if_nest_while(self):
4992 def test_script_for_in_range(self):
4995 for i
in range(100):
4998 self.
checkScript(fn, (), outputs=4950, optimize=
True)
5000 def test_script_for_in_range_dynamic(self):
5003 for i
in range(100):
5011 def test_script_for_in_range_ast(self):
5013 def test_script_for_in_range_ast():
5015 for i
in range(100):
5022 self.
assertEqual(test_script_for_in_range_ast(), 161700)
5024 def test_script_for_in_range_if_ast(self):
5026 def test_script_for_in_range_if_ast(x):
5030 output = x.unsqueeze(0)
5032 output = torch.cat((output, x.unsqueeze(0)), dim=0)
5036 self.
assertEqual(test_script_for_in_range_if_ast(*inputs).shape[0], 20)
5038 def test_script_optional_none(self):
5048 self.
checkScript(none_stmt, [torch.arange(0, 2)], optimize=
True)
5049 self.
checkScript(none_args, [
None], optimize=
True)
5052 def test_script_optional_tensor_none(x=None):
5054 res = torch.zeros(1, dtype=torch.int8)
5061 fn = test_script_optional_tensor_none
5064 self.
assertEqual(fn(torch.zeros(1)), scripted_fn(torch.zeros(1)))
5067 def test_script_optional_other_none(x=None):
5076 fn = test_script_optional_other_none
5081 def test_script_clamp_none(self):
5082 def test_script_clamp_max_none(x):
5083 return torch.clamp(x, min=2, max=
None)
5085 def test_script_clamp_max(x):
5086 return torch.clamp(x, max=2)
5088 def test_script_clamp_min_none(x):
5089 return torch.clamp(x, min=
None, max=2)
5091 def test_script_clamp_min(x):
5092 return torch.clamp(x, min=2)
5094 input = [torch.arange(0, 3)]
5095 self.
checkScript(test_script_clamp_max_none, input, optimize=
True)
5096 self.
checkScript(test_script_clamp_max, input, optimize=
True)
5097 self.
checkScript(test_script_clamp_min_none, input, optimize=
True)
5098 self.
checkScript(test_script_clamp_min, input, optimize=
True)
5100 def test_script_bool_constant(self):
5102 def test_script_bool_constant(): 5107 self.
checkScript(script, [], outputs[0],
True,
'test_script_bool_constant')
5109 def test_ternary(self):
5112 c = a + b
if bool(a > 3)
else b
5117 self.
checkScript(func, inputs_true, optimize=
True)
5118 self.
checkScript(func, inputs_false, optimize=
True)
5120 def test_print(self):
5122 q = (x + y).sigmoid()
5123 print(q, 1, 2, [1, 2], [1.0, 2.0])
5127 x = torch.arange(4., requires_grad=
True)
5128 y = torch.arange(0., 8, 2, requires_grad=
True)
5129 self.
checkScript(func, [x, y], optimize=
True, capture_output=
True)
5131 def test_format(self):
5133 print(
"{}, I'm a {}".format(
"Hello",
"test"))
5134 print(
"format blank".format())
5135 print(
"stuff before {}".format(
"hi"))
5136 print(
"{} stuff after".format(
"hi"))
5139 x = torch.arange(4., requires_grad=
True)
5140 self.
checkScript(func, [x], optimize=
True, capture_output=
True)
5142 def test_logical_short_circuit(self):
5144 def testNoThrows(t):
5146 if (
False and bool(t[1]))
or (
True or bool(t[1])):
5151 ifs = testNoThrows.graph.findAllNodes(
"prim::If", recurse=
False)
5155 self.assertTrue(len(ifs) == 3)
5156 self.assertTrue(ifs[0].findNode(
"prim::If")
is None)
5157 self.assertTrue(ifs[1].findNode(
"prim::If").findNode(
"prim::If")
is None)
5158 self.assertTrue(ifs[2].findNode(
"prim::If")
is None)
5162 c0 =
False or bool(t[1])
5167 c0 =
True and bool(t[1])
5171 with self.
assertRaisesRegex(RuntimeError,
"index 1 out of range for tensor of size"):
5173 with self.
assertRaisesRegex(RuntimeError,
"index 1 out of range for tensor of size"):
5176 def test_type_cast(self):
5177 template = dedent(
''' 5179 # type: ({from_type}) -> {to_type} 5183 def check_cast(from_type, to_type, value, raises=False):
5184 code = template.format(from_type=from_type, to_type=to_type)
5185 expected = getattr(builtins, to_type)(value)
5190 self.
checkScript(code, (value,), name=
'cast', outputs=expected)
5192 check_cast(
'int',
'float', 1)
5193 check_cast(
'int',
'bool', 1)
5194 check_cast(
'int',
'bool', 0)
5196 check_cast(
'float',
'int', 1.)
5197 check_cast(
'float',
'bool', 1.)
5198 check_cast(
'float',
'bool', 0.)
5200 check_cast(
'bool',
'int',
True)
5201 check_cast(
'bool',
'float',
True)
5203 def test_multiple_assignment(self):
5209 y, z = outer_func(x)
5215 def test_literals(self):
5217 return a.view(size=[1, 2, 3])
5222 def test_return(self):
5232 def multiple_returns(a):
5233 return a * 1., a * 2., a * 3.
5235 a = torch.randn(1, dtype=torch.float)
5236 self.checkScript(no_return, [a], optimize=
True)
5237 self.checkScript(void_return, [a], optimize=
True)
5238 self.checkScript(one_return, [a], optimize=
True)
5239 self.checkScript(multiple_returns, [a], optimize=
True)
5241 with self.assertRaisesRegex(RuntimeError,
"but is actually of type None"):
5243 def no_return_bad_annotation(a): 5244 # type: (Tensor) -> Tensor 5248 def test_error(self):
5252 s = Variable(torch.rand(5, 5, 5))
5254 with self.assertRaisesRegex(RuntimeError,
"failed in interpreter"):
5261 with self.assertRaisesRegex(RuntimeError,
"failed in interpreter"):
5262 bar(Variable(torch.rand(10), requires_grad=
True), Variable(torch.rand(9), requires_grad=
True))
5264 def test_binop_unsupported_error(self):
5265 with self.assertRaisesRegex(NotSupportedError,
"unsupported binary operator:"):
5271 def test_bitwise_ops(self):
5274 return 2 & 3, 2 ^ 3, 2 | 3
5276 self.checkScript(int_test, ())
5278 def bool_test(x, y):
5280 return x & y, x ^ y, x | y
5282 self.checkScript(bool_test, (
True,
False))
5283 self.checkScript(bool_test, (
True,
True))
5285 def tensor_test(x, y):
5286 return x & y, x ^ y, x | y
5291 self.checkScript(tensor_test, (x, y))
5293 def test_number_math(self):
5294 ops_template = dedent(
''' 5296 return {scalar1} {op} {scalar2} 5298 ops = [
'+',
'-',
'*',
'%',
'<',
'<=',
'>',
'>=',
'==',
'!=',
'//']
5299 funcs_template = dedent(
''' 5301 return {func}({scalar1}, {scalar2}) 5303 funcs = [
'min',
'max']
5304 scalars = [
'7',
'2',
'3',
'-3',
'3.14',
'0.125',
'-0.5',
'2.0',
'-2.0']
5305 scalar_pairs = [(scalar1, scalar2)
for scalar1
in scalars
for scalar2
in scalars]
5309 execWrapper(code, globals(), scope)
5312 self.assertEqual(cu.func(), scope[
'func']())
5314 for scalar1, scalar2
in scalar_pairs:
5316 code = ops_template.format(op=op, scalar1=scalar1, scalar2=scalar2)
5319 code = funcs_template.format(func=func, scalar1=scalar1, scalar2=scalar2)
5322 def test_number_div(self):
5323 self.checkScript(div_int_future, (), optimize=
True)
5324 self.checkScript(div_float_future, (), optimize=
True)
5327 with self.assertRaisesRegex(RuntimeError,
'from __future__ import division'):
5329 with self.assertRaisesRegex(RuntimeError,
'from __future__ import division'):
5332 self.checkScript(div_int_nofuture, (), optimize=
True)
5333 self.checkScript(div_float_nofuture, (), optimize=
True)
5335 def test_floor_div(self):
5340 for i
in range(-8, 8):
5341 for j
in range(-8, 8):
5343 self.assertEqual(foo(i, j), i // j)
5345 with self.assertRaisesRegex(RuntimeError,
'division by 0'):
5348 def test_number_augassign(self):
5354 self.checkScript(func, (), optimize=
True)
5356 def test_number_neg(self):
5365 self.checkScript(func1, (), optimize=
True)
5366 self.checkScript(func2, (), optimize=
True)
5368 def _test_tensor_number_math(self, device='cpu'):
5369 template = dedent(
''' 5371 return {lhs} {op} {rhs} 5374 def test(op, const, swap_args):
5379 code = template.format(lhs=args[0], rhs=args[1], op=op)
5381 execWrapper(code, globals(), scope)
5383 self.assertEqual(cu.func(tensor), scope[
'func'](tensor))
5386 var_float = [1.4321, -1.2]
5388 ops = [
'+',
'-',
'*',
'%',
'<',
'<=',
'>',
'>=',
'==',
'!=',
'/']
5390 float_tensor = torch.randn(5, 5, device=device)
5391 double_tensor = torch.randn(5, 5, dtype=torch.double, device=device)
5392 long_tensor = torch.randint(-5, 5, (5, 5), dtype=torch.long, device=device)
5393 long_tensor[long_tensor == 0] = 2
5395 tensors = [float_tensor, double_tensor, long_tensor]
5396 consts = var_int + var_float
5398 for op, tensor, const, swap_args
in product(ops, tensors, consts, [
True,
False]):
5401 if op ==
'/' and tensor.data_ptr() == long_tensor.data_ptr():
5405 if op ==
'%' and swap_args
is True:
5408 test(op, const, swap_args)
5410 def test_tensor_number_math(self):
5411 self._test_tensor_number_math()
5413 def test_torch_tensor_bad_input(self):
5414 with self.assertRaisesRegex(RuntimeError,
"Input list to torch.tensor must be of ints, floats, " 5415 "or bools, got None"):
5420 with self.assertRaisesRegex(RuntimeError,
"Note: empty lists are constructed as Tensor"):
5428 with self.assertRaisesRegex(RuntimeError,
"Expected sequence of length"):
5432 def test_torch_tensor_empty_list(self):
5440 self.assertNotEqual(t1.dtype, t2.dtype)
5446 self.checkScript(func, ())
5452 self.checkScript(func, ())
5454 def test_torch_tensor(self):
5455 template = dedent(
''' 5458 return torch.tensor(li {options}) 5461 lists = [
"2.5",
"4",
"True",
"False",
"[2]",
"[-.5]",
"[False, True, False]",
"[2, 2]",
5462 "torch.jit.annotate(List[int], [])",
"[2.5, 2.5]",
"[[2], [2]]",
"[[-.5], [2.2]]",
"[[False], [True]]"]
5464 dtypes = [
"",
", dtype=torch.float",
", dtype=torch.double",
", dtype=torch.half",
5465 ", dtype=torch.uint8",
", dtype=torch.int8",
", dtype=torch.short",
5466 ", dtype=torch.int",
", dtype=torch.long"]
5468 devices = [
'',
", device='cpu'"]
5470 devices.append(
", device='cuda'")
5472 option_pairs = [dtype + device
for dtype
in dtypes
for device
in devices]
5474 for option
in option_pairs:
5476 if "annotate" in li
and "dtype" not in option:
5478 code = template.format(list_create=li, options=option)
5480 exec(code, globals(), scope)
5483 t2 = scope[
'func']()
5484 if t1.dtype == torch.float16:
5485 self.assertTrue(str(t1) == str(t2))
5487 self.assertEqual(t1, t2)
5488 self.assertEqual(t1.dtype, t2.dtype)
5489 self.assertEqual(t1.device, t2.device)
5492 def test_tensor_to(self):
5493 template = dedent(
''' 5497 non_blocking = {non_blocking} 5501 def s(t, to_str, non_blocking=None, device=None, cuda=None):
5502 device = device
if device
is not None else str(t.device)
5503 non_blocking = non_blocking
if non_blocking
is not None else False 5504 cuda =
"cuda" if cuda
is None else cuda
5505 code = template.format(to_str=to_str, device=device, non_blocking=non_blocking, cuda=cuda)
5510 def test_copy_behavior(t, non_blocking=False):
5511 self.assertIs(t, s(t,
't.to(t, non_blocking=non_blocking)', non_blocking))
5512 self.assertIs(t, s(t,
't.to(t.dtype, non_blocking=non_blocking)', non_blocking))
5513 self.assertIs(t, s(t,
't.to(torch.empty_like(t), non_blocking=non_blocking)', non_blocking))
5514 self.assertIsNot(t, s(t,
't.to(t, non_blocking=non_blocking, copy=True)', non_blocking))
5515 self.assertIsNot(t, s(t,
't.to(t.dtype, non_blocking=non_blocking, copy=True)', non_blocking))
5516 self.assertIsNot(t, s(t,
't.to(torch.empty_like(t), non_blocking=non_blocking, copy=True)', non_blocking))
5518 devices = [t.device]
5519 if t.device.type ==
'cuda':
5520 if t.device.index == -1:
5523 devices.append(
'cuda')
5524 for device
in devices:
5525 self.assertIs(t, s(t,
't.to(device, non_blocking=non_blocking)', non_blocking, device))
5526 self.assertIs(t, s(t,
't.to(device, t.dtype, non_blocking=non_blocking)', non_blocking, device))
5527 self.assertIsNot(t, s(t,
't.to(device, non_blocking=non_blocking, copy=True)', non_blocking, device))
5528 self.assertIsNot(t, s(t,
't.to(device, t.dtype, non_blocking=non_blocking, copy=True)',
5529 non_blocking, device))
5532 test_copy_behavior(t)
5534 self.assertEqual(t.device, s(t,
"t.to('cpu')").device)
5535 self.assertEqual(t.device, s(t,
"t.to('cpu', dtype=torch.float32)").device)
5536 self.assertIs(torch.float32, s(t,
"t.to('cpu', dtype=torch.float32)").dtype)
5537 self.assertEqual(t.device, s(t,
"t.to(torch.float32)").device)
5538 self.assertIs(torch.float32, s(t,
"t.to(dtype=torch.float32)").dtype)
5539 self.assertEqual(t.data_ptr(), s(t,
"t.to('cpu')").data_ptr())
5540 self.assertEqual(t.data_ptr(), s(t,
"t.to(dtype=t.dtype, device=t.device, copy=False)").data_ptr())
5541 self.assertEqual(t.data_ptr(), s(t,
"t.to('cpu', copy=False)").data_ptr())
5542 self.assertNotEqual(t.data_ptr(), s(t,
"t.to('cpu', copy=True)").data_ptr())
5546 for non_blocking
in [
True,
False]:
5549 test_copy_behavior(b, non_blocking)
5550 self.assertEqual(b.device, s(b,
"t.to(cuda, non_blocking=non_blocking).device", cuda=cuda))
5551 self.assertEqual(a.device, s(b,
"t.to('cpu', non_blocking=non_blocking).device"))
5552 self.assertEqual(b.device, s(b,
"t.to(cuda, non_blocking=non_blocking).device", cuda=cuda))
5553 self.assertIs(torch.int32, s(b,
"t.to('cpu', dtype=torch.int32, non_blocking=non_blocking)").dtype)
5554 self.assertEqual(a.device, s(b,
"t.to('cpu', dtype=torch.int32, non_blocking=non_blocking)").device)
5555 self.assertIs(torch.int32, s(b,
"t.to(dtype=torch.int32)").dtype)
5556 self.assertEqual(b.device, s(b,
"t.to(dtype=torch.int32)").device)
5560 out_ref = t.to(torch.float32)
5561 out = s(t,
"t.to(torch.float32)")
5562 self.assertEqual(out_ref, out)
5566 self.assertEqual(grad_ref, grad)
5569 out_ref = t.to(
'cpu')
5570 out = s(t,
"t.to('cpu')")
5571 self.assertEqual(out_ref, out)
5575 self.assertEqual(grad_ref, grad)
5579 def func2(t, t_ref):
5582 func2.debug_disable_autodiff_subgraph_inlining()
5585 out_ref = t.to(t_ref)
5586 out = func2(t, t_ref)
5589 self.assertEqual(grad_ref, grad)
5591 @unittest.skipIf(
not RUN_CUDA,
"No CUDA")
5592 def test_tensor_number_math_cuda(self):
5593 self._test_tensor_number_math(device=
'cuda')
5599 return not bool(a > 1)
5601 self.checkScript(test_not_op, (
torch.tensor(2), ), optimize=
True)
5603 def test_is_isnot(self):
5605 template = dedent(
''' 5608 return {lhs} {op} {rhs} 5612 code = template.format(lhs=args[0], rhs=args[1], op=op)
5614 execWrapper(code, globals(), scope)
5619 "Failed with op: {}, lhs: {}, rhs: {}" 5620 .format(op, args[0], args[1])
5623 ops = [
'is',
'is not']
5624 type_literals = [
True,
False,
None, [1, 1]]
5627 for op, lhs, rhs
in product(ops, type_literals, type_literals):
5628 test(op, [lhs, rhs])
5630 def test_isinstance(self):
5632 template = dedent(
''' 5634 # type: ({type_hint}) -> bool 5635 return isinstance(x, {typ}) 5638 def test(inp, typ, type_hint):
5639 code = template.format(typ=typ, type_hint=type_hint)
5641 execWrapper(code, globals(), scope)
5646 "Failed with typ: {}" 5650 inputs = [
True, 1, 1.0,
torch.tensor(1), [1, 2], (1.0,), [1, 2], 1]
5651 type_literals = [
'bool',
'int',
'float',
'torch.Tensor',
'list',
'tuple',
5652 '(list, tuple)',
'(int, float, bool)']
5653 type_annotations = [
'bool',
'int',
'float',
'Tensor',
'List[int]',
'Tuple[float]',
5657 for inp, typ, type_hint
in zip(inputs, type_literals, type_annotations):
5658 test(inp, typ, type_hint)
5661 with self.assertRaisesRegex(RuntimeError,
"Optional isinstance check is not supported"):
5665 return isinstance(x, int)
5667 def test_python_call(self):
5675 def test_call_python(a): 5687 inputs = self._make_scalar_vars([1], torch.float)
5688 outputs = self._make_scalar_vars([54], torch.float)
5690 self.assertEqual(cu.test_call_python(*inputs), outputs[0])
5692 def test_python_call_failure(self):
5693 with self.assertRaisesRegex(RuntimeError,
"undefined value pyfunc2"):
5701 def test_call_python(a): 5713 inputs = self._make_scalar_vars([1], torch.float)
5714 outputs = self._make_scalar_vars([54], torch.float)
5716 self.assertEqual(cu.test_call_python(*inputs), outputs)
5718 def test_python_call_annotation(self):
5724 return pyfunc(a) + pyfunc(a)
5726 inputs = self._make_scalar_vars([1], torch.float)
5727 outputs = self._make_scalar_vars([6], torch.float)
5728 self.assertEqual(foo(*inputs), outputs[0])
5730 def test_python_call_annoytation_failure(self):
5731 with self.assertRaisesRegex(RuntimeError,
"undefined value pyfunc2"):
5737 return pyfunc2(a) + pyfunc(a)
5739 inputs = self._make_scalar_vars([1], torch.float)
5740 outputs = self._make_scalar_vars([6], torch.float)
5742 self.assertEqual(foo(*inputs), outputs[0])
5744 def test_desugar_module(self):
5750 c = F.prelu(x, slope)
5753 x = torch.arange(-3., 4)
5755 self.checkScript(fn, [x, slope], optimize=
True)
5757 def test_script_docstring(self):
5759 def with_docstring(x):
5762 """y is the same as x""" 5764 self.assertEqual(with_docstring.__doc__,
'test str')
5766 def test_script_method_docstring(self):
5768 @torch.jit.script_method
5769 def with_docstring(self, x):
5772 """y is the same as x""" 5775 self.assertEqual(a.with_docstring.__doc__,
'test str')
5777 @unittest.skipIf(TEST_WITH_UBSAN
or not torch.fbgemm_is_cpu_supported(),
5778 'Quantized RNN requires FBGEMM. FBGEMM does not play' 5779 ' well with UBSAN at the moment, so we skip the test if' 5780 ' we are in a UBSAN environment.')
5781 def test_rnn_cell_quantized(self):
5785 torch.nn.LSTMCell(d_in, d_hid).float(),
5786 torch.nn.GRUCell(d_in, d_hid).float(),
5787 torch.nn.RNNCell(d_in, d_hid).float(),
5789 if isinstance(cell, torch.nn.LSTMCell):
5791 elif isinstance(cell, torch.nn.GRUCell):
5793 elif isinstance(cell, torch.nn.RNNCell):
5807 vals = [[100, -155],
5815 vals = vals[:d_hid * num_chunks]
5816 cell.weight_ih = torch.nn.Parameter(
5818 requires_grad=
False)
5819 cell.weight_hh = torch.nn.Parameter(
5821 requires_grad=
False)
5823 ref = copy.deepcopy(cell)
5828 [100, -155]], dtype=torch.float)
5829 h0_vals = [[-155, 100],
5841 def __init__(self, cell):
5842 super(ScriptWrapper, self).__init__()
5845 @torch.jit.script_method
5846 def forward(self, x, hiddens):
5848 return self.cell(x, hiddens)
5852 def __init__(self, cell):
5853 super(ScriptWrapper, self).__init__()
5856 @torch.jit.script_method
5857 def forward(self, x, hiddens):
5859 return self.cell(x, hiddens)
5861 cell = ScriptWrapper(cell)
5862 outs = cell(x, hiddens)
5863 cell = self.getExportImportCopyWithPacking(cell)
5865 outs = cell(x, hiddens)
5866 ref_outs = ref(x, hiddens)
5868 self.assertEqual(len(outs), len(ref_outs))
5869 for out, ref_out
in zip(outs, ref_outs):
5872 def test_script_module(self):
5875 super(M1, self).__init__(
False)
5876 self.weight = nn.Parameter(torch.randn(2))
5878 @torch.jit.script_method
5879 def forward(self, thing):
5880 return self.weight + thing
5882 class PModule(nn.Module):
5884 super(PModule, self).__init__()
5885 self.a = nn.Parameter(torch.randn(2, 3))
5887 def forward(self, a):
5892 super(M2, self).__init__(
False)
5895 self.sub2 = PModule()
5897 self.weight = nn.Parameter(torch.randn(2, 3))
5898 self.bias = nn.Parameter(torch.randn(2))
5902 return self.weight.mm(a) 5906 @torch.jit.script_method
5907 def doit(self, input):
5909 return self.weight.mm(input)
5911 @torch.jit.script_method
5912 def doit2(self, input):
5913 return self.weight.mm(input)
5915 @torch.jit.script_method
5916 def forward(self, input):
5917 a = self.doit(input)
5918 b = self.doit2(input)
5920 d = self.sub2(input)
5921 return a + b + self.bias + self.sub(a) + c + d
5923 input = torch.randn(3, 2)
5924 a = m2.weight.mm(input)
5925 b = m2.weight.mm(input)
5926 c = m2.weight.mm(input)
5927 d = m2.sub2.a.mm(input)
5928 ref = a + b + m2.bias + m2.sub.weight + a + c + d
5929 self.assertEqual(ref, m2.forward(input))
5930 m2.weight = nn.Parameter(torch.zeros_like(m2.weight))
5931 m2.bias = nn.Parameter(torch.zeros_like(m2.bias))
5932 m2.sub.weight = nn.Parameter(torch.zeros_like(m2.sub.weight))
5933 m2.sub2.a.data.zero_()
5934 self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2)))
5936 def test_filecheck(self):
5939 FileCheck().check(
"2").check(
"3").check(
"2").run(file)
5940 FileCheck().check(
"232").run(file)
5942 with self.assertRaisesRegex(RuntimeError,
'Expected to find "22"'):
5943 FileCheck().check(
"22").run(file)
5944 with self.assertRaisesRegex(RuntimeError,
"CHECK: 3"):
5945 FileCheck().check(
"3").check(
"3").run(file)
5949 def test_check_count():
5951 FileCheck().check_count(
"2", 5).run(file)
5952 FileCheck().check_count(
"22", 2).run(file)
5953 FileCheck().check_count(
"222", 1).run(file)
5955 with self.assertRaisesRegex(RuntimeError,
'Expected to not find'):
5956 FileCheck().check_count(
"2", 4, exactly=
True).run(file)
5958 with self.assertRaisesRegex(RuntimeError,
'Expected to find "22"'):
5959 FileCheck().check_count(
"22", 3).run(file)
5961 with self.assertRaisesRegex(RuntimeError,
"CHECK-COUNT-6: 2"):
5962 FileCheck().check_count(
"2", 6).run(file)
5966 def test_check_same():
5968 FileCheck().check_same(
"22").run(file)
5970 with self.assertRaisesRegex(RuntimeError,
"Expected to not find"):
5971 FileCheck().check_same(
"33").run(file)
5975 FileCheck().check(
"2").check_same(
"3").run(file)
5976 FileCheck().check_count(
"2", 2).check_same(
"3").run(file)
5980 def test_check_next():
5982 FileCheck().check(
"1").check_next(
"2").check_next(
"3").run(file)
5983 FileCheck().check_next(
"1").check_next(
"2").check_next(
"3").run(file)
5985 with self.assertRaisesRegex(RuntimeError,
"Expected to find"):
5986 FileCheck().check(
"1").check_next(
"2").run(
"12")
5988 with self.assertRaisesRegex(RuntimeError,
"Expected to not find"):
5989 FileCheck().check(
"1").check_next(
"2").run(
"1\n\n2")
5993 def test_check_dag():
5994 fc = FileCheck().check_dag(
"1").check_dag(
"2").check_not(
"2")
5999 fc.check_not(
"3").check_dag(
"1").check_dag(
"2").check_not(
"3")
6003 fc = FileCheck().check_dag(
"1").check_dag(
"2").check(
"3")
6004 with self.assertRaisesRegex(RuntimeError,
'Expected to find "3" but did not find it'):
6009 def test_check_not():
6010 FileCheck().check_not(
"2").check(
"1").run(
"12")
6011 FileCheck().check(
"2").check_not(
"2").run(
"12")
6013 with self.assertRaisesRegex(RuntimeError,
'Expected to not find "2"'):
6014 FileCheck().check_not(
"2").check(
"1").run(
"21")
6016 with self.assertRaisesRegex(RuntimeError,
'Expected to not find "1"'):
6017 FileCheck().check(
"2").check_not(
"1").run(
"21")
6020 fb = FileCheck().check_count(
"2", 2).check_count(
"2", 2).check_not(
"2")
6021 with self.assertRaisesRegex(RuntimeError,
'Expected to not find "2"'):
6024 fb = FileCheck().check_count(
"2", 2).check_not(
"1").check_count(
"2", 2)
6025 with self.assertRaisesRegex(RuntimeError,
'Expected to not find "1"'):
6028 def test_script_module_call_noscript(self):
6031 super(M, self).__init__(
False)
6035 return torch.ones(2, 2) + self.value
6037 @torch.jit.script_method
6038 def forward(self, input):
6039 return input + self.foo()
6042 input = torch.randn(2, 2)
6044 self.assertEqual(o, input + torch.ones(2, 2) + 1)
6049 self.assertEqual(o, input + torch.ones(2, 2) + 2)
6051 def test_script_module_nochange_submodule(self):
6054 super(M, self).__init__(
False)
6055 self.sub = nn.Linear(5, 5)
6057 @torch.jit.script_method
6058 def forward(self, input):
6059 return self.sub(input)
6062 input = torch.randn(1, 5, 5)
6064 self.assertEqual(o, m.sub(input))
6065 with self.assertRaisesRegex(RuntimeError,
"cannot re-assign"):
6066 m.sub = nn.Linear(5, 5)
6068 def test_script_inline_trace_multiple_args(self):
6071 super(M, self).__init__(
False)
6073 def forward(self, input, input2):
6074 return input + input2
6078 super(M2, self).__init__(
False)
6081 @torch.jit.script_method
6082 def forward(self, inp):
6083 return self.m(inp, inp)
6086 m2(torch.zeros(4, 3))
6088 def test_script_module_const(self):
6091 __constants__ = [
'b',
'i',
'c']
6094 super(M, self).__init__(
False)
6099 @torch.jit.script_method
6101 return self.b, self.i, self.c
6105 self.assertEqual(o0, 0)
6106 self.assertEqual(o1, 1)
6107 self.assertEqual(o2, 3.5)
6109 def test_script_module_fail_const(self):
6112 super(M, self).__init__(
False)
6115 @torch.jit.script_method
6118 with self.assertRaisesRegex(RuntimeError,
"is not usable in a script method"):
6121 def test_script_module_valid_consts(self):
6125 __constants__ = [
'a',
'b',
'c',
'd',
'e',
'f',
'g',
'h',
'i']
6128 super(Foo, self).__init__(
False)
6132 with tester.assertRaisesRegex(
6134 "'Linear' object for attribute 'd' is not a valid constant"):
6135 self.d = [nn.Linear(3, 4)]
6136 self.e =
lambda x: x
6138 tester.assertTrue(type(self.f)
is tuple)
6139 self.g = [3, (3, 4), 5]
6140 with tester.assertRaisesRegex(TypeError,
"not a valid constant"):
6142 with tester.assertRaisesRegex(TypeError,
"not a valid constant"):
6147 def test_script_module_param_buffer_mutation(self):
6151 super(ModuleBufferMutate, self).__init__(
False)
6152 self.register_buffer(
'running_var',
torch.tensor(0, dtype=torch.long))
6154 @torch.jit.script_method
6157 self.running_var += 1
6158 return self.running_var
6160 m = ModuleBufferMutate()
6161 self.assertEqual(m(), 1)
6163 self.assertEqual(m(), 1)
6165 def test_script_module_for(self):
6167 __constants__ = [
'b']
6170 super(M, self).__init__(
False)
6171 self.b = [1, 2, 3, 4]
6173 @torch.jit.script_method
6181 self.assertEqual(m(), 10)
6183 def test_script_module_for2(self):
6186 super(Sub, self).__init__(
False)
6187 self.weight = nn.Parameter(torch.randn(2))
6189 @torch.jit.script_method
6190 def forward(self, thing):
6191 return self.weight + thing
6194 __constants__ = [
'mods']
6197 super(M, self).__init__(
False)
6198 self.mods = nn.ModuleList([Sub()
for i
in range(10)])
6200 @torch.jit.script_method
6201 def forward(self, v):
6212 self.assertEqual(o, v)
6214 def test_script_module_const_submodule_fail(self):
6217 super(Sub, self).__init__(
False)
6218 self.weight = nn.Parameter(torch.randn(2))
6220 @torch.jit.script_method
6221 def forward(self, thing):
6222 return self.weight + thing
6226 super(M, self).__init__(
False)
6227 self.mods = [Sub()
for _
in range(10)]
6229 @torch.jit.script_method
6235 with self.assertRaisesRegex(RuntimeError,
"did you forget to add it __constants__"):
6241 self.tensor_constant = torch.ones(2)
6243 @torch.jit.script_method
6245 return self.tensor_constant + 2
6247 with self.assertRaisesRegex(RuntimeError,
"Tensors must be added to a module as a buffer or parameter"):
6253 self.
param = torch.nn.Parameter(torch.ones(3, 4, dtype=torch.float))
6254 self.register_buffer(
'derived', torch.neg(self.
param).detach().clone())
6257 self.register_buffer(
'pack_called', torch.zeros(1, dtype=torch.long))
6259 self.register_buffer(
'unpack_called', torch.zeros(1, dtype=torch.long))
6261 @torch.jit.script_method
6263 self.pack_called.set_(torch.ones(1, dtype=torch.long))
6264 self.derived.set_(torch.rand(1, dtype=torch.float).detach())
6266 @torch.jit.script_method
6268 self.unpack_called.set_(torch.ones(1, dtype=torch.long))
6269 self.derived.set_(torch.neg(self.
param).detach())
6271 @torch.jit.script_method
6272 def forward(self, x):
6273 return x + self.derived
6275 def test_pack_unpack_state(self):
6277 x = torch.rand(3, 4, dtype=torch.float)
6281 self.assertFalse(sm.pack_called.item())
6282 self.assertFalse(sm.unpack_called.item())
6283 imported = self.getExportImportCopyWithPacking(sm)
6285 self.assertTrue(sm.pack_called.item())
6287 self.assertTrue(sm.unpack_called.item())
6292 self.assertTrue(imported.unpack_called.item())
6295 def test_pack_unpack_nested(self):
6298 super(SubSubMod, self).__init__()
6299 self.register_buffer(
'buf', torch.ones(3, 4) * 3)
6301 @torch.jit.script_method
6303 self.buf.set_(torch.zeros(1, dtype=torch.double))
6305 @torch.jit.script_method
6307 self.buf.set_(torch.ones(3, 4, dtype=torch.double) * 3)
6309 @torch.jit.script_method
6310 def forward(self, x):
6315 super(SubMod, self).__init__()
6316 self.register_buffer(
'buf', torch.ones(3, 4) * 2)
6317 self.ssm = SubSubMod()
6319 @torch.jit.script_method
6321 self.buf.set_(torch.zeros(1, dtype=torch.double))
6323 @torch.jit.script_method
6325 self.buf.set_(torch.ones(3, 4, dtype=torch.double) * 2)
6327 @torch.jit.script_method
6328 def forward(self, x):
6329 return self.ssm(x + self.buf)
6333 super(Mod, self).__init__()
6334 self.submod = SubMod()
6335 self.register_buffer(
'buf', torch.ones(3, 4) * 1)
6337 @torch.jit.script_method
6339 self.buf.set_(torch.zeros(1, dtype=torch.double))
6341 @torch.jit.script_method
6343 self.buf.set_(torch.ones(3, 4, dtype=torch.double))
6345 @torch.jit.script_method
6346 def forward(self, x):
6347 return self.submod(x + self.buf)
6351 m.apply(
lambda s: s._pack())
6353 m.apply(
lambda s: s._unpack())
6356 def test_script_module_not_tuple(self):
6358 __constants__ = [
'mods']
6361 super(M, self).__init__(
False)
6364 @torch.jit.script_method
6365 def forward(self, v):
6369 with self.assertRaisesRegex(RuntimeError,
"cannot be used as a tuple"):
6372 def test_script_module_list_sequential_error(self):
6374 def __init__(self, mod_list):
6375 super(M, self).__init__(
False)
6376 self.mods = mod_list
6378 @torch.jit.script_method
6379 def forward(self, v):
6384 with self.assertRaisesRegex(RuntimeError,
"Did you forget to add it to __constants"):
6385 a = M(nn.Sequential(nn.ReLU()))
6386 with self.assertRaisesRegex(RuntimeError,
"Did you forget to add it to __constants"):
6387 a = M(nn.ModuleList([nn.ReLU()]))
6389 def test_script_sequential_for(self):
6392 super(Sub, self).__init__(
False)
6393 self.weight = nn.Parameter(torch.randn(2))
6395 @torch.jit.script_method
6396 def forward(self, thing):
6397 return self.weight + thing
6400 __constants__ = [
'mods']
6403 super(M, self).__init__(
False)
6404 self.mods = nn.Sequential(Sub(), Sub(), Sub())
6406 @torch.jit.script_method
6407 def forward(self, v):
6412 @torch.jit.script_method
6413 def forward2(self, v):
6422 self.assertEqual(o, v)
6425 self.assertEqual(o2, v)
6427 def test_script_sequential_multi_output_fail(self):
6430 super(Sub, self).__init__(
False)
6431 self.weight = nn.Parameter(torch.randn(2))
6433 @torch.jit.script_method
6434 def forward(self, thing):
6435 return self.weight + thing
6439 super(ReturnMulti, self).__init__(
False)
6441 @torch.jit.script_method
6442 def forward(self, x):
6446 __constants__ = [
'someseq']
6449 super(HaveSequential, self).__init__(
False)
6450 self.someseq = nn.Sequential(
6456 @torch.jit.script_method
6457 def forward(self, x):
6458 return self.someseq(x)
6460 with self.assertRaisesRegex(RuntimeError,
"(Tensor, Tensor, Tensor)"):
6461 hs = HaveSequential()
6465 def test_constant_insert_fail_lint(self):
6473 self.run_pass(
'constant_propagation', foo.graph)
6474 self.assertTrue(
"aten::tensor" in str(foo.graph))
6476 def test_script_sequential_in_mod_list(self):
6479 super(Sub, self).__init__(
False)
6480 self.weight = nn.Parameter(torch.randn(2))
6482 @torch.jit.script_method
6483 def forward(self, thing):
6484 return self.weight + thing
6487 __constants__ = [
'mods']
6490 super(M, self).__init__(
False)
6491 self.mods = nn.ModuleList([Sub(), nn.Sequential(Sub(), nn.Sequential(Sub(), Sub()), Sub())])
6493 @torch.jit.script_method
6494 def forward(self, v):
6495 for mod
in self.mods:
6500 graph = str(m.graph)
6501 self.assertTrue(graph.count(
"aten::add") == 5)
6502 self.assertTrue(
"python" not in graph)
6504 def test_script_nested_mod_list(self):
6507 super(Sub, self).__init__(
False)
6508 self.weight = nn.Parameter(torch.randn(2))
6510 @torch.jit.script_method
6511 def forward(self, thing):
6512 return self.weight + thing
6515 __constants__ = [
'mods']
6518 super(M, self).__init__(
False)
6519 self.mods = nn.ModuleList([nn.ModuleList([Sub()]), nn.Sequential(Sub()), nn.ModuleList([Sub(), Sub()])])
6521 @torch.jit.script_method
6522 def forward(self, v):
6523 for mod
in self.mods:
6529 graph = str(m.graph)
6530 self.assertTrue(graph.count(
"aten::add") == 4)
6531 self.assertTrue(
"python" not in graph)
6533 def test_constant_as_attr(self):
6535 __constants__ = [
'dim']
6538 super(M, self).__init__(
False)
6541 @torch.jit.script_method
6542 def forward(self, v):
6543 return torch.cat([v, v, v], dim=self.dim)
6544 v = torch.zeros(1, 1)
6545 self.assertEqual(torch.cat([v, v, v], dim=1), M()(v))
6551 def forward(self, *inputs):
6553 for i
in range(1, len(inputs)):
6561 def forward(self, rep):
6562 return rep, rep, rep
6564 def test_script_star_expr(self):
6568 super(M2, self).__init__(
True)
6570 (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)))
6573 @torch.jit.script_method
6574 def forward(self, rep):
6579 self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
6581 def test_script_star_expr_string(self):
6584 super(M2, self).__init__(
True)
6586 (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)))
6590 def forward(self, rep): 6596 self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
6602 def forward(self, *inputs):
6604 for i
in range(1, len(inputs)):
6606 return output, output, output
6608 def test_script_star_assign(self):
6611 super(M2, self).__init__(
True)
6614 def forward(self, rep): 6615 head, *tail = self.g(rep) 6620 self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
6622 def test_script_module_star_assign2(self):
6625 super(M2, self).__init__(
True)
6628 (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)),
6629 _force_outplace=
True)
6631 def forward(self, rep): 6632 *head, tail = self.g(rep, rep, rep) 6637 self.assertEqual(m(torch.ones(4, 3)), 3 * torch.ones(4, 3))
6639 def test_script_module_star_assign2_inplace(self):
6642 super(M2, self).__init__(
True)
6645 (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)),
6646 _force_outplace=
False)
6648 def forward(self, rep): 6649 *head, tail = self.g(rep, rep, rep) 6657 self.assertEqual(m(torch.ones(4, 3)), 4 * torch.ones(4, 3))
6659 def test_script_module_star_assign_fail_pythonop(self):
6661 with self.assertRaisesRegex(RuntimeError,
"cannot be used as a tuple"):
6664 super(M2, self).__init__(
True)
6667 return torch.zeros(1, 2, 3), torch.zeros(1, 2, 3)
6670 def forward(self, rep): 6676 m(torch.zeros(4, 3))
6678 def test_script_module_star_assign_fail_builtin(self):
6679 with self.assertRaisesRegex(RuntimeError,
"cannot be used as a tuple"):
6682 super(M2, self).__init__(
True)
6685 def forward(self, rep): 6686 a, *b = torch.neg(rep) 6691 m(torch.zeros(4, 3))
6693 def test_pack_padded_pad_packed_trace(self):
6697 class PadPackedWrapper(torch.nn.Module):
6699 super(PadPackedWrapper, self).__init__()
6701 def forward(self, x, seq_lens):
6702 x = pack_padded_sequence(x, seq_lens)
6703 x, _ = pad_packed_sequence(x)
6706 x = np.ones((T, B, C))
6707 seq_lens = np.array([3, 3, 2, 2, 1], dtype=np.int32)
6711 x[seq_lens[b]:, b, :] = 0
6712 seq_lens = torch.from_numpy(seq_lens)
6713 x = torch.autograd.Variable(torch.from_numpy(x), requires_grad=
True)
6715 m = PadPackedWrapper()
6721 grad = x.grad.clone()
6724 y_traced = m_traced(x, seq_lens)
6725 loss_traced = torch.sum(y_traced)
6726 loss_traced.backward()
6727 grad_traced = x.grad.clone()
6729 self.assertEqual(y_traced, x)
6730 self.assertEqual(y_traced, y)
6731 self.assertEqual(grad, grad_traced)
6736 def test_script_outputs(self):
6737 with self.assertRaisesRegex(RuntimeError,
"cannot be used as a tuple"):
6747 with self.assertRaisesRegex(RuntimeError,
"too many values to unpack"):
6754 @unittest.skipIf(
not RUN_CUDA,
"requires CUDA")
6755 def test_script_get_device_cuda(self):
6758 return a.get_device()
6760 v = torch.randn(1, device=
'cuda')
6761 self.assertEqual(foo(v), 0)
6763 def test_script_chunk(self):
6766 b, c = torch.chunk(a, dim=0, chunks=2)
6768 v = torch.rand(10, 3)
6769 self.assertEqual(torch.chunk(v, dim=0, chunks=2)[0], foo(v))
6771 def test_rnn_trace_override(self):
6776 class RNNTraceWrapper(torch.nn.Module):
6777 def __init__(self, cell_type):
6778 super(RNNTraceWrapper, self).__init__()
6779 if cell_type ==
'RNN':
6780 self.rnn = torch.nn.RNN(input_size=C, hidden_size=C, num_layers=num_layers)
6781 elif cell_type ==
'LSTM':
6782 self.rnn = torch.nn.LSTM(input_size=C, hidden_size=C, num_layers=num_layers)
6783 elif cell_type ==
'GRU':
6784 self.rnn = torch.nn.GRU(input_size=C, hidden_size=C, num_layers=num_layers)
6786 def forward(self, x, seq_lens):
6787 x = pack_padded_sequence(x, seq_lens)
6789 x, _ = pad_packed_sequence(x)
6792 for cell_type
in [
'RNN',
'LSTM',
'GRU']:
6793 x = torch.ones(T, B, C, requires_grad=
True)
6794 seq_lens = torch.from_numpy(np.array([11, 3, 2, 2, 1], dtype=np.int32))
6796 m = RNNTraceWrapper(cell_type)
6802 grad = x.grad.clone()
6805 y_traced = m_traced(x, seq_lens)
6806 loss_traced = torch.sum(y_traced)
6807 loss_traced.backward()
6808 grad_traced = x.grad.clone()
6810 self.assertEqual(y_traced, y)
6811 self.assertEqual(grad, grad_traced)
6816 def test_python_call_non_tensor(self):
6824 x = torch.ones(3, 4)
6825 a, b = foo(x, 3, (x, 3))
6828 self.assertEqual((6, torch.ones(3, 4) + 1), bar())
6830 def test_python_call_non_tensor_wrong(self):
6831 with self.assertRaisesRegex(RuntimeError,
r"but instead got value of type tuple"):
6842 def test_tuples(self):
6858 v = torch.rand(10, 3)
6859 self.checkScript(foo, (v,))
6861 with self.assertRaisesRegex(RuntimeError,
r"variable 'a' previously has type \(Tensor, Tensor\)"):
6868 def test_if_tuple_sizes(self):
6869 with self.assertRaisesRegex(RuntimeError,
"Type mismatch"):
6871 def diff_tuple_sizes(x):
6873 c0 = ((x, x), (x, x, x))
6875 c0 = ((x, x, x), (x, x))
6878 def test_if_different_type(self):
6879 with self.assertRaisesRegex(RuntimeError,
"Type mismatch: c0 is set to type int " 6880 "in the true branch and type float in the false branch:"):
6882 def diff_type_used():
6889 with self.assertRaisesRegex(RuntimeError,
"variable 'c0' previously has type float"):
6891 def diff_existing_type(x):
6899 def diff_type_unused():
6908 def test_if_list_cat(self):
6912 if bool(x.sum() < 1):
6918 b = torch.zeros(2, 4)
6919 test_list.graph.propagate_shapes((b,),
False)
6921 def test_if_supertype(self):
6923 def tensor_unifying(x, y, z):
6932 a = torch.zeros(2, 2, dtype=torch.float)
6933 b = torch.zeros(2, 4, dtype=torch.long)
6934 c = torch.zeros(2, 4, dtype=torch.float)
6936 tensor_unifying.graph.propagate_shapes((a, b, c),
False)
6937 if_outputs = list(tensor_unifying.graph.findNode(
"prim::If").outputs())
6938 self.assertTrue(if_outputs[0].type().str() ==
"Float(*, *)")
6939 self.assertTrue(if_outputs[1].type().str() ==
"Tensor")
6940 self.assertTrue(if_outputs[2].type().str() ==
"Tensor")
6942 def test_list_unify(self):
6946 with self.assertRaisesRegex(RuntimeError,
"int[] in the true branch and type None[]"):
6948 def list_optional_fails(x):
6957 def list_tensors(x):
6960 a = torch.zeros([1, 1])
6963 a = torch.zeros([1, 2])
6967 self.run_pass(
'constant_propagation', list_tensors.graph)
6969 m._create_method_from_graph(
"forward", list_tensors.graph)
6971 self.getExportImportCopy(m)
6973 def test_type_annotations_repeated_list(self):
6978 self.assertEqual(float_fn(2.0, 1.0), float_fn(2.0, [1.0, 1.0, 1.0]))
6979 self.assertEqual(float_fn(2.0, 1.0), float_fn(2.0, (1.0, 1.0, 1.0)))
6982 def float_fn_call():
6983 print(float_fn(1.0, 1.0))
6984 print(float_fn(1.0, (1.0, 1.0, 1.0)))
6990 self.assertEqual(int_fn(1), int_fn([1, 1, 1]))
6991 self.assertEqual(int_fn(1), int_fn((1, 1, 1)))
6996 print(int_fn((1, 1, 1)))
6998 with self.assertRaisesRegex(RuntimeError,
"must be a positive integer:"):
7005 with self.assertRaisesRegex(RuntimeError,
"Unknown type constructor"):
7008 # type: (int, Tuple[int, int[2]]) -> List[int] 7009 return x # noqa: T484 7012 def test_ntuple_builtins(self):
7016 return _single(1), _pair(2), _triple(3), _quadruple(4)
7019 return _single(1), _pair(2.1), _triple(3.1), _quadruple(4.1)
7021 self.checkScript(test_ints, ())
7022 self.checkScript(test_floats, ())
7024 def test_embedding_renorm_grad_error(self):
7028 def embedding_norm(input, embedding_matrix, max_norm):
7029 F.embedding(input, embedding_matrix, max_norm=0.01)
7032 def embedding_norm_script(input, embedding_matrix, max_norm):
7034 F.embedding(input, embedding_matrix, max_norm=0.01)
7036 for _
in [embedding_norm, embedding_norm_script]:
7038 embedding_matrix = torch.randn(10, 3)
7040 var1 = torch.randn(10, 3, requires_grad=
True)
7041 var2 = var1.detach().requires_grad_()
7042 output1 = var1 * embedding_matrix
7043 output2 = var2 * embedding_matrix
7045 output1.sum().backward()
7047 ignore = F.embedding(input, embedding_matrix, max_norm=0.01)
7048 with self.assertRaisesRegex(RuntimeError,
"modified"):
7049 output2.sum().backward()
7051 def test_type_annotations(self):
7054 return x, x * 2, x * 3
7056 with self.assertRaisesRegex(RuntimeError,
r"need 4 values .* found only 3"):
7059 x, y, z, w = fn(x, x)
7061 with self.assertRaisesRegex(RuntimeError,
r"too many values .* need 2 but found 3"):
7074 def fn_string(str, strpair):
7076 str1, str2 = strpair
7077 return str, 2, str1, str2
7079 x = torch.ones(2, 2)
7080 self.checkScript(fn_unpack, (x,), optimize=
True)
7081 self.checkScript(fn_index, (x,), optimize=
True)
7082 self.checkScript(fn_string, (
"1", (
"3",
"4")), optimize=
True)
7084 def test_type_annotations_varargs(self):
7085 def fn_varargs(x, *args):
7086 return args[0]
if args
else x
7089 return fn_varargs(x)
7092 return fn_varargs(x, y)
7095 return fn_varargs(x, y, z)
7097 x, y, z = [torch.randn(2, 2)
for _
in range(3)]
7098 self.checkScript(fn1, (x, y, z), optimize=
True)
7099 self.checkScript(fn2, (x, y, z), optimize=
True)
7100 self.checkScript(fn3, (x, y, z), optimize=
True)
7102 @unittest.skipIf(
not PY35,
"Python 3.5 needed")
7103 def test_type_annotation_py3(self):
7104 import importlib.util
7108 from torch import Tensor 7109 from typing import Tuple 7111 def fn(x : torch.Tensor, y : Tensor, z) -> Tuple[Tensor, Tensor, Tensor]: 7112 return (x, y + z, z) 7115 with tempfile.TemporaryDirectory()
as tmp_dir:
7116 script_path = os.path.join(tmp_dir,
'script.py')
7117 with open(script_path,
'w')
as f:
7119 fn = get_fn(
'test_type_annotation_py3', script_path)
7121 with self.assertRaisesRegex(RuntimeError,
r"expected a value of type Tensor for argument" 7122 r" '0' but found \(Tensor, Tensor\)"):
7125 x, y = fn((x, x), x, x)
7128 with self.assertRaisesRegex(RuntimeError,
r"too many values .* need 2 but found 3"):
7134 with self.assertRaisesRegex(RuntimeError,
r"need 4 values .* found only 3"):
7137 x, y, z, w = fn(x, x, x)
7141 y, z, w = fn(x, x, x)
7144 self.checkScript(good_fn, (torch.ones(2, 2),), optimize=
True)
7146 def test_type_annotation_module(self):
7152 def bar(self, x, y):
7156 def baz(self, x, y):
7159 class ModuleTooMany(BaseModule):
7160 @torch.jit.script_method
7161 def method(self, x):
7162 return self.foo(x, x)
7164 class ModuleTooFew(BaseModule):
7165 @torch.jit.script_method
7166 def method(self, x):
7169 class ModuleTooManyAssign(BaseModule):
7170 @torch.jit.script_method
7171 def method(self, x):
7172 y, z, w = self.bar(x, x)
7175 class ModuleDefault(BaseModule):
7176 @torch.jit.script_method
7177 def method(self, x):
7181 with self.assertRaisesRegex(RuntimeError,
"expected at most 1 arguments but found 2"):
7183 with self.assertRaisesRegex(RuntimeError,
"argument 1 not provided"):
7185 with self.assertRaisesRegex(RuntimeError,
"need 3 values .* found only 2"):
7186 ModuleTooManyAssign()
7187 with self.assertRaisesRegex(RuntimeError,
"argument 1 not provided."):
7190 def test_script_define_order(self):
7195 @torch.jit.script_method
7196 def call_foo(self, input):
7197 return self.foo(input)
7199 @torch.jit.script_method
7200 def foo(self, input):
7203 self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64)))
7205 def test_script_define_order_recursive_fail(self):
7210 @torch.jit.script_method
7211 def call_foo(self, input):
7212 return self.foo(input)
7214 @torch.jit.script_method
7215 def foo(self, input):
7216 self.call_foo(input)
7218 with self.assertRaisesRegex(RuntimeError,
'called recursively involving'):
7221 def test_script_kwargs_fn_call(self):
7226 @torch.jit.script_method
7227 def call_foo(self, input):
7228 return self.foo(input=input, bar=1)
7230 @torch.jit.script_method
7231 def foo(self, bar, input):
7235 self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64)))
7237 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
7238 def test_trace_of_script(self):
7246 a = torch.ones(1, dtype=torch.float)
7248 @_trace(torch.zeros(1, dtype=torch.float))
7250 return foo(b - 1.0, a) + 1.0
7253 self.assertTrue(
"Dynamic" not in str(use.graph))
7255 self.assertEqual(3, use(torch.ones(1, dtype=torch.float)))
7256 self.assertEqual(2, use(torch.zeros(1, dtype=torch.float)))
7258 def test_if_define(self):
7283 a = torch.ones(1, dtype=torch.long)
7284 b = torch.zeros(1, dtype=torch.long)
7285 self.assertEqual(1, foo(a))
7286 self.assertEqual(2, foo(b))
7287 self.assertEqual(1, foo2(a))
7288 self.assertEqual(2, foo2(b))
7289 self.assertEqual(1, foo3(a))
7290 self.assertEqual(2, foo3(b))
7292 def test_script_module_export_submodule(self):
7295 super(M1, self).__init__(
False)
7296 self.weight = nn.Parameter(torch.randn(2))
7298 @torch.jit.script_method
7299 def forward(self, thing):
7300 return self.weight + thing
7304 super(M2, self).__init__(
False)
7307 self.weight = nn.Parameter(torch.randn(2, 3))
7308 self.bias = nn.Parameter(torch.randn(2))
7311 return self.weight.mm(a) 7314 @torch.jit.script_method
7315 def doit(self, input):
7316 return self.weight.mm(input)
7318 @torch.jit.script_method
7319 def doit2(self, input):
7320 return self.weight.mm(input)
7322 @torch.jit.script_method
7323 def doit3(self, input):
7324 return input + torch.ones([1], dtype=torch.double)
7326 @torch.jit.script_method
7327 def forward(self, input):
7328 a = self.doit(input)
7329 b = self.doit2(input)
7331 return a + b + self.bias + c
7334 m_import = self.getExportImportCopy(m_orig)
7336 input = torch.randn(3, 2)
7337 self.assertEqual(m_orig.doit(input), m_import.doit(input))
7338 self.assertEqual(m_orig.hi(input), m_import.hi(input))
7339 self.assertEqual(m_orig.doit3(input), m_import.doit3(input))
7340 self.assertEqual(m_orig.forward(input), m_import.forward(input))
7342 @skipIfNoTorchVision
7343 def test_script_module_trace_resnet18(self):
7344 x = torch.ones(1, 3, 224, 224)
7345 m_orig =
torch.jit.trace(torchvision.models.resnet18(), torch.ones(1, 3, 224, 224))
7346 m_import = self.getExportImportCopy(m_orig)
7348 input = torch.randn(1, 3, 224, 224, requires_grad=
True)
7349 output_orig = m_orig(input)
7350 output_orig.sum().backward()
7351 grad_orig = input.grad.clone()
7354 output_import = m_import(input)
7355 output_import.sum().backward()
7356 grad_import = input.grad.clone()
7358 self.assertEqual(output_orig, output_import)
7359 self.assertEqual(grad_orig, grad_import)
7361 @skipIfNoTorchVision
7362 def test_script_module_script_resnet(self):
7363 def conv1x1(in_planes, out_planes, stride=1):
7364 """1x1 convolution""" 7365 return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=
False)
7367 def conv3x3(in_planes, out_planes, stride=1):
7368 """3x3 convolution with padding""" 7369 return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
7370 padding=1, bias=
False)
7374 __constants__ = [
'downsample']
7376 def __init__(self, inplanes, planes, stride=1, downsample=None):
7377 super(BasicBlock, self).__init__()
7378 self.conv1 = conv3x3(inplanes, planes, stride)
7379 self.bn1 = nn.BatchNorm2d(planes)
7380 self.relu = nn.ReLU(inplace=
True)
7381 self.conv2 = conv3x3(planes, planes)
7382 self.bn2 = nn.BatchNorm2d(planes)
7383 self.downsample = downsample
7384 self.stride = stride
7386 @torch.jit.script_method
7387 def forward(self, x):
7392 out = self.relu(out)
7394 out = self.conv2(out)
7397 if self.downsample
is not None:
7398 residual = self.downsample(x)
7401 out = self.relu(out)
7406 __constants__ = [
'layer1',
'layer2',
'layer3',
'layer4']
7408 def __init__(self, block, layers, num_classes=1000):
7409 super(ResNet, self).__init__()
7411 self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
7413 self.bn1 = nn.BatchNorm2d(64)
7414 self.relu = nn.ReLU(inplace=
True)
7415 self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
7416 self.layer1 = self._make_layer(block, 64, layers[0])
7417 self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
7418 self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
7419 self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
7420 self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
7421 self.fc = nn.Linear(512 * block.expansion, num_classes)
7423 for m
in self.modules():
7424 if isinstance(m, nn.Conv2d):
7425 nn.init.kaiming_normal_(m.weight, mode=
'fan_out', nonlinearity=
'relu')
7426 elif isinstance(m, nn.BatchNorm2d):
7427 nn.init.constant_(m.weight, 1)
7428 nn.init.constant_(m.bias, 0)
7430 def _make_layer(self, block, planes, blocks, stride=1):
7432 if stride != 1
or self.inplanes != planes * block.expansion:
7433 downsample = nn.Sequential(
7434 conv1x1(self.inplanes, planes * block.expansion, stride),
7435 nn.BatchNorm2d(planes * block.expansion),
7439 layers.append(block(self.inplanes, planes, stride, downsample))
7440 self.inplanes = planes * block.expansion
7441 for _
in range(1, blocks):
7442 layers.append(block(self.inplanes, planes))
7444 return nn.Sequential(*layers)
7446 @torch.jit.script_method
7447 def forward(self, x):
7459 x = x.view(x.size(0), -1)
7464 resnet18 = ResNet(BasicBlock, [2, 2, 2, 2])
7466 resnet18_imported = self.getExportImportCopy(resnet18)
7468 input = torch.randn(1, 3, 224, 224, requires_grad=
True)
7469 output_orig = resnet18(input)
7470 output_orig.sum().backward()
7471 grad_orig = input.grad.clone()
7473 output_import = resnet18_imported(input)
7474 output_import.sum().backward()
7475 grad_import = input.grad.clone()
7477 self.assertEqual(output_orig, output_import)
7478 self.assertEqual(grad_orig, grad_import)
7480 def test_script_module_export_tensor_type(self):
7483 def __init__(self, type):
7484 super(M, self).__init__(
False)
7485 self.param = torch.nn.Parameter(torch.zeros((5, 5), dtype=type).random_())
7487 @torch.jit.script_method
7491 for type
in [torch.float, torch.double]:
7493 m_import = self.getExportImportCopy(m_orig)
7495 self.assertTrue(m_orig.param.storage().size() == 25)
7496 self.assertEqual(m_orig.foo(), m_import.foo())
7497 self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype)
7499 @unittest.skipIf(
not RUN_CUDA,
"testing cuda tensors require CUDA")
7500 def test_script_module_export_tensor_cuda(self):
7504 super(M, self).__init__(
False)
7505 self.param = torch.nn.Parameter(torch.zeros((5, 5), device=
'cuda:0').random_())
7507 @torch.jit.script_method
7512 m_import = self.getExportImportCopy(m_orig)
7514 self.assertTrue(m_orig.param.storage().size() == 25)
7515 self.assertTrue(m_import.foo().device == torch.device(
'cuda:0'))
7516 self.assertEqual(m_orig.foo(), m_import.foo())
7517 self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype)
7519 def test_script_module_export_blocks(self):
7521 def __init__(self, n, m):
7522 super(M, self).__init__()
7523 self.weight = torch.nn.Parameter(torch.rand(n, m))
7525 @torch.jit.script_method
7526 def forward(self, input):
7527 if bool(input.sum() > 0):
7528 output = self.weight.mv(input)
7530 output = self.weight + input
7533 m_orig = M(200, 200)
7534 m_import = self.getExportImportCopy(m_orig)
7537 self.assertEqual(m_orig(t), m_import(t))
7539 def test_script_module_export_shared_storage(self):
7543 super(M, self).__init__(
False)
7544 self.param1 = torch.nn.Parameter(torch.rand(5, 5))
7545 self.param2 = torch.nn.Parameter(self.param1[3])
7546 self.param3 = torch.nn.Parameter(torch.rand(5, 5))
7547 self.param4 = torch.nn.Parameter(torch.rand(11, 5)[1:6])
7549 @torch.jit.script_method
7551 return self.param1 + self.param2 + self.param3 + self.param4
7554 m_import = self.getExportImportCopy(m_orig)
7556 self.assertEqual(m_orig.foo(), m_import.foo())
7557 self.assertTrue(m_import.param1.storage().data_ptr() == m_import.param2.storage().data_ptr())
7558 self.assertTrue(m_import.param1.storage().data_ptr() != m_import.param3.storage().data_ptr())
7560 def test_onnx_export_script_module(self):
7563 super(ModuleToExport, self).__init__()
7565 @torch.jit.script_method
7566 def forward(self, x):
7570 mte = ModuleToExport()
7571 outputs = mte(torch.zeros(1, 2, 3))
7573 mte, (torch.zeros(1, 2, 3),),
None, verbose=
False,
7574 example_outputs=outputs))
7576 def test_trace_nested_datatypes(self):
7579 return [[x + 1, x - 1], [x + 2, x - 2]]
7583 return list_stuff[0][0], list_stuff[1][1]
7586 x = torch.rand(5, 6)
7587 self.assertEqual(bar(x), traced(x))
7590 def test_onnx_export_func_with_warnings(self):
7592 def func_with_warning(inp):
7595 class WarningTest(torch.nn.Module):
7597 super(WarningTest, self).__init__()
7599 def forward(self, x):
7600 return func_with_warning(x)
7602 outputs = WarningTest()(torch.randn(42))
7605 WarningTest(), torch.randn(42),
None, verbose=
False,
7606 example_outputs=outputs)
7608 def test_onnx_export_script_python_fail(self):
7611 super(ModuleToInline, self).__init__()
7613 def forward(self, x):
7618 super(ModuleToExport, self).__init__()
7619 self.mod = ModuleToInline()
7621 @torch.jit.script_method
7622 def forward(self, x):
7626 mte = ModuleToExport()
7627 outputs = mte(torch.zeros(1, 2, 3))
7629 with self.assertRaisesRegex(RuntimeError,
"Couldn't export Python operator"):
7631 example_outputs=outputs)
7633 def test_onnx_export_script_inline_trace(self):
7636 super(ModuleToInline, self).__init__()
7638 def forward(self, x):
7643 super(ModuleToExport, self).__init__()
7646 @torch.jit.script_method
7647 def forward(self, x):
7651 mte = ModuleToExport()
7652 outputs = mte(torch.zeros(1, 2, 3))
7654 mte, (torch.zeros(1, 2, 3),),
None, verbose=
False,
7655 example_outputs=outputs))
7657 def test_onnx_export_script_inline_script(self):
7660 super(ModuleToInline, self).__init__()
7662 @torch.jit.script_method
7663 def forward(self, x):
7668 super(ModuleToExport, self).__init__()
7669 self.mod = ModuleToInline()
7671 @torch.jit.script_method
7672 def forward(self, x):
7676 mte = ModuleToExport()
7677 outputs = mte(torch.zeros(1, 2, 3))
7679 mte, (torch.zeros(1, 2, 3),),
None, verbose=
False,
7680 example_outputs=outputs))
7682 def test_onnx_export_script_module_loop(self):
7685 super(ModuleToExport, self).__init__()
7687 @torch.jit.script_method
7688 def forward(self, x):
7696 mte = ModuleToExport()
7697 outputs = mte(torch.zeros(1, 2, 3))
7699 mte, (torch.zeros(1, 2, 3),),
None, verbose=
False,
7700 example_outputs=outputs))
7702 def test_onnx_export_script_truediv(self):
7705 super(ModuleToExport, self).__init__()
7707 @torch.jit.script_method
7708 def forward(self, x):
7712 mte = ModuleToExport()
7713 outputs = mte(torch.zeros(1, 2, 3))
7715 mte, (torch.zeros(1, 2, 3),),
None, verbose=
False,
7716 example_outputs=outputs))
7718 def test_onnx_raw_export_script_truediv(self):
7721 super(ModuleToExport, self).__init__()
7723 @torch.jit.script_method
7724 def forward(self, x):
7728 mte = ModuleToExport()
7729 outputs = mte(torch.zeros(1, 2, 3))
7731 mte, (torch.zeros(1, 2, 3),),
None, verbose=
False,
7732 example_outputs=outputs, export_raw_ir=
True))
7734 def test_onnx_export_script_non_alpha_add_sub(self):
7737 super(ModuleToExport, self).__init__()
7739 @torch.jit.script_method
7740 def forward(self, x):
7744 mte = ModuleToExport()
7745 outputs = torch.LongTensor([mte(torch.rand(3, 4))])
7747 mte, (torch.rand(3, 4),),
None, verbose=
False,
7748 example_outputs=outputs))
7750 def test_onnx_export_script_module_if(self):
7753 super(ModuleToExport, self).__init__()
7755 @torch.jit.script_method
7756 def forward(self, x):
7757 if bool(torch.sum(x) > 0):
7761 mte = ModuleToExport()
7762 outputs = mte(torch.zeros(1, 2, 3, dtype=torch.long))
7764 mte, (torch.zeros(1, 2, 3),),
None, verbose=
False,
7765 example_outputs=outputs))
7767 def test_onnx_export_script_inline_params(self):
7770 super(ModuleToInline, self).__init__()
7771 self.m = torch.nn.Parameter(torch.ones(3, 3))
7772 self.unused = torch.nn.Parameter(torch.ones(1, 2, 3))
7774 @torch.jit.script_method
7775 def forward(self, x):
7776 return torch.mm(x, self.m)
7780 super(ModuleToExport, self).__init__()
7781 self.mod = ModuleToInline()
7782 self.param = torch.nn.Parameter(torch.ones(3, 4))
7784 @torch.jit.script_method
7785 def forward(self, x):
7787 return torch.mm(y, self.param)
7789 mte = ModuleToExport()
7790 result = mte(torch.zeros(2, 3))
7791 reference = torch.mm(torch.mm(torch.zeros(2, 3), torch.ones(3, 3)), torch.ones(3, 4))
7792 self.assertEqual(result, reference)
7794 mte, (torch.ones(2, 3),),
None, verbose=
False,
7795 example_outputs=result, propagate=
True))
7797 def test_trace_with_size(self):
7798 @_trace(torch.zeros(1, 1))
7809 self.assertEqual(8, bar(torch.ones(1, 1)))
7811 def test_tracing_slicing(self):
7812 @_trace(torch.zeros(10))
7823 a = torch.arange(0, 8)
7824 b = torch.arange(0, 20)
7825 self.assertEqual(foo_trace(a), foo_script(a))
7826 self.assertEqual(foo_trace(a), foo(a))
7827 self.assertNotEqual(foo_trace(a), foo_trace(b))
7829 def test_tracing_indexing(self):
7830 @_trace(torch.zeros(10))
7841 a = torch.arange(0, 8)
7842 b = torch.arange(0, 20)
7843 self.assertEqual(foo_script(a), foo_trace(a))
7844 self.assertEqual(foo_trace(a), foo(a))
7845 self.assertNotEqual(foo_trace(a), foo_trace(b))
7847 def test_index_select_shape_prop(self):
7851 return torch.index_select(x, index=y, dim=1)
7853 a = torch.zeros(2, 2)
7854 b = torch.zeros(4, dtype=torch.long)
7855 torch._C._jit_pass_complete_shape_analysis(foo.graph, (a, b),
False)
7856 FileCheck().check(
"Double(2, 4)").run(str(foo.graph))
7858 def test_onnx_export_speculate(self):
7861 def __init__(self, m):
7862 super(Foo, self).__init__()
7865 @torch.jit.script_method
7866 def forward(self, x):
7871 c = torch.sum(x) > 4
7881 linear =
torch.jit.trace(nn.Linear(10, 20).float(), torch.zeros(1, 10, dtype=torch.float))
7888 outputs_f1 = f1(torch.ones(1, 10, dtype=torch.float))
7890 outputs_f2 = f2(torch.ones(1, 10, dtype=torch.float))
7894 (torch.ones(1, 10, dtype=torch.float), ),
7895 None, verbose=
False, example_outputs=outputs_f1)
7896 self.assertExpected(onnx_ish, subname=
'f1')
7899 (torch.ones(1, 10, dtype=torch.float), ),
7900 None, verbose=
False, example_outputs=outputs_f2)
7901 self.assertExpected(onnx_ish, subname=
'f2')
7903 def test_onnx_export_shape_reshape(self):
7904 class Foo(torch.nn.Module):
7905 def forward(self, x):
7907 x = x.repeat(5, 1, 1)
7913 outputs = foo(torch.zeros(1, 2, 3))
7916 example_outputs=outputs)
7917 self.assertExpected(s)
7919 def test_shape_analysis_loop(self):
7936 self.checkScript(foo, (torch.zeros(1), torch.zeros(4), torch.zeros(5)), optimize=
False)
7938 def test_intlist_args(self):
7948 x = torch.randn(8, 8, 8)
7949 self.checkScript(func_1, [x], optimize=
True)
7950 self.checkScript(func_2, [x], optimize=
True)
7951 self.checkScript(func_3, [x], optimize=
True)
7953 def test_wrong_implicit_expand(self):
7955 @_trace(torch.zeros(3), torch.zeros(1))
7961 self.assertEqual(a + b, foo(a, b))
7963 def test_builtin_args_fails(self):
7965 with self.assertRaisesRegex(RuntimeError,
'expected at most'):
7968 torch.sum(a, a, a, a)
7970 with self.assertRaisesRegex(RuntimeError,
'argument self not provided'):
7975 with self.assertRaisesRegex(RuntimeError,
'specified twice'):
7978 torch.sum(a, self=a)
7980 with self.assertRaisesRegex(RuntimeError,
'not provided'):
7985 with self.assertRaisesRegex(RuntimeError,
'for argument \'tensors\' but found Tensor'):
7990 with self.assertRaisesRegex(RuntimeError,
r'argument \'tensors\' but found int\[\]'):
7995 with self.assertRaisesRegex(RuntimeError,
'Lists must contain only a single type'):
7998 a.expand(size=[3, [4]])
8000 with self.assertRaisesRegex(RuntimeError,
'xpected a value of type Tensor for argument \'self\''):
8005 def test_builtin_args(self):
8009 return torch.cat([a, a])
8011 self.checkScript(t0, (torch.zeros(1, 1),))
8015 return torch.cat(dim=1, tensors=[a, a])
8017 self.checkScript(t1, (torch.zeros(1, 1, 2),))
8025 return torch.sum(a, dim=b, keepdim=
False)
8027 self.checkScript(t2, (torch.zeros(1, 1, 2),))
8029 def test_parser_type_annotations(self):
8031 def foo(x : Tensor, y : Tuple[Tuple[Tensor, Tensor], Tensor]) -> Tuple[Tensor, Tensor]: 8035 self.assertExpected(cu.__getattr__(
'foo').pretty_print_schema())
8037 def test_parser_type_annotations_comment(self):
8040 # type: (Tensor, Tuple[Tuple[Tensor, Tensor], Tensor]) -> Tuple[Tensor, Tensor] 8044 self.assertExpected(cu.__getattr__(
'foo').pretty_print_schema())
8046 def test_parser_type_annotations_unknown_type(self):
8047 with self.assertRaisesRegex(RuntimeError,
r'Unknown type name Foo'):
8049 def foo(x : Tensor, y : Tuple[Tuple[Foo, Tensor], Tensor]) -> Tuple[Tensor, Tensor]: 8053 def test_parser_type_annotations_subscript_non_ident(self):
8054 with self.assertRaisesRegex(RuntimeError,
r'Subscripted type must be a type identifier'):
8056 def foo(x : Tensor, y : Tuple[Tensor, Tensor][Tensor]) -> Tuple[Tensor, Tensor]: 8060 def test_parser_type_annotations_subscript_tensor(self):
8061 with self.assertRaisesRegex(RuntimeError,
r'Unknown type constructor Tensor'):
8063 def foo(x : Tensor, y : Tensor[Tensor, Tensor]) -> Tuple[Tensor, Tensor]: 8067 def test_parser_type_annotations_incompatible_expression(self):
8068 with self.assertRaisesRegex(RuntimeError,
r'Expression of type \+ cannot be used in a type expression'):
8070 def foo(x : Tensor, y : Tuple[3 + 4, Tensor]) -> Tuple[Tensor, Tensor]: 8074 def test_gather_dynamic_index(self):
8079 return gather1 + gather2
8081 self.checkScript(t, (torch.zeros(3, 2, 3),))
8083 def test_slice_dynamic_index(self):
8088 slice2 = x[zero:one]
8089 return slice1 + slice2
8091 self.checkScript(t, (torch.zeros(3, 2, 3),))
8094 """ This test checks several things: 8095 1. An expand node was inserted before the addmm operating on the 8097 2. The fused form of addmm appears in the ultimate graph that's 8099 3. A sum op was emitted for accumulating gradients along the 0th 8100 (expanded) dimension of the bias term. 8101 4. The correct symbolic representation for the backward pass of the 8102 mm operator was emitted (x.t() -> mm) 8104 TODO: we should actually check these conditions once we have a way 8105 to dump the GraphExecutor state. Namely the processed forward graph 8106 and the backward graph. 8109 def addmm_grad_test(b, x, w):
8110 return torch.addmm(b, x, w)
8113 w_init = torch.rand(2, 5)
8114 b_init = torch.rand(5)
8115 x = torch.rand(3, 2)
8124 y = addmm_grad_test(b, x, w)
8128 b_ref = b_init.clone()
8129 b_ref.requires_grad_()
8130 w_ref = w_init.clone()
8131 w_ref.requires_grad_()
8132 y_ref = torch.addmm(b_ref, x, w_ref)
8133 y_ref.sum().backward()
8138 def test_zeros(self):
8140 __constants__ = [
'd']
8143 self.
d = torch.device(
'cpu')
8145 @torch.jit.script_method
8147 return torch.zeros([1, 1, 2], dtype=torch.float, device=self.
d, layout=torch.strided)
8151 self.
assertEqual(torch.zeros([1, 1, 2], dtype=torch.float), r)
8153 def test_vararg_zeros(self):
8155 return torch.zeros(3, 4, 5, dtype=torch.int)
8159 def test_rand(self):
8161 a = torch.rand([3, 4])
8166 def test_erase_number_types(self):
8174 FileCheck().check(
"int = prim::Constant").check(
"aten::add_").run(str(graph))
8175 self.
run_pass(
'remove_inplace_ops', graph)
8176 self.
run_pass(
'erase_number_types', graph)
8178 FileCheck().check_not(
"int = prim::Constant").check_not(
"aten::add_").run(str(graph))
8180 def test_mm_batching(self):
8183 def lstm(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
8184 for i
in range(x.size(0)):
8185 hx, cx = lstm_cell(x[i], hx, cx, w_ih, w_hh, b_ih, b_hh)
8190 inputs = get_lstm_inputs(
'cpu', training=
True, seq_length=10)
8191 slstm(*inputs).sum().backward()
8193 fw_graph = slstm.graph_for(*inputs)
8194 bw_graph = backward_graph(slstm, diff_graph_idx=0)
8195 self.assertTrue(
'prim::MMBatchSide' in str(fw_graph))
8196 self.assertTrue(
'prim::MMTreeReduce' in str(bw_graph))
8198 sout = slstm(*inputs)
8204 def test_loop_unrolling(self):
8207 for i
in range(int(x)):
8212 self.
run_pass(
'loop_unrolling', graph)
8214 FileCheck().check(
"prim::Loop").check_count(
"aten::sub", unroll_factor) \
8215 .check(
"prim::Loop").check(
"aten::sub").run(str(graph))
8218 def test_loop_unrolling_const(self):
8231 def check(fn, name):
8233 self.
run_pass(
'loop_unrolling', graph)
8235 FileCheck().check_not(
"prim::Loop'").run(str(graph))
8238 check(fn,
'add_const')
8239 check(fn2,
'add_iter')
8241 def test_loop_unrolling_nested(self):
8245 for j
in range(int(x)):
8250 self.
run_pass(
'loop_unrolling', graph)
8253 FileCheck().check(
"prim::Loop").check(
"prim::Loop").check_count(
'aten::sub', unroll_factor) \
8254 .check(
"prim::Loop").check(
"aten::sub").run(str(graph))
8257 def test_loop_unroll_unused_counter(self):
8260 for _
in range(int(x)):
8265 self.
run_pass(
'loop_unrolling', graph)
8266 FileCheck().check(
"prim::Loop").check_not(
"aten::add").check(
"return") \
8269 def test_loop_unroll_negative(self):
8272 for _
in range(int(x)):
8283 def test_where(self):
8285 return torch.where(x > 0.0, x, y)
8287 self.
checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float)))
8289 def test_where_method(self):
8291 return x.where(x > 0.0, y)
8293 self.
checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float)))
8295 def test_reassign_module_lhs(self):
8296 with self.
assertRaisesRegex(RuntimeError,
'Cannot re-assign \'self\' because it has type value and self is' 8297 ' not a first-class value. Only reassignments to first-class values are allowed'):
8299 @torch.jit.script_method
8300 def forward(self, x):
8307 def test_reassign_module_rhs(self):
8308 with self.
assertRaisesRegex(RuntimeError,
'Cannot re-assign \'x\' to a value of type module because x is not a' 8309 ' first-class value. Only reassignments to first-class values are allowed'):
8311 @torch.jit.script_method
8312 def forward(self, x):
8319 def test_unknown_builtin(self):
8322 def unknown_builtin(x):
8325 def test_return_tuple(self):
8326 def return_tuple(x):
8331 def test_method_no_self(self):
8334 @torch.jit.script_method
8336 return torch.zeros(3, 4)
8340 def test_return_stmt_not_at_end(self):
8348 def test_for_range_no_arg(self):
8349 with self.
assertRaisesRegex(RuntimeError,
r'range\(\) expects 1 argument but got 0'):
8351 def range_no_arg(x):
8356 def test_list_iterables(self):
8357 with self.
assertRaisesRegex(RuntimeError,
'List of iterables is not supported currently'):
8359 def list_iterables(x): 8360 for i, j in [2, 3, 4], [5, 6, 7]: 8366 def test_for_tuple_unpack(self):
8367 with self.
assertRaisesRegex(RuntimeError,
'Iteration variable unpacking is not supported'):
8369 def for_tuple_unpack(x, y): 8370 for i, j in [[3, 4], [5, 6], [7, 8]]: 8376 def test_single_starred_lhs(self):
8377 with self.
assertRaisesRegex(RuntimeError,
'A Starred expression may only appear on the lhs within the presence' 8378 ' of another non-starred expression'):
8380 def single_starred_lhs(x): 8386 def test_singleton_tuple_unpack(self):
8392 def test_multi_reduction(self):
8395 'augmented assignment can only have one LHS expression'):
8397 def multi_reduction(x): 8402 def test_invalid_call_arguments(self):
8405 def invalid_call_arguments(x):
8406 return torch.unsqueeze(3, 4, 5, 6, 7, 8)
8408 def test_invalid_lhs_assignment(self):
8411 def invalid_lhs_assignment(x): 8416 def test_multi_starred_expr_lhs(self):
8417 with self.
assertRaisesRegex(RuntimeError,
'Only one starred expression is allowed on the lhs'):
8419 def multi_starred_expr_lhs(): 8420 a, *b, *c = [1, 2, 3, 4, 5, 6] 8424 def test_pack_tuple_into_non_var(self):
8425 with self.
assertRaisesRegex(RuntimeError,
'Cannot pack a tuple into a non-variable'):
8427 def pack_tuple_into_non_var(x): 8432 def test_print_kwargs(self):
8433 with self.
assertRaisesRegex(RuntimeError,
'print doesn\'t accept any keyword arguments'):
8435 def print_kwargs(x): 8436 print(x, flush=True) 8440 def test_builtin_use_as_value(self):
8443 def builtin_use_as_value(x):
8446 def test_wrong_use_as_tuple(self):
8452 def wrong_use_as_tuple(self):
8456 def test_wrong_attr_lookup(self):
8457 with self.
assertRaisesRegex(RuntimeError,
'attribute lookup is not defined on builtin'):
8459 def wrong_attr_lookup(self, x):
8460 a = x.unsqueeze.myattr
8463 def test_wrong_use_as_callable(self):
8466 def wrong_use_as_callable(x):
8469 def test_python_val_doesnt_have_attr(self):
8473 def python_val_doesnt_have_attr():
8478 def test_wrong_module_attr_lookup(self):
8479 with self.
assertRaisesRegex(RuntimeError,
'python value of type \'type\' cannot be used as a value:'):
8483 def wrong_module_attr_lookup():
8486 def test_wrong_method_call_inputs(self):
8490 @torch.jit.script_method
8491 def foo(self, x, y):
8494 @torch.jit.script_method
8495 def forward(self, x, y):
8499 def test_single_starred_expr_for_loop(self):
8504 for *a in [1, 2, 3]: 8509 def test_duplicate(self):
8519 def test_call_ge(self):
8520 with self.
assertRaisesRegex(RuntimeError,
'expected at most 1 arguments but found 3'):
8521 @_trace(torch.zeros(1, 2, 3))
8527 return foo(torch.full([1], 1), torch.full([1], 2), torch.full([1], 3))
8529 def test_wrong_return_type(self):
8530 with self.
assertRaisesRegex(RuntimeError,
'but instead got value of type tuple'):
8533 return torch.zeros(3, 4), torch.zeros(4, 5)
8536 def wrong_return_type():
8541 def test_call_python_fn_from_tracing_fn(self):
8545 @_trace(torch.rand(3, 4))
8547 return python_fn(x) + 1
8551 FileCheck().check(
"aten::neg").run(str(traced_fn.graph))
8553 def test_call_python_mod_from_tracing_fn(self):
8554 class PythonMod(torch.nn.Module):
8556 super(PythonMod, self).__init__()
8557 self.
param = torch.nn.Parameter(torch.rand(4, 3), requires_grad=
False)
8559 def forward(self, x):
8560 return torch.mm(x, self.
param)
8564 @_trace(torch.rand(3, 4))
8570 self.assertTrue(len(list(traced_fn.graph.inputs())) == 1)
8571 FileCheck().check(
"aten::mm").check(
"aten::add").run(str(traced_fn.graph))
8573 def test_call_traced_fn_from_tracing_fn(self):
8574 @_trace(torch.rand(3, 4))
8578 @_trace(torch.rand(3, 4))
8580 return traced_fn1(x) + 1
8582 FileCheck().check(
"aten::neg").check_same(
"scope: traced_fn1").check(
"aten::add") \
8583 .run(str(traced_fn.graph))
8585 def test_call_traced_mod_from_tracing_fn(self):
8586 class TracedModule(torch.nn.Module):
8588 super(TracedModule, self).__init__()
8589 self.
param = torch.nn.Parameter(torch.rand(4, 3), requires_grad=
False)
8591 def forward(self, x):
8592 return torch.mm(x, self.
param)
8596 @_trace(torch.rand(3, 4))
8602 FileCheck().check(
"prim::Constant[value=<Tensor>]").check(
"aten::mm") \
8603 .check(
"aten::add").run(str(traced_fn.graph))
8605 def test_call_script_fn_from_tracing_fn(self):
8610 @_trace(torch.rand(3, 4))
8612 return script_fn(x) + 1
8614 FileCheck().check(
"aten::neg").check(
"aten::add").run(str(traced_fn.graph))
8616 def test_call_script_mod_from_tracing_fn(self):
8620 super(ScriptMod, self).__init__()
8621 self.
param = torch.nn.Parameter(torch.rand(3, 4), requires_grad=
False)
8623 @torch.jit.script_method
8624 def forward(self, x):
8631 @_trace(torch.rand(3, 4))
8636 FileCheck().check(
"prim::Constant[value=<Tensor>]").check(
"Loop") \
8637 .run(str(traced_fn.graph))
8639 def test_call_python_fn_from_traced_module(self):
8643 class TracedModule(torch.nn.Module):
8645 super(TracedModule, self).__init__()
8646 self.
param = torch.nn.Parameter(torch.rand(4, 3))
8648 def forward(self, x):
8649 return torch.mm(python_fn(x), self.
param)
8656 self.assertTrue(len(list(tm.graph.inputs())) == 2)
8657 FileCheck().check(
"aten::neg").check(
"aten::mm").run(str(tm.graph))
8659 def test_call_python_mod_from_traced_module(self):
8660 class PythonModule(torch.nn.Module):
8662 super(PythonModule, self).__init__()
8663 self.
param = torch.nn.Parameter(torch.rand(5, 7))
8665 def forward(self, x):
8666 return torch.mm(x, self.
param)
8668 class TracedModule(torch.nn.Module):
8670 super(TracedModule, self).__init__()
8671 self.
param = torch.nn.Parameter(torch.rand(4, 5))
8672 self.
mod = PythonModule()
8674 def forward(self, x):
8675 return self.
mod(torch.mm(x, self.
param)) + 1.0
8681 self.assertTrue(len(list(tm.graph.inputs())) == 3)
8682 FileCheck().check_not(
"value=<Tensor>").check_count(
"aten::mm", 2).check(
"aten::add") \
8685 def test_call_traced_fn_from_traced_module(self):
8686 @_trace(torch.rand(3, 4))
8690 class TracedModule(torch.nn.Module):
8692 super(TracedModule, self).__init__()
8693 self.
param = torch.nn.Parameter(torch.rand(4, 5))
8695 def forward(self, x):
8696 return traced_fn(torch.mm(x, self.
param))
8700 FileCheck().check(
"aten::mm").check_same(
"scope: TracedModule") \
8701 .check_next(
"aten::neg").check(
"scope: TracedModule/traced_fn") \
8704 def test_trace_hierarchy(self):
8710 super(AnotherScriptMod, self).__init__()
8711 self.
param = torch.nn.Parameter(torch.rand(1, 2, 3))
8713 @torch.jit.script_method
8715 return torch.zeros(4, 5)
8719 super(SomeScriptMod, self).__init__()
8720 self.
asm = AnotherScriptMod()
8722 @torch.jit.script_method
8724 return torch.zeros(3, 4)
8726 @torch.jit.script_method
8728 return torch.zeros(4, 3)
8730 class TraceMe(torch.nn.Module):
8732 super(TraceMe, self).__init__()
8733 self.
ssm = SomeScriptMod()
8735 def forward(self, x):
8736 return self.ssm.bar() + x
8743 self.assertTrue(traced.ssm._has_method(
'foo'))
8744 self.assertTrue(hasattr(traced.ssm,
'foo'))
8748 self.assertTrue(imported.ssm._has_method(
'foo'))
8749 self.assertTrue(hasattr(imported.ssm,
'foo'))
8751 self.assertTrue(imported.ssm.asm._has_method(
'bar'))
8752 self.assertTrue(hasattr(imported.ssm.asm,
'bar'))
8754 self.assertTrue(imported.ssm.asm._has_parameter(
'param'))
8755 self.assertTrue(hasattr(imported.ssm.asm,
'param'))
8757 def test_trace_parameter(self):
8758 class Param(nn.Module):
8760 super(Param, self).__init__()
8761 self.register_parameter(
"bias", nn.Parameter(torch.Tensor(4, 4)))
8763 def forward(self, x):
8767 def __init__(self, model):
8768 super(M3, self).__init__(
False)
8771 @torch.jit.script_method
8772 def forward(self, x):
8775 class M2(nn.Module):
8776 def __init__(self, model):
8777 super(M2, self).__init__()
8780 def forward(self, x):
8784 def __init__(self, model):
8785 super(M1, self).__init__(
False)
8788 @torch.jit.script_method
8789 def forward(self, x):
8792 module = M1(Param())
8796 def test_call_traced_module_from_traced_module(self):
8797 class TracedModule1(torch.nn.Module):
8799 super(TracedModule1, self).__init__()
8800 self.
param = torch.nn.Parameter(torch.rand(5, 7))
8802 def forward(self, x):
8803 return torch.mm(x, self.
param)
8805 class TracedModule(torch.nn.Module):
8807 super(TracedModule, self).__init__()
8808 self.
param = torch.nn.Parameter(torch.rand(4, 5))
8811 def forward(self, x):
8812 return self.
mod(torch.mm(x, self.
param)) + 1.0
8818 self.assertTrue(len(list(tm.graph.inputs())) == 3)
8819 FileCheck().check_count(
"aten::mm", 2).check(
"aten::add").run(str(tm.graph))
8821 def test_call_script_fn_from_traced_module(self):
8826 class TracedModule(torch.nn.Module):
8828 super(TracedModule, self).__init__()
8829 self.
param = torch.nn.Parameter(torch.rand(4, 5))
8831 def forward(self, x):
8832 return traced_fn(torch.mm(x, self.
param))
8836 FileCheck().check(
"aten::mm").check(
"aten::neg").run(str(tm.graph))
8838 def test_call_script_module_from_traced_module(self):
8841 super(ScriptMod, self).__init__()
8842 self.
param_foo = torch.nn.Parameter(torch.rand(5, 7))
8844 @torch.jit.script_method
8845 def forward(self, x):
8848 class TracedModule(torch.nn.Module):
8850 super(TracedModule, self).__init__()
8851 self.
param = torch.nn.Parameter(torch.rand(4, 5))
8852 self.
mod = ScriptMod()
8854 def forward(self, x):
8855 return self.
mod(torch.mm(x, self.
param)) + 1.0
8861 self.assertTrue(len(list(tm.graph.inputs())) == 3)
8862 FileCheck().check_count(
"aten::mm", 2).check(
"aten::add").run(str(tm.graph))
8864 def test_call_python_fn_from_script_fn(self):
8870 return python_fn(x) + 1
8876 FileCheck().check(
"python_fn").run(str(script_fn.graph))
8878 def test_call_python_mod_from_script_fn(self):
8879 class PythonModule(torch.nn.Module):
8881 super(PythonModule, self).__init__()
8882 self.
param = torch.nn.Parameter(torch.rand(5, 7))
8884 def forward(self, x):
8885 return torch.mm(x, self.
param)
8895 FileCheck().check(
"python_value").check(
"aten::add").run(str(script_fn.graph))
8897 def test_call_traced_fn_from_script_fn(self):
8898 @_trace(torch.rand(3, 4))
8904 return traced_fn(x) + 1
8908 FileCheck().check(
"aten::neg").check(
"aten::add").run(str(script_fn.graph))
8910 def test_call_traced_mod_from_script_fn(self):
8911 class TracedModule(torch.nn.Module):
8913 super(TracedModule, self).__init__()
8915 def forward(self, x):
8916 return torch.mm(x, torch.zeros(4, 3))
8924 FileCheck().check(
"aten::zeros").check_same(
"scope: TracedModule").check(
"aten::mm") \
8925 .check(
"aten::add").run(str(script_fn.graph))
8927 def test_call_script_fn_from_script_fn(self):
8934 return script_fn1(x) + 1
8938 FileCheck().check(
"aten::neg").run(str(script_fn.graph))
8940 def test_call_script_mod_from_script_fn(self):
8943 super(ScriptMod, self).__init__()
8945 @torch.jit.script_method
8946 def forward(self, x):
8947 return torch.mm(x, torch.zeros([4, 3]))
8955 FileCheck().check(
"zeros").check(
"aten::mm").check(
"add").run(str(script_fn.graph))
8957 def test_call_python_fn_from_script_module(self):
8963 super(ScriptMod, self).__init__()
8964 self.
param = torch.nn.Parameter(torch.rand(4, 3))
8966 @torch.jit.script_method
8967 def forward(self, x):
8968 return python_fn(torch.mm(x, self.
param))
8971 FileCheck().check(
"aten::mm").check(
"python_fn") \
8972 .run(str(sm.__getattr__(
'forward').graph))
8974 def test_call_python_mod_from_script_module(self):
8975 class PythonMod(torch.nn.Module):
8977 super(PythonMod, self).__init__()
8978 self.
param = torch.nn.Parameter(torch.rand(3, 5))
8980 def forward(self, x):
8981 return torch.mm(x, self.
param)
8985 super(ScriptMod, self).__init__()
8986 self.
param = torch.nn.Parameter(torch.rand(4, 3))
8987 self.
pm = PythonMod()
8989 @torch.jit.script_method
8990 def forward(self, x):
8991 return self.
pm(torch.mm(x, self.
param))
8996 FileCheck().check(
"aten::mm").check(
"python_value").run(str(sm.graph))
8998 def test_call_tracing_fn_from_script_module(self):
8999 @_trace(torch.rand(3, 3))
9005 super(ScriptMod, self).__init__()
9006 self.
param = torch.nn.Parameter(torch.rand(4, 3))
9008 @torch.jit.script_method
9009 def forward(self, x):
9010 return traced_fn(torch.mm(x, self.
param))
9013 FileCheck().check(
"aten::mm").check(
"aten::neg").run(str(sm.__getattr__(
'forward').graph))
9015 def test_call_tracing_mod_from_script_module(self):
9016 class TracedMod(torch.nn.Module):
9018 super(TracedMod, self).__init__()
9019 self.
param = torch.nn.Parameter(torch.rand(3, 5))
9021 def forward(self, x):
9022 return torch.mm(x, self.
param)
9026 super(ScriptMod, self).__init__()
9027 self.
param = torch.nn.Parameter(torch.rand(4, 3))
9030 @torch.jit.script_method
9031 def forward(self, x):
9032 return self.
tm(torch.mm(x, self.
param))
9038 self.assertTrue(len(list(sm.graph.inputs())) == 3)
9039 FileCheck().check(
"aten::mm").check(
"aten::mm").run(str(sm.graph))
9041 def test_call_script_fn_from_script_module(self):
9048 super(ScriptMod, self).__init__()
9049 self.
param = torch.nn.Parameter(torch.rand(4, 3))
9051 @torch.jit.script_method
9052 def forward(self, x):
9053 return script_fn(torch.mm(x, self.
param))
9056 graph = (sm.__getattr__(
'forward').graph)
9057 FileCheck().check(
"aten::mm").check(
"aten::neg").run(str(graph))
9059 def test_call_script_mod_from_script_module(self):
9062 super(ScriptMod1, self).__init__()
9063 self.
param = torch.nn.Parameter(torch.rand(3, 5))
9065 @torch.jit.script_method
9066 def forward(self, x):
9067 return torch.mm(x, self.
param)
9071 super(ScriptMod, self).__init__()
9072 self.
param = torch.nn.Parameter(torch.rand(4, 3))
9073 self.
tm = ScriptMod1()
9075 @torch.jit.script_method
9076 def forward(self, x):
9077 return self.
tm(torch.mm(x, self.
param))
9084 FileCheck().check_count(
'%', 3).check(
":").check_count(
"mm", 2).run(str(sm.graph))
9086 def test_module_with_params_called_fails(self):
9087 with self.
assertRaisesRegex(RuntimeError,
"Attempted to inline a Module with parameters. Stateful " 9088 "modules to be inlined must be submodules of the callee."):
9091 super(ScriptMod, self).__init__()
9092 self.
param = torch.nn.Parameter(torch.rand(3, 3))
9094 @torch.jit.script_method
9095 def forward(self, x):
9096 return torch.mm(x, self.
param)
9104 def test_index_put_trace_with_view(self):
9105 @_trace(torch.rand(100),
torch.tensor([1, 2, 3, 4]), torch.rand(1, 1, 1, 4))
9106 def test_index_put(target, indices, rhs):
9107 target[indices] = rhs
9110 FileCheck().check(
"aten::view").check(
"index_put_").run(str(test_index_put.graph))
9112 def test_index_put_trace_without_view(self):
9113 @_trace(torch.rand(100),
torch.tensor([1, 2, 3, 4]), torch.rand(4))
9114 def test_index_put(target, indices, rhs):
9115 target[indices] = rhs
9118 FileCheck().check_not(
"aten::view").check(
"index_put_").run(str(test_index_put.graph))
9120 def test_tuple_indexing(self):
9132 FileCheck().check_count(
"TupleIndex", 2, exactly=
True).run(str(tuple_comp.graph))
9134 with self.
assertRaisesRegex(RuntimeError,
"tuple indices must be integer constants"):
9136 def test_non_constant_input(a):
9144 def test_indexing_float():
9148 "tuple indices must")
9150 def test_indexing_out_of_bounds_pos():
9157 def test_indexing_out_of_bounds_neg():
9164 def test_namedtuple_attr(self):
9166 return x.max(dim=1).indices + torch.max(x, dim=1).indices
9168 self.
checkScript(f, (torch.rand(20, 20, 20),), optimize=
True)
9173 return x.max(dim=1).unknown_symbol
9175 with self.
assertRaisesRegex(RuntimeError,
"Getting attributes of tuples is not supported"):
9178 print((x, x, x).__doc__)
9181 def test_tuple_slicing(self):
9193 slices = tuple_graph.findAllNodes(
"prim::TupleSlice")
9194 num_outputs = set(map(
lambda x: len(x.output().type().elements()), slices))
9196 self.assertTrue(num_outputs == {2, 4})
9197 self.
run_pass(
'lower_all_tuples', tuple_graph)
9198 self.assertTrue(
'Tuple' not in str(tuple_graph))
9203 def test_indexing_end_out_of_bounds():
9207 self.
assertEqual(test_indexing_end_out_of_bounds(), ())
9209 def test_unwrap_optional_builtin(self):
9229 with self.
assertRaisesRegex(RuntimeError,
"cannot match an Optional\\[T\\] to None"):
9235 def test_indexing_error(self):
9236 with self.
assertRaisesRegex(RuntimeError,
"only supported on lists, dictionaries, tensors, and tuples"):
9238 def test_wrong_type():
9242 def test_annotated_script_fn(self):
9248 self.
assertExpected(foo.__getattr__(
'forward').pretty_print_schema())
9250 def test_annotated_script_method(self):
9252 @torch.jit.script_method
9253 def forward(self, x, y):
9259 self.
assertExpected(sm.__getattr__(
'forward').pretty_print_schema())
9261 def test_annotated_script_fn_return_mismatch(self):
9268 def test_annotated_script_fn_arg_mismatch(self):
9275 def test_script_non_tensor_args_outputs(self):
9279 return float((x + y).sum())
9281 x = torch.ones(2, 2)
9283 self.assertIsInstance(z, float)
9286 @unittest.skip(
'https://github.com/pytorch/pytorch/issues/9595')
9287 def test_inline_and_run_annotated_script_fn(self):
9289 def to_inline(x, y):
9295 return to_inline((x, x), x)
9297 x = torch.rand(3, 4)
9300 def test_file_format_serialization(self):
9302 filename = tempfile.mktemp()
9303 writer = torch._C.PyTorchFileWriter(filename)
9306 buffers = [os.urandom(size)
for size
in [random.randint(1, 100)
for i
in range(20)]]
9308 for i, buf
in enumerate(buffers):
9309 writer.write_record(str(i), buf, len(buf))
9312 serialized_offsets = pickle.dumps(offsets)
9313 writer.write_record(
"meta", serialized_offsets, len(serialized_offsets))
9314 writer.write_end_of_file()
9316 reader = torch._C.PyTorchFileReader(filename)
9317 serialized_offsets_read = reader.get_record(
"meta")
9318 parsed_serialized_offsets = pickle.loads(serialized_offsets)
9320 for i, offset
in enumerate(parsed_serialized_offsets):
9321 data = reader.get_record(str(offset))
9322 assert(data == buffers[i])
9325 def type_input_return_pairs(self):
9327 (
'Tensor',
'Tensor'),
9328 (
'torch.Tensor',
'Tensor'),
9332 (
'BroadcastingList3[float]',
'List[float]'),
9333 (
'BroadcastingList2[int]',
'List[int]'),
9334 (
'List[int]',
'List[int]'),
9335 (
'Optional[int]',
'Optional[int]'),
9339 def format_code(self, code, pair):
9340 return code.format(input=pair[0], output=pair[1])
9349 def test_annot_string_py3_fn(self):
9351 def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]: 9357 test_str.append(cu.__getattr__(
'foo').pretty_print_schema())
9361 def test_annot_string_py3_method(self):
9364 super(TestModule, self).__init__()
9367 def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]: 9374 test_str.append(tm.__getattr__(
'foo').pretty_print_schema())
9378 def test_annot_string_mypy_fn(self):
9381 # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}] 9387 test_str.append(cu.__getattr__(
'foo').pretty_print_schema())
9391 def test_annot_string_mypy_method(self):
9394 super(TestModule, self).__init__()
9397 def foo(self, x, y): 9398 # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}] 9406 test_str.append(tm.__getattr__(
'foo').pretty_print_schema())
9411 def _get_py3_code(self, code, fn_name):
9412 with tempfile.TemporaryDirectory()
as tmp_dir:
9413 script_path = os.path.join(tmp_dir,
'script.py')
9414 with open(script_path,
'w')
as f:
9416 import importlib.util
9417 spec = importlib.util.spec_from_file_location(fn_name, script_path)
9418 module = importlib.util.module_from_spec(spec)
9419 spec.loader.exec_module(module)
9420 fn = getattr(module, fn_name)
9424 @unittest.skipIf(
not PY35,
"Python 3.5 needed")
9425 def test_annot_ast_py3_fn(self):
9427 from typing import Tuple, List, Optional 9428 from torch import Tensor 9429 from torch.jit.annotations import BroadcastingList2, BroadcastingList3 9432 def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]: 9438 test_str.append(fn.__getattr__(
'forward').pretty_print_schema())
9442 @unittest.skipIf(
not PY35,
"Python 3.5 needed")
9443 def test_annot_ast_py3_method(self):
9445 from typing import Tuple, List, Optional 9446 from torch import Tensor 9447 from torch.jit.annotations import BroadcastingList2, \\ 9450 class FooModule(torch.jit.ScriptModule): 9451 @torch.jit.script_method 9452 def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]: 9454 instance = FooModule() 9460 test_str.append(fn.__getattr__(
'foo').pretty_print_schema())
9464 @unittest.skipIf(
not PY35,
"Python 3.5 needed")
9465 def test_annot_ast_mypy_fn(self):
9470 # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}] 9477 test_str.append(fn.__getattr__(
'forward').pretty_print_schema())
9481 @unittest.skipIf(
not PY35,
"Python 3.5 needed")
9482 def test_annot_ast_mypy_method(self):
9485 class FooModule(torch.jit.ScriptModule): 9486 @torch.jit.script_method 9487 def foo(self, x, y): 9488 # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}] 9490 instance = FooModule() 9496 test_str.append(fn.__getattr__(
'foo').pretty_print_schema())
9499 def test_method_casts_script(self):
9501 'byte',
'char',
'double',
'float',
'int',
'long',
'short' 9504 for cast_type
in cast_types:
9507 return x.{cast_type}() 9508 '''.format(cast_type=cast_type))
9510 x = torch.rand(3, 4, 5) * 128
9511 cu_result = cu.cast_to(x)
9512 reference = getattr(x, cast_type)()
9515 def test_listconstruct_erasure(self):
9516 class FooMod(torch.nn.Module):
9517 def forward(self, x):
9524 FooMod(), (torch.rand(3, 4),), f,
9525 operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK))
9527 def test_trace_checker_arange_as_constant(self):
9529 @_trace(torch.rand(3, 4), check_inputs=[(torch.rand(4, 5),)])
9531 y = torch.arange(0, x.shape[0]).double()
9532 return x + y.unsqueeze(1)
9535 def test_trace_checker_dot_data(self):
9537 r'across invocations'):
9538 @_trace(torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
9544 def test_trace_checker_control_flow(self):
9546 for _
in range(x.size(0)):
9551 torch.jit.trace(foo, torch.randn(3, 4), check_inputs=[torch.randn(4, 4)])
9554 def test_trace_checker_memoization(self):
9557 if not hasattr(foo,
'cache'):
9558 foo.cache = torch.neg(x)
9559 return x + foo.cache
9561 traced =
torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
9565 if not TEST_WITH_UBSAN
and torch.fbgemm_is_cpu_supported():
9566 def test_int8_quantization_module(self):
9569 class FooBar(torch.nn.Module):
9571 super(FooBar, self).__init__()
9572 self.
linear1 = torch.nn.Linear(K1, N1).float()
9574 def forward(self, x):
9579 fb.linear1.weight = torch.nn.Parameter(
9580 torch.tensor([[-150, 100], [100, -150]], dtype=torch.float), requires_grad=
False)
9581 fb.linear1.bias = torch.nn.Parameter(torch.zeros_like(fb.linear1.bias), requires_grad=
False)
9583 fb_ref.linear1.weight = torch.nn.Parameter(fb.linear1.weight.clone(), requires_grad=
False)
9584 fb_ref.linear1.bias = torch.nn.Parameter(fb.linear1.bias.clone(), requires_grad=
False)
9587 x = (torch.rand(1, K1).float() - 0.5) / 10.0
9596 def checkTracerWarning(self, *args, **kwargs):
9597 with warnings.catch_warnings(record=
True)
as warns:
9599 self.assertGreater(len(warns), 0)
9601 self.assertIn(
"cause the trace to be incorrect", str(warn.message))
9603 def test_trace_checker_slice_lhs(self):
9606 x[i, :] = torch.zeros(4)
9611 def test_trace_checker_inplace_on_view(self):
9613 x.view(-1).add_(-x.view(-1))
9618 check_inputs=[torch.rand(5, 6)],
9619 _force_outplace=
True),
9620 'Output nr 1. of the traced function does not match the ' 9621 'corresponding output of the Python function')
9623 def test_lhs_index_fails(self):
9629 def test_lhs_index_trivial(self):
9633 self.
checkTrace(foo, (torch.rand(3, 4), torch.rand(4)), inputs_require_grads=
False)
9635 def test_inplace_warn(self):
9637 x.view(-1).add_(-x.view(-1))
9642 def test_trace_checker_dropout_train(self):
9644 return torch.dropout(x, p=0.5, train=
True)
9647 'Output nr 1. of the traced function does not match the ' 9648 'corresponding output of the Python function')
9650 'Trace had nondeterministic nodes')
9652 def test_trace_checker_dropout_notrain(self):
9653 input = torch.rand(3, 4)
9657 return torch.dropout(x, p=0.5, train=
False)
9661 def test_export_dynamic_slice(self):
9663 @torch.jit.script_method
9664 def forward(self, x):
9666 for i
in range(x.size(1)):
9667 retval += torch.sum(x[0:i], dim=0)
9670 mod = DynamicSliceExportMod()
9672 input = torch.rand(3, 4, 5)
9673 example_outs = mod(input)
9677 DynamicSliceExportMod(), (input,), f, example_outputs=example_outs)
9680 def test_string_frontend_elif(self):
9682 def elif_test(niter : int): 9684 for i in range(niter): 9685 if i % 3 == 0 and i % 5 == 0: 9696 self.
checkScript(code, (101,), name=
'elif_test', outputs=3028)
9698 def test_pyop_exception_message(self):
9701 super(Foo, self).__init__()
9702 self.
conv = nn.Conv2d(1, 10, kernel_size=5)
9704 @torch.jit.script_method
9705 def forward(self, x):
9709 with self.
assertRaisesRegex(RuntimeError,
"Expected 4-dimensional input for 4-dimensional weight"):
9710 foo(torch.ones([123]))
9712 def test_builtin_error_messsage(self):
9718 return x.masked_fill(
True)
9720 with self.
assertRaisesRegex(RuntimeError,
"This op may not exist or may not be currently " 9721 "supported in TorchScript"):
9724 torch.set_grad_enabled(
True)
9727 def test_exceptions(self):
9743 raise ArbitraryError(a,
"hi")
9745 raise ArbitraryError
9754 def foo_except_used():
9775 raise Exception(
"Hi")
9778 def test_assertions(self):
9781 assert bool(cond), "hi" 9791 assert bool(cond),
"hi" 9798 def test_weak_script_function(self):
9802 def not_a_script_fn(x):
9806 def even_more_inner(x):
9811 return not_a_script_fn(x) + x + even_more_inner(x)
9814 def strong_script_fn(x):
9815 if bool(x.norm() > 2):
9817 return x + 4 + inner(x)
9819 @torch._jit_internal.weak_script
9820 def weak_script_fn_inner(x):
9821 return x + 6 + not_a_script_fn(x)
9823 @torch._jit_internal.weak_script
9824 def weak_script_fn(x):
9825 return x + 5 + weak_script_fn_inner(x) + weak_script_fn_inner(x)
9828 x = not_a_script_fn(x)
9829 x = strong_script_fn(x)
9830 return weak_script_fn(x)
9832 input = torch.randn(3, 4, 5)
9835 def test_python_op_exception(self):
9837 raise Exception(
"bad!")
9846 def test_trace_contiguous(self):
9848 return x[:, :, ::2].contiguous().view(12)
9850 x = torch.rand(2, 3, 4)
9853 self.
assertNotEqual(x.storage().data_ptr(), y.storage().data_ptr())
9861 def test_trace_contiguous_short_circuit(self):
9863 return x.contiguous()
9865 x = torch.rand(2, 3, 4)
9867 FileCheck().check(
"aten::contiguous").run(str(traced.graph))
9869 def test_weak_module(self):
9871 @torch._jit_internal.weak_module
9872 class Weak(torch.nn.Module):
9873 __constants__ = [
'number']
9876 super(Weak, self).__init__()
9879 def python_op_in_weak_module(self, x):
9882 @torch._jit_internal.weak_script_method
9883 def forward(self, x):
9884 return 55 + self.
number + self.python_op_in_weak_module(x)
9887 __constants__ = [
'number']
9890 super(OtherStrong, self).__init__()
9893 def python_op_in_strong_module(self, x):
9896 @torch.jit.script_method
9897 def forward(self, x):
9898 return x + self.
number + self.python_op_in_strong_module(x)
9902 super(Passthrough, self).__init__()
9905 @torch.jit.script_method
9906 def forward(self, x):
9911 expected_result = 55 + 199 + (x + 123)
9917 python_result = weak_mod(x)
9918 strong_mod = Passthrough()
9919 script_result = strong_mod(x)
9926 super(Strong, self).__init__()
9928 self.
strong = OtherStrong()
9930 @torch.jit.script_method
9931 def forward(self, x):
9935 strong_mod = Strong()
9936 strong_mod2 = Strong()
9938 expected_result = (x * 2) + 1 + (55 + 199 + x * 2 + 123) + (x * 2 + 357 + x * 2 + 456)
9939 script_result = strong_mod(x)
9940 script_result2 = strong_mod2(x)
9944 def test_weak_module_parameters_and_buffers(self):
9945 weights = torch.randn(10, 10)
9946 bias = torch.randn(10)
9947 weights2 = torch.randn(10, 10)
9948 bias2 = torch.randn(10)
9950 @torch._jit_internal.weak_module
9951 class TestLinear(torch.nn.Module):
9952 def __init__(self, in_features, out_features):
9953 super(TestLinear, self).__init__()
9956 self.
weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
9957 self.
bias = torch.nn.Parameter(torch.Tensor(out_features))
9958 self.register_buffer(
'counter', torch.ones(out_features))
9959 self.reset_parameters()
9961 def reset_parameters(self):
9963 if self.
bias is not None:
9965 bound = 1 / math.sqrt(fan_in)
9968 @torch._jit_internal.weak_script_method
9969 def forward(self, input):
9970 return F.linear(input, self.
weight, self.
bias) + self.counter
9975 super(Strong, self).__init__()
9976 self.
fc1 = TestLinear(10, 10)
9977 self.fc1.weight = torch.nn.Parameter(weights)
9978 self.fc1.bias = torch.nn.Parameter(bias)
9979 self.
fc2 = TestLinear(10, 10)
9980 self.fc2.weight = torch.nn.Parameter(weights2)
9981 self.fc2.bias = torch.nn.Parameter(bias2)
9983 @torch.jit.script_method
9984 def forward(self, x):
9985 return x + self.
fc1(x) + self.
fc1(x) + self.
fc2(x)
9987 strong_mod = Strong()
9990 inp = torch.ones(10)
9991 lin = torch.nn.Linear(10, 10)
9992 lin.weight = torch.nn.Parameter(weights)
9993 lin.bias = torch.nn.Parameter(bias)
9994 lin2 = torch.nn.Linear(10, 10)
9995 lin2.weight = torch.nn.Parameter(weights2)
9996 lin2.bias = torch.nn.Parameter(bias2)
9997 expected_result = inp + (lin(inp) + torch.ones(10)) * 2 + lin2(inp) + torch.ones(10)
9999 self.
assertEqual(strong_mod(inp), expected_result)
10002 def test_weak_module_nested(self):
10003 @torch._jit_internal.weak_module
10004 class OtherWeak(torch.nn.Module):
10005 __constants__ = [
'constant']
10007 def __init__(self, in_features, out_features):
10008 super(OtherWeak, self).__init__()
10011 self.
weight = torch.nn.Parameter(torch.ones(out_features, in_features))
10012 self.
bias = torch.nn.Parameter(torch.ones(out_features))
10015 @torch._jit_internal.weak_script_method
10016 def forward(self, x):
10021 def __init__(self):
10022 super(OtherStrong, self).__init__()
10024 @torch.jit.script_method
10025 def forward(self, x):
10028 @torch._jit_internal.weak_module
10029 class Weak(torch.nn.Module):
10030 def __init__(self, in_features, out_features):
10031 super(Weak, self).__init__()
10034 self.
weight = torch.nn.Parameter(2 * torch.ones(out_features, in_features))
10035 self.
bias = torch.nn.Parameter(2 * torch.ones(out_features))
10039 @torch._jit_internal.weak_script_method
10040 def forward(self, x):
10045 __constants__ = [
'constant']
10047 def __init__(self):
10048 super(Strong, self).__init__()
10049 self.
weak = Weak(10, 10)
10051 @torch.jit.script_method
10052 def forward(self, x):
10053 return x + self.
weak(x)
10055 strong_mod = Strong()
10056 inp = torch.randn(10)
10057 result = strong_mod(inp)
10058 expected_result = inp + (inp + inp * inp + inp + 27) + 3 \
10059 + F.linear(inp, torch.ones(10, 10), torch.ones(10)) \
10060 + F.linear(inp, 2 * torch.ones(10, 10), 2 * torch.ones(10))
10063 def test_weak_module_submodule(self):
10064 @torch._jit_internal.weak_module
10065 class Weak(torch.nn.Module):
10066 def __init__(self):
10067 super(Weak, self).__init__()
10068 self.
param = torch.nn.Parameter(100 * torch.ones(5))
10070 @torch._jit_internal.weak_script_method
10071 def forward(self, x):
10072 return x + self.
param 10077 def __init__(self):
10078 super(OtherStrong, self).__init__()
10082 @torch.jit.script_method
10083 def forward(self, x):
10084 return x + self.
weak(x)
10087 def __init__(self):
10088 super(Strong, self).__init__()
10091 @torch.jit.script_method
10092 def forward(self, x):
10093 return self.
weak(x) + weak(x)
10095 other_strong_mod = OtherStrong()
10097 self.assertIs(other_strong_mod.weak, other_strong_mod.weak2)
10099 with self.
assertRaisesRegex(RuntimeError,
"Attempted to inline a Module with param"):
10100 strong_mod = Strong()
10102 def test_weak_module_copying(self):
10103 class Submodule(torch.nn.Module):
10104 def __init__(self):
10105 super(Submodule, self).__init__()
10107 def forward(self, x):
10110 @torch._jit_internal.weak_module
10111 class Weak(torch.nn.Module):
10112 def __init__(self, in_features, out_features):
10113 super(Weak, self).__init__()
10114 self.
weight = torch.nn.Parameter(torch.ones(out_features, in_features))
10115 self.
bias = torch.nn.Parameter(torch.ones(out_features))
10116 self.register_buffer(
"buffer", torch.ones(out_features))
10119 @torch._jit_internal.weak_script_method
10120 def forward(self, x):
10121 return F.linear(x, self.
weight, self.
bias) \
10125 def __init__(self, weak):
10126 super(Strong, self).__init__()
10129 @torch.jit.script_method
10130 def forward(self, x):
10131 return self.
weak(x)
10133 inp = torch.ones(5, 5) * 5
10134 weak_mod = Weak(5, 5)
10135 strong_mod = Strong(weak_mod)
10140 self.assertIs(strong_mod.weak.weight, weak_mod.weight)
10141 self.assertIs(strong_mod.weak.buffer, weak_mod.buffer)
10142 self.assertIs(strong_mod.weak.submodule, weak_mod.submodule)
10145 weak_mod.new_attribute = 10
10146 self.assertIs(strong_mod.weak.new_attribute, weak_mod.new_attribute)
10148 weak_mod.weight.data += torch.ones(5, 5) * 100
10149 self.assertTrue(strong_mod(inp).allclose(weak_mod(inp)))
10152 weak_mod.weight = torch.nn.Parameter(torch.ones(5, 5) * 100)
10153 self.assertFalse(strong_mod(inp).allclose(weak_mod(inp)))
10155 def test_backend_cudnn_enabled(self):
10159 if torch.backends.cudnn.enabled:
10165 def test_inplace_add(self):
10171 self.
checkScript(foo, (torch.rand(3), torch.rand(3)))
10173 def test_add_out(self):
10177 torch.add(c, b, out=e)
10179 self.
checkScript(foo, (torch.rand(3), torch.rand(3)))
10181 def test_augmented_assign(self):
10188 self.
checkScript(foo, (torch.rand(3), torch.rand(3)))
10190 def test_pass(self):
10193 for _i
in range(3):
10203 def test_optional_conversion(self):
10205 def other_fn(x=None):
10217 def unify_to_optional(x):
10234 def broadcast_opt_list(x):
10239 def opt_list_tuple_caller(x):
10241 return opt_list(x) + broadcast_opt_list(x)
10243 self.
assertEqual(opt_list_tuple_caller((2., 3.)), 4)
10245 def test_lhs_indexing(self):
10250 self.
checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
10252 def test_lhs_advanced_indexing_assignment(self):
10258 self.
checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3)))
10260 def test_lhs_advanced_indexing_augmented_assignment(self):
10266 self.
checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3)))
10268 def test_lhs_indexing_list(self):
10273 self.
checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
10275 def test_inplace_copy_script(self):
10277 a = torch.rand(3, 4)
10282 def test_lhs_indexing_increment(self):
10286 self.
checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
10288 def test_lhs_indexing_increment_list(self):
10294 self.
checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
10296 def test_lhs_indexing_increment_list_prim(self):
10303 def test_lhs_indexing_multi(self):
10306 foo, a[0], bar = (1, b, 3)
10308 self.
checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
10310 def test_bool_dispatch(self):
10312 def kwarg_false(x):
10314 return F.max_pool1d(x, 1, 1, return_indices=
False)
10315 self.
checkScript(kwarg_false, (torch.randn(3, 3, 3),))
10319 return F.max_pool1d(x, 1, 1, return_indices=
True)
10320 self.
checkScript(kwarg_true, (torch.randn(3, 3, 3),))
10322 def full_kwarg_false(x):
10324 return F.max_pool1d(x, 1, 1, ceil_mode=
False, return_indices=
False)
10325 self.
checkScript(full_kwarg_false, (torch.randn(3, 3, 3),))
10327 def full_kwarg_true(x):
10329 return F.max_pool1d(x, 1, 1, ceil_mode=
False, return_indices=
True)
10330 self.
checkScript(full_kwarg_true, (torch.randn(3, 3, 3),))
10332 def use_default(x):
10334 return F.max_pool1d(x, 1, 1)
10335 self.
checkScript(use_default, (torch.randn(3, 3, 3),))
10339 return F.max_pool1d(x, 1, 1, 0, 1,
False,
False)
10340 self.
checkScript(arg_false, (torch.randn(3, 3, 3),))
10344 return F.max_pool1d(x, 1, 1, 0, 1,
False,
True)
10345 self.
checkScript(arg_true, (torch.randn(3, 3, 3),))
10347 def test_infer_size(self):
10352 return _infer_size(x.size(), y.size())
10354 self.
checkScript(fn, (torch.ones(2, 4, 2), torch.ones(2, 4, 2)))
10356 def test_mutable_dce(self):
10359 a = torch.rand(2, 3)
10360 a += torch.rand(2, 3)
10361 b = torch.rand(2, 3)
10362 b += torch.rand(2, 3)
10366 FileCheck().check_count(
"aten::rand", 2, exactly=
True) \
10367 .check_count(
"aten::add", 1, exactly=
True).run(str(foo.graph))
10369 def test_mutable_dce_block(self):
10372 a = torch.rand(2, 3)
10373 a += torch.rand(2, 3)
10374 b = torch.rand(2, 3)
10375 if bool(a > torch.zeros(2, 3)):
10376 b += torch.rand(2, 3)
10377 a += torch.rand(2, 3)
10381 FileCheck().check(
"prim::If").check_count(
"aten::rand", 1, exactly=
True) \
10382 .run(str(foo.graph))
10384 def test_mutable_dce_graph_input(self):
10387 a += torch.rand(2, 3)
10390 FileCheck().check(
"aten::rand").check(
"aten::add").run(str(foo.graph))
10392 def test_mutable_dce_list(self):
10398 b = torch.rand(2, 3)
10399 c += torch.rand(2, 3)
10403 FileCheck().check_count(
"aten::rand", 2, exactly=
True).run(str(foo.graph))
10405 def test_mutable_dce_loop(self):
10411 b = torch.rand(2, 3)
10413 dead = torch.rand(2, 3)
10415 c += torch.rand(2, 3)
10419 FileCheck().check(
"prim::Loop").check_not(
"aten::rand").check(
"aten::select") \
10420 .check_count(
"aten::rand", 1, exactly=
True).run(str(foo.graph))
10422 def test_mutable_dce_wildcards(self):
10424 x = torch.ones(2, 3)
10428 x.add_(torch.ones(2, 3))
10433 def test_cpp_function_tensor_str(self):
10434 x = torch.randn(2, 2)
10435 scale = torch.randn(2, 2, requires_grad=
True)
10436 shift = torch.randn(2, 2, requires_grad=
True)
10439 def fn(x, scale, shift):
10440 return scale * x + shift
10443 print(fn(x, scale, shift))
10445 def test_non_final_return(self):
10452 raise RuntimeError(
"nope")
10470 def nest_early_ret(x):
10486 def not_early_ret(x):
10497 def not_total_ret(x):
10509 def nest_while_ret(x):
10517 def nest_for_ret(x):
10523 def test_overloading(self):
10524 @torch._jit_internal.weak_module
10525 class W(torch.nn.Module):
10526 __overloads__ = {
'forward': [
'forward_tuple',
'forward_tensor']}
10528 def __init__(self):
10529 super(W, self).__init__()
10531 @torch._jit_internal.weak_script_method
10532 def forward_tuple(self, x):
10536 def forward(self, x):
10538 if isinstance(x, tuple):
10539 return self.forward_tuple(x)
10541 return self.forward_tensor(x)
10543 @torch._jit_internal.weak_script_method
10544 def forward_tensor(self, x):
10549 def __init__(self):
10550 super(S, self).__init__()
10553 @torch.jit.script_method
10554 def forward(self, x):
10555 return self.
weak(x) + self.
weak((x, x))
10565 def test_select_after_chunk(self):
10567 chunked = torch.chunk(x, 1)
10574 def test_nn_LSTM(self):
10578 def __init__(self):
10579 super(S, self).__init__()
10580 self.
x = torch.nn.LSTM(5, 5)
10582 @torch.jit.script_method
10583 def forward(self, input):
10585 return self.
x(input)
10587 eager_out = self.
runAndSaveRNG(
lambda x: torch.nn.LSTM(5, 5)(x), (input,))[0]
10588 script_out = self.
runAndSaveRNG(
lambda x: S()(x), (input,))[0]
10592 def test_list_python_op(self):
10593 def python_list_op(lst):
10599 return python_list_op(lst)
10601 self.
checkScript(fn, ([torch.ones(2) + 2, torch.ones(2)],))
10603 def test_ignore_decorator(self):
10605 def __init__(self):
10606 super(M, self).__init__()
10607 tensor = torch.zeros(1, requires_grad=
False)
10608 self.register_buffer(
'some_state', torch.nn.Parameter(tensor))
10610 @torch.jit.script_method
10611 def forward(self, x):
10612 self.ignored_code(x)
10616 def ignored_code(self, x):
10623 self.
assertEqual(m.some_state, torch.zeros(1) + 100)
10626 pp, constants = m._python_print()
10628 ppv =
"op_version_set = 0\n{}".format(pp)
10629 torch._C._jit_import_methods(printed, ppv, constants)
10630 self.assertIn(
'IgnoredPythonOp', ppv)
10631 self.assertNotIn(
'ignored_code', ppv)
10633 with self.
assertRaisesRegex(torch.jit.Error,
"This Python function is annotated to be ignored"):
10634 printed(torch.ones(1))
10636 def test_view_write(self):
10645 self.
checkScript(fn, (torch.rand(2, 3), torch.rand(2, 3)))
10647 def test_dict_view(self):
10655 self.
checkScript(fn, (torch.rand(2, 3), torch.rand(2, 3)))
10657 def test_dict_ops(self):
10658 d = {
'a': torch.ones(1),
'b': torch.ones(1) + 1,
'c': torch.ones(1) + 2}
10663 return list(x.keys())
10670 return list(x.values())
10672 self.
assertEqual(set(values(d)), set(d.values()))
10680 def test_dict(self):
10685 self.
checkScript(simple, ({
'item': 20,
'other_item': 120},))
10691 self.
checkScript(index, ({
'item': 20,
'other_item': 120},))
10693 def type_default():
10700 def missing_index(x):
10705 missing_index({
'item': 20,
'other_item': 120})
10709 return torch.jit.annotate(Dict[int, float], {}) 10711 return torch.jit.annotate(Dict[int, float], {10: 1.2}) 10719 return torch.jit.annotate(Dict[int, float], {10: 1.2, 11: 1.3}) 10721 self.
assertEqual({10: 1.2, 11: 1.3}, cu.literal3())
10723 def list_of_dicts():
10725 return [{
'word': torch.ones(2) + 3}, {
'other word': torch.ones(1) + 2}]
10729 def test_dict_mutability(self):
10739 def test_dict_membership(self):
10748 def optional(x, y):
10758 def bad_types(x, y):
10762 def dict_to_python(self):
10763 def python_lookup(my_dict, keys):
10765 return [my_dict[k]
for k
in keys]
10767 def fn(my_dict, keys):
10769 return python_lookup(my_dict, keys)
10771 a_dict = {
'a': torch.ones(1),
'b': torch.ones(1) + 1,
'c': torch.ones(1) + 2}
10774 def test_module_attrs(self):
10776 def __init__(self, table):
10777 super(M, self).__init__()
10781 @torch.jit.script_method
10782 def forward(self, key):
10784 return self.
table[key] + self.
x 10789 m = M({char : torch.ones(1) + ord(char) - ord(
"a")
for char
in "abcdefg"})
10792 def test_tensor_import_export(self):
10800 self.
run_pass(
'constant_propagation', foo.graph)
10802 m._create_method_from_graph(
"forward", foo.graph)
10805 def test_attribute_serialization(self):
10807 def __init__(self):
10808 super(M, self).__init__()
10817 @torch.jit.script_method
10825 @unittest.skipIf(IS_WINDOWS
or IS_SANDCASTLE,
"NYI: TemporaryFileName support for Windows or Sandcastle")
10826 def test_attribute_unpickling(self):
10830 def __init__(self):
10831 super(M, self).__init__()
10840 @torch.jit.script_method
10844 class TensorID(object):
10845 def __setstate__(self, id):
10848 class IntList(object):
10849 def __setstate__(self, data):
10852 class JitUnpickler(pickle.Unpickler):
10853 def find_class(self, module, name):
10854 if not module ==
'__main__':
10857 if name ==
'TensorID':
10859 elif name ==
'IntList':
10862 with TemporaryFileName()
as fname:
10864 archive_name = os.path.basename(os.path.normpath(fname))
10865 archive = zipfile.ZipFile(fname,
'r') 10866 pickled_data = archive.read(os.path.join(archive_name, 'attributes.pkl'))
10867 JitUnpickler(io.BytesIO(pickled_data)).load()
10869 def test_submodule_attribute_serialization(self):
10871 def __init__(self, list_data):
10872 super(S, self).__init__()
10876 @torch.jit.script_method
10881 def __init__(self):
10882 super(M, self).__init__()
10885 self.
s1 = S([(1, 2)])
10886 self.
s2 = S([(4, 5)])
10888 @torch.jit.script_method
10890 return (self.
table, self.
tensor, self.s1.table, self.s2.list, self.s1.list)
10896 def test_optional_tuple(self):
10908 def test_split(self):
10909 def split_two(tensor):
10910 a, b, c = torch.split(tensor, 2, dim=1)
10912 x = torch.randn(3, 6)
10913 y = torch.randn(3, 6)
10918 def __init__(self):
10919 super(MnistNet, self).__init__()
10920 self.
conv1 = nn.Conv2d(1, 10, kernel_size=5)
10921 self.
conv2 = nn.Conv2d(10, 20, kernel_size=5)
10923 self.
fc1 = nn.Linear(320, 50)
10924 self.
fc2 = nn.Linear(50, 10)
10926 def forward(self, x):
10927 x = F.relu(F.max_pool2d(self.
conv1(x), 2))
10929 x = x.view(-1, 320)
10930 x = F.relu(self.
fc1(x))
10931 x = F.dropout(x, training=self.training)
10933 return F.log_softmax(x, dim=1)
10938 def _test_dcgan_models(self, device, check_export_import=True):
10939 class DCGANGenerator(nn.Module):
10940 def __init__(self, nz, ngf, nc):
10941 super(DCGANGenerator, self).__init__()
10942 self.
main = nn.Sequential(
10944 nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=
False),
10945 nn.BatchNorm2d(ngf * 8),
10948 nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=
False),
10949 nn.BatchNorm2d(ngf * 4),
10952 nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=
False),
10953 nn.BatchNorm2d(ngf * 2),
10956 nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=
False),
10957 nn.BatchNorm2d(ngf),
10960 nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=
False),
10965 def forward(self, input):
10966 return self.
main(input)
10968 class DCGANDiscriminator(nn.Module):
10969 def __init__(self, nc, ndf):
10970 super(DCGANDiscriminator, self).__init__()
10971 self.
main = nn.Sequential(
10973 nn.Conv2d(nc, ndf, 4, 2, 1, bias=
False),
10974 nn.LeakyReLU(0.2, inplace=
True),
10976 nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=
False),
10977 nn.BatchNorm2d(ndf * 2),
10978 nn.LeakyReLU(0.2, inplace=
True),
10980 nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=
False),
10981 nn.BatchNorm2d(ndf * 4),
10982 nn.LeakyReLU(0.2, inplace=
True),
10984 nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=
False),
10985 nn.BatchNorm2d(ndf * 8),
10986 nn.LeakyReLU(0.2, inplace=
True),
10988 nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=
False),
10992 def forward(self, input):
10993 return self.
main(input).view(-1, 1).squeeze(1)
10995 bs, nz, ngf, nc, ndf = 5, 6, 9, 3, 10
10996 self.
checkTrace(DCGANGenerator(nz, ngf, nc).to(device),
10997 (torch.rand(bs, nz, 1, 1, device=device),),
10998 export_import=check_export_import)
10999 example_input = DCGANGenerator(nz, ngf, nc).to(device)(torch.rand(bs, nz, 1, 1, device=device))
11000 self.
checkTrace(DCGANDiscriminator(nc, ndf).to(device), (example_input,),
11001 export_import=check_export_import)
11003 def test_dcgan_models(self):
11006 @unittest.skipIf(
not RUN_CUDA,
"no CUDA")
11007 def test_dcgan_models_cuda(self):
11012 def _test_neural_style(self, device, check_export_import=True):
11013 class TransformerNet(torch.nn.Module):
11014 def __init__(self):
11015 super(TransformerNet, self).__init__()
11017 self.
conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
11018 self.
in1 = torch.nn.InstanceNorm2d(32, affine=
True)
11019 self.
conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
11020 self.
in2 = torch.nn.InstanceNorm2d(64, affine=
True)
11021 self.
conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
11022 self.
in3 = torch.nn.InstanceNorm2d(128, affine=
True)
11024 self.
res1 = ResidualBlock(128)
11025 self.
res2 = ResidualBlock(128)
11026 self.
res3 = ResidualBlock(128)
11027 self.
res4 = ResidualBlock(128)
11028 self.
res5 = ResidualBlock(128)
11030 self.
deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
11031 self.
in4 = torch.nn.InstanceNorm2d(64, affine=
True)
11032 self.
deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
11033 self.
in5 = torch.nn.InstanceNorm2d(32, affine=
True)
11034 self.
deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
11036 self.
relu = torch.nn.ReLU()
11038 def forward(self, X):
11052 class ConvLayer(torch.nn.Module):
11053 def __init__(self, in_channels, out_channels, kernel_size, stride):
11054 super(ConvLayer, self).__init__()
11055 reflection_padding = kernel_size // 2
11056 self.
reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
11057 self.
conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
11059 def forward(self, x):
11064 class ResidualBlock(torch.nn.Module):
11066 introduced in: https://arxiv.org/abs/1512.03385 11067 recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html 11070 def __init__(self, channels):
11071 super(ResidualBlock, self).__init__()
11072 self.
conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
11073 self.
in1 = torch.nn.InstanceNorm2d(channels, affine=
True)
11074 self.
conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
11075 self.
in2 = torch.nn.InstanceNorm2d(channels, affine=
True)
11076 self.
relu = torch.nn.ReLU()
11078 def forward(self, x):
11082 out = out + residual
11085 class UpsampleConvLayer(torch.nn.Module):
11086 """UpsampleConvLayer 11087 Upsamples the input and then does a convolution. This method gives better results 11088 compared to ConvTranspose2d. 11089 ref: http://distill.pub/2016/deconv-checkerboard/ 11092 def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
11093 super(UpsampleConvLayer, self).__init__()
11096 self.
upsample_layer = torch.nn.Upsample(mode=
'nearest', scale_factor=upsample)
11097 reflection_padding = kernel_size // 2
11098 self.
reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
11099 self.
conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
11101 def forward(self, x):
11109 self.
checkTrace(TransformerNet(), (torch.rand(5, 3, 16, 16),), export_import=check_export_import)
11111 def test_neural_style(self):
11114 @unittest.skipIf(
not RUN_CUDA,
"no CUDA")
11115 def test_neural_style_cuda(self):
11120 def _test_mnist(self, device, check_export_import=True):
11122 self.
checkTrace(
MnistNet().to(device).eval(), (torch.rand(5, 1, 28, 28, device=device),),
11123 export_import=check_export_import)
11125 def test_mnist(self):
11128 @unittest.skipIf(
not RUN_CUDA,
"no CUDA")
11129 def test_mnist_cuda(self):
11131 self.
_test_mnist(self, device=
'cuda', check_export_import=
False)
11133 @unittest.skipIf(
not RUN_CUDA,
"no CUDA")
11134 def test_mnist_training_leaks_no_memory_cuda(self):
11137 traced_net =
torch.jit.trace(net, [torch.randn(5, 1, 28, 28, device=
'cuda')],
11141 for _
in range(iters):
11143 inp = torch.randn(5, 1, 28, 28, device=
'cuda')
11144 out = traced_net(inp)
11147 out.sum().backward()
11150 traced_net.zero_grad()
11159 def _test_reinforcement_learning(self, device, test_export_import=True):
11160 class Policy(nn.Module):
11161 def __init__(self):
11162 super(Policy, self).__init__()
11163 self.
affine1 = nn.Linear(4, 128)
11164 self.
affine2 = nn.Linear(128, 2)
11166 def forward(self, x):
11168 action_scores = self.
affine2(x)
11169 return F.softmax(action_scores, dim=1)
11171 self.
checkTrace(Policy().to(device), (torch.rand(1, 4, device=device),),
11172 export_import=test_export_import)
11174 def test_reinforcement_learning(self):
11177 @unittest.skipIf(
not RUN_CUDA,
"no CUDA")
11178 def test_reinforcement_learning_cuda(self):
11183 def _test_snli(self, device, check_export_import=True, quantized=False):
11184 class Bottle(nn.Module):
11186 def forward(self, input):
11187 if len(input.size()) <= 2:
11188 return super(Bottle, self).forward(input)
11189 size = input.size()[:2]
11190 out = super(Bottle, self).forward(input.view(size[0] * size[1], -1))
11191 return out.view(size[0], size[1], -1)
11193 class Linear(Bottle, nn.Linear):
11196 class Encoder(nn.Module):
11198 def __init__(self, config):
11199 super(Encoder, self).__init__()
11200 self.config = config
11201 input_size = config.d_proj
if config.projection
else config.d_embed
11202 dropout = 0
if config.n_layers == 1
else config.dp_ratio
11203 self.rnn = nn.LSTM(input_size=input_size, hidden_size=config.d_hidden,
11204 num_layers=config.n_layers, dropout=dropout,
11205 bidirectional=config.birnn)
11207 def forward(self, inputs):
11208 batch_size = inputs.size()[1]
11209 state_shape = self.config.n_cells, batch_size, self.config.d_hidden
11210 h0 = c0 = inputs.new_zeros(state_shape)
11211 outputs, (ht, ct) = self.rnn(inputs, (h0, c0))
11212 return ht[-1]
if not self.config.birnn
else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)
11214 class SNLIClassifier(nn.Module):
11216 def __init__(self, config):
11217 super(SNLIClassifier, self).__init__()
11218 self.config = config
11219 self.embed = nn.Embedding(config.n_embed, config.d_embed)
11220 self.projection = Linear(config.d_embed, config.d_proj)
11221 self.encoder = Encoder(config)
11222 self.dropout = nn.Dropout(p=config.dp_ratio)
11223 self.relu = nn.ReLU()
11224 seq_in_size = 2 * config.d_hidden
11225 if self.config.birnn:
11227 lin_config = [seq_in_size] * 2
11228 self.out = nn.Sequential(
11229 Linear(*lin_config),
11232 Linear(*lin_config),
11235 Linear(*lin_config),
11238 Linear(seq_in_size, config.d_out))
11240 def forward(self, premise, hypothesis):
11241 prem_embed = self.embed(premise)
11242 hypo_embed = self.embed(hypothesis)
11243 if self.config.fix_emb:
11244 prem_embed = prem_embed.detach()
11245 hypo_embed = hypo_embed.detach()
11246 if self.config.projection:
11247 prem_embed = self.relu(self.projection(prem_embed))
11248 hypo_embed = self.relu(self.projection(hypo_embed))
11249 premise = self.encoder(prem_embed)
11250 hypothesis = self.encoder(hypo_embed)
11251 scores = self.out(torch.cat([premise, hypothesis], 1))
11267 premise = torch.LongTensor(48, 64).random_(0, 100).to(device)
11268 hypothesis = torch.LongTensor(24, 64).random_(0, 100).to(device)
11271 snli = SNLIClassifier(Config()).cpu()
11275 self.checkTrace(snli, (premise, hypothesis), inputs_require_grads=
False,
11276 export_import=
False)
11278 self.checkTrace(SNLIClassifier(Config()).to(device), (premise, hypothesis),
11279 inputs_require_grads=
False, export_import=check_export_import)
11281 def test_snli(self):
11282 self._test_snli(self, device=
'cpu')
11284 if not TEST_WITH_UBSAN
and torch.fbgemm_is_cpu_supported():
11285 def test_snli_quantized(self):
11286 self._test_snli(self, device=
'cpu', quantized=
True)
11288 @unittest.skipIf(
not RUN_CUDA,
"no CUDA")
11289 def test_snli_cuda(self):
11291 self._test_snli(self, device=
'cuda', check_export_import=
False)
11294 def _test_super_resolution(self, device, check_export_import=True):
11297 class Net(nn.Module):
11299 def __init__(self, upscale_factor):
11300 super(Net, self).__init__()
11302 self.relu = nn.ReLU()
11303 self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
11304 self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
11305 self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
11306 self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
11307 self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
11309 def forward(self, x):
11310 x = self.relu(self.conv1(x))
11311 x = self.relu(self.conv2(x))
11312 x = self.relu(self.conv3(x))
11313 x = self.pixel_shuffle(self.conv4(x))
11316 net = Net(upscale_factor=4).to(device)
11317 self.checkTrace(net, (torch.rand(5, 1, 32, 32, device=device),),
11318 export_import=check_export_import)
11320 def test_super_resolution(self):
11321 self._test_super_resolution(self, device=
'cpu')
11323 @unittest.skipIf(
not RUN_CUDA,
'no CUDA')
11324 def test_super_resolution_cuda(self):
11326 self._test_super_resolution(self, device=
'cuda', check_export_import=
False)
11329 def test_time_sequence_prediction(self):
11331 def __init__(self):
11332 super(Sequence, self).__init__()
11333 self.lstm1 = nn.LSTMCell(1, 51)
11334 self.lstm2 = nn.LSTMCell(51, 51)
11335 self.linear = nn.Linear(51, 1)
11341 def test_lstm1(self, input, hx, cx):
11343 return self.lstm1(input, (hx, cx))
11345 def test_lstm2(self, input, hx, cx):
11347 return self.lstm2(input, (hx, cx))
11351 def test_tensor(self):
11354 @torch.jit.script_method
11355 def forward(self, input):
11358 outputs = self.test_tensor()
11359 h_t = torch.zeros((3, 51), dtype=torch.double)
11360 c_t = torch.zeros((3, 51), dtype=torch.double)
11361 h_t2 = torch.zeros((3, 51), dtype=torch.double)
11362 c_t2 = torch.zeros((3, 51), dtype=torch.double)
11364 output = torch.zeros([3, 51])
11369 a, b, c, d = input.chunk(input.size(1), dim=1)
11370 for input_t
in (a, b, c, d):
11371 h_t, c_t = self.test_lstm1(input_t, h_t, c_t)
11372 h_t2, c_t2 = self.test_lstm2(h_t, h_t2, c_t2)
11373 output = self.linear(h_t2)
11374 outputs = torch.cat((outputs, output), 1)
11375 for _
in range(future):
11376 h_t, c_t = self.test_lstm1(output, h_t, c_t)
11377 h_t2, c_t2 = self.test_lstm2(h_t, h_t2, c_t2)
11378 output = self.linear(h_t2)
11379 outputs = torch.cat((outputs, output), 1)
11383 self.checkTrace(Sequence(), (torch.rand(3, 4),),
11384 export_import=
False)
11387 def _test_vae(self, device, check_export_import=True, quantized=False):
11388 class VAE(nn.Module):
11389 def __init__(self):
11390 super(VAE, self).__init__()
11392 self.fc1 = nn.Linear(784, 400)
11393 self.fc21 = nn.Linear(400, 20)
11394 self.fc22 = nn.Linear(400, 20)
11395 self.fc3 = nn.Linear(20, 400)
11396 self.fc4 = nn.Linear(400, 784)
11398 def encode(self, x):
11399 h1 = F.relu(self.fc1(x))
11400 return self.fc21(h1), self.fc22(h1)
11402 def reparameterize(self, mu, logvar):
11404 std = torch.exp(0.5 * logvar)
11405 eps = torch.randn_like(std)
11406 return eps.mul(std).add_(mu)
11410 def decode(self, z):
11411 h3 = F.relu(self.fc3(z))
11412 return torch.sigmoid(self.fc4(h3))
11414 def forward(self, x):
11415 mu, logvar = self.encode(x.view(-1, 784))
11416 z = self.reparameterize(mu, logvar)
11417 return self.decode(z), mu, logvar
11420 vae = VAE().to(device).eval()
11424 self.checkTrace(vae, (torch.rand(128, 1, 28, 28, device=device),),
11425 export_import=
False, allow_unused=
True,
11426 inputs_require_grads=
False)
11429 self.checkTrace(VAE().to(device).eval(), (torch.rand(128, 1, 28, 28, device=device),),
11430 export_import=check_export_import)
11432 def test_vae(self):
11433 self._test_vae(self, device=
'cpu')
11435 if not TEST_WITH_UBSAN
and torch.fbgemm_is_cpu_supported():
11436 def test_vae_quantized(self):
11437 self._test_vae(self, device=
'cpu', quantized=
True)
11439 @unittest.skipIf(
not RUN_CUDA,
"no CUDA")
11440 def test_vae_cuda(self):
11442 self._test_vae(self, device=
'cuda', check_export_import=
False)
11448 def __init__(self):
11451 def forward(self, x):
11452 return x.transpose(0, 1)
11454 def test_protobuf(self):
11456 fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=
True)
11459 export_type=torch.onnx.ExportTypes.PROTOBUF_FILE)
11461 def test_zipfile(self):
11463 fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=
True)
11466 export_type=torch.onnx.ExportTypes.ZIP_ARCHIVE)
11468 def test_compressed_zipfile(self):
11470 fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=
True)
11473 export_type=torch.onnx.ExportTypes.COMPRESSED_ZIP_ARCHIVE)
11475 def test_directory(self):
11477 fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=
True)
11478 d = tempfile.mkdtemp()
11480 export_type=torch.onnx.ExportTypes.DIRECTORY)
11483 def test_onnx_multiple_return(self):
11492 def test_aten_fallback(self):
11493 class ModelWithAtenNotONNXOp(nn.Module):
11494 def forward(self, x, y):
11496 defg = torch.qr(abcd)
11499 x = torch.rand(3, 4)
11500 y = torch.rand(3, 4)
11503 ModelWithAtenNotONNXOp(), (x, y), f,
11504 operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK)
11505 self.assertExpected(exported)
11510 def test_onnx_aten(self):
11511 class ModelWithAtenFmod(nn.Module):
11512 def forward(self, x, y):
11513 return torch.fmod(x, y)
11516 x = torch.randn(3, 4, dtype=torch.float32)
11517 y = torch.randn(3, 4, dtype=torch.float32)
11519 ModelWithAtenFmod(), (x, y), f,
11520 operator_export_type=OperatorExportTypes.ONNX_ATEN)
11521 self.assertExpected(exported)
11529 'test___getitem___adv_index',
11530 'test___getitem___adv_index_beg',
11531 'test___getitem___adv_index_comb',
11532 'test___getitem___adv_index_dup',
11533 'test___getitem___adv_index_sub',
11534 'test___getitem___adv_index_sub_2',
11535 'test___getitem___adv_index_sub_3',
11536 'test___getitem___adv_index_var',
11540 EXCLUDE_TYPE_CHECK = {
11544 'test_slogdet_1x1_neg_det',
11545 'test_slogdet_1x1_pos_det',
11546 'test_slogdet_distinct_singular_values',
11547 'test_slogdet_neg_det',
11548 'test_slogdet_pos_det',
11549 'test_slogdet_symmetric',
11550 'test_slogdet_symmetric_pd',
11556 'test_norm_fro_default',
11563 'test_nn_ctc_loss',
11569 EXCLUDE_PYTHON_PRINT = {
11571 'test_nn_max_unpool1d',
11572 'test_nn_max_unpool2d',
11573 'test_nn_max_unpool3d',
11574 'test_nn_max_pool1d',
11575 'test_nn_max_pool2d',
11576 'test_nn_max_pool3d',
11577 'test_nn_max_pool1d_with_indices',
11580 EXCLUDE_SCRIPT_MODULES = {
11581 'test_nn_AdaptiveAvgPool2d_tuple_none',
11582 'test_nn_AdaptiveAvgPool3d_tuple_none',
11583 'test_nn_AdaptiveMaxPool2d_tuple_none',
11584 'test_nn_AdaptiveMaxPool3d_tuple_none',
11587 DISABLE_AUTODIFF_SUBGRAPH_INLINING = {
11588 'test_nn_avg_pool2d',
11589 'test_nn_adaptive_avg_pool1d',
11590 'test_nn_adaptive_avg_pool2d',
11591 'test_nn_adaptive_avg_pool3d',
11592 'test_nn_batch_norm',
11593 'test_nn_embedding',
11594 'test_nn_log_softmax',
11596 'test_nn_softmax_with_all_args',
11597 'test_nn_threshold',
11598 'test_nn_nll_loss',
11608 def partial_apply_nontensors(fn, args, **kwargs):
11609 source = [
't' if isinstance(arg, torch.Tensor)
else 's' for arg
in args]
11611 def new_fn(*tensors_):
11612 tensors = iter(tensors_)
11613 return fn(*(args[i]
if s ==
's' else next(tensors)
for i, s
in enumerate(source)), **kwargs)
11615 return new_fn, [arg
for arg
in args
if isinstance(arg, torch.Tensor)]
11622 def create_traced_fn(self, fn,
11623 disable_autodiff_subgraph_inlining=
False):
11624 def traced_fn(*inputs, **kwargs):
11625 fn_tensors, inputs_tensors = partial_apply_nontensors(fn, inputs, **kwargs)
11627 self.assertExportImport(traced.graph, inputs_tensors)
11628 if disable_autodiff_subgraph_inlining:
11629 traced.debug_disable_autodiff_subgraph_inlining()
11630 output = traced(*inputs_tensors)
11631 traced_fn.last_graph = traced.graph_for(*inputs_tensors)
11635 script_template =
''' 11636 def the_method({}): 11640 script_method_template =
''' 11646 def get_constant(x):
11648 return 'float(\'inf\')' if PY2
else 'math.inf' 11650 return 'float(\'-inf\')' if PY2
else '-math.inf' 11654 def get_script_args(args):
11659 if isinstance(arg, torch.Tensor):
11660 name =
'i{}'.format(len(formals))
11661 formals.append(name)
11662 actuals.append(name)
11663 tensors.append(arg)
11664 elif isinstance(arg, str):
11665 actuals.append(
"'{}'".format(arg))
11667 actuals.append(str(get_constant(arg)))
11668 return (formals, tensors, actuals)
11674 def create_script_fn(self, method_name, func_type, output_process_fn,
11675 disable_autodiff_subgraph_inlining=
False):
11676 def script_fn(*args, **kwargs):
11677 formals, tensors, actuals = get_script_args(args)
11679 for k, v
in kwargs.items():
11680 kwargs_str +=
', ' + k +
'=' + str(v)
11681 if func_type ==
'functional':
11682 call =
'torch.{}({}{})'.format(method_name,
', '.join(actuals), kwargs_str)
11683 elif func_type ==
'method':
11684 call =
'{}.{}({}{})'.format(actuals[0], method_name,
', '.join(actuals[1:]), kwargs_str)
11685 elif func_type ==
'nn_functional':
11686 call =
'torch.nn.functional.{}({}{})'.format(method_name,
', '.join(actuals), kwargs_str)
11688 raise 'Unsupported function type' 11690 script = script_template.format(
', '.join(formals), call)
11693 if disable_autodiff_subgraph_inlining:
11694 CU.the_method.debug_disable_autodiff_subgraph_inlining()
11695 self.assertExportImport(CU.the_method.graph, tensors)
11696 output = output_process_fn(CU.the_method(*tensors))
11697 script_fn.last_graph = CU.the_method.graph_for(*tensors)
11702 def check_alias_annotation(method_name, args, kwargs):
11703 formals, tensors, actuals = get_script_args(args)
11705 for k, v
in kwargs.items():
11706 kwargs_str +=
', ' + k +
'=' + str(v)
11707 call =
'{}.{}({}{})'.format(actuals[0], method_name,
', '.join(actuals[1:]), kwargs_str)
11708 script = script_template.format(
', '.join(formals), call)
11710 torch._C._jit_check_alias_annotation(CU.the_method.graph, tuple(tensors), method_name)
11713 def check_output_types(self, func, ref_outputs, args, kwargs):
11714 graph = getattr(func,
'last_graph',
None)
11715 types = [o.type()
for o
in graph.outputs()]
11716 self.assertTrue(len(types) == 1)
11718 torch._C._jit_assert_is_instance(ref_outputs, t)
11721 def check_against_reference(self, func, reference_func, args, kwargs=None,
11722 allow_unused=
True, check_types=
True, no_grad=
False):
11723 kwargs = kwargs
if kwargs
else {}
11726 if isinstance(vs, torch.Tensor):
11728 return sum((i + 1) * v.sum()
11729 for i, v
in enumerate(vs)
11730 if v
is not None and v.dtype.is_floating_point)
11732 def clone_inputs(requires_grad):
11734 arg.detach().clone().requires_grad_(requires_grad
and arg.requires_grad)
11735 if isinstance(arg, torch.Tensor)
else arg
for arg
in args
11737 return inputs, [input
for input
in inputs
if isinstance(input, torch.Tensor)
and input.requires_grad]
11739 nograd_inputs, nograd_tensors = clone_inputs(
False)
11740 recording_inputs, recording_tensors = clone_inputs(
True)
11743 outputs = self.runAndSaveRNG(reference_func, nograd_inputs, kwargs)
11744 outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs)
11745 self.assertEqual(outputs, outputs_test)
11748 check_output_types(self, func, outputs_test, nograd_inputs, kwargs)
11755 outputs = self.runAndSaveRNG(reference_func, recording_inputs, kwargs)
11757 allow_unused=allow_unused)
11759 outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs)
11761 allow_unused=allow_unused)
11762 self.assertEqual(outputs, outputs_test)
11763 self.assertEqual(grads, grads_test)
11766 if self._testMethodName
in nn_functional_single_grad:
11769 outputs = self.runAndSaveRNG(reference_func, recording_inputs, kwargs)
11770 l1 = allSum(outputs)
11772 allow_unused=allow_unused)
11773 l2 = (allSum(grads) * l1)
11776 recording_inputs, recording_tensors = clone_inputs(
True)
11778 outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs)
11779 l1_test = allSum(outputs_test)
11781 l1_test, recording_tensors, create_graph=
True, allow_unused=allow_unused)
11782 l2_test = (allSum(grads_test) * l1_test)
11785 self.assertEqual(outputs, outputs_test)
11786 self.assertEqual(grads, grads_test)
11787 for g2, g2_test
in zip(grads2, grads2_test):
11788 if g2
is None and g2_test
is None:
11790 self.assertTrue(torch.allclose(g2, g2_test, atol=5e-4, rtol=1e-4))
11794 def assertAllFused(self, graph, except_for=()):
11795 if [n.kind()
for n
in graph.nodes()] == [
'prim::DifferentiableGraph']:
11796 graph = next(graph.nodes()).g(
'Subgraph')
11797 allowed_nodes = {
'prim::Constant',
'prim::FusionGroup'} | set(except_for)
11798 self.assertTrue(all(node.kind()
in allowed_nodes
for node
in graph.nodes()),
11799 'got {}'.format(graph))
11800 self.assertTrue([node.kind()
for node
in graph.nodes()].count(
'prim::FusionGroup') == 1)
11802 def _test_fused_abs(self, device='cpu'):
11808 a = torch.randn(5, device=device)
11809 self.assertEqual(func(a), a.abs() * 2)
11810 self.assertAllFused(func.graph_for(a))
11812 @unittest.skipIf(IS_WINDOWS
or IS_SANDCASTLE,
"NYI: fuser support for Windows or Sandcastle")
11814 def test_abs_cpu(self):
11815 self._test_fused_abs()
11817 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
11818 @unittest.skipIf(
not RUN_CUDA,
"requires CUDA")
11820 def test_abs_cuda(self):
11821 self._test_fused_abs(device=
"cuda")
11823 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
11824 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
11825 def test_arg_configurations_smoke_cuda(self):
11831 z1, z2 = (x + y).chunk(2, dim=1)
11834 x = torch.randn(4, 4, dtype=torch.float, device=
'cuda')
11835 y = torch.randn(4, 4, dtype=torch.float, device=
'cuda')
11837 self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y))
11839 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
11840 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
11842 def test_broadcast_cuda(self):
11843 def scaleshift(x, scale, shift):
11844 return x * scale + shift
11847 torch.randn(4, 4, dtype=torch.float, device=
'cuda'),
11848 torch.randn(4, dtype=torch.float, device=
'cuda'),
11849 torch.randn(4, dtype=torch.float, device=
'cuda'),
11851 ge = self.checkTrace(scaleshift, inputs)
11852 self.assertAllFused(ge.graph_for(*inputs))
11854 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
11855 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
11856 @unittest.skipIf(
not RUN_CUDA_HALF,
"no half support")
11857 def test_cuda_half(self):
11858 x = torch.randn(4, 4, dtype=torch.half, device=
'cuda')
11859 y = torch.randn(4, 4, dtype=torch.half, device=
'cuda')
11862 self.fn_test_comparison_gt_lt,
11868 inputs = (x.float(), y.float())
11869 fusion_inputs = (x, y)
11871 local_inputs = [t.clone().requires_grad_()
for t
in inputs]
11872 local_fusion_inputs = [t.clone().requires_grad_()
for t
in fusion_inputs]
11875 fusion =
torch.jit.trace(fn, local_fusion_inputs, check_trace=
False, optimize=
True)
11876 outputs = fn(*local_inputs)
11877 fusion_outputs = fusion(*local_fusion_inputs)
11878 outputs_half = [t.half()
for t
in outputs]
11879 self.assertEqual(outputs_half, fusion_outputs)
11882 for output, fusion_output
in zip(outputs_half, fusion_outputs):
11884 output.float().sum(), local_inputs, allow_unused=
True, retain_graph=
True)
11886 fusion_output.sum(), local_fusion_inputs, allow_unused=
True, retain_graph=
True)
11887 grads_half = [t.half()
for t
in grads]
11888 self.assertEqual(grads_half, fusion_grads)
11890 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
11891 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
11893 def test_checks_cat_inputs(self):
11899 return torch.cat([x + 2 * x + x ** 2, y + 4 * y + y ** 3], dim=0)
11903 x = torch.randn(2, 4, dtype=torch.float, device=
'cuda')
11904 y = torch.randn(1, 4, dtype=torch.float, device=
'cuda')
11906 self.assertEqual(f(x, y).shape, (3, 4))
11907 self.assertAllFused(f.graph_for(x, y))
11909 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
11910 @unittest.skipIf(
not RUN_CUDA,
"No CUDA")
11912 def test_chunk_cuda(self):
11914 a, b, c = x.chunk(3, 1)
11917 inputs = [torch.randn(10, 6, dtype=torch.float, device=
'cuda')]
11919 ge = self.checkScript(fn, inputs)
11920 graph = ge.graph_for(*inputs)
11921 self.assertAllFused(graph)
11922 FileCheck().check(
"prim::ConstantChunk[chunks=3, dim=1]").run(str(graph))
11925 def _test_chunk_correctness(self, device='cpu'):
11927 x0, x1, x2, x3 = x.chunk(4, 0)
11928 return x0 + x1 + x2 + x3
11931 x0, x1, x2, x3 = x.chunk(4, 1)
11932 return x0 + x1 + x2 + x3
11934 def chunk_4_last(x):
11935 x0, x1, x2, x3 = x.chunk(4, 2)
11936 return x0 + x1 + x2 + x3
11938 fns = [chunk_4_0, chunk_4_1, chunk_4_last]
11941 torch.randn(4, 4, 4, dtype=torch.float, device=device),
11944 torch.randn(12, 8, 16, dtype=torch.float, device=device),
11947 torch.randn(12, 8, 16, dtype=torch.float, device=device).transpose(1, 2),
11950 for tensor
in tensors:
11952 self.checkScript(fn, [tensor])
11954 @unittest.skipIf(IS_WINDOWS
or IS_SANDCASTLE,
"NYI: fuser support for Windows or Sandcastle")
11956 def test_chunk_correctness(self):
11957 return self._test_chunk_correctness(self,
'cpu')
11959 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
11960 @unittest.skipIf(
not RUN_CUDA,
"No CUDA")
11961 def test_chunk_correctness_cuda(self):
11962 return self._test_chunk_correctness(self,
'cuda')
11964 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
11965 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
11967 def test_chunk_distributes_cuda(self):
11969 z1, z2 = (x + y).chunk(2, dim=1)
11972 x = torch.randn(4, 4, dtype=torch.float, device=
'cuda')
11973 y = torch.randn(4, 4, dtype=torch.float, device=
'cuda')
11975 ge = self.checkTrace(f, (x, y))
11976 graph = ge.graph_for(x, y)
11977 FileCheck().check(
"broadcast_tensors").check(
'with prim::FusionGroup_0') \
11978 .check_count(
'ConstantChunk', 2, exactly=
True).run(str(graph))
11980 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
11981 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
11983 def test_chunk_motion_deduplicates_inputs(self):
11986 z0, z1 = z.chunk(2)
11991 z0, z1 = z.chunk(2)
11995 torch.tensor([1.1, 1.2], device=
'cuda', dtype=torch.float),
11997 for func
in [func1, func2]:
11998 module = self.checkScript(func, inputs)
11999 forward_graph = module.graph_for(*inputs)
12000 self.assertGraphContainsExactly(forward_graph,
'prim::FusionGroup', 1)
12001 fusion_group = list(forward_graph.nodes())[-1]
12002 self.assertEqual(len(list(fusion_group.inputs())), 1)
12004 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
12005 @unittest.skipIf(
not RUN_CUDA,
"No CUDA")
12007 def test_chunk_multiple_cuda(self):
12010 def fn(s, x, y, z):
12011 z1, z2 = z.chunk(2, 2)
12012 x1, x2, x3 = x.chunk(3, 1)
12013 y1, y2 = y.chunk(2, 0)
12014 return s + x1 + x2 + x3 + y1 + y2 + z1 + z2
12017 torch.randn(5, 2, 3, dtype=torch.float, device=
'cuda'),
12018 torch.randn(5, 6, 3, dtype=torch.float, device=
'cuda'),
12019 torch.randn(10, 2, 3, dtype=torch.float, device=
'cuda'),
12020 torch.randn(5, 2, 6, dtype=torch.float, device=
'cuda'),
12023 ge = self.checkScript(fn, inputs)
12024 self.assertAllFused(ge.graph_for(*inputs))
12026 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
12027 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
12029 def test_clamp(self):
12031 return torch.clamp(a + b, min=0, max=2)
12034 return torch.clamp(a + b, min=0, max=float(
'inf'))
12036 def funcOptMin(a, b):
12037 return torch.clamp(a + b, max=2)
12039 def funcOptMax(a, b):
12040 return torch.clamp(a + b, min=0)
12042 a = torch.randn(4, 4, dtype=torch.float, device=
'cuda', requires_grad=
True)
12043 b = torch.randn(4, 4, dtype=torch.float, device=
'cuda')
12046 funcs = (func2, funcInf, funcOptMin, funcOptMax)
12047 for f, inputs
in product(funcs, [[a, b], [a, nan]]):
12048 inp1, inp2 = inputs
12049 s = self.checkScript(f, (inp1, inp2))
12050 self.assertAllFused(s.graph_for(inp1, inp2), except_for={
'aten::size'})
12054 graph = backward_graph(s)
12055 self.assertAllFused(graph)
12057 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
12058 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
12060 def test_comparison_eq_ne(self):
12062 mask = (x == 0).type_as(x)
12064 mask = (x != 0).type_as(x)
12068 x = torch.randn(4, 4, dtype=torch.float, device=
'cuda')
12069 y = torch.randn(4, 4, dtype=torch.float, device=
'cuda')
12071 ge = self.checkTrace(f, (x, y))
12072 self.assertAllFused(ge.graph_for(x, y))
12075 def fn_test_comparison_gt_lt(x, y):
12076 mask = (x > 0).type_as(x)
12078 mask = (x < 0).type_as(x)
12082 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
12083 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
12085 def test_comparison_gt_lt_cuda(self):
12086 x = torch.randn(4, 4, dtype=torch.float, device=
'cuda')
12087 y = torch.randn(4, 4, dtype=torch.float, device=
'cuda')
12089 ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y))
12090 self.assertAllFused(ge.graph_for(x, y))
12092 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
12093 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
12095 def test_comparison_ge_le_cuda(self):
12097 mask = (x >= 0).type_as(x)
12099 mask = (x <= 0).type_as(x)
12103 x = torch.randn(4, 4, dtype=torch.float, device=
'cuda')
12104 y = torch.randn(4, 4, dtype=torch.float, device=
'cuda')
12106 ge = self.checkTrace(f, (x, y))
12107 self.assertAllFused(ge.graph_for(x, y))
12108 x.requires_grad_(
True)
12109 y.requires_grad_(
True)
12110 self.assertAllFused(ge.graph_for(x, y), except_for=(
"aten::size",
"prim::BroadcastSizes"))
12112 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
12113 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
12115 def test_concat_cuda(self):
12116 hx = torch.randn(3, 20, dtype=torch.float, device=
'cuda')
12117 cx = torch.randn(3, 20, dtype=torch.float, device=
'cuda')
12120 return torch.cat((hx + cx, hx * cx))
12122 ge = self.checkTrace(foo, (hx, cx))
12123 graph = ge.graph_for(hx, cx)
12124 self.assertAllFused(graph)
12125 FileCheck().check(
"FusedConcat").check_next(
"return").run(str(graph))
12127 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
12128 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
12130 def test_concat_invariant_cuda(self):
12136 w = torch.cat([x1, y1])
12139 x = torch.randn(2, 2, dtype=torch.float, device=
'cuda')
12140 y = torch.randn(2, 2, dtype=torch.float, device=
'cuda')
12141 z = torch.randn(4, 2, dtype=torch.float, device=
'cuda')
12142 ge = self.checkTrace(fn, (x, y, z))
12143 graph = ge.graph_for(x, y, z)
12144 self.assertAllFused(graph, except_for={
'aten::add'})
12145 FileCheck().check(
"FusedConcat").check_next(
"return").run(str(graph))
12148 def fn_test_exp(x, y):
12149 return (x + .5 * y).exp()
12151 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
12152 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
12154 def test_exp_cuda(self):
12155 x = torch.randn(4, 4, dtype=torch.float, device=
'cuda')
12156 y = torch.randn(4, 4, dtype=torch.float, device=
'cuda')
12158 ge = self.checkTrace(self.fn_test_exp, (x, y))
12159 self.assertAllFused(ge.graph_for(x, y))
12161 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
12162 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
12164 def test_fuse_batch_norm(self):
12167 def __init__(self, optimize=True):
12168 super(ResLike, self).__init__(optimize)
12169 self.bn = nn.BatchNorm2d(16)
12171 @torch.jit.script_method
12172 def forward(self, x, y):
12173 return y + torch.relu(self.bn(x))
12175 model = ResLike().cuda()
12176 model_noopt = ResLike(optimize=
False).cuda()
12177 model_noopt.load_state_dict(model.state_dict())
12178 x = torch.randn(2, 16, 8, 8, device=
'cuda')
12179 y = torch.randn(2, 16, 8, 8, device=
'cuda')
12181 with torch.no_grad():
12183 graph = model.graph_for(x, y)
12186 out_noopt = model_noopt(x, y)
12187 rep_noopt = str(model_noopt.graph_for(x, y))
12188 self.assertEqual(out, out_noopt, prec=3e-5)
12191 self.assertIn(
'aten::batch_norm_update_stats', rep)
12192 self.assertNotIn(
'aten::batch_norm(', rep)
12193 self.assertIn(
'aten::batch_norm(', rep_noopt)
12197 fusion_groups = [node
for node
in graph.nodes()
if node.kind() ==
'prim::FusionGroup']
12198 self.assertEqual(len(fusion_groups), 1)
12199 fused_graph = fusion_groups[0].g(
'Subgraph')
12200 self.assertTrue(any(node.kind() ==
'aten::sqrt' for node
in fused_graph.nodes()))
12202 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
12203 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
12205 def test_threshold(self):
12207 return torch.threshold(x, 0, -10) + x + x + x
12209 x =
torch.tensor([-1, -0.5, 0, 1, 2, 3], device=
'cuda')
12212 self.assertEqual(f(x), scripted(x))
12213 self.assertAllFused(scripted.graph_for(x))
12215 @unittest.skipIf(IS_WINDOWS
or IS_SANDCASTLE,
"NYI: fuser support for Windows or Sandcastle")
12217 def test_fuser_deduplication(self):
12221 return torch.sigmoid(x + y)
12223 b = torch.randn(5, 5, requires_grad=
True)
12224 a = torch.randn(5, 5, requires_grad=
True)
12225 s = self.checkScript(f, (a, b))
12226 self.assertAllFused(s.graph_for(a, b), except_for={
'aten::size'})
12230 graph = backward_graph(s)
12231 self.assertAllFused(graph)
12233 self.assertEqual(ga.data_ptr(), gb.data_ptr())
12235 @unittest.skipIf(IS_WINDOWS
or IS_SANDCASTLE,
"NYI: fuser support for Windows or Sandcastle")
12237 def test_fuser_iou(self):
12240 def iou(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2):
12241 ltx = torch.max(b1x1, b2x1)
12242 lty = torch.max(b1y1, b2y1)
12243 rbx = torch.min(b1x2, b2x2)
12244 rby = torch.min(b1y2, b2y2)
12246 w = (rbx - ltx).clamp(min=0, max=float(
'inf'))
12247 h = (rby - lty).clamp(min=0, max=float(
'inf'))
12250 area1 = (b1x2 - b1x1) * (b1y2 - b1y2)
12251 area2 = (b2x2 - b2x1) * (b2y2 - b2y2)
12252 iou = inter / (area1 + area2 - inter)
12255 box1 = torch.randn(5, 4, requires_grad=
True)
12256 box2 = torch.randn(5, 4, requires_grad=
True)
12258 b1x1 = box1[:, 0].unsqueeze(1)
12259 b1y1 = box1[:, 1].unsqueeze(1)
12260 b1x2 = box1[:, 2].unsqueeze(1)
12261 b1y2 = box1[:, 3].unsqueeze(1)
12262 b2x1 = box2[:, 0].unsqueeze(0)
12263 b2y1 = box2[:, 1].unsqueeze(0)
12264 b2x2 = box2[:, 2].unsqueeze(0)
12265 b2y2 = box2[:, 3].unsqueeze(0)
12267 s = self.checkScript(iou, (b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2))
12268 self.assertAllFused(s.graph_for(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2),
12269 except_for={
'aten::size',
'prim::BroadcastSizes'})
12271 c = s(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2)
12273 graph = backward_graph(s)
12274 self.assertAllFused(graph, except_for={
'aten::size',
'prim::BroadcastSizes'})
12276 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
12277 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
12278 @unittest.skipIf(
not RUN_CUDA_MULTI_GPU,
"needs non-zero device")
12281 def test_fusion_reuse_multi_gpu(self):
12283 return x * y * x * y
12286 torch.randn(4, 4, dtype=torch.float),
12287 torch.randn(4, 4, dtype=torch.float),
12289 inputs_cuda0 = [x.cuda(0)
for x
in inputs_cpu]
12290 inputs_cuda1 = [y.cuda(1)
for y
in inputs_cpu]
12293 ge = self.checkScript(fn, inputs_cpu)
12294 self.assertAllFused(ge.graph_for(*inputs_cpu))
12298 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
12299 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
12300 @unittest.skipIf(
not RUN_CUDA_MULTI_GPU,
"needs non-zero device")
12303 def test_kernel_cache_multi_gpu(self):
12304 def not_fusible(x):
12308 x_out = x * x * x * x * x
12309 y_out = y * y * y * y * y
12310 z_out = z * z * z * z * z
12311 return not_fusible(x_out), not_fusible(y_out), not_fusible(z_out)
12314 torch.randn(4, 4, dtype=torch.float),
12315 torch.randn(4, 4, dtype=torch.float, device=
'cuda:0'),
12316 torch.randn(4, 4, dtype=torch.float, device=
'cuda:1'),
12319 prev_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
12323 ge = self.checkScript(fn, inputs)
12324 self.assertGraphContainsExactly(
12325 ge.graph_for(*inputs),
'prim::FusionGroup', 3,
True)
12326 new_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
12328 self.assertEqual(new_cache_size - prev_cache_size, 1)
12330 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
12331 @unittest.skipIf(
not RUN_CUDA_MULTI_GPU,
"needs non-zero device")
12333 def test_nonzero_device_cuda(self):
12334 device =
'cuda:' + str(1)
12335 x =
torch.tensor([0.4], dtype=torch.float, device=device)
12336 y =
torch.tensor([0.7], dtype=torch.float, device=device)
12339 return torch.sigmoid(torch.tanh(x * (x + y) + x))
12341 ge = self.checkTrace(doit, (x, y))
12342 self.assertAllFused(ge.graph_for(x, y))
12344 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
12345 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
12347 def test_lstm_cuda(self):
12348 inputs = get_lstm_inputs(
'cuda', training=
True)
12349 module = self.checkScript(LSTMCellS, inputs)
12350 forward_graph = module.graph_for(*inputs)
12351 self.assertGraphContainsExactly(
12352 forward_graph,
'prim::FusionGroup', 1, consider_subgraphs=
True)
12353 self.assertTrue(len(list(forward_graph.nodes())) == 2)
12355 FileCheck().check(
"DifferentiableGraph").check_next(
"TupleConstruct") \
12356 .check_next(
"return").run(str(forward_graph))
12358 hy, cy = module(*inputs)
12359 (hy + cy).sum().backward()
12360 backward = backward_graph(module)
12361 FileCheck().check(
"FusionGroup_0").check_next(
"FusionGroup_1") \
12362 .check_not(
"FusionGroup_2").run(str(backward))
12364 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
12365 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
12367 def test_lstm_concat_cuda(self):
12368 inputs = get_lstm_inputs(
'cuda')
12369 ge = self.checkTrace(LSTMCellC, inputs)
12370 graph = ge.graph_for(*inputs)
12371 FileCheck().check(
"FusedConcat").check_next(
"return").run(str(graph))
12373 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
12374 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
12376 def test_lstm_gates_permutations_cuda(self):
12379 choices = [
'x.mm(w_ih.t())',
'hx.mm(w_hh.t())',
'b_ih',
'b_hh']
12380 template = dedent(
''' 12381 def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh): 12382 gates = {} + {} + {} + {} 12383 ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 12384 return ingate * forgetgate * cellgate * outgate 12386 for permutation
in itertools.permutations(choices, len(choices)):
12387 code = template.format(*permutation)
12389 exec(code, globals(), scope)
12392 inputs = get_lstm_inputs(
'cuda', training=
False)
12393 self.assertEqual(cu.cell(*inputs), scope[
'cell'](*inputs))
12394 forward_graph = cu.cell.graph_for(*inputs)
12395 self.assertGraphContainsExactly(forward_graph,
'prim::FusionGroup', 1)
12398 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
12399 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
12401 def test_lstm_traced_cuda(self):
12402 inputs = get_lstm_inputs(
'cuda')
12403 ge = self.checkTrace(LSTMCellF, inputs)
12404 graph = ge.graph_for(*inputs)
12405 FileCheck().check_not(
"Chunk").check_not(
"aten::add").check_not(
"aten::sigmoid") \
12406 .check_not(
"aten::tanh").check(
"FusionGroup").check_next(
"TupleConstruct") \
12407 .check_next(
"return").check_not(
"FusionGroup_1").run(str(graph))
12409 @unittest.skipIf(IS_WINDOWS
or IS_SANDCASTLE,
"NYI: fuser support for Windows or Sandcastle")
12410 @unittest.skip(
"Test is flaky, see https://github.com/pytorch/pytorch/issues/8746")
12412 def test_lstm_traced_cpu(self):
12413 inputs = get_lstm_inputs(
'cpu')
12415 ge = self.checkTrace(LSTMCellF, inputs)
12416 graph = ge.graph_for(*inputs)
12417 FileCheck.check(
"FusionGroup").run(str(graph))
12418 except RuntimeError
as e:
12419 if 'Failed to compile' in e.args[0]:
12420 warnings.warn(
'CPU fuser test has failed! This is not a hard failure, ' 12421 'because the kernels sometimes trigger bugs in compilers ' 12422 '(most notably GCC 7.2).')
12423 raise unittest.SkipTest(
'Failed to compile')
12427 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
12428 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
12430 def test_milstm_cuda(self):
12431 inputs = get_milstm_inputs(
'cuda', training=
True)
12432 module = self.checkScript(MiLSTMCell, inputs)
12433 forward_graph = module.graph_for(*inputs)
12434 self.assertGraphContainsExactly(
12435 forward_graph,
'prim::FusionGroup', 1, consider_subgraphs=
True)
12436 FileCheck().check(
"DifferentiableGraph").check_next(
"TupleConstruct") \
12437 .check_next(
"return").check(
"FusionGroup").run(str(forward_graph))
12438 hy, cy = module(*inputs)
12439 (hy + cy).sum().backward()
12441 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
12442 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
12444 def test_rand_cuda(self):
12446 __constants__ = [
'd']
12448 def __init__(self):
12449 self.d = torch.device(
'cuda')
12451 @torch.jit.script_method
12452 def create(self, x):
12453 return x * x + x + torch.rand_like(x)
12455 x = torch.zeros([3, 4, 5], dtype=torch.float, device=
'cuda')
12459 self.assertNotEqual(out1, out2)
12460 self.assertTrue(torch.all(out1 >= 0))
12461 self.assertTrue(torch.all(out1 < 1))
12462 self.assertTrue(torch.all(out2 >= 0))
12463 self.assertTrue(torch.all(out2 < 1))
12464 self.assertAllFused(m.create.graph_for(x))
12467 def fn_test_relu(x, y):
12468 return F.relu(x + .5 * y)
12470 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
12471 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
12473 def test_relu_cuda(self):
12474 x = torch.randn(4, 4, dtype=torch.float, device=
'cuda')
12475 y = torch.randn(4, 4, dtype=torch.float, device=
'cuda')
12477 ge = self.checkTrace(self.fn_test_relu, (x, y))
12478 self.assertAllFused(ge.graph_for(x, y))
12480 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
12481 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
12483 def test_erf_cuda(self):
12484 def fn_test_erf(x):
12485 return F.relu(torch.erf(x) - torch.erfc(x))
12487 x = torch.randn(4, 4, dtype=torch.float, device=
'cuda')
12488 ge = self.checkTrace(fn_test_erf, (x,))
12489 self.assertAllFused(ge.graph_for(x))
12490 x.requires_grad_(
True)
12491 self.assertAllFused(ge.graph_for(x), except_for=(
"aten::size",
"prim::BroadcastSizes"))
12493 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
12494 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
12496 def test_rand_broadcast_cuda(self):
12497 def fn_test_rand(x, y):
12498 r = torch.rand_like(y)
12501 x = torch.randn(4, 4, dtype=torch.float, device=
'cuda')
12502 y = torch.randn(4, 4, dtype=torch.float, device=
'cuda')
12504 out = script_f(x, y)
12505 self.assertAllFused(script_f.graph_for(x, y))
12506 x.requires_grad_(
True)
12507 out = script_f(x, y)
12508 self.assertAllFused(script_f.graph_for(x, y), except_for=(
"aten::size",
"prim::BroadcastSizes"))
12510 x = torch.ones(4, 4, dtype=torch.float, device=
'cuda')
12511 y = torch.ones(4, dtype=torch.float, device=
'cuda')
12512 out = script_f(x, y)
12513 self.assertEqual(out[0], out[1])
12515 @unittest.skipIf(IS_WINDOWS
or IS_SANDCASTLE,
"NYI: fuser support for Windows or Sandcastle")
12517 def test_scalar(self):
12521 x =
torch.tensor(0.1, dtype=torch.float, device=
'cpu')
12523 ge = self.checkScript(fn, (x, y))
12524 self.assertAllFused(ge.graph_for(x, y))
12526 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
12527 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
12529 def test_small_constant_cuda(self):
12530 def fn_test_small_constant(x, y):
12531 return (1e-8 * x + 5e-9 * y) * 1e8
12532 x = torch.randn(4, 4, dtype=torch.float, device=
'cuda')
12533 y = torch.randn(4, 4, dtype=torch.float, device=
'cuda')
12535 ge = self.checkTrace(fn_test_small_constant, (x, y))
12536 self.assertAllFused(ge.graph_for(x, y))
12538 @unittest.skipIf(IS_WINDOWS,
"NYI: fuser support for Windows")
12539 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
12541 def test_tensor_scalar_ops_cuda(self):
12542 def should_fuse(x):
12549 def should_not_fuse(x, z):
12553 inputs = [torch.randn(2, 2, dtype=torch.float, device=
'cuda')]
12554 ge = self.checkScript(should_fuse, inputs)
12555 self.assertAllFused(ge.graph_for(*inputs))
12558 torch.randn(2, 2, dtype=torch.float, device=
'cuda'),
12561 ge = self.checkScript(should_not_fuse, inputs)
12562 self.assertGraphContainsExactly(
12563 ge.graph_for(*inputs),
'prim::FusionGroup', 0, consider_subgraphs=
True)
12565 @unittest.skipIf(IS_WINDOWS
or IS_SANDCASTLE,
"NYI: fuser support for Windows or Sandcastle")
12567 def test_where_and_typing(self):
12570 res = torch.where(mask, x, y)
12575 x = torch.randn(4, 4, dtype=torch.double)
12576 y = torch.randn(4, 4, dtype=torch.double)
12578 result1, result2 = script_f(x, y)
12579 expected1, expected2 = f(x, y)
12580 self.assertEqual(result1, expected1)
12581 self.assertEqual(result2, expected2)
12582 self.assertAllFused(script_f.graph_for(x, y), except_for={
'prim::TupleConstruct'})
12584 @unittest.skipIf(
not IS_WINDOWS,
"Test that the fuser is disabled on Windows")
12585 @unittest.skipIf(
not RUN_CUDA,
"fuser requires CUDA")
12586 def test_windows_cuda(self):
12587 def scaleshift(x, scale, shift):
12588 return x * scale + shift
12591 torch.randn(4, 4, dtype=torch.float, device=
'cuda'),
12592 torch.randn(4, dtype=torch.float, device=
'cuda'),
12593 torch.randn(4, dtype=torch.float, device=
'cuda'),
12596 ge = self.checkScript(scaleshift, inputs)
12597 self.assertGraphContainsExactly(
12598 ge.graph_for(*inputs),
'prim::FusionGroup', 0, consider_subgraphs=
True)
12612 def _perform_ad_subgraph_slicing(self, fn, *input_sizes):
12614 ge.debug_disable_autodiff_subgraph_inlining()
12615 inputs = [torch.randn(size, requires_grad=
True)
for size
in input_sizes]
12617 return ge.graph_for(*inputs)
12619 def assertGraphSize(self, graph, size):
12620 self.assertEqual(len(list(graph.nodes())), size)
12622 def test_simple_merge(self):
12629 graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
12631 self.assertGraphSize(graph, 1)
12632 self.assertGraphContainsExactly(graph,
'prim::DifferentiableGraph', 1)
12634 def test_simple_no_merge(self):
12642 graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
12644 self.assertGraphSize(graph, 2)
12645 self.assertGraphContainsExactly(graph,
'prim::DifferentiableGraph', 1)
12647 def test_does_not_merge_unrelated(self):
12649 def fn(w, x, y, z):
12654 graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
12656 self.assertGraphSize(graph, 3)
12657 self.assertGraphContainsExactly(graph,
'prim::DifferentiableGraph', 2)
12659 def test_merges_without_cycles(self):
12669 graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
12671 self.assertGraphSize(graph, 1)
12672 self.assertGraphContainsExactly(graph,
'prim::DifferentiableGraph', 1)
12674 def test_merges_dense(self):
12684 return a + c, b + d
12686 graph = self._perform_ad_subgraph_slicing(fn, 2, 2)
12688 self.assertGraphSize(graph, 2)
12689 self.assertGraphContainsExactly(graph,
'prim::DifferentiableGraph', 1)
12691 def test_does_not_create_cycles(self):
12701 graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
12703 self.assertGraphSize(graph, 3)
12704 self.assertGraphContainsExactly(graph,
'prim::DifferentiableGraph', 2)
12706 def test_merges_up(self):
12710 def fn(w, x, y, z):
12716 graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
12718 self.assertGraphSize(graph, 3)
12719 self.assertGraphContainsExactly(graph,
'prim::DifferentiableGraph', 1)
12721 def test_merges_down(self):
12725 def fn(v, w, x, y):
12731 graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
12733 self.assertGraphSize(graph, 3)
12734 self.assertGraphContainsExactly(graph,
'prim::DifferentiableGraph', 1)
12736 def test_respects_lexical_scoping(self):
12744 graph = self._perform_ad_subgraph_slicing(fn, 1, 1)
12748 self.assertGraphContainsExactly(graph,
'prim::DifferentiableGraph', 2)
12753 def test_dynamic_op_registry(self):
12755 self.assertTrue(hasattr(torch,
'ops'))
12757 if '_test' in torch.ops.__dict__:
12758 torch.ops.__dict__.pop(
'_test')
12761 self.assertNotIn(
'_test', torch.ops.__dict__)
12763 self.assertIn(
'_test', torch.ops.__dict__)
12764 self.assertEqual(type(torch.ops._test), _OpNamespace)
12766 self.assertNotIn(
'leaky_relu', torch.ops._test.__dict__)
12767 op = torch.ops._test.leaky_relu
12768 self.assertTrue(callable(op))
12769 self.assertIn(
'leaky_relu', torch.ops._test.__dict__)
12770 op2 = torch.ops._test.leaky_relu
12771 self.assertEqual(op, op2)
12773 def test_simply_calling_an_operator(self):
12774 input = torch.randn(100)
12775 output = torch.ops.aten.relu(input)
12776 self.assertEqual(output, input.relu())
12778 def test_default_arguments_are_used(self):
12779 output = torch.ops._test.leaky_relu(
torch.tensor([-1.0, 1.0]))
12782 def test_only_kwargs(self):
12783 output = torch.ops._test.leaky_relu(self=
torch.tensor(-1.0))
12786 def test_passing_too_many_args(self):
12787 with self.assertRaisesRegex(
12789 r"aten::relu\(\) expected at most 1 argument\(s\) but received 2 argument\(s\)" 12791 torch.ops.aten.relu(1, 2)
12793 def test_passing_too_few_args(self):
12794 with self.assertRaisesRegex(
12796 r"aten::relu\(\) is missing value for argument 'self'." 12798 torch.ops.aten.relu()
12800 def test_passing_one_positional_but_not_the_second(self):
12801 with self.assertRaisesRegex(
12803 r"aten::transpose\(\) is missing value for argument 'dim0'." 12805 torch.ops.aten.transpose(torch.ones(5, 5))
12807 def test_passing_an_argument_both_as_positional_and_kwarg(self):
12808 with self.assertRaisesRegex(
12810 "Argument 'self' specified both as positional and keyword argument" 12812 torch.ops._test.leaky_relu(torch.ones(5), self=torch.ones(5))
12814 def test_passing_unknown_kwargs(self):
12815 with self.assertRaisesRegex(
12817 "Unknown keyword argument 'foo' for operator '_test::leaky_relu'" 12819 torch.ops._test.leaky_relu(torch.ones(5), foo=torch.ones(5))
12821 def test_passing_and_returning_lists(self):
12823 a, b = torch.rand(5), torch.rand(5)
12824 output = torch.ops._test.cat([a, b])
12825 output_ref = torch.cat([a, b])
12826 self.assertEqual(output, output_ref)
12828 def test_calling_scripted_custom_op(self):
12831 return torch.ops.aten.relu(x)
12832 input = torch.ones(5, 5)
12833 self.assertEqual(func(input), input.relu())
12835 def test_calling_traced_custom_op(self):
12836 input = torch.ones(5, 5)
12838 self.assertEqual(func(input), input.relu())
12840 def test_script_graph_for_custom_ops_matches_traced_graph(self):
12841 input = torch.ones(5, 5)
12843 self.assertExpectedInline(canonical(trace.graph),
'''\ 12844 graph(%0 : Double(5, 5)): 12845 %1 : Double(5, 5) = aten::relu(%0) 12849 def test_script_graph_contains_custom_op(self):
12852 return torch.ops.aten.relu(x)
12853 self.assertExpectedInline(canonical(func.graph),
'''\ 12854 graph(%x : Tensor): 12855 %1 : Tensor = aten::relu(%x) 12859 def test_generic_list(self):
12860 self.assertEqual(torch.ops._test.get_first([[
'hello']]),
'hello')
12867 class TestJitGeneratedModule(JitTestCase):
12871 class TestJitGeneratedFunctional(JitTestCase):
12877 UBSAN_BLACKLISTED_TESTS = [
12878 "test___rdiv___constant",
12879 "test___rdiv___scalar_constant",
12881 "test_addcdiv_broadcast_all",
12882 "test_addcdiv_broadcast_rhs",
12883 "test_addcdiv_scalar",
12884 "test_addcdiv_scalar_broadcast_lhs",
12885 "test_addcdiv_scalar_broadcast_rhs",
12886 "test_addcdiv_scalar_scale",
12887 "test_addcdiv_scalar_scale_broadcast_lhs",
12888 "test_addcdiv_scalar_scale_broadcast_rhs",
12889 "test_addcdiv_scale",
12890 "test_addcdiv_scale_broadcast_all",
12891 "test_addcdiv_scale_broadcast_rhs",
12892 "test_add_broadcast_all",
12893 "test_add_broadcast_lhs",
12894 "test_add_broadcast_rhs",
12895 "test_add_constant",
12897 "test_add_scalar_broadcast_lhs",
12898 "test_add_scalar_broadcast_rhs",
12900 "test_div_broadcast_all",
12901 "test_div_broadcast_lhs",
12902 "test_div_broadcast_rhs",
12904 "test_div_scalar_broadcast_lhs",
12905 "test_div_scalar_broadcast_rhs",
12907 "test_rsqrt_scalar",
12910 "test_reciprocal_scalar",
12918 EXCLUDE_MODULE_EXPORT_IMPORT = {
12923 'AdaptiveAvgPool2d',
12924 'AdaptiveAvgPool3d',
12943 nn_functional_tests = [
12944 (
'conv1d', (S, S, S), ((S, S, S),)),
12945 (
'conv2d', (S, S, S, S), ((S, S, S, S),)),
12946 (
'conv3d', (S, S, S, S, S), ((S, S, S, S, S),)),
12947 (
'conv_transpose1d', (S, S, S), ((S, S, S),)),
12948 (
'conv_transpose2d', (S, S, S, S), ((S, S, S, S),)),
12949 (
'conv_transpose3d', (S, S, S, S, S), ((S, S, S, S, S),)),
12950 (
'conv_tbc', (S, S, S), ((S, S, S), (S,), 2)),
12951 (
'avg_pool1d', (S, S, S), (3,)),
12952 (
'avg_pool2d', (S, S, S, S), (3,)),
12953 (
'avg_pool3d', (S, S, S, S, S), (3,)),
12954 (
'fractional_max_pool2d', (S, S, S, S), (3, [2, 3],)),
12955 (
'max_pool1d', (S, S, S), (2, 1)),
12956 (
'max_pool1d', (S, S, S), (2, 1, 1, 1,
False,
True),
'with_indices'),
12957 (
'max_pool2d', (S, S, S, S), (2, 1)),
12958 (
'max_pool3d', (S, S, S, S, S), (2, 1)),
12962 (
'lp_pool1d', (S, S, S), (2., 3, 2,)),
12963 (
'lp_pool2d', (S, S, S, S), (2., 3, 2,)),
12964 (
'adaptive_max_pool1d', (S, S, S), (5,)),
12965 (
'adaptive_max_pool2d', (S, S, S, S), ([5, 7],)),
12966 (
'adaptive_max_pool3d', (S, S, S, S, S), ([3, 2, 2],)),
12967 (
'adaptive_avg_pool1d', (S, S, S), (5,)),
12968 (
'adaptive_avg_pool2d', (S, S, S, S), ([5, 7],)),
12969 (
'adaptive_avg_pool3d', (S, S, S, S, S), ([3, 2, 2],)),
12970 (
'dropout', (S, S, S), (0.5,)),
12971 (
'alpha_dropout', (S, S, S), (0.5,)),
12972 (
'dropout2d', (S, S, S), (0.5,)),
12973 (
'dropout3d', (S, S, S), (0.5,)),
12974 (
'feature_alpha_dropout', (S, S, S), (0.5,)),
12975 (
'threshold', (S, S, S), (0.1, 2.),),
12976 (
'threshold', (S, S, S), (0.1, 2.,
True),
'inplace'),
12977 (
'relu', (S, S, S), (),),
12978 (
'relu', (S, S, S), (),
'inplace'),
12979 (
'glu', (S - 1, S - 1, S - 1), (),),
12980 (
'hardtanh', (S, S, S), (-0.5, 0.5),),
12981 (
'hardtanh', (S, S, S), (-0.5, 0.5,
True),
'inplace'),
12982 (
'relu6', (S, S, S), (),),
12983 (
'relu6', (S, S, S), (
True),
'inplace'),
12984 (
'elu', (S, S, S), (0.9,),),
12985 (
'elu', (S, S, S), (0.9,
True),
'inplace'),
12986 (
'selu', (S, S, S), (),),
12987 (
'selu', (S, S, S), (
True),
'inplace'),
12988 (
'celu', (S, S, S), (0.9,),),
12989 (
'celu', (S, S, S), (0.9,
True),
'inplace'),
12990 (
'leaky_relu', (S, S, S), (0.02,),),
12991 (
'leaky_relu', (S, S, S), (0.02,),
'inplace'),
12992 (
'rrelu', (S, S), (0.1, 0.3,
False),),
12993 (
'rrelu', (S, S), (0.1, 0.3,
False,
True),
'inplace'),
12994 (
'hardshrink', (S, S, S), (0.4,),),
12995 (
'tanhshrink', (S, S, S), (),),
12996 (
'softsign', (S, S, S), (),),
12997 (
'softplus', (S, S, S), (),),
12998 (
'softmin', (S, S, S), (0,),),
12999 (
'softmax', (S, S, S), (0,),),
13000 (
'softmax', (S, S, S), (0, 3, torch.double),
'with_all_args'),
13001 (
'tanh', (S, S, S), (),),
13002 (
'sigmoid', (S, S, S), (),),
13003 (
'log_softmax', (S, S, S), (0,),),
13004 (
'linear', (S, S), ((M, S),),),
13005 (
'bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),),
13006 (
'embedding',
torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ),),
13007 (
'embedding_bag',
torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3),
torch.tensor([0, 4]),),),
13008 (
'batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), ),),
13009 (
'instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),),
13010 (
'layer_norm', (S, S, S, S), ([5],),),
13011 (
'layer_norm', (S, S, S, S), ([5], (S,)),
'with_only_weight'),
13012 (
'layer_norm', (S, S, S, S), ([5],
None, (S,)),
'with_only_bias'),
13013 (
'layer_norm', (S, S, S, S), ([5], (S,), (S,)),
'with_weight_and_bias'),
13014 (
'group_norm', (S, S, S), (1, torch.rand(5),),),
13015 (
'local_response_norm', (S, S, S), (2, ),),
13016 (
'nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (
torch.tensor([1, 0, 4]),),),
13017 (
'poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2),),),
13018 (
'poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2),
True,
True),
'full'),
13019 (
'kl_div', F.log_softmax(torch.randn(S, 10), 1), (F.softmax(torch.randn(S, 10), 1),),),
13020 (
'cross_entropy', (3, S), (torch.randint(S, (3,), dtype=torch.int64),),),
13021 (
'binary_cross_entropy_with_logits', (3,), (torch.empty(3).random_(2), ),),
13022 (
'smooth_l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
13023 (
'l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
13024 (
'mse_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
13025 (
'smooth_l1_loss', (3, S), ((torch.rand(3, S)),),
'with_grad'),
13026 (
'l1_loss', (3, S), ((torch.rand(3, S)),),
'with_grad'),
13027 (
'mse_loss', (3, S), ((torch.rand(3, S)),),
'with_grad'),
13028 (
'margin_ranking_loss', (3, S), ((3, S), (S,)),),
13029 (
'hinge_embedding_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
13030 (
'soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
13031 (
'multilabel_soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
13032 (
'cosine_embedding_loss', (S, S), ((S, S), non_differentiable(torch.rand(S,))),),
13033 (
'pixel_shuffle', (1, 9, 4, 4), (3,),),
13034 (
'affine_grid', (S, 2, 3), (torch.Size([S, 1, 7, 7]),),),
13035 (
'pad', (3, 3, 4, 2), ([1, 1],),),
13036 (
'pairwise_distance', (S, S), ((S, S),),),
13037 (
'pdist', (S, S), (),),
13038 (
'cosine_similarity', (S, S), ((S, S),),),
13039 (
'triplet_margin_loss', (S, S), ((S, S), (S, S)),),
13040 (
'normalize', (S, S, S), (),),
13041 (
'unfold', (S, S, S, S), ([2, 3]),),
13042 (
'fold', (1, 3 * 2 * 2, 12), ([4, 5], [2, 2]),),
13043 (
'grid_sample', (S, S, S, S), (non_differentiable(torch.rand(S, S, S, 2)),),),
13044 (
'gumbel_softmax', (S, S), (2.,),),
13045 (
'gumbel_softmax', (S, S), (2.,
True,),
'hard'),
13047 (
'multi_margin_loss', (S, S), (non_differentiable(torch.randint(S, (S, ), dtype=torch.int64)),
13048 1, 1., non_differentiable(torch.randn(S))),),
13049 (
'binary_cross_entropy', torch.randn(3, 2).sigmoid(), (non_differentiable(torch.rand(3, 2)),
13050 non_differentiable(torch.randn(3, 2))),),
13051 (
'binary_cross_entropy', torch.randn(3, 2).sigmoid(),
13052 (non_differentiable(torch.rand(3, 2)),
13053 non_differentiable(torch.randn(3, 2)),
None,
None,
'mean'),
'size_average'),
13054 (
'ctc_loss', torch.rand(S, S, S).log_softmax(2).detach().requires_grad_(),
13055 (torch.randint(1, S, (S, S), dtype=torch.long), torch.full((S,), S, dtype=torch.long),
13056 torch.randint(1, S, (S,), dtype=torch.long))),
13057 (
'upsample', torch.randn(S, S, M, M), (
None, 2),
'with_scale'),
13058 (
'upsample', torch.randn(S, S, M, M), (4,),
'with_size'),
13059 (
'interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,),
'nearest_4d'),
13060 (
'interpolate', torch.randn(S, S, M, M), (
None, 2.),
'nearest_4d_with_scale'),
13061 (
'interpolate', torch.randn(S, S, M, M), (4,),
'nearest_4d_with_size'),
13062 (
'interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,),
'area_4d'),
13063 (
'interpolate', torch.randn(S, S, M, M), (
None, 2.),
'area_4d_with_scale'),
13064 (
'interpolate', torch.randn(S, S, M, M), (4,),
'area_4d_with_size'),
13065 (
'interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,),
'bilinear_4d'),
13066 (
'interpolate', torch.randn(S, S, M, M), (
None, 2.),
'bilinear_4d_with_scale'),
13067 (
'interpolate', torch.randn(S, S, M, M), (4,),
'bilinear_4d_with_size'),
13068 (
'interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,),
'bicubic_4d'),
13069 (
'interpolate', torch.randn(S, S, M, M), (
None, 2.),
'bicubic_4d_with_scale'),
13070 (
'interpolate', torch.randn(S, S, M, M), (4,),
'bicubic_4d_with_size'),
13071 (
'interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,),
'nearest_3d'),
13072 (
'interpolate', torch.randn(S, M, M), (
None, 2.),
'nearest_3d_with_scale'),
13073 (
'interpolate', torch.randn(S, M, M), (4,),
'nearest_3d_with_size'),
13074 (
'interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,),
'area_3d'),
13075 (
'interpolate', torch.randn(S, M, M), (
None, 2.),
'area_3d_with_scale'),
13076 (
'interpolate', torch.randn(S, M, M), (4,),
'area_3d_with_size'),
13077 (
'interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,),
'linear_3d'),
13078 (
'interpolate', torch.randn(S, M, M), (
None, 2.),
'linear_3d_with_scale'),
13079 (
'interpolate', torch.randn(S, M, M), (4,),
'linear_3d_with_size'),
13080 (
'interpolate', torch.randn(S, M, M, M, M), (
None, 2.),
'nearest_5d_with_scale'),
13081 (
'interpolate', torch.randn(S, M, M, M, M), (4,),
'nearest_5d_with_size'),
13082 (
'interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,),
'area_5d'),
13083 (
'interpolate', torch.randn(S, M, M, M, M), (
None, 2.),
'area_5d_with_scale'),
13084 (
'interpolate', torch.randn(S, M, M, M, M), (4,),
'area_5d_with_size'),
13085 (
'interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,),
'trilinear_5d'),
13086 (
'interpolate', torch.randn(S, M, M, M, M), (
None, 2.),
'trilinear_5d_with_scale'),
13087 (
'interpolate', torch.randn(S, M, M, M, M), (4,),
'trilinear_5d_with_size'),
13092 nn_functional_single_grad = frozenset(
'test_nn_' + name
for name
in [
13094 'multilabel_margin_loss',
13096 'multi_margin_loss',
13097 'binary_cross_entropy',
13098 'binary_cross_entropy_size_average',
13105 additional_module_tests = [
13107 'module_name':
'Bilinear',
13108 'constructor_args': (S, S, M),
13109 'input_size': (S, S),
13110 'extra_args': ((S, S),)
13113 'module_name':
'RNNCell',
13114 'constructor_args': (S, S),
13115 'input_size': (S, S),
13118 'module_name':
'LSTMCell',
13119 'constructor_args': (S, S),
13120 'input_size': (S, S),
13123 'module_name':
'GRUCell',
13124 'constructor_args': (S, S),
13125 'input_size': (S, S),
13130 def add_autograd_test(
13137 output_process_fn=
lambda x: x,
13139 basic_test_name =
'test_' + name
13140 if variant_name !=
'':
13141 basic_test_name +=
'_' + variant_name
13143 for dim_perm
in product([-1, 1], repeat=len(dim_args_idx)):
13144 test_name = basic_test_name
13145 new_args = [arg * dim_perm[dim_args_idx.index(i)]
if i
in dim_args_idx
else arg
for i, arg
in enumerate(args)]
13146 test_name = basic_test_name +
''.join(
'_neg' + str(i)
for i, idx
in enumerate(dim_perm)
if idx < 0)
13147 new_args = tuple(new_args)
13151 def do_test(self, name=name, self_size=self_size, args=new_args, test_name=test_name,
13152 output_process_fn=output_process_fn):
13155 is_magic_method = name[:2] ==
'__' and name[-2:] ==
'__' 13156 is_inplace = name[-1] ==
"_" and not is_magic_method
13157 self_variable = create_input((self_size,))[0][0]
13160 self_variable.requires_grad =
False 13162 args_variable, kwargs_variable = create_input(args, requires_grad=
not is_inplace, call_kwargs=kwargs)
13163 self_tensor = deepcopy(self_variable.data)
13164 args_tensor = deepcopy(unpack_variables(args_variable))
13166 def fn(*inputs, **kwargs):
13167 output = getattr(inputs[0], name)(*inputs[1:], **kwargs)
13168 return output_process_fn(output)
13170 check_types = test_name
not in EXCLUDE_TYPE_CHECK
13172 if not is_inplace
and name
not in EXCLUDE_GRADCHECK
and not exclude_tensor_method(name, test_name):
13176 if test_name
not in EXCLUDE_TRACED:
13177 check_against_reference(self,
13178 create_traced_fn(self, fn,
13179 disable_autodiff_subgraph_inlining=
True),
13180 fn, (self_variable,) + args_variable, kwargs_variable,
13181 check_types=check_types)
13183 if not is_magic_method
and test_name
not in EXCLUDE_SCRIPT:
13184 check_against_reference(self,
13185 create_script_fn(self, name,
'method', output_process_fn,
13186 disable_autodiff_subgraph_inlining=
True),
13187 fn, (self_variable,) + args_variable, kwargs_variable,
13188 check_types=check_types)
13191 if hasattr(torch, name)
and name
not in EXCLUDE_FUNCTIONAL:
13192 def fn(*inputs, **kwargs):
13193 output = getattr(torch, name)(*inputs, **kwargs)
13194 return output_process_fn(output)
13196 f_args_variable = (self_variable,) + args_variable
13197 f_args_tensor = (self_tensor,) + args_tensor
13199 if not is_inplace
and test_name
not in EXCLUDE_TRACED:
13200 check_against_reference(self,
13201 create_traced_fn(self, fn,
13202 disable_autodiff_subgraph_inlining=
True),
13203 fn, f_args_variable, kwargs_variable, check_types=check_types)
13205 if not is_inplace
and test_name
not in EXCLUDE_SCRIPT:
13206 check_against_reference(self,
13207 create_script_fn(self, name,
'functional', output_process_fn,
13208 disable_autodiff_subgraph_inlining=
True),
13209 fn, f_args_variable, kwargs_variable,
13210 check_types=check_types)
13213 if is_inplace
and test_name
not in EXCLUDE_SCRIPT:
13214 check_alias_annotation(name, (self_variable,) + args_variable, kwargs_variable)
13217 inplace_name = name +
'_' 13219 broadcast_skip_inplace =
'broadcast_lhs' in test_name
or 'broadcast_all' in test_name
13220 if hasattr(torch.ones(1), inplace_name)
and not broadcast_skip_inplace:
13221 check(inplace_name)
13223 post_add_test(test_name, skipTestIf, do_test, TestJitGeneratedAutograd)
13226 def suppress_warnings(fn):
13228 def wrapper(*args, **kwargs):
13229 with warnings.catch_warnings(record=
True):
13230 return fn(*args, **kwargs)
13234 def add_nn_functional_test(name, self_size, args, variant_name='', skipTestIf=(),
13235 output_process_fn=
lambda x: x, kwargs=
None):
13236 test_name =
'test_nn_' + name
13238 if variant_name !=
'':
13239 test_name = test_name +
'_' + variant_name
13241 no_grad = variant_name ==
'inplace' 13244 def do_test(self, name=name, args=args, test_name=test_name):
13245 torch.manual_seed(2)
13247 self_variable = create_input((self_size,))[0][0]
13250 args_variable, kwargs_variable = create_input(args, call_kwargs=kwargs)
13252 self_tensor = deepcopy(self_variable.data)
13253 args_tensor = deepcopy(unpack_variables(args_variable))
13256 output_variable = getattr(F, name)(self_variable, *args_variable, **kwargs_variable)
13258 def fn(*inputs, **kwargs):
13259 output = getattr(F, name)(*inputs, **kwargs)
13260 return output_process_fn(output)
13262 f_args_variable = (self_variable,) + args_variable
13263 f_args_tensor = (self_tensor,) + args_tensor
13265 if test_name
not in EXCLUDE_SCRIPT:
13266 disable_ad_subgraph_inlining = test_name
in DISABLE_AUTODIFF_SUBGRAPH_INLINING
13269 script_fn = create_script_fn(self, name,
'nn_functional', output_process_fn,
13270 disable_autodiff_subgraph_inlining=disable_ad_subgraph_inlining)
13271 check_against_reference(self, script_fn, fn, f_args_variable, kwargs_variable, no_grad=no_grad)
13273 if test_name
in EXCLUDE_PYTHON_PRINT:
13274 with self.disableModuleHook():
13279 post_add_test(test_name, skipTestIf, do_test, TestJitGeneratedFunctional)
13282 def add_nn_module_test(*args, **kwargs):
13283 if 'module_name' in kwargs:
13284 name = kwargs[
'module_name']
13285 elif 'fullname' in kwargs:
13286 name = kwargs[
'fullname']
13287 elif 'constructor' in kwargs:
13288 name = kwargs[
'constructor'].__name__
13290 no_grad =
False if 'no_grad' not in kwargs
else kwargs[
'no_grad']
13292 module_name = name.split(
"_")[0]
13294 module = getattr(
torch.nn, module_name,
None)
13295 if module
is None or torch._jit_internal.weak_types.get(module)
is None:
13298 if 'desc' in kwargs
and 'eval' in kwargs[
'desc']:
13303 if 'desc' in kwargs:
13304 test_name =
"{}_{}".format(test_name, kwargs[
'desc'])
13305 test_name =
'test_nn_{}'.format(test_name)
13309 if test_name
in EXCLUDE_SCRIPT_MODULES:
13311 if 'constructor' in kwargs:
13312 nn_module = kwargs[
'constructor']
13314 nn_module = getattr(
torch.nn, name)
13316 if "FunctionalModule" in str(nn_module):
13319 if 'constructor_args_fn' in kwargs:
13320 constructor_args = kwargs[
'constructor_args_fn']()
13322 constructor_args = kwargs.get(
'constructor_args', ())
13326 def create_script_module(*args, **kwargs):
13327 formals, tensors, actuals = get_script_args(args)
13329 method_args =
', '.join([
'self'] + actuals)
13330 call_args_str =
', '.join(actuals)
13331 call =
"self.submodule({})".format(call_args_str)
13332 script = script_method_template.format(method_args, call)
13334 submodule_constants = []
13335 if kwargs.get(
'is_constant'):
13336 submodule_constants = [
'submodule']
13340 __constants__ = submodule_constants
13342 def __init__(self):
13343 super(TheModule, self).__init__()
13344 self.submodule = nn_module(*constructor_args)
13346 if module_name
in EXCLUDE_MODULE_EXPORT_IMPORT:
13347 with self.disableModuleHook():
13348 module = TheModule()
13349 module.define(script)
13350 create_script_module.last_graph = module.graph
13351 mod = module(*args)
13353 module = TheModule()
13354 module.define(script)
13355 self.assertExportImportModule(module, tensors)
13356 create_script_module.last_graph = module.graph
13357 mod = module(*args)
13362 def create_nn_module(*args, **kwargs):
13363 module = nn_module(*constructor_args)
13364 return module(*args)
13367 if 'input_fn' in kwargs:
13368 input = kwargs[
'input_fn']()
13370 input = (kwargs[
'input_size'],)
13373 if 'extra_args' in kwargs:
13374 input = input + kwargs[
'extra_args']
13376 if 'target_size' in kwargs:
13377 input = input + (kwargs[
'target_size'],)
13378 elif 'target_fn' in kwargs:
13381 input = input + (kwargs[
'target_fn'](),)
13383 args_variable, kwargs_variable = create_input(input)
13384 f_args_variable = deepcopy(unpack_variables(args_variable))
13387 check_against_reference(self, create_script_module, create_nn_module, f_args_variable, no_grad=no_grad)
13389 post_add_test(test_name, (), do_test, TestJitGeneratedModule)
13392 def post_add_test(test_name, skipTestIf, do_test, test_class):
13393 assert not hasattr(test_class, test_name),
'Two tests have the same name: ' + test_name
13395 for skip
in skipTestIf:
13396 do_test = skip(do_test)
13398 if not (TEST_WITH_UBSAN
and test_name
in UBSAN_BLACKLISTED_TESTS):
13399 setattr(test_class, test_name, do_test)
13402 class TestAsync(JitTestCase):
13403 def test_async_python(self):
13406 return torch.neg(x)
13408 x = torch.rand(3, 4)
13414 def test_async_parsing(self):
13418 return [torch.neg(x), x.t()]
13425 Future[List[Tensor]],
13428 futures.append(future)
13435 x = torch.rand(3, 3)
13437 self.assertEqual(len(result), 3)
13439 def test_async_script(self):
13442 return torch.neg(x), x
13444 x = torch.rand(3, 4)
13447 def wait_script(x):
13453 y, y_hat = wait_script(x)
13455 self.assertEqual(y, y_hat)
13457 def test_async_script_capture(self):
13459 __constants__ = [
'const']
13461 def __init__(self):
13462 super(Mod, self).__init__(
False)
13464 self.param = nn.Parameter(torch.randn(2, 2))
13466 @torch.jit.script_method
13467 def foo(self, x1, x2):
13468 return torch.neg(x1), self.param, self.const, torch.neg(x2), self.param
13470 @torch.jit.script_method
13471 def wait_script(self, x1, x2):
13473 y_hat = self.foo(x1, x2)
13477 x1 = torch.rand(3, 4)
13478 x2 = torch.rand(5, 6)
13481 y, y_hat = m.wait_script(x1, x2)
13483 self.assertEqual(y, y_hat)
13485 def test_async_script_nested(self):
13488 return torch.neg(x), x
13490 x = torch.rand(3, 4)
13493 def wait_script(x):
13500 def wait_script_nest(x):
13504 y, y_hat = wait_script_nest(x)
13506 self.assertEqual(y, y_hat)
13508 def test_async_script_no_script_mod(self):
13509 x = torch.rand(3, 4)
13511 with self.assertRaisesRegex(RuntimeError,
'cannot call a value'):
13513 def wait_script(x):
13517 def test_async_script_multi_waits(self):
13520 return torch.neg(x).t() + x
13523 def wait_script(x):
13531 x = torch.rand(2, 2)
13532 y1, y2 = wait_script(x)
13533 self.assertEqual(y1, y2)
13535 def test_async_script_multi_forks(self):
13538 return torch.neg(x).t() + x
13542 return torch.neg(x).t() + x + torch.neg(y).t()
13546 return torch.neg(z).t() + y.t() + x
13548 x1 = torch.rand(10, 10)
13549 x2 = torch.rand(10, 10)
13550 x3 = torch.rand(10, 10)
13553 def wait_script(x1, x2, x3):
13567 y1, y2, y3 = wait_script(x1, x2, x3)
13568 self.assertEqual(y1, foo1(x1))
13569 self.assertEqual(y2, foo2(x1, x2))
13570 self.assertEqual(y3, foo3(x1, x2, x3))
13572 def test_async_script_trace(self):
13573 class Traced(nn.Module):
13574 def __init__(self):
13575 super(Traced, self).__init__()
13577 def forward(self, x):
13578 return (torch.neg(x), x)
13581 def __init__(self):
13582 super(Mod, self).__init__(
False)
13583 x = torch.rand(3, 3)
13586 @torch.jit.script_method
13587 def forward(self, x):
13596 tensor_list.append(tensor_tuple[0])
13597 tensor_list.append(tensor_single)
13600 return (tensor_list, tensor_tuple, tensor_tuple[1])
13602 class TupleCl(nn.Module):
13603 def __init__(self):
13604 super(TupleCl, self).__init__()
13605 self.module = Mod()
13607 def forward(self, x):
13610 list = [z, y[0][0], y[0][1], y[1][0], y[1][1], y[2]]
13613 x = torch.rand(3, 3)
13617 self.assertGraphContainsExactly(module.graph, kind=
'prim::fork', num_kind_nodes=2)
13619 self.assertGraphContainsExactly(module.graph, kind=
'aten::neg', num_kind_nodes=1)
13620 self.assertGraphContainsExactly(module.graph, kind=
'aten::neg', num_kind_nodes=3, consider_subgraphs=
True)
13623 self.assertEqual(module(x), (y, y, y, y, x, x))
13625 def test_async_script_error(self):
13626 x = torch.rand(3, 4)
13634 def wait_script(x):
13639 def wait_script_nest(x):
13644 error_msg =
'The size.*must match the size of tensor' 13645 with self.assertRaisesRegex(Exception, error_msg):
13649 with self.assertRaisesRegex(Exception, error_msg):
13653 x = torch.rand(3, 4, 5)
13654 with self.assertRaisesRegex(Exception,
'expects a tensor with <= 2 dimensions'):
13655 wait_script_nest(x)
13657 def test_async_grad_guard_with_grad(self):
13661 return y.requires_grad
13668 return (requires_grad_in_fork, z.requires_grad)
13670 x = torch.randn(3, requires_grad=
True)
13672 with torch.enable_grad():
13673 (inside_fork, after_wait) = bar(x)
13675 self.assertEqual(inside_fork,
True)
13676 self.assertEqual(after_wait,
True)
13678 def test_async_grad_guard_no_grad(self):
13682 return y.requires_grad
13689 return (requires_grad_in_fork, z.requires_grad)
13691 x = torch.randn(3, requires_grad=
True)
13693 with torch.no_grad():
13694 (inside_fork, after_wait) = bar(x)
13696 self.assertEqual(inside_fork,
False)
13697 self.assertEqual(after_wait,
False)
13699 def test_trace_fork_wait(self):
13701 return x.neg(), x.neg() + 1
13706 return vals[0], vals[1], x - 1
13709 x = torch.rand(3, 4)
13710 self.assertEqual(fn(x), traced(x))
13712 self.assertGraphContainsExactly(traced.graph, kind=
'prim::fork', num_kind_nodes=1)
13713 self.assertGraphContainsExactly(traced.graph, kind=
'aten::wait', num_kind_nodes=1)
13714 self.assertGraphContainsExactly(traced.graph, kind=
'aten::neg', num_kind_nodes=2, consider_subgraphs=
True)
13716 def test_trace_fork_wait_leaking(self):
13720 my_list.append(x + 1)
13728 with self.assertRaisesRegex(RuntimeError,
'did not have observable data dependence with trace inputs; ' 13729 'this probably indicates your program cannot be understood ' 13733 def test_trace_fork_wait_inline(self):
13735 return x + 1, x + 2
13743 torch._C._jit_pass_inline_fork_wait(traced.graph)
13744 torch._C._jit_pass_dce(traced.graph)
13745 self.assertGraphContainsExactly(traced.graph, kind=
'prim::fork', num_kind_nodes=0)
13746 self.assertGraphContainsExactly(traced.graph, kind=
'aten::wait', num_kind_nodes=0)
13747 self.assertGraphContainsExactly(traced.graph, kind=
'aten::add', num_kind_nodes=2)
13749 def test_trace_fork_wait_inline_onnx(self):
13751 return torch.neg(x), torch.neg(x)
13753 class MyMod(torch.nn.Module):
13754 def forward(self, x):
13763 def test_save_load_with_extra_files(self):
13765 @torch.jit.script_method
13766 def forward(self, a):
13769 expected_extra_files = torch._C.ExtraFilesMap()
13770 expected_extra_files[
'foo'] =
'bar' 13774 with TemporaryFileName()
as fname:
13775 m.save(fname, _extra_files=expected_extra_files)
13776 extra_files = torch._C.ExtraFilesMap()
13777 extra_files[
'foo'] =
'' 13779 self.assertEqual(
'bar', extra_files[
'foo'])
13783 extra_files[
'foo'] =
'' 13785 self.assertEqual(
'bar', extra_files[
'foo'])
13788 buffer = io.BytesIO(m.save_to_buffer(_extra_files=expected_extra_files))
13789 extra_files = torch._C.ExtraFilesMap()
13790 extra_files[
'foo'] =
'' 13792 self.assertEqual(
'bar', extra_files[
'foo'])
13795 buffer = io.BytesIO()
13798 extra_files = torch._C.ExtraFilesMap()
13799 extra_files[
'foo'] =
'' 13801 self.assertEqual(
'bar', extra_files[
'foo'])
13804 with self.assertRaises(RuntimeError):
13805 extra_files[
'bar'] =
'' 13809 class TestDataParallel(JitTestCase):
13810 class Mpy(torch.nn.Module):
13811 def __init__(self):
13812 super(TestDataParallel.Mpy, self).__init__()
13813 self.m = nn.Sequential(nn.Linear(2, 2), nn.BatchNorm1d(2),
13814 nn.ReLU(), nn.Linear(2, 2))
13816 def forward(self, input):
13817 return self.m(input)
13819 class Mpy1(torch.nn.Module):
13820 def __init__(self, block):
13821 super(TestDataParallel.Mpy1, self).__init__()
13824 def forward(self, input):
13825 return self.m.forward(input)
13827 class Mpy2(torch.nn.Module):
13828 def __init__(self, block1, block2):
13829 super(TestDataParallel.Mpy2, self).__init__()
13833 def forward(self, input):
13834 x = self.m1.forward(input)
13839 __constants__ = [
'm']
13841 def __init__(self):
13842 super(TestDataParallel.Msm, self).__init__(
False)
13843 self.m = nn.Sequential(nn.Linear(2, 2), nn.BatchNorm1d(2),
13844 nn.ReLU(), nn.Linear(2, 2))
13846 @torch.jit.script_method
13847 def forward(self, input):
13848 return self.m(input)
13851 def __init__(self, block):
13852 super(TestDataParallel.Msm1, self).__init__(
False)
13855 @torch.jit.script_method
13856 def forward(self, input):
13857 x = self.block(input)
13860 def check_replicas(self, module, replicas, input_shape=(2, 2)):
13861 input = torch.randn(input_shape).cuda()
13862 expected_output = module(input).data
13863 for i, replica
in enumerate(replicas):
13864 for p
in replica.parameters():
13865 self.assertEqual(p.get_device(), i)
13866 for b
in replica.buffers():
13867 self.assertEqual(b.get_device(), i)
13868 replica_input = input.cuda(i)
13869 self.assertEqual(replica(replica_input).data, expected_output)
13871 @unittest.skipIf(
not RUN_CUDA_MULTI_GPU,
"multi-GPU not supported")
13873 def test_python_submodule_exception(self):
13874 module = self.Msm1(self.Mpy()).cuda()
13875 msg =
"Cannot replicate.*" 13876 with self.assertRaisesRegex(Exception, msg):
13877 dp.replicate(module, {0, 1})
13879 @unittest.skipIf(
not RUN_CUDA_MULTI_GPU,
"multi-GPU not supported")
13881 def test_python_submodule_script(self):
13882 module = self.Mpy1(self.Msm()).cuda()
13883 replicas = dp.replicate(module, {0, 1})
13884 self.check_replicas(module, replicas)
13886 @unittest.skipIf(
not RUN_CUDA_MULTI_GPU,
"multi-GPU not supported")
13888 def test_shared_module(self):
13891 module = self.Mpy2(p1, s).cuda()
13892 replicas = dp.replicate(module, {0, 1})
13893 self.check_replicas(module, replicas)
13895 @unittest.skipIf(
not RUN_CUDA_MULTI_GPU,
"multi-GPU not supported")
13897 def test_traced_module(self):
13898 module =
torch.jit.trace(self.Mpy1(self.Mpy()), torch.ones(2, 2)).cuda()
13899 replicas = dp.replicate(module, {0, 1})
13900 self.check_replicas(module, replicas)
13902 @unittest.skipIf(
not RUN_CUDA_MULTI_GPU,
"multi-GPU not supported")
13904 def test_tensor_sharing(self):
13905 module = self.Msm1(self.Msm()).cuda()
13906 replica = dp.replicate(module, {0, 1})
13907 optimizer = optim.SGD(module.parameters(), lr=1, momentum=1)
13908 x = torch.ones(2, 2, requires_grad=
True).cuda()
13909 first_forward = module.forward(x)
13910 first_forward.sum().backward()
13912 second_forward = module.forward(first_forward)
13916 r0_forward = replica[0].forward(x)
13917 self.assertEqual(second_forward, r0_forward)
13921 x1 = torch.ones(2, 2, requires_grad=
True).cuda(device=1)
13922 r1_forward = replica[1].forward(x1)
13923 self.assertEqual(first_forward, r1_forward)
13926 class TestClassType(JitTestCase):
13927 def test_get_with_method(self):
13930 def __init__(self, x):
13933 def getFooTest(self):
13939 return foo.getFooTest()
13941 input = torch.ones(2, 3)
13942 self.assertEqual(fn(input), input)
13944 def test_get_attr(self):
13947 def __init__(self, x):
13955 input = torch.ones(2, 3)
13956 self.assertEqual(fn(input), input)
13958 def test_set_attr_in_method(self):
13961 def __init__(self, x):
13965 def incFooTest(self, y):
13967 self.foo = self.foo + y
13976 self.assertEqual(fn(1), 3)
13978 def test_set_attr_type_mismatch(self):
13979 with self.assertRaisesRegex(RuntimeError,
"Wrong type for attribute assignment"):
13982 def __init__(self, x):
13986 def test_get_attr_not_initialized(self):
13987 with self.assertRaisesRegex(RuntimeError,
"Tried to access to nonexistent attribute"):
13990 def __init__(self, x):
13993 def get_non_initialized(self):
13996 def test_set_attr_non_initialized(self):
13997 with self.assertRaisesRegex(RuntimeError,
"Tried to set nonexistent attribute"):
14000 def __init__(self, x):
14003 def set_non_initialized(self, y):
14006 def test_type_annotations(self):
14007 with self.assertRaisesRegex(RuntimeError,
"expected a value of type bool"):
14010 def __init__(self, x):
14020 def test_conditional_set_attr(self):
14021 with self.assertRaisesRegex(RuntimeError,
"assignment cannot be in a control-flow block"):
14024 def __init__(self, x):
14028 def test_class_type_as_param(self):
14031 def __init__(self, x):
14044 input = torch.ones(1)
14045 self.assertEqual(fn2(input), input)
14047 def test_out_of_order_methods(self):
14050 def __init__(self, x):
14052 self.x = self.get_stuff(x)
14054 def get_stuff(self, y):
14062 input = torch.ones(1)
14063 self.assertEqual(fn(input), input + input)
14065 def test_save_load_with_classes(self):
14068 def __init__(self, x):
14075 @torch.jit.script_method
14076 def forward(self, a):
14082 buffer = io.BytesIO()
14087 torch._C._jit_clear_class_registry()
14092 input = torch.rand(2, 3)
14093 output = m_loaded(input)
14094 self.assertEqual(input, output)
14096 def test_save_load_with_classes_nested(self):
14098 class FooNestedTest:
14099 def __init__(self, y):
14103 class FooNestedTest2:
14104 def __init__(self, y):
14106 self.nested = FooNestedTest(y)
14110 def __init__(self, x):
14111 self.class_attr = FooNestedTest(x)
14112 self.class_attr2 = FooNestedTest2(x)
14113 self.x = self.class_attr.y + self.class_attr2.y
14116 @torch.jit.script_method
14117 def forward(self, a):
14123 buffer = io.BytesIO()
14128 torch._C._jit_clear_class_registry()
14133 input = torch.rand(2, 3)
14134 output = m_loaded(input)
14135 self.assertEqual(2 * input, output)
14138 for test
in autograd_method_tests():
14139 add_autograd_test(*test)
14141 for test
in nn_functional_tests:
14142 add_nn_functional_test(*test)
14144 for test
in module_tests + new_module_tests + additional_module_tests:
14145 add_nn_module_test(**test)
14147 for test
in criterion_tests:
14148 test[
'no_grad'] =
True 14149 add_nn_module_test(**test)
14151 if __name__ ==
'__main__':
def export_to_pretty_string(args, kwargs)
def _calculate_fan_in_and_fan_out(tensor)
def assertEqual(self, x, y, prec=None, message='', allow_inf=False)
def _test_dcgan_models(self, device, check_export_import=True)
def createScriptModuleFromGraph(self, trace)
def assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True)
def load(f, map_location=None, _extra_files=DEFAULT_EXTRA_FILES_MAP)
def reshape_from_tensor_shape(x, shape)
def run_ge_tests(self, optimize, use_cuda)
def checkTracerWarning(self, args, kwargs)
def getExportImportCopyWithPacking(self, m, also_test_file=True, map_location=None)
def format_code(self, code, pair)
def _test_mnist(self, device, check_export_import=True)
def assertExpected(self, s, subname=None)
def verify(model, args, loss_fn=torch.sum, devices=None)
def rand_batch(self, dims)
def _test_reinforcement_learning(self, device, test_export_import=True)
def annotate(the_type, the_value)
def pack_sequence(sequences, enforce_sorted=True)
def script(obj, optimize=True, _frames_up=0, _rcb=None)
def disableModuleHook(self)
def trace(func, example_inputs, optimize=True, check_trace=True, check_inputs=None, check_tolerance=1e-5, _force_outplace=False, _module_class=None)
def test_arg_configurations(self)
def test_addmm_grad(self)
def assertNotEqual(self, x, y, prec=None, message='')
def get_trace_graph(f, args=(), kwargs=None, _force_outplace=False, return_inputs=False)
def get_sum_list_fn(self)
def type_input_return_pairs(self)
def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None)
def get_jit_def(fn, self_name=None)
def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
def test_input_flatten(self)
def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices")
def quantize_linear_modules(module)
def save(m, f, _extra_files=DEFAULT_EXTRA_FILES_MAP)
Module caffe2.python.helpers.train.
def runAndSaveRNG(self, func, inputs, kwargs=None)
def run_pass(self, name, trace)
def _export(args, kwargs)
def set_training(args, kwargs)
def assertExportImportModule(self, m, inputs)
def batch(batch_size=1, optimize=True, _frames_up=0)
def pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)
def _check_code(self, code_str, fn_name, inputs)
def get_device_capability(device=None)
def checkTrace(self, func, reference_tensors, input_tensors=None, optimize=True, drop=None, allow_unused=False, verbose=False, inputs_require_grads=True, check_tolerance=1e-5, export_import=True, _force_outplace=False)
def ignore_lib_warnings()
def test_output_unflatten(self)
def do_trace_size(self, requires_grad)
def uniform_(tensor, a=0, b=1)
def assertLeaksNoCudaTensors(self, name=None)
def checkScript(self, script, inputs, optimize=True, outputs=None, name='func', capture_output=False, frames_up=1, check_expected=False)
def _get_py3_code(self, code, fn_name)
def _export_to_pretty_string(args, kwargs)
Module caffe2.python.helpers.dropout.
def dropout(input, p=0.5, training=True, inplace=False)
def _test_neural_style(self, device, check_export_import=True)
def checkScriptRaisesRegex(self, script, inputs, exception, regex, optimize=True, outputs=None, capture_output=False)
def assertExportImport(self, trace, inputs)
def getExportImportCopy(self, m, also_test_file=True, map_location=None)
def _make_scalar_vars(self, arr, dtype)
def grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False, only_inputs=True, allow_unused=False)
def assertGraphContains(self, graph, kind)
def assertExpectedGraph(self, trace, args, kwargs)
def emitModuleHook(self, module)
def quantize_rnn_cell_modules(module)
def _optimize_trace(trace, operator_export_type)
def assertWarnsRegex(self, callable, regex, msg='')