Caffe2 - Python API
A deep learning, cross platform ML framework
split.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package split
17 # Module caffe2.python.layers.split
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from caffe2.python import schema
24 from caffe2.python.layers.layers import (
25  ModelLayer,
26 )
27 
28 
29 class Split(ModelLayer):
30 
31  def __init__(self, model, input_record, num_splits, axis=1,
32  name='split', **kwargs):
33  super(Split, self).__init__(model, name, input_record, **kwargs)
34  self.axis = axis
35  # Assume that first dimension is batch, so actual axis in shape is
36  # axis - 1
37  axis -= 1
38  assert axis >= 0
39 
40  assert isinstance(input_record, schema.Scalar),\
41  "Incorrect input type. Excpected Scalar, but received: {0}".\
42  format(input_record)
43 
44  input_shape = input_record.field_type().shape
45  assert len(input_shape) >= axis
46  assert input_shape[axis] % num_splits == 0
47 
48  output_shape = list(input_shape)
49  output_shape[axis] = int(output_shape[axis] / num_splits)
50 
51  data_type = input_record.field_type().base
52 
53  output_scalars = [
55  (data_type, output_shape),
56  self.get_next_blob_reference('output_{}'.format(i)),
57  )
58  for i in range(num_splits)
59  ]
60  self.output_schema = schema.Tuple(*output_scalars)
61 
62  def add_ops(self, net):
63  net.Split(
64  self.input_record.field_blobs(),
65  self.output_schema.field_blobs(),
66  axis=self.axis,
67  )