7 from future.utils
import viewitems
11 """A device checker in Python to check consistency across multiple devices. 13 This is not the most efficient way to check devices, as the Python interface 14 will involve a lot of copies back and forth operations. Use at your own risk. 17 def __init__(self, threshold, device_options):
22 input_device_options=
None):
23 """Checks the operator with different device implementations. 26 op: the operator to be checked. 27 inputs: the input data in numpy arrays. 28 outputs_to_check: the outputs to check between devices. 29 input_device_options: a mapping from input name to a device to use 30 (instead of self._device_options) 32 boolean: True if it passes, False if it does not pass. 34 op = copy.deepcopy(op)
36 old_ws_name = workspace.CurrentWorkspace()
38 workspace.SwitchWorkspace(
"_device_check_",
True)
40 op.device_option.CopyFrom(device_option)
41 _input_device_options = input_device_options
or \
42 InferOpBlobDevicesAsDict(op)[0]
43 print(_input_device_options)
44 for i, arr
in enumerate(inputs):
46 op.input[i], np.array(arr),
47 _input_device_options.get(op.input[i], device_option)
49 workspace.RunOperatorOnce(op)
51 [workspace.FetchBlob(op.output[idx])
52 for idx
in outputs_to_check])
54 workspace.ResetWorkspace()
58 for j
in range(len(outputs_to_check)):
61 if not np.allclose(x, y,
63 print(
'Failure in checking device option {}' 64 ' and output {}. The outputs are:' 65 .format(i, op.output[outputs_to_check[j]]))
68 print(np.max(np.abs(x - y)))
73 workspace.SwitchWorkspace(old_ws_name)
76 def CheckNet(self, net, inputs=None, blobs_to_check=None, ignore=None):
77 """Checks a network by inspecting all of its intermediate results, and 84 old_ws_name = workspace.CurrentWorkspace()
86 if blobs_to_check
is None:
87 blobs_to_check = sum([list(op.output)
for op
in net.op], [])
88 blobs_to_check = [b
for b
in blobs_to_check
if b
not in ignore]
89 workspace.SwitchWorkspace(
"_device_check_",
True)
91 for name, arr
in viewitems(inputs):
93 workspace.FeedBlob(name, arr, device_option)
95 op.device_option.CopyFrom(device_option)
96 workspace.RunNetOnce(net)
98 [workspace.FetchBlob(name)
for name
in blobs_to_check]
102 for i
in range(1, len(results)):
103 for j
in range(len(blobs_to_check)):
106 if not np.allclose(x, y,
108 print(
'Failure in checking device option {}' 109 ' and output {}. The outputs are:' 110 .format(i, blobs_to_check[j]))
113 print(np.max(np.abs(x - y)))
119 workspace.SwitchWorkspace(old_ws_name)
def CheckSimple(self, op, inputs, outputs_to_check, input_device_options=None)
def CheckNet(self, net, inputs=None, blobs_to_check=None, ignore=None)