Caffe2 - Python API
A deep learning, cross platform ML framework
layer_test_util.py
1 ## @package layer_test_util
2 # Module caffe2.python.layer_test_util
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from collections import namedtuple
9 
10 from caffe2.python import (
11  core,
12  layer_model_instantiator,
13  layer_model_helper,
14  schema,
15  test_util,
16  workspace,
17  utils,
18 )
19 from caffe2.proto import caffe2_pb2
20 import numpy as np
21 
22 
23 class OpSpec(namedtuple("OpSpec", "type input output arg")):
24 
25  def __new__(cls, op_type, op_input, op_output, op_arg=None):
26  return super(OpSpec, cls).__new__(cls, op_type, op_input,
27  op_output, op_arg)
28 
29 
31 
32  def setUp(self):
33  super(LayersTestCase, self).setUp()
34  self.setup_example()
35 
36  def setup_example(self):
37  """
38  This is undocumented feature in hypothesis,
39  https://github.com/HypothesisWorks/hypothesis-python/issues/59
40  """
41  workspace.ResetWorkspace()
42  self.reset_model()
43 
44  def reset_model(self, input_feature_schema=None, trainer_extra_schema=None):
45  input_feature_schema = input_feature_schema or schema.Struct(
46  ('float_features', schema.Scalar((np.float32, (32,)))),
47  )
48  trainer_extra_schema = trainer_extra_schema or schema.Struct()
50  'test_model',
51  input_feature_schema=input_feature_schema,
52  trainer_extra_schema=trainer_extra_schema)
53 
54  def new_record(self, schema_obj):
55  return schema.NewRecord(self.model.net, schema_obj)
56 
57  def get_training_nets(self, add_constants=False):
58  """
59  We don't use
60  layer_model_instantiator.generate_training_nets_forward_only()
61  here because it includes initialization of global constants, which make
62  testing tricky
63  """
64  train_net = core.Net('train_net')
65  if add_constants:
66  train_init_net = self.model.create_init_net('train_init_net')
67  else:
68  train_init_net = core.Net('train_init_net')
69  for layer in self.model.layers:
70  layer.add_operators(train_net, train_init_net)
71  return train_init_net, train_net
72 
73  def get_eval_net(self):
74  return layer_model_instantiator.generate_eval_net(self.model)
75 
76  def get_predict_net(self):
77  return layer_model_instantiator.generate_predict_net(self.model)
78 
79  def run_train_net(self):
80  self.model.output_schema = schema.Struct()
81  train_init_net, train_net = \
82  layer_model_instantiator.generate_training_nets(self.model)
83  workspace.RunNetOnce(train_init_net)
84  workspace.RunNetOnce(train_net)
85 
86  def run_train_net_forward_only(self, num_iter=1):
87  self.model.output_schema = schema.Struct()
88  train_init_net, train_net = \
89  layer_model_instantiator.generate_training_nets_forward_only(
90  self.model)
91  workspace.RunNetOnce(train_init_net)
92  assert num_iter > 0, 'num_iter must be larger than 0'
93  workspace.CreateNet(train_net)
94  workspace.RunNet(train_net.Proto().name, num_iter=num_iter)
95 
96  def assertBlobsEqual(self, spec_blobs, op_blobs):
97  """
98  spec_blobs can either be None or a list of blob names. If it's None,
99  then no assertion is performed. The elements of the list can be None,
100  in that case, it means that position will not be checked.
101  """
102  if spec_blobs is None:
103  return
104  self.assertEqual(len(spec_blobs), len(op_blobs))
105  for spec_blob, op_blob in zip(spec_blobs, op_blobs):
106  if spec_blob is None:
107  continue
108  self.assertEqual(spec_blob, op_blob)
109 
110  def assertArgsEqual(self, spec_args, op_args):
111  self.assertEqual(len(spec_args), len(op_args))
112  keys = [a.name for a in op_args]
113 
114  def parse_args(args):
115  operator = caffe2_pb2.OperatorDef()
116  # Generate the expected value in the same order
117  for k in keys:
118  v = args[k]
119  arg = utils.MakeArgument(k, v)
120  operator.arg.add().CopyFrom(arg)
121  return operator.arg
122 
123  self.assertEqual(parse_args(spec_args), op_args)
124 
125  def assertNetContainOps(self, net, op_specs):
126  """
127  Given a net and a list of OpSpec's, check that the net match the spec
128  """
129  ops = net.Proto().op
130  self.assertEqual(len(op_specs), len(ops))
131  for op, op_spec in zip(ops, op_specs):
132  self.assertEqual(op_spec.type, op.type)
133  self.assertBlobsEqual(op_spec.input, op.input)
134  self.assertBlobsEqual(op_spec.output, op.output)
135  if op_spec.arg is not None:
136  self.assertArgsEqual(op_spec.arg, op.arg)
137  return ops
def assertBlobsEqual(self, spec_blobs, op_blobs)
def reset_model(self, input_feature_schema=None, trainer_extra_schema=None)
def get_training_nets(self, add_constants=False)
def assertArgsEqual(self, spec_args, op_args)