Caffe2 - C++ API
A deep learning, cross platform ML framework
reservoir_sampling.cc
1 
17 #include <memory>
18 #include <string>
19 #include <vector>
20 #include "caffe2/core/operator.h"
21 #include "caffe2/core/tensor.h"
22 #include "caffe2/operators/map_ops.h"
23 
24 namespace caffe2 {
25 namespace {
26 
27 template <class Context>
28 class ReservoirSamplingOp final : public Operator<Context> {
29  public:
30  USE_OPERATOR_CONTEXT_FUNCTIONS;
31  ReservoirSamplingOp(const OperatorDef operator_def, Workspace* ws)
32  : Operator<Context>(operator_def, ws),
33  numToCollect_(
34  OperatorBase::GetSingleArgument<int>("num_to_collect", -1)) {
35  CAFFE_ENFORCE(numToCollect_ > 0);
36  }
37 
38  bool RunOnDevice() override {
39  auto& mutex = OperatorBase::Input<std::unique_ptr<std::mutex>>(MUTEX);
40  std::lock_guard<std::mutex> guard(*mutex);
41 
42  auto* output = Output(RESERVOIR);
43  const auto& input = Input(DATA);
44 
45  CAFFE_ENFORCE_GE(input.ndim(), 1);
46 
47  bool output_initialized = output->size() > 0 &&
48  (static_cast<std::shared_ptr<std::vector<TensorCPU>>*>(
49  output->raw_mutable_data(input.meta()))[0] != nullptr);
50 
51  if (output_initialized) {
52  CAFFE_ENFORCE_EQ(output->ndim(), input.ndim());
53  for (size_t i = 1; i < input.ndim(); ++i) {
54  CAFFE_ENFORCE_EQ(output->dim(i), input.dim(i));
55  }
56  }
57 
58  auto dims = input.dims();
59  auto num_entries = dims[0];
60 
61  dims[0] = numToCollect_;
62  // IMPORTANT: Force the output to have the right type before reserving,
63  // so that the output gets the right capacity
64  output->raw_mutable_data(input.meta());
65  output->Reserve(dims, &context_);
66 
67  auto* pos_to_object =
68  OutputSize() > POS_TO_OBJECT ? Output(POS_TO_OBJECT) : nullptr;
69  if (pos_to_object) {
70  pos_to_object->Reserve(std::vector<TIndex>{numToCollect_}, &context_);
71  }
72 
73  if (num_entries == 0) {
74  if (!output_initialized) {
75  // Get both shape and meta
76  output->CopyFrom(input, &context_);
77  }
78  return true;
79  }
80 
81  const int64_t* object_id_data = nullptr;
82  std::set<int64_t> unique_object_ids;
83  if (InputSize() > OBJECT_ID) {
84  const auto& object_id = Input(OBJECT_ID);
85  CAFFE_ENFORCE_EQ(object_id.ndim(), 1);
86  CAFFE_ENFORCE_EQ(object_id.size(), num_entries);
87  object_id_data = object_id.template data<int64_t>();
88  unique_object_ids.insert(
89  object_id_data, object_id_data + object_id.size());
90  }
91 
92  const auto num_new_entries = countNewEntries(unique_object_ids);
93  auto num_to_copy = std::min<int32_t>(num_new_entries, numToCollect_);
94  auto output_batch_size = output_initialized ? output->dim(0) : 0;
95  dims[0] = std::min<size_t>(numToCollect_, output_batch_size + num_to_copy);
96  if (output_batch_size < numToCollect_) {
97  output->Resize(dims);
98  if (pos_to_object) {
99  pos_to_object->Resize(dims[0]);
100  }
101  }
102  auto* output_data =
103  static_cast<char*>(output->raw_mutable_data(input.meta()));
104  auto* pos_to_object_data = pos_to_object
105  ? pos_to_object->template mutable_data<int64_t>()
106  : nullptr;
107 
108  auto block_size = input.size_from_dim(1);
109  auto block_bytesize = block_size * input.itemsize();
110  const auto* input_data = static_cast<const char*>(input.raw_data());
111 
112  auto* num_visited_tensor = Output(NUM_VISITED);
113  CAFFE_ENFORCE_EQ(1, num_visited_tensor->size());
114  auto* num_visited = num_visited_tensor->template mutable_data<int64_t>();
115  if (!output_initialized) {
116  *num_visited = 0;
117  }
118  CAFFE_ENFORCE_GE(*num_visited, 0);
119 
120  const auto start_num_visited = *num_visited;
121 
122  auto* object_to_pos_map = OutputSize() > OBJECT_TO_POS_MAP
123  ? OperatorBase::Output<MapType64To32>(OBJECT_TO_POS_MAP)
124  : nullptr;
125 
126  std::set<int64_t> eligible_object_ids;
127  if (object_to_pos_map) {
128  for (auto oid : unique_object_ids) {
129  if (!object_to_pos_map->count(oid)) {
130  eligible_object_ids.insert(oid);
131  }
132  }
133  }
134 
135  for (int i = 0; i < num_entries; ++i) {
136  if (object_id_data && object_to_pos_map &&
137  !eligible_object_ids.count(object_id_data[i])) {
138  // Already in the pool or processed
139  continue;
140  }
141  if (object_id_data) {
142  eligible_object_ids.erase(object_id_data[i]);
143  }
144  int64_t pos = -1;
145  if (*num_visited < numToCollect_) {
146  // append
147  pos = *num_visited;
148  } else {
149  auto& gen = context_.RandGenerator();
150  // uniform between [0, num_visited]
151  std::uniform_int_distribution<int64_t> uniformDist(0, *num_visited);
152  pos = uniformDist(gen);
153  if (pos >= numToCollect_) {
154  // discard
155  pos = -1;
156  }
157  }
158 
159  if (pos < 0) {
160  // discard
161  CAFFE_ENFORCE_GE(*num_visited, numToCollect_);
162  } else {
163  // replace
164  context_.template CopyItems<Context, Context>(
165  input.meta(),
166  block_size,
167  input_data + i * block_bytesize,
168  output_data + pos * block_bytesize);
169 
170  if (object_id_data && pos_to_object_data && object_to_pos_map) {
171  auto old_oid = pos_to_object_data[pos];
172  auto new_oid = object_id_data[i];
173  pos_to_object_data[pos] = new_oid;
174  object_to_pos_map->erase(old_oid);
175  object_to_pos_map->emplace(new_oid, pos);
176  }
177  }
178 
179  ++(*num_visited);
180  }
181  // Sanity check
182  CAFFE_ENFORCE_EQ(*num_visited, start_num_visited + num_new_entries);
183  return true;
184  }
185 
186  private:
187  // number of tensors to collect
188  int numToCollect_;
189 
190  INPUT_TAGS(
191  RESERVOIR_IN,
192  NUM_VISITED_IN,
193  DATA,
194  MUTEX,
195  OBJECT_ID,
196  OBJECT_TO_POS_MAP_IN,
197  POS_TO_OBJECT_IN);
198  OUTPUT_TAGS(RESERVOIR, NUM_VISITED, OBJECT_TO_POS_MAP, POS_TO_OBJECT);
199 
200  int32_t countNewEntries(const std::set<int64_t>& unique_object_ids) {
201  const auto& input = Input(DATA);
202  if (InputSize() <= OBJECT_ID) {
203  return input.dim(0);
204  }
205  const auto& object_to_pos_map =
206  OperatorBase::Input<MapType64To32>(OBJECT_TO_POS_MAP_IN);
207  return std::count_if(
208  unique_object_ids.begin(),
209  unique_object_ids.end(),
210  [&object_to_pos_map](int64_t oid) {
211  return !object_to_pos_map.count(oid);
212  });
213  }
214 };
215 
216 REGISTER_CPU_OPERATOR(ReservoirSampling, ReservoirSamplingOp<CPUContext>);
217 
218 OPERATOR_SCHEMA(ReservoirSampling)
219  .NumInputs({4, 7})
220  .NumOutputs({2, 4})
221  .NumInputsOutputs([](int in, int out) { return in / 3 == out / 2; })
222  .EnforceInplace({{0, 0}, {1, 1}, {5, 2}, {6, 3}})
223  .SetDoc(R"DOC(
224 Collect `DATA` tensor into `RESERVOIR` of size `num_to_collect`. `DATA` is
225 assumed to be a batch.
226 
227 In case where 'objects' may be repeated in data and you only want at most one
228 instance of each 'object' in the reservoir, `OBJECT_ID` can be given for
229 deduplication. If `OBJECT_ID` is given, then you also need to supply additional
230 book-keeping tensors. See input blob documentation for details.
231 
232 This operator is thread-safe.
233 )DOC")
234  .Arg(
235  "num_to_collect",
236  "The number of random samples to append for each positive samples")
237  .Input(
238  0,
239  "RESERVOIR",
240  "The reservoir; should be initialized to empty tensor")
241  .Input(
242  1,
243  "NUM_VISITED",
244  "Number of examples seen so far; should be initialized to 0")
245  .Input(
246  2,
247  "DATA",
248  "Tensor to collect from. The first dimension is assumed to be batch "
249  "size. If the object to be collected is represented by multiple "
250  "tensors, use `PackRecords` to pack them into single tensor.")
251  .Input(3, "MUTEX", "Mutex to prevent data race")
252  .Input(
253  4,
254  "OBJECT_ID",
255  "(Optional, int64) If provided, used for deduplicating object in the "
256  "reservoir")
257  .Input(
258  5,
259  "OBJECT_TO_POS_MAP_IN",
260  "(Optional) Auxillary bookkeeping map. This should be created from "
261  " `CreateMap` with keys of type int64 and values of type int32")
262  .Input(
263  6,
264  "POS_TO_OBJECT_IN",
265  "(Optional) Tensor of type int64 used for bookkeeping in deduplication")
266  .Output(0, "RESERVOIR", "Same as the input")
267  .Output(1, "NUM_VISITED", "Same as the input")
268  .Output(2, "OBJECT_TO_POS_MAP", "(Optional) Same as the input")
269  .Output(3, "POS_TO_OBJECT", "(Optional) Same as the input");
270 
271 SHOULD_NOT_DO_GRADIENT(ReservoirSampling);
272 } // namespace
273 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.