Caffe2 - C++ API
A deep learning, cross platform ML framework
canonicalize.cpp
1 #include <torch/csrc/jit/passes/canonicalize.h>
2 
3 namespace torch {
4 namespace jit {
5 
6 // Canonicalize a graph, renumbering it so that all structurally equivalent
7 // graphs have same numbers.
8 // keep_unique_names: If false, canonicalizes unique names by removing them
9 // and replacing them with normal value names.
10 // Otherwise, ignores values with unique names.
11 std::shared_ptr<Graph> Canonicalize(
12  const std::shared_ptr<Graph>& graph,
13  bool keep_unique_names) {
14  auto r = std::make_shared<Graph>(graph->current_scope());
15  std::unordered_map<Value*, Value*> rn_env;
16  auto rn_fn = [&](Value* v) { return rn_env.at(v); };
17  for (auto* input : graph->inputs()) {
18  auto* r_input = r->addInput();
19  r_input->copyMetadata(input);
20  if (!keep_unique_names)
21  r_input->setUniqueName("");
22  rn_env[input] = r_input;
23  }
24  for (auto* node : graph->nodes()) {
25  auto* r_node = r->createClone(node, rn_fn);
26  if (!keep_unique_names) {
27  for (auto* output : r_node->outputs()) {
28  output->setUniqueName("");
29  }
30  }
31  r->appendNode(r_node);
32  auto outputs = node->outputs();
33  auto r_outputs = r_node->outputs();
34  for (size_t i = 0; i < outputs.size(); i++) {
35  rn_env[outputs.at(i)] = r_outputs.at(i);
36  }
37  if (node->hasAttribute(attr::Subgraph)) {
38  r_node->g_(
39  attr::Subgraph,
40  Canonicalize(node->g(attr::Subgraph), keep_unique_names));
41  }
42  }
43  for (auto* output : graph->outputs()) {
44  r->registerOutput(rn_fn(output));
45  }
46 
47  return r;
48 }
49 
50 } // namespace jit
51 } // namespace torch
Definition: jit_type.h:17