1 #include "caffe2/core/memonger.h" 4 #include <unordered_set> 6 #include "caffe2/utils/proto_utils.h" 11 NetDef optimize_inference_net(
13 const std::set<string>& static_blobs) {
14 if (net.type() !=
"" && net.type() !=
"simple") {
15 LOG(INFO) <<
"Cannot optimize memory for nets of type: " << net.type();
19 std::vector<OperatorDef> ops;
20 for (
auto& op : net.op()) {
21 if (op.type() ==
"RecurrentNetwork") {
24 LOG(INFO) <<
"Memonger does not support RecurrentNetwork yet";
31 std::unordered_set<std::string> all_blobs;
32 std::unordered_map<std::string, std::pair<int, int>> ranges;
33 for (
size_t i = 0; i < ops.size(); i++) {
34 for (
auto& inp : ops[i].input()) {
35 if (ranges.find(inp) != ranges.end()) {
36 ranges[inp].second = i;
38 all_blobs.insert(inp);
40 for (
auto& outp : ops[i].output()) {
41 all_blobs.insert(outp);
42 if (static_blobs.find(outp) != static_blobs.end()) {
45 if (ranges.find(outp) == ranges.end()) {
46 ranges[outp] = std::make_pair(i, i);
52 std::vector<std::string> free_blobs;
53 std::unordered_map<std::string, std::string> renaming;
54 std::unordered_map<std::string, std::string> mapping;
56 for (
int i = 0; i < (int)ops.size(); i++) {
58 std::unordered_set<std::string> new_free_blobs;
61 for (
auto& inp : op.input()) {
62 auto rit = ranges.find(inp);
63 if (rit != ranges.end() && rit->second.second == i) {
64 if (mapping.find(inp) == mapping.end()) {
65 new_free_blobs.insert(inp);
70 "__m" + c10::to_string(renaming.size()) +
"_shared";
71 if (all_blobs.find(shared_blob) != all_blobs.end()) {
72 LOG(INFO) <<
"Net was already memongered!";
75 renaming[inp] = shared_blob;
77 new_free_blobs.insert(mapping[inp]);
84 for (
auto& outp : op.output()) {
85 if (!free_blobs.empty()) {
87 auto rit = ranges.find(outp);
88 if (rit != ranges.end() && rit->second.first == i) {
89 std::string recycled = free_blobs.back();
90 free_blobs.pop_back();
91 mapping[outp] = recycled;
97 for (
auto& b : new_free_blobs) {
98 free_blobs.push_back(b);
103 NetDef optim_net = net;
104 optim_net.mutable_op()->Clear();
105 for (
auto op : ops) {
106 for (
int i = 0; i < op.input_size(); i++) {
107 auto& inp = op.input(i);
108 if (mapping.find(inp) != mapping.end()) {
109 op.set_input(i, renaming[mapping[inp]]);
112 for (
int i = 0; i < op.output_size(); i++) {
113 auto& outp = op.output(i);
114 if (mapping.find(outp) != mapping.end()) {
115 op.set_output(i, renaming[mapping[outp]]);
118 auto* ao = optim_net.add_op();
122 VLOG(1) <<
"optimized net using " << renaming.size() <<
" shared blobs";
130 op_visited_count_(size),
131 op_token_deposit_(size),
132 op_visited_(size,
false) {}
135 const std::vector<string>& heads,
136 const std::vector<int>& op_indices,
137 const std::unordered_set<string>& shareable_blob_names,
138 const string& namescope,
139 const std::unordered_set<string>& dont_share_blob_names,
140 const std::unordered_map<
string, vector<int>>& blob_shapes) {
142 std::unordered_set<string> heads_blobs_set(heads.begin(), heads.end());
145 for (
const int op_index : op_indices) {
146 for (
const auto& output : net.op(op_index).output()) {
147 optim_op_outputs_.insert(output);
154 std::unordered_map<string, int> blob_seen;
155 for (
const int op_index : op_indices) {
156 for (
const auto& input : net.op(op_index).input()) {
157 if (has_key(shareable_blob_names, input) ||
158 has_key(heads_blobs_set, input)) {
159 if (has_key(optim_op_outputs_, input)) {
161 blob_seen.find(input) != blob_seen.end(),
164 " was not output by an op before");
165 op_inputs_[op_index] += blob_seen[input];
167 share_counts_[input] = 1;
169 blob_to_ops_[input].push_back(op_index);
172 for (
const auto& output : net.op(op_index).output()) {
173 blob_seen[output] += 1;
174 blob_device_[output] = net.op(op_index).device_option();
177 if (net.op(op_index).type() ==
"CopyGPUToCPU") {
178 blob_device_[output].set_device_type(0);
179 blob_device_[output].set_device_id(0);
186 for (
const auto& input_blob : heads) {
187 for (
const int op_index : blob_to_ops_[input_blob]) {
188 if (!op_visited_[op_index]) {
189 vector<std::pair<int, string>> free_blobs;
190 std::unordered_set<int> tokens{tokens_counter_++};
193 shareable_blob_names,
195 dont_share_blob_names,
205 std::unordered_map<string, string> renamed;
207 std::unordered_set<string> mapped_blobs_set;
208 for (
const auto& mapped_blob : mapping_) {
209 mapped_blobs_set.insert(mapped_blob.second);
210 if (has_key(optim_op_outputs_, mapped_blob.second)) {
211 if (renamed.find(mapped_blob.second) == renamed.end()) {
214 namescope +
"__m" + c10::to_string(name_idx++) +
"_shared"});
217 renamed.insert({mapped_blob.second, mapped_blob.second});
222 mapping_.insert(renamed.begin(), renamed.end());
223 bool had_changes =
true;
224 while (had_changes) {
226 for (
const auto mapped_blob : mapping_) {
227 if (has_key(renamed, mapped_blob.second) &&
228 renamed[mapped_blob.second] != mapped_blob.second) {
229 renamed[mapped_blob.first] = renamed[mapped_blob.second];
230 mapping_[mapped_blob.first] = renamed[mapped_blob.first];
235 NetDef optimized_net = apply_assignments(net);
236 LOG(INFO) <<
"Remapping " << mapping_.size() <<
" using " 237 << mapped_blobs_set.size() <<
" shared blobs.";
238 if (floats_saved_ > 0) {
239 LOG(INFO) <<
"Memonger saved approximately : " 240 << (floats_saved_ * 4.0 / 1024.0 / 1024.0) <<
" MB.";
243 return optimized_net;
247 NetDef apply_assignments(
const NetDef& net) {
248 NetDef optimized_net = net;
250 for (
int i = 0; i < optimized_net.op_size(); ++i) {
253 if (optimized_net.op(i).type().find(
"RecurrentNetwork") == 0) {
254 apply_recurrent_blob_assignments(optimized_net.mutable_op(i));
257 for (
int j = 0; j < optimized_net.op(i).input_size(); ++j) {
258 const string& input_name =
259 get_blob_or_mapped_blob(optimized_net.op(i).input(j));
260 optimized_net.mutable_op(i)->set_input(j, input_name);
263 for (
int j = 0; j < optimized_net.op(i).output_size(); ++j) {
265 get_blob_or_mapped_blob(optimized_net.op(i).output(j));
266 optimized_net.mutable_op(i)->set_output(j, output_name);
269 return optimized_net;
272 void apply_recurrent_blob_assignments(OperatorDef* op) {
275 for (
int i = 0; i < op->arg_size(); i++) {
277 const string& name = arg->name();
278 if (name ==
"step_net" || name ==
"backward_step_net") {
280 NetDef* step_net_ref = arg->mutable_n();
283 "Invalid definition for ",
285 ". Only one of NetDef and string should be present");
286 NetDef optimized_net = apply_assignments(*step_net_ref);
287 step_net_ref->CopyFrom(optimized_net);
291 TextFormat::ParseFromString(
292 arg->s(), &step_net),
293 "Could not parse step net:",
295 step_net = apply_assignments(step_net);
296 arg->set_s(ProtoDebugString(step_net));
302 vector<string> inputs_outputs(op->input().begin(), op->input().end());
303 inputs_outputs.insert(
304 inputs_outputs.end(), op->output().begin(), op->output().end());
306 for (
auto& b : inputs_outputs) {
307 string mapped = get_blob_or_mapped_blob(b);
310 map_arg->set_name(b +
".rename");
311 map_arg->set_s(mapped);
316 template <
typename K,
typename V>
317 inline bool has_key(
const std::unordered_map<K, V>& in_map,
const K& key) {
318 return in_map.find(key) != in_map.end();
321 template <
typename K>
322 inline bool has_key(
const std::unordered_set<K>& in_set,
const K& key) {
323 return in_set.find(key) != in_set.end();
328 const std::unordered_set<string>& shareable_blob_names,
329 const string& namescope,
330 const std::unordered_set<string>& dont_share_blob_names,
331 const std::unordered_map<
string, vector<int>>& blob_shapes,
333 std::vector<std::pair<int, string>>* free_blobs,
334 std::unordered_set<int>* tokens) {
338 op_token_deposit_[op_index].begin(), op_token_deposit_[op_index].end());
339 op_token_deposit_[op_index].clear();
340 CAFFE_ENFORCE(!op_visited_[op_index]);
341 op_visited_[op_index] =
true;
343 const OperatorDef& current_op = net.op(op_index);
346 std::vector<std::pair<int, string>> new_free_blobs;
347 std::unordered_set<string> new_free_blobs_set;
350 for (
const auto& input : current_op.input()) {
351 const auto& actual_blob = get_blob_or_mapped_blob(input);
352 req_tokens_[actual_blob].insert(tokens->begin(), tokens->end());
353 if (actual_blob != input) {
354 req_tokens_[input].insert(tokens->begin(), tokens->end());
357 for (
const auto& output : current_op.output()) {
358 const auto& actual_blob = get_blob_or_mapped_blob(output);
359 req_tokens_[actual_blob].insert(tokens->begin(), tokens->end());
360 if (actual_blob != output) {
361 req_tokens_[output].insert(tokens->begin(), tokens->end());
366 for (
const auto& input : current_op.input()) {
367 if (has_key(shareable_blob_names, input)) {
368 blob_input_count_[input]++;
369 if (blob_input_count_[input] == (
int)blob_to_ops_[input].size()) {
370 const string& actual_blob = get_blob_or_mapped_blob(input);
371 if (!has_key(dont_share_blob_names, actual_blob)) {
372 new_free_blobs.emplace_back(
373 -share_counts_[actual_blob], actual_blob);
374 new_free_blobs_set.insert(actual_blob);
381 for (
const auto& output : current_op.output()) {
382 if (has_key(shareable_blob_names, output) &&
383 !has_key(processed_output_blobs_, output) &&
384 !has_key(new_free_blobs_set, output)) {
385 const string freed_blob = get_free_blob(
386 output, blob_shapes, tokens, free_blobs, blob_device_[output]);
387 if (freed_blob !=
"") {
388 req_tokens_[freed_blob].insert(tokens->begin(), tokens->end());
389 share_counts_[freed_blob]++;
390 mapping_[output] = freed_blob;
392 processed_output_blobs_.insert(output);
397 std::unordered_set<string> free_blob_set;
398 for (
const auto& free_blob : *free_blobs) {
399 free_blob_set.insert(free_blob.second);
401 for (
const auto& new_free_blob : new_free_blobs) {
402 if (!has_key(free_blob_set, new_free_blob.second)) {
403 free_blobs->push_back(new_free_blob);
404 if (blob_shapes.size() > 0) {
405 if (!has_key(blob_sizes_, new_free_blob.second)) {
407 {new_free_blob.second,
408 infer_blob_size(new_free_blob.second, blob_shapes)});
414 std::greater<std::pair<int, string>>());
418 int num_branches = 0;
419 for (
const auto& output : current_op.output()) {
420 num_branches += blob_to_ops_[output].size();
423 for (
const auto& output : current_op.output()) {
424 for (
const auto& input_op_index : blob_to_ops_[output]) {
425 op_visited_count_[input_op_index]++;
426 if (op_visited_count_[input_op_index] == op_inputs_[input_op_index]) {
427 std::unordered_set<int> new_tokens;
428 new_tokens.insert(tokens->begin(), tokens->end());
429 if (num_branches > 1) {
430 new_tokens.insert(tokens_counter_++);
434 shareable_blob_names,
436 dont_share_blob_names,
442 if (!op_visited_[input_op_index]) {
443 op_token_deposit_[input_op_index].insert(
444 tokens->begin(), tokens->end());
451 inline int infer_blob_size(
452 const string& blob_name,
453 const std::unordered_map<
string, vector<int>>& blob_shapes) {
454 const auto& blob_shapes_iter = blob_shapes.find(blob_name);
455 if (blob_shapes_iter == blob_shapes.end()) {
459 for (
size_t i = 0; i < blob_shapes_iter->second.size(); ++i) {
460 size *= blob_shapes_iter->second[i];
465 inline string get_blob_or_mapped_blob(
const string& blob_name) {
466 auto mapped_blob = mapping_.find(blob_name);
467 if (mapped_blob == mapping_.end()) {
470 return mapped_blob->second;
475 inline bool can_use_blob(
476 const string& blob_name,
477 std::unordered_set<int>* tokens,
478 const DeviceOption& device_option) {
479 const DeviceOption& blob_device = blob_device_[blob_name];
480 if (device_option.device_type() != blob_device.device_type() ||
481 device_option.device_id() != blob_device.device_id()) {
484 for (
const int token : req_tokens_[blob_name]) {
485 if (tokens->find(token) == tokens->end()) {
493 inline string get_free_blob(
494 const string& blob_name,
495 const std::unordered_map<
string, vector<int>>& blob_shapes,
496 std::unordered_set<int>* tokens,
497 std::vector<std::pair<int, string>>* free_blobs,
498 const DeviceOption& device) {
499 string freed_blob =
"";
500 if (blob_shapes.size() == 0) {
501 std::vector<std::pair<int, string>> cant_use_blobs;
502 while (free_blobs->size() > 0) {
506 std::greater<std::pair<int, string>>());
507 const auto cand_free_blob = free_blobs->back();
508 free_blobs->pop_back();
509 if (can_use_blob(cand_free_blob.second, tokens, device)) {
510 freed_blob = cand_free_blob.second;
513 cant_use_blobs.push_back(cand_free_blob);
516 for (
const auto& cant_use_blob : cant_use_blobs) {
517 free_blobs->push_back(cant_use_blob);
521 std::greater<std::pair<int, string>>());
526 const int blob_size = infer_blob_size(blob_name, blob_shapes);
528 int free_blob_index = -1;
529 for (
size_t i = 0; i < free_blobs->size(); ++i) {
530 const string& cb_name = (*free_blobs)[i].second;
531 if (can_use_blob(cb_name, tokens, device)) {
532 const int cand_bz = blob_sizes_[cb_name];
533 CAFFE_ENFORCE(blob_sizes_.find(cb_name) != blob_sizes_.end());
534 if (cand_bz >= best_size) {
535 if (best_size < blob_size || best_size >= cand_bz) {
542 if (free_blob_index != -1) {
543 floats_saved_ += best_size;
544 freed_blob = (*free_blobs)[free_blob_index].second;
545 free_blobs->erase(free_blobs->begin() + free_blob_index);
551 int tokens_counter_ = 1;
552 int floats_saved_ = 0;
554 std::unordered_map<string, std::vector<int>> blob_to_ops_;
556 std::unordered_map<string, int> blob_input_count_;
558 std::vector<int> op_inputs_;
560 std::vector<int> op_visited_count_;
561 std::unordered_map<string, int> share_counts_;
562 std::unordered_map<string, int> blob_sizes_;
563 std::unordered_map<string, std::unordered_set<int>> req_tokens_;
564 std::vector<std::unordered_set<int>> op_token_deposit_;
565 std::unordered_set<string> optim_op_outputs_;
566 std::unordered_map<string, string> mapping_;
567 std::unordered_map<string, DeviceOption> blob_device_;
569 std::unordered_set<string> processed_output_blobs_;
570 std::vector<bool> op_visited_;
573 NetDef compute_blob_recycling_for_dag(
575 const std::vector<string>& heads,
576 const std::vector<int>& op_indices,
577 const std::unordered_set<string>& shareable_blob_names,
578 const string& namescope,
579 const std::unordered_set<string>& dont_share_blob_names,
580 const std::unordered_map<
string, vector<int>>& blob_shapes) {
582 return memonger.OptimizeNet(
586 shareable_blob_names,
588 dont_share_blob_names,
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...