3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
14 Collect samples from input record w/ reservoir sampling. If you have complex 15 data, use PackRecords to pack it before using this layer. 17 This layer is not thread safe. 20 def __init__(self, model, input_record, num_to_collect,
21 name=
'reservoir_sampling', **kwargs):
22 super(ReservoirSampling, self).__init__(
23 model, name, input_record, **kwargs)
24 assert num_to_collect > 0
28 param_name=
'reservoir',
30 initializer=(
'ConstantFill',),
31 optimizer=model.NoOptim,
34 param_name=
'num_visited',
36 initializer=(
'ConstantFill', {
38 'dtype': core.DataType.INT64,
40 optimizer=model.NoOptim,
45 initializer=(
'CreateMutex',),
46 optimizer=model.NoOptim,
51 if 'object_id' in input_record:
53 param_name=
'object_to_pos',
55 initializer=(
'CreateMap', {
56 'key_dtype': core.DataType.INT64,
57 'valued_dtype': core.DataType.INT32,
59 optimizer=model.NoOptim,
62 param_name=
'pos_to_object',
64 initializer=(
'ConstantFill', {
66 'dtype': core.DataType.INT64,
68 optimizer=model.NoOptim,
70 self.extra_input_blobs.append(input_record.object_id())
71 self.extra_input_blobs.extend([object_to_pos, pos_to_object])
72 self.extra_output_blobs.extend([object_to_pos, pos_to_object])
77 schema.from_blob_list(input_record.data, [self.
reservoir])
83 def add_ops(self, net):
84 net.ReservoirSampling(
def create_param(self, param_name, shape, initializer, optimizer, ps_param=None, regularizer=None)