Caffe2 - C++ API
A deep learning, cross platform ML framework
do_op.h
1 
17 #ifndef CAFFE2_OPERATORS_DO_OP_H_
18 #define CAFFE2_OPERATORS_DO_OP_H_
19 
20 #include <string>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24 
25 #include "caffe2/core/context.h"
26 #include "caffe2/core/logging.h"
27 #include "caffe2/core/operator.h"
28 #include "caffe2/operators/create_scope_op.h"
29 #include "caffe2/proto/caffe2.pb.h"
30 
31 namespace caffe2 {
32 
33 template <class Context>
34 class DoOp final : public Operator<Context> {
35  public:
36  DoOp(const OperatorDef& operator_def, Workspace* ws)
37  : Operator<Context>(operator_def, ws), parent_ws_(ws) {
38  CAFFE_ENFORCE(
39  this->template HasSingleArgumentOfType<NetDef>("net"),
40  "net must be specified in Do operator");
41  net_def_ = this->template GetSingleArgument<NetDef>("net", NetDef());
42  is_gradient_op_ = operator_def.is_gradient_op();
43  copy_external_blobs_ =
44  this->template GetSingleArgument<bool>("copy_external_blobs", false);
45  reuse_workspace_ =
46  this->template GetSingleArgument<bool>("reuse_workspace", false);
47  CAFFE_ENFORCE(
48  !(is_gradient_op_ && reuse_workspace_),
49  "Gradient Do op requires use of stacked workspaces");
50  CAFFE_ENFORCE(
51  !(copy_external_blobs_ && reuse_workspace_),
52  "Reuse workspace and copy external blobs simultaneously in Do op");
53 
54  const auto& inner_blobs =
55  this->template GetRepeatedArgument<std::string>("inner_blobs");
56  const auto& outer_blobs_idx =
57  this->template GetRepeatedArgument<int>("outer_blobs_idx");
58  CAFFE_ENFORCE_EQ(
59  inner_blobs.size(),
60  outer_blobs_idx.size(),
61  "Invalid blob bindings: different inner/outer blobs lengths");
62 
63  const auto& outer_blob_names = checkAndGetOuterNames(operator_def);
64  std::unordered_set<std::string> used_outer_names;
65  for (size_t blob_idx = 0; blob_idx < inner_blobs.size(); ++blob_idx) {
66  CAFFE_ENFORCE(
67  !blob_bindings_.count(inner_blobs[blob_idx]),
68  "Invalid blob bindings: redefinition of inner blob " +
69  inner_blobs[blob_idx]);
70  CAFFE_ENFORCE(
71  outer_blobs_idx[blob_idx] >= 0 &&
72  outer_blobs_idx[blob_idx] < outer_blob_names.size(),
73  "Invalid blob bindings: outer blob index (" +
74  caffe2::to_string(outer_blobs_idx[blob_idx]) + ", inner name: " +
75  inner_blobs[blob_idx] + ") is out of bounds [0, " +
76  caffe2::to_string(outer_blob_names.size() - 1) + "]");
77  const auto& outer_name = outer_blob_names[outer_blobs_idx[blob_idx]];
78  CAFFE_ENFORCE(
79  !used_outer_names.count(outer_name),
80  "Reusage of outer name: " + outer_name);
81  used_outer_names.insert(outer_name);
82  blob_bindings_[inner_blobs[blob_idx]] = outer_name;
83  forwarded_inner_blobs_.insert(inner_blobs[blob_idx]);
84  }
85  std::unordered_set<std::string> all_outer_names(
86  outer_blob_names.begin(), outer_blob_names.end());
87  CAFFE_ENFORCE_EQ(
88  used_outer_names.size(),
89  all_outer_names.size(),
90  "Not all outer names are used in blob bindings");
91  }
92 
93  USE_OPERATOR_CONTEXT_FUNCTIONS;
94 
95  bool RunOnDevice() override {
96  auto* ws_stack =
97  OperatorBase::Output<detail::WorkspaceStack>(OutputSize() - 1);
98  std::shared_ptr<Workspace> net_workspace;
99  if (is_gradient_op_) {
100  net_workspace =
101  ws_stack->popGradientWorkspace(parent_ws_, blob_bindings_);
102  } else {
103  if (reuse_workspace_ && !ws_stack->empty()) {
104  net_workspace =
105  ws_stack->reuseLastForwardWorkspace(parent_ws_, blob_bindings_);
106  } else {
107  net_workspace =
108  ws_stack->pushForwardWorkspace(parent_ws_, blob_bindings_);
109  }
110  }
111  CAFFE_ENFORCE(net_workspace, "Failed to initialize Do op workspace");
112 
113  // TODO(iliacher): figure how to reuse existing net with a new workspace
114  auto* net = net_workspace->GetNet(net_def_.name());
115  if (!net) {
116  net = net_workspace->CreateNet(net_def_, true);
117  }
118  CAFFE_ENFORCE(net, "Failed to initialize subnet");
119  auto success = net->Run();
120  if (!is_gradient_op_ && copy_external_blobs_) {
121  net_workspace->template CopyForwardedTensors<Context>(
122  forwarded_inner_blobs_);
123  }
124  return success;
125  }
126 
127  private:
128  // returns vector of input blob names followed by output blob names in
129  // operator definition order; ensures that input (output) names are unique,
130  // checks number of input (output) blobs
131  std::vector<std::string> checkAndGetOuterNames(
132  const OperatorDef& operator_def) const {
133  auto input_names = getInputBlobNames(operator_def);
134  CAFFE_ENFORCE(!input_names.empty(), "Expected at least one input blob");
135  std::string input_ws_blob = input_names.back(); // copy
136  // removing blob that holds pointer op workspace
137  input_names.pop_back();
138 
139  std::unordered_set<std::string> all_input_names(
140  input_names.begin(), input_names.end());
141  CAFFE_ENFORCE_EQ(
142  input_names.size(), all_input_names.size(), "Duplicate input blobs");
143 
144  auto output_names = getOutputBlobNames(operator_def);
145  CAFFE_ENFORCE(!output_names.empty(), "Expected at least one output blob");
146  const auto& output_ws_blob = output_names.back();
147  CAFFE_ENFORCE_EQ(
148  input_ws_blob,
149  output_ws_blob,
150  "Expected same input/output workspace blob");
151  // remove blob that holds pointer to op workspace
152  output_names.pop_back();
153 
154  std::unordered_set<std::string> all_output_names(
155  output_names.begin(), output_names.end());
156  CAFFE_ENFORCE_EQ(
157  output_names.size(), all_output_names.size(), "Duplicate output blobs");
158 
159  std::vector<std::string> outer_blob_names;
160  outer_blob_names.reserve(input_names.size() + output_names.size());
161  outer_blob_names.insert(
162  outer_blob_names.end(), input_names.begin(), input_names.end());
163  outer_blob_names.insert(
164  outer_blob_names.end(), output_names.begin(), output_names.end());
165  return outer_blob_names;
166  }
167 
168  std::vector<std::string> getInputBlobNames(
169  const OperatorDef& operator_def) const {
170  std::vector<std::string> names;
171  names.reserve(operator_def.input_size());
172  for (auto idx = 0; idx < operator_def.input_size(); ++idx) {
173  names.push_back(operator_def.input(idx));
174  }
175  return names;
176  }
177 
178  std::vector<std::string> getOutputBlobNames(
179  const OperatorDef& operator_def) const {
180  std::vector<std::string> names;
181  names.reserve(operator_def.output_size());
182  for (auto idx = 0; idx < operator_def.output_size(); ++idx) {
183  names.push_back(operator_def.output(idx));
184  }
185  return names;
186  }
187 
188  std::unordered_map<std::string, std::string> blob_bindings_;
189  std::unordered_set<std::string> forwarded_inner_blobs_;
190  bool is_gradient_op_;
191  bool copy_external_blobs_;
192  bool reuse_workspace_;
193  NetDef net_def_;
194  Workspace* parent_ws_;
195 };
196 
197 } // namespace caffe2
198 
199 #endif // CAFFE2_OPERATORS_DO_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.