4 #include "caffe2/core/operator.h" 5 #include "caffe2/core/tensor.h" 10 template <
class Context>
11 class LastNWindowCollectorOp :
public Operator<Context> {
13 USE_OPERATOR_CONTEXT_FUNCTIONS;
14 template <
class... Args>
15 explicit LastNWindowCollectorOp(Args&&... args)
16 : Operator<Context>(
std::forward<Args>(args)...),
18 OperatorBase::GetSingleArgument<int>(
"num_to_collect", -1)) {
19 CAFFE_ENFORCE_GT(numToCollect_, 0);
22 bool RunOnDevice()
override {
23 if (InputSize() > MUTEX) {
24 auto& mutex = OperatorBase::Input<std::unique_ptr<std::mutex>>(MUTEX);
25 std::lock_guard<std::mutex> guard(*mutex);
33 const int32_t numToCollect_;
36 auto* output = Output(LAST_N);
37 const auto& input = Input(DATA);
39 CAFFE_ENFORCE_GE(input.dim(), 1);
40 bool output_initialized = output->numel() > 0 &&
41 (
static_cast<std::shared_ptr<std::vector<TensorCPU>
>*>(
42 output->raw_mutable_data(input.dtype()))[0] !=
nullptr);
43 if (output_initialized) {
44 CAFFE_ENFORCE_EQ(output->dim(), input.dim());
45 for (
size_t i = 1; i < input.dim(); ++i) {
46 CAFFE_ENFORCE_EQ(output->size(i), input.size(i));
50 auto num_entries = input.sizes()[0];
52 if (OutputSize() > NUM_VISITED) {
53 auto* num_visited_tensor = Output(NUM_VISITED);
54 CAFFE_ENFORCE_EQ(1, num_visited_tensor->numel());
55 auto* num_visited = num_visited_tensor->template mutable_data<int64_t>();
56 if (!output_initialized) {
59 CAFFE_ENFORCE_GE(*num_visited, 0);
60 *num_visited += num_entries;
63 if (!output_initialized) {
64 auto dims = input.sizes().vec();
68 output->raw_mutable_data(input.dtype());
69 output->ReserveSpace(numToCollect_);
72 if (num_entries == 0) {
73 if (!output_initialized) {
75 output->CopyFrom(input,
true );
80 auto num_to_copy = std::min<int32_t>(num_entries, numToCollect_);
81 auto output_batch_size = output_initialized ? output->size(0) : 0;
83 std::min<size_t>(numToCollect_, output_batch_size + num_to_copy);
86 if (output_num > output_batch_size) {
87 output->ExtendTo(output_num, 50);
91 static_cast<char*
>(output->raw_mutable_data(input.dtype()));
93 auto* next = Output(NEXT);
94 CAFFE_ENFORCE_EQ(0, next->dim());
95 auto* next_data = next->template mutable_data<int32_t>();
96 if (!output_initialized) {
99 CAFFE_ENFORCE_LT(*next_data, output->size(0));
101 auto block_size = input.size_from_dim(1);
102 auto block_bytesize = block_size * input.itemsize();
103 const auto* input_data =
static_cast<const char*
>(input.raw_data());
105 if (num_entries > numToCollect_) {
107 context_.CopyItemsSameDevice(
109 num_to_copy * block_size,
110 input_data + (num_entries - numToCollect_) * block_bytesize,
115 auto start = *next_data;
116 auto first_chunk_size =
117 std::min<size_t>(num_to_copy + start, numToCollect_) - start;
118 context_.CopyItemsSameDevice(
120 first_chunk_size * block_size,
122 output_data + start * block_bytesize);
124 context_.CopyItemsSameDevice(
126 (num_to_copy - first_chunk_size) * block_size,
127 input_data + first_chunk_size * block_bytesize,
130 *next_data = (start + num_to_copy) % numToCollect_;
135 INPUT_TAGS(LAST_N_IN, NEXT_IN, DATA, MUTEX, NUM_VISITED_IN);
136 OUTPUT_TAGS(LAST_N, NEXT, NUM_VISITED);
139 REGISTER_CPU_OPERATOR(LastNWindowCollector, LastNWindowCollectorOp<CPUContext>);
141 OPERATOR_SCHEMA(LastNWindowCollector)
142 .NumInputs({3, 4, 5})
144 .EnforceInplace({{0, 0}, {1, 1}, {4, 2}})
146 Collect the last N rows from input data. The purpose is to keep track of data 147 accross batches, so for example suppose the LastNWindowCollector is called 148 successively with the following input data 154 And the number of items is set to 6, then the output after the 3rd call 155 will contain the following elements: 159 No guarantee is made on the ordering of elements in input. So a valid value for 160 output could have been 164 Also, this method works for any order tensor, treating the first dimension as 165 input rows and keeping the last N rows seen as input. So for instance: 167 [[1, 2], [2, 3], [3, 4], [4, 5]] 168 [[5, 6], [6, 7], [7, 8]] 169 [[8, 9], [9, 10], [10, 11], [11, 12]] 171 A possible output would be 173 [[6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12]] 175 This is not thread safe unless a mutex is given. 179 "The number of random samples to append for each positive samples")
183 "The buffer for last-N record. Should be initialized to empty tensor")
187 "The cursor pointing to the next position that should be replaced. " 188 "Should be initialized to 0.")
189 .Input(2,
"DATA",
"tensor to collect from")
190 .Input(3,
"MUTEX",
"(optional) mutex to use to make this thread-safe")
191 .Input(4,
"NUM_VISITED",
"")
192 .Output(0,
"last-N buffer",
"Data stored in sessions")
193 .Output(1,
"next cursor",
"Updated input cursor")
194 .Output(2,
"NUM_VISITED",
"number of records seen so far");
195 SHOULD_NOT_DO_GRADIENT(LastNWindowCollector);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...