1 #include <torch/csrc/jit/passes/dead_code_elimination.h> 3 #include <torch/csrc/jit/ir_views.h> 4 #include <torch/csrc/jit/passes/alias_analysis.h> 5 #include <torch/csrc/utils/memory.h> 7 #include <unordered_map> 13 using namespace ::c10::prim;
19 : aliasDb_(torch::make_unique<AliasDb>(std::move(graph))) {}
25 void run(
Block* block,
bool recurse) {
27 mark(block->return_node());
31 deleteCallback_(liveValues_);
33 sweep(block, recurse);
36 void setDeleteCallback(
37 std::function<
void(
const std::unordered_set<const Value*>&)>
39 deleteCallback_ = std::move(deleteCallback);
59 void markReturnNode(
Node* node) {
60 if (marked_.count(node)) {
64 AT_ASSERT(node->owningBlock()->return_node() == node);
65 auto outerNode = node->owningBlock()->owningNode();
66 if (outerNode ==
nullptr || outerNode->kind() == prim::Reverse) {
74 if (outerNode->kind() == prim::Loop ||
75 outerNode->kind() == c10::onnx::Loop) {
78 for (
size_t i = 0; i < loop.carriedOutputs().size(); i++) {
79 auto innerInput = loop.bodyCarriedInputs().at(i);
80 auto innerOutput = loop.bodyCarriedOutputs().at(i);
81 auto outerOutput = loop.carriedOutputs().at(i);
82 if (liveValues_.count(outerOutput) || innerInput->hasUses()) {
83 liveValues_.insert(innerOutput);
89 liveValues_.insert(loop.nextCond());
91 AT_ASSERT(outerNode->outputs().size() == node->inputs().size());
92 for (
size_t i = 0; i < outerNode->outputs().size(); i++) {
93 auto innerOutput = node->inputs()[i];
94 auto outerOutput = outerNode->outputs()[i];
95 if (liveValues_.count(outerOutput)) {
96 liveValues_.insert(innerOutput);
101 marked_.insert(node);
104 void mark(
Block* block) {
106 for (
auto node : block->nodes()) {
107 if (hasSideEffects(node)) {
113 markReturnNode(block->return_node());
115 for (
auto it = block->nodes().rbegin(); it != block->nodes().rend(); ++it) {
117 for (
auto subBlock : node->blocks()) {
125 void markIfLive(
Node* node) {
126 for (
const auto output : node->outputs()) {
127 if (liveValues_.count(output)) {
133 if (aliasDb_->writesToAlias(node, liveValues_,
false)) {
141 void mark(
Node* node) {
142 if (marked_.count(node)) {
146 marked_.insert(node);
152 if (!curNode->owningBlock()) {
157 curNode = curNode->owningBlock()->owningNode();
160 for (
const auto input : node->inputs()) {
161 if (liveValues_.count(input)) {
164 liveValues_.insert(input);
169 void sweep(
Block* block,
bool recurse) {
170 auto nodes = block->nodes().reverse();
171 for (
auto it = nodes.begin(); it != nodes.end(); it++) {
175 removeDeadBlockOutputs(node);
176 removeDeadLoopOutputs(node);
178 for (
Block* block : node->blocks()) {
186 if (!(marked_.count(node) || node->hasUses())) {
192 bool hasUntrackedMutation(
Node* node) {
196 if (!node->kind().is_aten() && !node->kind().is_prim()) {
202 auto schema = node->maybeSchema();
203 return schema && schema->is_mutable();
205 return aliasDb_->hasUntrackedEffects(node);
209 bool hasSideEffects(
Node* node) {
210 auto it = memo_.find(node);
211 if (it != memo_.end())
213 bool has_side_effects = node->hasSideEffects() ||
214 std::any_of(node->blocks().begin(),
215 node->blocks().end(),
218 b->nodes().begin(), b->nodes().end(), [&](
Node* n) {
219 return hasSideEffects(n);
222 hasUntrackedMutation(node);
224 memo_.emplace(node, has_side_effects);
225 return has_side_effects;
228 void removeDeadBlockOutputs(
Node* node) {
229 if (node->kind() != prim::If && node->kind() != prim::GradOf) {
233 for (
size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
235 if (!node->outputs().at(i)->hasUses()) {
236 node->eraseOutput(i);
237 for (
Block* b : node->blocks()) {
244 void removeDeadLoopOutputs(
Node* node) {
245 if (node->kind() != prim::Loop)
247 auto loop_body = node->blocks().at(0);
248 auto loop_input_offset = 2;
249 auto loop_body_offset =
252 for (
size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
254 if (!node->outputs().at(i)->hasUses() &&
255 !loop_body->inputs().at(loop_body_offset + i)->hasUses()) {
256 node->eraseOutput(i);
257 node->removeInput(loop_input_offset + i);
258 loop_body->eraseInput(loop_body_offset + i);
259 loop_body->eraseOutput(loop_body_offset + i);
264 std::unique_ptr<AliasDb> aliasDb_ =
nullptr;
265 std::unordered_map<Node*, bool> memo_;
266 std::unordered_set<Node*> marked_;
267 std::unordered_set<const Value*> liveValues_;
268 std::function<void(const std::unordered_set<const Value*>&)> deleteCallback_ =
269 [](
const std::unordered_set<const Value*>&) {};
272 void EliminateDeadCode(
const std::shared_ptr<Graph>& graph) {
276 void EliminateDeadCode(
Block* block,
bool recurse) {
280 void EliminateDeadCode(
282 std::function<
void(
const std::unordered_set<const Value*>&)> cb) {
284 eliminator.setDeleteCallback(std::move(cb));
285 eliminator.run(block,
true);