Caffe2 - C++ API
A deep learning, cross platform ML framework
blobs_queue.cc
1 #include "caffe2/queue/blobs_queue.h"
2 
3 #include <atomic>
4 #include <condition_variable>
5 #include <memory>
6 #include <mutex>
7 #include <queue>
8 
9 #include "caffe2/core/blob_stats.h"
10 #include "caffe2/core/logging.h"
11 #include "caffe2/core/stats.h"
12 #include "caffe2/core/tensor.h"
13 #include "caffe2/core/timer.h"
14 #include "caffe2/core/workspace.h"
15 
16 namespace caffe2 {
17 
18 // Constants for user tracepoints
19 static constexpr int SDT_NONBLOCKING_OP = 0;
20 static constexpr int SDT_BLOCKING_OP = 1;
21 static constexpr uint64_t SDT_TIMEOUT = (uint64_t)-1;
22 static constexpr uint64_t SDT_ABORT = (uint64_t)-2;
23 static constexpr uint64_t SDT_CANCEL = (uint64_t)-3;
24 
25 BlobsQueue::BlobsQueue(
26  Workspace* ws,
27  const std::string& queueName,
28  size_t capacity,
29  size_t numBlobs,
30  bool enforceUniqueName,
31  const std::vector<std::string>& fieldNames)
32  : numBlobs_(numBlobs), name_(queueName), stats_(queueName) {
33  if (!fieldNames.empty()) {
34  CAFFE_ENFORCE_EQ(
35  fieldNames.size(), numBlobs, "Wrong number of fieldNames provided.");
36  stats_.queue_dequeued_bytes.setDetails(fieldNames);
37  }
38  queue_.reserve(capacity);
39  for (size_t i = 0; i < capacity; ++i) {
40  std::vector<Blob*> blobs;
41  blobs.reserve(numBlobs);
42  for (size_t j = 0; j < numBlobs; ++j) {
43  const auto blobName = queueName + "_" + to_string(i) + "_" + to_string(j);
44  if (enforceUniqueName) {
45  CAFFE_ENFORCE(
46  !ws->GetBlob(blobName),
47  "Queue internal blob already exists: ",
48  blobName);
49  }
50  blobs.push_back(ws->CreateBlob(blobName));
51  }
52  queue_.push_back(blobs);
53  }
54  DCHECK_EQ(queue_.size(), capacity);
55 }
56 
57 bool BlobsQueue::blockingRead(
58  const std::vector<Blob*>& inputs,
59  float timeout_secs) {
60  Timer readTimer;
61  auto keeper = this->shared_from_this();
62  const auto& name = name_.c_str();
63  CAFFE_SDT(queue_read_start, name, (void*)this, SDT_BLOCKING_OP);
64  std::unique_lock<std::mutex> g(mutex_);
65  auto canRead = [this]() {
66  CAFFE_ENFORCE_LE(reader_, writer_);
67  return reader_ != writer_;
68  };
69  // Decrease queue balance before reading to indicate queue read pressure
70  // is being increased (-ve queue balance indicates more reads than writes)
71  CAFFE_EVENT(stats_, queue_balance, -1);
72  if (timeout_secs > 0) {
73  std::chrono::milliseconds timeout_ms(int(timeout_secs * 1000));
74  cv_.wait_for(
75  g, timeout_ms, [this, canRead]() { return closing_ || canRead(); });
76  } else {
77  cv_.wait(g, [this, canRead]() { return closing_ || canRead(); });
78  }
79  if (!canRead()) {
80  if (timeout_secs > 0 && !closing_) {
81  LOG(ERROR) << "DequeueBlobs timed out in " << timeout_secs << " secs";
82  CAFFE_SDT(queue_read_end, name, (void*)this, SDT_TIMEOUT);
83  } else {
84  CAFFE_SDT(queue_read_end, name, (void*)this, SDT_CANCEL);
85  }
86  return false;
87  }
88  DCHECK(canRead());
89  auto& result = queue_[reader_ % queue_.size()];
90  CAFFE_ENFORCE(inputs.size() >= result.size());
91  for (auto i = 0; i < result.size(); ++i) {
92  auto bytes = BlobStat::sizeBytes(*result[i]);
93  CAFFE_EVENT(stats_, queue_dequeued_bytes, bytes, i);
94  using std::swap;
95  swap(*(inputs[i]), *(result[i]));
96  }
97  CAFFE_SDT(queue_read_end, name, (void*)this, writer_ - reader_);
98  CAFFE_EVENT(stats_, queue_dequeued_records);
99  ++reader_;
100  cv_.notify_all();
101  CAFFE_EVENT(stats_, read_time_ns, readTimer.NanoSeconds());
102  return true;
103 }
104 
105 bool BlobsQueue::tryWrite(const std::vector<Blob*>& inputs) {
106  Timer writeTimer;
107  auto keeper = this->shared_from_this();
108  const auto& name = name_.c_str();
109  CAFFE_SDT(queue_write_start, name, (void*)this, SDT_NONBLOCKING_OP);
110  std::unique_lock<std::mutex> g(mutex_);
111  if (!canWrite()) {
112  CAFFE_SDT(queue_write_end, name, (void*)this, SDT_ABORT);
113  return false;
114  }
115  // Increase queue balance before writing to indicate queue write pressure is
116  // being increased (+ve queue balance indicates more writes than reads)
117  CAFFE_EVENT(stats_, queue_balance, 1);
118  DCHECK(canWrite());
119  doWrite(inputs);
120  CAFFE_EVENT(stats_, write_time_ns, writeTimer.NanoSeconds());
121  return true;
122 }
123 
124 bool BlobsQueue::blockingWrite(const std::vector<Blob*>& inputs) {
125  Timer writeTimer;
126  auto keeper = this->shared_from_this();
127  const auto& name = name_.c_str();
128  CAFFE_SDT(queue_write_start, name, (void*)this, SDT_BLOCKING_OP);
129  std::unique_lock<std::mutex> g(mutex_);
130  // Increase queue balance before writing to indicate queue write pressure is
131  // being increased (+ve queue balance indicates more writes than reads)
132  CAFFE_EVENT(stats_, queue_balance, 1);
133  cv_.wait(g, [this]() { return closing_ || canWrite(); });
134  if (!canWrite()) {
135  CAFFE_SDT(queue_write_end, name, (void*)this, SDT_ABORT);
136  return false;
137  }
138  DCHECK(canWrite());
139  doWrite(inputs);
140  CAFFE_EVENT(stats_, write_time_ns, writeTimer.NanoSeconds());
141  return true;
142 }
143 
144 void BlobsQueue::close() {
145  closing_ = true;
146 
147  std::lock_guard<std::mutex> g(mutex_);
148  cv_.notify_all();
149 }
150 
151 bool BlobsQueue::canWrite() {
152  // writer is always within [reader, reader + size)
153  // we can write if reader is within [reader, reader + size)
154  CAFFE_ENFORCE_LE(reader_, writer_);
155  CAFFE_ENFORCE_LE(writer_, reader_ + queue_.size());
156  return writer_ != reader_ + queue_.size();
157 }
158 
159 void BlobsQueue::doWrite(const std::vector<Blob*>& inputs) {
160  auto& result = queue_[writer_ % queue_.size()];
161  CAFFE_ENFORCE(inputs.size() >= result.size());
162  const auto& name = name_.c_str();
163  for (auto i = 0; i < result.size(); ++i) {
164  using std::swap;
165  swap(*(inputs[i]), *(result[i]));
166  }
167  CAFFE_SDT(
168  queue_write_end, name, (void*)this, reader_ + queue_.size() - writer_);
169  ++writer_;
170  cv_.notify_all();
171 }
172 
173 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13