Caffe2 - C++ API
A deep learning, cross platform ML framework
create_scope_op.h
1 #ifndef CAFFE2_OPERATORS_CREATE_SCOPE_OP_H_
2 #define CAFFE2_OPERATORS_CREATE_SCOPE_OP_H_
3 
4 #include <string>
5 #include <unordered_map>
6 #include <unordered_set>
7 #include <vector>
8 
9 #include "caffe2/core/context.h"
10 #include "caffe2/core/logging.h"
11 #include "caffe2/core/operator.h"
12 #include "caffe2/proto/caffe2_pb.h"
13 
14 C10_DECLARE_bool(caffe2_workspace_stack_debug);
15 
16 namespace caffe2 {
17 namespace detail {
18 
19 /*
20  * Keeps track of forward and backward gradient workspaces in stack,
21  * reuses previously created workspaces, non-thread safe
22  */
23 class CAFFE2_API WorkspaceStack {
24  public:
25  explicit WorkspaceStack() : parent_ws_(nullptr), top_(-1) {}
26 
27  std::shared_ptr<Workspace> pushForwardWorkspace(Workspace* parent_ws) {
28  return pushForwardWorkspace(
29  parent_ws, std::unordered_map<std::string, std::string>());
30  }
31 
32  std::shared_ptr<Workspace> pushForwardWorkspace(
33  Workspace* parent_ws,
34  const std::unordered_map<std::string, std::string>& blob_bindings) {
35  checkStack();
36  if (FLAGS_caffe2_workspace_stack_debug) {
37  if (parent_ws_) {
38  CAFFE_ENFORCE_EQ(parent_ws_, parent_ws, "Parent workspace mismatch");
39  } else {
40  parent_ws_ = parent_ws;
41  }
42  if (!blob_bindings_.empty()) {
43  checkBindingsMatch(blob_bindings_, blob_bindings);
44  } else {
45  blob_bindings_ = blob_bindings;
46  }
47  }
48 
49  if (top_ == workspaces_.size() - 1) {
50  workspaces_.push_back(
51  std::make_shared<Workspace>(parent_ws, blob_bindings));
52  } else {
53  // when reusing workspace, make sure copies of external blobs are
54  // removed and blob bindings are set
55  auto& workspace = workspaces_[top_ + 1];
56  const auto& local_blobs = workspace->LocalBlobs();
57  std::unordered_set<std::string> local_blobs_set;
58  local_blobs_set.insert(local_blobs.begin(), local_blobs.end());
59  bool found_local_copy = false;
60  for (const auto& blob_pair : blob_bindings) {
61  if (local_blobs_set.count(blob_pair.first)) {
62  workspace->RemoveBlob(blob_pair.first);
63  found_local_copy = true;
64  }
65  }
66  if (found_local_copy) {
67  workspace->AddBlobMapping(parent_ws, blob_bindings);
68  }
69  }
70 
71  return workspaces_[++top_];
72  }
73 
74  std::shared_ptr<Workspace> popGradientWorkspace(
75  Workspace* parent_ws,
76  const std::unordered_map<std::string, std::string>& grad_blob_bindings) {
77  checkStack();
78  if (FLAGS_caffe2_workspace_stack_debug) {
79  if (parent_ws_) {
80  CAFFE_ENFORCE_EQ(parent_ws_, parent_ws, "Parent workspace mismatch");
81  } else {
82  parent_ws_ = parent_ws;
83  }
84  if (!grad_blob_bindings_.empty()) {
85  checkBindingsMatch(grad_blob_bindings_, grad_blob_bindings);
86  } else {
87  grad_blob_bindings_ = grad_blob_bindings;
88  }
89  }
90 
91  if (top_ < 0) {
92  return nullptr;
93  }
94  auto& grad_workspace = workspaces_[top_];
95  grad_workspace->AddBlobMapping(parent_ws, grad_blob_bindings, true);
96  --top_;
97  return grad_workspace;
98  }
99 
100  std::shared_ptr<Workspace> reuseLastForwardWorkspace(Workspace* parent_ws) {
101  return reuseLastForwardWorkspace(
102  parent_ws, std::unordered_map<std::string, std::string>());
103  }
104 
105  std::shared_ptr<Workspace> reuseLastForwardWorkspace(
106  Workspace* parent_ws,
107  const std::unordered_map<std::string, std::string>& blob_bindings) {
108  checkStack();
109  if (top_ < 0) {
110  return nullptr;
111  }
112  workspaces_[top_]->AddBlobMapping(parent_ws, blob_bindings);
113  return workspaces_[top_];
114  }
115 
116  void clear() {
117  checkStack();
118  top_ = -1;
119  }
120 
121  bool empty() const {
122  return top_ < 0;
123  }
124 
125  private:
126  void checkStack() const {
127  CAFFE_ENFORCE_GT(
128  (int)workspaces_.size(), top_, "Corrupted workspaces stack");
129  }
130 
131  void checkBindingsMatch(
132  const std::unordered_map<std::string, std::string>& bindings,
133  const std::unordered_map<std::string, std::string>& test_bindings) const {
134  CAFFE_ENFORCE_EQ(
135  bindings.size(), test_bindings.size(), "Blob bindings mismatch");
136  for (const auto& blob_binding : bindings) {
137  CAFFE_ENFORCE(
138  test_bindings.count(blob_binding.first), "Blob bindings mismatch");
139  CAFFE_ENFORCE_EQ(
140  test_bindings.at(blob_binding.first),
141  blob_binding.second,
142  "Blob bindings mismatch");
143  }
144  }
145 
146  std::unordered_map<std::string, std::string> blob_bindings_;
147  std::unordered_map<std::string, std::string> grad_blob_bindings_;
148  Workspace* parent_ws_;
149  int top_;
150  std::vector<std::shared_ptr<Workspace>> workspaces_;
151 };
152 }
153 
154 template <class Context>
155 class CreateScopeOp final : public Operator<Context> {
156  public:
157  template <class... Args>
158  explicit CreateScopeOp(Args&&... args)
159  : Operator<Context>(std::forward<Args>(args)...) {}
160 
161  USE_OPERATOR_CONTEXT_FUNCTIONS;
162  bool RunOnDevice() override;
163 };
164 
165 template <class Context>
166 class HasScopeOp final : public Operator<Context> {
167  public:
168  template <class... Args>
169  explicit HasScopeOp(Args&&... args)
170  : Operator<Context>(std::forward<Args>(args)...) {}
171 
172  USE_OPERATOR_CONTEXT_FUNCTIONS;
173  bool RunOnDevice() override;
174 };
175 
176 } // namespace caffe2
177 
178 #endif // CAFFE2_OPERATORS_CREATE_SCOPE_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13