3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
14 Collect last-N samples from input record. If you have complex data, 15 use PackRecords to pack it before using this layer. 17 This layer is not thread safe. 20 def __init__(self, model, input_record, num_to_collect,
21 name=
'last_n_window_collector', **kwargs):
22 super(LastNWindowCollector, self).__init__(
23 model, name, input_record, **kwargs)
24 assert num_to_collect > 0
27 "Got {!r}".format(input_record)
31 initializer=(
'ConstantFill', {}),
32 optimizer=model.NoOptim)
37 initializer=(
'ConstantFill',
38 {
'value': 0,
'dtype': core.DataType.INT32}),
39 optimizer=model.NoOptim
45 initializer=(
'CreateMutex',),
46 optimizer=model.NoOptim,
50 param_name=
'num_visited',
52 initializer=(
'ConstantFill', {
54 'dtype': core.DataType.INT64,
56 optimizer=model.NoOptim,
62 schema.from_blob_list(input_record, [self.
last_n])
68 def add_ops(self, net):
69 net.LastNWindowCollector(
def create_param(self, param_name, shape, initializer, optimizer, ps_param=None, regularizer=None)