Caffe2 - C++ API
A deep learning, cross platform ML framework
queue_ops.h
1 
17 #pragma once
18 
19 #include <memory>
20 #include "blobs_queue.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/utils/math.h"
23 
24 namespace caffe2 {
25 
26 template <typename Context>
27 class CreateBlobsQueueOp final : public Operator<Context> {
28  public:
29  USE_OPERATOR_CONTEXT_FUNCTIONS;
30 
31  CreateBlobsQueueOp(const OperatorDef& operator_def, Workspace* ws)
32  : Operator<Context>(operator_def, ws),
33  ws_(ws),
34  name(operator_def.output().Get(0)) {}
35 
36  bool RunOnDevice() override {
37  const auto capacity = GetSingleArgument("capacity", 1);
38  const auto numBlobs = GetSingleArgument("num_blobs", 1);
39  const auto enforceUniqueName =
40  GetSingleArgument("enforce_unique_name", false);
41  const auto fieldNames =
42  OperatorBase::template GetRepeatedArgument<std::string>("field_names");
43  CAFFE_ENFORCE_EQ(this->OutputSize(), 1);
44  auto queuePtr = Operator<Context>::Outputs()[0]
45  ->template GetMutable<std::shared_ptr<BlobsQueue>>();
46  CAFFE_ENFORCE(queuePtr);
47  *queuePtr = std::make_shared<BlobsQueue>(
48  ws_, name, capacity, numBlobs, enforceUniqueName, fieldNames);
49  return true;
50  }
51 
52  private:
53  Workspace* ws_{nullptr};
54  const std::string name;
55 };
56 
57 template <typename Context>
58 class EnqueueBlobsOp final : public Operator<Context> {
59  public:
60  USE_OPERATOR_CONTEXT_FUNCTIONS;
62  bool RunOnDevice() override {
63  CAFFE_ENFORCE(InputSize() > 1);
64  auto queue = Operator<Context>::Inputs()[0]
65  ->template Get<std::shared_ptr<BlobsQueue>>();
66  CAFFE_ENFORCE(queue && OutputSize() == queue->getNumBlobs());
67  return queue->blockingWrite(this->Outputs());
68  }
69 
70  private:
71 };
72 
73 template <typename Context>
74 class DequeueBlobsOp final : public Operator<Context> {
75  public:
76  USE_OPERATOR_CONTEXT_FUNCTIONS;
77 
78  DequeueBlobsOp(const OperatorDef& operator_def, Workspace* ws)
79  : Operator<Context>(operator_def, ws) {
80  timeout_secs_ = OperatorBase::GetSingleArgument<float>("timeout_secs", 0);
81  }
82 
83  bool RunOnDevice() override {
84  CAFFE_ENFORCE(InputSize() == 1);
85  auto queue =
86  OperatorBase::Inputs()[0]->template Get<std::shared_ptr<BlobsQueue>>();
87  CAFFE_ENFORCE(queue && OutputSize() == queue->getNumBlobs());
88  return queue->blockingRead(this->Outputs(), timeout_secs_);
89  }
90 
91  private:
92  float timeout_secs_;
93 };
94 
95 template <typename Context>
96 class CloseBlobsQueueOp final : public Operator<Context> {
97  public:
98  USE_OPERATOR_CONTEXT_FUNCTIONS;
100  bool RunOnDevice() override {
101  CAFFE_ENFORCE_EQ(InputSize(), 1);
102  auto queue =
103  OperatorBase::Inputs()[0]->template Get<std::shared_ptr<BlobsQueue>>();
104  CAFFE_ENFORCE(queue);
105  queue->close();
106  return true;
107  }
108 
109  private:
110 };
111 
112 template <typename Context>
113 class SafeEnqueueBlobsOp final : public Operator<Context> {
114  public:
115  USE_OPERATOR_CONTEXT_FUNCTIONS;
117  bool RunOnDevice() override {
118  auto queue = Operator<Context>::Inputs()[0]
119  ->template Get<std::shared_ptr<BlobsQueue>>();
120  CAFFE_ENFORCE(queue);
121  auto size = queue->getNumBlobs();
122  CAFFE_ENFORCE(
123  OutputSize() == size + 1,
124  "Expected " + caffe2::to_string(size + 1) + ", " +
125  " got: " + caffe2::to_string(size));
126  bool status = queue->blockingWrite(this->Outputs());
127  Output(size)->Resize();
128  math::Set<bool, Context>(
129  1, !status, Output(size)->template mutable_data<bool>(), &context_);
130  return true;
131  }
132 };
133 
134 template <typename Context>
135 class SafeDequeueBlobsOp final : public Operator<Context> {
136  public:
137  USE_OPERATOR_CONTEXT_FUNCTIONS;
139 
140  SafeDequeueBlobsOp(const OperatorDef& operator_def, Workspace* ws)
141  : Operator<Context>(operator_def, ws),
142  numRecords_(OperatorBase::GetSingleArgument<int>("num_records", 1)) {
143  CAFFE_ENFORCE_GT(numRecords_, 0);
144  }
145 
146  bool dequeueMany(std::shared_ptr<BlobsQueue>& queue) {
147  auto size = queue->getNumBlobs();
148 
149  if (blobs_.size() != size) {
150  blobs_.resize(size);
151  blobPtrs_.resize(size);
152  for (int col = 0; col < size; ++col) {
153  blobPtrs_.at(col) = &blobs_.at(col);
154  }
155  }
156 
157  const int kTensorGrowthPct = 40;
158  for (int i = 0; i < numRecords_; ++i) {
159  if (!queue->blockingRead(blobPtrs_)) {
160  // if we read at least one record, status is still true
161  return i > 0;
162  }
163  for (int col = 0; col < size; ++col) {
164  auto* out = this->Output(col);
165  const auto& in = blobPtrs_.at(col)->template Get<Tensor<Context>>();
166  if (i == 0) {
167  out->CopyFrom(in);
168  } else {
169  auto oldSize = out->size();
170 
171  CAFFE_ENFORCE(
172  in.ndim() > 0,
173  "Empty tensor to dequeue at column ",
174  col,
175  " within ",
176  size,
177  " total columns");
178 
179  out->Extend(in.dims()[0], kTensorGrowthPct, &context_);
180  auto* dst =
181  (char*)out->raw_mutable_data() + oldSize * in.meta().itemsize();
182  context_.template CopyItems<Context, Context>(
183  in.meta(), in.size(), in.raw_data(), dst);
184  }
185  }
186  }
187  return true;
188  }
189 
190  bool dequeueOne(std::shared_ptr<BlobsQueue>& queue) {
191  return queue->blockingRead(this->Outputs());
192  }
193 
194  bool RunOnDevice() override {
195  CAFFE_ENFORCE(InputSize() == 1);
196  auto queue = Operator<Context>::Inputs()[0]
197  ->template Get<std::shared_ptr<BlobsQueue>>();
198  CAFFE_ENFORCE(queue);
199 
200  auto size = queue->getNumBlobs();
201  CAFFE_ENFORCE_EQ(OutputSize(), size + 1);
202 
203  bool status = numRecords_ > 1 ? dequeueMany(queue) : dequeueOne(queue);
204 
205  Output(size)->Resize();
206  math::Set<bool, Context>(
207  1, !status, Output(size)->template mutable_data<bool>(), &context_);
208  return true;
209  }
210 
211  private:
212  int numRecords_;
213  std::vector<Blob> blobs_;
214  std::vector<Blob*> blobPtrs_;
215 };
216 
217 template <typename Context>
218 class WeightedSampleDequeueBlobsOp final : public Operator<Context> {
219  public:
220  USE_OPERATOR_CONTEXT_FUNCTIONS;
221 
222  WeightedSampleDequeueBlobsOp(const OperatorDef& operator_def, Workspace* ws)
223  : Operator<Context>(operator_def, ws),
224  table_idx_blob_(
225  OperatorBase::GetSingleArgument<int>("table_idx_blob", -1)) {
226  CAFFE_ENFORCE_LT(table_idx_blob_, OutputSize() - 1);
227  vector<float> weights = OperatorBase::GetRepeatedArgument<float>("weights");
228  if (weights.empty()) {
229  weights.resize(InputSize(), 1.0f);
230  }
231  CAFFE_ENFORCE_EQ(InputSize(), weights.size());
232 
233  float sum = accumulate(weights.begin(), weights.end(), 0.0f);
234  CAFFE_ENFORCE(sum > 0.0f, "Sum of weights must be positive");
235  cumProbs_.resize(weights.size());
236  for (int i = 0; i < weights.size(); i++) {
237  cumProbs_[i] = weights[i] / sum;
238  CAFFE_ENFORCE_GE(
239  cumProbs_[i], 0.0f, "Each probability must be non-negative");
240  }
241  std::partial_sum(cumProbs_.begin(), cumProbs_.end(), cumProbs_.begin());
242  // Put last value to be 1.0001 to avoid numerical issues.
243  cumProbs_.back() = 1.0001f;
244 
245  LOG(INFO) << "Dequeue weights: " << weights;
246  LOG(INFO) << "cumProbs: " << cumProbs_;
247  }
248 
249  bool RunOnDevice() override {
250  float r;
251  math::RandUniform<float, Context>(1, 0.0f, 1.0f, &r, &context_);
252  auto lb = lower_bound(cumProbs_.begin(), cumProbs_.end(), r);
253  CAFFE_ENFORCE(lb != cumProbs_.end(), "Cannot find ", r, " in cumProbs_.");
254  const int32_t idx = lb - cumProbs_.begin();
255  auto queue = Operator<Context>::Inputs()[idx]
256  ->template Get<std::shared_ptr<BlobsQueue>>();
257 
258  CAFFE_ENFORCE(queue);
259  auto size = queue->getNumBlobs();
260  CAFFE_ENFORCE_EQ(OutputSize(), size + 1);
261  bool status = queue->blockingRead(this->Outputs());
262  if (table_idx_blob_ >= 0) {
263  auto* table_idx_blob_out = Output(table_idx_blob_);
264  table_idx_blob_out->Resize(1);
265  int32_t* data = table_idx_blob_out->template mutable_data<int32_t>();
266  data[0] = idx;
267  }
268 
269  Output(size)->Resize();
270  math::Set<bool, Context>(
271  1, !status, Output(size)->template mutable_data<bool>(), &context_);
272  return true;
273  }
274 
275  private:
276  vector<float> cumProbs_;
277  int table_idx_blob_;
278 };
279 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.