Caffe2 - Python API
A deep learning, cross platform ML framework
device_checker.py
1 ## @package device_checker
2 # Module caffe2.python.device_checker
3 import numpy as np
4 import copy
5 from caffe2.python import workspace
6 from caffe2.python.core import InferOpBlobDevicesAsDict
7 from future.utils import viewitems
8 
9 
10 class DeviceChecker(object):
11  """A device checker in Python to check consistency across multiple devices.
12 
13  This is not the most efficient way to check devices, as the Python interface
14  will involve a lot of copies back and forth operations. Use at your own risk.
15  """
16 
17  def __init__(self, threshold, device_options):
18  self._threshold = threshold
19  self._device_options = device_options
20 
21  def CheckSimple(self, op, inputs, outputs_to_check,
22  input_device_options=None):
23  """Checks the operator with different device implementations.
24 
25  Inputs:
26  op: the operator to be checked.
27  inputs: the input data in numpy arrays.
28  outputs_to_check: the outputs to check between devices.
29  input_device_options: a mapping from input name to a device to use
30  (instead of self._device_options)
31  Outputs:
32  boolean: True if it passes, False if it does not pass.
33  """
34  op = copy.deepcopy(op)
35  # Entering the checker workspace
36  old_ws_name = workspace.CurrentWorkspace()
37  results = []
38  workspace.SwitchWorkspace("_device_check_", True)
39  for i, device_option in enumerate(self._device_options):
40  op.device_option.CopyFrom(device_option)
41  _input_device_options = input_device_options or \
42  InferOpBlobDevicesAsDict(op)[0]
43  print(_input_device_options)
44  for i, arr in enumerate(inputs):
45  workspace.FeedBlob(
46  op.input[i], np.array(arr),
47  _input_device_options.get(op.input[i], device_option)
48  )
49  workspace.RunOperatorOnce(op)
50  results.append(
51  [workspace.FetchBlob(op.output[idx])
52  for idx in outputs_to_check])
53  # Everything is done, reset the workspace.
54  workspace.ResetWorkspace()
55  # After running on all devices, check correctness
56  success = True
57  for i in range(1, len(self._device_options)):
58  for j in range(len(outputs_to_check)):
59  x = results[i][j]
60  y = results[0][j]
61  if not np.allclose(x, y,
62  atol=self._threshold, rtol=self._threshold):
63  print('Failure in checking device option {}'
64  ' and output {}. The outputs are:'
65  .format(i, op.output[outputs_to_check[j]]))
66  print(x.flatten())
67  print(y.flatten())
68  print(np.max(np.abs(x - y)))
69  success = False
70  # else:
71  # print ('Passed device pair (0, %d), %s %s' %
72  # (i, outputs_to_check[j], y.shape))
73  workspace.SwitchWorkspace(old_ws_name)
74  return success
75 
76  def CheckNet(self, net, inputs=None, blobs_to_check=None, ignore=None):
77  """Checks a network by inspecting all of its intermediate results, and
78  see if things match.
79  """
80  if inputs is None:
81  inputs = {}
82  if ignore is None:
83  ignore = set()
84  old_ws_name = workspace.CurrentWorkspace()
85  results = []
86  if blobs_to_check is None:
87  blobs_to_check = sum([list(op.output) for op in net.op], [])
88  blobs_to_check = [b for b in blobs_to_check if b not in ignore]
89  workspace.SwitchWorkspace("_device_check_", True)
90  for device_option in self._device_options:
91  for name, arr in viewitems(inputs):
92  # print 'feeding', name
93  workspace.FeedBlob(name, arr, device_option)
94  for op in net.op:
95  op.device_option.CopyFrom(device_option)
96  workspace.RunNetOnce(net)
97  results.append(
98  [workspace.FetchBlob(name) for name in blobs_to_check]
99  )
100  # After running on all devices, check correctness
101  success = True
102  for i in range(1, len(results)):
103  for j in range(len(blobs_to_check)):
104  x = results[i][j]
105  y = results[0][j]
106  if not np.allclose(x, y,
107  atol=self._threshold, rtol=self._threshold):
108  print('Failure in checking device option {}'
109  ' and output {}. The outputs are:'
110  .format(i, blobs_to_check[j]))
111  print(x.flatten())
112  print(y.flatten())
113  print(np.max(np.abs(x - y)))
114  success = False
115  # else:
116  # print ('Passed device pair (%d, %d), %s %s: %s' %
117  # (i, j, blobs_to_check[j], y.shape,
118  # str(y.flatten())))
119  workspace.SwitchWorkspace(old_ws_name)
120  return success
def CheckSimple(self, op, inputs, outputs_to_check, input_device_options=None)
def CheckNet(self, net, inputs=None, blobs_to_check=None, ignore=None)