Caffe2 - C++ API
A deep learning, cross platform ML framework
rebatching_queue.cc
1 #include "rebatching_queue.h"
2 #include "caffe2/utils/smart_tensor_printer.h"
3 
4 namespace caffe2 {
5 
6 namespace {
7 
8 // This concat function will always create a new first dimension to concat
9 void concat(
10  CPUContext& context,
11  const std::vector<std::vector<TensorCPU>>& inputs,
12  const std::vector<TensorCPU*>& outputs) {
13  CAFFE_ENFORCE(!inputs.empty());
14 
15  const auto& inputZero = inputs[0];
16  const auto numTensors = inputZero.size();
17  const auto numRows = inputs.size();
18 
19  // Precompute the output sizes to avoid resizing
20  std::vector<std::vector<int64_t>> outputDims(numTensors);
21 
22  for (size_t i = 0; i < numTensors; ++i) {
23  SmartTensorPrinter::PrintTensor(inputZero.at(i));
24  outputDims[i] = inputZero.at(i).sizes().vec();
25  outputDims[i].insert(outputDims[i].begin(), numRows);
26  }
27 
28  // Resize to the final output size
29  std::vector<void*> destinations(numTensors);
30  for (size_t i = 0; i < numTensors; ++i) {
31  outputs[i]->Resize(outputDims[i]);
32  destinations[i] = outputs[i]->raw_mutable_data(inputZero[i].meta());
33  }
34 
35  for (size_t i = 0; i < numRows; ++i) {
36  CAFFE_ENFORCE_EQ(inputs[i].size(), numTensors);
37 
38  for (int j = 0; j < numTensors; ++j) {
39  const auto& input = inputs[i][j];
40 
41  CAFFE_ENFORCE(inputZero[j].meta() == input.dtype());
42  CAFFE_ENFORCE_EQ(inputZero[j].itemsize(), input.itemsize());
43  CAFFE_ENFORCE_EQ(inputZero[j].ndim(), input.dim());
44  for (int k = 0; k < input.dim(); ++k) {
45  CAFFE_ENFORCE_EQ(input.sizes()[k], inputZero[j].size(k));
46  }
47 
48  // Skip empty tensors
49  if (input.numel() == 0) {
50  continue;
51  }
52 
53  context.CopyItemsToCPU(
54  input.dtype(),
55  input.numel(),
56  input.raw_data() /* src */,
57  destinations[j] /* dst */
58  );
59 
60  destinations[j] =
61  (char*)destinations[j] + input.numel() * input.itemsize();
62  }
63  }
64 }
65 
66 std::vector<std::vector<TensorCPU>> split(
67  CPUContext& context,
68  const std::vector<const TensorCPU*>& inputs) {
69  CAFFE_ENFORCE(!inputs.empty());
70 
71  const auto outputSize = inputs[0]->sizes().at(0);
72  std::vector<std::vector<TensorCPU>> outputs(outputSize);
73 
74  for (const auto* inputPtr : inputs) {
75  CAFFE_ENFORCE(inputPtr);
76 
77  const auto& input = *inputPtr;
78  const auto innerSize = input.size_from_dim(1);
79  const auto itemSize = input.dtype().itemsize();
80 
81  auto outputDims = input.sizes().vec();
82  CAFFE_ENFORCE(!outputDims.empty());
83  outputDims.erase(outputDims.begin());
84  CAFFE_ENFORCE_EQ(input.sizes().at(0), outputSize);
85 
86  for (int i = 0; i < outputSize; ++i) {
87  outputs[i].push_back(Tensor(outputDims, CPU));
88  context.CopyItemsToCPU(
89  input.dtype(),
90  innerSize,
91  (char*)input.raw_data() + i * innerSize * itemSize /* src */,
92  outputs[i].back().raw_mutable_data(input.dtype()) /* dst */);
93  }
94  }
95 
96  return outputs;
97 }
98 } // anonymous namespace
99 
100 RebatchingQueue::RebatchingQueue(size_t capacity, size_t numBlobs)
101  : capacity_(capacity), numBlobs_(numBlobs), queue_(capacity) {}
102 
103 RebatchingQueue::~RebatchingQueue() {
104  close();
105 }
106 
107 bool RebatchingQueue::canRead() const {
108  return tail_ < head_;
109 }
110 
111 bool RebatchingQueue::dequeue(
112  CPUContext& context,
113  size_t numElements,
114  const std::vector<TensorCPU*>& outputs) {
115  std::vector<std::vector<TensorCPU>> results;
116  results.reserve(numElements);
117 
118  for (;;) {
119  if (results.size() == numElements) {
120  break;
121  }
122 
123  {
124  std::unique_lock<std::mutex> lock(mutex_);
125 
126  cvEmpty_.wait(lock, [this] { return canRead() || isClosed_; });
127 
128  // We only want to stop reading if the queue is empty and closed
129  if (!canRead() && isClosed_) {
130  break;
131  }
132 
133  do {
134  results.push_back(std::move(queue_[tail_++ % capacity()]));
135  } while (canRead() && results.size() < numElements);
136  }
137 
138  if (numElements == 1) {
139  cvOverflow_.notify_one();
140  } else {
141  cvOverflow_.notify_all();
142  }
143  }
144 
145  if (results.empty()) {
146  return false;
147  }
148 
149  concat(context, results, outputs);
150 
151  return true;
152 }
153 
154 bool RebatchingQueue::canWrite() const {
155  return tail_ + capacity() > head_;
156 }
157 
158 bool RebatchingQueue::enqueueOne(
159  CPUContext& /*context*/,
160  const std::vector<const TensorCPU*>& inputs) {
161  std::vector<std::vector<TensorCPU>> splittedInputs;
162  splittedInputs.emplace_back();
163  auto& tensorVector = splittedInputs.back();
164  tensorVector.reserve(inputs.size());
165  for (const auto* tensorPtr : inputs) {
166  tensorVector.push_back(tensorPtr->Clone());
167  }
168 
169  return enqueue(std::move(splittedInputs));
170 }
171 
172 bool RebatchingQueue::enqueueMany(
173  CPUContext& context,
174  const std::vector<const TensorCPU*>& inputs) {
175  CAFFE_ENFORCE_EQ(numBlobs_, inputs.size());
176 
177  std::vector<std::vector<TensorCPU>> splittedInputs;
178  splittedInputs = split(context, inputs);
179  return enqueue(std::move(splittedInputs));
180 }
181 
182 bool RebatchingQueue::enqueue(
183  std::vector<std::vector<TensorCPU>> splittedInputs) {
184  int idx = 0;
185  for (;;) {
186  if (idx >= splittedInputs.size()) {
187  break;
188  }
189 
190  {
191  std::unique_lock<std::mutex> lock(mutex_);
192 
193  cvOverflow_.wait(lock, [this] { return canWrite() || isClosed_; });
194 
195  if (isClosed_) {
196  // If we are here it means that we didn't apply the entire batch and if
197  // we get closed in the middle of enquing we treat it as a non-success.
198  return false;
199  }
200 
201  do {
202  queue_[head_++ % capacity()] = std::move(splittedInputs[idx++]);
203  } while (canWrite() && idx < splittedInputs.size());
204  }
205 
206  cvEmpty_.notify_all();
207  }
208 
209  return true;
210 }
211 
212 size_t RebatchingQueue::capacity() const {
213  return capacity_;
214 }
215 
216 size_t RebatchingQueue::numBlobs() const {
217  return numBlobs_;
218 }
219 
220 bool RebatchingQueue::isClosed() const {
221  std::lock_guard<std::mutex> g(mutex_);
222  return isClosed_;
223 }
224 
225 void RebatchingQueue::close() {
226  {
227  std::lock_guard<std::mutex> g(mutex_);
228  isClosed_ = true;
229  }
230 
231  cvEmpty_.notify_all();
232  cvOverflow_.notify_all();
233 }
234 } // caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13