Caffe2 - Python API
A deep learning, cross platform ML framework
verify.py
1 import torch
2 import torch.jit
3 import torch.onnx
4 
5 import onnx
6 import onnx.helper
7 
8 import numpy as np
9 
10 import difflib
11 import contextlib
12 import io
13 
14 
15 def colonize(msg, sep=": "):
16  if not msg:
17  return ""
18  else:
19  return msg + sep
20 
21 
22 class Errors(object):
23  """
24  An error-collecting object which supports error recovery.
25 
26  It is intended to be used like a context manager:
27 
28  >>> with Errors("Top-level error message") as errs:
29  >>> ...
30  """
31 
32  def __init__(self, msg, rtol=1e-3, atol=1e-5):
33  self.msg = msg
34  self.errors = []
35  self.context = []
36  self.rtol = rtol
37  self.atol = atol
38 
39  # Allocated upon instance creation so that multiple Errors
40  # can be used
41  class ShortCircuit(Exception):
42  pass
43  self.exc_class = ShortCircuit
44 
45  def requireAlmostEqual(self, x, y, msg=None):
46  """
47  Test that x and y are nearly equal (equal within self.rtol
48  precision); aborts execution if they are not.
49  """
50  self.almostEqualAndThen(x, y, msg, self.failWith)
51 
52  def checkAlmostEqual(self, x, y, msg=None):
53  """
54  Test that x and y are nearly equal (equal within self.rtol
55  precision), but continue execution even if they are not equal.
56 
57  To prevent error cascades, you should remember to call 'failIfErrs'
58  at some later point in time.
59  """
60  self.almostEqualAndThen(x, y, msg, self.addErr)
61 
62  def almostEqualAndThen(self, x, y, msg, k):
63  """
64  Helper for implementing 'requireAlmostEqual' and 'checkAlmostEqual'.
65  Upon failure, invokes continuation 'k' with the error message.
66 
67  At the moment, only tests on 'numpy.ndarray' are supported.
68  """
69  if isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
70  try:
71  np.testing.assert_allclose(x, y, rtol=self.rtol, atol=self.atol, equal_nan=False, verbose=True)
72  except AssertionError as e:
73  raise
74  k("{}{}".format(colonize(msg), str(e).lstrip()))
75  else:
76  raise RuntimeError("Unsupported almost equal test")
77 
78  def requireEqual(self, x, y, msg=None):
79  """
80  Test that x and y are equal; aborts execution if they are not.
81  """
82  self.equalAndThen(x, y, msg, self.failWith)
83 
84  def checkEqual(self, x, y, msg=None):
85  """
86  Test that x and y are equal, but continue execution even if they are not equal.
87 
88  To prevent error cascades, you should remember to call 'failIfErrs'
89  at some later point in time.
90  """
91  self.equalAndThen(x, y, msg, self.addErr)
92 
93  # Bit-for-bit accuracy test
94  def equalAndThen(self, x, y, msg, k):
95  """
96  Helper for implementing 'requireEqual' and 'checkEqual'. Upon failure,
97  invokes continuation 'k' with the error message.
98  """
99  if isinstance(x, onnx.TensorProto) and isinstance(y, onnx.TensorProto):
100  self.equalAndThen(x.name, y.name, msg, k)
101  # Use numpy for the comparison
102  t1 = onnx.numpy_helper.to_array(x)
103  t2 = onnx.numpy_helper.to_array(y)
104  new_msg = "{}In embedded parameter '{}'".format(colonize(msg), x.name)
105  self.equalAndThen(t1, t2, new_msg, k)
106  elif isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
107  try:
108  np.testing.assert_equal(x, y)
109  except AssertionError as e:
110  raise
111  k("{}{}".format(colonize(msg, ": "), str(e).lstrip()))
112  else:
113  if x != y:
114  # TODO: Better algorithm for lists
115  sx = str(x)
116  sy = str(y)
117  if len(sx) > 40 or len(sy) > 40 or '\n' in sx or '\n' in sy:
118  # long form
119  l = "=" * 50
120  k("\n{}The value\n{}\n{}\n{}\n\ndoes not equal\n\n{}\n{}\n{}"
121  .format(colonize(msg, ":\n"), l, sx, l, l, sy, l))
122  else:
123  k("{}{} != {}".format(colonize(msg), sx, sy))
124 
125  def requireMultiLineEqual(self, x, y, msg=None):
126  """
127  Test that long, multi-line strings x and y are equal;
128  aborts execution if they are not.
129  """
130  self.multiLineEqualAndThen(x, y, msg, self.failWith)
131 
132  def multiLineEqualAndThen(self, x, y, msg, k):
133  """
134  Helper for implementing 'requireMultiLineEqual'. Upon failure,
135  invokes continuation 'k' with the error message.
136  """
137  if msg is None:
138  msg = "Strings are not equal"
139  if x != y:
140  diff = difflib.ndiff(x.splitlines(True), y.splitlines(True))
141  k("{}{}".format(colonize(msg, ":\n\n"), "".join(diff)))
142 
143  def addErr(self, msg):
144  """
145  Add an error to the error context, but continue executing.
146  """
147  # TODO: instead of immediately concatenating the context in the msg,
148  # attach it as metadata and make a decision how to format it later.
149  msg_w_ctx = msg
150  for c in reversed(self.context):
151  msg += "\n\n * " + "\n ".join(c.splitlines())
152  self.errors.append(msg)
153 
154  def fail(self):
155  """
156  Immediately fail and short-circuit to the next recovery context.
157 
158  NB: It is an error to 'fail' without having added any errors to
159  the error context.
160  """
161  raise self.exc_class()
162 
163  def failWith(self, msg):
164  """
165  Add an error to the error context, and then short-circuit.
166  """
167  self.addErr(msg)
168  self.fail()
169 
170  def failIfErrs(self):
171  """
172  If there are any errors in the error context, short-circuit.
173 
174  This is used to prevent error cascades.
175  """
176  if self.errors:
177  self.fail()
178 
179  def recover(parent_self):
180  """
181  Returns a context manager which can be used to recover in case of
182  an error. Example usage:
183 
184  >>> with errs.recover():
185  >>> ...
186  """
187  class Recover(object):
188  def __enter__(self):
189  pass
190 
191  def __exit__(self, exc_type, exc_value, traceback):
192  if exc_type == parent_self.exc_class:
193  return True
194  return Recover()
195 
196  def addErrCtxt(parent_self, msg):
197  """
198  Returns a context manager which encloses a fragment of code with
199  an extra contextual message, e.g., where an error occurred, or a hint
200  applicable to all errors in the area. Example usage:
201 
202  >>> with errs.addErrCtx("Some text"):
203  >>> ...
204  """
205  class AddContext(object):
206  def __enter__(self):
207  parent_self.context.append(msg)
208 
209  def __exit__(self, exc_type, exc_value, traceback):
210  parent_self.context.pop()
211  return AddContext()
212 
213  def __enter__(self):
214  return self
215 
216  def __exit__(self, exc_type, exc_value, traceback):
217  if self.errors:
218  errors_msg = "\n\n".join(map(lambda x: "ERROR: " + x, self.errors))
219  final_msg = "{}\n{}\n{}".format(self.msg, '-' * 70, errors_msg)
220  raise AssertionError(final_msg)
221  if exc_type == self.exc_class:
222  raise RuntimeError("ShortCircuit was raised, but no errors were recorded")
223 
224 
225 @contextlib.contextmanager
226 def set_training(model, mode):
227  """
228  A context manager to temporarily set the training mode of 'model'
229  to 'mode', resetting it when we exit the with-block.
230  """
231  old_mode = model.training
232  if old_mode != mode:
233  model.train(mode)
234  try:
235  yield
236  finally:
237  if old_mode != mode:
238  model.train(old_mode)
239 
240 
241 def verify(model, args, backend, verbose=False, training=False, rtol=1e-3, atol=1e-7, test_args=2):
242  """
243  Export a model into ONNX, import it into a specified ONNX backend, and then
244  on a few random inputs verify that PyTorch and the backend produced the same
245  results. Requires onnx to be installed.
246 
247  This function may spuriously fail: some operators are implemented with
248  different numerical precision in an ONNX backend, in which case an unstable
249  network (e.g., Inception) may blow up these numerical instabilities. This
250  situation is less likely to happen if your model has been trained. However,
251  if this is not the case, you may have found a bug! Please report it to the
252  PyTorch developers. You can also debug the issue yourself by removing
253  suffixes of operators from your model until verification passes.
254 
255  For reproducibility, we recommend explicitly setting PyTorch's seed before
256  invoking this function.
257 
258  Arguments:
259  model (torch.nn.Module): the model to be exported and verified
260  args (tuple of arguments): the inputs to
261  the model, e.g., such that ``model(*args)`` is a valid
262  invocation of the model. Any non-Variable arguments will
263  be hard-coded into the exported model; any Variable arguments
264  will become inputs of the exported model, in the order they
265  occur in args. If args is a Variable, this is equivalent
266  to having called it with a 1-ary tuple of that Variable.
267  (Note: passing keyword arguments to the model is not currently
268  supported. Give us a shout if you need it.)
269  backend (onnx.backend module): ONNX backend to verify with
270  verbose (bool, default False): if specified, we will print out a debug
271  description of the trace being exported.
272  training (bool, default False): export the model in training mode. At
273  the moment, ONNX is oriented towards exporting models for inference
274  only, so you will generally not need to set this to True.
275  rtol (float, default 1e-3): relative precision required
276  test_args (int or iterable of args, default 2):
277  either an integer specifying the number
278  of random arguments to generate, or an iterable producing arguments
279  to test under.
280  """
281  def _nested_map(condition, fn, condition_msg=None):
282  def _map(obj):
283  if condition(obj):
284  return fn(obj)
285  elif obj is None:
286  return None
287  elif isinstance(obj, (list, tuple)):
288  return type(obj)(_map(x) for x in obj)
289  else:
290  raise ValueError("Auto nesting doesn't know how to process "
291  "an input object of type " + torch.typename(obj) +
292  (". Accepted types: " + condition_msg +
293  ", or lists/tuples of them"
294  if condition_msg else ""))
295 
296  return _map
297 
298  def _iter_filter(condition, allow_unknown=False, condition_msg=None):
299  def _iter(obj):
300  if condition(obj):
301  yield obj
302  elif obj is None:
303  return
304  elif isinstance(obj, (list, tuple)):
305  for o in obj:
306  for var in _iter(o):
307  yield var
308  elif allow_unknown:
309  yield obj
310  else:
311  raise ValueError("Auto nesting doesn't know how to process "
312  "an input object of type " + torch.typename(obj) +
313  (". Accepted types: " + condition_msg +
314  ", or lists/tuples of them"
315  if condition_msg else ""))
316 
317  return _iter
318 
319  def is_tensor(o):
320  return isinstance(o, torch.Tensor)
321 
322  _iter_tensors = _iter_filter(is_tensor, condition_msg="Tensors")
323 
324  def randomize_arg(arg):
325  new_data = arg.data.clone()
326  # For now, don't try randomizing non-float tensors; these
327  # are likely to be things like indices, where just randomly
328  # spattering some longs is unlikely to work. One way we could
329  # make this work is to apply a random permutation or something.
330  if arg.is_floating_point():
331  new_data.uniform_()
332  return torch.autograd.Variable(new_data, requires_grad=arg.requires_grad)
333 
334  randomize_args = _nested_map(is_tensor, randomize_arg)
335 
336  def backend_args(args):
337  # TODO: onnx should accept iterables
338  return tuple(v.data.cpu().numpy() for v in _iter_tensors(args))
339 
340  def load_bytes(b):
341  b.seek(0)
342  x = onnx.load(b)
343  # doc_string has stack traces - let's remove them to make comparison
344  # sane
345  onnx.helper.strip_doc_string(x)
346  return x
347 
348  # Special case for common case of passing a single Tensor
349  if isinstance(args, torch.Tensor):
350  args = (args,)
351 
352  with set_training(model, training):
353  proto_bytes = io.BytesIO()
354  torch_out = torch.onnx._export(model, args, proto_bytes, verbose=verbose)
355  proto = load_bytes(proto_bytes)
356  prepared = backend.prepare(proto)
357 
358  def run(args):
359  alt_proto_bytes = io.BytesIO()
360  torch_out = torch.onnx._export(model, args, alt_proto_bytes, verbose=verbose)
361  alt_proto = load_bytes(alt_proto_bytes)
362  if proto.SerializeToString() != alt_proto.SerializeToString():
363  # OK, let's try to figure out what happened.
364  msg = "When I exported your model with different inputs, the result was different."
365  if not verbose:
366  msg += "\n(To get more information, run torch.onnx.verify(..., verbose=True))"
367  with Errors(msg, rtol=rtol, atol=atol) as errs:
368  # First, check if we have the same number of parameters, and
369  # that they're the same order. If they don't, something has *really* gone wrong.
370  initializer_order_hint = ("This is really strange! The second time I exported your model,\n"
371  "it had a different set of parameters. Are you assigning Parameters\n"
372  "in the forward() of your model definition?")
373  with errs.addErrCtxt(initializer_order_hint):
374  errs.requireEqual(list(map(lambda x: x.name, proto.graph.initializer)),
375  list(map(lambda x: x.name, alt_proto.graph.initializer)),
376  msg="Parameters list differs")
377 
378  # Now check if the embedded parameters are actually the same
379  initializer_hint = ("A difference in embedded parameters usually means that\n"
380  "your model is updating parameters/buffers even in inference\n"
381  "mode. Look for a buggy nn.Module which isn't respecting train().\n")
382  with errs.recover(), errs.addErrCtxt(initializer_hint):
383  for x, y in zip(proto.graph.initializer, alt_proto.graph.initializer):
384  errs.checkEqual(x, y)
385 
386  # Next, check if the model structure lines up.
387  structure_hint = ("A difference in model structure usually means that\n"
388  "your model has dynamic control flow. These models are not\n"
389  "currently supported by the exporter.")
390  with errs.recover(), errs.addErrCtxt(structure_hint):
391  # Delete initializers since we already tested them
392  stripped_proto = onnx.ModelProto()
393  stripped_proto.CopyFrom(proto)
394  del stripped_proto.graph.initializer[:]
395 
396  stripped_alt_proto = onnx.ModelProto()
397  stripped_alt_proto.CopyFrom(alt_proto)
398  del stripped_alt_proto.graph.initializer[:]
399 
400  # Compare the printable graph representations first
401  errs.requireMultiLineEqual(onnx.helper.printable_graph(stripped_proto.graph),
402  onnx.helper.printable_graph(stripped_alt_proto.graph))
403 
404  # Compare the actual protobuf text formats now (not
405  # very user-friendly!)
406  errs.requireMultiLineEqual(str(stripped_proto), str(stripped_alt_proto))
407 
408  # One last ditch effort, using built-in equality on
409  # protobufs
410  errs.requireEqual(stripped_proto, stripped_alt_proto)
411 
412  errs.failIfErrs()
413 
414  # At this point, we should have figured out why the binary
415  # protobufs differed, and short-circuited out of this code
416  # with a helpful error message. But what if we didn't?
417  # We better still try to give a good error message in this
418  # case. We EXPECT these requires to fail. If they don't,
419  # that is a bug in verify
420  errs.requireEqual(proto, alt_proto)
421  errs.requireEqual(proto_bytes.getvalue(), alt_proto_bytes.getvalue())
422  assert False
423 
424  # TODO: test that the traced model also returns the same thing...
425  run_helper(torch_out, args)
426 
427  # Factored out so we can avoid one run of the model
428  def run_helper(torch_out, args):
429  backend_out = prepared.run(backend_args(args))
430  if isinstance(torch_out, torch.Tensor):
431  torch_out = (torch_out,)
432  # NB: onnx backend NEVER returns bare numpy array
433  msg = "ONNX backend returned different results from PyTorch"
434  result_hint = ("If you are not using trained parameters, a difference in results\n"
435  "could mean that your network is numerically unstable. Otherwise\n"
436  "it indicates a bug in PyTorch/ONNX; please file a bug report.")
437  with Errors(msg, rtol=rtol, atol=atol) as errs, errs.addErrCtxt(result_hint):
438  for i, (x, y) in enumerate(zip(torch_out, backend_out)):
439  errs.checkAlmostEqual(x.data.cpu().numpy(), y, "In output {}".format(i))
440 
441  run_helper(torch_out, args)
442 
443  if isinstance(test_args, int):
444  for i in range(test_args):
445  run(randomize_args(args))
446  else:
447  for test_arg in test_args:
448  run(test_arg)
def checkEqual(self, x, y, msg=None)
Definition: verify.py:84
def checkAlmostEqual(self, x, y, msg=None)
Definition: verify.py:52
def failWith(self, msg)
Definition: verify.py:163
def equalAndThen(self, x, y, msg, k)
Definition: verify.py:94
def fail(self)
Definition: verify.py:154
def requireMultiLineEqual(self, x, y, msg=None)
Definition: verify.py:125
def requireEqual(self, x, y, msg=None)
Definition: verify.py:78
def failIfErrs(self)
Definition: verify.py:170
def requireAlmostEqual(self, x, y, msg=None)
Definition: verify.py:45
def addErrCtxt(parent_self, msg)
Definition: verify.py:196
def _export(args, kwargs)
Definition: __init__.py:20
def addErr(self, msg)
Definition: verify.py:143
Definition: verify.py:1
def multiLineEqualAndThen(self, x, y, msg, k)
Definition: verify.py:132
def recover(parent_self)
Definition: verify.py:179
def typename(o)
Define basic utilities.
Definition: __init__.py:94
def almostEqualAndThen(self, x, y, msg, k)
Definition: verify.py:62