1 from __future__
import absolute_import
2 from __future__
import division
3 from __future__
import print_function
4 from __future__
import unicode_literals
7 from caffe2.proto
import caffe2_pb2
11 import hypothesis
as hy
19 from zipfile
import ZipFile
21 operator_test_type =
'operator_test' 22 TOP_DIR = os.path.dirname(os.path.realpath(__file__))
24 DATA_DIR = os.path.join(TOP_DIR, DATA_SUFFIX)
25 _output_context = threading.local()
28 def given(*given_args, **given_kwargs):
30 hyp_func = hy.given(*given_args, **given_kwargs)(f)
31 fixed_seed_func = hy.seed(0)(hy.settings(max_examples=1)(hy.given(
32 *given_args, **given_kwargs)(f)))
34 def func(self, *args, **kwargs):
35 self.should_serialize =
True 36 fixed_seed_func(self, *args, **kwargs)
37 self.should_serialize =
False 38 hyp_func(self, *args, **kwargs)
43 def _getGradientOrNone(op_proto):
45 grad_ops, _ = gradient_checker.getGradientForOp(op_proto)
52 def _transformList(l):
53 ret = np.empty(len(l), dtype=np.object)
54 for (i, arr)
in enumerate(l):
59 def _prepare_dir(path):
60 if os.path.exists(path):
67 should_serialize =
False 69 def get_output_dir(self):
70 output_dir_arg = getattr(_output_context,
'output_dir', DATA_DIR)
71 output_dir = os.path.join(
72 output_dir_arg, operator_test_type)
74 if os.path.exists(output_dir):
79 serialized_util_module_components = __name__.split(
'.')
80 serialized_util_module_components.pop()
81 serialized_dir =
'/'.join(serialized_util_module_components)
82 output_dir_fallback = os.path.join(cwd, serialized_dir, DATA_SUFFIX)
83 output_dir = os.path.join(
89 def get_output_filename(self):
90 class_path = inspect.getfile(self.__class__)
91 file_name_components = os.path.basename(class_path).
split(
'.')
92 test_file = file_name_components[0]
94 function_name_components = self.id().
split(
'.')
95 test_function = function_name_components[-1]
97 return test_file +
'.' + test_function
99 def serialize_test(self, inputs, outputs, grad_ops, op, device_option):
102 full_dir = os.path.join(output_dir, test_name)
103 _prepare_dir(full_dir)
105 inputs = _transformList(inputs)
106 outputs = _transformList(outputs)
107 device_type = int(device_option.device_type)
109 op_path = os.path.join(full_dir,
'op.pb')
111 inout_path = os.path.join(full_dir,
'inout')
113 with open(op_path,
'wb')
as f:
114 f.write(op.SerializeToString())
115 for (i, grad)
in enumerate(grad_ops):
116 grad_path = os.path.join(full_dir,
'grad_{}.pb'.format(i))
117 grad_paths.append(grad_path)
118 with open(grad_path,
'wb')
as f:
119 f.write(grad.SerializeToString())
125 device_type=device_type)
127 with ZipFile(os.path.join(output_dir, test_name +
'.zip'),
'w')
as z:
128 z.write(op_path,
'op.pb')
129 z.write(inout_path +
'.npz',
'inout.npz')
130 for path
in grad_paths:
131 z.write(path, os.path.basename(path))
133 shutil.rmtree(full_dir)
135 def compare_test(self, inputs, outputs, grad_ops, atol=1e-7, rtol=1e-7):
138 proto = caffe2_pb2.OperatorDef()
139 proto.ParseFromString(x)
144 temp_dir = tempfile.mkdtemp()
145 with ZipFile(os.path.join(source_dir, test_name +
'.zip'))
as z:
146 z.extractall(temp_dir)
148 op_path = os.path.join(temp_dir,
'op.pb')
149 inout_path = os.path.join(temp_dir,
'inout.npz')
152 loaded = np.load(inout_path, encoding=
'bytes')
153 loaded_inputs = loaded[
'inputs'].tolist()
155 for (x, y)
in zip(inputs, loaded_inputs):
156 if not np.array_equal(x, y):
158 loaded_outputs = loaded[
'outputs'].tolist()
163 with open(op_path,
'rb')
as f:
166 op_proto = parse_proto(loaded_op)
167 device_type = loaded[
'device_type']
168 device_option = caffe2_pb2.DeviceOption(
169 device_type=int(device_type))
171 outputs = hu.runOpOnInput(device_option, op_proto, loaded_inputs)
172 grad_ops = _getGradientOrNone(op_proto)
175 for (x, y)
in zip(outputs, loaded_outputs):
176 np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)
179 for i
in range(len(grad_ops)):
180 grad_path = os.path.join(temp_dir,
'grad_{}.pb'.format(i))
181 with open(grad_path,
'rb')
as f:
182 loaded_grad = f.read()
183 grad_proto = parse_proto(loaded_grad)
184 self.assertTrue(grad_proto == grad_ops[i])
186 shutil.rmtree(temp_dir)
188 def assertSerializedOperatorChecks(
199 if getattr(_output_context,
'should_generate_output',
False):
201 inputs, outputs, gradient_operator, op, device_option)
202 if not getattr(_output_context,
'disable_gen_coverage',
False):
203 coverage.gen_serialized_test_coverage(
207 inputs, outputs, gradient_operator, atol, rtol)
209 def assertReferenceChecks(
215 input_device_options=
None,
220 outputs_to_check=
None,
222 outs = super(SerializedTestCase, self).assertReferenceChecks(
227 input_device_options,
234 if not getattr(_output_context,
'disable_serialized_check',
False):
235 grad_ops = _getGradientOrNone(op)
251 parser = argparse.ArgumentParser()
253 '-G',
'--generate-serialized', action=
'store_true', dest=
'generate',
254 help=
'generate output files (default=false, compares to current files)')
256 '-O',
'--output', default=DATA_DIR,
257 help=
'output directory (default: %(default)s)')
259 '-D',
'--disable-serialized_check', action=
'store_true', dest=
'disable',
260 help=
'disable checking serialized tests')
262 '-C',
'--disable-gen-coverage', action=
'store_true',
263 dest=
'disable_coverage',
264 help=
'disable generating coverage markdown file')
265 parser.add_argument(
'unittest_args', nargs=
'*')
266 args = parser.parse_args()
267 sys.argv[1:] = args.unittest_args
268 _output_context.__setattr__(
'should_generate_output', args.generate)
269 _output_context.__setattr__(
'output_dir', args.output)
270 _output_context.__setattr__(
'disable_serialized_check', args.disable)
271 _output_context.__setattr__(
'disable_gen_coverage', args.disable_coverage)
Module caffe2.python.layers.split.
def assertSerializedOperatorChecks(self, inputs, outputs, gradient_operator, op, device_option, atol=1e-7, rtol=1e-7)
def compare_test(self, inputs, outputs, grad_ops, atol=1e-7, rtol=1e-7)
def serialize_test(self, inputs, outputs, grad_ops, op, device_option)
def get_output_filename(self)