Caffe2 - C++ API
A deep learning, cross platform ML framework
canonicalize_ops.cpp
1 #include <torch/csrc/jit/passes/canonicalize_ops.h>
2 #include <torch/csrc/jit/passes/dead_code_elimination.h>
3 #include <torch/csrc/jit/symbolic_variable.h>
4 
5 namespace torch {
6 namespace jit {
7 
8 struct ChunkOutput {
9  ChunkOutput(Value* v, size_t o) : val(v), offset(o){};
10  Value* val;
11  size_t offset;
12 };
13 
14 static c10::optional<std::vector<ChunkOutput>> getChunkOutputs(Node* chunk) {
15  std::vector<ChunkOutput> outputs;
16  for (auto list_use : chunk->output()->uses()) {
17  if (list_use.user->matches(
18  "aten::select(Tensor[] list, int idx) -> Tensor", attr::idx)) {
19  outputs.emplace_back(
20  list_use.user->output(),
21  list_use.user->get<int64_t>(attr::idx).value());
22  } else if (list_use.user->kind() == prim::ListUnpack) {
23  // This sometimes happens if the sizes can't be evenly divided by the
24  // number of chunks
25  if (static_cast<int64_t>(list_use.user->outputs().size()) !=
26  chunk->get<int64_t>(attr::chunks).value()) {
27  return c10::nullopt;
28  }
29  auto unpack_outputs = list_use.user->outputs();
30  for (size_t i = 0; i < unpack_outputs.size(); ++i) {
31  outputs.emplace_back(unpack_outputs[i], i);
32  }
33  } else {
34  return c10::nullopt;
35  }
36  }
37  return outputs;
38 }
39 
40 static void CanonicalizeOps(Block* block) {
41  for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
42  ++it) {
43  for (auto sub : it->blocks())
44  CanonicalizeOps(sub);
45  // For the case where we have an addmm where alpha and beta are Attributes
46  // and both of those scalars are equal to 1.0, decompose this into an mm
47  // followed by an add so that it can go through the existing optimization,
48  // shape analysis and differentiation passes for those two individual ops.
49  // Later, we will fuse together those two ops into a single addmm.
50  if (it->matches(
51  "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor",
52  /*const_inputs=*/{attr::beta, attr::alpha})) {
53  if (it->get<at::Scalar>(attr::alpha)->toDouble() != 1.0 ||
54  it->get<at::Scalar>(attr::beta)->toDouble() != 1.0) {
55  continue;
56  }
57 
58  WithInsertPoint guard(*it);
59 
60  SymbolicVariable mat(it->inputs()[0]);
61  SymbolicVariable mat1(it->inputs()[1]);
62  SymbolicVariable mat2(it->inputs()[2]);
63 
64  auto mm_result = mat1.mm(mat2);
65  // Set this intermediate aten::mm node to have the same output type as the
66  // original aten::addmm otherwise the canonicalized graph will have
67  // DynamicType as the output of this node which is incorrect
68  (static_cast<Value*>(mm_result))->setType(it->output()->type());
69  auto result = mat + mm_result;
70  (static_cast<Value*>(result))->setType(it->output()->type());
71 
72  it->output()->replaceAllUsesWith(result);
73  it.destroyCurrent();
74  } else if (
75  it->matches(
76  "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
77  it->matches(
78  "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
79  it->matches("aten::mul(Tensor self, Tensor other) -> Tensor") ||
80  it->matches("aten::div(Tensor self, Tensor other) -> Tensor")) {
81  if (auto other = it->get<at::Tensor>(attr::other)) {
82  if (other->dim() == 0) {
83  WithInsertPoint insert_guard{*it};
84  auto graph = it->owningGraph();
85  auto new_other = graph->insertConstant(other->item());
86  std::vector<Value*> inputs = it->inputs().vec();
87  inputs.at(1) = new_other;
88  Value* new_output =
89  graph->insertNode(graph->create(it->kind(), inputs))->output();
90  it->output()->replaceAllUsesWith(new_output);
91  }
92  }
93  } else if (it->matches(
94  "aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]",
95  /*const_inputs=*/{attr::chunks, attr::dim})) {
96  if (auto orig_outputs = getChunkOutputs(*it)) {
97  WithInsertPoint guard(*it);
98  SymbolicVariable self{it->namedInput(attr::self)};
99  auto outputs = self.chunk(
100  it->get<int64_t>(attr::chunks).value(),
101  it->get<int64_t>(attr::dim).value());
102  for (ChunkOutput orig_out : *orig_outputs) {
103  orig_out.val->replaceAllUsesWith(outputs.at(orig_out.offset));
104  outputs[orig_out.offset].value()->setType(orig_out.val->type());
105  }
106  }
107  }
108  }
109 }
110 
111 void CanonicalizeOps(const std::shared_ptr<Graph>& graph) {
112  CanonicalizeOps(graph->block());
113  EliminateDeadCode(graph);
114 }
115 
116 } // namespace jit
117 } // namespace torch
Scalar represents a 0-dimensional tensor which contains a single element.
Definition: Scalar.h:22
Definition: jit_type.h:17
An utility class for setting temporary insertion points.
Definition: ir.h:1174