Caffe2 - Python API
A deep learning, cross platform ML framework
device_checker.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package device_checker
17 # Module caffe2.python.device_checker
18 import numpy as np
19 import copy
20 from caffe2.python import workspace
21 from future.utils import viewitems
22 
23 
24 class DeviceChecker(object):
25  """A device checker in Python to check consistency across multiple devices.
26 
27  This is not the most efficient way to check devices, as the Python interface
28  will involve a lot of copies back and forth operations. Use at your own risk.
29  """
30 
31  def __init__(self, threshold, device_options):
32  self._threshold = threshold
33  self._device_options = device_options
34 
35  def CheckSimple(self, op, inputs, outputs_to_check,
36  input_device_options=None):
37  """Checks the operator with different device implementations.
38 
39  Inputs:
40  op: the operator to be checked.
41  inputs: the input data in numpy arrays.
42  outputs_to_check: the outputs to check between devices.
43  input_device_options: a mapping from input name to a device to use
44  (instead of self._device_options)
45  Outputs:
46  boolean: True if it passes, False if it does not pass.
47  """
48  op = copy.deepcopy(op)
49  input_device_options = input_device_options or {}
50  # Entering the checker workspace
51  old_ws_name = workspace.CurrentWorkspace()
52  results = []
53  workspace.SwitchWorkspace("_device_check_", True)
54  for i, device_option in enumerate(self._device_options):
55  for i, arr in enumerate(inputs):
56  workspace.FeedBlob(
57  op.input[i], np.array(arr),
58  input_device_options.get(op.input[i], device_option))
59  op.device_option.CopyFrom(device_option)
60  workspace.RunOperatorOnce(op)
61  results.append(
62  [workspace.FetchBlob(op.output[idx])
63  for idx in outputs_to_check])
64  # Everything is done, reset the workspace.
65  workspace.ResetWorkspace()
66  # After running on all devices, check correctness
67  success = True
68  for i in range(1, len(self._device_options)):
69  for j in range(len(outputs_to_check)):
70  x = results[i][j]
71  y = results[0][j]
72  if not np.allclose(x, y,
73  atol=self._threshold, rtol=self._threshold):
74  print('Failure in checking device option {}'
75  ' and output {}. The outputs are:'
76  .format(i, op.output[outputs_to_check[j]]))
77  print(x.flatten())
78  print(y.flatten())
79  print(np.max(np.abs(x - y)))
80  success = False
81  # else:
82  # print ('Passed device pair (0, %d), %s %s' %
83  # (i, outputs_to_check[j], y.shape))
84  workspace.SwitchWorkspace(old_ws_name)
85  return success
86 
87  def CheckNet(self, net, inputs=None, blobs_to_check=None, ignore=None):
88  """Checks a network by inspecting all of its intermediate results, and
89  see if things match.
90  """
91  if inputs is None:
92  inputs = {}
93  if ignore is None:
94  ignore = set()
95  old_ws_name = workspace.CurrentWorkspace()
96  results = []
97  if blobs_to_check is None:
98  blobs_to_check = sum([list(op.output) for op in net.op], [])
99  blobs_to_check = [b for b in blobs_to_check if b not in ignore]
100  workspace.SwitchWorkspace("_device_check_", True)
101  for device_option in self._device_options:
102  for name, arr in viewitems(inputs):
103  # print 'feeding', name
104  workspace.FeedBlob(name, arr, device_option)
105  for op in net.op:
106  op.device_option.CopyFrom(device_option)
107  workspace.RunNetOnce(net)
108  results.append(
109  [workspace.FetchBlob(name) for name in blobs_to_check]
110  )
111  # After running on all devices, check correctness
112  success = True
113  for i in range(1, len(results)):
114  for j in range(len(blobs_to_check)):
115  x = results[i][j]
116  y = results[0][j]
117  if not np.allclose(x, y,
118  atol=self._threshold, rtol=self._threshold):
119  print('Failure in checking device option {}'
120  ' and output {}. The outputs are:'
121  .format(i, blobs_to_check[j]))
122  print(x.flatten())
123  print(y.flatten())
124  print(np.max(np.abs(x - y)))
125  success = False
126  # else:
127  # print ('Passed device pair (%d, %d), %s %s: %s' %
128  # (i, j, blobs_to_check[j], y.shape,
129  # str(y.flatten())))
130  workspace.SwitchWorkspace(old_ws_name)
131  return success
def CheckSimple(self, op, inputs, outputs_to_check, input_device_options=None)
def CheckNet(self, net, inputs=None, blobs_to_check=None, ignore=None)