1 #ifndef CAFFE2_OPERATORS_CREATE_SCOPE_OP_H_ 2 #define CAFFE2_OPERATORS_CREATE_SCOPE_OP_H_ 5 #include <unordered_map> 6 #include <unordered_set> 9 #include "caffe2/core/context.h" 10 #include "caffe2/core/logging.h" 11 #include "caffe2/core/operator.h" 12 #include "caffe2/proto/caffe2_pb.h" 14 C10_DECLARE_bool(caffe2_workspace_stack_debug);
27 std::shared_ptr<Workspace> pushForwardWorkspace(
Workspace* parent_ws) {
28 return pushForwardWorkspace(
29 parent_ws, std::unordered_map<std::string, std::string>());
32 std::shared_ptr<Workspace> pushForwardWorkspace(
34 const std::unordered_map<std::string, std::string>& blob_bindings) {
36 if (FLAGS_caffe2_workspace_stack_debug) {
38 CAFFE_ENFORCE_EQ(parent_ws_, parent_ws,
"Parent workspace mismatch");
40 parent_ws_ = parent_ws;
42 if (!blob_bindings_.empty()) {
43 checkBindingsMatch(blob_bindings_, blob_bindings);
45 blob_bindings_ = blob_bindings;
49 if (top_ == workspaces_.size() - 1) {
50 workspaces_.push_back(
51 std::make_shared<Workspace>(parent_ws, blob_bindings));
55 auto& workspace = workspaces_[top_ + 1];
56 const auto& local_blobs = workspace->LocalBlobs();
57 std::unordered_set<std::string> local_blobs_set;
58 local_blobs_set.insert(local_blobs.begin(), local_blobs.end());
59 bool found_local_copy =
false;
60 for (
const auto& blob_pair : blob_bindings) {
61 if (local_blobs_set.count(blob_pair.first)) {
62 workspace->RemoveBlob(blob_pair.first);
63 found_local_copy =
true;
66 if (found_local_copy) {
67 workspace->AddBlobMapping(parent_ws, blob_bindings);
71 return workspaces_[++top_];
74 std::shared_ptr<Workspace> popGradientWorkspace(
76 const std::unordered_map<std::string, std::string>& grad_blob_bindings) {
78 if (FLAGS_caffe2_workspace_stack_debug) {
80 CAFFE_ENFORCE_EQ(parent_ws_, parent_ws,
"Parent workspace mismatch");
82 parent_ws_ = parent_ws;
84 if (!grad_blob_bindings_.empty()) {
85 checkBindingsMatch(grad_blob_bindings_, grad_blob_bindings);
87 grad_blob_bindings_ = grad_blob_bindings;
94 auto& grad_workspace = workspaces_[top_];
95 grad_workspace->AddBlobMapping(parent_ws, grad_blob_bindings,
true);
97 return grad_workspace;
100 std::shared_ptr<Workspace> reuseLastForwardWorkspace(
Workspace* parent_ws) {
101 return reuseLastForwardWorkspace(
102 parent_ws, std::unordered_map<std::string, std::string>());
105 std::shared_ptr<Workspace> reuseLastForwardWorkspace(
107 const std::unordered_map<std::string, std::string>& blob_bindings) {
112 workspaces_[top_]->AddBlobMapping(parent_ws, blob_bindings);
113 return workspaces_[top_];
126 void checkStack()
const {
128 (
int)workspaces_.size(), top_,
"Corrupted workspaces stack");
131 void checkBindingsMatch(
132 const std::unordered_map<std::string, std::string>& bindings,
133 const std::unordered_map<std::string, std::string>& test_bindings)
const {
135 bindings.size(), test_bindings.size(),
"Blob bindings mismatch");
136 for (
const auto& blob_binding : bindings) {
138 test_bindings.count(blob_binding.first),
"Blob bindings mismatch");
140 test_bindings.at(blob_binding.first),
142 "Blob bindings mismatch");
146 std::unordered_map<std::string, std::string> blob_bindings_;
147 std::unordered_map<std::string, std::string> grad_blob_bindings_;
150 std::vector<std::shared_ptr<Workspace>> workspaces_;
154 template <
class Context>
157 template <
class... Args>
161 USE_OPERATOR_CONTEXT_FUNCTIONS;
162 bool RunOnDevice()
override;
165 template <
class Context>
168 template <
class... Args>
172 USE_OPERATOR_CONTEXT_FUNCTIONS;
173 bool RunOnDevice()
override;
178 #endif // CAFFE2_OPERATORS_CREATE_SCOPE_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...