Caffe2 - Python API
A deep learning, cross platform ML framework
last_n_window_collector.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package last_n_window_collector
17 # Module caffe2.python.layers.last_n_window_collector
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from caffe2.python import core, schema
24 from caffe2.python.layers.layers import ModelLayer
25 
26 
28  """
29  Collect last-N samples from input record. If you have complex data,
30  use PackRecords to pack it before using this layer.
31 
32  This layer is not thread safe.
33  """
34 
35  def __init__(self, model, input_record, num_to_collect,
36  name='last_n_window_collector', **kwargs):
37  super(LastNWindowCollector, self).__init__(
38  model, name, input_record, **kwargs)
39  assert num_to_collect > 0
40  self.num_to_collect = num_to_collect
41  assert isinstance(input_record, schema.Scalar), \
42  "Got {!r}".format(input_record)
43 
44  self.last_n = self.create_param(param_name='last_n',
45  shape=[0],
46  initializer=('ConstantFill', {}),
47  optimizer=model.NoOptim)
48 
49  self.next_blob = self.create_param(
50  param_name='next',
51  shape=[],
52  initializer=('ConstantFill',
53  {'value': 0, 'dtype': core.DataType.INT32}),
54  optimizer=model.NoOptim
55  )
56 
57  self.mutex = self.create_param(
58  param_name='mutex',
59  shape=None,
60  initializer=('CreateMutex',),
61  optimizer=model.NoOptim,
62  )
63 
64  self.num_visited_blob = self.create_param(
65  param_name='num_visited',
66  shape=[],
67  initializer=('ConstantFill', {
68  'value': 0,
69  'dtype': core.DataType.INT64,
70  }),
71  optimizer=model.NoOptim,
72  )
73 
75  (
76  'last_n',
77  schema.from_blob_list(input_record, [self.last_n])
78  ),
79  ('num_visited', schema.Scalar(blob=self.num_visited_blob)),
80  ('mutex', schema.Scalar(blob=self.mutex)),
81  )
82 
83  def add_ops(self, net):
84  net.LastNWindowCollector(
85  [self.last_n, self.next_blob, self.input_record(), self.mutex,
86  self.num_visited_blob],
87  [self.last_n, self.next_blob, self.num_visited_blob],
88  num_to_collect=self.num_to_collect,
89  )
def create_param(self, param_name, shape, initializer, optimizer, ps_param=None, regularizer=None)
Definition: layers.py:337