Caffe2 - Python API
A deep learning, cross platform ML framework
layer_model_instantiator.py
1 ## @package layer_model_instantiator
2 # Module caffe2.python.layer_model_instantiator
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core, schema
9 from caffe2.python.layers.layers import InstantiationContext
10 from caffe2.python.layers.tags import Tags
11 
12 
13 def _filter_layers(layers, include_tags):
14  if include_tags is None:
15  return layers
16  include_tags = set(include_tags)
17  return [l for l in layers if not include_tags.isdisjoint(l.tags)]
18 
19 
20 def shrink_output_schema(net, out_schema):
21  if len(out_schema.field_names()) <= 1:
22  return out_schema
23  exists = [net.BlobIsDefined(blob) for blob in out_schema.field_blobs()]
24  return schema.from_column_list(
25  [
26  col_name for ok, col_name in
27  zip(exists, out_schema.field_names()) if ok
28  ],
29  [
30  col_type for ok, col_type in
31  zip(exists, out_schema.field_types()) if ok
32  ],
33  [
34  col_blob for ok, col_blob in
35  zip(exists, out_schema.field_blobs()) if ok
36  ],
37  [
38  col_meta for ok, col_meta in
39  zip(exists, out_schema.field_metadata()) if ok
40  ]
41  )
42 
43 
44 def generate_predict_net(model, include_tags=None):
45  predict_net = core.Net('predict_net')
46 
47  for layer in _filter_layers(model.layers, include_tags):
48  if Tags.EXCLUDE_FROM_PREDICTION not in layer.tags:
49  layer.add_operators(
50  predict_net, context=InstantiationContext.PREDICTION)
51 
52  predict_net.set_input_record(model.input_feature_schema.clone())
53  output_schema = shrink_output_schema(
54  predict_net, model.output_schema.clone()
55  )
56  predict_net.set_output_record(output_schema)
57  return predict_net
58 
59 
60 def generate_eval_net(model, include_tags=None):
61  eval_net = core.Net('eval_net')
62 
63  for layer in _filter_layers(model.layers, include_tags):
64  if Tags.EXCLUDE_FROM_EVAL not in layer.tags:
65  layer.add_operators(eval_net, context=InstantiationContext.EVAL)
66 
67  input_schema = model.input_feature_schema + model.trainer_extra_schema
68  eval_net.set_input_record(input_schema)
69  output_schema = shrink_output_schema(
70  eval_net, model.output_schema + model.metrics_schema
71  )
72  eval_net.set_output_record(output_schema)
73  return eval_net
74 
75 
76 def _generate_training_net_only(model, include_tags=None):
77  train_net = core.Net('train_net')
78  train_init_net = model.create_init_net('train_init_net')
79 
80  for layer in _filter_layers(model.layers, include_tags):
81  if Tags.EXCLUDE_FROM_TRAIN not in layer.tags:
82  layer.add_operators(train_net, train_init_net)
83 
84  input_schema = model.input_feature_schema + model.trainer_extra_schema
85  train_net.set_input_record(input_schema)
86  output_schema = shrink_output_schema(
87  train_net, model.output_schema + model.metrics_schema
88  )
89  train_net.set_output_record(output_schema)
90  return train_init_net, train_net
91 
92 
93 def generate_training_nets_forward_only(model, include_tags=None):
94  train_init_net, train_net = _generate_training_net_only(model, include_tags)
95  return train_init_net, train_net
96 
97 
98 def generate_training_nets(model, include_tags=None):
99  train_init_net, train_net = _generate_training_net_only(model, include_tags)
100 
101  model.apply_regularizers_on_loss(train_net, train_init_net)
102  if not model.has_loss():
103  return train_init_net, train_net
104  loss = model.loss
105  grad_map = train_net.AddGradientOperators(loss.field_blobs())
106  model.apply_post_grad_net_modifiers(train_net, train_init_net, grad_map,
107  modify_output_record=True)
108  model.apply_optimizers(train_net, train_init_net, grad_map)
109  model.apply_regularizers_after_optimizer(train_net, train_init_net, grad_map)
110  model.apply_final_net_modifiers(train_net, train_init_net, grad_map,
111  modify_output_record=True)
112 
113  return train_init_net, train_net