1 #include <torch/csrc/jit/passes/canonicalize.h> 11 std::shared_ptr<Graph> Canonicalize(
12 const std::shared_ptr<Graph>& graph,
13 bool keep_unique_names) {
14 auto r = std::make_shared<Graph>(graph->current_scope());
15 std::unordered_map<Value*, Value*> rn_env;
16 auto rn_fn = [&](Value* v) {
return rn_env.at(v); };
17 for (
auto* input : graph->inputs()) {
18 auto* r_input = r->addInput();
19 r_input->copyMetadata(input);
20 if (!keep_unique_names)
21 r_input->setUniqueName(
"");
22 rn_env[input] = r_input;
24 for (
auto* node : graph->nodes()) {
25 auto* r_node = r->createClone(node, rn_fn);
26 if (!keep_unique_names) {
27 for (
auto* output : r_node->outputs()) {
28 output->setUniqueName(
"");
31 r->appendNode(r_node);
32 auto outputs = node->outputs();
33 auto r_outputs = r_node->outputs();
34 for (
size_t i = 0; i < outputs.size(); i++) {
35 rn_env[outputs.at(i)] = r_outputs.at(i);
37 if (node->hasAttribute(attr::Subgraph)) {
40 Canonicalize(node->g(attr::Subgraph), keep_unique_names));
43 for (
auto* output : graph->outputs()) {
44 r->registerOutput(rn_fn(output));