3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from collections
import namedtuple
12 layer_model_instantiator,
19 from caffe2.proto
import caffe2_pb2
23 class OpSpec(namedtuple(
"OpSpec",
"type input output arg")):
25 def __new__(cls, op_type, op_input, op_output, op_arg=None):
26 return super(OpSpec, cls).__new__(cls, op_type, op_input,
33 super(LayersTestCase, self).setUp()
38 This is undocumented feature in hypothesis, 39 https://github.com/HypothesisWorks/hypothesis-python/issues/59 41 workspace.ResetWorkspace()
44 def reset_model(self, input_feature_schema=None, trainer_extra_schema=None):
48 trainer_extra_schema = trainer_extra_schema
or schema.Struct()
51 input_feature_schema=input_feature_schema,
52 trainer_extra_schema=trainer_extra_schema)
54 def new_record(self, schema_obj):
55 return schema.NewRecord(self.model.net, schema_obj)
60 layer_model_instantiator.generate_training_nets_forward_only() 61 here because it includes initialization of global constants, which make 66 train_init_net = self.model.create_init_net(
'train_init_net')
68 train_init_net =
core.Net(
'train_init_net')
69 for layer
in self.model.layers:
70 layer.add_operators(train_net, train_init_net)
71 return train_init_net, train_net
73 def get_eval_net(self):
74 return layer_model_instantiator.generate_eval_net(self.
model)
76 def get_predict_net(self):
77 return layer_model_instantiator.generate_predict_net(self.
model)
79 def run_train_net(self):
81 train_init_net, train_net = \
82 layer_model_instantiator.generate_training_nets(self.
model)
83 workspace.RunNetOnce(train_init_net)
84 workspace.RunNetOnce(train_net)
86 def run_train_net_forward_only(self, num_iter=1):
88 train_init_net, train_net = \
89 layer_model_instantiator.generate_training_nets_forward_only(
91 workspace.RunNetOnce(train_init_net)
92 assert num_iter > 0,
'num_iter must be larger than 0' 93 workspace.CreateNet(train_net)
94 workspace.RunNet(train_net.Proto().name, num_iter=num_iter)
98 spec_blobs can either be None or a list of blob names. If it's None, 99 then no assertion is performed. The elements of the list can be None, 100 in that case, it means that position will not be checked. 102 if spec_blobs
is None:
104 self.assertEqual(len(spec_blobs), len(op_blobs))
105 for spec_blob, op_blob
in zip(spec_blobs, op_blobs):
106 if spec_blob
is None:
108 self.assertEqual(spec_blob, op_blob)
110 def assertArgsEqual(self, spec_args, op_args):
111 self.assertEqual(len(spec_args), len(op_args))
112 keys = [a.name
for a
in op_args]
114 def parse_args(args):
115 operator = caffe2_pb2.OperatorDef()
119 arg = utils.MakeArgument(k, v)
120 operator.arg.add().CopyFrom(arg)
123 self.assertEqual(parse_args(spec_args), op_args)
127 Given a net and a list of OpSpec's, check that the net match the spec 130 self.assertEqual(len(op_specs), len(ops))
131 for op, op_spec
in zip(ops, op_specs):
132 self.assertEqual(op_spec.type, op.type)
135 if op_spec.arg
is not None:
def assertBlobsEqual(self, spec_blobs, op_blobs)
def assertNetContainOps(self, net, op_specs)
def reset_model(self, input_feature_schema=None, trainer_extra_schema=None)
def get_training_nets(self, add_constants=False)
def assertArgsEqual(self, spec_args, op_args)