15 def colonize(msg, sep=":
"): 24 An error-collecting object which supports error recovery. 26 It is intended to be used like a context manager: 28 >>> with Errors("Top-level error message") as errs: 32 def __init__(self, msg, rtol=1e-3, atol=1e-5):
41 class ShortCircuit(Exception):
43 self.exc_class = ShortCircuit
47 Test that x and y are nearly equal (equal within self.rtol 48 precision); aborts execution if they are not. 54 Test that x and y are nearly equal (equal within self.rtol 55 precision), but continue execution even if they are not equal. 57 To prevent error cascades, you should remember to call 'failIfErrs' 58 at some later point in time. 64 Helper for implementing 'requireAlmostEqual' and 'checkAlmostEqual'. 65 Upon failure, invokes continuation 'k' with the error message. 67 At the moment, only tests on 'numpy.ndarray' are supported. 69 if isinstance(x, np.ndarray)
and isinstance(y, np.ndarray):
71 np.testing.assert_allclose(x, y, rtol=self.
rtol, atol=self.
atol, equal_nan=
False, verbose=
True)
72 except AssertionError
as e:
74 k(
"{}{}".format(colonize(msg), str(e).lstrip()))
76 raise RuntimeError(
"Unsupported almost equal test")
80 Test that x and y are equal; aborts execution if they are not. 86 Test that x and y are equal, but continue execution even if they are not equal. 88 To prevent error cascades, you should remember to call 'failIfErrs' 89 at some later point in time. 96 Helper for implementing 'requireEqual' and 'checkEqual'. Upon failure, 97 invokes continuation 'k' with the error message. 99 if isinstance(x, onnx.TensorProto)
and isinstance(y, onnx.TensorProto):
102 t1 = onnx.numpy_helper.to_array(x)
103 t2 = onnx.numpy_helper.to_array(y)
104 new_msg =
"{}In embedded parameter '{}'".format(colonize(msg), x.name)
106 elif isinstance(x, np.ndarray)
and isinstance(y, np.ndarray):
108 np.testing.assert_equal(x, y)
109 except AssertionError
as e:
111 k(
"{}{}".format(colonize(msg,
": "), str(e).lstrip()))
117 if len(sx) > 40
or len(sy) > 40
or '\n' in sx
or '\n' in sy:
120 k(
"\n{}The value\n{}\n{}\n{}\n\ndoes not equal\n\n{}\n{}\n{}" 121 .format(colonize(msg,
":\n"), l, sx, l, l, sy, l))
123 k(
"{}{} != {}".format(colonize(msg), sx, sy))
127 Test that long, multi-line strings x and y are equal; 128 aborts execution if they are not. 134 Helper for implementing 'requireMultiLineEqual'. Upon failure, 135 invokes continuation 'k' with the error message. 138 msg =
"Strings are not equal" 140 diff = difflib.ndiff(x.splitlines(
True), y.splitlines(
True))
141 k(
"{}{}".format(colonize(msg,
":\n\n"),
"".join(diff)))
145 Add an error to the error context, but continue executing. 150 for c
in reversed(self.
context):
151 msg +=
"\n\n * " +
"\n ".join(c.splitlines())
152 self.errors.append(msg)
156 Immediately fail and short-circuit to the next recovery context. 158 NB: It is an error to 'fail' without having added any errors to 165 Add an error to the error context, and then short-circuit. 172 If there are any errors in the error context, short-circuit. 174 This is used to prevent error cascades. 181 Returns a context manager which can be used to recover in case of 182 an error. Example usage: 184 >>> with errs.recover(): 187 class Recover(object):
191 def __exit__(self, exc_type, exc_value, traceback):
192 if exc_type == parent_self.exc_class:
198 Returns a context manager which encloses a fragment of code with 199 an extra contextual message, e.g., where an error occurred, or a hint 200 applicable to all errors in the area. Example usage: 202 >>> with errs.addErrCtx("Some text"): 205 class AddContext(object):
207 parent_self.context.append(msg)
209 def __exit__(self, exc_type, exc_value, traceback):
210 parent_self.context.pop()
216 def __exit__(self, exc_type, exc_value, traceback):
218 errors_msg =
"\n\n".join(map(
lambda x:
"ERROR: " + x, self.
errors))
219 final_msg =
"{}\n{}\n{}".format(self.
msg,
'-' * 70, errors_msg)
220 raise AssertionError(final_msg)
222 raise RuntimeError(
"ShortCircuit was raised, but no errors were recorded")
225 @contextlib.contextmanager
226 def set_training(model, mode):
228 A context manager to temporarily set the training mode of 'model' 229 to 'mode', resetting it when we exit the with-block. 231 old_mode = model.training
238 model.train(old_mode)
241 def verify(model, args, backend, verbose=False, training=False, rtol=1e-3, atol=1e-7, test_args=2):
243 Export a model into ONNX, import it into a specified ONNX backend, and then 244 on a few random inputs verify that PyTorch and the backend produced the same 245 results. Requires onnx to be installed. 247 This function may spuriously fail: some operators are implemented with 248 different numerical precision in an ONNX backend, in which case an unstable 249 network (e.g., Inception) may blow up these numerical instabilities. This 250 situation is less likely to happen if your model has been trained. However, 251 if this is not the case, you may have found a bug! Please report it to the 252 PyTorch developers. You can also debug the issue yourself by removing 253 suffixes of operators from your model until verification passes. 255 For reproducibility, we recommend explicitly setting PyTorch's seed before 256 invoking this function. 259 model (torch.nn.Module): the model to be exported and verified 260 args (tuple of arguments): the inputs to 261 the model, e.g., such that ``model(*args)`` is a valid 262 invocation of the model. Any non-Variable arguments will 263 be hard-coded into the exported model; any Variable arguments 264 will become inputs of the exported model, in the order they 265 occur in args. If args is a Variable, this is equivalent 266 to having called it with a 1-ary tuple of that Variable. 267 (Note: passing keyword arguments to the model is not currently 268 supported. Give us a shout if you need it.) 269 backend (onnx.backend module): ONNX backend to verify with 270 verbose (bool, default False): if specified, we will print out a debug 271 description of the trace being exported. 272 training (bool, default False): export the model in training mode. At 273 the moment, ONNX is oriented towards exporting models for inference 274 only, so you will generally not need to set this to True. 275 rtol (float, default 1e-3): relative precision required 276 test_args (int or iterable of args, default 2): 277 either an integer specifying the number 278 of random arguments to generate, or an iterable producing arguments 281 def _nested_map(condition, fn, condition_msg=None):
287 elif isinstance(obj, (list, tuple)):
288 return type(obj)(_map(x)
for x
in obj)
290 raise ValueError(
"Auto nesting doesn't know how to process " 292 (
". Accepted types: " + condition_msg +
293 ", or lists/tuples of them" 294 if condition_msg
else ""))
298 def _iter_filter(condition, allow_unknown=False, condition_msg=None):
304 elif isinstance(obj, (list, tuple)):
311 raise ValueError(
"Auto nesting doesn't know how to process " 313 (
". Accepted types: " + condition_msg +
314 ", or lists/tuples of them" 315 if condition_msg
else ""))
320 return isinstance(o, torch.Tensor)
322 _iter_tensors = _iter_filter(is_tensor, condition_msg=
"Tensors")
324 def randomize_arg(arg):
325 new_data = arg.data.clone()
330 if arg.is_floating_point():
332 return torch.autograd.Variable(new_data, requires_grad=arg.requires_grad)
334 randomize_args = _nested_map(is_tensor, randomize_arg)
336 def backend_args(args):
338 return tuple(v.data.cpu().numpy()
for v
in _iter_tensors(args))
345 onnx.helper.strip_doc_string(x)
349 if isinstance(args, torch.Tensor):
352 with set_training(model, training):
353 proto_bytes = io.BytesIO()
355 proto = load_bytes(proto_bytes)
356 prepared = backend.prepare(proto)
359 alt_proto_bytes = io.BytesIO()
361 alt_proto = load_bytes(alt_proto_bytes)
362 if proto.SerializeToString() != alt_proto.SerializeToString():
364 msg =
"When I exported your model with different inputs, the result was different." 366 msg +=
"\n(To get more information, run torch.onnx.verify(..., verbose=True))" 367 with
Errors(msg, rtol=rtol, atol=atol)
as errs:
370 initializer_order_hint = (
"This is really strange! The second time I exported your model,\n" 371 "it had a different set of parameters. Are you assigning Parameters\n" 372 "in the forward() of your model definition?")
373 with errs.addErrCtxt(initializer_order_hint):
374 errs.requireEqual(list(map(
lambda x: x.name, proto.graph.initializer)),
375 list(map(
lambda x: x.name, alt_proto.graph.initializer)),
376 msg=
"Parameters list differs")
379 initializer_hint = (
"A difference in embedded parameters usually means that\n" 380 "your model is updating parameters/buffers even in inference\n" 381 "mode. Look for a buggy nn.Module which isn't respecting train().\n")
382 with errs.recover(), errs.addErrCtxt(initializer_hint):
383 for x, y
in zip(proto.graph.initializer, alt_proto.graph.initializer):
384 errs.checkEqual(x, y)
387 structure_hint = (
"A difference in model structure usually means that\n" 388 "your model has dynamic control flow. These models are not\n" 389 "currently supported by the exporter.")
390 with errs.recover(), errs.addErrCtxt(structure_hint):
392 stripped_proto = onnx.ModelProto()
393 stripped_proto.CopyFrom(proto)
394 del stripped_proto.graph.initializer[:]
396 stripped_alt_proto = onnx.ModelProto()
397 stripped_alt_proto.CopyFrom(alt_proto)
398 del stripped_alt_proto.graph.initializer[:]
401 errs.requireMultiLineEqual(onnx.helper.printable_graph(stripped_proto.graph),
402 onnx.helper.printable_graph(stripped_alt_proto.graph))
406 errs.requireMultiLineEqual(str(stripped_proto), str(stripped_alt_proto))
410 errs.requireEqual(stripped_proto, stripped_alt_proto)
420 errs.requireEqual(proto, alt_proto)
421 errs.requireEqual(proto_bytes.getvalue(), alt_proto_bytes.getvalue())
425 run_helper(torch_out, args)
428 def run_helper(torch_out, args):
429 backend_out = prepared.run(backend_args(args))
430 if isinstance(torch_out, torch.Tensor):
431 torch_out = (torch_out,)
433 msg =
"ONNX backend returned different results from PyTorch" 434 result_hint = (
"If you are not using trained parameters, a difference in results\n" 435 "could mean that your network is numerically unstable. Otherwise\n" 436 "it indicates a bug in PyTorch/ONNX; please file a bug report.")
437 with
Errors(msg, rtol=rtol, atol=atol)
as errs, errs.addErrCtxt(result_hint):
438 for i, (x, y)
in enumerate(zip(torch_out, backend_out)):
439 errs.checkAlmostEqual(x.data.cpu().numpy(), y,
"In output {}".format(i))
441 run_helper(torch_out, args)
443 if isinstance(test_args, int):
444 for i
in range(test_args):
445 run(randomize_args(args))
447 for test_arg
in test_args:
def checkEqual(self, x, y, msg=None)
def checkAlmostEqual(self, x, y, msg=None)
def equalAndThen(self, x, y, msg, k)
def requireMultiLineEqual(self, x, y, msg=None)
def requireEqual(self, x, y, msg=None)
def requireAlmostEqual(self, x, y, msg=None)
def addErrCtxt(parent_self, msg)
def _export(args, kwargs)
def multiLineEqualAndThen(self, x, y, msg, k)
def typename(o)
Define basic utilities.
def almostEqualAndThen(self, x, y, msg, k)