Caffe2 - Python API
A deep learning, cross platform ML framework
select_record_by_context.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 from __future__ import absolute_import
17 from __future__ import division
18 from __future__ import print_function
19 from __future__ import unicode_literals
20 
21 import logging
22 
23 from caffe2.python import schema
24 from caffe2.python.layers.layers import (
25  InstantiationContext,
26  ModelLayer,
27 )
28 
29 
30 logger = logging.getLogger(__name__)
31 
32 
33 class SelectRecordByContext(ModelLayer):
34  """
35  Allowing model to follow different paths for each instatiation context and
36  join later at some point. The implementation use `Alias` because schema
37  sometimes clone fields internally so we need static blob name for output
38  """
39 
40  def __init__(self, model, input_record, name='select_record_by_context',
41  check_field_metas=True, **kwargs):
42  super(SelectRecordByContext, self).__init__(model, name, input_record,
43  **kwargs)
44 
45  assert isinstance(input_record, schema.Struct)
46  assert len(input_record) > 1
47 
48  ref_record = input_record[0]
49  for record in input_record:
50  assert schema.equal_schemas(record, ref_record,
51  check_field_metas=check_field_metas)
52 
53  self.output_schema = schema.NewRecord(model.net, ref_record)
54 
55  def _set_output_blobs(self, net, context):
56  assert context in self.input_record, (
57  "{} context is not in input record".format(context)
58  )
59  record = self.input_record[context]
60 
61  for in_blob, out_blob in zip(
62  record.field_blobs(), self.output_schema.field_blobs()
63  ):
64  net.Alias(in_blob, out_blob)
65 
66  def add_ops(self, net):
67  self._set_output_blobs(net, InstantiationContext.PREDICTION)
68 
69  def add_eval_ops(self, net):
70  self._set_output_blobs(net, InstantiationContext.EVAL)
71 
72  def add_train_ops(self, net):
73  self._set_output_blobs(net, InstantiationContext.TRAINING)
74 
75  def add_ops_to_accumulate_pred(self, net):
76  self._set_output_blobs(net, InstantiationContext.ACCUMULATE_PRED)