Caffe2 - Python API
A deep learning, cross platform ML framework
reservoir_sampling.py
1 ## @package reservoir_sampling
2 # Module caffe2.python.layers.reservoir_sampling
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core, schema
9 from caffe2.python.layers.layers import ModelLayer
10 
11 
13  """
14  Collect samples from input record w/ reservoir sampling. If you have complex
15  data, use PackRecords to pack it before using this layer.
16 
17  This layer is not thread safe.
18  """
19 
20  def __init__(self, model, input_record, num_to_collect,
21  name='reservoir_sampling', **kwargs):
22  super(ReservoirSampling, self).__init__(
23  model, name, input_record, **kwargs)
24  assert num_to_collect > 0
25  self.num_to_collect = num_to_collect
26 
27  self.reservoir = self.create_param(
28  param_name='reservoir',
29  shape=[0],
30  initializer=('ConstantFill',),
31  optimizer=model.NoOptim,
32  )
33  self.num_visited_blob = self.create_param(
34  param_name='num_visited',
35  shape=[],
36  initializer=('ConstantFill', {
37  'value': 0,
38  'dtype': core.DataType.INT64,
39  }),
40  optimizer=model.NoOptim,
41  )
42  self.mutex = self.create_param(
43  param_name='mutex',
44  shape=None,
45  initializer=('CreateMutex',),
46  optimizer=model.NoOptim,
47  )
48 
49  self.extra_input_blobs = []
50  self.extra_output_blobs = []
51  if 'object_id' in input_record:
52  object_to_pos = self.create_param(
53  param_name='object_to_pos',
54  shape=None,
55  initializer=('CreateMap', {
56  'key_dtype': core.DataType.INT64,
57  'valued_dtype': core.DataType.INT32,
58  }),
59  optimizer=model.NoOptim,
60  )
61  pos_to_object = self.create_param(
62  param_name='pos_to_object',
63  shape=[0],
64  initializer=('ConstantFill', {
65  'value': 0,
66  'dtype': core.DataType.INT64,
67  }),
68  optimizer=model.NoOptim,
69  )
70  self.extra_input_blobs.append(input_record.object_id())
71  self.extra_input_blobs.extend([object_to_pos, pos_to_object])
72  self.extra_output_blobs.extend([object_to_pos, pos_to_object])
73 
75  (
76  'reservoir',
77  schema.from_blob_list(input_record.data, [self.reservoir])
78  ),
79  ('num_visited', schema.Scalar(blob=self.num_visited_blob)),
80  ('mutex', schema.Scalar(blob=self.mutex)),
81  )
82 
83  def add_ops(self, net):
84  net.ReservoirSampling(
85  [self.reservoir, self.num_visited_blob, self.input_record.data(),
86  self.mutex] + self.extra_input_blobs,
87  [self.reservoir, self.num_visited_blob] + self.extra_output_blobs,
88  num_to_collect=self.num_to_collect,
89  )
def create_param(self, param_name, shape, initializer, optimizer, ps_param=None, regularizer=None)
Definition: layers.py:334