Caffe2 - C++ API
A deep learning, cross platform ML framework
remove_expands.cpp
1 #include <torch/csrc/jit/passes/dead_code_elimination.h>
2 #include <torch/csrc/jit/passes/remove_expands.h>
3 
4 namespace torch {
5 namespace jit {
6 
7 static void RemoveExpands(Block* block) {
8  for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
9  ++it) {
10  for (auto sub : it->blocks())
11  RemoveExpands(sub);
12 
13  if (it->kind() == aten::expand && it->get<bool>(attr::implicit) == true) {
14  it->output()->replaceAllUsesWith(it->namedInput(attr::self));
15  it.destroyCurrent();
16  }
17  }
18 }
19 
20 void RemoveExpands(const std::shared_ptr<Graph>& graph) {
21  RemoveExpands(graph->block());
22 }
23 
24 } // namespace jit
25 } // namespace torch
Definition: jit_type.h:17