4 #include "blobs_queue.h"     5 #include "caffe2/core/operator.h"     6 #include "caffe2/utils/math.h"    10 template <
typename Context>
    13   USE_OPERATOR_CONTEXT_FUNCTIONS;
    18         name(operator_def.output().Get(0)) {}
    20   bool RunOnDevice()
 override {
    21     const auto capacity = GetSingleArgument(
"capacity", 1);
    22     const auto numBlobs = GetSingleArgument(
"num_blobs", 1);
    23     const auto enforceUniqueName =
    24         GetSingleArgument(
"enforce_unique_name", 
false);
    25     const auto fieldNames =
    26         OperatorBase::template GetRepeatedArgument<std::string>(
"field_names");
    27     CAFFE_ENFORCE_EQ(this->OutputSize(), 1);
    29                         ->template GetMutable<std::shared_ptr<BlobsQueue>>();
    30     CAFFE_ENFORCE(queuePtr);
    31     *queuePtr = std::make_shared<BlobsQueue>(
    32         ws_, name, capacity, numBlobs, enforceUniqueName, fieldNames);
    38   const std::string name;
    41 template <
typename Context>
    44   USE_OPERATOR_CONTEXT_FUNCTIONS;
    46   bool RunOnDevice()
 override {
    47     CAFFE_ENFORCE(InputSize() > 1);
    49                      ->template Get<std::shared_ptr<BlobsQueue>>();
    50     CAFFE_ENFORCE(queue && OutputSize() == queue->getNumBlobs());
    51     return queue->blockingWrite(this->Outputs());
    57 template <
typename Context>
    60   USE_OPERATOR_CONTEXT_FUNCTIONS;
    64     timeout_secs_ = OperatorBase::GetSingleArgument<float>(
"timeout_secs", 0);
    67   bool RunOnDevice()
 override {
    68     CAFFE_ENFORCE(InputSize() == 1);
    70         OperatorBase::Inputs()[0]->template Get<std::shared_ptr<BlobsQueue>>();
    71     CAFFE_ENFORCE(queue && OutputSize() == queue->getNumBlobs());
    72     return queue->blockingRead(this->Outputs(), timeout_secs_);
    79 template <
typename Context>
    82   USE_OPERATOR_CONTEXT_FUNCTIONS;
    84   bool RunOnDevice()
 override {
    85     CAFFE_ENFORCE_EQ(InputSize(), 1);
    87         OperatorBase::Inputs()[0]->template Get<std::shared_ptr<BlobsQueue>>();
    96 template <
typename Context>
    99   USE_OPERATOR_CONTEXT_FUNCTIONS;
   101   bool RunOnDevice()
 override {
   103                      ->template Get<std::shared_ptr<BlobsQueue>>();
   104     CAFFE_ENFORCE(queue);
   105     auto size = queue->getNumBlobs();
   107         OutputSize() == size + 1,
   108         "Expected " + c10::to_string(size + 1) + 
", " +
   109             " got: " + c10::to_string(size));
   110     bool status = queue->blockingWrite(this->Outputs());
   111     Output(size)->Resize();
   112     math::Set<bool, Context>(
   113         1, !status, Output(size)->template mutable_data<bool>(), &context_);
   118 template <
typename Context>
   121   USE_OPERATOR_CONTEXT_FUNCTIONS;
   126         numRecords_(OperatorBase::GetSingleArgument<int>(
"num_records", 1)) {
   127     CAFFE_ENFORCE_GT(numRecords_, 0);
   130   bool dequeueMany(std::shared_ptr<BlobsQueue>& queue) {
   131     auto size = queue->getNumBlobs();
   133     if (blobs_.size() != size) {
   135       blobPtrs_.resize(size);
   136       for (
int col = 0; col < size; ++col) {
   137         blobPtrs_.at(col) = &blobs_.at(col);
   141     const int kTensorGrowthPct = 40;
   142     for (
int i = 0; i < numRecords_; ++i) {
   143       if (!queue->blockingRead(blobPtrs_)) {
   147       for (
int col = 0; col < size; ++col) {
   148         auto* out = this->Output(col);
   149         const auto& in = blobPtrs_.at(col)->template Get<Tensor>();
   153           auto oldSize = out->numel();
   157               "Empty tensor to dequeue at column ",
   163           out->Extend(in.sizes()[0], kTensorGrowthPct);
   165               (
char*)out->raw_mutable_data() + oldSize * in.dtype().itemsize();
   166           context_.template CopyItems<Context, Context>(
   167               in.meta(), in.numel(), in.raw_data(), dst);
   174   bool dequeueOne(std::shared_ptr<BlobsQueue>& queue) {
   175     return queue->blockingRead(this->Outputs());
   178   bool RunOnDevice()
 override {
   179     CAFFE_ENFORCE(InputSize() == 1);
   181                      ->template Get<std::shared_ptr<BlobsQueue>>();
   182     CAFFE_ENFORCE(queue);
   184     auto size = queue->getNumBlobs();
   185     CAFFE_ENFORCE_EQ(OutputSize(), size + 1);
   187     bool status = numRecords_ > 1 ? dequeueMany(queue) : dequeueOne(queue);
   189     Output(size)->Resize();
   190     math::Set<bool, Context>(
   191         1, !status, Output(size)->template mutable_data<bool>(), &context_);
   197   std::vector<Blob> blobs_;
   198   std::vector<Blob*> blobPtrs_;
   201 template <
typename Context>
   204   USE_OPERATOR_CONTEXT_FUNCTIONS;
   209             OperatorBase::GetSingleArgument<int>(
"table_idx_blob", -1)) {
   210     CAFFE_ENFORCE_LT(table_idx_blob_, OutputSize() - 1);
   211     vector<float> weights = OperatorBase::GetRepeatedArgument<float>(
"weights");
   212     if (weights.empty()) {
   213       weights.resize(InputSize(), 1.0f);
   215     CAFFE_ENFORCE_EQ(InputSize(), weights.size());
   217     float sum = accumulate(weights.begin(), weights.end(), 0.0f);
   218     CAFFE_ENFORCE(sum > 0.0f, 
"Sum of weights must be positive");
   219     cumProbs_.resize(weights.size());
   220     for (
int i = 0; i < weights.size(); i++) {
   221       cumProbs_[i] = weights[i] / sum;
   223           cumProbs_[i], 0.0f, 
"Each probability must be non-negative");
   225     std::partial_sum(cumProbs_.begin(), cumProbs_.end(), cumProbs_.begin());
   227     cumProbs_.back() = 1.0001f;
   229     LOG(INFO) << 
"Dequeue weights: " << weights;
   230     LOG(INFO) << 
"cumProbs: " << cumProbs_;
   233   bool RunOnDevice()
 override {
   235     math::RandUniform<float, Context>(1, 0.0f, 1.0f, &r, &context_);
   236     auto lb = lower_bound(cumProbs_.begin(), cumProbs_.end(), r);
   237     CAFFE_ENFORCE(lb != cumProbs_.end(), 
"Cannot find ", r, 
" in cumProbs_.");
   238     const int32_t idx = lb - cumProbs_.begin();
   240                      ->template Get<std::shared_ptr<BlobsQueue>>();
   242     CAFFE_ENFORCE(queue);
   243     auto size = queue->getNumBlobs();
   244     CAFFE_ENFORCE_EQ(OutputSize(), size + 1);
   245     bool status = queue->blockingRead(this->Outputs());
   246     if (table_idx_blob_ >= 0) {
   247       auto* table_idx_blob_out =
   248           Output(table_idx_blob_, {1}, at::dtype<int32_t>());
   249       int32_t* data = table_idx_blob_out->template mutable_data<int32_t>();
   253     Output(size)->Resize();
   254     math::Set<bool, Context>(
   255         1, !status, Output(size)->template mutable_data<bool>(), &context_);
   260   vector<float> cumProbs_;
 
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
 
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...