Caffe2 - C++ API
A deep learning, cross platform ML framework
queue_ops.cc
1 #include <caffe2/ideep/ideep_utils.h>
2 #include <caffe2/queue/blobs_queue.h>
3 
4 namespace caffe2 {
5 
6 class IDEEPCreateBlobsQueueOp final : public IDEEPOperator {
7  public:
8  USE_IDEEP_DEF_ALIASES();
9  USE_IDEEP_OPERATOR_FUNCTIONS();
10 
11  IDEEPCreateBlobsQueueOp(const OperatorDef& operator_def, Workspace* ws)
12  : IDEEPOperator(operator_def, ws),
13  ws_(ws),
14  name(operator_def.output().Get(0)) {}
15 
16  bool RunOnDevice() override {
17  const auto capacity = GetSingleArgument("capacity", 1);
18  const auto numBlobs = GetSingleArgument("num_blobs", 1);
19  const auto enforceUniqueName =
20  GetSingleArgument("enforce_unique_name", false);
21  const auto fieldNames =
22  OperatorBase::template GetRepeatedArgument<std::string>("field_names");
23  CAFFE_ENFORCE_EQ(this->OutputSize(), 1);
24  auto queuePtr = OperatorBase::Outputs()[0]
25  ->template GetMutable<std::shared_ptr<BlobsQueue>>();
26 
27  CAFFE_ENFORCE(queuePtr);
28  *queuePtr = std::make_shared<BlobsQueue>(
29  ws_, name, capacity, numBlobs, enforceUniqueName, fieldNames);
30  return true;
31  }
32 
33  private:
34  Workspace* ws_{nullptr};
35  const std::string name;
36 };
37 
38 class IDEEPSafeEnqueueBlobsOp final : public IDEEPOperator {
39  public:
40  USE_IDEEP_DEF_ALIASES();
41  USE_IDEEP_OPERATOR_FUNCTIONS();
42 
43  IDEEPSafeEnqueueBlobsOp(const OperatorDef& operator_def, Workspace* ws)
44  : IDEEPOperator(operator_def, ws) {}
45 
46  bool RunOnDevice() override {
47  auto queue =
48  OperatorBase::Inputs()[0]->template Get<std::shared_ptr<BlobsQueue>>();
49  CAFFE_ENFORCE(queue);
50  auto size = queue->getNumBlobs();
51  CAFFE_ENFORCE(
52  OutputSize() == size + 1,
53  "Expected " + caffe2::to_string(size + 1) + ", " +
54  " got: " + caffe2::to_string(size));
55  bool status = queue->blockingWrite(OperatorBase::Outputs());
56 
57  auto st = OperatorBase::Output<TensorCPU>(1, CPU);
58  st->Resize();
59  auto stat = st->template mutable_data<bool>();
60  stat[0] = !status;
61  return true;
62  }
63 };
64 
65 REGISTER_IDEEP_OPERATOR(CreateBlobsQueue, IDEEPCreateBlobsQueueOp);
66 SHOULD_NOT_DO_GRADIENT(IDEEPCreateBlobsQueueOp);
67 
68 REGISTER_IDEEP_OPERATOR(SafeEnqueueBlobs, IDEEPSafeEnqueueBlobsOp);
69 SHOULD_NOT_DO_GRADIENT(IDEEPSafeEnqueueBlobsOp);
70 
71 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13