Caffe2 - C++ API
A deep learning, cross platform ML framework
reservoir_sampling.cc
1 #include <memory>
2 #include <string>
3 #include <vector>
4 #include "caffe2/core/operator.h"
5 #include "caffe2/core/tensor.h"
6 #include "caffe2/operators/map_ops.h"
7 
8 namespace caffe2 {
9 namespace {
10 
11 template <class Context>
12 class ReservoirSamplingOp final : public Operator<Context> {
13  public:
14  USE_OPERATOR_CONTEXT_FUNCTIONS;
15  ReservoirSamplingOp(const OperatorDef operator_def, Workspace* ws)
16  : Operator<Context>(operator_def, ws),
17  numToCollect_(
18  OperatorBase::GetSingleArgument<int>("num_to_collect", -1)) {
19  CAFFE_ENFORCE(numToCollect_ > 0);
20  }
21 
22  bool RunOnDevice() override {
23  auto& mutex = OperatorBase::Input<std::unique_ptr<std::mutex>>(MUTEX);
24  std::lock_guard<std::mutex> guard(*mutex);
25 
26  auto* output = Output(RESERVOIR);
27  const auto& input = Input(DATA);
28 
29  CAFFE_ENFORCE_GE(input.dim(), 1);
30 
31  bool output_initialized = output->numel() > 0 &&
32  (static_cast<std::shared_ptr<std::vector<TensorCPU>>*>(
33  output->raw_mutable_data(input.dtype()))[0] != nullptr);
34 
35  if (output_initialized) {
36  CAFFE_ENFORCE_EQ(output->dim(), input.dim());
37  for (size_t i = 1; i < input.dim(); ++i) {
38  CAFFE_ENFORCE_EQ(output->size(i), input.size(i));
39  }
40  }
41 
42  auto num_entries = input.sizes()[0];
43 
44  if (!output_initialized) {
45  // IMPORTANT: Force the output to have the right type before reserving,
46  // so that the output gets the right capacity
47  auto dims = input.sizes().vec();
48  dims[0] = 0;
49  output->Resize(dims);
50  output->raw_mutable_data(input.dtype());
51  output->ReserveSpace(numToCollect_);
52  }
53 
54  auto* pos_to_object =
55  OutputSize() > POS_TO_OBJECT ? Output(POS_TO_OBJECT) : nullptr;
56  if (pos_to_object) {
57  if (!output_initialized) {
58  // Cleaning up in case the reservoir got reset.
59  pos_to_object->Resize(0);
60  pos_to_object->template mutable_data<int64_t>();
61  pos_to_object->ReserveSpace(numToCollect_);
62  }
63  }
64 
65  auto* object_to_pos_map = OutputSize() > OBJECT_TO_POS_MAP
66  ? OperatorBase::Output<MapType64To32>(OBJECT_TO_POS_MAP)
67  : nullptr;
68 
69  if (object_to_pos_map && !output_initialized) {
70  object_to_pos_map->clear();
71  }
72 
73  auto* num_visited_tensor = Output(NUM_VISITED);
74  CAFFE_ENFORCE_EQ(1, num_visited_tensor->numel());
75  auto* num_visited = num_visited_tensor->template mutable_data<int64_t>();
76  if (!output_initialized) {
77  *num_visited = 0;
78  }
79  CAFFE_ENFORCE_GE(*num_visited, 0);
80 
81  if (num_entries == 0) {
82  if (!output_initialized) {
83  // Get both shape and meta
84  output->CopyFrom(input, /* async */ true);
85  }
86  return true;
87  }
88 
89  const int64_t* object_id_data = nullptr;
90  std::set<int64_t> unique_object_ids;
91  if (InputSize() > OBJECT_ID) {
92  const auto& object_id = Input(OBJECT_ID);
93  CAFFE_ENFORCE_EQ(object_id.dim(), 1);
94  CAFFE_ENFORCE_EQ(object_id.numel(), num_entries);
95  object_id_data = object_id.template data<int64_t>();
96  unique_object_ids.insert(
97  object_id_data, object_id_data + object_id.numel());
98  }
99 
100  const auto num_new_entries = countNewEntries(unique_object_ids);
101  auto num_to_copy = std::min<int32_t>(num_new_entries, numToCollect_);
102  auto output_batch_size = output_initialized ? output->size(0) : 0;
103  auto output_num =
104  std::min<size_t>(numToCollect_, output_batch_size + num_to_copy);
105  // output_num is >= output_batch_size
106  output->ExtendTo(output_num, 50);
107  if (pos_to_object) {
108  pos_to_object->ExtendTo(output_num, 50);
109  // ExtendTo doesn't zero-initialize tensors any more, explicitly clear
110  // the memory
111  memset(
112  pos_to_object->template mutable_data<int64_t>() +
113  output_batch_size * sizeof(int64_t),
114  0,
115  (output_num - output_batch_size) * sizeof(int64_t));
116  }
117 
118  auto* output_data =
119  static_cast<char*>(output->raw_mutable_data(input.dtype()));
120  auto* pos_to_object_data = pos_to_object
121  ? pos_to_object->template mutable_data<int64_t>()
122  : nullptr;
123 
124  auto block_size = input.size_from_dim(1);
125  auto block_bytesize = block_size * input.itemsize();
126  const auto* input_data = static_cast<const char*>(input.raw_data());
127 
128  const auto start_num_visited = *num_visited;
129 
130  std::set<int64_t> eligible_object_ids;
131  if (object_to_pos_map) {
132  for (auto oid : unique_object_ids) {
133  if (!object_to_pos_map->count(oid)) {
134  eligible_object_ids.insert(oid);
135  }
136  }
137  }
138 
139  for (int i = 0; i < num_entries; ++i) {
140  if (object_id_data && object_to_pos_map &&
141  !eligible_object_ids.count(object_id_data[i])) {
142  // Already in the pool or processed
143  continue;
144  }
145  if (object_id_data) {
146  eligible_object_ids.erase(object_id_data[i]);
147  }
148  int64_t pos = -1;
149  if (*num_visited < numToCollect_) {
150  // append
151  pos = *num_visited;
152  } else {
153  auto& gen = context_.RandGenerator();
154  // uniform between [0, num_visited]
155  std::uniform_int_distribution<int64_t> uniformDist(0, *num_visited);
156  pos = uniformDist(gen);
157  if (pos >= numToCollect_) {
158  // discard
159  pos = -1;
160  }
161  }
162 
163  if (pos < 0) {
164  // discard
165  CAFFE_ENFORCE_GE(*num_visited, numToCollect_);
166  } else {
167  // replace
168  context_.CopyItemsSameDevice(
169  input.dtype(),
170  block_size,
171  input_data + i * block_bytesize,
172  output_data + pos * block_bytesize);
173 
174  if (object_id_data && pos_to_object_data && object_to_pos_map) {
175  auto old_oid = pos_to_object_data[pos];
176  auto new_oid = object_id_data[i];
177  pos_to_object_data[pos] = new_oid;
178  object_to_pos_map->erase(old_oid);
179  object_to_pos_map->emplace(new_oid, pos);
180  }
181  }
182 
183  ++(*num_visited);
184  }
185  // Sanity check
186  CAFFE_ENFORCE_EQ(*num_visited, start_num_visited + num_new_entries);
187  return true;
188  }
189 
190  private:
191  // number of tensors to collect
192  int numToCollect_;
193 
194  INPUT_TAGS(
195  RESERVOIR_IN,
196  NUM_VISITED_IN,
197  DATA,
198  MUTEX,
199  OBJECT_ID,
200  OBJECT_TO_POS_MAP_IN,
201  POS_TO_OBJECT_IN);
202  OUTPUT_TAGS(RESERVOIR, NUM_VISITED, OBJECT_TO_POS_MAP, POS_TO_OBJECT);
203 
204  int32_t countNewEntries(const std::set<int64_t>& unique_object_ids) {
205  const auto& input = Input(DATA);
206  if (InputSize() <= OBJECT_ID) {
207  return input.size(0);
208  }
209  const auto& object_to_pos_map =
210  OperatorBase::Input<MapType64To32>(OBJECT_TO_POS_MAP_IN);
211  return std::count_if(
212  unique_object_ids.begin(),
213  unique_object_ids.end(),
214  [&object_to_pos_map](int64_t oid) {
215  return !object_to_pos_map.count(oid);
216  });
217  }
218 };
219 
220 REGISTER_CPU_OPERATOR(ReservoirSampling, ReservoirSamplingOp<CPUContext>);
221 
222 OPERATOR_SCHEMA(ReservoirSampling)
223  .NumInputs({4, 7})
224  .NumOutputs({2, 4})
225  .NumInputsOutputs([](int in, int out) { return in / 3 == out / 2; })
226  .EnforceInplace({{0, 0}, {1, 1}, {5, 2}, {6, 3}})
227  .SetDoc(R"DOC(
228 Collect `DATA` tensor into `RESERVOIR` of size `num_to_collect`. `DATA` is
229 assumed to be a batch.
230 
231 In case where 'objects' may be repeated in data and you only want at most one
232 instance of each 'object' in the reservoir, `OBJECT_ID` can be given for
233 deduplication. If `OBJECT_ID` is given, then you also need to supply additional
234 book-keeping tensors. See input blob documentation for details.
235 
236 This operator is thread-safe.
237 )DOC")
238  .Arg(
239  "num_to_collect",
240  "The number of random samples to append for each positive samples")
241  .Input(
242  0,
243  "RESERVOIR",
244  "The reservoir; should be initialized to empty tensor")
245  .Input(
246  1,
247  "NUM_VISITED",
248  "Number of examples seen so far; should be initialized to 0")
249  .Input(
250  2,
251  "DATA",
252  "Tensor to collect from. The first dimension is assumed to be batch "
253  "size. If the object to be collected is represented by multiple "
254  "tensors, use `PackRecords` to pack them into single tensor.")
255  .Input(3, "MUTEX", "Mutex to prevent data race")
256  .Input(
257  4,
258  "OBJECT_ID",
259  "(Optional, int64) If provided, used for deduplicating object in the "
260  "reservoir")
261  .Input(
262  5,
263  "OBJECT_TO_POS_MAP_IN",
264  "(Optional) Auxillary bookkeeping map. This should be created from "
265  " `CreateMap` with keys of type int64 and values of type int32")
266  .Input(
267  6,
268  "POS_TO_OBJECT_IN",
269  "(Optional) Tensor of type int64 used for bookkeeping in deduplication")
270  .Output(0, "RESERVOIR", "Same as the input")
271  .Output(1, "NUM_VISITED", "Same as the input")
272  .Output(2, "OBJECT_TO_POS_MAP", "(Optional) Same as the input")
273  .Output(3, "POS_TO_OBJECT", "(Optional) Same as the input");
274 
275 SHOULD_NOT_DO_GRADIENT(ReservoirSampling);
276 } // namespace
277 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13