1 #include <torch/csrc/jit/passes/dead_code_elimination.h> 2 #include <torch/csrc/jit/passes/remove_expands.h> 7 static void RemoveExpands(Block* block) {
8 for (
auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
10 for (
auto sub : it->blocks())
13 if (it->kind() == aten::expand && it->get<
bool>(attr::implicit) ==
true) {
14 it->output()->replaceAllUsesWith(it->namedInput(attr::self));
20 void RemoveExpands(
const std::shared_ptr<Graph>& graph) {
21 RemoveExpands(graph->block());