1 #include "rebatching_queue.h" 2 #include "caffe2/utils/smart_tensor_printer.h" 11 const std::vector<std::vector<TensorCPU>>& inputs,
12 const std::vector<TensorCPU*>& outputs) {
13 CAFFE_ENFORCE(!inputs.empty());
15 const auto& inputZero = inputs[0];
16 const auto numTensors = inputZero.size();
17 const auto numRows = inputs.size();
20 std::vector<std::vector<int64_t>> outputDims(numTensors);
22 for (
size_t i = 0; i < numTensors; ++i) {
23 SmartTensorPrinter::PrintTensor(inputZero.at(i));
24 outputDims[i] = inputZero.at(i).sizes().vec();
25 outputDims[i].insert(outputDims[i].begin(), numRows);
29 std::vector<void*> destinations(numTensors);
30 for (
size_t i = 0; i < numTensors; ++i) {
31 outputs[i]->Resize(outputDims[i]);
32 destinations[i] = outputs[i]->raw_mutable_data(inputZero[i].meta());
35 for (
size_t i = 0; i < numRows; ++i) {
36 CAFFE_ENFORCE_EQ(inputs[i].size(), numTensors);
38 for (
int j = 0; j < numTensors; ++j) {
39 const auto& input = inputs[i][j];
41 CAFFE_ENFORCE(inputZero[j].meta() == input.dtype());
42 CAFFE_ENFORCE_EQ(inputZero[j].itemsize(), input.itemsize());
43 CAFFE_ENFORCE_EQ(inputZero[j].ndim(), input.dim());
44 for (
int k = 0; k < input.dim(); ++k) {
45 CAFFE_ENFORCE_EQ(input.sizes()[k], inputZero[j].size(k));
49 if (input.numel() == 0) {
53 context.CopyItemsToCPU(
61 (
char*)destinations[j] + input.numel() * input.itemsize();
66 std::vector<std::vector<TensorCPU>> split(
68 const std::vector<const TensorCPU*>& inputs) {
69 CAFFE_ENFORCE(!inputs.empty());
71 const auto outputSize = inputs[0]->sizes().at(0);
72 std::vector<std::vector<TensorCPU>> outputs(outputSize);
74 for (
const auto* inputPtr : inputs) {
75 CAFFE_ENFORCE(inputPtr);
77 const auto& input = *inputPtr;
78 const auto innerSize = input.size_from_dim(1);
79 const auto itemSize = input.dtype().itemsize();
81 auto outputDims = input.sizes().vec();
82 CAFFE_ENFORCE(!outputDims.empty());
83 outputDims.erase(outputDims.begin());
84 CAFFE_ENFORCE_EQ(input.sizes().at(0), outputSize);
86 for (
int i = 0; i < outputSize; ++i) {
87 outputs[i].push_back(
Tensor(outputDims, CPU));
88 context.CopyItemsToCPU(
91 (
char*)input.raw_data() + i * innerSize * itemSize ,
92 outputs[i].back().raw_mutable_data(input.dtype()) );
100 RebatchingQueue::RebatchingQueue(
size_t capacity,
size_t numBlobs)
101 : capacity_(capacity), numBlobs_(numBlobs), queue_(capacity) {}
103 RebatchingQueue::~RebatchingQueue() {
107 bool RebatchingQueue::canRead()
const {
108 return tail_ < head_;
111 bool RebatchingQueue::dequeue(
114 const std::vector<TensorCPU*>& outputs) {
115 std::vector<std::vector<TensorCPU>> results;
116 results.reserve(numElements);
119 if (results.size() == numElements) {
124 std::unique_lock<std::mutex> lock(mutex_);
126 cvEmpty_.wait(lock, [
this] {
return canRead() || isClosed_; });
129 if (!canRead() && isClosed_) {
134 results.push_back(std::move(queue_[tail_++ % capacity()]));
135 }
while (canRead() && results.size() < numElements);
138 if (numElements == 1) {
139 cvOverflow_.notify_one();
141 cvOverflow_.notify_all();
145 if (results.empty()) {
149 concat(context, results, outputs);
154 bool RebatchingQueue::canWrite()
const {
155 return tail_ + capacity() > head_;
158 bool RebatchingQueue::enqueueOne(
160 const std::vector<const TensorCPU*>& inputs) {
161 std::vector<std::vector<TensorCPU>> splittedInputs;
162 splittedInputs.emplace_back();
163 auto& tensorVector = splittedInputs.back();
164 tensorVector.reserve(inputs.size());
165 for (
const auto* tensorPtr : inputs) {
166 tensorVector.push_back(tensorPtr->Clone());
169 return enqueue(std::move(splittedInputs));
172 bool RebatchingQueue::enqueueMany(
174 const std::vector<const TensorCPU*>& inputs) {
175 CAFFE_ENFORCE_EQ(numBlobs_, inputs.size());
177 std::vector<std::vector<TensorCPU>> splittedInputs;
178 splittedInputs = split(context, inputs);
179 return enqueue(std::move(splittedInputs));
182 bool RebatchingQueue::enqueue(
183 std::vector<std::vector<TensorCPU>> splittedInputs) {
186 if (idx >= splittedInputs.size()) {
191 std::unique_lock<std::mutex> lock(mutex_);
193 cvOverflow_.wait(lock, [
this] {
return canWrite() || isClosed_; });
202 queue_[head_++ % capacity()] = std::move(splittedInputs[idx++]);
203 }
while (canWrite() && idx < splittedInputs.size());
206 cvEmpty_.notify_all();
212 size_t RebatchingQueue::capacity()
const {
216 size_t RebatchingQueue::numBlobs()
const {
220 bool RebatchingQueue::isClosed()
const {
221 std::lock_guard<std::mutex> g(mutex_);
225 void RebatchingQueue::close() {
227 std::lock_guard<std::mutex> g(mutex_);
231 cvEmpty_.notify_all();
232 cvOverflow_.notify_all();
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...