2 from torch
import Tensor
5 from torch.nn import Module, ModuleList, ParameterList, Parameter, Sequential
10 from torch._six import raise_from, with_metaclass, get_function_from_type, \
13 from ..nn.modules.utils
import _single, _pair, _triple, _quadruple, \
18 from collections
import defaultdict, OrderedDict, namedtuple
34 if sys.version_info[0] > 2:
38 def _parse_env(name, default, true_message, false_message):
39 value = os.environ.get(name)
42 if value.lower()
in {
'1',
'true',
'yes'}:
44 elif value.lower()
in {
'0',
'false',
'no'}:
52 raise ValueError(
'Unknown setting of {}. Try using 0 or 1.'.format(name))
55 _enabled = _parse_env(
'PYTORCH_JIT',
True,
"> Using PyTorch JIT",
"> PyTorch JIT DISABLED")
56 _flatten = torch._C._jit_flatten
57 _unflatten = torch._C._jit_unflatten
58 _jit_script_compile = torch._C._jit_script_compile
59 _jit_script_class_compile = torch._C._jit_script_class_compile
60 BatchTensor = torch._C._jit.BatchTensor
62 Future = torch._C.Future
67 @contextlib.contextmanager
68 def scope(scope_name):
69 tracing_state = torch._C._get_tracing_state()
71 tracing_state.push_scope(scope_name)
76 tracing_state.pop_scope()
79 DEFAULT_EXTRA_FILES_MAP = torch._C.ExtraFilesMap()
82 def load(f, map_location=None, _extra_files=DEFAULT_EXTRA_FILES_MAP):
84 Load a ``ScriptModule`` previously saved with :func:`save <torch.jit.save>` 86 All previously saved modules, no matter their device, are first loaded onto CPU, 87 and then are moved to the devices they were saved from. If this fails (e.g. because 88 the run time system doesn't have certain devices), an exception is raised. 89 However, storages can be dynamically remapped to an alternative set of devices 90 using the `map_location` argument. Comparing to :func:`torch.load`, `map_location` 91 in this function is simplified, which only accepts a string (e.g., 'cpu', 'cuda:0'), 92 or torch.device (e.g., torch.device('cpu')) 95 f: a file-like object (has to implement read, readline, tell, and seek), 96 or a string containing a file name 97 map_location: can a string (e.g., 'cpu', 'cuda:0'), a device (e.g., 99 _extra_files: map from filename to content. The extra 100 filenames given in the map would be loaded and their content 101 would be stored in the provided map. 105 A ``ScriptModule`` object. 108 >>> torch.jit.load('scriptmodule.pt') 109 # Load ScriptModule from io.BytesIO object 110 >>> with open('scriptmodule.pt', 'rb') as f: 111 buffer = io.BytesIO(f.read()) 112 # Load all tensors to the original device 113 >>> torch.jit.load(buffer) 114 # Load all tensors onto CPU, using a device 115 >>> torch.jit.load(buffer, map_location=torch.device('cpu')) 116 # Load all tensors onto CPU, using a string 117 >>> torch.jit.load(buffer, map_location='cpu') 118 # Load with extra files. 119 >>> files = {'metadata.json' : ''} 120 >>> torch.jit.load('scriptmodule.pt', _extra_files = files) 121 >>> print (files['metadata.json']) 125 def module_lookup(names):
128 if not hasattr(curr, name):
129 setattr(curr, name, ScriptModule())
130 curr = getattr(curr, name)
132 if isinstance(f, string_classes):
133 if not os.path.exists(f):
134 raise ValueError(
"The provided filename {} does not exist".format(f))
135 if isinstance(map_location, string_classes):
136 map_location = torch.device(map_location)
137 elif not (map_location
is None or 138 isinstance(map_location, torch.device)):
139 raise ValueError(
"map_location should be either None, string or torch.device, " 140 "but got type: " + str(type(map_location)))
141 if (str(map_location).startswith(
'cuda')):
142 validate_cuda_device(map_location)
144 if isinstance(f, str)
or \
145 (sys.version_info[0] == 2
and isinstance(f, unicode))
or \
146 (sys.version_info[0] == 3
and isinstance(f, pathlib.Path)):
147 torch._C.import_ir_module(module_lookup, f, map_location, _extra_files)
149 torch._C.import_ir_module_from_buffer(module_lookup, f.read(), map_location, _extra_files)
154 def save(m, f, _extra_files=DEFAULT_EXTRA_FILES_MAP):
156 Saves a ScriptModule to a file. 159 m: a ScriptModule to save 160 f: a file-like object (has to implement write and flush) or a string 161 containing a file name 162 _extra_files: Map from filename to contents which will be stored as part of 'f' 165 If you are using Python 2, torch.save does NOT support StringIO.StringIO 166 as a valid file-like object. This is because the write method should return 167 the number of bytes written; StringIO.write() does not do this. 169 Please use something like io.BytesIO instead. 172 >>> m = torch.jit.ScriptModule() 174 >>> torch.jit.save(m, 'scriptmodule.pt') 175 >>> # Save to io.BytesIO buffer 176 >>> buffer = io.BytesIO() 177 >>> torch.jit.save(m, buffer) 178 >>> # Save with extra files 179 >>> extra_files = torch._C.ExtraFilesMap() 180 >>> extra_files['foo.txt'] = 'bar' 181 >>> torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files) 183 if isinstance(f, str)
or \
184 (sys.version_info[0] == 2
and isinstance(f, unicode))
or \
185 (sys.version_info[0] == 3
and isinstance(f, pathlib.Path)):
186 m.save(f, _extra_files=_extra_files)
188 ret = m.save_to_buffer(_extra_files=_extra_files)
192 def get_trace_graph(f, args=(), kwargs=
None, _force_outplace=
False, return_inputs=
False):
194 Trace a function or model, returning a tuple consisting of the both the 195 *trace* of an execution, as well as the original return value. If return_inputs, 196 also returns the trace inputs as part of the tuple 198 Tracing is guaranteed not to change the semantics of the function/module 202 f (torch.nn.Module or function): the function or module 204 args (tuple or Tensor): the positional arguments to pass to the 205 function/module to be traced. A non-tuple is assumed to 206 be a single positional argument to be passed to the model. 207 kwargs (dict): the keyword arguments to pass to the function/module 210 Example: Trace a cell. 212 >>> trace, out = jit.trace(nn.LSTMCell(), (input, hidden)) 217 if not isinstance(args, tuple):
219 return LegacyTracedModule(f, _force_outplace, return_inputs)(*args, **kwargs)
222 def _unique_state_dict(module, keep_vars=False):
223 state_dict = module.state_dict(keep_vars=keep_vars)
224 filtered_dict = type(state_dict)()
226 for k, v
in state_dict.items():
227 if id(v)
in seen_ids:
234 def _create_interpreter_name_lookup_fn(frames_up=1):
235 def _get_interpreter_name_for_var(var):
236 frame = inspect.currentframe()
238 while i < frames_up + 1:
242 f_locals = frame.f_locals
243 f_globals = frame.f_globals
245 for k, v
in f_locals.items():
246 if isinstance(v, torch.Tensor)
and var
is v:
247 return k
if k !=
'self' else '' 248 for k, v
in f_globals.items():
249 if isinstance(v, torch.Tensor)
and var
is v:
250 return k
if k !=
'self' else '' 252 return _get_interpreter_name_for_var
256 def __init__(self, inner, force_outplace=False, return_inputs=False):
257 super(LegacyTracedModule, self).__init__()
265 def forward(self, *args):
266 in_vars, in_desc = _flatten(args)
269 module_state = list(_unique_state_dict(self, keep_vars=
True).values())
270 trace, all_trace_inputs = torch._C._tracer_enter(*(in_vars + module_state))
271 ret_inputs = tuple(x.clone()
for x
in all_trace_inputs)
273 torch._C._tracer_set_get_unique_name_fn(_create_interpreter_name_lookup_fn())
275 trace_inputs = _unflatten(all_trace_inputs[:len(in_vars)], in_desc)
276 out = self.
inner(*trace_inputs)
277 out_vars, _ = _flatten(out)
278 torch._C._tracer_exit(tuple(out_vars))
280 torch._C._tracer_abandon()
283 return trace, out, ret_inputs
288 def _clone_inputs(args):
292 elif isinstance(a, torch.Tensor):
294 v = Variable(a.data.clone(), requires_grad=a.requires_grad)
295 if a.grad
is not None:
296 v.grad = clone_input(v.grad)
300 return function._nested_map(
lambda x: isinstance(x, torch.Tensor),
301 clone_input, condition_msg=
"tensors")(args)
305 _JIT_DUMP = os.environ.get(
'PYTORCH_JIT_DUMP',
False)
306 _JIT_TIME = os.environ.get(
'PYTORCH_JIT_TIME',
False)
307 _JIT_DISABLE = os.environ.get(
'PYTORCH_JIT_DISABLE',
False)
308 _JIT_STATS = os.environ.get(
'PYTORCH_JIT_STATS',
False)
311 def _dump_trace(trace_name, pass_name, input_key, trace):
317 filename =
"{}_{}".format(trace_name, pass_name)
320 with open(filename +
".ir",
"w")
as f:
321 f.write(
"Input key: {}\n\n{}".format(input_key, str(trace)))
322 graph_vis.write(trace.graph(), filename +
".html")
325 @contextlib.contextmanager
326 def _time(trace_name, name, time=True):
331 start = torch.cuda.Event(enable_timing=
True)
332 end = torch.cuda.Event(enable_timing=
True)
333 stream.record_event(start)
337 stream.record_event(end)
339 print(
"{} {} time: {} ms".format(trace_name, name, start.elapsed_time(end)))
342 def verify(model, args, loss_fn=torch.sum, devices=None):
344 Verify that a JIT compiled model has the same behavior as its uncompiled 345 version along with its backwards pass. If your model returns multiple 346 outputs, you must also specify a `loss_fn` to produce a loss for which 347 the backwards will be computed. 349 This function has side-effects (e.g., it executes your model / saves and loads 350 parameters), so don't expect the model to come out exactly the same as what 354 model (compiled torch.nn.Module or function): the module/function to be 355 verified. The module/function definition MUST have been decorated with 356 `@torch.jit.compile`. 357 args (tuple or Tensor): the positional arguments to pass to the 358 compiled function/module to be verified. A non-tuple is assumed to 359 be a single positional argument to be passed to the model. 360 loss_fn (function, optional): the loss function to be applied to 361 the output of the model, before backwards is invoked. By default, 362 we assume that a model returns a single result, and we :func:`torch.sum` 363 before calling backwards; if this is inappropriate, you can pass your 364 own loss function. Note that if a model returns a tuple of results, 365 these are passed as separate positional arguments to `loss_fn`. 366 devices (iterable of device IDs, optional): the GPU devices which the 367 compiled module will be run on. This determines the RNG state we 368 must save when running both compiled and uncompiled versions of the model. 376 if not isinstance(model, torch._C.CompiledFunction):
377 raise TypeError(
"Cannot verify an uncompiled module. Add @torch.jit.compile to compile it")
378 is_module = isinstance(model, Module)
380 if not isinstance(args, tuple):
383 saved_args = _clone_inputs(args)
385 saved_state = copy.deepcopy(model.state_dict())
387 def run_fwd_bwd(args, force_trace=False, assert_compiled=False):
388 params = list(model.parameters())
if is_module
else []
389 in_vars, _ = _flatten((args, params))
393 compiled_fn.clear_cache()
395 hits = compiled_fn.hits
397 if assert_compiled
and compiled_fn.hits == hits:
398 raise RuntimeError(
"failed to use the compiled function")
399 if not isinstance(out, tuple):
401 if loss_fn == torch.sum
and len(out) != 1:
402 raise ValueError((
"Model returns {} outputs, but default loss function " 403 "(torch.sum) can only handle a single output").format(len(out)))
404 out_vars, _ = _flatten(out)
405 saved_outs = [v.data.clone()
for v
in out_vars]
409 saved_grads = [v.data.clone()
for v
in grads]
410 return (saved_outs, saved_grads)
413 uncompiled_outs, uncompiled_grads = run_fwd_bwd(args, force_trace=
True)
414 assert model.has_trace_for(*args)
417 model.load_state_dict(saved_state)
418 compiled_outs, compiled_grads = run_fwd_bwd(args, assert_compiled=
True)
420 _verify_equal(uncompiled_outs, compiled_outs)
421 _verify_equal(uncompiled_grads, compiled_grads)
424 def _verify_equal(xs, ys):
425 for x, y
in zip(xs, ys):
426 if x.sub(y).abs().max() > 1e-6:
427 raise RuntimeError(
"JIT and real computation mismatch")
431 return '\n'.join([
'\t' + line
for line
in s.splitlines()])
435 def __init__(self, graph_diff_error, tensor_compare_error, extra_msg=None):
436 self.
message =
'Tracing failed sanity checks!\n' 437 if extra_msg
is not None:
438 self.
message += extra_msg +
'\n' 439 if graph_diff_error
is not None:
440 self.
message +=
'ERROR: Graphs differed across invocations!\n' 441 self.
message += indent(graph_diff_error) +
'\n' 442 if tensor_compare_error
is not None:
443 self.
message +=
'ERROR: Tensor-valued Constant nodes differed in value ' \
444 'across invocations. This often indicates that the tracer has' \
445 ' encountered untraceable code.\n' 446 self.
message += indent(tensor_compare_error) +
'\n' 447 super(TracingCheckError, self).__init__(self.
message)
452 def _check_trace(check_inputs, func, executor_options, module, check_tolerance, force_outplace):
454 executor_options[
'optimize'] =
False 455 for inputs
in check_inputs:
456 if isinstance(inputs, torch.Tensor):
460 _clone_inputs(inputs),
462 _force_outplace=force_outplace,
465 def graph_diagnostic_info():
466 mod_canonicalized = torch._C._jit_pass_canonicalize(module.graph)
467 torch._C._jit_pass_erase_shape_information(mod_canonicalized)
468 check_canonicalized = torch._C._jit_pass_canonicalize(check_mod.graph)
469 torch._C._jit_pass_erase_shape_information(check_canonicalized)
471 graph_diff_errors =
None 472 if str(mod_canonicalized) != str(check_canonicalized):
474 graph_diff = difflib.ndiff(str(mod_canonicalized).splitlines(
True),
475 str(check_canonicalized).splitlines(
True))
476 graph_diff_errors =
'Graph diff:\n' + indent(
''.join(graph_diff)) +
'\n' 478 for n_mod, n_check
in zip(mod_canonicalized.nodes(), check_canonicalized.nodes()):
479 if str(n_mod) != str(n_check):
480 graph_diff_errors +=
'First diverging operator:\n' 481 node_diff = difflib.ndiff(str(n_mod).splitlines(
True),
482 str(n_check).splitlines(
True))
483 source_printout =
'Node diff:\n' + indent(
''.join(node_diff)) +
'\n' 484 mod_stack = n_mod.getSourceLocation()
486 source_printout +=
'Trace source location:\n' + indent(mod_stack) +
'\n' 487 check_stack = n_check.getSourceLocation()
489 source_printout +=
'Check source location:\n' + indent(check_stack) +
'\n' 490 graph_diff_errors += source_printout
494 tensor_compare_errors =
None 496 for n_mod, n_check
in zip(mod_canonicalized.nodes(), check_canonicalized.nodes()):
497 if n_mod.kind() != n_check.kind():
500 if n_mod.kind() ==
'prim::Constant' and not (n_mod.mustBeNone()
or n_check.mustBeNone()):
501 if n_mod.kindOf(
'value') !=
't' or n_check.kindOf(
'value') !=
't':
504 mod_tensor_val = n_mod.t(
'value')
505 check_tensor_val = n_check.t(
'value')
509 except (RuntimeError, AssertionError)
as e:
510 if tensor_compare_errors
is None:
511 tensor_compare_errors =
'' 512 tensor_compare_errors +=
'Node:\n' + indent(str(n_mod)) +
'\n' 513 compare_stack = n_mod.getSourceLocation()
515 tensor_compare_errors +=
'Source Location:\n' + indent(compare_stack) +
'\n' 516 tensor_compare_errors +=
'Comparison exception: ' + indent(str(e))
520 return graph_diff_errors, tensor_compare_errors
523 return x
if isinstance(x, tuple)
else (x,)
525 def run_mod_and_filter_tensor_outputs(mod, inputs, running_what):
527 outs = wrap_retval(mod(*_clone_inputs(inputs)))
528 outs = [out
for out
in outs
if isinstance(out, torch.Tensor)]
530 except Exception
as e:
532 extra_msg=
'Encountered an exception while running the ' + running_what +
533 ' with test inputs.\nException:\n' + indent(str(e)))
537 def maybe_warn_nondeterministic():
541 nondeterm_ops = [op
for op
in module.graph.nodes()
if op.isNondeterministic()]
542 if len(nondeterm_ops) > 0:
543 nondeterministic_ops_warning =
"Trace had nondeterministic nodes. " 544 nondeterministic_ops_warning +=
"Did you forget call .eval() on your model? Nodes:\n" 545 nondeterministic_ops_warning +=
"\n".join([indent(str(op))
for op
in nondeterm_ops][:20])
546 nondeterministic_ops_warning +=
"\nThis may cause errors in trace checking. To disable trace checking,"\
547 " pass check_trace=False to torch.jit.trace()" 548 warnings.warn(nondeterministic_ops_warning, category=TracerWarning, stacklevel=5)
550 def compare_outputs(original, reference, match_what):
552 for i, (orig, ref)
in enumerate(zip(original, reference)):
556 except AssertionError
as e:
557 maybe_warn_nondeterministic()
558 warnings.warn(
'Output nr ' + str(i + 1) +
'. of the traced function does not match ' 559 'the corresponding output of the ' + match_what +
'. Detailed error:\n' + str(e),
560 category=TracerWarning, stacklevel=4)
565 traced_outs = run_mod_and_filter_tensor_outputs(module, inputs,
'trace')
566 fn_outs = run_mod_and_filter_tensor_outputs(func, inputs,
'Python function')
567 if compare_outputs(traced_outs, fn_outs,
'Python function'):
568 check_outs = run_mod_and_filter_tensor_outputs(check_mod, inputs,
'repeated trace')
569 compare_outputs(traced_outs, check_outs,
'repeated trace')
571 diag_info = graph_diagnostic_info()
572 if any(info
is not None for info
in diag_info):
578 def ignore_lib_warnings():
580 warnings.filterwarnings(
'ignore', category=TracerWarning, module=
'torch.(?!jit)')
585 TracerWarning.ignore_lib_warnings()
586 torch._C._tracer_warn_use_python()
594 check_tolerance=1e-5,
595 _force_outplace=
False,
598 Trace a function and return an executable trace that will be optimized 599 using just-in-time compilation. 603 Tracing only correctly records functions and modules which are not data 604 dependent (e.g., have conditionals on data in tensors) and do not have 605 any untracked external dependencies (e.g., perform input/output or 606 access global variables). If you trace such models, you may silently get 607 incorrect results on subsequent invocations of the model. The tracer 608 will try to emit warnings when doing something that may cause an 609 incorrect trace to be produced. 612 func (callable or torch.nn.Module): a python function or torch.nn.Module 613 that will be run with example_inputs. 614 arguments and returns to func must be Tensors 615 or (possibly nested) tuples that 617 example_inputs (tuple): a tuple of example inputs that will be passed to the function 618 while tracing. The resulting trace can be run with 619 inputs of different types and shapes assuming the traced operations 620 support those types and shapes. example_inputs may also be a single 621 Tensor in which case it is automatically wrapped in a tuple 624 optimize (bool, optional): whether or not to apply optimizations. Default: ``True``. 625 check_trace (bool, optional): check if the same inputs run through 626 traced code produce the same outputs. Default: ``True``. You might want 627 to disable this if, for example, your network contains non- 628 deterministic ops or if you are sure that the network is correct despite 631 check_inputs (list of tuples, optional): A list of tuples of input arguments that should be used 632 to check the trace against what is expected. Each tuple 633 is equivalent to a seet of input arguments that would 634 be specified in ``args``. For best results, pass in a 635 set of checking inputs representative of the space of 636 shapes and types of inputs you expect the network to see. 637 If not specified, the original ``args`` is used for checking 638 check_tolerance (float, optional): Floating-point comparison tolerance to use in the checker procedure. 639 This can be used to relax the checker strictness in the event that 640 results diverge numerically for a known reason, such as operator fusion. 643 A ``ScriptModule`` object with a single ``forward()`` method containing the traced code. 644 When func is a ``torch.nn.Module``, the returned ``ScriptModule`` will have the same set of 645 sub-modules and parameters as func. 650 >>> traced_f = torch.jit.trace(f, torch.rand(1)) 655 executor_options = {
'optimize': bool(optimize)}
657 if isinstance(example_inputs, torch.Tensor):
658 example_inputs = (example_inputs,)
660 elif not isinstance(example_inputs, tuple):
661 example_inputs = tuple(example_inputs)
663 module = _module_class(func, **executor_options)
666 var_lookup_fn = _create_interpreter_name_lookup_fn(0)
667 module._create_method_from_trace(
'forward', func, example_inputs,
668 var_lookup_fn, _force_outplace)
672 if check_inputs
is not None:
673 _check_trace(check_inputs, func, executor_options, module, check_tolerance, _force_outplace)
675 _check_trace([example_inputs], func, executor_options, module, check_tolerance, _force_outplace)
681 def __init__(self, lang=None, optimize=True, _frames_up=0):
682 self.
module = torch._C.ScriptModule()
683 self.module._set_optimized(optimize)
685 self.
define(lang, _frames_up=_frames_up + 1)
688 def define(self, lang, rcb=None, _frames_up=0):
690 rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
691 self.module._define(lang, rcb,
False)
693 def __getattr__(self, attr):
694 return self.module._get_method(attr)
697 def _try_get_dispatched_fn(fn):
700 return _jit_internal.boolean_dispatched.get(fn)
703 def _try_get_overloaded_fn(fn):
704 if not hasattr(fn,
'__self__')
or not isinstance(fn.__self__, ScriptModule):
707 overloads = fn.__self__._overloads.get(fn.__name__,
None)
708 if overloads
is None:
710 return [getattr(fn.__self__, overload)
for overload
in overloads]
713 def _try_compile_weak_script(fn):
714 entry = _jit_internal.compiled_weak_fns.get(fn)
717 if entry[
"status"] == _jit_internal.COMPILATION_PENDING:
720 _jit_internal.compiled_weak_fns[fn][
"compiled_fn"] = compiled_fn
721 entry[
"status"] = _jit_internal.COMPILED
724 return entry[
"compiled_fn"]
727 def script(obj, optimize=True, _frames_up=0, _rcb=None):
731 _rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
732 if inspect.isclass(obj):
734 ast = get_jit_class_def(obj)
735 _jit_script_class_compile(mod, ast, _rcb)
738 ast = get_jit_def(obj)
739 _jit_script_compile(mod, ast, _rcb, get_default_args(obj))
741 mod.__doc__ = obj.__doc__
745 ScriptMethodStub = namedtuple(
'ScriptMethodStub', (
'resolution_callback',
'def_',
'original_method'))
748 def script_method(fn, _rcb=None):
764 _rcb = _jit_internal.createResolutionCallback(frames_up=2)
765 ast = get_jit_def(fn, self_name=
"ScriptModule")
766 return ScriptMethodStub(_rcb, ast, fn)
769 def _try_get_weak_module(mod):
771 Get the WeakScriptModuleProxy corresponding to mod if it exists 773 if not isinstance(mod, Module):
775 return _jit_internal.weak_modules.get(mod)
778 def _try_get_ignored_op(fn):
781 if hasattr(fn,
'__func__'):
783 return fn
in _jit_internal.ignored_fns
786 def _is_weak_type(cls):
788 Check if a type has been annotated with `weak_module` 790 return cls
in _jit_internal.weak_types
793 def batch(batch_size=1, optimize=True, _frames_up=0):
798 mod = script(fn, optimize, _frames_up)
799 res_graph = torch.to_batch_graph(mod.graph)
801 res_mod._create_method_from_graph(
'forward', res_graph)
806 if isinstance(arg, torch.Tensor):
807 arg = BatchTensor(arg, batch_size)
808 if isinstance(arg, BatchTensor):
809 new_args.extend([arg.get_data(), arg.get_mask(), arg.get_dims()])
812 res = res_mod(*new_args)
813 assert len(res) % 3 == 0
814 if len(res) % 3 != 0:
815 raise "non-batched-tensor output is not supported yet" 816 result = [BatchTensor(*res[i * 3: i * 3 + 3])
for i
in range(len(res) // 3)]
820 wrapper.__doc__ = fn.__doc__
840 def __init__(self, module):
847 raise RuntimeError(
"_parameters or _modules alive after module is dead")
851 return [k
for k, v
in self.
items()]
854 return [v
for k, v
in self.
items()]
856 def __delitem__(self, k):
857 raise RuntimeError(
"cannot delete methods or parameters of a script module")
860 raise NotImplementedError
862 def __contains__(self, k):
863 raise NotImplementedError
865 def __getitem__(self, k):
866 raise NotImplementedError
868 def __setitem__(self, k, v):
869 raise NotImplementedError
873 def __init__(self, module):
874 super(OrderedModuleDict, self).__init__(module)
884 r = self._python_modules.items()
887 def __contains__(self, k):
890 def __setitem__(self, k, v):
892 raise RuntimeError(
"cannot re-assign modules in a ScriptModule")
893 if isinstance(v, ScriptModule):
894 self.module._register_module(k, v)
898 def __getitem__(self, k):
903 def __init__(self, module):
904 super(OrderedParameterDict, self).__init__(module)
907 return [(name, param)
for name, param
in self.module._get_parameters()]
909 def __setitem__(self, k, v):
910 self.module._register_parameter(k, v,
False)
912 def __contains__(self, k):
913 return self.module._has_parameter(k)
915 def __getitem__(self, k):
918 return self.module._get_parameter(k)
922 def __init__(self, module):
923 super(OrderedBufferDict, self).__init__(module)
926 return [(name, param)
for name, _, param
in 927 self.module._get_attributes()
if isinstance(param, torch.Tensor)]
929 def __setitem__(self, k, v):
930 self.module._register_buffer(k, v)
932 def __contains__(self, k):
933 return self.module._has_buffer(k)
935 def __getitem__(self, k):
938 return self.module._get_buffer(k)
944 _constant_types = (bool, float, int, str, type(
None), types.FunctionType, torch.device, torch.layout, torch.dtype)
947 def _get_valid_constant(attr, v):
948 if isinstance(v, _constant_types):
950 elif isinstance(v, tuple)
or isinstance(v, list):
951 return tuple(_get_valid_constant(attr, x)
for x
in v)
952 constants =
", ".join(typ.__name__
for typ
in _constant_types)
953 raise TypeError(textwrap.dedent(
""" 954 '{}' object for attribute '{}' is not a valid constant. 957 2. a value of type {{{}}} 958 3. a list or tuple of (2) 959 """.format(type(v).__name__, attr, constants)))
962 def _create_methods_from_stubs(self, stubs):
963 defs = [m.def_
for m
in stubs]
964 rcbs = [m.resolution_callback
for m
in stubs]
965 defaults = [get_default_args(m.original_method)
for m
in stubs]
966 self._create_methods(defs, rcbs, defaults)
982 def __init__(cls, name, bases, attrs):
986 for k, v
in sorted(attrs.items()):
987 if isinstance(v, ScriptMethodStub):
993 original_init = getattr(cls,
'__init__',
lambda self:
None)
994 super_constants = getattr(super(cls),
'_constants_set', set())
995 cls.
_constants_set = set(getattr(cls,
'__constants__', ())).union(super_constants)
996 cls.
_overloads = dict(getattr(cls,
'__overloads__', {}))
998 @functools.wraps(original_init)
999 def init_then_register(self, *args, **kwargs):
1003 if cls
is type(self):
1004 torch._C.ScriptModule.__init__(self)
1005 original_init(self, *args, **kwargs)
1006 _create_methods_from_stubs(self, methods)
1009 return super(ScriptMeta, cls).__init__(name, bases, attrs)
1015 The core data structure in TorchScript is the ``ScriptModule``. It is an 1016 analogue of torch's nn.Module and represents an entire model as a tree of 1017 submodules. Like normal modules, each individual module in a ScriptModule can 1018 have submodules, parameters, and methods. In nn.Modules methods are implemented 1019 as Python functions, but in ScriptModules methods typically implemented as 1020 *TorchScript* functions, a statically-typed subset of Python that contains all 1021 of PyTorch's built-in Tensor operations. This difference allows your 1022 ScriptModules code to run without the need for a Python interpreter. 1024 ScriptModules and the TorchScript functions inside of them can be created in 1029 Using ``torch.jit.trace``, you can take an existing module or python 1030 function, provide example inputs, and we run the function, recording the 1031 operations performed on all the tensors. We turn the resulting recording 1032 into a TorchScript method that is installed as the ``forward`` method of a 1033 ScriptModule. This module also contains any parameters that the original 1041 traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3))) 1044 Tracing a *function* will produce a ``ScriptModule`` with a single 1045 ``forward`` method that implements that function, and that contains 1052 traced_net = torch.jit.trace(torchvision.models.resnet18(), 1053 torch.rand(1, 3, 224, 224)) 1057 Tracing only records operations done when the given function is run on the given 1058 tensors. Therefore, the returned ``ScriptModule`` will always run the same traced 1059 graph on any input. This has some important implications when your module is 1060 expected to run different sets of operations, depending on the input and/or the 1061 module state. For example, 1063 + Tracing will not record any control-flow like if statements or loops. When 1064 this control-flow is constant across your module, this is fine and it often 1065 just inlines configuration decisions. But sometimes the control-flow is 1066 actually part of the model itself. For instance, a recurrent network is 1067 a loop over the (possibly dynamic) length of an input sequence. 1069 + In the returned ``ScriptModule``, operations that have different behaviors 1070 in ``training`` and ``eval`` modes will always behave as if it is in the 1071 mode it was in during tracing, no matter which mode the ``ScriptModule`` 1074 In cases like these, tracing would not be appropriate and scripting is a better 1079 You can write TorchScript code directly using Python syntax. You do this 1080 using the ``torch.jit.script`` annotation (for functions) or 1081 ``torch.jit.script_method`` annotation (for methods) on subclasses of 1082 ScriptModule. With this annotation the body of the annotated function is 1083 directly translated into TorchScript. TorchScript itself is a subset of 1084 the Python language, so not all features in python work, but we provide 1085 enough functionality to compute on tensors and do control-dependent 1093 if x.max() > y.max(): 1100 A script *function* annotation will construct a ScriptModule 1101 with a single ``forward`` method that implements that function, 1102 and that contains no parameters. 1107 class MyModule(torch.jit.ScriptModule): 1108 def __init__(self, N, M): 1109 super(MyModule, self).__init__() 1110 self.weight = torch.nn.Parameter(torch.rand(N, M)) 1112 @torch.jit.script_method 1113 def forward(self, input): 1114 return self.weight.mv(input) 1119 import torch.nn as nn 1120 import torch.nn.functional as F 1121 from torch.jit import ScriptModule, script_method, trace 1123 class MyScriptModule(ScriptModule): 1125 super(MyScriptModule, self).__init__() 1126 # trace produces a ScriptModule's conv1 and conv2 1127 self.conv1 = trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16)) 1128 self.conv2 = trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16)) 1131 def forward(self, input): 1132 input = F.relu(self.conv1(input)) 1133 input = F.relu(self.conv2(input)) 1137 def __init__(self, optimize=True):
1139 Module.__init__(self)
1140 self._set_optimized(optimize)
1145 def __getattr__(self, attr):
1146 if self._has_method(attr):
1147 if attr
in self.__class__._original_methods:
1148 original_method = self.__class__._original_methods[attr]
1149 script_method = self._get_method(attr)
1150 return functools.wraps(original_method)(script_method)
1152 return self._get_method(attr)
1153 if attr ==
'graph' and self._has_method(
'forward'):
1155 return Module.__getattr__(self, attr)
1157 def __setattr__(self, attr, value):
1159 if isinstance(value, Module)
and _is_weak_type(type(value)):
1161 value = _make_strong(value)
1162 if attr ==
'training':
1163 if self._has_buffer(
'training'):
1164 self.__dict__[
'training'] = value
1165 self._get_buffer(
'training').fill_(int(value))
1167 if isinstance(value, Attribute):
1170 self._register_attribute(attr, the_type, value.value)
1171 except RuntimeError:
1172 raise RuntimeError(
"Could not register attribute '{}' of type '{}' for a value of type '{}'" 1173 .format(attr, value.type, type(value.value)))
1175 return super(ScriptModule, self).__setattr__(attr, value)
1177 if hasattr(self, attr):
1178 raise RuntimeError(
"attempting to re-assign constant '{}'".format(attr))
1180 def conv_module_to_const(module_value):
1181 if not isinstance(module_value, (ModuleList, Sequential)):
1183 for i
in range(len(module_value)):
1184 module_value[i] = conv_module_to_const(module_value[i])
1185 if isinstance(module_value, Sequential):
1190 if isinstance(value, (ModuleList, Sequential)):
1195 super(ScriptModule, self).__setattr__(attr, conv_module_to_const(value))
1197 super(ScriptModule, self).__setattr__(attr, _get_valid_constant(attr, value))
1200 return sorted(Module.__dir__(self) + self._method_names())
1202 def define(self, lang):
1211 rcb = _jit_internal.createResolutionCallback(frames_up=1)
1212 self._define(lang, rcb,
True)
1217 def module_lookup(names):
1220 if not hasattr(curr, name):
1222 curr = getattr(curr, name)
1224 self._copy_into(module_lookup, {}, [])
1227 def __getstate__(self):
1228 raise pickle.PickleError(
1229 "ScriptModules cannot be saved using torch.save. " +
1230 "Mixed serialization of script and non-script modules is not supported. " +
1231 "For purely script modules use my_script_module.save(<filename>) instead.")
1234 def __init__(self, original, stubs):
1237 self.__dict__[
'_initialized'] =
False 1238 super(WeakScriptModuleProxy, self).__init__()
1240 self.__dict__[
"_original"] = weakref.ref(original)
1243 for name
in dir(original):
1244 item = getattr(original, name)
1245 if item
is None and name
in original._parameters:
1248 object.__setattr__(self, name, item)
1249 elif isinstance(item, Parameter)
or (isinstance(item, Module)
and item
is not self):
1250 ScriptModule.__setattr__(self, name, item)
1251 for name
in original._buffers:
1252 if original._buffers[name]
is None:
1253 object.__setattr__(self, name,
None)
1255 self.register_buffer(name, original._buffers[name])
1258 self.__dict__[
"_constants_set"] = set(getattr(original,
"__constants__", []))
1261 self.__dict__[
"_overloads"] = dict(getattr(original,
"__overloads__", {}))
1263 self.__dict__[
"_initialized"] =
True 1264 _create_methods_from_stubs(self, stubs)
1266 def __getattr__(self, attr):
1270 return ScriptModule.__getattr__(self, attr)
1271 except AttributeError:
1272 if self.__dict__[
"_initialized"]:
1273 return getattr(self.__dict__[
"_original"](), attr)
1276 raise AttributeError(
"Weak module has no attribute '{}'" 1279 def __setattr__(self, attr, value):
1282 if not self.__dict__[
"_initialized"]:
1284 return ScriptModule.__setattr__(self, attr, value)
1286 if hasattr(self, attr):
1287 return ScriptModule.__setattr__(self, attr, value)
1289 raise AttributeError(
"Cannot set new attribute '{}' on " 1290 "weak script module once it has been " 1291 "created".format(attr))
1294 def __init__(self, name):
1295 super(ScriptClass, self).__init__()
1300 def __init__(self, optimize=True):
1301 super(ScriptModule, self).__init__()
1304 def __init__(self, name):
1305 super(ScriptClass, self).__init__()
1308 def _get_weak_stubs(cls):
1310 Calls script_method for each method on the type of the object passed in and 1311 returns the generated ScriptMethodStubs 1314 for name
in dir(cls):
1315 func = get_function_from_type(cls, name)
1316 if func
in _jit_internal.weak_script_methods:
1317 entry = _jit_internal.weak_script_methods[func]
1318 stub = script_method(entry[
"original_method"], entry[
"rcb"])
1323 def _make_strong(mod):
1325 Converts a weak module into a subclass of ScriptModule 1327 if mod
in _jit_internal.weak_modules:
1328 return _jit_internal.weak_modules[mod]
1330 stubs = _jit_internal.weak_types.get(type(mod))[
"method_stubs"]
1335 stubs = _get_weak_stubs(type(mod))
1336 _jit_internal.weak_types[type(mod)][
"method_stubs"] = stubs
1341 _jit_internal.weak_modules[mod] = proxy
1346 def _get_methods(cls):
1349 return inspect.getmembers(cls, predicate=
lambda x: inspect.isfunction(x)
or inspect.ismethod(x))
1352 _compiled_methods_whitelist = {
1353 'forward',
'register_buffer',
'register_parameter',
'add_module',
1354 '_apply',
'apply',
'cuda',
'cpu',
'to',
'type',
'float',
'double',
'half',
1355 'state_dict',
'load_state_dict',
'_load_from_state_dict',
1356 '_named_members',
'parameters',
'named_parameters',
1357 'buffers',
'named_buffers',
'children',
'named_children',
'modules',
1358 'named_modules',
'zero_grad',
'share_memory',
'_get_name',
'extra_repr',
1359 '_slow_forward',
'_tracing_name',
'eval',
'train',
1363 def _make_fail(name):
1364 def fail(self, *args, **kwargs):
1365 raise RuntimeError(name +
" is not supported on ScriptModules")
1369 for name, method
in _get_methods(torch.nn.Module):
1370 if name.startswith(
'__'):
1372 if name
not in ScriptModule.__dict__
and name
not in _compiled_methods_whitelist:
1373 setattr(ScriptModule, method.__name__, _make_fail(name))
1379 def __init__(self, orig, id_set=None, optimize=True):
1381 super(TracedModule, self).__init__(optimize=optimize)
1385 if not isinstance(orig, torch.nn.Module):
1386 self.
_name = orig.__name__
1387 orig = torch.nn.Module()
1389 self.
_name =
'TracedModule[' + type(orig).__name__ +
']' 1391 def check_unique(param):
1393 raise ValueError(
"TracedModules don't support parameter sharing between modules")
1398 for name, param
in orig._parameters.items():
1399 if param
is not None:
1402 for name, buf
in orig._buffers.items():
1407 if orig._backward_hooks
or orig._forward_hooks
or orig._forward_pre_hooks:
1408 raise ValueError(
"Modules that have hooks assigned can't be compiled")
1410 for name, submodule
in orig._modules.items():
1411 if isinstance(submodule, ScriptModule)
and not isinstance(submodule, TracedModule):
1412 self.
_modules[name] = submodule.copy()
1418 def forward(self, *args, **kwargs):
1419 raise RuntimeError(
'Trace submodules cannot be called.')
1424 def _get_name(self):
1427 def __setattr__(self, attr, value):
1428 if not self.
__frozen or hasattr(self, attr):
1429 return super(TracedModule, self).__setattr__(attr, value)
1430 raise RuntimeError(
"Cannot set new properties on a traced module.")
1434 def forward(self, *args, **kwargs):
1435 return self._get_method(
'forward')(*args, **kwargs)
1439 def __init__(self, modules):
1440 super(_ConstModuleList, self).__init__()
1441 for i, module
in enumerate(modules):
1442 if _is_weak_type(type(module)):
1443 module = _make_strong(module)
1444 self.add_module(str(i), module)
1446 def __getitem__(self, idx):
1447 if isinstance(idx, slice):
1450 if not (-len(self) <= idx < len(self)):
1451 raise IndexError(
'index {} is out of range'.format(idx))
1460 return iter(self._modules.values())
1463 keys = super(_ConstModuleList, self).__dir__()
1464 keys = [key
for key
in keys
if not key.isdigit()]
1469 __constants__ = [
'mods']
1471 def __init__(self, mods):
1472 super(_ConstSequential, self).__init__(mods._modules.values())
1480 def forward(self, input): 1487 _builtin_table =
None 1489 _modules_containing_builtins = (torch, torch._C._nn)
1492 def _unwrap_optional(x):
1493 assert x
is not None,
"Unwrapping null optional" 1498 def _get_builtin_table():
1499 global _builtin_table
1500 if _builtin_table
is not None:
1501 return _builtin_table
1504 def register_all(mod):
1505 for name
in dir(mod):
1506 v = getattr(mod, name)
1508 _builtin_table[id(v)] =
"aten::" + name
1509 for mod
in _modules_containing_builtins:
1512 _builtin_table[id(warnings.warn)] =
"aten::warn" 1513 _builtin_table[id(_single)] =
"aten::_single" 1514 _builtin_table[id(_pair)] =
"aten::_pair" 1515 _builtin_table[id(_triple)] =
"aten::_triple" 1516 _builtin_table[id(_quadruple)] =
"aten::_quadruple" 1517 _builtin_table[id(_list_with_default)] =
"aten::list_with_default" 1518 _builtin_table[id(_unwrap_optional)] =
"aten::_unwrap_optional" 1519 _builtin_table[id(cudnn.is_acceptable)] =
"aten::cudnn_is_acceptable" 1520 _builtin_table[id(torch._C._infer_size)] =
"aten::_infer_size" 1521 _builtin_table[id(torch.nn.functional._no_grad_embedding_renorm_)] =
"aten::_no_grad_embedding_renorm_" 1523 _builtin_table[id(math.floor)] =
"aten::floor" 1524 _builtin_table[id(torch.nn.functional.interpolate)] =
"aten::__interpolate" 1525 _builtin_table[id(torch.nn.functional.upsample_nearest)] =
"aten::__upsample_nearest" 1526 _builtin_table[id(torch.nn.functional.upsample)] =
"aten::__upsample" 1527 _builtin_table[id(torch.nn.functional.upsample_bilinear)] =
"aten::__upsample_bilinear" 1528 _builtin_table[id(torch.nn.functional.assert_int_or_pair)] =
"aten::_assert_int_or_pair" 1529 _builtin_table[id(torch.nn.utils.rnn.get_packed_sequence)] =
"aten::_pack_sequence" 1531 return _builtin_table
1534 def _register_builtin(fn, op):
1535 _get_builtin_table()[id(fn)] = op
1538 def _find_builtin(fn):
1539 return _get_builtin_table().get(id(fn))
1542 _register_builtin(len,
'aten::len')
1543 _register_builtin(_wait,
'aten::wait')
1546 Error = torch._C.JITException
1550 def __enter__(self):
1551 self.
state = torch._C._get_tracing_state()
1552 torch._C._set_tracing_state(
None)
1554 def __exit__(self, *args):
1555 torch._C._set_tracing_state(self.
state)
1560 def annotate(the_type, the_value):
1566 def __init__(self, value, the_type):
1568 self.
type = the_type
1571 if not torch._C._jit_init():
1572 raise RuntimeError(
"JIT initialization failed")
def assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True)
Module caffe2.python.scope.
def _get_default_tolerance(a, b=None)
def script(obj, optimize=True, _frames_up=0, _rcb=None)
def trace(func, example_inputs, optimize=True, check_trace=True, check_inputs=None, check_tolerance=1e-5, _force_outplace=False, _module_class=None)
def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices")
def define(self, lang, rcb=None, _frames_up=0)
def grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False, only_inputs=True, allow_unused=False)
def __getattr__(self, attr)
def current_stream(device=None)