Caffe2 - Python API
A deep learning, cross platform ML framework
gather_record.py
1 ## @package gather_record
2 # Module caffe2.python.layers.gather_record
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core, schema
9 from caffe2.python.layers.layers import ModelLayer
10 
11 
13  """
14  Given 1-D `indices` tensor, gather elements at `i` in `indices` from all the
15  blobs in `record`. If a blob is a values blob of a list, all the elements
16  included by the list's lengths blob are gathered. For example,
17 
18  Input:
19  indices = [0, 2]
20  record:a = [[0, 1], [2, 3], [4, 5], [6, 7]]
21  record:b:lengths = [0, 1, 2, 3]
22  record:b:items = [0, 1, 2, 3, 4, 5]
23 
24  Output:
25  a = [[0, 1], [4, 5]]
26  b:lengths = [0, 2]
27  b:items = [1, 2]
28 
29  This supports nested list.
30  """
31 
32  def __init__(self, model, input_record, name='gather_record', **kwargs):
33  super(GatherRecord, self).__init__(model, name, input_record, **kwargs)
34 
35  assert 'indices' in input_record
36  assert 'record' in input_record
37 
38  self.output_schema = schema.NewRecord(
39  model.net, input_record.record.clone_schema())
40 
41  self._indices = self.input_record.indices()
42 
43  def _gather_scalar(self, net, record, lengths_blob, output_record):
44  if lengths_blob is None:
45  net.Gather([record(), self._indices], output_record())
46  else:
47  net.LengthsGather([record(), lengths_blob, self._indices],
48  output_record())
49 
50  def _gather_struct(self, net, record, lengths_blob, output_record):
51  for name, field in record.get_children():
52  self._dispatch(net, field, lengths_blob, output_record[name])
53 
54  def _gather_list(self, net, record, lengths_blob, output_record):
55  self._gather_scalar(
56  net, record.lengths, lengths_blob, output_record.lengths)
57  if lengths_blob is None:
58  lengths_blob = record.lengths()
59  else:
60  # TODO(kittipat): This is a hacky solution until LengthsSum for int
61  # is implemented
62  lengths_float = net.Cast(
63  record.lengths(),
64  net.NextScopedBlob(str(record.lengths()) + '_float'),
65  to=core.DataType.FLOAT,
66  )
67  lengths_blob_float = net.LengthsSum(
68  [lengths_float, lengths_blob],
69  net.NextScopedBlob(str(record.lengths()) + "_nested_float")
70  )
71  lengths_blob = net.Cast(
72  lengths_blob_float,
73  net.NextScopedBlob(str(record.lengths()) + "_nested"),
74  to=core.DataType.INT32,
75  )
76  self._dispatch(net, record._items, lengths_blob, output_record._items)
77 
78  def _dispatch(self, net, record, lengths_blob, output_record):
79  if isinstance(record, schema.Scalar):
80  self._gather_scalar(net, record, lengths_blob, output_record)
81  elif isinstance(record, schema.Struct):
82  self._gather_struct(net, record, lengths_blob, output_record)
83  elif isinstance(record, schema.List):
84  self._gather_list(net, record, lengths_blob, output_record)
85  else:
86  raise NotImplementedError
87 
88  def add_ops(self, net):
89  self._dispatch(net, self.input_record.record, None, self.output_schema)
def _gather_scalar(self, net, record, lengths_blob, output_record)
def _gather_list(self, net, record, lengths_blob, output_record)
def _gather_struct(self, net, record, lengths_blob, output_record)
def _dispatch(self, net, record, lengths_blob, output_record)