Caffe2 - C++ API
A deep learning, cross platform ML framework
rebatching_queue_ops.h
1 #pragma once
2 
3 #include "rebatching_queue.h"
4 
5 namespace caffe2 {
6 
7 using RebatchingQueuePtr = std::unique_ptr<RebatchingQueue>;
8 
9 class CreateRebatchingQueueOp : public Operator<CPUContext> {
10  public:
11  CreateRebatchingQueueOp(const OperatorDef& operator_def, Workspace* ws)
12  : Operator(operator_def, ws) {}
13 
14  bool RunOnDevice() override {
15  *OperatorBase::Output<RebatchingQueuePtr>(0) =
16  RebatchingQueuePtr(new RebatchingQueue(
17  OperatorBase::GetSingleArgument<int>("capacity", 1),
18  OperatorBase::GetSingleArgument<int>("num_blobs", 1)));
19  return true;
20  }
21 };
22 
23 class EnqueueRebatchingQueueOp : public Operator<CPUContext> {
24  public:
25  EnqueueRebatchingQueueOp(const OperatorDef& operator_def, Workspace* ws)
26  : Operator(operator_def, ws),
27  enqueueBatch_(
28  OperatorBase::GetSingleArgument<bool>("enqueue_batch", false)) {}
29  bool RunOnDevice() override {
30  auto& queue = Inputs()[0]->template Get<RebatchingQueuePtr>();
31  CHECK(queue);
32  CAFFE_ENFORCE_EQ(InputSize(), queue->numBlobs() + 1);
33  std::vector<const Tensor*> inputTensors;
34  inputTensors.reserve(InputSize() - 1);
35  for (int i = 1; i < InputSize(); ++i) {
36  inputTensors.push_back(&Input(i));
37  }
38 
39  return enqueueBatch_ ? queue->enqueueMany(context_, inputTensors)
40  : queue->enqueueOne(context_, inputTensors);
41  }
42 
43  private:
44  const bool enqueueBatch_;
45 };
46 
47 class DequeueRebatchingQueueOp : public Operator<CPUContext> {
48  public:
49  DequeueRebatchingQueueOp(const OperatorDef& operator_def, Workspace* ws)
50  : Operator(operator_def, ws),
51  numElements_(OperatorBase::GetSingleArgument<int>("num_elements", 1)) {}
52 
53  bool RunOnDevice() override {
54  auto& queue = Inputs()[0]->template Get<RebatchingQueuePtr>();
55  CHECK(queue);
56 
57  std::vector<Tensor*> outputTensors;
58  outputTensors.reserve(OutputSize());
59  for (int i = 0; i < OutputSize(); ++i) {
60  outputTensors.push_back(Output(i));
61  }
62 
63  return queue->dequeue(context_, numElements_, outputTensors);
64  }
65 
66  private:
67  int numElements_;
68 };
69 
70 class CloseRebatchingQueueOp : public Operator<CPUContext> {
71  public:
72  CloseRebatchingQueueOp(const OperatorDef& operator_def, Workspace* ws)
73  : Operator(operator_def, ws) {}
74 
75  bool RunOnDevice() override {
76  CAFFE_ENFORCE_EQ(InputSize(), 1);
77  auto& queue = Inputs()[0]->template Get<RebatchingQueuePtr>();
78  CAFFE_ENFORCE(queue);
79  queue->close();
80  return true;
81  }
82 };
83 } // caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=CPUContext::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13