Caffe2 - Python API
A deep learning, cross platform ML framework
concat.py
1 ## @package concat
2 # Module caffe2.python.layers.concat
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import schema
9 from caffe2.python.layers.layers import (
10  ModelLayer,
11 )
12 from future.utils import viewitems
13 import numpy as np
14 from collections import defaultdict
15 
16 import logging
17 logger = logging.getLogger(__name__)
18 
19 
20 def get_concatenated_feature_to_index(blobs_to_concat):
21  concat_feature_to_index = defaultdict(list)
22  start_pos = 0
23  for scalar in blobs_to_concat:
24  num_dims = scalar.dtype.shape[0]
25  if hasattr(scalar, 'metadata') \
26  and hasattr(scalar.metadata, 'feature_specs') \
27  and hasattr(scalar.metadata.feature_specs, 'feature_to_index') \
28  and isinstance(scalar.metadata.feature_specs.feature_to_index, dict): # noqa B950
29  for k, v in scalar.metadata.feature_specs.feature_to_index.items():
30  concat_feature_to_index[k].extend([start_pos + vi for vi in v])
31  start_pos += num_dims
32  return dict(concat_feature_to_index) if concat_feature_to_index.keys() else None
33 
34 
35 class Concat(ModelLayer):
36  """
37  Construct Concat layer
38  Assume that first dimension is batch,
39 
40  Example:
41 
42  embedding_dim = 64
43  input_record = self.new_record(schema.Struct(
44  ('input1', schema.Scalar((np.float32, (embedding_dim, )))),
45  ('input2', schema.Scalar((np.float32, (embedding_dim, )))),
46  ('input3', schema.Scalar((np.float32, (embedding_dim, )))),
47  ))
48 
49  output = self.model.Concat(input_record)
50  self.assertEqual(
51  schema.Scalar((np.float32, ((len(input_record.fields) * embedding_dim, )))),
52  output
53  )
54 
55  # Note that in Concat layer we assume first dimension is batch.
56  # so input is B * embedding_dim
57  # add_axis=1 make it B * 1 * embedding_dim
58  # Concat on axis=1 make it B * N * embedding_dim
59 
60  output = self.model.Concat(input_record, axis=1, add_axis=1)
61  self.assertEqual(
62  schema.Scalar((np.float32, ((len(input_record.fields), embedding_dim)))),
63  output
64  )
65  """
66 
67  def __init__(self, model, input_record, axis=1, add_axis=0,
68  name='concat', **kwargs):
69  super(Concat, self).__init__(model, name, input_record, **kwargs)
70  self.axis = axis
71  self.add_axis = add_axis
72  assert not (axis == 0 and add_axis == 1), \
73  "It's not allowed to add axis=0"
74  assert isinstance(input_record, schema.Struct),\
75  "Incorrect input type. Excpected Struct, but received: {0}".\
76  format(input_record)
77 
78  shapes = []
79  for field_name, field_type in viewitems(input_record.fields):
80  assert isinstance(field_type, schema.Scalar),\
81  "Incorrect input type for {}. Excpected Scalar, but got: {}".\
82  format(field_name, field_type)
83  # Assume that first dimension is batch, so actual axis in shape is
84  # axis - 1
85  shape = list(field_type.field_type().shape)
86  if add_axis:
87  shape.insert(axis - 1, 1)
88  assert len(shape) >= axis,\
89  "Concat expects that limited dimensions of the input tensor"
90  shapes.append(shape)
91  logger.info('Concat Layer input shapes: ' + str(shapes))
92 
93  if axis == 0:
94  self.output_schema = schema.from_blob_list(
95  input_record[0],
96  [self.get_next_blob_reference('output')]
97  )
98  return
99 
100  concat_dim = 0
101  for shape in shapes:
102  concat_dim += shape[axis - 1]
103  shape[axis - 1] = 0
104  assert shape == shapes[0],\
105  "Shapes {0} and {1} are not compatible for Concat".\
106  format(shape, shapes[0])
107  output_dims = shapes[0]
108  output_dims[axis - 1] = concat_dim
109 
110  logger.info('Concat Layer output_dims: ' + str(output_dims))
112  (np.float32, output_dims),
113  self.get_next_blob_reference('output'))
114 
115  record_to_concat = input_record.fields.values()
116  concated_feature_to_index = get_concatenated_feature_to_index(
117  record_to_concat
118  )
119  if concated_feature_to_index:
120  metadata = schema.Metadata(
121  feature_specs=schema.FeatureSpec(
122  feature_to_index=concated_feature_to_index
123  )
124  )
125  self.output_schema.set_metadata(metadata)
126 
127 
128  def add_ops(self, net):
129  net.Concat(
130  self.input_record.field_blobs(),
131  [
132  self.output_schema.field_blobs()[0],
133  self.output_schema.field_blobs()[0] + "_concat_dims"
134  ],
135  axis=self.axis,
136  add_axis=self.add_axis,
137  )