Caffe2 - C++ API
A deep learning, cross platform ML framework
counter_ops.cc
1 #include "counter_ops.h"
2 #include "caffe2/core/blob_serialization.h"
3 
4 namespace caffe2 {
5 
6 const char* githubLinks = R"DOC(
7  Github Links:
8  - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/counter_ops.cc
9 
10 )DOC";
11 
12 const char* kCountExample = R"DOC(
13 <details>
14 
15 <summary> <b>Example</b> </summary>
16 
17 **Code**
18 
19 ```
20 workspace.ResetWorkspace()
21 
22 createcounter_op = core.CreateOperator(
23  "CreateCounter",
24  [],
25  ["counter"],
26  init_count=5
27 )
28 
29 retrievecount_op = core.CreateOperator(
30  "RetrieveCount",
31  ["counter"],
32  ["count"]
33 )
34 
35 checkcounterdone_op = core.CreateOperator(
36  "CheckCounterDone",
37  ["counter"],
38  ["done"]
39 )
40 
41 countup_op = core.CreateOperator(
42  "CountUp",
43  ["counter"],
44  ["previous_count"],
45 )
46 
47 countdown_op = core.CreateOperator(
48  "CountDown",
49  ["counter"],
50  ["done"],
51 )
52 
53 resetcounter_op = core.CreateOperator(
54  "ResetCounter",
55  ["counter"],
56  ["previous_count"],
57  init_count=3
58 )
59 
60 
61 // Create counter
62 workspace.RunOperatorOnce(createcounter_op)
63 print("'counter' pointer:", workspace.FetchBlob("counter"))
64 
65 
66 // Retrieve initial counter value
67 workspace.RunOperatorOnce(retrievecount_op)
68 print("Initial 'count':", workspace.FetchBlob("count"))
69 
70 
71 // Check if counter is done
72 workspace.RunOperatorOnce(checkcounterdone_op)
73 print("Initial 'done' value:", workspace.FetchBlob("done"))
74 
75 
76 // Test CountUp operator
77 print("\nTesting CountUp operator...")
78 for i in range(5):
79  workspace.RunOperatorOnce(countup_op)
80  print("'previous_count' after CountUp:", workspace.FetchBlob("previous_count"))
81 
82 workspace.RunOperatorOnce(retrievecount_op)
83 print("'count' value after CountUp test:", workspace.FetchBlob("count"))
84 
85 
86 // Test CountDown operator
87 print("\nTesting CountDown operator...")
88 for i in range(11):
89  workspace.RunOperatorOnce(countdown_op)
90  workspace.RunOperatorOnce(retrievecount_op)
91  print("'count' value after CountDown: {}\t'done' value: {}".format(workspace.FetchBlob("count"), workspace.FetchBlob("done")))
92 ```
93 
94 **Result**
95 
96 ```
97 'counter' pointer: counter, a C++ native class of type std::__1::unique_ptr<caffe2::Counter<long long>, std::__1::default_delete<caffe2::Counter<long long> > >.
98 Initial 'count': 5
99 Initial 'done' value: False
100 
101 Testing CountUp operator...
102 'previous_count' after CountUp: 5
103 'previous_count' after CountUp: 6
104 'previous_count' after CountUp: 7
105 'previous_count' after CountUp: 8
106 'previous_count' after CountUp: 9
107 'count' value after CountUp test: 10
108 
109 Testing CountDown operator...
110 'count' value after CountDown: 9 'done' value: False
111 'count' value after CountDown: 8 'done' value: False
112 'count' value after CountDown: 7 'done' value: False
113 'count' value after CountDown: 6 'done' value: False
114 'count' value after CountDown: 5 'done' value: False
115 'count' value after CountDown: 4 'done' value: False
116 'count' value after CountDown: 3 'done' value: False
117 'count' value after CountDown: 2 'done' value: False
118 'count' value after CountDown: 1 'done' value: False
119 'count' value after CountDown: 0 'done' value: False
120 'count' value after CountDown: -1 'done' value: True
121 ```
122 
123 </details>
124 
125 )DOC";
126 
127 namespace {
136 class CounterSerializer : public BlobSerializerBase {
137  public:
138  CounterSerializer() {}
139  ~CounterSerializer() override {}
140 
141  void Serialize(
142  const void* pointer,
143  TypeMeta typeMeta,
144  const string& name,
145  SerializationAcceptor acceptor) override {
146  CAFFE_ENFORCE(typeMeta.Match<std::unique_ptr<Counter<int64_t>>>());
147 
148  BlobProto blob_proto;
149  blob_proto.set_name(name);
150  blob_proto.set_type("std::unique_ptr<Counter<int64_t>>");
151  TensorProto& proto = *blob_proto.mutable_tensor();
152  proto.set_name(name);
153  proto.set_data_type(TensorProto_DataType_INT64);
154  proto.add_dims(1);
155  proto.add_int64_data(
156  (*static_cast<const std::unique_ptr<Counter<int64_t>>*>(pointer))
157  ->retrieve());
158  acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
159  }
160 };
161 
166 class CounterDeserializer : public BlobDeserializerBase {
167  public:
168  void Deserialize(const BlobProto& proto, Blob* blob) override {
169  auto tensorProto = proto.tensor();
170  CAFFE_ENFORCE_EQ(tensorProto.dims_size(), 1, "Unexpected size of dims");
171  CAFFE_ENFORCE_EQ(tensorProto.dims(0), 1, "Unexpected value of dims");
172  CAFFE_ENFORCE_EQ(
173  tensorProto.data_type(),
174  TensorProto_DataType_INT64,
175  "Only int64_t counters supported");
176  CAFFE_ENFORCE_EQ(
177  tensorProto.int64_data_size(), 1, "Unexpected size of data");
178  *blob->GetMutable<std::unique_ptr<Counter<int64_t>>>() =
179  caffe2::make_unique<Counter<int64_t>>(tensorProto.int64_data(0));
180  }
181 };
182 }
183 
184 // TODO(jiayq): deprecate these ops & consolidate them with
185 // IterOp/AtomicIterOp
186 
187 REGISTER_CPU_OPERATOR(CreateCounter, CreateCounterOp<int64_t, CPUContext>);
188 REGISTER_CPU_OPERATOR(ResetCounter, ResetCounterOp<int64_t, CPUContext>);
189 REGISTER_CPU_OPERATOR(CountDown, CountDownOp<int64_t, CPUContext>);
190 REGISTER_CPU_OPERATOR(
191  CheckCounterDone,
192  CheckCounterDoneOp<int64_t, CPUContext>);
193 REGISTER_CPU_OPERATOR(CountUp, CountUpOp<int64_t, CPUContext>);
194 REGISTER_CPU_OPERATOR(RetrieveCount, RetrieveCountOp<int64_t, CPUContext>);
195 
196 OPERATOR_SCHEMA(CreateCounter)
197  .NumInputs(0)
198  .NumOutputs(1)
199  .SetDoc(R"DOC(
200 Creates a count-down counter with initial value specified by the `init_count`
201 argument.
202 
203 )DOC" + (string) githubLinks + (string) kCountExample)
204  .Output(
205  0,
206  "counter",
207  "*(type: Tensor`<ptr>`)* A blob pointing to an instance of a new counter.")
208  .Arg(
209  "init_count",
210  "*(type: int; default: 0)* Initial count for the counter, must be >= 0.");
211 
212 OPERATOR_SCHEMA(ResetCounter)
213  .NumInputs(1)
214  .NumOutputs(0, 1)
215  .SetDoc(R"DOC(
216 Resets a count-down counter with initial value specified by the `init_count`
217 argument.
218 )DOC" + (string) githubLinks + (string) kCountExample)
219  .Input(
220  0,
221  "counter",
222  "*(type: Tensor`<ptr>`)* A blob pointing to an instance of a counter.")
223  .Output(
224  0,
225  "previous_value",
226  "*(type: int)* [OPTIONAL] count value BEFORE this operation.")
227  .Arg(
228  "init_count",
229  "*(type: int; default: 0)* Resets counter to this value, must be >= 0.");
230 
231 OPERATOR_SCHEMA(CountDown)
232  .NumInputs(1)
233  .NumOutputs(1)
234  .SetDoc(R"DOC(
235 If the internal count value > 0, decreases count value by 1 and outputs False,
236 otherwise outputs True.
237 )DOC" + (string) githubLinks + (string) kCountExample)
238  .Input(
239  0,
240  "counter",
241  "*(type: Tensor`<ptr>`)* A blob pointing to an instance of a counter.")
242  .Output(
243  0,
244  "done",
245  "*(type: bool)* False unless the internal count is zero.");
246 
247 OPERATOR_SCHEMA(CheckCounterDone)
248  .NumInputs(1)
249  .NumOutputs(1)
250  .SetDoc(R"DOC(
251 If the internal count value <= 0, outputs true, otherwise outputs false.
252 )DOC" + (string) githubLinks + (string) kCountExample)
253  .Input(
254  0,
255  "counter",
256  "*(type: Tensor`<ptr>`)* A blob pointing to an instance of a counter.")
257  .Output(
258  0,
259  "done",
260  "*(type: bool)* True if the internal count is zero or negative, otherwise False.");
261 
262 OPERATOR_SCHEMA(CountUp)
263  .NumInputs(1)
264  .NumOutputs(1)
265  .SetDoc(R"DOC(
266 Increases count value by 1 and outputs the previous value atomically.
267 )DOC" + (string) githubLinks + (string) kCountExample)
268  .Input(
269  0,
270  "counter",
271  "*(type: Tensor`<ptr>`)* A blob pointing to an instance of a counter.")
272  .Output(
273  0,
274  "previous_count",
275  "*(type: int)* Count value BEFORE this operation.");
276 
277 OPERATOR_SCHEMA(RetrieveCount)
278  .NumInputs(1)
279  .NumOutputs(1)
280  .ScalarType(TensorProto::INT64)
281  .SetDoc(R"DOC(
282 Retrieve the current value from the counter as an integer.
283 )DOC" + (string) githubLinks + (string) kCountExample)
284  .Input(
285  0,
286  "counter",
287  "*(type: Tensor`<ptr>`)* A blob pointing to an instance of a counter.")
288  .Output(
289  0,
290  "count",
291  "*(type: int)* Current count value.");
292 
293 SHOULD_NOT_DO_GRADIENT(CreateCounter);
294 SHOULD_NOT_DO_GRADIENT(ResetCounter);
295 SHOULD_NOT_DO_GRADIENT(CountDown);
296 SHOULD_NOT_DO_GRADIENT(CountUp);
297 SHOULD_NOT_DO_GRADIENT(RetrieveCount);
298 
299 CAFFE_KNOWN_TYPE(std::unique_ptr<Counter<int64_t>>);
300 REGISTER_BLOB_SERIALIZER(
301  (TypeMeta::Id<std::unique_ptr<Counter<int64_t>>>()),
302  CounterSerializer);
303 REGISTER_BLOB_DESERIALIZER(
304  std::unique_ptr<Counter<int64_t>>,
305  CounterDeserializer);
306 
307 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13