Caffe2 - C++ API
A deep learning, cross platform ML framework
last_n_window_collector.cc
1 #include <memory>
2 #include <string>
3 #include <vector>
4 #include "caffe2/core/operator.h"
5 #include "caffe2/core/tensor.h"
6 
7 namespace caffe2 {
8 namespace {
9 
10 template <class Context>
11 class LastNWindowCollectorOp : public Operator<Context> {
12  public:
13  USE_OPERATOR_CONTEXT_FUNCTIONS;
14  template <class... Args>
15  explicit LastNWindowCollectorOp(Args&&... args)
16  : Operator<Context>(std::forward<Args>(args)...),
17  numToCollect_(
18  OperatorBase::GetSingleArgument<int>("num_to_collect", -1)) {
19  CAFFE_ENFORCE_GT(numToCollect_, 0);
20  }
21 
22  bool RunOnDevice() override {
23  if (InputSize() > MUTEX) {
24  auto& mutex = OperatorBase::Input<std::unique_ptr<std::mutex>>(MUTEX);
25  std::lock_guard<std::mutex> guard(*mutex);
26  return collect();
27  } else {
28  return collect();
29  }
30  }
31 
32  private:
33  const int32_t numToCollect_;
34 
35  bool collect() {
36  auto* output = Output(LAST_N);
37  const auto& input = Input(DATA);
38 
39  CAFFE_ENFORCE_GE(input.dim(), 1);
40  bool output_initialized = output->numel() > 0 &&
41  (static_cast<std::shared_ptr<std::vector<TensorCPU>>*>(
42  output->raw_mutable_data(input.dtype()))[0] != nullptr);
43  if (output_initialized) {
44  CAFFE_ENFORCE_EQ(output->dim(), input.dim());
45  for (size_t i = 1; i < input.dim(); ++i) {
46  CAFFE_ENFORCE_EQ(output->size(i), input.size(i));
47  }
48  }
49 
50  auto num_entries = input.sizes()[0];
51 
52  if (OutputSize() > NUM_VISITED) {
53  auto* num_visited_tensor = Output(NUM_VISITED);
54  CAFFE_ENFORCE_EQ(1, num_visited_tensor->numel());
55  auto* num_visited = num_visited_tensor->template mutable_data<int64_t>();
56  if (!output_initialized) {
57  *num_visited = 0;
58  }
59  CAFFE_ENFORCE_GE(*num_visited, 0);
60  *num_visited += num_entries;
61  }
62 
63  if (!output_initialized) {
64  auto dims = input.sizes().vec();
65  dims[0] = 0;
66  output->Resize(dims);
67  // pass meta to output
68  output->raw_mutable_data(input.dtype());
69  output->ReserveSpace(numToCollect_);
70  }
71 
72  if (num_entries == 0) {
73  if (!output_initialized) {
74  // Get both shape and meta
75  output->CopyFrom(input, true /*async*/);
76  }
77  return true;
78  }
79 
80  auto num_to_copy = std::min<int32_t>(num_entries, numToCollect_);
81  auto output_batch_size = output_initialized ? output->size(0) : 0;
82  auto output_num =
83  std::min<size_t>(numToCollect_, output_batch_size + num_to_copy);
84 
85  // output_num is >= output_batch_size
86  if (output_num > output_batch_size) {
87  output->ExtendTo(output_num, 50);
88  }
89 
90  auto* output_data =
91  static_cast<char*>(output->raw_mutable_data(input.dtype()));
92 
93  auto* next = Output(NEXT);
94  CAFFE_ENFORCE_EQ(0, next->dim());
95  auto* next_data = next->template mutable_data<int32_t>();
96  if (!output_initialized) {
97  *next_data = 0;
98  }
99  CAFFE_ENFORCE_LT(*next_data, output->size(0));
100 
101  auto block_size = input.size_from_dim(1);
102  auto block_bytesize = block_size * input.itemsize();
103  const auto* input_data = static_cast<const char*>(input.raw_data());
104 
105  if (num_entries > numToCollect_) {
106  // just copy the last N rows
107  context_.CopyItemsSameDevice(
108  input.dtype(),
109  num_to_copy * block_size,
110  input_data + (num_entries - numToCollect_) * block_bytesize,
111  output_data);
112  *next_data = 0;
113  return true;
114  }
115  auto start = *next_data;
116  auto first_chunk_size =
117  std::min<size_t>(num_to_copy + start, numToCollect_) - start;
118  context_.CopyItemsSameDevice(
119  input.dtype(),
120  first_chunk_size * block_size,
121  input_data,
122  output_data + start * block_bytesize);
123 
124  context_.CopyItemsSameDevice(
125  input.dtype(),
126  (num_to_copy - first_chunk_size) * block_size,
127  input_data + first_chunk_size * block_bytesize,
128  output_data);
129 
130  *next_data = (start + num_to_copy) % numToCollect_;
131 
132  return true;
133  }
134 
135  INPUT_TAGS(LAST_N_IN, NEXT_IN, DATA, MUTEX, NUM_VISITED_IN);
136  OUTPUT_TAGS(LAST_N, NEXT, NUM_VISITED);
137 };
138 
139 REGISTER_CPU_OPERATOR(LastNWindowCollector, LastNWindowCollectorOp<CPUContext>);
140 
141 OPERATOR_SCHEMA(LastNWindowCollector)
142  .NumInputs({3, 4, 5})
143  .NumOutputs(2, 3)
144  .EnforceInplace({{0, 0}, {1, 1}, {4, 2}})
145  .SetDoc(R"DOC(
146 Collect the last N rows from input data. The purpose is to keep track of data
147 accross batches, so for example suppose the LastNWindowCollector is called
148 successively with the following input data
149 
150  [1, 2, 3, 4]
151  [5, 6, 7]
152  [8, 9, 10, 11]
153 
154 And the number of items is set to 6, then the output after the 3rd call
155 will contain the following elements:
156 
157  [6, 7, 8, 9, 10, 11]
158 
159 No guarantee is made on the ordering of elements in input. So a valid value for
160 output could have been
161 
162  [11, 10, 9, 8, 7, 6]
163 
164 Also, this method works for any order tensor, treating the first dimension as
165 input rows and keeping the last N rows seen as input. So for instance:
166 
167  [[1, 2], [2, 3], [3, 4], [4, 5]]
168  [[5, 6], [6, 7], [7, 8]]
169  [[8, 9], [9, 10], [10, 11], [11, 12]]
170 
171 A possible output would be
172 
173  [[6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12]]
174 
175 This is not thread safe unless a mutex is given.
176 )DOC")
177  .Arg(
178  "num_to_collect",
179  "The number of random samples to append for each positive samples")
180  .Input(
181  0,
182  "last-N buffer",
183  "The buffer for last-N record. Should be initialized to empty tensor")
184  .Input(
185  1,
186  "next cursor",
187  "The cursor pointing to the next position that should be replaced. "
188  "Should be initialized to 0.")
189  .Input(2, "DATA", "tensor to collect from")
190  .Input(3, "MUTEX", "(optional) mutex to use to make this thread-safe")
191  .Input(4, "NUM_VISITED", "")
192  .Output(0, "last-N buffer", "Data stored in sessions")
193  .Output(1, "next cursor", "Updated input cursor")
194  .Output(2, "NUM_VISITED", "number of records seen so far");
195 SHOULD_NOT_DO_GRADIENT(LastNWindowCollector);
196 } // namespace
197 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13