Caffe2 - Python API
A deep learning, cross platform ML framework
concat.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package concat
17 # Module caffe2.python.layers.concat
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from caffe2.python import schema
24 from caffe2.python.layers.layers import (
25  ModelLayer,
26 )
27 from future.utils import viewitems
28 import numpy as np
29 
30 import logging
31 logger = logging.getLogger(__name__)
32 
33 class Concat(ModelLayer):
34  """
35  Construct Concat layer
36  Assume that first dimension is batch,
37 
38  Example:
39 
40  embedding_dim = 64
41  input_record = self.new_record(schema.Struct(
42  ('input1', schema.Scalar((np.float32, (embedding_dim, )))),
43  ('input2', schema.Scalar((np.float32, (embedding_dim, )))),
44  ('input3', schema.Scalar((np.float32, (embedding_dim, )))),
45  ))
46 
47  output = self.model.Concat(input_record)
48  self.assertEqual(
49  schema.Scalar((np.float32, ((len(input_record.fields) * embedding_dim, )))),
50  output
51  )
52 
53  # Note that in Concat layer we assume first dimension is batch.
54  # so input is B * embedding_dim
55  # add_axis=1 make it B * 1 * embedding_dim
56  # Concat on axis=1 make it B * N * embedding_dim
57 
58  output = self.model.Concat(input_record, axis=1, add_axis=1)
59  self.assertEqual(
60  schema.Scalar((np.float32, ((len(input_record.fields), embedding_dim)))),
61  output
62  )
63  """
64 
65  def __init__(self, model, input_record, axis=1, add_axis=0,
66  name='concat', **kwargs):
67  super(Concat, self).__init__(model, name, input_record, **kwargs)
68  self.axis = axis
69  self.add_axis = add_axis
70  assert not (axis == 0 and add_axis == 1), \
71  "It's not allowed to add axis=0"
72  assert isinstance(input_record, schema.Struct),\
73  "Incorrect input type. Excpected Struct, but received: {0}".\
74  format(input_record)
75 
76  shapes = []
77  for field_name, field_type in viewitems(input_record.fields):
78  assert isinstance(field_type, schema.Scalar),\
79  "Incorrect input type for {}. Excpected Scalar, but got: {}".\
80  format(field_name, field_type)
81  # Assume that first dimension is batch, so actual axis in shape is
82  # axis - 1
83  shape = list(field_type.field_type().shape)
84  if add_axis:
85  shape.insert(axis - 1, 1)
86  assert len(shape) >= axis,\
87  "Concat expects that limited dimensions of the input tensor"
88  shapes.append(shape)
89  logger.info('Concat Layer input shapes: ' + str(shapes))
90 
91  if axis == 0:
92  self.output_schema = schema.from_blob_list(
93  input_record[0],
94  [self.get_next_blob_reference('output')]
95  )
96  return
97 
98  concat_dim = 0
99  for shape in shapes:
100  concat_dim += shape[axis - 1]
101  shape[axis - 1] = 0
102  assert shape == shapes[0],\
103  "Shapes {0} and {1} are not compatible for Concat".\
104  format(shape, shapes[0])
105  output_dims = shapes[0]
106  output_dims[axis - 1] = concat_dim
107 
108  logger.info('Concat Layer output_dims: ' + str(output_dims))
110  (np.float32, output_dims),
111  self.get_next_blob_reference('output'))
112 
113  def add_ops(self, net):
114  net.Concat(
115  self.input_record.field_blobs(),
116  [
117  self.output_schema.field_blobs()[0],
118  self.output_schema.field_blobs()[0] + "_concat_dims"
119  ],
120  axis=self.axis,
121  add_axis=self.add_axis,
122  )