Caffe2 - Python API
A deep learning, cross platform ML framework
gather_record.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package gather_record
17 # Module caffe2.python.layers.gather_record
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from caffe2.python import core, schema
24 from caffe2.python.layers.layers import ModelLayer
25 
26 
28  """
29  Given 1-D `indices` tensor, gather elements at `i` in `indices` from all the
30  blobs in `record`. If a blob is a values blob of a list, all the elements
31  included by the list's lengths blob are gathered. For example,
32 
33  Input:
34  indices = [0, 2]
35  record:a = [[0, 1], [2, 3], [4, 5], [6, 7]]
36  record:b:lengths = [0, 1, 2, 3]
37  record:b:items = [0, 1, 2, 3, 4, 5]
38 
39  Output:
40  a = [[0, 1], [4, 5]]
41  b:lengths = [0, 2]
42  b:items = [1, 2]
43 
44  This supports nested list.
45  """
46 
47  def __init__(self, model, input_record, name='gather_record', **kwargs):
48  super(GatherRecord, self).__init__(model, name, input_record, **kwargs)
49 
50  assert 'indices' in input_record
51  assert 'record' in input_record
52 
53  self.output_schema = schema.NewRecord(
54  model.net, input_record.record.clone_schema())
55 
56  self._indices = self.input_record.indices()
57 
58  def _gather_scalar(self, net, record, lengths_blob, output_record):
59  if lengths_blob is None:
60  net.Gather([record(), self._indices], output_record())
61  else:
62  net.LengthsGather([record(), lengths_blob, self._indices],
63  output_record())
64 
65  def _gather_struct(self, net, record, lengths_blob, output_record):
66  for name, field in record.get_children():
67  self._dispatch(net, field, lengths_blob, output_record[name])
68 
69  def _gather_list(self, net, record, lengths_blob, output_record):
70  self._gather_scalar(
71  net, record.lengths, lengths_blob, output_record.lengths)
72  if lengths_blob is None:
73  lengths_blob = record.lengths()
74  else:
75  # TODO(kittipat): This is a hacky solution until LengthsSum for int
76  # is implemented
77  lengths_float = net.Cast(
78  record.lengths(),
79  net.NextScopedBlob(str(record.lengths()) + '_float'),
80  to=core.DataType.FLOAT,
81  )
82  lengths_blob_float = net.LengthsSum(
83  [lengths_float, lengths_blob],
84  net.NextScopedBlob(str(record.lengths()) + "_nested_float")
85  )
86  lengths_blob = net.Cast(
87  lengths_blob_float,
88  net.NextScopedBlob(str(record.lengths()) + "_nested"),
89  to=core.DataType.INT32,
90  )
91  self._dispatch(net, record._items, lengths_blob, output_record._items)
92 
93  def _dispatch(self, net, record, lengths_blob, output_record):
94  if isinstance(record, schema.Scalar):
95  self._gather_scalar(net, record, lengths_blob, output_record)
96  elif isinstance(record, schema.Struct):
97  self._gather_struct(net, record, lengths_blob, output_record)
98  elif isinstance(record, schema.List):
99  self._gather_list(net, record, lengths_blob, output_record)
100  else:
101  raise NotImplementedError
102 
103  def add_ops(self, net):
104  self._dispatch(net, self.input_record.record, None, self.output_schema)
def _gather_scalar(self, net, record, lengths_blob, output_record)
def _gather_list(self, net, record, lengths_blob, output_record)
def _gather_struct(self, net, record, lengths_blob, output_record)
def _dispatch(self, net, record, lengths_blob, output_record)