4 #include "caffe2/core/operator.h" 5 #include "caffe2/core/tensor.h" 6 #include "caffe2/operators/map_ops.h" 11 template <
class Context>
12 class ReservoirSamplingOp final :
public Operator<Context> {
14 USE_OPERATOR_CONTEXT_FUNCTIONS;
15 ReservoirSamplingOp(
const OperatorDef operator_def, Workspace* ws)
16 : Operator<Context>(operator_def, ws),
18 OperatorBase::GetSingleArgument<int>(
"num_to_collect", -1)) {
19 CAFFE_ENFORCE(numToCollect_ > 0);
22 bool RunOnDevice()
override {
23 auto& mutex = OperatorBase::Input<std::unique_ptr<std::mutex>>(MUTEX);
24 std::lock_guard<std::mutex> guard(*mutex);
26 auto* output = Output(RESERVOIR);
27 const auto& input = Input(DATA);
29 CAFFE_ENFORCE_GE(input.dim(), 1);
31 bool output_initialized = output->numel() > 0 &&
32 (
static_cast<std::shared_ptr<std::vector<TensorCPU>
>*>(
33 output->raw_mutable_data(input.dtype()))[0] !=
nullptr);
35 if (output_initialized) {
36 CAFFE_ENFORCE_EQ(output->dim(), input.dim());
37 for (
size_t i = 1; i < input.dim(); ++i) {
38 CAFFE_ENFORCE_EQ(output->size(i), input.size(i));
42 auto num_entries = input.sizes()[0];
44 if (!output_initialized) {
47 auto dims = input.sizes().vec();
50 output->raw_mutable_data(input.dtype());
51 output->ReserveSpace(numToCollect_);
55 OutputSize() > POS_TO_OBJECT ? Output(POS_TO_OBJECT) : nullptr;
57 if (!output_initialized) {
59 pos_to_object->Resize(0);
60 pos_to_object->template mutable_data<int64_t>();
61 pos_to_object->ReserveSpace(numToCollect_);
65 auto* object_to_pos_map = OutputSize() > OBJECT_TO_POS_MAP
66 ? OperatorBase::Output<MapType64To32>(OBJECT_TO_POS_MAP)
69 if (object_to_pos_map && !output_initialized) {
70 object_to_pos_map->clear();
73 auto* num_visited_tensor = Output(NUM_VISITED);
74 CAFFE_ENFORCE_EQ(1, num_visited_tensor->numel());
75 auto* num_visited = num_visited_tensor->template mutable_data<int64_t>();
76 if (!output_initialized) {
79 CAFFE_ENFORCE_GE(*num_visited, 0);
81 if (num_entries == 0) {
82 if (!output_initialized) {
84 output->CopyFrom(input,
true);
89 const int64_t* object_id_data =
nullptr;
90 std::set<int64_t> unique_object_ids;
91 if (InputSize() > OBJECT_ID) {
92 const auto& object_id = Input(OBJECT_ID);
93 CAFFE_ENFORCE_EQ(object_id.dim(), 1);
94 CAFFE_ENFORCE_EQ(object_id.numel(), num_entries);
95 object_id_data = object_id.template data<int64_t>();
96 unique_object_ids.insert(
97 object_id_data, object_id_data + object_id.numel());
100 const auto num_new_entries = countNewEntries(unique_object_ids);
101 auto num_to_copy = std::min<int32_t>(num_new_entries, numToCollect_);
102 auto output_batch_size = output_initialized ? output->size(0) : 0;
104 std::min<size_t>(numToCollect_, output_batch_size + num_to_copy);
106 output->ExtendTo(output_num, 50);
108 pos_to_object->ExtendTo(output_num, 50);
112 pos_to_object->template mutable_data<int64_t>() +
113 output_batch_size *
sizeof(int64_t),
115 (output_num - output_batch_size) *
sizeof(int64_t));
119 static_cast<char*
>(output->raw_mutable_data(input.dtype()));
120 auto* pos_to_object_data = pos_to_object
121 ? pos_to_object->template mutable_data<int64_t>()
124 auto block_size = input.size_from_dim(1);
125 auto block_bytesize = block_size * input.itemsize();
126 const auto* input_data =
static_cast<const char*
>(input.raw_data());
128 const auto start_num_visited = *num_visited;
130 std::set<int64_t> eligible_object_ids;
131 if (object_to_pos_map) {
132 for (
auto oid : unique_object_ids) {
133 if (!object_to_pos_map->count(oid)) {
134 eligible_object_ids.insert(oid);
139 for (
int i = 0; i < num_entries; ++i) {
140 if (object_id_data && object_to_pos_map &&
141 !eligible_object_ids.count(object_id_data[i])) {
145 if (object_id_data) {
146 eligible_object_ids.erase(object_id_data[i]);
149 if (*num_visited < numToCollect_) {
153 auto& gen = context_.RandGenerator();
155 std::uniform_int_distribution<int64_t> uniformDist(0, *num_visited);
156 pos = uniformDist(gen);
157 if (pos >= numToCollect_) {
165 CAFFE_ENFORCE_GE(*num_visited, numToCollect_);
168 context_.CopyItemsSameDevice(
171 input_data + i * block_bytesize,
172 output_data + pos * block_bytesize);
174 if (object_id_data && pos_to_object_data && object_to_pos_map) {
175 auto old_oid = pos_to_object_data[pos];
176 auto new_oid = object_id_data[i];
177 pos_to_object_data[pos] = new_oid;
178 object_to_pos_map->erase(old_oid);
179 object_to_pos_map->emplace(new_oid, pos);
186 CAFFE_ENFORCE_EQ(*num_visited, start_num_visited + num_new_entries);
200 OBJECT_TO_POS_MAP_IN,
202 OUTPUT_TAGS(RESERVOIR, NUM_VISITED, OBJECT_TO_POS_MAP, POS_TO_OBJECT);
204 int32_t countNewEntries(
const std::set<int64_t>& unique_object_ids) {
205 const auto& input = Input(DATA);
206 if (InputSize() <= OBJECT_ID) {
207 return input.size(0);
209 const auto& object_to_pos_map =
210 OperatorBase::Input<MapType64To32>(OBJECT_TO_POS_MAP_IN);
211 return std::count_if(
212 unique_object_ids.begin(),
213 unique_object_ids.end(),
214 [&object_to_pos_map](int64_t oid) {
215 return !object_to_pos_map.count(oid);
220 REGISTER_CPU_OPERATOR(ReservoirSampling, ReservoirSamplingOp<CPUContext>);
222 OPERATOR_SCHEMA(ReservoirSampling)
225 .NumInputsOutputs([](
int in,
int out) {
return in / 3 == out / 2; })
226 .EnforceInplace({{0, 0}, {1, 1}, {5, 2}, {6, 3}})
228 Collect `DATA` tensor into `RESERVOIR` of size `num_to_collect`. `DATA` is 229 assumed to be a batch. 231 In case where 'objects' may be repeated in data and you only want at most one 232 instance of each 'object' in the reservoir, `OBJECT_ID` can be given for 233 deduplication. If `OBJECT_ID` is given, then you also need to supply additional 234 book-keeping tensors. See input blob documentation for details. 236 This operator is thread-safe. 240 "The number of random samples to append for each positive samples")
244 "The reservoir; should be initialized to empty tensor")
248 "Number of examples seen so far; should be initialized to 0")
252 "Tensor to collect from. The first dimension is assumed to be batch " 253 "size. If the object to be collected is represented by multiple " 254 "tensors, use `PackRecords` to pack them into single tensor.")
255 .Input(3,
"MUTEX",
"Mutex to prevent data race")
259 "(Optional, int64) If provided, used for deduplicating object in the " 263 "OBJECT_TO_POS_MAP_IN",
264 "(Optional) Auxillary bookkeeping map. This should be created from " 265 " `CreateMap` with keys of type int64 and values of type int32")
269 "(Optional) Tensor of type int64 used for bookkeeping in deduplication")
270 .Output(0,
"RESERVOIR",
"Same as the input")
271 .Output(1,
"NUM_VISITED",
"Same as the input")
272 .Output(2,
"OBJECT_TO_POS_MAP",
"(Optional) Same as the input")
273 .Output(3,
"POS_TO_OBJECT",
"(Optional) Same as the input");
275 SHOULD_NOT_DO_GRADIENT(ReservoirSampling);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...