1 #include <torch/csrc/jit/passes/batch_mm.h> 3 #include <ATen/core/functional.h> 4 #include <ATen/core/interned_strings.h> 5 #include <c10/util/Exception.h> 6 #include <torch/csrc/jit/constants.h> 7 #include <torch/csrc/jit/custom_operator.h> 8 #include <torch/csrc/jit/passes/alias_analysis.h> 9 #include <torch/csrc/jit/passes/dead_code_elimination.h> 10 #include <torch/csrc/jit/passes/peephole.h> 11 #include <torch/csrc/jit/symbolic_variable.h> 13 #include <ATen/ATen.h> 15 #include <unordered_map> 76 static constexpr
size_t min_fusion_size = 4;
79 auto expected_sizes = inputs[0].sizes();
81 inputs.begin(), inputs.end(), [expected_sizes](
const at::Tensor& t) {
82 return t.sizes() == expected_sizes;
87 size_t l = lhs.size(0);
88 size_t m = lhs.size(1);
89 size_t r = rhs.size(1);
91 return m < 512 || ((l < 256 && r < 256) || (l > 256 && r > 256));
94 RegisterOperators mm_tree_reduction_reg(
95 {Operator(prim::MMTreeReduce, [](
const Node* node) {
96 size_t num_inputs = node->inputs().size();
97 return [num_inputs](Stack& stack) {
98 std::vector<at::Tensor> inputs;
99 inputs.reserve(num_inputs);
100 for (
auto it = stack.end() - num_inputs; it != stack.end(); ++it) {
101 inputs.push_back(std::move(*it).toTensor());
103 drop(stack, num_inputs);
105 AT_ASSERT(inputs.size() > 0);
106 AT_ASSERT(inputs.size() % 2 == 0);
107 size_t side_num_elems = inputs.size() / 2;
112 if (have_same_shape(lhs_inputs) && have_same_shape(rhs_inputs) &&
113 shape_is_fast_for_reduce(lhs_inputs[0], rhs_inputs[0])) {
114 auto lhs = at::cat(lhs_inputs, 1);
115 auto rhs = at::cat(rhs_inputs, 0);
116 push(stack, at::mm(lhs, rhs));
118 auto acc = at::mm(inputs[0], inputs[side_num_elems]);
119 for (
size_t i = 1; i < side_num_elems; ++i) {
120 acc.add_(at::mm(inputs[i], inputs[side_num_elems + i]));
122 push(stack, std::move(acc));
134 uint64_t tree_size = 0;
135 Node* node =
nullptr;
136 bool is_root =
false;
142 token.is_root =
true;
150 if (!inp_token.node->matches(
151 "aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
156 token.is_root =
true;
157 inp_token.is_root =
false;
166 if (&l == &r || !l.is_root || !r.is_root)
168 token.tree_size = l.tree_size + r.tree_size;
170 token.is_root =
true;
171 l.is_root = r.is_root =
176 explicit operator bool() {
180 std::vector<Node*> removeTransposesAndGatherMatmuls() {
181 std::vector<Node*> matmuls;
182 std::vector<Node*> queue{node};
183 Graph* graph = node->owningGraph();
184 while (!queue.empty()) {
185 auto n = queue.back();
187 if (n->matches(
"aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
188 matmuls.push_back(n);
189 }
else if (n->matches(
"aten::t(Tensor self) -> Tensor")) {
190 Node* input_node = n->input()->node();
191 AT_ASSERT(input_node->matches(
192 "aten::mm(Tensor self, Tensor mat2) -> Tensor"));
195 Value*
A = input_node->inputs()[0];
196 Value*
B = input_node->inputs()[1];
197 Value* AT = graph->insert(aten::t, {A});
198 Value* BT = graph->insert(aten::t, {B});
199 Value* BTAT = graph->insert(aten::mm, {BT, AT});
200 n->output()->replaceAllUsesWith(BTAT);
201 matmuls.push_back(BTAT->node());
204 "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
205 queue.push_back(n->inputs()[0]->node());
206 queue.push_back(n->inputs()[1]->node());
208 AT_ASSERTM(
false,
"Unsupported node found in a BatchMM tree!");
215 enum class Side { LHS, RHS };
217 void BatchMMTreeReduce(
Block* block) {
218 auto graph = block->owningGraph();
221 std::unordered_map<Node*, TreeToken> tokens;
222 for (
auto node : block->nodes()) {
223 if (node->matches(
"aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
224 tokens[node] = TreeToken::mm(node);
225 }
else if (node->matches(
"aten::t(Tensor self) -> Tensor")) {
226 auto input_it = tokens.find(node->input()->node());
227 if (input_it != tokens.end()) {
228 tokens[node] = TreeToken::transpose(node, input_it->second);
232 "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
233 Node* lhs = node->inputs()[0]->node();
234 Node* rhs = node->inputs()[1]->node();
235 auto lhs_it = tokens.find(lhs);
236 auto rhs_it = tokens.find(rhs);
244 if (lhs_it != tokens.end() && rhs_it != tokens.end() &&
245 lhs->output()->uses().size() == 1 &&
246 rhs->output()->uses().size() == 1) {
247 if (
auto token = TreeToken::add(node, lhs_it->second, rhs_it->second)) {
248 tokens[node] = token;
252 for (
auto block : node->blocks()) {
253 BatchMMTreeReduce(block);
259 for (
auto& item : tokens) {
260 auto& root = item.second;
261 if (!root || root.tree_size < min_fusion_size)
263 auto matmuls = root.removeTransposesAndGatherMatmuls();
266 graph->insertNode(graph->create(Symbol::prim(
"MMTreeReduce")));
267 for (
Node* matmul : matmuls) {
268 tree_reduce->addInput(matmul->inputs().at(0));
270 for (
Node* matmul : matmuls) {
271 tree_reduce->addInput(matmul->inputs().at(1));
273 root.node->output()->replaceAllUsesWith(tree_reduce->output());
278 bool shape_is_fast_for_side(
const at::Tensor& other_side_input) {
280 return other_side_input.numel() <= 1024 * 2048;
285 size_t num_other_side_inputs = node->inputs().size() - 1;
286 Side single_side =
static_cast<Side
>(node->i(Symbol::attr(
"side")));
287 return [num_other_side_inputs, single_side](Stack& stack) {
289 std::vector<at::Tensor> other_side_inputs;
290 other_side_inputs.reserve(num_other_side_inputs);
291 for (
auto it = stack.end() - num_other_side_inputs; it != stack.end();
293 other_side_inputs.push_back(std::move(*it).toTensor());
295 drop(stack, num_other_side_inputs);
296 pop(stack, side_input);
298 auto any_other_input = other_side_inputs[0];
299 if (have_same_shape(other_side_inputs) &&
300 shape_is_fast_for_side(other_side_inputs[0])) {
301 auto other_side_input =
302 at::cat(other_side_inputs, single_side == Side::LHS ? 1 : 0);
303 auto mm_out = single_side == Side::LHS
304 ? side_input.mm(other_side_input)
305 : other_side_input.mm(side_input);
306 auto outputs = at::chunk(
308 num_other_side_inputs,
309 single_side == Side::LHS ? 1 : 0);
312 std::make_move_iterator(outputs.begin()),
313 std::make_move_iterator(outputs.end()));
315 if (single_side == Side::LHS) {
317 stack.emplace_back(side_input.mm(other));
321 stack.emplace_back(other.mm(side_input));
330 std::pair<std::vector<Node*>, std::vector<Node*>> gatherIndependentMMUses(
333 const auto postprocess = [&](std::vector<Node*> mms) {
334 if (mms.size() == 0) {
337 std::sort(mms.begin(), mms.end(), [](
Node* n,
Node* m) {
338 return n->isBefore(m);
343 for (
size_t i = 0; i < mms.size(); ++i) {
344 if (mms[i] ==
nullptr)
346 for (
size_t j = i + 1; j < mms.size(); ++j) {
347 if (mms[j] ==
nullptr)
349 if (!alias_db.couldMoveBeforeTopologically(mms[j], mms[i])) {
354 return c10::filter(mms, [](
Node* n) {
return n !=
nullptr; });
357 Block* block = value->node()->owningBlock();
358 std::vector<Node*> lhses;
359 std::vector<Node*> rhses;
360 for (
Use u : value->uses()) {
361 if (u.user->owningBlock() == block &&
362 u.user->matches(
"aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
363 if (u.offset == 0 && u.user->inputs()[1] != value) {
364 lhses.push_back(u.user);
365 }
else if (u.offset == 1 && u.user->inputs()[0] != value) {
366 rhses.push_back(u.user);
370 return std::make_pair(postprocess(lhses), postprocess(rhses));
375 static constexpr
size_t how_many_is_many = 8;
376 const auto batch_side = [&](std::vector<Node*>& mms, Side side) {
377 AT_ASSERT(!mms.empty());
378 for (int64_t i = static_cast<int64_t>(mms.size()) - 2; i >= 0; --i) {
379 bool move_ok = alias_db.moveBeforeTopologicallyValid(mms[i], mms[i + 1]);
383 Graph* graph = mms[0]->owningGraph();
384 Node* batch_mm = graph->create(
388 graph->insertNode(batch_mm);
389 batch_mm->i_(Symbol::attr(
"side"), static_cast<int>(side));
390 Value* const_side = mms[0]->inputs().at(side == Side::LHS ? 0 : 1);
391 batch_mm->addInput(const_side);
392 for (
size_t i = 0; i < mms.size(); ++i) {
393 batch_mm->addInput(mms[i]->inputs().
at(side == Side::LHS ? 1 : 0));
394 mms[i]->output()->replaceAllUsesWith(batch_mm->outputs().at(i));
398 std::unordered_set<Value*> considered_values;
399 for (
Node* node : block->nodes()) {
400 if (node->matches(
"aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
401 for (
Value* input : node->inputs()) {
402 if ( !considered_values.emplace(input).second) {
405 auto uses_with_many = gatherIndependentMMUses(input, alias_db);
406 if (uses_with_many.first.size() >= how_many_is_many) {
407 batch_side(uses_with_many.first, Side::LHS);
409 if (uses_with_many.second.size() >= how_many_is_many) {
410 batch_side(uses_with_many.second, Side::RHS);
414 for (
Block* subblock : node->blocks()) {
415 BatchMMSide(subblock, alias_db);
421 bool hasMutableOperators(
Block* block) {
422 for (
auto n : block->nodes()) {
423 if (n->kind().is_aten() && n->schema().is_mutable())
425 for (
auto b : n->blocks()) {
426 if (hasMutableOperators(b))
433 void BatchMM(std::shared_ptr<Graph>& graph) {
434 if (hasMutableOperators(graph->block())) {
439 BatchMMTreeReduce(graph->block());
440 BatchMMSide(graph->block(), alias_db);
441 EliminateDeadCode(graph);
444 PeepholeOptimize(graph);
AT_CPP14_CONSTEXPR ArrayRef< T > slice(size_t N, size_t M) const
slice(n, m) - Chop off the first N elements of the array, and keep M elements in the array...
Registration class for new operators.
An utility class for setting temporary insertion points.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Flush-To-Zero and Denormals-Are-Zero mode.