Caffe2 - C++ API
A deep learning, cross platform ML framework
counter_ops.h
1 
17 #ifndef CAFFE2_OPERATORS_COUNTER_OPS_H
18 #define CAFFE2_OPERATORS_COUNTER_OPS_H
19 
20 #include <atomic>
21 
22 #include "caffe2/core/context.h"
23 #include "caffe2/core/logging.h"
24 #include "caffe2/core/operator.h"
25 
26 namespace caffe2 {
27 template <typename T>
28 class Counter {
29  public:
30  explicit Counter(T count) : count_(count) {}
31  bool countDown() {
32  if (count_-- > 0) {
33  return false;
34  }
35  return true;
36  }
37 
38  T countUp() {
39  return count_++;
40  }
41 
42  T retrieve() const {
43  return count_.load();
44  }
45 
46  T checkIfDone() const {
47  return (count_.load() <= 0);
48  }
49 
50  T reset(T init_count) {
51  return count_.exchange(init_count);
52  }
53 
54  private:
55  std::atomic<T> count_;
56 };
57 
58 // TODO(jiayq): deprecate these ops & consolidate them with IterOp/AtomicIterOp
59 
60 template <typename T, class Context>
61 class CreateCounterOp final : public Operator<Context> {
62  public:
63  USE_OPERATOR_CONTEXT_FUNCTIONS;
64  CreateCounterOp(const OperatorDef& operator_def, Workspace* ws)
65  : Operator<Context>(operator_def, ws),
66  init_count_(OperatorBase::GetSingleArgument<T>("init_count", 0)) {
67  CAFFE_ENFORCE_LE(0, init_count_, "negative init_count is not permitted.");
68  }
69 
70  bool RunOnDevice() override {
71  *OperatorBase::Output<std::unique_ptr<Counter<T>>>(0) =
72  std::unique_ptr<Counter<T>>(new Counter<T>(init_count_));
73  return true;
74  }
75 
76  private:
77  T init_count_ = 0;
78 };
79 
80 template <typename T, class Context>
81 class ResetCounterOp final : public Operator<Context> {
82  public:
83  USE_OPERATOR_CONTEXT_FUNCTIONS;
84  ResetCounterOp(const OperatorDef& operator_def, Workspace* ws)
85  : Operator<Context>(operator_def, ws),
86  init_count_(OperatorBase::GetSingleArgument<T>("init_count", 0)) {
87  CAFFE_ENFORCE_LE(0, init_count_, "negative init_count is not permitted.");
88  }
89 
90  bool RunOnDevice() override {
91  auto& counterPtr = OperatorBase::Input<std::unique_ptr<Counter<T>>>(0);
92  auto previous = counterPtr->reset(init_count_);
93  if (OutputSize() == 1) {
94  auto* output = OperatorBase::Output<TensorCPU>(0);
95  output->Resize();
96  *output->template mutable_data<T>() = previous;
97  }
98  return true;
99  }
100 
101  private:
102  T init_count_;
103 };
104 
105 // Will always use TensorCPU regardless the Context
106 template <typename T, class Context>
107 class CountDownOp final : public Operator<Context> {
108  public:
109  USE_OPERATOR_CONTEXT_FUNCTIONS;
110  CountDownOp(const OperatorDef& operator_def, Workspace* ws)
111  : Operator<Context>(operator_def, ws) {}
112 
113  bool RunOnDevice() override {
114  auto& counterPtr = OperatorBase::Input<std::unique_ptr<Counter<T>>>(0);
115  auto* output = OperatorBase::Output<TensorCPU>(0);
116  output->Resize(std::vector<int>{});
117  *output->template mutable_data<bool>() = counterPtr->countDown();
118  return true;
119  }
120 };
121 
122 // Will always use TensorCPU regardless the Context
123 template <typename T, class Context>
124 class CheckCounterDoneOp final : public Operator<Context> {
125  public:
126  USE_OPERATOR_CONTEXT_FUNCTIONS;
127  CheckCounterDoneOp(const OperatorDef& operator_def, Workspace* ws)
128  : Operator<Context>(operator_def, ws) {}
129 
130  bool RunOnDevice() override {
131  auto& counterPtr = OperatorBase::Input<std::unique_ptr<Counter<T>>>(0);
132  auto* output = OperatorBase::Output<TensorCPU>(0);
133  output->Resize(std::vector<int>{});
134  *output->template mutable_data<bool>() = counterPtr->checkIfDone();
135  return true;
136  }
137 };
138 
139 // Will always use TensorCPU regardless the Context
140 template <typename T, class Context>
141 class CountUpOp final : public Operator<Context> {
142  public:
143  USE_OPERATOR_CONTEXT_FUNCTIONS;
144  CountUpOp(const OperatorDef& operator_def, Workspace* ws)
145  : Operator<Context>(operator_def, ws) {}
146 
147  bool RunOnDevice() override {
148  auto& counterPtr = OperatorBase::Input<std::unique_ptr<Counter<T>>>(0);
149  auto* output = OperatorBase::Output<TensorCPU>(0);
150  output->Resize(std::vector<int>{});
151  *output->template mutable_data<T>() = counterPtr->countUp();
152  return true;
153  }
154 };
155 
156 // Will always use TensorCPU regardless the Context
157 template <typename T, class Context>
158 class RetrieveCountOp final : public Operator<Context> {
159  public:
160  USE_OPERATOR_CONTEXT_FUNCTIONS;
161  RetrieveCountOp(const OperatorDef& operator_def, Workspace* ws)
162  : Operator<Context>(operator_def, ws) {}
163 
164  bool RunOnDevice() override {
165  auto& counterPtr = OperatorBase::Input<std::unique_ptr<Counter<T>>>(0);
166  auto* output = OperatorBase::Output<TensorCPU>(0);
167  output->Resize(std::vector<int>{});
168  *output->template mutable_data<T>() = counterPtr->retrieve();
169  return true;
170  }
171 };
172 
173 } // namespace caffe2
174 #endif // CAFFE2_OPERATORS_COUNTER_OPS_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.