Caffe2 - C++ API
A deep learning, cross platform ML framework
create_scope_op.h
1 
17 #ifndef CAFFE2_OPERATORS_CREATE_SCOPE_OP_H_
18 #define CAFFE2_OPERATORS_CREATE_SCOPE_OP_H_
19 
20 #include <string>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24 
25 #include "caffe2/core/context.h"
26 #include "caffe2/core/logging.h"
27 #include "caffe2/core/operator.h"
28 #include "caffe2/proto/caffe2.pb.h"
29 
30 CAFFE2_DECLARE_bool(caffe2_workspace_stack_debug);
31 
32 namespace caffe2 {
33 namespace detail {
34 
35 /*
36  * Keeps track of forward and backward gradient workspaces in stack,
37  * reuses previously created workspaces, non-thread safe
38  */
40  public:
41  explicit WorkspaceStack() : parent_ws_(nullptr), top_(-1) {}
42 
43  std::shared_ptr<Workspace> pushForwardWorkspace(
44  Workspace* parent_ws,
45  const std::unordered_map<std::string, std::string>& blob_bindings) {
46  checkStack();
47  if (FLAGS_caffe2_workspace_stack_debug) {
48  if (parent_ws_) {
49  CAFFE_ENFORCE_EQ(parent_ws_, parent_ws, "Parent workspace mismatch");
50  } else {
51  parent_ws_ = parent_ws;
52  }
53  if (!blob_bindings_.empty()) {
54  checkBindingsMatch(blob_bindings_, blob_bindings);
55  } else {
56  blob_bindings_ = blob_bindings;
57  }
58  }
59 
60  if (top_ == workspaces_.size() - 1) {
61  workspaces_.push_back(
62  std::make_shared<Workspace>(parent_ws, blob_bindings));
63  } else {
64  // when reusing workspace, make sure copies of external blobs are
65  // removed and blob bindings are set
66  auto& workspace = workspaces_[top_ + 1];
67  const auto& local_blobs = workspace->LocalBlobs();
68  std::unordered_set<std::string> local_blobs_set;
69  local_blobs_set.insert(local_blobs.begin(), local_blobs.end());
70  bool found_local_copy = false;
71  for (const auto& blob_pair : blob_bindings) {
72  if (local_blobs_set.count(blob_pair.first)) {
73  workspace->RemoveBlob(blob_pair.first);
74  found_local_copy = true;
75  }
76  }
77  if (found_local_copy) {
78  workspace->AddBlobMapping(parent_ws, blob_bindings);
79  }
80  }
81 
82  return workspaces_[++top_];
83  }
84 
85  std::shared_ptr<Workspace> popGradientWorkspace(
86  Workspace* parent_ws,
87  const std::unordered_map<std::string, std::string>& grad_blob_bindings) {
88  checkStack();
89  if (FLAGS_caffe2_workspace_stack_debug) {
90  if (parent_ws_) {
91  CAFFE_ENFORCE_EQ(parent_ws_, parent_ws, "Parent workspace mismatch");
92  } else {
93  parent_ws_ = parent_ws;
94  }
95  if (!grad_blob_bindings_.empty()) {
96  checkBindingsMatch(grad_blob_bindings_, grad_blob_bindings);
97  } else {
98  grad_blob_bindings_ = grad_blob_bindings;
99  }
100  }
101 
102  if (top_ < 0) {
103  return nullptr;
104  }
105  auto& grad_workspace = workspaces_[top_];
106  grad_workspace->AddBlobMapping(parent_ws, grad_blob_bindings, true);
107  --top_;
108  return grad_workspace;
109  }
110 
111  std::shared_ptr<Workspace> reuseLastForwardWorkspace(
112  Workspace* parent_ws,
113  const std::unordered_map<std::string, std::string>& blob_bindings) {
114  checkStack();
115  if (top_ < 0) {
116  return nullptr;
117  }
118  workspaces_[top_]->AddBlobMapping(parent_ws, blob_bindings);
119  return workspaces_[top_];
120  }
121 
122  void clear() {
123  checkStack();
124  top_ = -1;
125  }
126 
127  bool empty() const {
128  return top_ < 0;
129  }
130 
131  private:
132  void checkStack() const {
133  CAFFE_ENFORCE_GT(
134  (int)workspaces_.size(), top_, "Corrupted workspaces stack");
135  }
136 
137  void checkBindingsMatch(
138  const std::unordered_map<std::string, std::string>& bindings,
139  const std::unordered_map<std::string, std::string>& test_bindings) const {
140  CAFFE_ENFORCE_EQ(
141  bindings.size(), test_bindings.size(), "Blob bindings mismatch");
142  for (const auto& blob_binding : bindings) {
143  CAFFE_ENFORCE(
144  test_bindings.count(blob_binding.first), "Blob bindings mismatch");
145  CAFFE_ENFORCE_EQ(
146  test_bindings.at(blob_binding.first),
147  blob_binding.second,
148  "Blob bindings mismatch");
149  }
150  }
151 
152  std::unordered_map<std::string, std::string> blob_bindings_;
153  std::unordered_map<std::string, std::string> grad_blob_bindings_;
154  Workspace* parent_ws_;
155  int top_;
156  std::vector<std::shared_ptr<Workspace>> workspaces_;
157 };
158 }
159 
160 template <class Context>
161 class CreateScopeOp final : public Operator<Context> {
162  public:
163  CreateScopeOp(const OperatorDef& operator_def, Workspace* ws)
164  : Operator<Context>(operator_def, ws) {}
165 
166  USE_OPERATOR_CONTEXT_FUNCTIONS;
167  bool RunOnDevice() override;
168 };
169 
170 template <class Context>
171 class HasScopeOp final : public Operator<Context> {
172  public:
173  HasScopeOp(const OperatorDef& operator_def, Workspace* ws)
174  : Operator<Context>(operator_def, ws) {}
175 
176  USE_OPERATOR_CONTEXT_FUNCTIONS;
177  bool RunOnDevice() override;
178 };
179 
180 } // namespace caffe2
181 
182 #endif // CAFFE2_OPERATORS_CREATE_SCOPE_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.