Caffe2 - Python API
A deep learning, cross platform ML framework
layer_model_instantiator.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package layer_model_instantiator
17 # Module caffe2.python.layer_model_instantiator
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from caffe2.python import core, schema
24 from caffe2.python.layers.layers import InstantiationContext
25 from caffe2.python.layers.tags import Tags
26 
27 
28 def _filter_layers(layers, include_tags):
29  if include_tags is None:
30  return layers
31  include_tags = set(include_tags)
32  return [l for l in layers if not include_tags.isdisjoint(l.tags)]
33 
34 
35 def shrink_output_schema(net, out_schema):
36  if len(out_schema.field_names()) <= 1:
37  return out_schema
38  exists = [net.BlobIsDefined(blob) for blob in out_schema.field_blobs()]
39  return schema.from_column_list(
40  [
41  col_name for ok, col_name in
42  zip(exists, out_schema.field_names()) if ok
43  ],
44  [
45  col_type for ok, col_type in
46  zip(exists, out_schema.field_types()) if ok
47  ],
48  [
49  col_blob for ok, col_blob in
50  zip(exists, out_schema.field_blobs()) if ok
51  ],
52  [
53  col_meta for ok, col_meta in
54  zip(exists, out_schema.field_metadata()) if ok
55  ]
56  )
57 
58 
59 def generate_predict_net(model, include_tags=None):
60  predict_net = core.Net('predict_net')
61 
62  for layer in _filter_layers(model.layers, include_tags):
63  if Tags.EXCLUDE_FROM_PREDICTION not in layer.tags:
64  layer.add_operators(
65  predict_net, context=InstantiationContext.PREDICTION)
66 
67  predict_net.set_input_record(model.input_feature_schema.clone())
68  output_schema = shrink_output_schema(
69  predict_net, model.output_schema.clone()
70  )
71  predict_net.set_output_record(output_schema)
72  return predict_net
73 
74 
75 def generate_eval_net(model, include_tags=None):
76  eval_net = core.Net('eval_net')
77 
78  for layer in _filter_layers(model.layers, include_tags):
79  if Tags.EXCLUDE_FROM_EVAL not in layer.tags:
80  layer.add_operators(eval_net, context=InstantiationContext.EVAL)
81 
82  input_schema = model.input_feature_schema + model.trainer_extra_schema
83  eval_net.set_input_record(input_schema)
84  output_schema = shrink_output_schema(
85  eval_net, model.output_schema + model.metrics_schema
86  )
87  eval_net.set_output_record(output_schema)
88  return eval_net
89 
90 
91 def _generate_training_net_only(model, include_tags=None):
92  train_net = core.Net('train_net')
93  train_init_net = model.create_init_net('train_init_net')
94 
95  for layer in _filter_layers(model.layers, include_tags):
96  if Tags.EXCLUDE_FROM_TRAIN not in layer.tags:
97  layer.add_operators(train_net, train_init_net)
98 
99  input_schema = model.input_feature_schema + model.trainer_extra_schema
100  train_net.set_input_record(input_schema)
101  output_schema = shrink_output_schema(
102  train_net, model.output_schema + model.metrics_schema
103  )
104  train_net.set_output_record(output_schema)
105  return train_init_net, train_net
106 
107 
108 def generate_training_nets_forward_only(model, include_tags=None):
109  train_init_net, train_net = _generate_training_net_only(model, include_tags)
110  return train_init_net, train_net
111 
112 
113 def generate_training_nets(model, include_tags=None):
114  train_init_net, train_net = _generate_training_net_only(model, include_tags)
115 
116  model.apply_regularizers_on_loss(train_net, train_init_net)
117  loss = model.loss
118  grad_map = train_net.AddGradientOperators(loss.field_blobs())
119  model.apply_optimizers(train_net, train_init_net, grad_map)
120  model.apply_regularizers_after_optimizer(train_net, train_init_net, grad_map)
121  return train_init_net, train_net