1 #include <torch/csrc/jit/passes/canonicalize_ops.h> 2 #include <torch/csrc/jit/passes/dead_code_elimination.h> 3 #include <torch/csrc/jit/symbolic_variable.h> 15 std::vector<ChunkOutput> outputs;
16 for (
auto list_use : chunk->output()->uses()) {
17 if (list_use.user->matches(
18 "aten::select(Tensor[] list, int idx) -> Tensor", attr::idx)) {
20 list_use.user->output(),
21 list_use.user->get<int64_t>(attr::idx).value());
22 }
else if (list_use.user->kind() == prim::ListUnpack) {
25 if (static_cast<int64_t>(list_use.user->outputs().size()) !=
26 chunk->get<int64_t>(attr::chunks).value()) {
29 auto unpack_outputs = list_use.user->outputs();
30 for (
size_t i = 0; i < unpack_outputs.size(); ++i) {
31 outputs.emplace_back(unpack_outputs[i], i);
40 static void CanonicalizeOps(
Block* block) {
41 for (
auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
43 for (
auto sub : it->blocks())
51 "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor",
52 {attr::beta, attr::alpha})) {
53 if (it->get<
at::Scalar>(attr::alpha)->toDouble() != 1.0 ||
54 it->get<
at::Scalar>(attr::beta)->toDouble() != 1.0) {
64 auto mm_result = mat1.mm(mat2);
68 (
static_cast<Value*
>(mm_result))->setType(it->output()->type());
69 auto result = mat + mm_result;
70 (
static_cast<Value*
>(result))->setType(it->output()->type());
72 it->output()->replaceAllUsesWith(result);
76 "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
78 "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
79 it->matches(
"aten::mul(Tensor self, Tensor other) -> Tensor") ||
80 it->matches(
"aten::div(Tensor self, Tensor other) -> Tensor")) {
81 if (
auto other = it->get<
at::Tensor>(attr::other)) {
82 if (other->dim() == 0) {
84 auto graph = it->owningGraph();
85 auto new_other = graph->insertConstant(other->item());
86 std::vector<Value*> inputs = it->inputs().vec();
87 inputs.at(1) = new_other;
89 graph->insertNode(graph->create(it->kind(), inputs))->output();
90 it->output()->replaceAllUsesWith(new_output);
93 }
else if (it->matches(
94 "aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]",
95 {attr::chunks, attr::dim})) {
96 if (
auto orig_outputs = getChunkOutputs(*it)) {
99 auto outputs =
self.chunk(
100 it->get<int64_t>(attr::chunks).value(),
101 it->get<int64_t>(attr::dim).value());
103 orig_out.val->replaceAllUsesWith(outputs.at(orig_out.offset));
104 outputs[orig_out.offset].value()->setType(orig_out.val->type());
111 void CanonicalizeOps(
const std::shared_ptr<Graph>& graph) {
112 CanonicalizeOps(graph->block());
113 EliminateDeadCode(graph);
Scalar represents a 0-dimensional tensor which contains a single element.
An utility class for setting temporary insertion points.