1 #include <ATen/core/interned_strings.h> 2 #include <torch/csrc/jit/ir.h> 3 #include <torch/csrc/jit/node_hashing.h> 4 #include <torch/csrc/jit/passes/constant_pooling.h> 5 #include <unordered_set> 16 std::unordered_set<Node*, HashNode, EqualNode>& constants) {
17 for (
auto it = block->nodes().begin(); it != block->nodes().end();) {
21 if (!node->blocks().empty()) {
23 for (
auto block : node->blocks()) {
24 ConstantPooling(block, constants);
29 if (node->kind() != prim::Constant) {
34 auto subit = constants.insert(node);
37 auto existing = *subit.first;
38 node->replaceAllUsesWith(existing);
44 auto first_node = node->owningGraph()->block()->nodes().front();
45 if (node != first_node)
46 node->moveBefore(first_node);
52 void ConstantPooling(
const std::shared_ptr<Graph>& graph) {
53 std::unordered_set<Node*, HashNode, EqualNode> constants;
54 ConstantPooling(graph->block(), constants);