3 #include "rebatching_queue.h" 7 using RebatchingQueuePtr = std::unique_ptr<RebatchingQueue>;
14 bool RunOnDevice()
override {
15 *OperatorBase::Output<RebatchingQueuePtr>(0) =
17 OperatorBase::GetSingleArgument<int>(
"capacity", 1),
18 OperatorBase::GetSingleArgument<int>(
"num_blobs", 1)));
28 OperatorBase::GetSingleArgument<bool>(
"enqueue_batch",
false)) {}
29 bool RunOnDevice()
override {
30 auto& queue = Inputs()[0]->template Get<RebatchingQueuePtr>();
32 CAFFE_ENFORCE_EQ(InputSize(), queue->numBlobs() + 1);
33 std::vector<const Tensor*> inputTensors;
34 inputTensors.reserve(InputSize() - 1);
35 for (
int i = 1; i < InputSize(); ++i) {
36 inputTensors.push_back(&
Input(i));
39 return enqueueBatch_ ? queue->enqueueMany(context_, inputTensors)
40 : queue->enqueueOne(context_, inputTensors);
44 const bool enqueueBatch_;
51 numElements_(OperatorBase::GetSingleArgument<int>(
"num_elements", 1)) {}
53 bool RunOnDevice()
override {
54 auto& queue = Inputs()[0]->template Get<RebatchingQueuePtr>();
57 std::vector<Tensor*> outputTensors;
58 outputTensors.reserve(OutputSize());
59 for (
int i = 0; i < OutputSize(); ++i) {
60 outputTensors.push_back(Output(i));
63 return queue->dequeue(context_, numElements_, outputTensors);
75 bool RunOnDevice()
override {
76 CAFFE_ENFORCE_EQ(InputSize(), 1);
77 auto& queue = Inputs()[0]->template Get<RebatchingQueuePtr>();
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Tensor & Input(int idx, DeviceType type=CPUContext::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...