3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
14 Given 1-D `indices` tensor, gather elements at `i` in `indices` from all the 15 blobs in `record`. If a blob is a values blob of a list, all the elements 16 included by the list's lengths blob are gathered. For example, 20 record:a = [[0, 1], [2, 3], [4, 5], [6, 7]] 21 record:b:lengths = [0, 1, 2, 3] 22 record:b:items = [0, 1, 2, 3, 4, 5] 29 This supports nested list. 32 def __init__(self, model, input_record, name='gather_record', **kwargs):
33 super(GatherRecord, self).__init__(model, name, input_record, **kwargs)
35 assert 'indices' in input_record
36 assert 'record' in input_record
39 model.net, input_record.record.clone_schema())
41 self.
_indices = self.input_record.indices()
43 def _gather_scalar(self, net, record, lengths_blob, output_record):
44 if lengths_blob
is None:
45 net.Gather([record(), self.
_indices], output_record())
47 net.LengthsGather([record(), lengths_blob, self.
_indices],
50 def _gather_struct(self, net, record, lengths_blob, output_record):
51 for name, field
in record.get_children():
52 self.
_dispatch(net, field, lengths_blob, output_record[name])
54 def _gather_list(self, net, record, lengths_blob, output_record):
56 net, record.lengths, lengths_blob, output_record.lengths)
57 if lengths_blob
is None:
58 lengths_blob = record.lengths()
62 lengths_float = net.Cast(
64 net.NextScopedBlob(str(record.lengths()) +
'_float'),
65 to=core.DataType.FLOAT,
67 lengths_blob_float = net.LengthsSum(
68 [lengths_float, lengths_blob],
69 net.NextScopedBlob(str(record.lengths()) +
"_nested_float")
71 lengths_blob = net.Cast(
73 net.NextScopedBlob(str(record.lengths()) +
"_nested"),
74 to=core.DataType.INT32,
76 self.
_dispatch(net, record._items, lengths_blob, output_record._items)
78 def _dispatch(self, net, record, lengths_blob, output_record):
84 self.
_gather_list(net, record, lengths_blob, output_record)
86 raise NotImplementedError
88 def add_ops(self, net):
def _gather_scalar(self, net, record, lengths_blob, output_record)
def _gather_list(self, net, record, lengths_blob, output_record)
def _gather_struct(self, net, record, lengths_blob, output_record)
def _dispatch(self, net, record, lengths_blob, output_record)