1 from __future__
import absolute_import
2 from __future__
import division
3 from __future__
import print_function
4 from __future__
import unicode_literals
15 logger = logging.getLogger(__name__)
20 Allowing model to follow different paths for each instatiation context and 21 join later at some point. The implementation use `Alias` because schema 22 sometimes clone fields internally so we need static blob name for output 29 name=
'select_record_by_context',
30 check_field_metas=
True,
32 default_output_record_field=
None,
35 super(SelectRecordByContext, self).__init__(model, name, input_record,
39 assert len(input_record) > 1
43 input_record[default_output_record_field]
44 if (default_output_record_field
is not None)
else None 46 ref_record = input_record[0]
47 for record
in input_record:
48 assert schema.equal_schemas(record, ref_record,
49 check_field_metas=check_field_metas)
53 def _set_output_blobs(self, net, context):
55 assert record
is not None, (
56 "{} context is not in input record without providing default" 57 " output".format(context)
59 for in_blob, out_blob
in zip(
60 record.field_blobs(), self.output_schema.field_blobs()
63 net.Copy(in_blob, out_blob)
65 net.Alias(in_blob, out_blob)
67 def add_ops(self, net):
70 def add_eval_ops(self, net):
73 def add_train_ops(self, net):
76 def add_ops_to_accumulate_pred(self, net):
def _set_output_blobs(self, net, context)