Caffe2 - Python API
A deep learning, cross platform ML framework
layer_test_util.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package layer_test_util
17 # Module caffe2.python.layer_test_util
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from collections import namedtuple
24 
25 from caffe2.python import (
26  core,
27  layer_model_instantiator,
28  layer_model_helper,
29  schema,
30  test_util,
31  workspace,
32  utils,
33 )
34 from caffe2.proto import caffe2_pb2
35 import numpy as np
36 
37 
38 class OpSpec(namedtuple("OpSpec", "type input output arg")):
39 
40  def __new__(cls, op_type, op_input, op_output, op_arg=None):
41  return super(OpSpec, cls).__new__(cls, op_type, op_input,
42  op_output, op_arg)
43 
44 
46 
47  def setUp(self):
48  super(LayersTestCase, self).setUp()
49  self.setup_example()
50 
51  def setup_example(self):
52  """
53  This is undocumented feature in hypothesis,
54  https://github.com/HypothesisWorks/hypothesis-python/issues/59
55  """
56  workspace.ResetWorkspace()
57  self.reset_model()
58 
59  def reset_model(self, input_feature_schema=None, trainer_extra_schema=None):
60  input_feature_schema = input_feature_schema or schema.Struct(
61  ('float_features', schema.Scalar((np.float32, (32,)))),
62  )
63  trainer_extra_schema = trainer_extra_schema or schema.Struct()
65  'test_model',
66  input_feature_schema=input_feature_schema,
67  trainer_extra_schema=trainer_extra_schema)
68 
69  def new_record(self, schema_obj):
70  return schema.NewRecord(self.model.net, schema_obj)
71 
72  def get_training_nets(self):
73  """
74  We don't use
75  layer_model_instantiator.generate_training_nets_forward_only()
76  here because it includes initialization of global constants, which make
77  testing tricky
78  """
79  train_net = core.Net('train_net')
80  train_init_net = core.Net('train_init_net')
81  for layer in self.model.layers:
82  layer.add_operators(train_net, train_init_net)
83  return train_init_net, train_net
84 
85  def get_eval_net(self):
86  return layer_model_instantiator.generate_eval_net(self.model)
87 
88  def get_predict_net(self):
89  return layer_model_instantiator.generate_predict_net(self.model)
90 
91  def run_train_net(self):
92  self.model.output_schema = schema.Struct()
93  train_init_net, train_net = \
94  layer_model_instantiator.generate_training_nets(self.model)
95  workspace.RunNetOnce(train_init_net)
96  workspace.RunNetOnce(train_net)
97 
98  def run_train_net_forward_only(self, num_iter=1):
99  self.model.output_schema = schema.Struct()
100  train_init_net, train_net = \
101  layer_model_instantiator.generate_training_nets_forward_only(
102  self.model)
103  workspace.RunNetOnce(train_init_net)
104  assert num_iter > 0, 'num_iter must be larger than 0'
105  workspace.CreateNet(train_net)
106  workspace.RunNet(train_net.Proto().name, num_iter=num_iter)
107 
108  def assertBlobsEqual(self, spec_blobs, op_blobs):
109  """
110  spec_blobs can either be None or a list of blob names. If it's None,
111  then no assertion is performed. The elements of the list can be None,
112  in that case, it means that position will not be checked.
113  """
114  if spec_blobs is None:
115  return
116  self.assertEqual(len(spec_blobs), len(op_blobs))
117  for spec_blob, op_blob in zip(spec_blobs, op_blobs):
118  if spec_blob is None:
119  continue
120  self.assertEqual(spec_blob, op_blob)
121 
122  def assertArgsEqual(self, spec_args, op_args):
123  self.assertEqual(len(spec_args), len(op_args))
124  keys = [a.name for a in op_args]
125 
126  def parse_args(args):
127  operator = caffe2_pb2.OperatorDef()
128  # Generate the expected value in the same order
129  for k in keys:
130  v = args[k]
131  arg = utils.MakeArgument(k, v)
132  operator.arg.add().CopyFrom(arg)
133  return operator.arg
134 
135  self.assertEqual(parse_args(spec_args), op_args)
136 
137  def assertNetContainOps(self, net, op_specs):
138  """
139  Given a net and a list of OpSpec's, check that the net match the spec
140  """
141  ops = net.Proto().op
142  self.assertEqual(len(op_specs), len(ops))
143  for op, op_spec in zip(ops, op_specs):
144  self.assertEqual(op_spec.type, op.type)
145  self.assertBlobsEqual(op_spec.input, op.input)
146  self.assertBlobsEqual(op_spec.output, op.output)
147  if op_spec.arg is not None:
148  self.assertArgsEqual(op_spec.arg, op.arg)
149  return ops
def assertBlobsEqual(self, spec_blobs, op_blobs)
def reset_model(self, input_feature_schema=None, trainer_extra_schema=None)
def assertArgsEqual(self, spec_args, op_args)