Caffe2 - C++ API
A deep learning, cross platform ML framework
store_ops.cc
1 #include "store_ops.h"
2 
3 #include "caffe2/core/blob_serialization.h"
4 
5 namespace caffe2 {
6 
7 constexpr auto kBlobName = "blob_name";
8 constexpr auto kAddValue = "add_value";
9 
10 StoreSetOp::StoreSetOp(const OperatorDef& operator_def, Workspace* ws)
11  : Operator<CPUContext>(operator_def, ws),
12  blobName_(
13  GetSingleArgument<std::string>(kBlobName, operator_def.input(DATA))) {
14 }
15 
16 bool StoreSetOp::RunOnDevice() {
17  // Serialize and pass to store
18  auto* handler =
19  OperatorBase::Input<std::unique_ptr<StoreHandler>>(HANDLER).get();
20  handler->set(blobName_, SerializeBlob(InputBlob(DATA), blobName_));
21  return true;
22 }
23 
24 REGISTER_CPU_OPERATOR(StoreSet, StoreSetOp);
25 OPERATOR_SCHEMA(StoreSet)
26  .NumInputs(2)
27  .NumOutputs(0)
28  .SetDoc(R"DOC(
29 Set a blob in a store. The key is the input blob's name and the value
30 is the data in that blob. The key can be overridden by specifying the
31 'blob_name' argument.
32 )DOC")
33  .Arg("blob_name", "alternative key for the blob (optional)")
34  .Input(0, "handler", "unique_ptr<StoreHandler>")
35  .Input(1, "data", "data blob");
36 
37 StoreGetOp::StoreGetOp(const OperatorDef& operator_def, Workspace* ws)
38  : Operator<CPUContext>(operator_def, ws),
39  blobName_(GetSingleArgument<std::string>(
40  kBlobName,
41  operator_def.output(DATA))) {}
42 
43 bool StoreGetOp::RunOnDevice() {
44  // Get from store and deserialize
45  auto* handler =
46  OperatorBase::Input<std::unique_ptr<StoreHandler>>(HANDLER).get();
47  DeserializeBlob(handler->get(blobName_), OperatorBase::Outputs()[DATA]);
48  return true;
49 }
50 
51 REGISTER_CPU_OPERATOR(StoreGet, StoreGetOp);
52 OPERATOR_SCHEMA(StoreGet)
53  .NumInputs(1)
54  .NumOutputs(1)
55  .SetDoc(R"DOC(
56 Get a blob from a store. The key is the output blob's name. The key
57 can be overridden by specifying the 'blob_name' argument.
58 )DOC")
59  .Arg("blob_name", "alternative key for the blob (optional)")
60  .Input(0, "handler", "unique_ptr<StoreHandler>")
61  .Output(0, "data", "data blob");
62 
63 StoreAddOp::StoreAddOp(const OperatorDef& operator_def, Workspace* ws)
64  : Operator<CPUContext>(operator_def, ws),
65  blobName_(GetSingleArgument<std::string>(kBlobName, "")),
66  addValue_(GetSingleArgument<int64_t>(kAddValue, 1)) {
67  CAFFE_ENFORCE(HasArgument(kBlobName));
68 }
69 
70 bool StoreAddOp::RunOnDevice() {
71  auto* handler =
72  OperatorBase::Input<std::unique_ptr<StoreHandler>>(HANDLER).get();
73  Output(VALUE)->Resize(1);
74  Output(VALUE)->mutable_data<int64_t>()[0] =
75  handler->add(blobName_, addValue_);
76  return true;
77 }
78 
79 REGISTER_CPU_OPERATOR(StoreAdd, StoreAddOp);
80 OPERATOR_SCHEMA(StoreAdd)
81  .NumInputs(1)
82  .NumOutputs(1)
83  .SetDoc(R"DOC(
84 Add a value to a remote counter. If the key is not set, the store
85 initializes it to 0 and then performs the add operation. The operation
86 returns the resulting counter value.
87 )DOC")
88  .Arg("blob_name", "key of the counter (required)")
89  .Arg("add_value", "value that is added (optional, default: 1)")
90  .Input(0, "handler", "unique_ptr<StoreHandler>")
91  .Output(0, "value", "the current value of the counter");
92 
93 StoreWaitOp::StoreWaitOp(const OperatorDef& operator_def, Workspace* ws)
94  : Operator<CPUContext>(operator_def, ws),
95  blobNames_(GetRepeatedArgument<std::string>(kBlobName)) {}
96 
97 bool StoreWaitOp::RunOnDevice() {
98  auto* handler =
99  OperatorBase::Input<std::unique_ptr<StoreHandler>>(HANDLER).get();
100  if (InputSize() == 2 && Input(1).IsType<std::string>()) {
101  CAFFE_ENFORCE(
102  blobNames_.empty(), "cannot specify both argument and input blob");
103  std::vector<std::string> blobNames;
104  auto* namesPtr = Input(1).data<std::string>();
105  for (int i = 0; i < Input(1).size(); ++i) {
106  blobNames.push_back(namesPtr[i]);
107  }
108  handler->wait(blobNames);
109  } else {
110  handler->wait(blobNames_);
111  }
112  return true;
113 }
114 
115 REGISTER_CPU_OPERATOR(StoreWait, StoreWaitOp);
116 OPERATOR_SCHEMA(StoreWait)
117  .NumInputs(1, 2)
118  .NumOutputs(0)
119  .SetDoc(R"DOC(
120 Wait for the specified blob names to be set. The blob names can be passed
121 either as an input blob with blob names or as an argument.
122 )DOC")
123  .Arg("blob_names", "names of the blobs to wait for (optional)")
124  .Input(0, "handler", "unique_ptr<StoreHandler>")
125  .Input(1, "names", "names of the blobs to wait for (optional)");
126 }
void DeserializeBlob(const string &content, Blob *result)
Deserializes from a string containing either BlobProto or TensorProto.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
void SerializeBlob(const Blob &blob, const string &name, BlobSerializerBase::SerializationAcceptor acceptor, int chunk_size)
Serializes the given blob, if possible.