Caffe2 - Python API
A deep learning, cross platform ML framework
last_n_window_collector.py
1 ## @package last_n_window_collector
2 # Module caffe2.python.layers.last_n_window_collector
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core, schema
9 from caffe2.python.layers.layers import ModelLayer
10 
11 
13  """
14  Collect last-N samples from input record. If you have complex data,
15  use PackRecords to pack it before using this layer.
16 
17  This layer is not thread safe.
18  """
19 
20  def __init__(self, model, input_record, num_to_collect,
21  name='last_n_window_collector', **kwargs):
22  super(LastNWindowCollector, self).__init__(
23  model, name, input_record, **kwargs)
24  assert num_to_collect > 0
25  self.num_to_collect = num_to_collect
26  assert isinstance(input_record, schema.Scalar), \
27  "Got {!r}".format(input_record)
28 
29  self.last_n = self.create_param(param_name='last_n',
30  shape=[0],
31  initializer=('ConstantFill', {}),
32  optimizer=model.NoOptim)
33 
34  self.next_blob = self.create_param(
35  param_name='next',
36  shape=[],
37  initializer=('ConstantFill',
38  {'value': 0, 'dtype': core.DataType.INT32}),
39  optimizer=model.NoOptim
40  )
41 
42  self.mutex = self.create_param(
43  param_name='mutex',
44  shape=None,
45  initializer=('CreateMutex',),
46  optimizer=model.NoOptim,
47  )
48 
49  self.num_visited_blob = self.create_param(
50  param_name='num_visited',
51  shape=[],
52  initializer=('ConstantFill', {
53  'value': 0,
54  'dtype': core.DataType.INT64,
55  }),
56  optimizer=model.NoOptim,
57  )
58 
60  (
61  'last_n',
62  schema.from_blob_list(input_record, [self.last_n])
63  ),
64  ('num_visited', schema.Scalar(blob=self.num_visited_blob)),
65  ('mutex', schema.Scalar(blob=self.mutex)),
66  )
67 
68  def add_ops(self, net):
69  net.LastNWindowCollector(
70  [self.last_n, self.next_blob, self.input_record(), self.mutex,
71  self.num_visited_blob],
72  [self.last_n, self.next_blob, self.num_visited_blob],
73  num_to_collect=self.num_to_collect,
74  )
def create_param(self, param_name, shape, initializer, optimizer, ps_param=None, regularizer=None)
Definition: layers.py:334