Caffe2 - C++ API
A deep learning, cross platform ML framework
counter_ops.h
1 #ifndef CAFFE2_OPERATORS_COUNTER_OPS_H
2 #define CAFFE2_OPERATORS_COUNTER_OPS_H
3 
4 #include <atomic>
5 
6 #include "caffe2/core/context.h"
7 #include "caffe2/core/logging.h"
8 #include "caffe2/core/operator.h"
9 
10 namespace caffe2 {
11 template <typename T>
12 class CAFFE2_API Counter {
13  public:
14  explicit Counter(T count) : count_(count) {}
15  bool countDown() {
16  if (count_-- > 0) {
17  return false;
18  }
19  return true;
20  }
21 
22  T countUp() {
23  return count_++;
24  }
25 
26  T retrieve() const {
27  return count_.load();
28  }
29 
30  T checkIfDone() const {
31  return (count_.load() <= 0);
32  }
33 
34  T reset(T init_count) {
35  return count_.exchange(init_count);
36  }
37 
38  private:
39  std::atomic<T> count_;
40 };
41 
42 // TODO(jiayq): deprecate these ops & consolidate them with IterOp/AtomicIterOp
43 
44 template <typename T, class Context>
45 class CreateCounterOp final : public Operator<Context> {
46  public:
47  USE_OPERATOR_CONTEXT_FUNCTIONS;
48  template <class... Args>
49  explicit CreateCounterOp(Args&&... args)
50  : Operator<Context>(std::forward<Args>(args)...),
51  init_count_(this->template GetSingleArgument<T>("init_count", 0)) {
52  CAFFE_ENFORCE_LE(0, init_count_, "negative init_count is not permitted.");
53  }
54 
55  bool RunOnDevice() override {
56  *this->template Output<std::unique_ptr<Counter<T>>>(0) =
57  std::unique_ptr<Counter<T>>(new Counter<T>(init_count_));
58  return true;
59  }
60 
61  private:
62  T init_count_ = 0;
63 };
64 
65 template <typename T, class Context>
66 class ResetCounterOp final : public Operator<Context> {
67  public:
68  USE_OPERATOR_CONTEXT_FUNCTIONS;
69  template <class... Args>
70  explicit ResetCounterOp(Args&&... args)
71  : Operator<Context>(std::forward<Args>(args)...),
72  init_count_(this->template GetSingleArgument<T>("init_count", 0)) {
73  CAFFE_ENFORCE_LE(0, init_count_, "negative init_count is not permitted.");
74  }
75 
76  bool RunOnDevice() override {
77  auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0);
78  auto previous = counterPtr->reset(init_count_);
79  if (OutputSize() == 1) {
80  auto* output = Output(0);
81  output->Resize();
82  *output->template mutable_data<T>() = previous;
83  }
84  return true;
85  }
86 
87  private:
88  T init_count_;
89 };
90 
91 // Will always use TensorCPU regardless the Context
92 template <typename T, class Context>
93 class CountDownOp final : public Operator<Context> {
94  public:
95  USE_OPERATOR_CONTEXT_FUNCTIONS;
96  template <class... Args>
97  explicit CountDownOp(Args&&... args)
98  : Operator<Context>(std::forward<Args>(args)...) {}
99 
100  bool RunOnDevice() override {
101  auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0);
102  auto* output = Output(0);
103  output->Resize(std::vector<int>{});
104  *output->template mutable_data<bool>() = counterPtr->countDown();
105  return true;
106  }
107 };
108 
109 // Will always use TensorCPU regardless the Context
110 template <typename T, class Context>
111 class CheckCounterDoneOp final : public Operator<Context> {
112  public:
113  USE_OPERATOR_CONTEXT_FUNCTIONS;
114  template <class... Args>
115  explicit CheckCounterDoneOp(Args&&... args)
116  : Operator<Context>(std::forward<Args>(args)...) {}
117 
118  bool RunOnDevice() override {
119  auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0);
120  auto* output = Output(0);
121  output->Resize(std::vector<int>{});
122  *output->template mutable_data<bool>() = counterPtr->checkIfDone();
123  return true;
124  }
125 };
126 
127 // Will always use TensorCPU regardless the Context
128 template <typename T, class Context>
129 class CountUpOp final : public Operator<Context> {
130  public:
131  USE_OPERATOR_CONTEXT_FUNCTIONS;
132  template <class... Args>
133  explicit CountUpOp(Args&&... args)
134  : Operator<Context>(std::forward<Args>(args)...) {}
135 
136  bool RunOnDevice() override {
137  auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0);
138  auto* output = Output(0);
139  output->Resize(std::vector<int>{});
140  *output->template mutable_data<T>() = counterPtr->countUp();
141  return true;
142  }
143 };
144 
145 // Will always use TensorCPU regardless the Context
146 template <typename T, class Context>
147 class RetrieveCountOp final : public Operator<Context> {
148  public:
149  USE_OPERATOR_CONTEXT_FUNCTIONS;
150  template <class... Args>
151  explicit RetrieveCountOp(Args&&... args)
152  : Operator<Context>(std::forward<Args>(args)...) {}
153 
154  bool RunOnDevice() override {
155  auto& counterPtr = this->template Input<std::unique_ptr<Counter<T>>>(0);
156  auto* output = Output(0);
157  output->Resize(std::vector<int>{});
158  *output->template mutable_data<T>() = counterPtr->retrieve();
159  return true;
160  }
161 };
162 
163 } // namespace caffe2
164 #endif // CAFFE2_OPERATORS_COUNTER_OPS_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13