Caffe2 - C++ API
A deep learning, cross platform ML framework
do_op.h
1 #ifndef CAFFE2_OPERATORS_DO_OP_H_
2 #define CAFFE2_OPERATORS_DO_OP_H_
3 
4 #include <string>
5 #include <unordered_map>
6 #include <unordered_set>
7 #include <vector>
8 
9 #include "caffe2/core/context.h"
10 #include "caffe2/core/logging.h"
11 #include "caffe2/core/operator.h"
12 #include "caffe2/operators/create_scope_op.h"
13 #include "caffe2/proto/caffe2_pb.h"
14 
15 namespace caffe2 {
16 
17 template <class Context>
18 class DoOp final : public Operator<Context> {
19  public:
20  explicit DoOp(const OperatorDef& operator_def, Workspace* ws)
21  : Operator<Context>(operator_def, ws), parent_ws_(ws) {
22  CAFFE_ENFORCE(
23  this->template HasSingleArgumentOfType<NetDef>("net"),
24  "net must be specified in Do operator");
25  net_def_ = this->template GetSingleArgument<NetDef>("net", NetDef());
26  is_gradient_op_ = operator_def.is_gradient_op();
27  copy_external_blobs_ =
28  this->template GetSingleArgument<bool>("copy_external_blobs", false);
29  reuse_workspace_ =
30  this->template GetSingleArgument<bool>("reuse_workspace", false);
31  CAFFE_ENFORCE(
32  !(is_gradient_op_ && reuse_workspace_),
33  "Gradient Do op requires use of stacked workspaces");
34  CAFFE_ENFORCE(
35  !(copy_external_blobs_ && reuse_workspace_),
36  "Reuse workspace and copy external blobs simultaneously in Do op");
37 
38  const auto& inner_blobs =
39  this->template GetRepeatedArgument<std::string>("inner_blobs");
40  const auto& outer_blobs_idx =
41  this->template GetRepeatedArgument<int>("outer_blobs_idx");
42  CAFFE_ENFORCE_EQ(
43  inner_blobs.size(),
44  outer_blobs_idx.size(),
45  "Invalid blob bindings: different inner/outer blobs lengths");
46 
47  const auto& outer_blob_names = checkAndGetOuterNames(operator_def);
48  std::unordered_set<std::string> used_outer_names;
49  for (size_t blob_idx = 0; blob_idx < inner_blobs.size(); ++blob_idx) {
50  CAFFE_ENFORCE(
51  !blob_bindings_.count(inner_blobs[blob_idx]),
52  "Invalid blob bindings: redefinition of inner blob " +
53  inner_blobs[blob_idx]);
54  CAFFE_ENFORCE(
55  outer_blobs_idx[blob_idx] >= 0 &&
56  outer_blobs_idx[blob_idx] < outer_blob_names.size(),
57  "Invalid blob bindings: outer blob index (" +
58  c10::to_string(outer_blobs_idx[blob_idx]) + ", inner name: " +
59  inner_blobs[blob_idx] + ") is out of bounds [0, " +
60  c10::to_string(outer_blob_names.size() - 1) + "]");
61  const auto& outer_name = outer_blob_names[outer_blobs_idx[blob_idx]];
62  CAFFE_ENFORCE(
63  !used_outer_names.count(outer_name),
64  "Reusage of outer name: " + outer_name);
65  used_outer_names.insert(outer_name);
66  blob_bindings_[inner_blobs[blob_idx]] = outer_name;
67  forwarded_inner_blobs_.insert(inner_blobs[blob_idx]);
68  }
69  std::unordered_set<std::string> all_outer_names(
70  outer_blob_names.begin(), outer_blob_names.end());
71  CAFFE_ENFORCE_EQ(
72  used_outer_names.size(),
73  all_outer_names.size(),
74  "Not all outer names are used in blob bindings");
75  }
76 
77  USE_OPERATOR_CONTEXT_FUNCTIONS;
78 
79  bool RunOnDevice() override {
80  auto* ws_stack =
81  this->template Output<detail::WorkspaceStack>(OutputSize() - 1);
82  std::shared_ptr<Workspace> net_workspace;
83  if (is_gradient_op_) {
84  net_workspace =
85  ws_stack->popGradientWorkspace(parent_ws_, blob_bindings_);
86  } else {
87  if (reuse_workspace_ && !ws_stack->empty()) {
88  net_workspace =
89  ws_stack->reuseLastForwardWorkspace(parent_ws_, blob_bindings_);
90  } else {
91  net_workspace =
92  ws_stack->pushForwardWorkspace(parent_ws_, blob_bindings_);
93  }
94  }
95  CAFFE_ENFORCE(net_workspace, "Failed to initialize Do op workspace");
96 
97  // TODO(iliacher): figure how to reuse existing net with a new workspace
98  auto* net = net_workspace->GetNet(net_def_.name());
99  if (!net) {
100  net = net_workspace->CreateNet(net_def_, true);
101  }
102  CAFFE_ENFORCE(net, "Failed to initialize subnet");
103  auto success = net->Run();
104  if (!is_gradient_op_ && copy_external_blobs_) {
105  net_workspace->template CopyForwardedTensors<Context>(
106  forwarded_inner_blobs_);
107  }
108  return success;
109  }
110 
111  private:
112  // returns vector of input blob names followed by output blob names in
113  // operator definition order; ensures that input (output) names are unique,
114  // checks number of input (output) blobs
115  std::vector<std::string> checkAndGetOuterNames(
116  const OperatorDef& operator_def) const {
117  auto input_names = getInputBlobNames(operator_def);
118  CAFFE_ENFORCE(!input_names.empty(), "Expected at least one input blob");
119  std::string input_ws_blob = input_names.back(); // copy
120  // removing blob that holds pointer op workspace
121  input_names.pop_back();
122 
123  std::unordered_set<std::string> all_input_names(
124  input_names.begin(), input_names.end());
125  CAFFE_ENFORCE_EQ(
126  input_names.size(), all_input_names.size(), "Duplicate input blobs");
127 
128  auto output_names = getOutputBlobNames(operator_def);
129  CAFFE_ENFORCE(!output_names.empty(), "Expected at least one output blob");
130  const auto& output_ws_blob = output_names.back();
131  CAFFE_ENFORCE_EQ(
132  input_ws_blob,
133  output_ws_blob,
134  "Expected same input/output workspace blob");
135  // remove blob that holds pointer to op workspace
136  output_names.pop_back();
137 
138  std::unordered_set<std::string> all_output_names(
139  output_names.begin(), output_names.end());
140  CAFFE_ENFORCE_EQ(
141  output_names.size(), all_output_names.size(), "Duplicate output blobs");
142 
143  std::vector<std::string> outer_blob_names;
144  outer_blob_names.reserve(input_names.size() + output_names.size());
145  outer_blob_names.insert(
146  outer_blob_names.end(), input_names.begin(), input_names.end());
147  outer_blob_names.insert(
148  outer_blob_names.end(), output_names.begin(), output_names.end());
149  return outer_blob_names;
150  }
151 
152  std::vector<std::string> getInputBlobNames(
153  const OperatorDef& operator_def) const {
154  std::vector<std::string> names;
155  names.reserve(operator_def.input_size());
156  for (auto idx = 0; idx < operator_def.input_size(); ++idx) {
157  names.push_back(operator_def.input(idx));
158  }
159  return names;
160  }
161 
162  std::vector<std::string> getOutputBlobNames(
163  const OperatorDef& operator_def) const {
164  std::vector<std::string> names;
165  names.reserve(operator_def.output_size());
166  for (auto idx = 0; idx < operator_def.output_size(); ++idx) {
167  names.push_back(operator_def.output(idx));
168  }
169  return names;
170  }
171 
172  std::unordered_map<std::string, std::string> blob_bindings_;
173  std::unordered_set<std::string> forwarded_inner_blobs_;
174  bool is_gradient_op_;
175  bool copy_external_blobs_;
176  bool reuse_workspace_;
177  NetDef net_def_;
178  Workspace* parent_ws_;
179 };
180 
181 } // namespace caffe2
182 
183 #endif // CAFFE2_OPERATORS_DO_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13