Caffe2 - Python API
A deep learning, cross platform ML framework
select_record_by_context.py
1 from __future__ import absolute_import
2 from __future__ import division
3 from __future__ import print_function
4 from __future__ import unicode_literals
5 
6 import logging
7 
8 from caffe2.python import schema
9 from caffe2.python.layers.layers import (
10  InstantiationContext,
11  ModelLayer,
12 )
13 
14 
15 logger = logging.getLogger(__name__)
16 
17 
18 class SelectRecordByContext(ModelLayer):
19  """
20  Allowing model to follow different paths for each instatiation context and
21  join later at some point. The implementation use `Alias` because schema
22  sometimes clone fields internally so we need static blob name for output
23  """
24 
25  def __init__(
26  self,
27  model,
28  input_record,
29  name='select_record_by_context',
30  check_field_metas=True,
31  use_copy=False,
32  default_output_record_field=None,
33  **kwargs
34  ):
35  super(SelectRecordByContext, self).__init__(model, name, input_record,
36  **kwargs)
37 
38  assert isinstance(input_record, schema.Struct)
39  assert len(input_record) > 1
40 
41  self.use_copy = use_copy
42  self.default_output_record = (
43  input_record[default_output_record_field]
44  if (default_output_record_field is not None) else None
45  )
46  ref_record = input_record[0]
47  for record in input_record:
48  assert schema.equal_schemas(record, ref_record,
49  check_field_metas=check_field_metas)
50 
51  self.output_schema = schema.NewRecord(model.net, ref_record)
52 
53  def _set_output_blobs(self, net, context):
54  record = self.input_record.get(context, self.default_output_record)
55  assert record is not None, (
56  "{} context is not in input record without providing default"
57  " output".format(context)
58  )
59  for in_blob, out_blob in zip(
60  record.field_blobs(), self.output_schema.field_blobs()
61  ):
62  if self.use_copy:
63  net.Copy(in_blob, out_blob)
64  else:
65  net.Alias(in_blob, out_blob)
66 
67  def add_ops(self, net):
68  self._set_output_blobs(net, InstantiationContext.PREDICTION)
69 
70  def add_eval_ops(self, net):
71  self._set_output_blobs(net, InstantiationContext.EVAL)
72 
73  def add_train_ops(self, net):
74  self._set_output_blobs(net, InstantiationContext.TRAINING)
75 
76  def add_ops_to_accumulate_pred(self, net):
77  self._set_output_blobs(net, InstantiationContext.ACCUMULATE_PRED)