3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
12 from future.utils
import viewitems
14 from collections
import defaultdict
17 logger = logging.getLogger(__name__)
20 def get_concatenated_feature_to_index(blobs_to_concat):
21 concat_feature_to_index = defaultdict(list)
23 for scalar
in blobs_to_concat:
24 num_dims = scalar.dtype.shape[0]
25 if hasattr(scalar,
'metadata') \
26 and hasattr(scalar.metadata,
'feature_specs') \
27 and hasattr(scalar.metadata.feature_specs,
'feature_to_index') \
28 and isinstance(scalar.metadata.feature_specs.feature_to_index, dict):
29 for k, v
in scalar.metadata.feature_specs.feature_to_index.items():
30 concat_feature_to_index[k].extend([start_pos + vi
for vi
in v])
32 return dict(concat_feature_to_index)
if concat_feature_to_index.keys()
else None 37 Construct Concat layer 38 Assume that first dimension is batch, 43 input_record = self.new_record(schema.Struct( 44 ('input1', schema.Scalar((np.float32, (embedding_dim, )))), 45 ('input2', schema.Scalar((np.float32, (embedding_dim, )))), 46 ('input3', schema.Scalar((np.float32, (embedding_dim, )))), 49 output = self.model.Concat(input_record) 51 schema.Scalar((np.float32, ((len(input_record.fields) * embedding_dim, )))), 55 # Note that in Concat layer we assume first dimension is batch. 56 # so input is B * embedding_dim 57 # add_axis=1 make it B * 1 * embedding_dim 58 # Concat on axis=1 make it B * N * embedding_dim 60 output = self.model.Concat(input_record, axis=1, add_axis=1) 62 schema.Scalar((np.float32, ((len(input_record.fields), embedding_dim)))), 67 def __init__(self, model, input_record, axis=1, add_axis=0,
68 name=
'concat', **kwargs):
69 super(Concat, self).__init__(model, name, input_record, **kwargs)
72 assert not (axis == 0
and add_axis == 1), \
73 "It's not allowed to add axis=0" 75 "Incorrect input type. Excpected Struct, but received: {0}".\
79 for field_name, field_type
in viewitems(input_record.fields):
81 "Incorrect input type for {}. Excpected Scalar, but got: {}".\
82 format(field_name, field_type)
85 shape = list(field_type.field_type().shape)
87 shape.insert(axis - 1, 1)
88 assert len(shape) >= axis,\
89 "Concat expects that limited dimensions of the input tensor" 91 logger.info(
'Concat Layer input shapes: ' + str(shapes))
96 [self.get_next_blob_reference(
'output')]
102 concat_dim += shape[axis - 1]
104 assert shape == shapes[0],\
105 "Shapes {0} and {1} are not compatible for Concat".\
106 format(shape, shapes[0])
107 output_dims = shapes[0]
108 output_dims[axis - 1] = concat_dim
110 logger.info(
'Concat Layer output_dims: ' + str(output_dims))
112 (np.float32, output_dims),
113 self.get_next_blob_reference(
'output'))
115 record_to_concat = input_record.fields.values()
116 concated_feature_to_index = get_concatenated_feature_to_index(
119 if concated_feature_to_index:
121 feature_specs=schema.FeatureSpec(
122 feature_to_index=concated_feature_to_index
125 self.output_schema.set_metadata(metadata)
128 def add_ops(self, net):
130 self.input_record.field_blobs(),
132 self.output_schema.field_blobs()[0],
133 self.output_schema.field_blobs()[0] +
"_concat_dims"