1 #include <torch/csrc/jit/passes/common_subexpression_elimination.h> 3 #include <torch/csrc/jit/ir.h> 4 #include <torch/csrc/jit/node_hashing.h> 5 #include <torch/csrc/jit/passes/alias_analysis.h> 7 #include <unordered_map> 14 void EliminateCommonSubexpression(
16 const AliasDb& aliasDb,
17 std::function<Node*(Node*)> parent_lookup_fn) {
18 std::unordered_set<Node*, HashNode, EqualNode> subexprs;
19 for (
auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
21 if (node->hasSideEffects() || node->isNondeterministic() ||
22 aliasDb.hasWriters(node)) {
27 if (!node->blocks().empty()) {
29 for (
auto block : node->blocks()) {
30 EliminateCommonSubexpression(block, aliasDb, [&](Node* n) {
31 auto existing = subexprs.find(n);
32 if (existing != subexprs.end()) {
36 return parent_lookup_fn(n);
44 auto parent_lookup = parent_lookup_fn(node);
46 node->replaceAllUsesWith(parent_lookup);
52 auto subit = subexprs.insert(node);
55 auto existing = *subit.first;
56 node->replaceAllUsesWith(existing);
64 void EliminateCommonSubexpression(std::shared_ptr<Graph>& graph) {
65 AliasDb aliasDb(graph);
66 EliminateCommonSubexpression(
67 graph->block(), aliasDb, [](Node*) {
return nullptr; });