Caffe2 - C++ API
A deep learning, cross platform ML framework
queue_ops.h
1 #pragma once
2 
3 #include <memory>
4 #include "blobs_queue.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 
10 template <typename Context>
11 class CreateBlobsQueueOp final : public Operator<Context> {
12  public:
13  USE_OPERATOR_CONTEXT_FUNCTIONS;
14 
15  CreateBlobsQueueOp(const OperatorDef& operator_def, Workspace* ws)
16  : Operator<Context>(operator_def, ws),
17  ws_(ws),
18  name(operator_def.output().Get(0)) {}
19 
20  bool RunOnDevice() override {
21  const auto capacity = GetSingleArgument("capacity", 1);
22  const auto numBlobs = GetSingleArgument("num_blobs", 1);
23  const auto enforceUniqueName =
24  GetSingleArgument("enforce_unique_name", false);
25  const auto fieldNames =
26  OperatorBase::template GetRepeatedArgument<std::string>("field_names");
27  CAFFE_ENFORCE_EQ(this->OutputSize(), 1);
28  auto queuePtr = Operator<Context>::Outputs()[0]
29  ->template GetMutable<std::shared_ptr<BlobsQueue>>();
30  CAFFE_ENFORCE(queuePtr);
31  *queuePtr = std::make_shared<BlobsQueue>(
32  ws_, name, capacity, numBlobs, enforceUniqueName, fieldNames);
33  return true;
34  }
35 
36  private:
37  Workspace* ws_{nullptr};
38  const std::string name;
39 };
40 
41 template <typename Context>
42 class EnqueueBlobsOp final : public Operator<Context> {
43  public:
44  USE_OPERATOR_CONTEXT_FUNCTIONS;
46  bool RunOnDevice() override {
47  CAFFE_ENFORCE(InputSize() > 1);
48  auto queue = Operator<Context>::Inputs()[0]
49  ->template Get<std::shared_ptr<BlobsQueue>>();
50  CAFFE_ENFORCE(queue && OutputSize() == queue->getNumBlobs());
51  return queue->blockingWrite(this->Outputs());
52  }
53 
54  private:
55 };
56 
57 template <typename Context>
58 class DequeueBlobsOp final : public Operator<Context> {
59  public:
60  USE_OPERATOR_CONTEXT_FUNCTIONS;
61 
62  DequeueBlobsOp(const OperatorDef& operator_def, Workspace* ws)
63  : Operator<Context>(operator_def, ws) {
64  timeout_secs_ = OperatorBase::GetSingleArgument<float>("timeout_secs", 0);
65  }
66 
67  bool RunOnDevice() override {
68  CAFFE_ENFORCE(InputSize() == 1);
69  auto queue =
70  OperatorBase::Inputs()[0]->template Get<std::shared_ptr<BlobsQueue>>();
71  CAFFE_ENFORCE(queue && OutputSize() == queue->getNumBlobs());
72  return queue->blockingRead(this->Outputs(), timeout_secs_);
73  }
74 
75  private:
76  float timeout_secs_;
77 };
78 
79 template <typename Context>
80 class CloseBlobsQueueOp final : public Operator<Context> {
81  public:
82  USE_OPERATOR_CONTEXT_FUNCTIONS;
84  bool RunOnDevice() override {
85  CAFFE_ENFORCE_EQ(InputSize(), 1);
86  auto queue =
87  OperatorBase::Inputs()[0]->template Get<std::shared_ptr<BlobsQueue>>();
88  CAFFE_ENFORCE(queue);
89  queue->close();
90  return true;
91  }
92 
93  private:
94 };
95 
96 template <typename Context>
97 class SafeEnqueueBlobsOp final : public Operator<Context> {
98  public:
99  USE_OPERATOR_CONTEXT_FUNCTIONS;
101  bool RunOnDevice() override {
102  auto queue = Operator<Context>::Inputs()[0]
103  ->template Get<std::shared_ptr<BlobsQueue>>();
104  CAFFE_ENFORCE(queue);
105  auto size = queue->getNumBlobs();
106  CAFFE_ENFORCE(
107  OutputSize() == size + 1,
108  "Expected " + c10::to_string(size + 1) + ", " +
109  " got: " + c10::to_string(size));
110  bool status = queue->blockingWrite(this->Outputs());
111  Output(size)->Resize();
112  math::Set<bool, Context>(
113  1, !status, Output(size)->template mutable_data<bool>(), &context_);
114  return true;
115  }
116 };
117 
118 template <typename Context>
119 class SafeDequeueBlobsOp final : public Operator<Context> {
120  public:
121  USE_OPERATOR_CONTEXT_FUNCTIONS;
123 
124  SafeDequeueBlobsOp(const OperatorDef& operator_def, Workspace* ws)
125  : Operator<Context>(operator_def, ws),
126  numRecords_(OperatorBase::GetSingleArgument<int>("num_records", 1)) {
127  CAFFE_ENFORCE_GT(numRecords_, 0);
128  }
129 
130  bool dequeueMany(std::shared_ptr<BlobsQueue>& queue) {
131  auto size = queue->getNumBlobs();
132 
133  if (blobs_.size() != size) {
134  blobs_.resize(size);
135  blobPtrs_.resize(size);
136  for (int col = 0; col < size; ++col) {
137  blobPtrs_.at(col) = &blobs_.at(col);
138  }
139  }
140 
141  const int kTensorGrowthPct = 40;
142  for (int i = 0; i < numRecords_; ++i) {
143  if (!queue->blockingRead(blobPtrs_)) {
144  // if we read at least one record, status is still true
145  return i > 0;
146  }
147  for (int col = 0; col < size; ++col) {
148  auto* out = this->Output(col);
149  const auto& in = blobPtrs_.at(col)->template Get<Tensor>();
150  if (i == 0) {
151  out->CopyFrom(in);
152  } else {
153  auto oldSize = out->numel();
154 
155  CAFFE_ENFORCE(
156  in.dim() > 0,
157  "Empty tensor to dequeue at column ",
158  col,
159  " within ",
160  size,
161  " total columns");
162 
163  out->Extend(in.sizes()[0], kTensorGrowthPct);
164  auto* dst =
165  (char*)out->raw_mutable_data() + oldSize * in.dtype().itemsize();
166  context_.template CopyItems<Context, Context>(
167  in.meta(), in.numel(), in.raw_data(), dst);
168  }
169  }
170  }
171  return true;
172  }
173 
174  bool dequeueOne(std::shared_ptr<BlobsQueue>& queue) {
175  return queue->blockingRead(this->Outputs());
176  }
177 
178  bool RunOnDevice() override {
179  CAFFE_ENFORCE(InputSize() == 1);
180  auto queue = Operator<Context>::Inputs()[0]
181  ->template Get<std::shared_ptr<BlobsQueue>>();
182  CAFFE_ENFORCE(queue);
183 
184  auto size = queue->getNumBlobs();
185  CAFFE_ENFORCE_EQ(OutputSize(), size + 1);
186 
187  bool status = numRecords_ > 1 ? dequeueMany(queue) : dequeueOne(queue);
188 
189  Output(size)->Resize();
190  math::Set<bool, Context>(
191  1, !status, Output(size)->template mutable_data<bool>(), &context_);
192  return true;
193  }
194 
195  private:
196  int numRecords_;
197  std::vector<Blob> blobs_;
198  std::vector<Blob*> blobPtrs_;
199 };
200 
201 template <typename Context>
202 class WeightedSampleDequeueBlobsOp final : public Operator<Context> {
203  public:
204  USE_OPERATOR_CONTEXT_FUNCTIONS;
205 
206  WeightedSampleDequeueBlobsOp(const OperatorDef& operator_def, Workspace* ws)
207  : Operator<Context>(operator_def, ws),
208  table_idx_blob_(
209  OperatorBase::GetSingleArgument<int>("table_idx_blob", -1)) {
210  CAFFE_ENFORCE_LT(table_idx_blob_, OutputSize() - 1);
211  vector<float> weights = OperatorBase::GetRepeatedArgument<float>("weights");
212  if (weights.empty()) {
213  weights.resize(InputSize(), 1.0f);
214  }
215  CAFFE_ENFORCE_EQ(InputSize(), weights.size());
216 
217  float sum = accumulate(weights.begin(), weights.end(), 0.0f);
218  CAFFE_ENFORCE(sum > 0.0f, "Sum of weights must be positive");
219  cumProbs_.resize(weights.size());
220  for (int i = 0; i < weights.size(); i++) {
221  cumProbs_[i] = weights[i] / sum;
222  CAFFE_ENFORCE_GE(
223  cumProbs_[i], 0.0f, "Each probability must be non-negative");
224  }
225  std::partial_sum(cumProbs_.begin(), cumProbs_.end(), cumProbs_.begin());
226  // Put last value to be 1.0001 to avoid numerical issues.
227  cumProbs_.back() = 1.0001f;
228 
229  LOG(INFO) << "Dequeue weights: " << weights;
230  LOG(INFO) << "cumProbs: " << cumProbs_;
231  }
232 
233  bool RunOnDevice() override {
234  float r;
235  math::RandUniform<float, Context>(1, 0.0f, 1.0f, &r, &context_);
236  auto lb = lower_bound(cumProbs_.begin(), cumProbs_.end(), r);
237  CAFFE_ENFORCE(lb != cumProbs_.end(), "Cannot find ", r, " in cumProbs_.");
238  const int32_t idx = lb - cumProbs_.begin();
239  auto queue = Operator<Context>::Inputs()[idx]
240  ->template Get<std::shared_ptr<BlobsQueue>>();
241 
242  CAFFE_ENFORCE(queue);
243  auto size = queue->getNumBlobs();
244  CAFFE_ENFORCE_EQ(OutputSize(), size + 1);
245  bool status = queue->blockingRead(this->Outputs());
246  if (table_idx_blob_ >= 0) {
247  auto* table_idx_blob_out =
248  Output(table_idx_blob_, {1}, at::dtype<int32_t>());
249  int32_t* data = table_idx_blob_out->template mutable_data<int32_t>();
250  data[0] = idx;
251  }
252 
253  Output(size)->Resize();
254  math::Set<bool, Context>(
255  1, !status, Output(size)->template mutable_data<bool>(), &context_);
256  return true;
257  }
258 
259  private:
260  vector<float> cumProbs_;
261  int table_idx_blob_;
262 };
263 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13