3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
13 def _filter_layers(layers, include_tags):
14 if include_tags
is None:
16 include_tags = set(include_tags)
17 return [l
for l
in layers
if not include_tags.isdisjoint(l.tags)]
20 def shrink_output_schema(net, out_schema):
21 if len(out_schema.field_names()) <= 1:
23 exists = [net.BlobIsDefined(blob)
for blob
in out_schema.field_blobs()]
24 return schema.from_column_list(
26 col_name
for ok, col_name
in 27 zip(exists, out_schema.field_names())
if ok
30 col_type
for ok, col_type
in 31 zip(exists, out_schema.field_types())
if ok
34 col_blob
for ok, col_blob
in 35 zip(exists, out_schema.field_blobs())
if ok
38 col_meta
for ok, col_meta
in 39 zip(exists, out_schema.field_metadata())
if ok
44 def generate_predict_net(model, include_tags=None):
45 predict_net = core.Net(
'predict_net')
47 for layer
in _filter_layers(model.layers, include_tags):
48 if Tags.EXCLUDE_FROM_PREDICTION
not in layer.tags:
50 predict_net, context=InstantiationContext.PREDICTION)
52 predict_net.set_input_record(model.input_feature_schema.clone())
53 output_schema = shrink_output_schema(
54 predict_net, model.output_schema.clone()
56 predict_net.set_output_record(output_schema)
60 def generate_eval_net(model, include_tags=None):
61 eval_net = core.Net(
'eval_net')
63 for layer
in _filter_layers(model.layers, include_tags):
64 if Tags.EXCLUDE_FROM_EVAL
not in layer.tags:
65 layer.add_operators(eval_net, context=InstantiationContext.EVAL)
67 input_schema = model.input_feature_schema + model.trainer_extra_schema
68 eval_net.set_input_record(input_schema)
69 output_schema = shrink_output_schema(
70 eval_net, model.output_schema + model.metrics_schema
72 eval_net.set_output_record(output_schema)
76 def _generate_training_net_only(model, include_tags=None):
77 train_net = core.Net(
'train_net')
78 train_init_net = model.create_init_net(
'train_init_net')
80 for layer
in _filter_layers(model.layers, include_tags):
81 if Tags.EXCLUDE_FROM_TRAIN
not in layer.tags:
82 layer.add_operators(train_net, train_init_net)
84 input_schema = model.input_feature_schema + model.trainer_extra_schema
85 train_net.set_input_record(input_schema)
86 output_schema = shrink_output_schema(
87 train_net, model.output_schema + model.metrics_schema
89 train_net.set_output_record(output_schema)
90 return train_init_net, train_net
93 def generate_training_nets_forward_only(model, include_tags=None):
94 train_init_net, train_net = _generate_training_net_only(model, include_tags)
95 return train_init_net, train_net
98 def generate_training_nets(model, include_tags=None):
99 train_init_net, train_net = _generate_training_net_only(model, include_tags)
101 model.apply_regularizers_on_loss(train_net, train_init_net)
102 if not model.has_loss():
103 return train_init_net, train_net
105 grad_map = train_net.AddGradientOperators(loss.field_blobs())
106 model.apply_post_grad_net_modifiers(train_net, train_init_net, grad_map,
107 modify_output_record=
True)
108 model.apply_optimizers(train_net, train_init_net, grad_map)
109 model.apply_regularizers_after_optimizer(train_net, train_init_net, grad_map)
110 model.apply_final_net_modifiers(train_net, train_init_net, grad_map,
111 modify_output_record=
True)
113 return train_init_net, train_net