1 #include <torch/csrc/jit/passes/peephole.h> 3 #include <torch/csrc/jit/symbolic_variable.h> 5 #include <torch/csrc/jit/passes/dead_code_elimination.h> 25 void PeepholeOptimizeImpl(Block* block,
bool addmm_fusion_enabled) {
26 for (
auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
29 for (Block* sub_block : node->blocks()) {
30 PeepholeOptimizeImpl(sub_block, addmm_fusion_enabled);
37 "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor",
40 if (
auto input_type = node->namedInput(attr::self)
42 ->cast<CompleteTensorType>()) {
43 auto expanded_sizes = node->get<std::vector<int64_t>>(attr::size);
44 if (expanded_sizes == input_type->sizes()) {
45 node->output()->replaceAllUsesWith(node->namedInput(attr::self));
48 }
else if (node->matches(
"aten::t(Tensor self) -> Tensor")) {
50 Node* input_node = node->input()->node();
51 if (input_node->matches(
"aten::t(Tensor self) -> Tensor")) {
52 node->output()->replaceAllUsesWith(input_node->input());
54 }
else if (node->matches(
55 "aten::type_as(Tensor self, Tensor other) -> Tensor")) {
57 auto self_type = node->input(0)->type()->cast<DimensionedTensorType>();
58 auto other_type = node->input(1)->type()->cast<DimensionedTensorType>();
59 if (self_type && other_type &&
60 self_type->scalarType() == other_type->scalarType() &&
61 self_type->device() == other_type->device()) {
62 node->output()->replaceAllUsesWith(node->input(0));
66 "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
80 if (addmm_fusion_enabled &&
81 node->get<
at::Scalar>(attr::alpha).value().toDouble() == 1.) {
83 for (
size_t mm_side = 0; mm_side < 2; mm_side++) {
91 node->input(1 - mm_side)->type()->cast<DimensionedTensorType>();
95 if (node->input(mm_side)->node()->matches(
96 "aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
97 WithInsertPoint guard(node);
99 auto mm_node = node->input(mm_side)->node();
100 SymbolicVariable add_mat(node->input(1 - mm_side));
101 SymbolicVariable mat1(mm_node->input(0));
102 SymbolicVariable mat2(mm_node->input(1));
104 auto mat_type = mat1.value()->type()->cast<DimensionedTensorType>();
106 mat_type = mat2.value()->type()->cast<DimensionedTensorType>();
112 if (add_mat_type->dim() == 0 &&
114 add_mat_type->scalarType() != mat_type->scalarType())) {
115 add_mat = add_mat.type_as(mat1);
118 SymbolicVariable addmm_value = add_mat.addmm(mat1, mat2);
121 ((Value*)addmm_value)->copyMetadata(node->output());
122 node->output()->replaceAllUsesWith(addmm_value);
130 "aten::mul(Tensor self, Scalar other) -> Tensor",
133 "aten::div(Tensor self, Scalar other) -> Tensor",
136 if (node->get<
at::Scalar>(attr::other)->toDouble() == 1) {
137 node->output()->replaceAllUsesWith(node->input(0));
141 "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
142 {attr::alpha, attr::other}) ||
144 "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
145 {attr::alpha, attr::other})) {
147 if (node->get<
at::Scalar>(attr::alpha)->toDouble() == 1 &&
148 node->get<
at::Scalar>(attr::other)->toDouble() == 0) {
149 node->output()->replaceAllUsesWith(node->input(0));
152 node->kind() == prim::Float || node->kind() == prim::Int ||
153 node->kind() == prim::ImplicitTensorToNum) {
154 Node* input_node = node->input()->node();
155 if (input_node->kind() == prim::NumToTensor) {
156 node->output()->replaceAllUsesWith(input_node->input());
160 "aten::_grad_sum_to_size(Tensor(a) self, int[] size) -> Tensor(a)")) {
161 auto uses = node->output()->uses();
164 "aten::_grad_sum_to_size(Tensor(a) self, int[] size) -> Tensor(a)")) {
165 u.user->replaceInput(0, node->inputs().at(0));
172 void PeepholeOptimize(Block* block,
bool addmm_fusion_enabled) {
173 PeepholeOptimizeImpl(block, addmm_fusion_enabled);
175 EliminateDeadCode(block);
178 void PeepholeOptimize(
179 const std::shared_ptr<Graph>& graph,
180 bool addmm_fusion_enabled) {
181 PeepholeOptimize(graph->block(), addmm_fusion_enabled);
Scalar represents a 0-dimensional tensor which contains a single element.