4 #include "blobs_queue.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/utils/math.h" 10 template <
typename Context>
13 USE_OPERATOR_CONTEXT_FUNCTIONS;
18 name(operator_def.output().Get(0)) {}
20 bool RunOnDevice()
override {
21 const auto capacity = GetSingleArgument(
"capacity", 1);
22 const auto numBlobs = GetSingleArgument(
"num_blobs", 1);
23 const auto enforceUniqueName =
24 GetSingleArgument(
"enforce_unique_name",
false);
25 const auto fieldNames =
26 OperatorBase::template GetRepeatedArgument<std::string>(
"field_names");
27 CAFFE_ENFORCE_EQ(this->OutputSize(), 1);
29 ->template GetMutable<std::shared_ptr<BlobsQueue>>();
30 CAFFE_ENFORCE(queuePtr);
31 *queuePtr = std::make_shared<BlobsQueue>(
32 ws_, name, capacity, numBlobs, enforceUniqueName, fieldNames);
38 const std::string name;
41 template <
typename Context>
44 USE_OPERATOR_CONTEXT_FUNCTIONS;
46 bool RunOnDevice()
override {
47 CAFFE_ENFORCE(InputSize() > 1);
49 ->template Get<std::shared_ptr<BlobsQueue>>();
50 CAFFE_ENFORCE(queue && OutputSize() == queue->getNumBlobs());
51 return queue->blockingWrite(this->Outputs());
57 template <
typename Context>
60 USE_OPERATOR_CONTEXT_FUNCTIONS;
64 timeout_secs_ = OperatorBase::GetSingleArgument<float>(
"timeout_secs", 0);
67 bool RunOnDevice()
override {
68 CAFFE_ENFORCE(InputSize() == 1);
70 OperatorBase::Inputs()[0]->template Get<std::shared_ptr<BlobsQueue>>();
71 CAFFE_ENFORCE(queue && OutputSize() == queue->getNumBlobs());
72 return queue->blockingRead(this->Outputs(), timeout_secs_);
79 template <
typename Context>
82 USE_OPERATOR_CONTEXT_FUNCTIONS;
84 bool RunOnDevice()
override {
85 CAFFE_ENFORCE_EQ(InputSize(), 1);
87 OperatorBase::Inputs()[0]->template Get<std::shared_ptr<BlobsQueue>>();
96 template <
typename Context>
99 USE_OPERATOR_CONTEXT_FUNCTIONS;
101 bool RunOnDevice()
override {
103 ->template Get<std::shared_ptr<BlobsQueue>>();
104 CAFFE_ENFORCE(queue);
105 auto size = queue->getNumBlobs();
107 OutputSize() == size + 1,
108 "Expected " + c10::to_string(size + 1) +
", " +
109 " got: " + c10::to_string(size));
110 bool status = queue->blockingWrite(this->Outputs());
111 Output(size)->Resize();
112 math::Set<bool, Context>(
113 1, !status, Output(size)->template mutable_data<bool>(), &context_);
118 template <
typename Context>
121 USE_OPERATOR_CONTEXT_FUNCTIONS;
126 numRecords_(OperatorBase::GetSingleArgument<int>(
"num_records", 1)) {
127 CAFFE_ENFORCE_GT(numRecords_, 0);
130 bool dequeueMany(std::shared_ptr<BlobsQueue>& queue) {
131 auto size = queue->getNumBlobs();
133 if (blobs_.size() != size) {
135 blobPtrs_.resize(size);
136 for (
int col = 0; col < size; ++col) {
137 blobPtrs_.at(col) = &blobs_.at(col);
141 const int kTensorGrowthPct = 40;
142 for (
int i = 0; i < numRecords_; ++i) {
143 if (!queue->blockingRead(blobPtrs_)) {
147 for (
int col = 0; col < size; ++col) {
148 auto* out = this->Output(col);
149 const auto& in = blobPtrs_.at(col)->template Get<Tensor>();
153 auto oldSize = out->numel();
157 "Empty tensor to dequeue at column ",
163 out->Extend(in.sizes()[0], kTensorGrowthPct);
165 (
char*)out->raw_mutable_data() + oldSize * in.dtype().itemsize();
166 context_.template CopyItems<Context, Context>(
167 in.meta(), in.numel(), in.raw_data(), dst);
174 bool dequeueOne(std::shared_ptr<BlobsQueue>& queue) {
175 return queue->blockingRead(this->Outputs());
178 bool RunOnDevice()
override {
179 CAFFE_ENFORCE(InputSize() == 1);
181 ->template Get<std::shared_ptr<BlobsQueue>>();
182 CAFFE_ENFORCE(queue);
184 auto size = queue->getNumBlobs();
185 CAFFE_ENFORCE_EQ(OutputSize(), size + 1);
187 bool status = numRecords_ > 1 ? dequeueMany(queue) : dequeueOne(queue);
189 Output(size)->Resize();
190 math::Set<bool, Context>(
191 1, !status, Output(size)->template mutable_data<bool>(), &context_);
197 std::vector<Blob> blobs_;
198 std::vector<Blob*> blobPtrs_;
201 template <
typename Context>
204 USE_OPERATOR_CONTEXT_FUNCTIONS;
209 OperatorBase::GetSingleArgument<int>(
"table_idx_blob", -1)) {
210 CAFFE_ENFORCE_LT(table_idx_blob_, OutputSize() - 1);
211 vector<float> weights = OperatorBase::GetRepeatedArgument<float>(
"weights");
212 if (weights.empty()) {
213 weights.resize(InputSize(), 1.0f);
215 CAFFE_ENFORCE_EQ(InputSize(), weights.size());
217 float sum = accumulate(weights.begin(), weights.end(), 0.0f);
218 CAFFE_ENFORCE(sum > 0.0f,
"Sum of weights must be positive");
219 cumProbs_.resize(weights.size());
220 for (
int i = 0; i < weights.size(); i++) {
221 cumProbs_[i] = weights[i] / sum;
223 cumProbs_[i], 0.0f,
"Each probability must be non-negative");
225 std::partial_sum(cumProbs_.begin(), cumProbs_.end(), cumProbs_.begin());
227 cumProbs_.back() = 1.0001f;
229 LOG(INFO) <<
"Dequeue weights: " << weights;
230 LOG(INFO) <<
"cumProbs: " << cumProbs_;
233 bool RunOnDevice()
override {
235 math::RandUniform<float, Context>(1, 0.0f, 1.0f, &r, &context_);
236 auto lb = lower_bound(cumProbs_.begin(), cumProbs_.end(), r);
237 CAFFE_ENFORCE(lb != cumProbs_.end(),
"Cannot find ", r,
" in cumProbs_.");
238 const int32_t idx = lb - cumProbs_.begin();
240 ->template Get<std::shared_ptr<BlobsQueue>>();
242 CAFFE_ENFORCE(queue);
243 auto size = queue->getNumBlobs();
244 CAFFE_ENFORCE_EQ(OutputSize(), size + 1);
245 bool status = queue->blockingRead(this->Outputs());
246 if (table_idx_blob_ >= 0) {
247 auto* table_idx_blob_out =
248 Output(table_idx_blob_, {1}, at::dtype<int32_t>());
249 int32_t* data = table_idx_blob_out->template mutable_data<int32_t>();
253 Output(size)->Resize();
254 math::Set<bool, Context>(
255 1, !status, Output(size)->template mutable_data<bool>(), &context_);
260 vector<float> cumProbs_;
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...