Caffe2 - Python API
A deep learning, cross platform ML framework
hypothesis_test_util.py
1 ## @package hypothesis_test_util
2 # Module caffe2.python.hypothesis_test_util
3 """
4 The Hypothesis library uses *property-based testing* to check
5 invariants about the code under test under a variety of random inputs.
6 
7  The key idea here is to express properties of the code under test
8 (e.g. that it passes a gradient check, that it implements a reference
9 function, etc), and then generate random instances and verify they
10 satisfy these properties.
11 
12 The main functions of interest are exposed on `HypothesisTestCase`.
13 You can usually just add a short function in this to generate an
14 arbitrary number of test cases for your operator.
15 
16 The key functions are:
17 
18 - `assertDeviceChecks(devices, op, inputs, outputs)`. This asserts that the
19  operator computes the same outputs, regardless of which device it is executed
20  on.
21 - `assertGradientChecks(device, op, inputs, output_,
22  outputs_with_grads)`. This implements a standard numerical gradient checker
23  for the operator in question.
24 - `assertReferenceChecks(device, op, inputs, reference)`. This runs the
25  reference function (effectively calling `reference(*inputs)`, and comparing
26  that to the output of output.
27 
28 `hypothesis_test_util.py` exposes some useful pre-built samplers.
29 
30 - `hu.gcs` - a gradient checker device (`gc`) and device checker devices (`dc`)
31 
32 - `hu.gcs_cpu_only` - a CPU-only gradient checker device (`gc`) and
33  device checker devices (`dc`). Used for when your operator is only
34  implemented on the CPU.
35 """
36 
37 from __future__ import absolute_import
38 from __future__ import division
39 from __future__ import print_function
40 from __future__ import unicode_literals
41 from caffe2.proto import caffe2_pb2
42 from caffe2.python import (
43  workspace, device_checker, gradient_checker, test_util, core)
44 import contextlib
45 import copy
46 import functools
47 import hypothesis
48 import hypothesis.extra.numpy
49 import hypothesis.strategies as st
50 import logging
51 import numpy as np
52 import os
53 import six
54 
55 
56 def is_sandcastle():
57  if os.getenv('SANDCASTLE') == '1':
58  return True
59  elif os.getenv('TW_JOB_USER') == 'sandcastle':
60  return True
61  return False
62 
63 
64 def is_travis():
65  return 'TRAVIS' in os.environ
66 
67 
68 # "min_satisfying_examples" setting has been deprecated in hypythesis
69 # 3.56.0 and removed in hypothesis 4.x
70 if hypothesis.version.__version_info__ >= (3, 56, 0):
71  hypothesis.settings.register_profile(
72  "sandcastle",
73  hypothesis.settings(
74  derandomize=True,
75  suppress_health_check=[hypothesis.HealthCheck.too_slow],
76  database=None,
77  max_examples=100,
78  verbosity=hypothesis.Verbosity.verbose))
79  hypothesis.settings.register_profile(
80  "dev",
81  hypothesis.settings(
82  suppress_health_check=[hypothesis.HealthCheck.too_slow],
83  database=None,
84  max_examples=10,
85  verbosity=hypothesis.Verbosity.verbose))
86  hypothesis.settings.register_profile(
87  "debug",
88  hypothesis.settings(
89  suppress_health_check=[hypothesis.HealthCheck.too_slow],
90  database=None,
91  max_examples=1000,
92  verbosity=hypothesis.Verbosity.verbose))
93 else:
94  hypothesis.settings.register_profile(
95  "sandcastle",
96  hypothesis.settings(
97  derandomize=True,
98  suppress_health_check=[hypothesis.HealthCheck.too_slow],
99  database=None,
100  max_examples=100,
101  min_satisfying_examples=1,
102  verbosity=hypothesis.Verbosity.verbose))
103  hypothesis.settings.register_profile(
104  "dev",
105  hypothesis.settings(
106  suppress_health_check=[hypothesis.HealthCheck.too_slow],
107  database=None,
108  max_examples=10,
109  min_satisfying_examples=1,
110  verbosity=hypothesis.Verbosity.verbose))
111  hypothesis.settings.register_profile(
112  "debug",
113  hypothesis.settings(
114  suppress_health_check=[hypothesis.HealthCheck.too_slow],
115  database=None,
116  max_examples=1000,
117  min_satisfying_examples=1,
118  verbosity=hypothesis.Verbosity.verbose))
119 
120 hypothesis.settings.load_profile(
121  'sandcastle' if is_sandcastle() else os.getenv('CAFFE2_HYPOTHESIS_PROFILE',
122  'dev')
123 )
124 
125 
126 def dims(min_value=1, max_value=5):
127  return st.integers(min_value=min_value, max_value=max_value)
128 
129 
130 def elements_of_type(dtype=np.float32, filter_=None):
131  elems = None
132  if dtype in (np.float16, np.float32, np.float64):
133  elems = st.floats(min_value=-1.0, max_value=1.0)
134  elif dtype is np.int32:
135  elems = st.integers(min_value=0, max_value=2 ** 31 - 1)
136  elif dtype is np.int64:
137  elems = st.integers(min_value=0, max_value=2 ** 63 - 1)
138  elif dtype is np.bool:
139  elems = st.booleans()
140  else:
141  raise ValueError("Unexpected dtype without elements provided")
142  return elems if filter_ is None else elems.filter(filter_)
143 
144 
145 def arrays(dims, dtype=np.float32, elements=None):
146  if elements is None:
147  elements = elements_of_type(dtype)
148  return hypothesis.extra.numpy.arrays(
149  dtype,
150  dims,
151  elements=elements,
152  )
153 
154 
155 def tensor(min_dim=1,
156  max_dim=4,
157  dtype=np.float32,
158  elements=None,
159  **kwargs):
160  dims_ = st.lists(dims(**kwargs), min_size=min_dim, max_size=max_dim)
161  return dims_.flatmap(
162  lambda dims: arrays(dims, dtype, elements))
163 
164 
165 def tensor1d(min_len=1, max_len=64, dtype=np.float32, elements=None):
166  return tensor(1, 1, dtype, elements, min_value=min_len, max_value=max_len)
167 
168 
169 def segment_ids(size, is_sorted):
170  if size == 0:
171  return st.just(np.empty(shape=[0], dtype=np.int32))
172  if is_sorted:
173  return arrays(
174  [size],
175  dtype=np.int32,
176  elements=st.booleans()).map(
177  lambda x: np.cumsum(x, dtype=np.int32) - x[0])
178  else:
179  return arrays(
180  [size],
181  dtype=np.int32,
182  elements=st.integers(min_value=0, max_value=2 * size))
183 
184 
185 def lengths(size, min_segments=None, max_segments=None, **kwargs):
186  # First generate number of boarders between segments
187  # Then create boarder values and add 0 and size
188  # By sorting and computing diff we convert them to lengths of
189  # possible 0 value
190  if min_segments is None:
191  min_segments = 0
192  if max_segments is None:
193  max_segments = size
194  assert min_segments >= 0
195  assert min_segments <= max_segments
196  if size == 0 and max_segments == 0:
197  return st.just(np.empty(shape=[0], dtype=np.int32))
198  assert max_segments > 0, "size is not 0, need at least one segment"
199  return st.integers(
200  min_value=max(min_segments - 1, 0), max_value=max_segments - 1
201  ).flatmap(
202  lambda num_borders:
203  hypothesis.extra.numpy.arrays(
204  np.int32, num_borders, elements=st.integers(
205  min_value=0, max_value=size
206  )
207  )
208  ).map(
209  lambda x: np.append(x, np.array([0, size], dtype=np.int32))
210  ).map(sorted).map(np.diff)
211 
212 
213 def segmented_tensor(
214  min_dim=1,
215  max_dim=4,
216  dtype=np.float32,
217  is_sorted=True,
218  elements=None,
219  segment_generator=segment_ids,
220  allow_empty=False,
221  **kwargs
222 ):
223  gen_empty = st.booleans() if allow_empty else st.just(False)
224  data_dims_ = st.lists(dims(**kwargs), min_size=min_dim, max_size=max_dim)
225  data_dims_ = st.tuples(
226  gen_empty, data_dims_
227  ).map(lambda pair: ([0] if pair[0] else []) + pair[1])
228  return data_dims_.flatmap(lambda data_dims: st.tuples(
229  arrays(data_dims, dtype, elements),
230  segment_generator(data_dims[0], is_sorted=is_sorted),
231  ))
232 
233 
234 def lengths_tensor(min_segments=None, max_segments=None, *args, **kwargs):
235  gen = functools.partial(
236  lengths, min_segments=min_segments, max_segments=max_segments)
237  return segmented_tensor(*args, segment_generator=gen, **kwargs)
238 
239 
240 def sparse_segmented_tensor(min_dim=1, max_dim=4, dtype=np.float32,
241  is_sorted=True, elements=None, allow_empty=False,
242  segment_generator=segment_ids, itype=np.int64,
243  **kwargs):
244  gen_empty = st.booleans() if allow_empty else st.just(False)
245  data_dims_ = st.lists(dims(**kwargs), min_size=min_dim, max_size=max_dim)
246  all_dims_ = st.tuples(gen_empty, data_dims_).flatmap(
247  lambda pair: st.tuples(
248  st.just(pair[1]),
249  (st.integers(min_value=1, max_value=pair[1][0]) if not pair[0]
250  else st.just(0)),
251  ))
252  return all_dims_.flatmap(lambda dims: st.tuples(
253  arrays(dims[0], dtype, elements),
254  arrays(dims[1], dtype=itype, elements=st.integers(
255  min_value=0, max_value=dims[0][0] - 1)),
256  segment_generator(dims[1], is_sorted=is_sorted),
257  ))
258 
259 
260 def sparse_lengths_tensor(**kwargs):
261  return sparse_segmented_tensor(segment_generator=lengths, **kwargs)
262 
263 
264 def tensors(n, min_dim=1, max_dim=4, dtype=np.float32, elements=None, **kwargs):
265  dims_ = st.lists(dims(**kwargs), min_size=min_dim, max_size=max_dim)
266  return dims_.flatmap(
267  lambda dims: st.lists(
268  arrays(dims, dtype, elements),
269  min_size=n,
270  max_size=n))
271 
272 
273 def tensors1d(n, min_len=1, max_len=64, dtype=np.float32, elements=None):
274  return tensors(
275  n, 1, 1, dtype, elements, min_value=min_len, max_value=max_len
276  )
277 
278 
279 cpu_do = caffe2_pb2.DeviceOption()
280 cuda_do = caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CUDA)
281 hip_do = caffe2_pb2.DeviceOption(device_type=caffe2_pb2.HIP)
282 gpu_do = caffe2_pb2.DeviceOption(device_type=workspace.GpuDeviceType) # CUDA or ROCm
283 # (bddppq) Do not rely on this no_hip option! It's just used to
284 # temporarily skip some flaky tests on ROCM before it's getting more mature.
285 _device_options_no_hip = [cpu_do] + ([cuda_do] if workspace.has_cuda_support else [])
286 device_options = _device_options_no_hip + ([hip_do] if workspace.has_hip_support else [])
287 
288 # Include device option for each GPU
289 expanded_device_options = [cpu_do] + [
290  caffe2_pb2.DeviceOption(device_type=workspace.GpuDeviceType, device_id=i)
291  for i in range(workspace.NumGpuDevices())]
292 
293 
294 def device_checker_device_options():
295  return st.just(device_options)
296 
297 
298 def gradient_checker_device_option():
299  return st.sampled_from(device_options)
300 
301 
302 gcs = dict(
303  gc=gradient_checker_device_option(),
304  dc=device_checker_device_options()
305 )
306 
307 gcs_cpu_only = dict(gc=st.sampled_from([cpu_do]), dc=st.just([cpu_do]))
308 gcs_cuda_only = dict(gc=st.sampled_from([cuda_do]), dc=st.just([cuda_do]))
309 gcs_gpu_only = dict(gc=st.sampled_from([gpu_do]), dc=st.just([gpu_do])) # CUDA or ROCm
310 gcs_no_hip = dict(gc=st.sampled_from(_device_options_no_hip), dc=st.just(_device_options_no_hip))
311 
312 
313 @contextlib.contextmanager
314 def temp_workspace(name=b"temp_ws"):
315  old_ws_name = workspace.CurrentWorkspace()
316  workspace.SwitchWorkspace(name, True)
317  yield
318  workspace.ResetWorkspace()
319  workspace.SwitchWorkspace(old_ws_name)
320 
321 
322 def runOpBenchmark(
323  device_option,
324  op,
325  inputs,
326  input_device_options=None,
327  iterations=10,
328 ):
329  op = copy.deepcopy(op)
330  op.device_option.CopyFrom(device_option)
331  net = caffe2_pb2.NetDef()
332  net.op.extend([op])
333  net.name = op.name if op.name else "test"
334 
335  with temp_workspace():
336  _input_device_options = input_device_options or \
337  core.InferOpBlobDevicesAsDict(op)[0]
338  for (n, b) in zip(op.input, inputs):
339  workspace.FeedBlob(
340  n,
341  b,
342  device_option=_input_device_options.get(n, device_option)
343  )
344  workspace.CreateNet(net)
345  ret = workspace.BenchmarkNet(net.name, 1, iterations, True)
346  return ret
347 
348 
349 def runOpOnInput(
350  device_option,
351  op,
352  inputs,
353  input_device_options=None,
354 ):
355  op = copy.deepcopy(op)
356  op.device_option.CopyFrom(device_option)
357 
358  with temp_workspace():
359  if (len(op.input) > len(inputs)):
360  raise ValueError(
361  'must supply an input for each input on the op: %s vs %s' %
362  (op.input, inputs))
363  _input_device_options = input_device_options or \
364  core.InferOpBlobDevicesAsDict(op)[0]
365  for (n, b) in zip(op.input, inputs):
366  workspace.FeedBlob(
367  n,
368  b,
369  device_option=_input_device_options.get(n, device_option)
370  )
371  workspace.RunOperatorOnce(op)
372  outputs_to_check = list(range(len(op.output)))
373  outs = []
374  for output_index in outputs_to_check:
375  output_blob_name = op.output[output_index]
376  output = workspace.FetchBlob(output_blob_name)
377  outs.append(output)
378  return outs
379 
380 
382  """
383  A unittest.TestCase subclass with some helper functions for
384  utilizing the `hypothesis` (hypothesis.readthedocs.io) library.
385  """
386 
387  def assertDeviceChecks(
388  self,
389  device_options,
390  op,
391  inputs,
392  outputs_to_check,
393  input_device_options=None,
394  threshold=0.01
395  ):
396  """
397  Asserts that the operator computes the same outputs, regardless of
398  which device it is executed on.
399 
400  Useful for checking the consistency of GPU and CPU
401  implementations of operators.
402 
403  Usage example:
404 
405  @given(inputs=hu.tensors(n=2), in_place=st.booleans(), **hu.gcs)
406  def test_sum(self, inputs, in_place, gc, dc):
407  op = core.CreateOperator("Sum", ["X1", "X2"],
408  ["Y" if not in_place else "X1"])
409  X1, X2 = inputs
410  self.assertDeviceChecks(dc, op, [X1, X2], [0])
411  """
413  threshold,
414  device_options=device_options
415  )
416  self.assertTrue(
417  dc.CheckSimple(op, inputs, outputs_to_check, input_device_options)
418  )
419 
421  self,
422  device_option,
423  op,
424  inputs,
425  outputs_to_check,
426  outputs_with_grads,
427  grad_ops=None,
428  threshold=0.005,
429  stepsize=0.05,
430  input_device_options=None,
431  ):
432  """
433  Implements a standard numerical gradient checker for the operator
434  in question.
435 
436  Useful for checking the consistency of the forward and
437  backward implementations of operators.
438 
439  Usage example:
440 
441  @given(inputs=hu.tensors(n=2), in_place=st.booleans(), **hu.gcs)
442  def test_sum(self, inputs, in_place, gc, dc):
443  op = core.CreateOperator("Sum", ["X1", "X2"],
444  ["Y" if not in_place else "X1"])
445  X1, X2 = inputs
446  self.assertGradientChecks(gc, op, [X1, X2], 0, [0])
447  """
449  stepsize=stepsize,
450  threshold=threshold,
451  device_option=device_option,
452  workspace_name=str(device_option),
453  input_device_options=input_device_options,
454  )
455  res, grad, grad_estimated = gc.CheckSimple(
456  op, inputs, outputs_to_check, outputs_with_grads,
457  grad_ops=grad_ops,
458  input_device_options=input_device_options
459  )
460  self.assertEqual(grad.shape, grad_estimated.shape)
461  self.assertTrue(
462  res,
463  "Gradient check failed for input " + str(op.input[outputs_to_check])
464  )
465 
466  def _assertGradReferenceChecks(
467  self,
468  op,
469  inputs,
470  ref_outputs,
471  output_to_grad,
472  grad_reference,
473  threshold=1e-4,
474  ):
475  grad_blob_name = output_to_grad + '_grad'
476  grad_ops, grad_map = core.GradientRegistry.GetBackwardPass(
477  [op], {output_to_grad: grad_blob_name})
478  output_grad = workspace.FetchBlob(output_to_grad)
479  grad_ref_outputs = grad_reference(output_grad, ref_outputs, inputs)
480  workspace.FeedBlob(grad_blob_name, workspace.FetchBlob(output_to_grad))
481  workspace.RunOperatorsOnce(grad_ops)
482 
483  self.assertEqual(len(grad_ref_outputs), len(inputs))
484  for (n, ref) in zip(op.input, grad_ref_outputs):
485  grad_names = grad_map.get(n)
486  if not grad_names:
487  # no grad for this input
488  self.assertIsNone(ref)
489  else:
490  if isinstance(grad_names, core.BlobReference):
491  # dense gradient
492  ref_vals = ref
493  ref_indices = None
494  val_name = grad_names
495  else:
496  # sparse gradient
497  ref_vals, ref_indices = ref
498  val_name = grad_names.values
499  vals = workspace.FetchBlob(str(val_name))
500  np.testing.assert_allclose(
501  vals,
502  ref_vals,
503  atol=threshold,
504  rtol=threshold,
505  err_msg='Gradient {0} (x) is not matching the reference (y)'
506  .format(val_name),
507  )
508  if ref_indices is not None:
509  indices = workspace.FetchBlob(str(grad_names.indices))
510  np.testing.assert_allclose(indices, ref_indices,
511  atol=1e-4, rtol=1e-4)
512 
513  def _assertInferTensorChecks(self, name, shapes, types, output):
514  if name not in shapes:
515  # No inferred shape or type available
516  return
517  output = workspace.FetchBlob(name)
518  if type(output) is np.ndarray:
519  if output.dtype == np.dtype('float64'):
520  correct_type = caffe2_pb2.TensorProto.DOUBLE
521  elif output.dtype == np.dtype('float32'):
522  correct_type = caffe2_pb2.TensorProto.FLOAT
523  elif output.dtype == np.dtype('int32'):
524  correct_type = caffe2_pb2.TensorProto.INT32
525  elif output.dtype == np.dtype('int64'):
526  correct_type = caffe2_pb2.TensorProto.INT64
527  else:
528  correct_type = "unknown {}".format(np.dtype)
529  else:
530  correct_type = str(type(output))
531  try:
532  np.testing.assert_array_equal(
533  np.array(shapes[name]).astype(np.int32),
534  np.array(output.shape).astype(np.int32),
535  err_msg='Shape {} mismatch: {} vs. {}'.format(
536  name,
537  shapes[name],
538  output.shape))
539  # BUG: Workspace blob type not being set correctly T16121392
540  if correct_type != caffe2_pb2.TensorProto.INT32:
541  return
542  np.testing.assert_equal(
543  types[name],
544  correct_type,
545  err_msg='Type {} mismatch: {} vs. {}'.format(
546  name, types[name], correct_type,
547  )
548  )
549  except AssertionError as e:
550  # Temporarily catch these assertion errors when validating
551  # inferred shape and type info
552  logging.warning(str(e))
553  if os.getenv('CAFFE2_ASSERT_SHAPEINFERENCE') == '1':
554  raise e
555 
557  self,
558  device_option,
559  op,
560  inputs,
561  reference,
562  input_device_options=None,
563  threshold=1e-4,
564  output_to_grad=None,
565  grad_reference=None,
566  atol=None,
567  outputs_to_check=None,
568  ):
569  """
570  This runs the reference Python function implementation
571  (effectively calling `reference(*inputs)`, and compares that
572  to the output of output, with an absolute/relative tolerance
573  given by the `threshold` parameter.
574 
575  Useful for checking the implementation matches the Python
576  (typically NumPy) implementation of the same functionality.
577 
578  Usage example:
579 
580  @given(X=hu.tensor(), inplace=st.booleans(), **hu.gcs)
581  def test_softsign(self, X, inplace, gc, dc):
582  op = core.CreateOperator(
583  "Softsign", ["X"], ["X" if inplace else "Y"])
584 
585  def softsign(X):
586  return (X / (1 + np.abs(X)),)
587 
588  self.assertReferenceChecks(gc, op, [X], softsign)
589  """
590  op = copy.deepcopy(op)
591  op.device_option.CopyFrom(device_option)
592 
593  with temp_workspace():
594  if (len(op.input) > len(inputs)):
595  raise ValueError(
596  'must supply an input for each input on the op: %s vs %s' %
597  (op.input, inputs))
598  _input_device_options = input_device_options or \
599  core.InferOpBlobDevicesAsDict(op)[0]
600  for (n, b) in zip(op.input, inputs):
601  workspace.FeedBlob(
602  n,
603  b,
604  device_option=_input_device_options.get(n, device_option)
605  )
606  net = core.Net("opnet")
607  net.Proto().op.extend([op])
608  test_shape_inference = False
609  try:
610  (shapes, types) = workspace.InferShapesAndTypes([net])
611  test_shape_inference = True
612  except RuntimeError as e:
613  # Temporarily catch runtime errors when inferring shape
614  # and type info
615  logging.warning(str(e))
616  if os.getenv('CAFFE2_ASSERT_SHAPEINFERENCE') == '1':
617  raise e
618  workspace.RunNetOnce(net)
619  reference_outputs = reference(*inputs)
620  if not (isinstance(reference_outputs, tuple) or
621  isinstance(reference_outputs, list)):
622  raise RuntimeError(
623  "You are providing a wrong reference implementation. A "
624  "proper one should return a tuple/list of numpy arrays.")
625  if not outputs_to_check:
626  self.assertEqual(len(reference_outputs), len(op.output))
627  outputs_to_check = list(range(len(op.output)))
628  outs = []
629  for (output_index, ref) in zip(outputs_to_check, reference_outputs):
630  output_blob_name = op.output[output_index]
631  output = workspace.FetchBlob(output_blob_name)
632  if output.dtype.kind in ('S', 'O'):
633  np.testing.assert_array_equal(output, ref)
634  else:
635  if atol is None:
636  atol = threshold
637  np.testing.assert_allclose(
638  output, ref, atol=atol, rtol=threshold,
639  err_msg=(
640  'Output {0} is not matching the reference'.format(
641  output_blob_name,
642  )),
643  )
644  if test_shape_inference:
646  output_blob_name, shapes, types, output)
647  outs.append(output)
648  if grad_reference is not None:
649  assert output_to_grad is not None, \
650  "If grad_reference is set," \
651  "output_to_grad has to be set as well"
652 
653  with core.DeviceScope(device_option):
655  op, inputs, reference_outputs,
656  output_to_grad, grad_reference,
657  threshold=threshold)
658 
659  return outs
660 
661  def assertValidationChecks(
662  self,
663  device_option,
664  op,
665  inputs,
666  validator,
667  input_device_options=None,
668  as_kwargs=True,
669  init_net=None,
670  ):
671  if as_kwargs:
672  assert len(set(list(op.input) + list(op.output))) == \
673  len(op.input) + len(op.output), \
674  "in-place ops are not supported in as_kwargs mode"
675  op = copy.deepcopy(op)
676  op.device_option.CopyFrom(device_option)
677 
678  with temp_workspace():
679  _input_device_options = input_device_options or \
680  core.InferOpBlobDevicesAsDict(op)[0]
681  for (n, b) in zip(op.input, inputs):
682  workspace.FeedBlob(
683  n,
684  b,
685  device_option=_input_device_options.get(n, device_option)
686  )
687  if init_net:
688  workspace.RunNetOnce(init_net)
689  workspace.RunOperatorOnce(op)
690  outputs = [workspace.FetchBlob(n) for n in op.output]
691  if as_kwargs:
692  validator(**dict(zip(
693  list(op.input) + list(op.output), inputs + outputs)))
694  else:
695  validator(inputs=inputs, outputs=outputs)
696 
697  def assertRunOpRaises(
698  self,
699  device_option,
700  op,
701  inputs,
702  input_device_options=None,
703  exception=(Exception,),
704  regexp=None,
705  ):
706  op = copy.deepcopy(op)
707  op.device_option.CopyFrom(device_option)
708 
709  with temp_workspace():
710  _input_device_options = input_device_options or \
711  core.InferOpBlobDevicesAsDict(op)[0]
712  for (n, b) in zip(op.input, inputs):
713  workspace.FeedBlob(
714  n,
715  b,
716  device_option=_input_device_options.get(n, device_option)
717  )
718  if regexp is None:
719  self.assertRaises(exception, workspace.RunOperatorOnce, op)
720  else:
721  six.assertRaisesRegex(
722  self, exception, regexp, workspace.RunOperatorOnce, op)
def _assertInferTensorChecks(self, name, shapes, types, output)
def _assertGradReferenceChecks(self, op, inputs, ref_outputs, output_to_grad, grad_reference, threshold=1e-4)
def assertReferenceChecks(self, device_option, op, inputs, reference, input_device_options=None, threshold=1e-4, output_to_grad=None, grad_reference=None, atol=None, outputs_to_check=None)
def assertDeviceChecks(self, device_options, op, inputs, outputs_to_check, input_device_options=None, threshold=0.01)
def assertGradientChecks(self, device_option, op, inputs, outputs_to_check, outputs_with_grads, grad_ops=None, threshold=0.005, stepsize=0.05, input_device_options=None)