Caffe2 - C++ API
A deep learning, cross platform ML framework
store_ops.cc
1 
17 #include "store_ops.h"
18 
19 namespace caffe2 {
20 
21 constexpr auto kBlobName = "blob_name";
22 constexpr auto kAddValue = "add_value";
23 
24 StoreSetOp::StoreSetOp(const OperatorDef& operator_def, Workspace* ws)
25  : Operator<CPUContext>(operator_def, ws),
26  blobName_(
27  GetSingleArgument<std::string>(kBlobName, operator_def.input(DATA))) {
28 }
29 
30 bool StoreSetOp::RunOnDevice() {
31  // Serialize and pass to store
32  auto* handler =
33  OperatorBase::Input<std::unique_ptr<StoreHandler>>(HANDLER).get();
34  handler->set(blobName_, InputBlob(DATA).Serialize(blobName_));
35  return true;
36 }
37 
38 REGISTER_CPU_OPERATOR(StoreSet, StoreSetOp);
39 OPERATOR_SCHEMA(StoreSet)
40  .NumInputs(2)
41  .NumOutputs(0)
42  .SetDoc(R"DOC(
43 Set a blob in a store. The key is the input blob's name and the value
44 is the data in that blob. The key can be overridden by specifying the
45 'blob_name' argument.
46 )DOC")
47  .Arg("blob_name", "alternative key for the blob (optional)")
48  .Input(0, "handler", "unique_ptr<StoreHandler>")
49  .Input(1, "data", "data blob");
50 
51 StoreGetOp::StoreGetOp(const OperatorDef& operator_def, Workspace* ws)
52  : Operator<CPUContext>(operator_def, ws),
53  blobName_(GetSingleArgument<std::string>(
54  kBlobName,
55  operator_def.output(DATA))) {}
56 
57 bool StoreGetOp::RunOnDevice() {
58  // Get from store and deserialize
59  auto* handler =
60  OperatorBase::Input<std::unique_ptr<StoreHandler>>(HANDLER).get();
61  OperatorBase::Outputs()[DATA]->Deserialize(handler->get(blobName_));
62  return true;
63 }
64 
65 REGISTER_CPU_OPERATOR(StoreGet, StoreGetOp);
66 OPERATOR_SCHEMA(StoreGet)
67  .NumInputs(1)
68  .NumOutputs(1)
69  .SetDoc(R"DOC(
70 Get a blob from a store. The key is the output blob's name. The key
71 can be overridden by specifying the 'blob_name' argument.
72 )DOC")
73  .Arg("blob_name", "alternative key for the blob (optional)")
74  .Input(0, "handler", "unique_ptr<StoreHandler>")
75  .Output(0, "data", "data blob");
76 
77 StoreAddOp::StoreAddOp(const OperatorDef& operator_def, Workspace* ws)
78  : Operator<CPUContext>(operator_def, ws),
79  blobName_(GetSingleArgument<std::string>(kBlobName, "")),
80  addValue_(GetSingleArgument<int64_t>(kAddValue, 1)) {
81  CAFFE_ENFORCE(HasArgument(kBlobName));
82 }
83 
84 bool StoreAddOp::RunOnDevice() {
85  auto* handler =
86  OperatorBase::Input<std::unique_ptr<StoreHandler>>(HANDLER).get();
87  Output(VALUE)->Resize(1);
88  Output(VALUE)->mutable_data<int64_t>()[0] =
89  handler->add(blobName_, addValue_);
90  return true;
91 }
92 
93 REGISTER_CPU_OPERATOR(StoreAdd, StoreAddOp);
94 OPERATOR_SCHEMA(StoreAdd)
95  .NumInputs(1)
96  .NumOutputs(1)
97  .SetDoc(R"DOC(
98 Add a value to a remote counter. If the key is not set, the store
99 initializes it to 0 and then performs the add operation. The operation
100 returns the resulting counter value.
101 )DOC")
102  .Arg("blob_name", "key of the counter (required)")
103  .Arg("add_value", "value that is added (optional, default: 1)")
104  .Input(0, "handler", "unique_ptr<StoreHandler>")
105  .Output(0, "value", "the current value of the counter");
106 
107 StoreWaitOp::StoreWaitOp(const OperatorDef& operator_def, Workspace* ws)
108  : Operator<CPUContext>(operator_def, ws),
109  blobNames_(GetRepeatedArgument<std::string>(kBlobName)) {}
110 
111 bool StoreWaitOp::RunOnDevice() {
112  auto* handler =
113  OperatorBase::Input<std::unique_ptr<StoreHandler>>(HANDLER).get();
114  if (InputSize() == 2 && Input(1).IsType<std::string>()) {
115  CAFFE_ENFORCE(
116  blobNames_.empty(), "cannot specify both argument and input blob");
117  std::vector<std::string> blobNames;
118  auto* namesPtr = Input(1).data<std::string>();
119  for (int i = 0; i < Input(1).size(); ++i) {
120  blobNames.push_back(namesPtr[i]);
121  }
122  handler->wait(blobNames);
123  } else {
124  handler->wait(blobNames_);
125  }
126  return true;
127 }
128 
129 REGISTER_CPU_OPERATOR(StoreWait, StoreWaitOp);
130 OPERATOR_SCHEMA(StoreWait)
131  .NumInputs(1, 2)
132  .NumOutputs(0)
133  .SetDoc(R"DOC(
134 Wait for the specified blob names to be set. The blob names can be passed
135 either as an input blob with blob names or as an argument.
136 )DOC")
137  .Arg("blob_names", "names of the blobs to wait for (optional)")
138  .Input(0, "handler", "unique_ptr<StoreHandler>")
139  .Input(1, "names", "names of the blobs to wait for (optional)");
140 }
Definition: types.h:88
Copyright (c) 2016-present, Facebook, Inc.