Caffe2 - C++ API
A deep learning, cross platform ML framework
counter_ops.cc
1 
17 #include "counter_ops.h"
18 
19 #include "caffe2/core/blob_serialization.h"
20 
21 namespace caffe2 {
22 namespace {
31 class CounterSerializer : public BlobSerializerBase {
32  public:
33  CounterSerializer() {}
34  ~CounterSerializer() {}
35 
36  void Serialize(
37  const Blob& blob,
38  const string& name,
39  SerializationAcceptor acceptor) override {
40  CAFFE_ENFORCE(blob.IsType<std::unique_ptr<Counter<int64_t>>>());
41 
42  BlobProto blob_proto;
43  blob_proto.set_name(name);
44  blob_proto.set_type("std::unique_ptr<Counter<int64_t>>");
45  TensorProto& proto = *blob_proto.mutable_tensor();
46  proto.set_name(name);
47  proto.set_data_type(TensorProto_DataType_INT64);
48  proto.add_dims(1);
49  proto.add_int64_data(
50  blob.template Get<std::unique_ptr<Counter<int64_t>>>()->retrieve());
51  acceptor(name, blob_proto.SerializeAsString());
52  }
53 };
54 
59 class CounterDeserializer : public BlobDeserializerBase {
60  public:
61  void Deserialize(const BlobProto& proto, Blob* blob) override {
62  auto tensorProto = proto.tensor();
63  CAFFE_ENFORCE_EQ(tensorProto.dims_size(), 1, "Unexpected size of dims");
64  CAFFE_ENFORCE_EQ(tensorProto.dims(0), 1, "Unexpected value of dims");
65  CAFFE_ENFORCE_EQ(
66  tensorProto.data_type(),
67  TensorProto_DataType_INT64,
68  "Only int64_t counters supported");
69  CAFFE_ENFORCE_EQ(
70  tensorProto.int64_data_size(), 1, "Unexpected size of data");
71  *blob->GetMutable<std::unique_ptr<Counter<int64_t>>>() =
72  caffe2::make_unique<Counter<int64_t>>(tensorProto.int64_data(0));
73  }
74 };
75 }
76 
77 // TODO(jiayq): deprecate these ops & consolidate them with
78 // IterOp/AtomicIterOp
79 
80 REGISTER_CPU_OPERATOR(CreateCounter, CreateCounterOp<int64_t, CPUContext>);
81 REGISTER_CPU_OPERATOR(ResetCounter, ResetCounterOp<int64_t, CPUContext>);
82 REGISTER_CPU_OPERATOR(CountDown, CountDownOp<int64_t, CPUContext>);
83 REGISTER_CPU_OPERATOR(
84  CheckCounterDone,
85  CheckCounterDoneOp<int64_t, CPUContext>);
86 REGISTER_CPU_OPERATOR(CountUp, CountUpOp<int64_t, CPUContext>);
87 REGISTER_CPU_OPERATOR(RetrieveCount, RetrieveCountOp<int64_t, CPUContext>);
88 
89 OPERATOR_SCHEMA(CreateCounter)
90  .NumInputs(0)
91  .NumOutputs(1)
92  .SetDoc(R"DOC(
93 Creates a count-down counter with initial value specified by the 'init_count'
94 argument.
95 )DOC")
96  .Output(0, "counter", "A blob pointing to an instance of a new counter.")
97  .Arg("init_count", "Initial count for the counter, must be >= 0.");
98 
99 OPERATOR_SCHEMA(ResetCounter)
100  .NumInputs(1)
101  .NumOutputs(0, 1)
102  .SetDoc(R"DOC(
103 Resets a count-down counter with initial value specified by the 'init_count'
104 argument.
105 )DOC")
106  .Input(0, "counter", "A blob pointing to an instance of a new counter.")
107  .Output(0, "previous_value", "(optional) Previous value of the counter.")
108  .Arg("init_count", "Resets counter to this value, must be >= 0.");
109 
110 OPERATOR_SCHEMA(CountDown)
111  .NumInputs(1)
112  .NumOutputs(1)
113  .SetDoc(R"DOC(
114 If the internal count value > 0, decreases count value by 1 and outputs false,
115 otherwise outputs true.
116 )DOC")
117  .Input(0, "counter", "A blob pointing to an instance of a counter.")
118  .Output(0, "done", "false unless the internal count is zero.");
119 
120 OPERATOR_SCHEMA(CheckCounterDone)
121  .NumInputs(1)
122  .NumOutputs(1)
123  .SetDoc(R"DOC(
124 If the internal count value <= 0, outputs true, otherwise outputs false,
125 )DOC")
126  .Input(0, "counter", "A blob pointing to an instance of a counter.")
127  .Output(0, "done", "true if the internal count is zero or negative.");
128 
129 OPERATOR_SCHEMA(CountUp)
130  .NumInputs(1)
131  .NumOutputs(1)
132  .SetDoc(R"DOC(
133 Increases count value by 1 and outputs the previous value atomically
134 )DOC")
135  .Input(0, "counter", "A blob pointing to an instance of a counter.")
136  .Output(0, "previous_count", "count value BEFORE this operation");
137 
138 OPERATOR_SCHEMA(RetrieveCount)
139  .NumInputs(1)
140  .NumOutputs(1)
141  .ScalarType(TensorProto::INT64)
142  .SetDoc(R"DOC(
143 Retrieve the current value from the counter.
144 )DOC")
145  .Input(0, "counter", "A blob pointing to an instance of a counter.")
146  .Output(0, "count", "current count value.");
147 
148 SHOULD_NOT_DO_GRADIENT(CreateCounter);
149 SHOULD_NOT_DO_GRADIENT(ResetCounter);
150 SHOULD_NOT_DO_GRADIENT(CountDown);
151 SHOULD_NOT_DO_GRADIENT(CountUp);
152 SHOULD_NOT_DO_GRADIENT(RetrieveCount);
153 
154 CAFFE_KNOWN_TYPE(std::unique_ptr<Counter<int64_t>>);
155 REGISTER_BLOB_SERIALIZER(
156  (TypeMeta::Id<std::unique_ptr<Counter<int64_t>>>()),
157  CounterSerializer);
158 REGISTER_BLOB_DESERIALIZER(
159  std::unique_ptr<Counter<int64_t>>,
160  CounterDeserializer);
161 
162 } // namespace caffe2
static CAFFE2_API CaffeTypeId Id()
Returns the unique id for the given type T.
Copyright (c) 2016-present, Facebook, Inc.