Caffe2 - Python API
A deep learning, cross platform ML framework
serialized_test_util.py
1 from __future__ import absolute_import
2 from __future__ import division
3 from __future__ import print_function
4 from __future__ import unicode_literals
5 
6 import argparse
7 from caffe2.proto import caffe2_pb2
8 from caffe2.python import gradient_checker
10 from caffe2.python.serialized_test import coverage
11 import hypothesis as hy
12 import inspect
13 import numpy as np
14 import os
15 import shutil
16 import sys
17 import tempfile
18 import threading
19 from zipfile import ZipFile
20 
21 operator_test_type = 'operator_test'
22 TOP_DIR = os.path.dirname(os.path.realpath(__file__))
23 DATA_SUFFIX = 'data'
24 DATA_DIR = os.path.join(TOP_DIR, DATA_SUFFIX)
25 _output_context = threading.local()
26 
27 
28 def given(*given_args, **given_kwargs):
29  def wrapper(f):
30  hyp_func = hy.given(*given_args, **given_kwargs)(f)
31  fixed_seed_func = hy.seed(0)(hy.settings(max_examples=1)(hy.given(
32  *given_args, **given_kwargs)(f)))
33 
34  def func(self, *args, **kwargs):
35  self.should_serialize = True
36  fixed_seed_func(self, *args, **kwargs)
37  self.should_serialize = False
38  hyp_func(self, *args, **kwargs)
39  return func
40  return wrapper
41 
42 
43 def _getGradientOrNone(op_proto):
44  try:
45  grad_ops, _ = gradient_checker.getGradientForOp(op_proto)
46  return grad_ops
47  except Exception:
48  return []
49 
50 
51 # necessary to support converting jagged lists into numpy arrays
52 def _transformList(l):
53  ret = np.empty(len(l), dtype=np.object)
54  for (i, arr) in enumerate(l):
55  ret[i] = arr
56  return ret
57 
58 
59 def _prepare_dir(path):
60  if os.path.exists(path):
61  shutil.rmtree(path)
62  os.makedirs(path)
63 
64 
65 class SerializedTestCase(hu.HypothesisTestCase):
66 
67  should_serialize = False
68 
69  def get_output_dir(self):
70  output_dir_arg = getattr(_output_context, 'output_dir', DATA_DIR)
71  output_dir = os.path.join(
72  output_dir_arg, operator_test_type)
73 
74  if os.path.exists(output_dir):
75  return output_dir
76 
77  # fall back to pwd
78  cwd = os.getcwd()
79  serialized_util_module_components = __name__.split('.')
80  serialized_util_module_components.pop()
81  serialized_dir = '/'.join(serialized_util_module_components)
82  output_dir_fallback = os.path.join(cwd, serialized_dir, DATA_SUFFIX)
83  output_dir = os.path.join(
84  output_dir_fallback,
85  operator_test_type)
86 
87  return output_dir
88 
89  def get_output_filename(self):
90  class_path = inspect.getfile(self.__class__)
91  file_name_components = os.path.basename(class_path).split('.')
92  test_file = file_name_components[0]
93 
94  function_name_components = self.id().split('.')
95  test_function = function_name_components[-1]
96 
97  return test_file + '.' + test_function
98 
99  def serialize_test(self, inputs, outputs, grad_ops, op, device_option):
100  output_dir = self.get_output_dir()
101  test_name = self.get_output_filename()
102  full_dir = os.path.join(output_dir, test_name)
103  _prepare_dir(full_dir)
104 
105  inputs = _transformList(inputs)
106  outputs = _transformList(outputs)
107  device_type = int(device_option.device_type)
108 
109  op_path = os.path.join(full_dir, 'op.pb')
110  grad_paths = []
111  inout_path = os.path.join(full_dir, 'inout')
112 
113  with open(op_path, 'wb') as f:
114  f.write(op.SerializeToString())
115  for (i, grad) in enumerate(grad_ops):
116  grad_path = os.path.join(full_dir, 'grad_{}.pb'.format(i))
117  grad_paths.append(grad_path)
118  with open(grad_path, 'wb') as f:
119  f.write(grad.SerializeToString())
120 
121  np.savez_compressed(
122  inout_path,
123  inputs=inputs,
124  outputs=outputs,
125  device_type=device_type)
126 
127  with ZipFile(os.path.join(output_dir, test_name + '.zip'), 'w') as z:
128  z.write(op_path, 'op.pb')
129  z.write(inout_path + '.npz', 'inout.npz')
130  for path in grad_paths:
131  z.write(path, os.path.basename(path))
132 
133  shutil.rmtree(full_dir)
134 
135  def compare_test(self, inputs, outputs, grad_ops, atol=1e-7, rtol=1e-7):
136 
137  def parse_proto(x):
138  proto = caffe2_pb2.OperatorDef()
139  proto.ParseFromString(x)
140  return proto
141 
142  source_dir = self.get_output_dir()
143  test_name = self.get_output_filename()
144  temp_dir = tempfile.mkdtemp()
145  with ZipFile(os.path.join(source_dir, test_name + '.zip')) as z:
146  z.extractall(temp_dir)
147 
148  op_path = os.path.join(temp_dir, 'op.pb')
149  inout_path = os.path.join(temp_dir, 'inout.npz')
150 
151  # load serialized input and output
152  loaded = np.load(inout_path, encoding='bytes')
153  loaded_inputs = loaded['inputs'].tolist()
154  inputs_equal = True
155  for (x, y) in zip(inputs, loaded_inputs):
156  if not np.array_equal(x, y):
157  inputs_equal = False
158  loaded_outputs = loaded['outputs'].tolist()
159 
160  # if inputs are not the same, run serialized input through serialized op
161  if not inputs_equal:
162  # load operator
163  with open(op_path, 'rb') as f:
164  loaded_op = f.read()
165 
166  op_proto = parse_proto(loaded_op)
167  device_type = loaded['device_type']
168  device_option = caffe2_pb2.DeviceOption(
169  device_type=int(device_type))
170 
171  outputs = hu.runOpOnInput(device_option, op_proto, loaded_inputs)
172  grad_ops = _getGradientOrNone(op_proto)
173 
174  # assert outputs are equal
175  for (x, y) in zip(outputs, loaded_outputs):
176  np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)
177 
178  # assert gradient op is equal
179  for i in range(len(grad_ops)):
180  grad_path = os.path.join(temp_dir, 'grad_{}.pb'.format(i))
181  with open(grad_path, 'rb') as f:
182  loaded_grad = f.read()
183  grad_proto = parse_proto(loaded_grad)
184  self.assertTrue(grad_proto == grad_ops[i])
185 
186  shutil.rmtree(temp_dir)
187 
188  def assertSerializedOperatorChecks(
189  self,
190  inputs,
191  outputs,
192  gradient_operator,
193  op,
194  device_option,
195  atol=1e-7,
196  rtol=1e-7,
197  ):
198  if self.should_serialize:
199  if getattr(_output_context, 'should_generate_output', False):
200  self.serialize_test(
201  inputs, outputs, gradient_operator, op, device_option)
202  if not getattr(_output_context, 'disable_gen_coverage', False):
203  coverage.gen_serialized_test_coverage(
204  self.get_output_dir(), TOP_DIR)
205  else:
206  self.compare_test(
207  inputs, outputs, gradient_operator, atol, rtol)
208 
209  def assertReferenceChecks(
210  self,
211  device_option,
212  op,
213  inputs,
214  reference,
215  input_device_options=None,
216  threshold=1e-4,
217  output_to_grad=None,
218  grad_reference=None,
219  atol=None,
220  outputs_to_check=None,
221  ):
222  outs = super(SerializedTestCase, self).assertReferenceChecks(
223  device_option,
224  op,
225  inputs,
226  reference,
227  input_device_options,
228  threshold,
229  output_to_grad,
230  grad_reference,
231  atol,
232  outputs_to_check,
233  )
234  if not getattr(_output_context, 'disable_serialized_check', False):
235  grad_ops = _getGradientOrNone(op)
236  rtol = threshold
237  if atol is None:
238  atol = threshold
240  inputs,
241  outs,
242  grad_ops,
243  op,
244  device_option,
245  atol,
246  rtol,
247  )
248 
249 
250 def testWithArgs():
251  parser = argparse.ArgumentParser()
252  parser.add_argument(
253  '-G', '--generate-serialized', action='store_true', dest='generate',
254  help='generate output files (default=false, compares to current files)')
255  parser.add_argument(
256  '-O', '--output', default=DATA_DIR,
257  help='output directory (default: %(default)s)')
258  parser.add_argument(
259  '-D', '--disable-serialized_check', action='store_true', dest='disable',
260  help='disable checking serialized tests')
261  parser.add_argument(
262  '-C', '--disable-gen-coverage', action='store_true',
263  dest='disable_coverage',
264  help='disable generating coverage markdown file')
265  parser.add_argument('unittest_args', nargs='*')
266  args = parser.parse_args()
267  sys.argv[1:] = args.unittest_args
268  _output_context.__setattr__('should_generate_output', args.generate)
269  _output_context.__setattr__('output_dir', args.output)
270  _output_context.__setattr__('disable_serialized_check', args.disable)
271  _output_context.__setattr__('disable_gen_coverage', args.disable_coverage)
272 
273  import unittest
274  unittest.main()
Module caffe2.python.layers.split.
def assertSerializedOperatorChecks(self, inputs, outputs, gradient_operator, op, device_option, atol=1e-7, rtol=1e-7)
def compare_test(self, inputs, outputs, grad_ops, atol=1e-7, rtol=1e-7)
def serialize_test(self, inputs, outputs, grad_ops, op, device_option)