Caffe2 - C++ API
A deep learning, cross platform ML framework
rebatching_queue.cc
1 
17 #include "rebatching_queue.h"
18 #include "caffe2/utils/smart_tensor_printer.h"
19 
20 namespace caffe2 {
21 
22 namespace {
23 
24 // This concat function will always create a new first dimension to concat
25 void concat(
26  CPUContext& context,
27  const std::vector<std::vector<TensorCPU>>& inputs,
28  const std::vector<TensorCPU*>& outputs) {
29  CAFFE_ENFORCE(!inputs.empty());
30 
31  const auto& inputZero = inputs[0];
32  const auto numTensors = inputZero.size();
33  const auto numRows = inputs.size();
34 
35  // Precompute the output sizes to avoid resizing
36  std::vector<std::vector<TIndex>> outputDims(numTensors);
37 
38  for (int i = 0; i < numTensors; ++i) {
39  SmartTensorPrinter::PrintTensor(inputZero.at(i));
40  outputDims[i] = inputZero.at(i).dims();
41  outputDims[i].insert(outputDims[i].begin(), numRows);
42  }
43 
44  // Resize to the final output size
45  std::vector<void*> destinations(numTensors);
46  for (int i = 0; i < numTensors; ++i) {
47  outputs[i]->Resize(outputDims[i]);
48  destinations[i] = outputs[i]->raw_mutable_data(inputZero[i].meta());
49  }
50 
51  for (int i = 0; i < numRows; ++i) {
52  CAFFE_ENFORCE_EQ(inputs[i].size(), numTensors);
53 
54  for (int j = 0; j < numTensors; ++j) {
55  const auto& input = inputs[i][j];
56 
57  CAFFE_ENFORCE(inputZero[j].meta() == input.meta());
58  CAFFE_ENFORCE_EQ(inputZero[j].itemsize(), input.itemsize());
59  CAFFE_ENFORCE_EQ(inputZero[j].ndim(), input.ndim());
60  for (int k = 0; k < input.ndim(); ++k) {
61  CAFFE_ENFORCE_EQ(input.dims()[k], inputZero[j].dims()[k]);
62  }
63 
64  // Skip empty tensors
65  if (input.size() == 0) {
66  continue;
67  }
68 
69  context.CopyItems<CPUContext, CPUContext>(
70  input.meta(),
71  input.size(),
72  input.raw_data() /* src */,
73  destinations[j] /* dst */
74  );
75 
76  destinations[j] =
77  (char*)destinations[j] + input.size() * input.itemsize();
78  }
79  }
80 }
81 
82 std::vector<std::vector<TensorCPU>> split(
83  CPUContext& context,
84  const std::vector<const TensorCPU*>& inputs) {
85  CAFFE_ENFORCE(!inputs.empty());
86 
87  const auto outputSize = inputs[0]->dims().at(0);
88  std::vector<std::vector<TensorCPU>> outputs(outputSize);
89 
90  for (const auto* inputPtr : inputs) {
91  CAFFE_ENFORCE(inputPtr);
92 
93  const auto& input = *inputPtr;
94  const auto innerSize = input.size_from_dim(1);
95  const auto itemSize = input.meta().itemsize();
96 
97  auto outputDims = input.dims();
98  CAFFE_ENFORCE(!outputDims.empty());
99  outputDims.erase(outputDims.begin());
100  CAFFE_ENFORCE_EQ(input.dims().at(0), outputSize);
101 
102  for (int i = 0; i < outputSize; ++i) {
103  outputs[i].push_back(TensorCPU(outputDims));
104  context.CopyItems<CPUContext, CPUContext>(
105  input.meta(),
106  innerSize,
107  (char*)input.raw_data() + i * innerSize * itemSize /* src */,
108  outputs[i].back().raw_mutable_data(input.meta()) /* dst */);
109  }
110  }
111 
112  return outputs;
113 }
114 } // anonymous namespace
115 
116 RebatchingQueue::RebatchingQueue(size_t capacity, size_t numBlobs)
117  : capacity_(capacity), numBlobs_(numBlobs), queue_(capacity) {}
118 
119 RebatchingQueue::~RebatchingQueue() {
120  close();
121 }
122 
123 bool RebatchingQueue::canRead() const {
124  return tail_ < head_;
125 }
126 
127 bool RebatchingQueue::dequeue(
128  CPUContext& context,
129  size_t numElements,
130  const std::vector<TensorCPU*>& outputs) {
131  std::vector<std::vector<TensorCPU>> results;
132  results.reserve(numElements);
133 
134  for (;;) {
135  if (results.size() == numElements) {
136  break;
137  }
138 
139  {
140  std::unique_lock<std::mutex> lock(mutex_);
141 
142  cvEmpty_.wait(lock, [this] { return canRead() || isClosed_; });
143 
144  // We only want to stop reading if the queue is empty and closed
145  if (!canRead() && isClosed_) {
146  break;
147  }
148 
149  do {
150  results.push_back(std::move(queue_[tail_++ % capacity()]));
151  } while (canRead() && results.size() < numElements);
152  }
153 
154  if (numElements == 1) {
155  cvOverflow_.notify_one();
156  } else {
157  cvOverflow_.notify_all();
158  }
159  }
160 
161  if (results.empty()) {
162  return false;
163  }
164 
165  concat(context, results, outputs);
166 
167  return true;
168 }
169 
170 bool RebatchingQueue::canWrite() const {
171  return tail_ + capacity() > head_;
172 }
173 
174 bool RebatchingQueue::enqueueOne(
175  CPUContext& /*context*/,
176  const std::vector<const TensorCPU*>& inputs) {
177  std::vector<std::vector<TensorCPU>> splittedInputs;
178  splittedInputs.emplace_back();
179  auto& tensorVector = splittedInputs.back();
180  tensorVector.reserve(inputs.size());
181  for (const auto* tensorPtr : inputs) {
182  tensorVector.push_back(*tensorPtr);
183  }
184 
185  return enqueue(std::move(splittedInputs));
186 }
187 
188 bool RebatchingQueue::enqueueMany(
189  CPUContext& context,
190  const std::vector<const TensorCPU*>& inputs) {
191  CAFFE_ENFORCE_EQ(numBlobs_, inputs.size());
192 
193  std::vector<std::vector<TensorCPU>> splittedInputs;
194  splittedInputs = split(context, inputs);
195  return enqueue(std::move(splittedInputs));
196 }
197 
198 bool RebatchingQueue::enqueue(
199  std::vector<std::vector<TensorCPU>> splittedInputs) {
200  int idx = 0;
201  for (;;) {
202  if (idx >= splittedInputs.size()) {
203  break;
204  }
205 
206  {
207  std::unique_lock<std::mutex> lock(mutex_);
208 
209  cvOverflow_.wait(lock, [this] { return canWrite() || isClosed_; });
210 
211  if (isClosed_) {
212  // If we are here it means that we didn't apply the entire batch and if
213  // we get closed in the middle of enquing we treat it as a non-success.
214  return false;
215  }
216 
217  do {
218  queue_[head_++ % capacity()] = std::move(splittedInputs[idx++]);
219  } while (canWrite() && idx < splittedInputs.size());
220  }
221 
222  cvEmpty_.notify_all();
223  }
224 
225  return true;
226 }
227 
228 size_t RebatchingQueue::capacity() const {
229  return capacity_;
230 }
231 
232 size_t RebatchingQueue::numBlobs() const {
233  return numBlobs_;
234 }
235 
236 bool RebatchingQueue::isClosed() const {
237  std::lock_guard<std::mutex> g(mutex_);
238  return isClosed_;
239 }
240 
241 void RebatchingQueue::close() {
242  {
243  std::lock_guard<std::mutex> g(mutex_);
244  isClosed_ = true;
245  }
246 
247  cvEmpty_.notify_all();
248  cvOverflow_.notify_all();
249 }
250 } // caffe2
Copyright (c) 2016-present, Facebook, Inc.