Caffe2 - C++ API
A deep learning, cross platform ML framework
last_n_window_collector.cc
1 
17 #include <memory>
18 #include <string>
19 #include <vector>
20 #include "caffe2/core/operator.h"
21 #include "caffe2/core/tensor.h"
22 
23 namespace caffe2 {
24 namespace {
25 
26 template <class Context>
27 class LastNWindowCollectorOp : public Operator<Context> {
28  public:
29  USE_OPERATOR_CONTEXT_FUNCTIONS;
30  LastNWindowCollectorOp(const OperatorDef& operator_def, Workspace* ws)
31  : Operator<Context>(operator_def, ws),
32  numToCollect_(
33  OperatorBase::GetSingleArgument<int>("num_to_collect", -1)) {
34  CAFFE_ENFORCE_GT(numToCollect_, 0);
35  }
36 
37  bool RunOnDevice() override {
38  if (InputSize() > MUTEX) {
39  auto& mutex = OperatorBase::Input<std::unique_ptr<std::mutex>>(MUTEX);
40  std::lock_guard<std::mutex> guard(*mutex);
41  return collect();
42  } else {
43  return collect();
44  }
45  }
46 
47  private:
48  const int32_t numToCollect_;
49 
50  bool collect() {
51  auto* output = Output(LAST_N);
52  const auto& input = Input(DATA);
53 
54  CAFFE_ENFORCE_GE(input.ndim(), 1);
55  bool output_initialized = output->size() > 0 &&
56  (static_cast<std::shared_ptr<std::vector<TensorCPU>>*>(
57  output->raw_mutable_data(input.meta()))[0] != nullptr);
58  if (output_initialized) {
59  CAFFE_ENFORCE_EQ(output->ndim(), input.ndim());
60  for (size_t i = 1; i < input.ndim(); ++i) {
61  CAFFE_ENFORCE_EQ(output->dim(i), input.dim(i));
62  }
63  }
64 
65  auto dims = input.dims();
66  auto num_entries = dims[0];
67 
68  if (OutputSize() > NUM_VISITED) {
69  auto* num_visited_tensor = Output(NUM_VISITED);
70  CAFFE_ENFORCE_EQ(1, num_visited_tensor->size());
71  auto* num_visited = num_visited_tensor->template mutable_data<int64_t>();
72  if (!output_initialized) {
73  *num_visited = 0;
74  }
75  CAFFE_ENFORCE_GE(*num_visited, 0);
76  *num_visited += num_entries;
77  }
78 
79  dims[0] = numToCollect_;
80  output->Reserve(dims, &context_);
81 
82  if (num_entries == 0) {
83  if (!output_initialized) {
84  // Get both shape and meta
85  output->CopyFrom(input, &context_);
86  }
87  return true;
88  }
89 
90  auto num_to_copy = std::min<int32_t>(num_entries, numToCollect_);
91  auto output_batch_size = output_initialized ? output->dim(0) : 0;
92  dims[0] = std::min<size_t>(numToCollect_, output_batch_size + num_to_copy);
93  if (output_batch_size < numToCollect_) {
94  output->Resize(dims);
95  }
96  auto* output_data =
97  static_cast<char*>(output->raw_mutable_data(input.meta()));
98 
99  auto* next = Output(NEXT);
100  CAFFE_ENFORCE_EQ(0, next->ndim());
101  auto* next_data = next->template mutable_data<int32_t>();
102  if (!output_initialized) {
103  *next_data = 0;
104  }
105  CAFFE_ENFORCE_LT(*next_data, output->dim(0));
106 
107  auto block_size = input.size_from_dim(1);
108  auto block_bytesize = block_size * input.itemsize();
109  const auto* input_data = static_cast<const char*>(input.raw_data());
110 
111  if (num_entries > numToCollect_) {
112  // just copy the last N rows
113  context_.template CopyItems<Context, Context>(
114  input.meta(),
115  num_to_copy * block_size,
116  input_data + (num_entries - numToCollect_) * block_bytesize,
117  output_data);
118  *next_data = 0;
119  return true;
120  }
121  auto start = *next_data;
122  auto first_chunk_size =
123  std::min<size_t>(num_to_copy + start, numToCollect_) - start;
124  context_.template CopyItems<Context, Context>(
125  input.meta(),
126  first_chunk_size * block_size,
127  input_data,
128  output_data + start * block_bytesize);
129 
130  context_.template CopyItems<Context, Context>(
131  input.meta(),
132  (num_to_copy - first_chunk_size) * block_size,
133  input_data + first_chunk_size * block_bytesize,
134  output_data);
135 
136  *next_data = (start + num_to_copy) % numToCollect_;
137 
138  return true;
139  }
140 
141  INPUT_TAGS(LAST_N_IN, NEXT_IN, DATA, MUTEX, NUM_VISITED_IN);
142  OUTPUT_TAGS(LAST_N, NEXT, NUM_VISITED);
143 };
144 
145 REGISTER_CPU_OPERATOR(LastNWindowCollector, LastNWindowCollectorOp<CPUContext>);
146 
147 OPERATOR_SCHEMA(LastNWindowCollector)
148  .NumInputs({3, 4, 5})
149  .NumOutputs(2, 3)
150  .EnforceInplace({{0, 0}, {1, 1}, {4, 2}})
151  .SetDoc(R"DOC(
152 Collect the last N rows from input data. The purpose is to keep track of data
153 accross batches, so for example suppose the LastNWindowCollector is called
154 successively with the following input data
155 
156  [1, 2, 3, 4]
157  [5, 6, 7]
158  [8, 9, 10, 11]
159 
160 And the number of items is set to 6, then the output after the 3rd call
161 will contain the following elements:
162 
163  [6, 7, 8, 9, 10, 11]
164 
165 No guarantee is made on the ordering of elements in input. So a valid value for
166 output could have been
167 
168  [11, 10, 9, 8, 7, 6]
169 
170 Also, this method works for any order tensor, treating the first dimension as
171 input rows and keeping the last N rows seen as input. So for instance:
172 
173  [[1, 2], [2, 3], [3, 4], [4, 5]]
174  [[5, 6], [6, 7], [7, 8]]
175  [[8, 9], [9, 10], [10, 11], [11, 12]]
176 
177 A possible output would be
178 
179  [[6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12]]
180 
181 This is not thread safe unless a mutex is given.
182 )DOC")
183  .Arg(
184  "num_to_collect",
185  "The number of random samples to append for each positive samples")
186  .Input(
187  0,
188  "last-N buffer",
189  "The buffer for last-N record. Should be initialized to empty tensor")
190  .Input(
191  1,
192  "next cursor",
193  "The cursor pointing to the next position that should be replaced. "
194  "Should be initialized to 0.")
195  .Input(2, "DATA", "tensor to collect from")
196  .Input(3, "MUTEX", "(optional) mutex to use to make this thread-safe")
197  .Input(4, "NUM_VISITED", "")
198  .Output(0, "last-N buffer", "Data stored in sessions")
199  .Output(1, "next cursor", "Updated input cursor")
200  .Output(2, "NUM_VISITED", "number of records seen so far");
201 SHOULD_NOT_DO_GRADIENT(LastNWindowCollector);
202 } // namespace
203 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.