Caffe2 - C++ API
A deep learning, cross platform ML framework
rebatching_queue_ops.h
1 
17 #pragma once
18 
19 #include "rebatching_queue.h"
20 
21 namespace caffe2 {
22 
23 using RebatchingQueuePtr = std::unique_ptr<RebatchingQueue>;
24 
25 class CreateRebatchingQueueOp : public Operator<CPUContext> {
26  public:
27  CreateRebatchingQueueOp(const OperatorDef& operator_def, Workspace* ws)
28  : Operator(operator_def, ws) {}
29 
30  bool RunOnDevice() override {
31  *OperatorBase::Output<RebatchingQueuePtr>(0) =
32  RebatchingQueuePtr(new RebatchingQueue(
33  OperatorBase::GetSingleArgument<int>("capacity", 1),
34  OperatorBase::GetSingleArgument<int>("num_blobs", 1)));
35  return true;
36  }
37 };
38 
39 class EnqueueRebatchingQueueOp : public Operator<CPUContext> {
40  public:
41  EnqueueRebatchingQueueOp(const OperatorDef& operator_def, Workspace* ws)
42  : Operator(operator_def, ws),
43  enqueueBatch_(
44  OperatorBase::GetSingleArgument<bool>("enqueue_batch", false)) {}
45  bool RunOnDevice() override {
46  auto& queue = Inputs()[0]->template Get<RebatchingQueuePtr>();
47  CHECK(queue);
48  CAFFE_ENFORCE_EQ(InputSize(), queue->numBlobs() + 1);
49  std::vector<const TensorCPU*> inputTensors;
50  inputTensors.reserve(InputSize() - 1);
51  for (int i = 1; i < InputSize(); ++i) {
52  inputTensors.push_back(&Input(i));
53  }
54 
55  return enqueueBatch_ ? queue->enqueueMany(context_, inputTensors)
56  : queue->enqueueOne(context_, inputTensors);
57  }
58 
59  private:
60  const bool enqueueBatch_;
61 };
62 
63 class DequeueRebatchingQueueOp : public Operator<CPUContext> {
64  public:
65  DequeueRebatchingQueueOp(const OperatorDef& operator_def, Workspace* ws)
66  : Operator(operator_def, ws),
67  numElements_(OperatorBase::GetSingleArgument<int>("num_elements", 1)) {}
68 
69  bool RunOnDevice() override {
70  auto& queue = Inputs()[0]->template Get<RebatchingQueuePtr>();
71  CHECK(queue);
72 
73  std::vector<TensorCPU*> outputTensors;
74  outputTensors.reserve(OutputSize());
75  for (int i = 0; i < OutputSize(); ++i) {
76  outputTensors.push_back(Output(i));
77  }
78 
79  return queue->dequeue(context_, numElements_, outputTensors);
80  }
81 
82  private:
83  int numElements_;
84 };
85 
86 class CloseRebatchingQueueOp : public Operator<CPUContext> {
87  public:
88  CloseRebatchingQueueOp(const OperatorDef& operator_def, Workspace* ws)
89  : Operator(operator_def, ws) {}
90 
91  bool RunOnDevice() override {
92  CAFFE_ENFORCE_EQ(InputSize(), 1);
93  auto& queue = Inputs()[0]->template Get<RebatchingQueuePtr>();
94  CAFFE_ENFORCE(queue);
95  queue->close();
96  return true;
97  }
98 };
99 } // caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.