1 #include "counter_ops.h" 2 #include "caffe2/core/blob_serialization.h" 6 const char* githubLinks = R
"DOC( 8 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/counter_ops.cc 12 const char* kCountExample = R
"DOC( 15 <summary> <b>Example</b> </summary> 20 workspace.ResetWorkspace() 22 createcounter_op = core.CreateOperator( 29 retrievecount_op = core.CreateOperator( 35 checkcounterdone_op = core.CreateOperator( 41 countup_op = core.CreateOperator( 47 countdown_op = core.CreateOperator( 53 resetcounter_op = core.CreateOperator( 62 workspace.RunOperatorOnce(createcounter_op) 63 print("'counter' pointer:", workspace.FetchBlob("counter")) 66 // Retrieve initial counter value 67 workspace.RunOperatorOnce(retrievecount_op) 68 print("Initial 'count':", workspace.FetchBlob("count")) 71 // Check if counter is done 72 workspace.RunOperatorOnce(checkcounterdone_op) 73 print("Initial 'done' value:", workspace.FetchBlob("done")) 76 // Test CountUp operator 77 print("\nTesting CountUp operator...") 79 workspace.RunOperatorOnce(countup_op) 80 print("'previous_count' after CountUp:", workspace.FetchBlob("previous_count")) 82 workspace.RunOperatorOnce(retrievecount_op) 83 print("'count' value after CountUp test:", workspace.FetchBlob("count")) 86 // Test CountDown operator 87 print("\nTesting CountDown operator...") 89 workspace.RunOperatorOnce(countdown_op) 90 workspace.RunOperatorOnce(retrievecount_op) 91 print("'count' value after CountDown: {}\t'done' value: {}".format(workspace.FetchBlob("count"), workspace.FetchBlob("done"))) 97 'counter' pointer: counter, a C++ native class of type std::__1::unique_ptr<caffe2::Counter<long long>, std::__1::default_delete<caffe2::Counter<long long> > >. 99 Initial 'done' value: False 101 Testing CountUp operator... 102 'previous_count' after CountUp: 5 103 'previous_count' after CountUp: 6 104 'previous_count' after CountUp: 7 105 'previous_count' after CountUp: 8 106 'previous_count' after CountUp: 9 107 'count' value after CountUp test: 10 109 Testing CountDown operator... 110 'count' value after CountDown: 9 'done' value: False 111 'count' value after CountDown: 8 'done' value: False 112 'count' value after CountDown: 7 'done' value: False 113 'count' value after CountDown: 6 'done' value: False 114 'count' value after CountDown: 5 'done' value: False 115 'count' value after CountDown: 4 'done' value: False 116 'count' value after CountDown: 3 'done' value: False 117 'count' value after CountDown: 2 'done' value: False 118 'count' value after CountDown: 1 'done' value: False 119 'count' value after CountDown: 0 'done' value: False 120 'count' value after CountDown: -1 'done' value: True 136 class CounterSerializer :
public BlobSerializerBase {
138 CounterSerializer() {}
139 ~CounterSerializer()
override {}
145 SerializationAcceptor acceptor)
override {
146 CAFFE_ENFORCE(typeMeta.Match<std::unique_ptr<Counter<int64_t>>>());
148 BlobProto blob_proto;
149 blob_proto.set_name(name);
150 blob_proto.set_type(
"std::unique_ptr<Counter<int64_t>>");
151 TensorProto& proto = *blob_proto.mutable_tensor();
152 proto.set_name(name);
153 proto.set_data_type(TensorProto_DataType_INT64);
155 proto.add_int64_data(
156 (*
static_cast<const std::unique_ptr<Counter<int64_t>
>*>(pointer))
158 acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
166 class CounterDeserializer :
public BlobDeserializerBase {
168 void Deserialize(
const BlobProto& proto, Blob* blob)
override {
169 auto tensorProto = proto.tensor();
170 CAFFE_ENFORCE_EQ(tensorProto.dims_size(), 1,
"Unexpected size of dims");
171 CAFFE_ENFORCE_EQ(tensorProto.dims(0), 1,
"Unexpected value of dims");
173 tensorProto.data_type(),
174 TensorProto_DataType_INT64,
175 "Only int64_t counters supported");
177 tensorProto.int64_data_size(), 1,
"Unexpected size of data");
178 *blob->GetMutable<std::unique_ptr<Counter<int64_t>>>() =
179 caffe2::make_unique<Counter<int64_t>>(tensorProto.int64_data(0));
187 REGISTER_CPU_OPERATOR(CreateCounter, CreateCounterOp<int64_t, CPUContext>);
188 REGISTER_CPU_OPERATOR(ResetCounter, ResetCounterOp<int64_t, CPUContext>);
189 REGISTER_CPU_OPERATOR(CountDown, CountDownOp<int64_t, CPUContext>);
190 REGISTER_CPU_OPERATOR(
192 CheckCounterDoneOp<int64_t, CPUContext>);
193 REGISTER_CPU_OPERATOR(CountUp, CountUpOp<int64_t, CPUContext>);
194 REGISTER_CPU_OPERATOR(RetrieveCount, RetrieveCountOp<int64_t, CPUContext>);
196 OPERATOR_SCHEMA(CreateCounter)
200 Creates a count-down counter with initial value specified by the `init_count` 203 )DOC" + (string) githubLinks + (
string) kCountExample)
207 "*(type: Tensor`<ptr>`)* A blob pointing to an instance of a new counter.")
210 "*(type: int; default: 0)* Initial count for the counter, must be >= 0.");
212 OPERATOR_SCHEMA(ResetCounter)
216 Resets a count-down counter with initial value specified by the `init_count` 218 )DOC" + (string) githubLinks + (
string) kCountExample)
222 "*(type: Tensor`<ptr>`)* A blob pointing to an instance of a counter.")
226 "*(type: int)* [OPTIONAL] count value BEFORE this operation.")
229 "*(type: int; default: 0)* Resets counter to this value, must be >= 0.");
231 OPERATOR_SCHEMA(CountDown)
235 If the internal count value > 0, decreases count value by 1 and outputs False, 236 otherwise outputs True. 237 )DOC" + (string) githubLinks + (
string) kCountExample)
241 "*(type: Tensor`<ptr>`)* A blob pointing to an instance of a counter.")
245 "*(type: bool)* False unless the internal count is zero.");
247 OPERATOR_SCHEMA(CheckCounterDone)
251 If the internal count value <= 0, outputs true, otherwise outputs false. 252 )DOC" + (string) githubLinks + (
string) kCountExample)
256 "*(type: Tensor`<ptr>`)* A blob pointing to an instance of a counter.")
260 "*(type: bool)* True if the internal count is zero or negative, otherwise False.");
262 OPERATOR_SCHEMA(CountUp)
266 Increases count value by 1 and outputs the previous value atomically. 267 )DOC" + (string) githubLinks + (
string) kCountExample)
271 "*(type: Tensor`<ptr>`)* A blob pointing to an instance of a counter.")
275 "*(type: int)* Count value BEFORE this operation.");
277 OPERATOR_SCHEMA(RetrieveCount)
280 .ScalarType(TensorProto::INT64)
282 Retrieve the current value from the counter as an integer. 283 )DOC" + (string) githubLinks + (
string) kCountExample)
287 "*(type: Tensor`<ptr>`)* A blob pointing to an instance of a counter.")
291 "*(type: int)* Current count value.");
293 SHOULD_NOT_DO_GRADIENT(CreateCounter);
294 SHOULD_NOT_DO_GRADIENT(ResetCounter);
295 SHOULD_NOT_DO_GRADIENT(CountDown);
296 SHOULD_NOT_DO_GRADIENT(CountUp);
297 SHOULD_NOT_DO_GRADIENT(RetrieveCount);
299 CAFFE_KNOWN_TYPE(std::unique_ptr<Counter<int64_t>>);
300 REGISTER_BLOB_SERIALIZER(
301 (TypeMeta::Id<std::unique_ptr<Counter<int64_t>>>()),
303 REGISTER_BLOB_DESERIALIZER(
304 std::unique_ptr<Counter<int64_t>>,
305 CounterDeserializer);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...