Caffe2 - Python API
A deep learning, cross platform ML framework
__init__.py
1 import torch._C
2 from torch import Tensor
3 from torch.autograd import Variable, function
4 from torch.serialization import validate_cuda_device
5 from torch.nn import Module, ModuleList, ParameterList, Parameter, Sequential
6 from torch.jit.frontend import get_jit_class_def, get_jit_def, get_default_args
7 import torch.backends.cudnn as cudnn
9 import torch._jit_internal as _jit_internal
10 from torch._six import raise_from, with_metaclass, get_function_from_type, \
11  string_classes
12 from torch._jit_internal import ignore
13 from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \
14  _list_with_default
15 import torch.testing
16 
17 import math
18 from collections import defaultdict, OrderedDict, namedtuple
19 import textwrap
20 import sys
21 import warnings
22 import itertools
23 import weakref
24 import types
25 import contextlib
26 import os
27 import functools
28 import copy
29 import numbers
30 import collections
31 import re
32 import inspect
33 import pickle
34 if sys.version_info[0] > 2:
35  import pathlib
36 
37 
38 def _parse_env(name, default, true_message, false_message):
39  value = os.environ.get(name)
40  if value is None:
41  return default
42  if value.lower() in {'1', 'true', 'yes'}:
43  return True
44  elif value.lower() in {'0', 'false', 'no'}:
45  return False
46  if value == '1v':
47  print(true_message)
48  return True
49  elif value == '0v':
50  print(false_message)
51  return False
52  raise ValueError('Unknown setting of {}. Try using 0 or 1.'.format(name))
53 
54 
55 _enabled = _parse_env('PYTORCH_JIT', True, "> Using PyTorch JIT", "> PyTorch JIT DISABLED")
56 _flatten = torch._C._jit_flatten
57 _unflatten = torch._C._jit_unflatten
58 _jit_script_compile = torch._C._jit_script_compile
59 _jit_script_class_compile = torch._C._jit_script_class_compile
60 BatchTensor = torch._C._jit.BatchTensor
61 
62 Future = torch._C.Future
63 _fork = torch._C.fork
64 _wait = torch._C.wait
65 
66 
67 @contextlib.contextmanager
68 def scope(scope_name):
69  tracing_state = torch._C._get_tracing_state()
70  if tracing_state:
71  tracing_state.push_scope(scope_name)
72  try:
73  yield
74  finally:
75  if tracing_state:
76  tracing_state.pop_scope()
77 
78 
79 DEFAULT_EXTRA_FILES_MAP = torch._C.ExtraFilesMap()
80 
81 
82 def load(f, map_location=None, _extra_files=DEFAULT_EXTRA_FILES_MAP):
83  r"""
84  Load a ``ScriptModule`` previously saved with :func:`save <torch.jit.save>`
85 
86  All previously saved modules, no matter their device, are first loaded onto CPU,
87  and then are moved to the devices they were saved from. If this fails (e.g. because
88  the run time system doesn't have certain devices), an exception is raised.
89  However, storages can be dynamically remapped to an alternative set of devices
90  using the `map_location` argument. Comparing to :func:`torch.load`, `map_location`
91  in this function is simplified, which only accepts a string (e.g., 'cpu', 'cuda:0'),
92  or torch.device (e.g., torch.device('cpu'))
93 
94  Arguments:
95  f: a file-like object (has to implement read, readline, tell, and seek),
96  or a string containing a file name
97  map_location: can a string (e.g., 'cpu', 'cuda:0'), a device (e.g.,
98  torch.device('cpu'))
99  _extra_files: map from filename to content. The extra
100  filenames given in the map would be loaded and their content
101  would be stored in the provided map.
102 
103 
104  Returns:
105  A ``ScriptModule`` object.
106 
107  Example:
108  >>> torch.jit.load('scriptmodule.pt')
109  # Load ScriptModule from io.BytesIO object
110  >>> with open('scriptmodule.pt', 'rb') as f:
111  buffer = io.BytesIO(f.read())
112  # Load all tensors to the original device
113  >>> torch.jit.load(buffer)
114  # Load all tensors onto CPU, using a device
115  >>> torch.jit.load(buffer, map_location=torch.device('cpu'))
116  # Load all tensors onto CPU, using a string
117  >>> torch.jit.load(buffer, map_location='cpu')
118  # Load with extra files.
119  >>> files = {'metadata.json' : ''}
120  >>> torch.jit.load('scriptmodule.pt', _extra_files = files)
121  >>> print (files['metadata.json'])
122  """
123  m = ScriptModule()
124 
125  def module_lookup(names):
126  curr = m
127  for name in names:
128  if not hasattr(curr, name):
129  setattr(curr, name, ScriptModule())
130  curr = getattr(curr, name)
131  return curr
132  if isinstance(f, string_classes):
133  if not os.path.exists(f):
134  raise ValueError("The provided filename {} does not exist".format(f))
135  if isinstance(map_location, string_classes):
136  map_location = torch.device(map_location)
137  elif not (map_location is None or
138  isinstance(map_location, torch.device)):
139  raise ValueError("map_location should be either None, string or torch.device, "
140  "but got type: " + str(type(map_location)))
141  if (str(map_location).startswith('cuda')):
142  validate_cuda_device(map_location)
143 
144  if isinstance(f, str) or \
145  (sys.version_info[0] == 2 and isinstance(f, unicode)) or \
146  (sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
147  torch._C.import_ir_module(module_lookup, f, map_location, _extra_files)
148  else:
149  torch._C.import_ir_module_from_buffer(module_lookup, f.read(), map_location, _extra_files)
150 
151  return m
152 
153 
154 def save(m, f, _extra_files=DEFAULT_EXTRA_FILES_MAP):
155  """
156  Saves a ScriptModule to a file.
157 
158  Args:
159  m: a ScriptModule to save
160  f: a file-like object (has to implement write and flush) or a string
161  containing a file name
162  _extra_files: Map from filename to contents which will be stored as part of 'f'
163 
164  .. warning::
165  If you are using Python 2, torch.save does NOT support StringIO.StringIO
166  as a valid file-like object. This is because the write method should return
167  the number of bytes written; StringIO.write() does not do this.
168 
169  Please use something like io.BytesIO instead.
170 
171  Example:
172  >>> m = torch.jit.ScriptModule()
173  >>> # Save to file
174  >>> torch.jit.save(m, 'scriptmodule.pt')
175  >>> # Save to io.BytesIO buffer
176  >>> buffer = io.BytesIO()
177  >>> torch.jit.save(m, buffer)
178  >>> # Save with extra files
179  >>> extra_files = torch._C.ExtraFilesMap()
180  >>> extra_files['foo.txt'] = 'bar'
181  >>> torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)
182  """
183  if isinstance(f, str) or \
184  (sys.version_info[0] == 2 and isinstance(f, unicode)) or \
185  (sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
186  m.save(f, _extra_files=_extra_files)
187  else:
188  ret = m.save_to_buffer(_extra_files=_extra_files)
189  f.write(ret)
190 
191 
192 def get_trace_graph(f, args=(), kwargs=None, _force_outplace=False, return_inputs=False):
193  """
194  Trace a function or model, returning a tuple consisting of the both the
195  *trace* of an execution, as well as the original return value. If return_inputs,
196  also returns the trace inputs as part of the tuple
197 
198  Tracing is guaranteed not to change the semantics of the function/module
199  that is traced.
200 
201  Arguments:
202  f (torch.nn.Module or function): the function or module
203  to be traced.
204  args (tuple or Tensor): the positional arguments to pass to the
205  function/module to be traced. A non-tuple is assumed to
206  be a single positional argument to be passed to the model.
207  kwargs (dict): the keyword arguments to pass to the function/module
208  to be traced.
209 
210  Example: Trace a cell.
211 
212  >>> trace, out = jit.trace(nn.LSTMCell(), (input, hidden))
213  >>> print(trace)
214  """
215  if kwargs is None:
216  kwargs = {}
217  if not isinstance(args, tuple):
218  args = (args,)
219  return LegacyTracedModule(f, _force_outplace, return_inputs)(*args, **kwargs)
220 
221 
222 def _unique_state_dict(module, keep_vars=False):
223  state_dict = module.state_dict(keep_vars=keep_vars)
224  filtered_dict = type(state_dict)()
225  seen_ids = set()
226  for k, v in state_dict.items():
227  if id(v) in seen_ids:
228  continue
229  seen_ids.add(id(v))
230  filtered_dict[k] = v
231  return filtered_dict
232 
233 
234 def _create_interpreter_name_lookup_fn(frames_up=1):
235  def _get_interpreter_name_for_var(var):
236  frame = inspect.currentframe()
237  i = 0
238  while i < frames_up + 1:
239  frame = frame.f_back
240  i += 1
241 
242  f_locals = frame.f_locals
243  f_globals = frame.f_globals
244 
245  for k, v in f_locals.items():
246  if isinstance(v, torch.Tensor) and var is v:
247  return k if k != 'self' else ''
248  for k, v in f_globals.items():
249  if isinstance(v, torch.Tensor) and var is v:
250  return k if k != 'self' else ''
251  return ''
252  return _get_interpreter_name_for_var
253 
254 
255 class LegacyTracedModule(Module):
256  def __init__(self, inner, force_outplace=False, return_inputs=False):
257  super(LegacyTracedModule, self).__init__()
258  # inner may be a Module, or it may be an arbitrary callable
259  # If it's a Module, we get its parameters automatically, which lets
260  # us avoid a special casing functions versus modules.
261  self.inner = inner
262  self._force_outplace = force_outplace
263  self._return_inputs = return_inputs
264 
265  def forward(self, *args):
266  in_vars, in_desc = _flatten(args)
267  # NOTE: use full state, because we need it for BatchNorm export
268  # This differs from the compiler path, which doesn't support it at the moment.
269  module_state = list(_unique_state_dict(self, keep_vars=True).values())
270  trace, all_trace_inputs = torch._C._tracer_enter(*(in_vars + module_state))
271  ret_inputs = tuple(x.clone() for x in all_trace_inputs)
272  torch._C._tracer_set_force_outplace(self._force_outplace)
273  torch._C._tracer_set_get_unique_name_fn(_create_interpreter_name_lookup_fn())
274  try:
275  trace_inputs = _unflatten(all_trace_inputs[:len(in_vars)], in_desc)
276  out = self.inner(*trace_inputs)
277  out_vars, _ = _flatten(out)
278  torch._C._tracer_exit(tuple(out_vars))
279  except Exception:
280  torch._C._tracer_abandon()
281  raise
282  if self._return_inputs:
283  return trace, out, ret_inputs
284  else:
285  return trace, out
286 
287 
288 def _clone_inputs(args):
289  def clone_input(a):
290  if a is None:
291  return None
292  elif isinstance(a, torch.Tensor):
293  # TODO: figure out one liner to .clone() and set requires_grad
294  v = Variable(a.data.clone(), requires_grad=a.requires_grad)
295  if a.grad is not None:
296  v.grad = clone_input(v.grad)
297  return v
298  else:
299  return a.clone()
300  return function._nested_map(lambda x: isinstance(x, torch.Tensor),
301  clone_input, condition_msg="tensors")(args)
302 
303 
304 # This is purely for developer debugging. We are not going to advertise it.
305 _JIT_DUMP = os.environ.get('PYTORCH_JIT_DUMP', False)
306 _JIT_TIME = os.environ.get('PYTORCH_JIT_TIME', False) # CUDA-only timing
307 _JIT_DISABLE = os.environ.get('PYTORCH_JIT_DISABLE', False)
308 _JIT_STATS = os.environ.get('PYTORCH_JIT_STATS', False)
309 
310 
311 def _dump_trace(trace_name, pass_name, input_key, trace):
312  if not _JIT_DUMP:
313  return
314 
315  import torch.contrib._graph_vis as graph_vis
316 
317  filename = "{}_{}".format(trace_name, pass_name)
318  # TODO: Also paste out the backtrace when the trace was compiled
319  # (and maybe also when it was run?)
320  with open(filename + ".ir", "w") as f:
321  f.write("Input key: {}\n\n{}".format(input_key, str(trace)))
322  graph_vis.write(trace.graph(), filename + ".html")
323 
324 
325 @contextlib.contextmanager
326 def _time(trace_name, name, time=True):
327  if (not _JIT_TIME and not time) or not torch.cuda.is_available():
328  yield
329  return
330  stream = torch.cuda.current_stream()
331  start = torch.cuda.Event(enable_timing=True)
332  end = torch.cuda.Event(enable_timing=True)
333  stream.record_event(start)
334  try:
335  yield
336  finally:
337  stream.record_event(end)
338  end.synchronize()
339  print("{} {} time: {} ms".format(trace_name, name, start.elapsed_time(end)))
340 
341 
342 def verify(model, args, loss_fn=torch.sum, devices=None):
343  """
344  Verify that a JIT compiled model has the same behavior as its uncompiled
345  version along with its backwards pass. If your model returns multiple
346  outputs, you must also specify a `loss_fn` to produce a loss for which
347  the backwards will be computed.
348 
349  This function has side-effects (e.g., it executes your model / saves and loads
350  parameters), so don't expect the model to come out exactly the same as what
351  you passed in.
352 
353  Arguments:
354  model (compiled torch.nn.Module or function): the module/function to be
355  verified. The module/function definition MUST have been decorated with
356  `@torch.jit.compile`.
357  args (tuple or Tensor): the positional arguments to pass to the
358  compiled function/module to be verified. A non-tuple is assumed to
359  be a single positional argument to be passed to the model.
360  loss_fn (function, optional): the loss function to be applied to
361  the output of the model, before backwards is invoked. By default,
362  we assume that a model returns a single result, and we :func:`torch.sum`
363  before calling backwards; if this is inappropriate, you can pass your
364  own loss function. Note that if a model returns a tuple of results,
365  these are passed as separate positional arguments to `loss_fn`.
366  devices (iterable of device IDs, optional): the GPU devices which the
367  compiled module will be run on. This determines the RNG state we
368  must save when running both compiled and uncompiled versions of the model.
369  """
370  # TODO: In principle, we track device information in our trace, so it
371  # should be possible to check if our execution actually obeyed the 'devices'
372  # the user provided.
373 
374  # TODO: Consider adding a utility function to torch.jit to test
375  # for this case
376  if not isinstance(model, torch._C.CompiledFunction):
377  raise TypeError("Cannot verify an uncompiled module. Add @torch.jit.compile to compile it")
378  is_module = isinstance(model, Module)
379 
380  if not isinstance(args, tuple):
381  args = (args,)
382 
383  saved_args = _clone_inputs(args)
384  if is_module:
385  saved_state = copy.deepcopy(model.state_dict())
386 
387  def run_fwd_bwd(args, force_trace=False, assert_compiled=False):
388  params = list(model.parameters()) if is_module else []
389  in_vars, _ = _flatten((args, params))
390  # We use a special API to reset the trace and compile it from scratch.
391  compiled_fn = model
392  if force_trace:
393  compiled_fn.clear_cache()
394  if assert_compiled:
395  hits = compiled_fn.hits
396  out = model(*args)
397  if assert_compiled and compiled_fn.hits == hits:
398  raise RuntimeError("failed to use the compiled function")
399  if not isinstance(out, tuple):
400  out = (out, )
401  if loss_fn == torch.sum and len(out) != 1:
402  raise ValueError(("Model returns {} outputs, but default loss function "
403  "(torch.sum) can only handle a single output").format(len(out)))
404  out_vars, _ = _flatten(out)
405  saved_outs = [v.data.clone() for v in out_vars]
406  loss = loss_fn(*out)
407  grads = torch.autograd.grad([loss], in_vars)
408  # TODO: I'm not sure if the clone here is necessary but it is safer
409  saved_grads = [v.data.clone() for v in grads]
410  return (saved_outs, saved_grads)
411 
412  with torch.random.fork_rng(devices, _caller="torch.jit.verify"):
413  uncompiled_outs, uncompiled_grads = run_fwd_bwd(args, force_trace=True)
414  assert model.has_trace_for(*args)
415 
416  if is_module:
417  model.load_state_dict(saved_state)
418  compiled_outs, compiled_grads = run_fwd_bwd(args, assert_compiled=True)
419 
420  _verify_equal(uncompiled_outs, compiled_outs)
421  _verify_equal(uncompiled_grads, compiled_grads)
422 
423 
424 def _verify_equal(xs, ys):
425  for x, y in zip(xs, ys):
426  if x.sub(y).abs().max() > 1e-6:
427  raise RuntimeError("JIT and real computation mismatch")
428 
429 
430 def indent(s):
431  return '\n'.join(['\t' + line for line in s.splitlines()])
432 
433 
434 class TracingCheckError(Exception):
435  def __init__(self, graph_diff_error, tensor_compare_error, extra_msg=None):
436  self.message = 'Tracing failed sanity checks!\n'
437  if extra_msg is not None:
438  self.message += extra_msg + '\n'
439  if graph_diff_error is not None:
440  self.message += 'ERROR: Graphs differed across invocations!\n'
441  self.message += indent(graph_diff_error) + '\n'
442  if tensor_compare_error is not None:
443  self.message += 'ERROR: Tensor-valued Constant nodes differed in value ' \
444  'across invocations. This often indicates that the tracer has' \
445  ' encountered untraceable code.\n'
446  self.message += indent(tensor_compare_error) + '\n'
447  super(TracingCheckError, self).__init__(self.message)
448 
449 
450 # Check the traced module against a set of user-provided validation inputs
451 @torch.no_grad()
452 def _check_trace(check_inputs, func, executor_options, module, check_tolerance, force_outplace):
453  # Note: tracing is independent of optimizations, which consume the trace
454  executor_options['optimize'] = False
455  for inputs in check_inputs:
456  if isinstance(inputs, torch.Tensor):
457  inputs = (inputs,)
458  check_mod = torch.jit.trace(
459  func,
460  _clone_inputs(inputs),
461  check_trace=False,
462  _force_outplace=force_outplace,
463  **executor_options)
464 
465  def graph_diagnostic_info():
466  mod_canonicalized = torch._C._jit_pass_canonicalize(module.graph)
467  torch._C._jit_pass_erase_shape_information(mod_canonicalized)
468  check_canonicalized = torch._C._jit_pass_canonicalize(check_mod.graph)
469  torch._C._jit_pass_erase_shape_information(check_canonicalized)
470 
471  graph_diff_errors = None
472  if str(mod_canonicalized) != str(check_canonicalized):
473  import difflib
474  graph_diff = difflib.ndiff(str(mod_canonicalized).splitlines(True),
475  str(check_canonicalized).splitlines(True))
476  graph_diff_errors = 'Graph diff:\n' + indent(''.join(graph_diff)) + '\n'
477 
478  for n_mod, n_check in zip(mod_canonicalized.nodes(), check_canonicalized.nodes()):
479  if str(n_mod) != str(n_check):
480  graph_diff_errors += 'First diverging operator:\n'
481  node_diff = difflib.ndiff(str(n_mod).splitlines(True),
482  str(n_check).splitlines(True))
483  source_printout = 'Node diff:\n' + indent(''.join(node_diff)) + '\n'
484  mod_stack = n_mod.getSourceLocation()
485  if mod_stack:
486  source_printout += 'Trace source location:\n' + indent(mod_stack) + '\n'
487  check_stack = n_check.getSourceLocation()
488  if check_stack:
489  source_printout += 'Check source location:\n' + indent(check_stack) + '\n'
490  graph_diff_errors += source_printout
491 
492  break # For now, only print out the first pair of nodes that diverges
493 
494  tensor_compare_errors = None
495  # Check Tensor-valued constant nodes
496  for n_mod, n_check in zip(mod_canonicalized.nodes(), check_canonicalized.nodes()):
497  if n_mod.kind() != n_check.kind():
498  break # Graphs have already diverged
499 
500  if n_mod.kind() == 'prim::Constant' and not (n_mod.mustBeNone() or n_check.mustBeNone()):
501  if n_mod.kindOf('value') != 't' or n_check.kindOf('value') != 't':
502  continue
503 
504  mod_tensor_val = n_mod.t('value')
505  check_tensor_val = n_check.t('value')
506 
507  try:
508  torch.testing.assert_allclose(mod_tensor_val, check_tensor_val)
509  except (RuntimeError, AssertionError) as e:
510  if tensor_compare_errors is None:
511  tensor_compare_errors = ''
512  tensor_compare_errors += 'Node:\n' + indent(str(n_mod)) + '\n'
513  compare_stack = n_mod.getSourceLocation()
514  if compare_stack:
515  tensor_compare_errors += 'Source Location:\n' + indent(compare_stack) + '\n'
516  tensor_compare_errors += 'Comparison exception: ' + indent(str(e))
517 
518  break # For now, only print the first diverging pair
519 
520  return graph_diff_errors, tensor_compare_errors
521 
522  def wrap_retval(x):
523  return x if isinstance(x, tuple) else (x,)
524 
525  def run_mod_and_filter_tensor_outputs(mod, inputs, running_what):
526  try:
527  outs = wrap_retval(mod(*_clone_inputs(inputs)))
528  outs = [out for out in outs if isinstance(out, torch.Tensor)]
529  return outs
530  except Exception as e:
531  raise TracingCheckError(*graph_diagnostic_info(),
532  extra_msg='Encountered an exception while running the ' + running_what +
533  ' with test inputs.\nException:\n' + indent(str(e)))
534 
535  has_warned = [False]
536 
537  def maybe_warn_nondeterministic():
538  if has_warned[0]:
539  return
540  has_warned[0] = True
541  nondeterm_ops = [op for op in module.graph.nodes() if op.isNondeterministic()]
542  if len(nondeterm_ops) > 0:
543  nondeterministic_ops_warning = "Trace had nondeterministic nodes. "
544  nondeterministic_ops_warning += "Did you forget call .eval() on your model? Nodes:\n"
545  nondeterministic_ops_warning += "\n".join([indent(str(op)) for op in nondeterm_ops][:20])
546  nondeterministic_ops_warning += "\nThis may cause errors in trace checking. To disable trace checking,"\
547  " pass check_trace=False to torch.jit.trace()"
548  warnings.warn(nondeterministic_ops_warning, category=TracerWarning, stacklevel=5)
549 
550  def compare_outputs(original, reference, match_what):
551  all_ok = True
552  for i, (orig, ref) in enumerate(zip(original, reference)):
553  try:
554  torch.testing.assert_allclose(orig.double(), ref.double(), rtol=check_tolerance,
555  atol=torch.testing._get_default_tolerance(orig, ref)[1])
556  except AssertionError as e:
557  maybe_warn_nondeterministic()
558  warnings.warn('Output nr ' + str(i + 1) + '. of the traced function does not match '
559  'the corresponding output of the ' + match_what + '. Detailed error:\n' + str(e),
560  category=TracerWarning, stacklevel=4)
561  all_ok = False
562 
563  return all_ok
564 
565  traced_outs = run_mod_and_filter_tensor_outputs(module, inputs, 'trace')
566  fn_outs = run_mod_and_filter_tensor_outputs(func, inputs, 'Python function')
567  if compare_outputs(traced_outs, fn_outs, 'Python function'):
568  check_outs = run_mod_and_filter_tensor_outputs(check_mod, inputs, 'repeated trace')
569  compare_outputs(traced_outs, check_outs, 'repeated trace')
570 
571  diag_info = graph_diagnostic_info()
572  if any(info is not None for info in diag_info):
573  raise TracingCheckError(*diag_info)
574 
575 
576 class TracerWarning(Warning):
577  @staticmethod
578  def ignore_lib_warnings():
579  # We ignore warnings from all submodules excluding the JIT, because we need them e.g. for _check_trace
580  warnings.filterwarnings('ignore', category=TracerWarning, module='torch.(?!jit)')
581 
582 
583 # We ignore the tracer warnings coming form inside the library, because all our shape
584 # checks in nn will trigger them.
585 TracerWarning.ignore_lib_warnings()
586 torch._C._tracer_warn_use_python()
587 
588 
589 def trace(func,
590  example_inputs,
591  optimize=True,
592  check_trace=True,
593  check_inputs=None,
594  check_tolerance=1e-5,
595  _force_outplace=False,
596  _module_class=None):
597  """
598  Trace a function and return an executable trace that will be optimized
599  using just-in-time compilation.
600 
601  .. warning::
602 
603  Tracing only correctly records functions and modules which are not data
604  dependent (e.g., have conditionals on data in tensors) and do not have
605  any untracked external dependencies (e.g., perform input/output or
606  access global variables). If you trace such models, you may silently get
607  incorrect results on subsequent invocations of the model. The tracer
608  will try to emit warnings when doing something that may cause an
609  incorrect trace to be produced.
610 
611  Arguments:
612  func (callable or torch.nn.Module): a python function or torch.nn.Module
613  that will be run with example_inputs.
614  arguments and returns to func must be Tensors
615  or (possibly nested) tuples that
616  contain tensors.
617  example_inputs (tuple): a tuple of example inputs that will be passed to the function
618  while tracing. The resulting trace can be run with
619  inputs of different types and shapes assuming the traced operations
620  support those types and shapes. example_inputs may also be a single
621  Tensor in which case it is automatically wrapped in a tuple
622 
623  Keyword arguments:
624  optimize (bool, optional): whether or not to apply optimizations. Default: ``True``.
625  check_trace (bool, optional): check if the same inputs run through
626  traced code produce the same outputs. Default: ``True``. You might want
627  to disable this if, for example, your network contains non-
628  deterministic ops or if you are sure that the network is correct despite
629  a checker failure.
630 
631  check_inputs (list of tuples, optional): A list of tuples of input arguments that should be used
632  to check the trace against what is expected. Each tuple
633  is equivalent to a seet of input arguments that would
634  be specified in ``args``. For best results, pass in a
635  set of checking inputs representative of the space of
636  shapes and types of inputs you expect the network to see.
637  If not specified, the original ``args`` is used for checking
638  check_tolerance (float, optional): Floating-point comparison tolerance to use in the checker procedure.
639  This can be used to relax the checker strictness in the event that
640  results diverge numerically for a known reason, such as operator fusion.
641 
642  Returns:
643  A ``ScriptModule`` object with a single ``forward()`` method containing the traced code.
644  When func is a ``torch.nn.Module``, the returned ``ScriptModule`` will have the same set of
645  sub-modules and parameters as func.
646 
647  Example:
648  >>> def f(x):
649  ... return x * 2
650  >>> traced_f = torch.jit.trace(f, torch.rand(1))
651 
652  """
653  if not _enabled:
654  return func
655  executor_options = {'optimize': bool(optimize)}
656  # Special case for common case of passing a single Tensor
657  if isinstance(example_inputs, torch.Tensor):
658  example_inputs = (example_inputs,)
659  # done primarily so that weird iterables fail here and not pybind11 code
660  elif not isinstance(example_inputs, tuple):
661  example_inputs = tuple(example_inputs)
662  if _module_class:
663  module = _module_class(func, **executor_options)
664  else:
665  module = TopLevelTracedModule(func, **executor_options)
666  var_lookup_fn = _create_interpreter_name_lookup_fn(0)
667  module._create_method_from_trace('forward', func, example_inputs,
668  var_lookup_fn, _force_outplace)
669 
670  # Check the trace against new traces created from user-specified inputs
671  if check_trace:
672  if check_inputs is not None:
673  _check_trace(check_inputs, func, executor_options, module, check_tolerance, _force_outplace)
674  else:
675  _check_trace([example_inputs], func, executor_options, module, check_tolerance, _force_outplace)
676 
677  return module
678 
679 
680 class CompilationUnit(object):
681  def __init__(self, lang=None, optimize=True, _frames_up=0):
682  self.module = torch._C.ScriptModule()
683  self.module._set_optimized(optimize)
684  if lang is not None:
685  self.define(lang, _frames_up=_frames_up + 1)
686  self.optimize = optimize
687 
688  def define(self, lang, rcb=None, _frames_up=0):
689  if not rcb:
690  rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
691  self.module._define(lang, rcb, False)
692 
693  def __getattr__(self, attr):
694  return self.module._get_method(attr)
695 
696 
697 def _try_get_dispatched_fn(fn):
698  if not callable(fn):
699  return None
700  return _jit_internal.boolean_dispatched.get(fn)
701 
702 
703 def _try_get_overloaded_fn(fn):
704  if not hasattr(fn, '__self__') or not isinstance(fn.__self__, ScriptModule):
705  # Only allow overloads for bound methods
706  return None
707  overloads = fn.__self__._overloads.get(fn.__name__, None)
708  if overloads is None:
709  return None
710  return [getattr(fn.__self__, overload) for overload in overloads]
711 
712 
713 def _try_compile_weak_script(fn):
714  entry = _jit_internal.compiled_weak_fns.get(fn)
715  if entry is None:
716  return None
717  if entry["status"] == _jit_internal.COMPILATION_PENDING:
718  compiled_fn = torch.jit.script(fn, True, 0, entry["rcb"])
719  del entry["rcb"]
720  _jit_internal.compiled_weak_fns[fn]["compiled_fn"] = compiled_fn
721  entry["status"] = _jit_internal.COMPILED
722  return compiled_fn
723  else:
724  return entry["compiled_fn"]
725 
726 
727 def script(obj, optimize=True, _frames_up=0, _rcb=None):
728  if not _enabled:
729  return obj
730  if _rcb is None:
731  _rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
732  if inspect.isclass(obj):
733  mod = ScriptClass(obj.__name__)
734  ast = get_jit_class_def(obj)
735  _jit_script_class_compile(mod, ast, _rcb)
736  else:
737  mod = ScriptModule()
738  ast = get_jit_def(obj)
739  _jit_script_compile(mod, ast, _rcb, get_default_args(obj))
740  # Forward docstrings
741  mod.__doc__ = obj.__doc__
742  return mod
743 
744 
745 ScriptMethodStub = namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method'))
746 
747 
748 def script_method(fn, _rcb=None):
749  if not _enabled:
750  return fn
751  # NOTE: we need to traverse two frames here because the meta-class frame
752  # for ScriptModule will be present, as opposed to invoking @script on a
753  # a function or invoking define() on a CompilationUnit.
754  # The stack will look like:
755  #
756  # 0. createResolutionCallback()
757  # 1. script_method()
758  # 2. ScriptModule metaclass frame
759  # 3. Surrounding scope
760  #
761  # createResolutionCallback internally adds 1 to get us to the scope of this
762  # function (the calling function). Adding 2 gets us to the proper surrounding scope.
763  if _rcb is None:
764  _rcb = _jit_internal.createResolutionCallback(frames_up=2)
765  ast = get_jit_def(fn, self_name="ScriptModule")
766  return ScriptMethodStub(_rcb, ast, fn)
767 
768 
769 def _try_get_weak_module(mod):
770  """
771  Get the WeakScriptModuleProxy corresponding to mod if it exists
772  """
773  if not isinstance(mod, Module):
774  return None
775  return _jit_internal.weak_modules.get(mod)
776 
777 
778 def _try_get_ignored_op(fn):
779  if not callable(fn):
780  return False
781  if hasattr(fn, '__func__'):
782  fn = fn.__func__
783  return fn in _jit_internal.ignored_fns
784 
785 
786 def _is_weak_type(cls):
787  """
788  Check if a type has been annotated with `weak_module`
789  """
790  return cls in _jit_internal.weak_types
791 
792 
793 def batch(batch_size=1, optimize=True, _frames_up=0):
794  def decorator(fn):
795  if not _enabled:
796  return fn
797  import torch.jit.batchop
798  mod = script(fn, optimize, _frames_up)
799  res_graph = torch.to_batch_graph(mod.graph)
800  res_mod = ScriptModule()
801  res_mod._create_method_from_graph('forward', res_graph)
802 
803  def wrapper(*args):
804  new_args = []
805  for arg in args:
806  if isinstance(arg, torch.Tensor):
807  arg = BatchTensor(arg, batch_size)
808  if isinstance(arg, BatchTensor):
809  new_args.extend([arg.get_data(), arg.get_mask(), arg.get_dims()])
810  else:
811  new_args.append(arg)
812  res = res_mod(*new_args)
813  assert len(res) % 3 == 0
814  if len(res) % 3 != 0:
815  raise "non-batched-tensor output is not supported yet"
816  result = [BatchTensor(*res[i * 3: i * 3 + 3]) for i in range(len(res) // 3)]
817  if len(result) == 1:
818  return result[0]
819  return result
820  wrapper.__doc__ = fn.__doc__
821  return wrapper
822  return decorator
823 
824 
825 # These OrderedDictWrapper classes replace the actual OrderedDicts in
826 # module with versions that get/set properties inside of script::Module.
827 # This allows us to reuse most of nn.Module while still storing the
828 # data in C++.
829 # Each OrderedDict needs to support:
830 # x not in view
831 # x in view
832 # view[name] = ...
833 # view.values()
834 # del view[name]
835 # view.items()
836 # view.keys()
837 # len(view)
838 
839 class OrderedDictWrapper(object):
840  def __init__(self, module):
841  self.module_ref = weakref.ref(module)
842 
843  @property
844  def module(self):
845  r = self.module_ref()
846  if r is None:
847  raise RuntimeError("_parameters or _modules alive after module is dead")
848  return r
849 
850  def keys(self):
851  return [k for k, v in self.items()]
852 
853  def values(self):
854  return [v for k, v in self.items()]
855 
856  def __delitem__(self, k):
857  raise RuntimeError("cannot delete methods or parameters of a script module")
858 
859  def items(self):
860  raise NotImplementedError
861 
862  def __contains__(self, k):
863  raise NotImplementedError
864 
865  def __getitem__(self, k):
866  raise NotImplementedError
867 
868  def __setitem__(self, k, v):
869  raise NotImplementedError
870 
871 
873  def __init__(self, module):
874  super(OrderedModuleDict, self).__init__(module)
875  # contains _both_ script modules and non-script python-only modules
876 
877  # because script modules are subclassed in python and the
878  # C++ script::Module class will not hold references to them,
879  # to ensure that you always get the same python value here
880  # we store it in the python dict as well
881  self._python_modules = OrderedDict()
882 
883  def items(self):
884  r = self._python_modules.items()
885  return r
886 
887  def __contains__(self, k):
888  return k in self._python_modules
889 
890  def __setitem__(self, k, v):
891  if k in self._python_modules:
892  raise RuntimeError("cannot re-assign modules in a ScriptModule")
893  if isinstance(v, ScriptModule):
894  self.module._register_module(k, v)
895 
896  self._python_modules[k] = v
897 
898  def __getitem__(self, k):
899  return self._python_modules[k]
900 
901 
903  def __init__(self, module):
904  super(OrderedParameterDict, self).__init__(module)
905 
906  def items(self):
907  return [(name, param) for name, param in self.module._get_parameters()]
908 
909  def __setitem__(self, k, v):
910  self.module._register_parameter(k, v, False)
911 
912  def __contains__(self, k):
913  return self.module._has_parameter(k)
914 
915  def __getitem__(self, k):
916  if k not in self:
917  raise KeyError(k)
918  return self.module._get_parameter(k)
919 
920 
922  def __init__(self, module):
923  super(OrderedBufferDict, self).__init__(module)
924 
925  def items(self):
926  return [(name, param) for name, _, param in
927  self.module._get_attributes() if isinstance(param, torch.Tensor)]
928 
929  def __setitem__(self, k, v):
930  self.module._register_buffer(k, v)
931 
932  def __contains__(self, k):
933  return self.module._has_buffer(k)
934 
935  def __getitem__(self, k):
936  if k not in self:
937  raise KeyError(k)
938  return self.module._get_buffer(k)
939 
940 # base types that can be constants
941 # in addition, tuples and lists of these base types are also considered constants
942 # If you edit this list, then you also need to edit the handlers in
943 # ConstantValue in jit/script/init.cpp
944 _constant_types = (bool, float, int, str, type(None), types.FunctionType, torch.device, torch.layout, torch.dtype)
945 
946 
947 def _get_valid_constant(attr, v):
948  if isinstance(v, _constant_types):
949  return v
950  elif isinstance(v, tuple) or isinstance(v, list):
951  return tuple(_get_valid_constant(attr, x) for x in v)
952  constants = ", ".join(typ.__name__ for typ in _constant_types)
953  raise TypeError(textwrap.dedent("""
954  '{}' object for attribute '{}' is not a valid constant.
955  Valid constants are:
956  1. a nn.ModuleList
957  2. a value of type {{{}}}
958  3. a list or tuple of (2)
959  """.format(type(v).__name__, attr, constants)))
960 
961 
962 def _create_methods_from_stubs(self, stubs):
963  defs = [m.def_ for m in stubs]
964  rcbs = [m.resolution_callback for m in stubs]
965  defaults = [get_default_args(m.original_method) for m in stubs]
966  self._create_methods(defs, rcbs, defaults)
967 
968 # For each user-defined class that subclasses ScriptModule this meta-class,
969 # (1) finds all the methods annotated with @script_method
970 # in a ScriptModule and removes them from the class attributes, and
971 # (2) puts a wrapper around the class's __init__ method to register
972 # all of the script_methods with the module after the original __init__
973 # has run. This has to occur after the user-defined __init__ so that
974 # submodules and parameters are initialized _before_ the script compiler
975 # resolve references to `self.param` or `self.module`.
976 
977 
978 class ScriptMeta(type(torch._C.ScriptModule)):
979  # this has to inherit from pybind11's metaclass otherwise we get
980  # issues because ScriptModule inherits from torch._C.ScriptModule,
981  # a pybind11 type
982  def __init__(cls, name, bases, attrs):
983  # find all the script methods
984  cls._original_methods = {}
985  methods = []
986  for k, v in sorted(attrs.items()):
987  if isinstance(v, ScriptMethodStub):
988  delattr(cls, k)
989  methods.append(v)
990  cls._original_methods[v.original_method.__name__] = v.original_method
991  # after the user's __init__ register all the script methods
992  # with the module
993  original_init = getattr(cls, '__init__', lambda self: None)
994  super_constants = getattr(super(cls), '_constants_set', set())
995  cls._constants_set = set(getattr(cls, '__constants__', ())).union(super_constants)
996  cls._overloads = dict(getattr(cls, '__overloads__', {}))
997 
998  @functools.wraps(original_init)
999  def init_then_register(self, *args, **kwargs):
1000  # ensure even if the user forgets to call super that
1001  # the pybind object is initialized so it will not segfault
1002  # run this once, before the most-derived __init__ is called
1003  if cls is type(self):
1004  torch._C.ScriptModule.__init__(self)
1005  original_init(self, *args, **kwargs)
1006  _create_methods_from_stubs(self, methods)
1007 
1008  cls.__init__ = init_then_register
1009  return super(ScriptMeta, cls).__init__(name, bases, attrs)
1010 
1011 
1012 if _enabled:
1013  class ScriptModule(with_metaclass(ScriptMeta, torch._C.ScriptModule, Module)):
1014  r"""
1015  The core data structure in TorchScript is the ``ScriptModule``. It is an
1016  analogue of torch's nn.Module and represents an entire model as a tree of
1017  submodules. Like normal modules, each individual module in a ScriptModule can
1018  have submodules, parameters, and methods. In nn.Modules methods are implemented
1019  as Python functions, but in ScriptModules methods typically implemented as
1020  *TorchScript* functions, a statically-typed subset of Python that contains all
1021  of PyTorch's built-in Tensor operations. This difference allows your
1022  ScriptModules code to run without the need for a Python interpreter.
1023 
1024  ScriptModules and the TorchScript functions inside of them can be created in
1025  two ways:
1026 
1027  **Tracing:**
1028 
1029  Using ``torch.jit.trace``, you can take an existing module or python
1030  function, provide example inputs, and we run the function, recording the
1031  operations performed on all the tensors. We turn the resulting recording
1032  into a TorchScript method that is installed as the ``forward`` method of a
1033  ScriptModule. This module also contains any parameters that the original
1034  module had as well.
1035 
1036  Example::
1037 
1038  import torch
1039  def foo(x, y):
1040  return 2*x + y
1041  traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))
1042 
1043  .. note::
1044  Tracing a *function* will produce a ``ScriptModule`` with a single
1045  ``forward`` method that implements that function, and that contains
1046  no parameters.
1047 
1048  Example::
1049 
1050  import torch
1051  import torchvision
1052  traced_net = torch.jit.trace(torchvision.models.resnet18(),
1053  torch.rand(1, 3, 224, 224))
1054 
1055  .. note::
1056 
1057  Tracing only records operations done when the given function is run on the given
1058  tensors. Therefore, the returned ``ScriptModule`` will always run the same traced
1059  graph on any input. This has some important implications when your module is
1060  expected to run different sets of operations, depending on the input and/or the
1061  module state. For example,
1062 
1063  + Tracing will not record any control-flow like if statements or loops. When
1064  this control-flow is constant across your module, this is fine and it often
1065  just inlines configuration decisions. But sometimes the control-flow is
1066  actually part of the model itself. For instance, a recurrent network is
1067  a loop over the (possibly dynamic) length of an input sequence.
1068 
1069  + In the returned ``ScriptModule``, operations that have different behaviors
1070  in ``training`` and ``eval`` modes will always behave as if it is in the
1071  mode it was in during tracing, no matter which mode the ``ScriptModule``
1072  is in.
1073 
1074  In cases like these, tracing would not be appropriate and scripting is a better
1075  choice.
1076 
1077  **Scripting:**
1078 
1079  You can write TorchScript code directly using Python syntax. You do this
1080  using the ``torch.jit.script`` annotation (for functions) or
1081  ``torch.jit.script_method`` annotation (for methods) on subclasses of
1082  ScriptModule. With this annotation the body of the annotated function is
1083  directly translated into TorchScript. TorchScript itself is a subset of
1084  the Python language, so not all features in python work, but we provide
1085  enough functionality to compute on tensors and do control-dependent
1086  operations.
1087 
1088  Example::
1089 
1090  import torch
1091  @torch.jit.script
1092  def foo(x, y):
1093  if x.max() > y.max():
1094  r = x
1095  else:
1096  r = y
1097  return r
1098 
1099  .. note::
1100  A script *function* annotation will construct a ScriptModule
1101  with a single ``forward`` method that implements that function,
1102  and that contains no parameters.
1103 
1104  Example::
1105 
1106  import torch
1107  class MyModule(torch.jit.ScriptModule):
1108  def __init__(self, N, M):
1109  super(MyModule, self).__init__()
1110  self.weight = torch.nn.Parameter(torch.rand(N, M))
1111 
1112  @torch.jit.script_method
1113  def forward(self, input):
1114  return self.weight.mv(input)
1115 
1116  Example::
1117 
1118  import torch
1119  import torch.nn as nn
1120  import torch.nn.functional as F
1121  from torch.jit import ScriptModule, script_method, trace
1122 
1123  class MyScriptModule(ScriptModule):
1124  def __init__(self):
1125  super(MyScriptModule, self).__init__()
1126  # trace produces a ScriptModule's conv1 and conv2
1127  self.conv1 = trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
1128  self.conv2 = trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))
1129 
1130  @script_method
1131  def forward(self, input):
1132  input = F.relu(self.conv1(input))
1133  input = F.relu(self.conv2(input))
1134  return input
1135  """
1136 
1137  def __init__(self, optimize=True):
1138  # must be before Module.init since the field is used in __getattr__
1139  Module.__init__(self)
1140  self._set_optimized(optimize)
1141  self._parameters = OrderedParameterDict(self)
1142  self._buffers = OrderedBufferDict(self)
1143  self._modules = OrderedModuleDict(self)
1144 
1145  def __getattr__(self, attr):
1146  if self._has_method(attr):
1147  if attr in self.__class__._original_methods:
1148  original_method = self.__class__._original_methods[attr]
1149  script_method = self._get_method(attr)
1150  return functools.wraps(original_method)(script_method)
1151  else:
1152  return self._get_method(attr)
1153  if attr == 'graph' and self._has_method('forward'):
1154  return self.__getattr__('forward').graph
1155  return Module.__getattr__(self, attr)
1156 
1157  def __setattr__(self, attr, value):
1158  if attr not in self._constants_set:
1159  if isinstance(value, Module) and _is_weak_type(type(value)):
1160  # Compile weak script module
1161  value = _make_strong(value)
1162  if attr == 'training':
1163  if self._has_buffer('training'):
1164  self.__dict__['training'] = value
1165  self._get_buffer('training').fill_(int(value))
1166  return
1167  if isinstance(value, Attribute):
1168  the_type = torch.jit.annotations.ann_to_type(value.type)
1169  try:
1170  self._register_attribute(attr, the_type, value.value)
1171  except RuntimeError:
1172  raise RuntimeError("Could not register attribute '{}' of type '{}' for a value of type '{}'"
1173  .format(attr, value.type, type(value.value)))
1174  return
1175  return super(ScriptModule, self).__setattr__(attr, value)
1176 
1177  if hasattr(self, attr):
1178  raise RuntimeError("attempting to re-assign constant '{}'".format(attr))
1179 
1180  def conv_module_to_const(module_value):
1181  if not isinstance(module_value, (ModuleList, Sequential)):
1182  return module_value
1183  for i in range(len(module_value)):
1184  module_value[i] = conv_module_to_const(module_value[i])
1185  if isinstance(module_value, Sequential):
1186  return _ConstSequential(module_value)
1187  else:
1188  return _ConstModuleList(module_value)
1189 
1190  if isinstance(value, (ModuleList, Sequential)):
1191  # special case for list of modules. Modules need to be registered with their
1192  # parent module. To do this, we create a ConstModuleList, which is itself a module, that
1193  # contains each of these modules as submodules. The ConstModuleList then
1194  # is set as an attribute of the parent module.
1195  super(ScriptModule, self).__setattr__(attr, conv_module_to_const(value))
1196  else:
1197  super(ScriptModule, self).__setattr__(attr, _get_valid_constant(attr, value))
1198 
1199  def __dir__(self):
1200  return sorted(Module.__dir__(self) + self._method_names())
1201 
1202  def define(self, lang):
1203  # We use frames_up=1 to get to the proper surrounding scope. The stack
1204  # will look like:
1205  # 0. createResolutionCallback
1206  # 1. define()
1207  # 2. surrounding scope.
1208  #
1209  # createResolutionCallback internally adds 1 to get us to our frame, then
1210  # we add 1 to get to the proper surrounding scope.
1211  rcb = _jit_internal.createResolutionCallback(frames_up=1)
1212  self._define(lang, rcb, True)
1213 
1214  def copy(self):
1215  m = ScriptModule()
1216 
1217  def module_lookup(names):
1218  curr = m
1219  for name in names:
1220  if not hasattr(curr, name):
1221  setattr(curr, name, ScriptModule())
1222  curr = getattr(curr, name)
1223  return curr
1224  self._copy_into(module_lookup, {}, [])
1225  return m
1226 
1227  def __getstate__(self):
1228  raise pickle.PickleError(
1229  "ScriptModules cannot be saved using torch.save. " +
1230  "Mixed serialization of script and non-script modules is not supported. " +
1231  "For purely script modules use my_script_module.save(<filename>) instead.")
1232 
1234  def __init__(self, original, stubs):
1235  # Guards behavior of __setattr__ and __getattr__ so ScriptModule
1236  # __init__ can run correctly
1237  self.__dict__['_initialized'] = False
1238  super(WeakScriptModuleProxy, self).__init__()
1239 
1240  self.__dict__["_original"] = weakref.ref(original)
1241 
1242  # Copy Parameters / Modules / Buffers
1243  for name in dir(original):
1244  item = getattr(original, name)
1245  if item is None and name in original._parameters:
1246  # XXX: treat None value simply as module attributes instead of adding them to the parameter list
1247  # TODO: need to handle this more generally when non-tensor attributes added to module
1248  object.__setattr__(self, name, item)
1249  elif isinstance(item, Parameter) or (isinstance(item, Module) and item is not self):
1250  ScriptModule.__setattr__(self, name, item)
1251  for name in original._buffers:
1252  if original._buffers[name] is None:
1253  object.__setattr__(self, name, None)
1254  else:
1255  self.register_buffer(name, original._buffers[name])
1256 
1257  # Copy constants
1258  self.__dict__["_constants_set"] = set(getattr(original, "__constants__", []))
1259 
1260  # Copy overloads
1261  self.__dict__["_overloads"] = dict(getattr(original, "__overloads__", {}))
1262 
1263  self.__dict__["_initialized"] = True
1264  _create_methods_from_stubs(self, stubs)
1265 
1266  def __getattr__(self, attr):
1267  # Try to get the attribute directly, if that fails, fall back to the
1268  # weak module itself
1269  try:
1270  return ScriptModule.__getattr__(self, attr)
1271  except AttributeError:
1272  if self.__dict__["_initialized"]:
1273  return getattr(self.__dict__["_original"](), attr)
1274  else:
1275  # Only fall back to original once __init__() is done
1276  raise AttributeError("Weak module has no attribute '{}'"
1277  .format(attr))
1278 
1279  def __setattr__(self, attr, value):
1280  # Once constructed, no new properties can be set
1281 
1282  if not self.__dict__["_initialized"]:
1283  # If constructing, don't fall back to original module
1284  return ScriptModule.__setattr__(self, attr, value)
1285 
1286  if hasattr(self, attr):
1287  return ScriptModule.__setattr__(self, attr, value)
1288  else:
1289  raise AttributeError("Cannot set new attribute '{}' on "
1290  "weak script module once it has been "
1291  "created".format(attr))
1292 
1294  def __init__(self, name):
1295  super(ScriptClass, self).__init__()
1296  self._name = name
1297 
1298 else:
1299  class ScriptModule(torch.nn.Module):
1300  def __init__(self, optimize=True):
1301  super(ScriptModule, self).__init__()
1302 
1303  class ScriptClass(ScriptModule):
1304  def __init__(self, name):
1305  super(ScriptClass, self).__init__()
1306 
1307 
1308 def _get_weak_stubs(cls):
1309  """
1310  Calls script_method for each method on the type of the object passed in and
1311  returns the generated ScriptMethodStubs
1312  """
1313  stubs = []
1314  for name in dir(cls):
1315  func = get_function_from_type(cls, name)
1316  if func in _jit_internal.weak_script_methods:
1317  entry = _jit_internal.weak_script_methods[func]
1318  stub = script_method(entry["original_method"], entry["rcb"])
1319  stubs.append(stub)
1320  return stubs
1321 
1322 
1323 def _make_strong(mod):
1324  """
1325  Converts a weak module into a subclass of ScriptModule
1326  """
1327  if mod in _jit_internal.weak_modules:
1328  return _jit_internal.weak_modules[mod]
1329 
1330  stubs = _jit_internal.weak_types.get(type(mod))["method_stubs"]
1331 
1332  if stubs is None:
1333  # Generate stubs and and store on weak_types in case this type is
1334  # used again
1335  stubs = _get_weak_stubs(type(mod))
1336  _jit_internal.weak_types[type(mod)]["method_stubs"] = stubs
1337 
1338  # Create proxy with stubs
1339  proxy = WeakScriptModuleProxy(mod, stubs)
1340 
1341  _jit_internal.weak_modules[mod] = proxy
1342 
1343  return proxy
1344 
1345 
1346 def _get_methods(cls):
1347  import inspect
1348  # In Python 3 unbound methods are functions, but in Python 2 they are methods
1349  return inspect.getmembers(cls, predicate=lambda x: inspect.isfunction(x) or inspect.ismethod(x))
1350 
1351 
1352 _compiled_methods_whitelist = {
1353  'forward', 'register_buffer', 'register_parameter', 'add_module',
1354  '_apply', 'apply', 'cuda', 'cpu', 'to', 'type', 'float', 'double', 'half',
1355  'state_dict', 'load_state_dict', '_load_from_state_dict',
1356  '_named_members', 'parameters', 'named_parameters',
1357  'buffers', 'named_buffers', 'children', 'named_children', 'modules',
1358  'named_modules', 'zero_grad', 'share_memory', '_get_name', 'extra_repr',
1359  '_slow_forward', '_tracing_name', 'eval', 'train',
1360 }
1361 
1362 
1363 def _make_fail(name):
1364  def fail(self, *args, **kwargs):
1365  raise RuntimeError(name + " is not supported on ScriptModules")
1366  return fail
1367 
1368 
1369 for name, method in _get_methods(torch.nn.Module):
1370  if name.startswith('__'):
1371  continue
1372  if name not in ScriptModule.__dict__ and name not in _compiled_methods_whitelist:
1373  setattr(ScriptModule, method.__name__, _make_fail(name))
1374 
1375 
1377  __frozen = False
1378 
1379  def __init__(self, orig, id_set=None, optimize=True):
1380  # XXX: orig can be a nn.Module or a function!
1381  super(TracedModule, self).__init__(optimize=optimize)
1382  if id_set is None:
1383  id_set = set()
1384 
1385  if not isinstance(orig, torch.nn.Module):
1386  self._name = orig.__name__
1387  orig = torch.nn.Module()
1388  else:
1389  self._name = 'TracedModule[' + type(orig).__name__ + ']'
1390 
1391  def check_unique(param):
1392  if param in id_set:
1393  raise ValueError("TracedModules don't support parameter sharing between modules")
1394  id_set.add(param)
1395 
1396  self.training = orig.training
1397 
1398  for name, param in orig._parameters.items():
1399  if param is not None:
1400  self._parameters[name] = param
1401  check_unique(param)
1402  for name, buf in orig._buffers.items():
1403  if buf is not None:
1404  self._buffers[name] = buf
1405  check_unique(buf)
1406 
1407  if orig._backward_hooks or orig._forward_hooks or orig._forward_pre_hooks:
1408  raise ValueError("Modules that have hooks assigned can't be compiled")
1409 
1410  for name, submodule in orig._modules.items():
1411  if isinstance(submodule, ScriptModule) and not isinstance(submodule, TracedModule):
1412  self._modules[name] = submodule.copy()
1413  else:
1414  self._modules[name] = TracedModule(submodule, id_set, optimize=optimize)
1415 
1416  self._freeze()
1417 
1418  def forward(self, *args, **kwargs):
1419  raise RuntimeError('Trace submodules cannot be called.')
1420 
1421  def _freeze(self):
1422  self.__frozen = True
1423 
1424  def _get_name(self):
1425  return self._name
1426 
1427  def __setattr__(self, attr, value):
1428  if not self.__frozen or hasattr(self, attr):
1429  return super(TracedModule, self).__setattr__(attr, value)
1430  raise RuntimeError("Cannot set new properties on a traced module.")
1431 
1432 
1434  def forward(self, *args, **kwargs):
1435  return self._get_method('forward')(*args, **kwargs)
1436 
1437 
1439  def __init__(self, modules):
1440  super(_ConstModuleList, self).__init__()
1441  for i, module in enumerate(modules):
1442  if _is_weak_type(type(module)):
1443  module = _make_strong(module)
1444  self.add_module(str(i), module)
1445 
1446  def __getitem__(self, idx):
1447  if isinstance(idx, slice):
1448  return _ConstModuleList(list(self._modules.values())[idx])
1449  else:
1450  if not (-len(self) <= idx < len(self)):
1451  raise IndexError('index {} is out of range'.format(idx))
1452  if idx < 0:
1453  idx += len(self)
1454  return self._modules[str(idx)]
1455 
1456  def __len__(self):
1457  return len(self._modules)
1458 
1459  def __iter__(self):
1460  return iter(self._modules.values())
1461 
1462  def __dir__(self):
1463  keys = super(_ConstModuleList, self).__dir__()
1464  keys = [key for key in keys if not key.isdigit()]
1465  return keys
1466 
1467 
1469  __constants__ = ['mods']
1470 
1471  def __init__(self, mods):
1472  super(_ConstSequential, self).__init__(mods._modules.values())
1473 
1474  # we define the forward method via self.define rather than
1475  # making it a direct class member (with a @script) annotation
1476  # because, in optimized runtime environments where only .pyc files
1477  # are shipped, we cant retrieve the source code.
1478  # TODO: find a workaround for this and remove this hack
1479  self.define("""
1480  def forward(self, input):
1481  for m in self:
1482  input = m(input)
1483  return input
1484  """)
1485 
1486 
1487 _builtin_table = None
1488 
1489 _modules_containing_builtins = (torch, torch._C._nn)
1490 
1491 
1492 def _unwrap_optional(x):
1493  assert x is not None, "Unwrapping null optional"
1494  return x
1495 
1496 
1497 # lazily built to ensure the correct initialization order
1498 def _get_builtin_table():
1499  global _builtin_table
1500  if _builtin_table is not None:
1501  return _builtin_table
1502  _builtin_table = {}
1503 
1504  def register_all(mod):
1505  for name in dir(mod):
1506  v = getattr(mod, name)
1507  if callable(v):
1508  _builtin_table[id(v)] = "aten::" + name
1509  for mod in _modules_containing_builtins:
1510  register_all(mod)
1511 
1512  _builtin_table[id(warnings.warn)] = "aten::warn"
1513  _builtin_table[id(_single)] = "aten::_single"
1514  _builtin_table[id(_pair)] = "aten::_pair"
1515  _builtin_table[id(_triple)] = "aten::_triple"
1516  _builtin_table[id(_quadruple)] = "aten::_quadruple"
1517  _builtin_table[id(_list_with_default)] = "aten::list_with_default"
1518  _builtin_table[id(_unwrap_optional)] = "aten::_unwrap_optional"
1519  _builtin_table[id(cudnn.is_acceptable)] = "aten::cudnn_is_acceptable"
1520  _builtin_table[id(torch._C._infer_size)] = "aten::_infer_size"
1521  _builtin_table[id(torch.nn.functional._no_grad_embedding_renorm_)] = "aten::_no_grad_embedding_renorm_"
1522 
1523  _builtin_table[id(math.floor)] = "aten::floor"
1524  _builtin_table[id(torch.nn.functional.interpolate)] = "aten::__interpolate"
1525  _builtin_table[id(torch.nn.functional.upsample_nearest)] = "aten::__upsample_nearest"
1526  _builtin_table[id(torch.nn.functional.upsample)] = "aten::__upsample"
1527  _builtin_table[id(torch.nn.functional.upsample_bilinear)] = "aten::__upsample_bilinear"
1528  _builtin_table[id(torch.nn.functional.assert_int_or_pair)] = "aten::_assert_int_or_pair"
1529  _builtin_table[id(torch.nn.utils.rnn.get_packed_sequence)] = "aten::_pack_sequence"
1530 
1531  return _builtin_table
1532 
1533 
1534 def _register_builtin(fn, op):
1535  _get_builtin_table()[id(fn)] = op
1536 
1537 
1538 def _find_builtin(fn):
1539  return _get_builtin_table().get(id(fn))
1540 
1541 
1542 _register_builtin(len, 'aten::len')
1543 _register_builtin(_wait, 'aten::wait')
1544 
1545 # torch.jit.Error
1546 Error = torch._C.JITException
1547 
1548 
1549 class _disable_tracing(object):
1550  def __enter__(self):
1551  self.state = torch._C._get_tracing_state()
1552  torch._C._set_tracing_state(None)
1553 
1554  def __exit__(self, *args):
1555  torch._C._set_tracing_state(self.state)
1556  self.state = None
1557 
1558 
1559 # for use in python if using annotate
1560 def annotate(the_type, the_value):
1561  # noop in python
1562  return the_value
1563 
1564 
1565 class Attribute(object):
1566  def __init__(self, value, the_type):
1567  self.value = value
1568  self.type = the_type
1569 
1570 
1571 if not torch._C._jit_init():
1572  raise RuntimeError("JIT initialization failed")
def assert_allclose(actual, expected, rtol=None, atol=None, equal_nan=True)
Definition: __init__.py:18
Module caffe2.python.scope.
def define(self, lang)
Definition: __init__.py:1202
def is_available()
Definition: __init__.py:45
Definition: model.py:1
def _get_default_tolerance(a, b=None)
Definition: __init__.py:102
def script(obj, optimize=True, _frames_up=0, _rcb=None)
Definition: __init__.py:727
def trace(func, example_inputs, optimize=True, check_trace=True, check_inputs=None, check_tolerance=1e-5, _force_outplace=False, _module_class=None)
Definition: __init__.py:596
def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices")
Definition: random.py:49
def define(self, lang, rcb=None, _frames_up=0)
Definition: __init__.py:688
Definition: verify.py:1
def grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False, only_inputs=True, allow_unused=False)
Definition: __init__.py:97
def __getattr__(self, attr)
Definition: __init__.py:1145
def current_stream(device=None)
Definition: __init__.py:361