Caffe2 - C++ API
A deep learning, cross platform ML framework
constant_pooling.cpp
1 #include <ATen/core/interned_strings.h>
2 #include <torch/csrc/jit/ir.h>
3 #include <torch/csrc/jit/node_hashing.h>
4 #include <torch/csrc/jit/passes/constant_pooling.h>
5 #include <unordered_set>
6 
7 namespace torch {
8 namespace jit {
9 
10 namespace {
11 
12 // Very similar to the common subexpression elimination pass
13 // Move all constants to the beginning of the graph, and deduplicate
14 void ConstantPooling(
15  Block* block,
16  std::unordered_set<Node*, HashNode, EqualNode>& constants) {
17  for (auto it = block->nodes().begin(); it != block->nodes().end();) {
18  auto node = *it;
19  // node may be moved to a different block so advance iterator now
20  ++it;
21  if (!node->blocks().empty()) {
22  // Traverse sub-blocks.
23  for (auto block : node->blocks()) {
24  ConstantPooling(block, constants);
25  }
26  continue;
27  }
28 
29  if (node->kind() != prim::Constant) {
30  continue;
31  }
32 
33  // Check whether the same constant already exists.
34  auto subit = constants.insert(node);
35  if (!subit.second) {
36  // constant exists, replace the uses of node, and destroy it.
37  auto existing = *subit.first;
38  node->replaceAllUsesWith(existing);
39  node->destroy();
40  continue;
41  }
42 
43  // Move the constant definition to the beginning of the graph.
44  auto first_node = node->owningGraph()->block()->nodes().front();
45  if (node != first_node)
46  node->moveBefore(first_node);
47  }
48 }
49 
50 } // anonymous namespace
51 
52 void ConstantPooling(const std::shared_ptr<Graph>& graph) {
53  std::unordered_set<Node*, HashNode, EqualNode> constants;
54  ConstantPooling(graph->block(), constants);
55 }
56 
57 } // namespace jit
58 } // namespace torch
Definition: jit_type.h:17