1 #include <torch/csrc/jit/passes/graph_fuser.h> 3 #include <ATen/ExpandUtils.h> 4 #include <c10/util/Exception.h> 5 #include <torch/csrc/jit/autodiff.h> 6 #include <torch/csrc/jit/custom_operator.h> 7 #include <torch/csrc/jit/fuser/interface.h> 8 #include <torch/csrc/jit/operator.h> 9 #include <torch/csrc/jit/passes/alias_analysis.h> 10 #include <torch/csrc/jit/passes/common_subexpression_elimination.h> 11 #include <torch/csrc/jit/passes/constant_pooling.h> 12 #include <torch/csrc/jit/passes/dead_code_elimination.h> 13 #include <torch/csrc/jit/passes/utils/subgraph_utils.h> 14 #include <torch/csrc/jit/script/compiler.h> 15 #include <torch/csrc/jit/symbolic_variable.h> 18 #include <unordered_map> 35 bool isSimpleMap(Node* node) {
36 static OperatorSet simple_mappable{{
37 "aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor",
39 "aten::abs(Tensor self) -> Tensor",
40 "aten::acos(Tensor self) -> Tensor",
41 "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
42 "aten::asin(Tensor self) -> Tensor",
43 "aten::atan(Tensor self) -> Tensor",
44 "aten::atan2(Tensor self, Tensor other) -> Tensor",
45 "aten::ceil(Tensor self) -> Tensor",
46 "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
47 "aten::cos(Tensor self) -> Tensor",
48 "aten::cosh(Tensor self) -> Tensor",
49 "aten::div(Tensor self, Tensor other) -> Tensor",
50 "aten::exp(Tensor self) -> Tensor",
51 "aten::expm1(Tensor self) -> Tensor",
52 "aten::erf(Tensor self) -> Tensor",
53 "aten::erfc(Tensor self) -> Tensor",
54 "aten::floor(Tensor self) -> Tensor",
55 "aten::fmod(Tensor self, Tensor other) -> Tensor",
56 "aten::frac(Tensor self) -> Tensor",
57 "aten::lgamma(Tensor self) -> Tensor",
58 "aten::log(Tensor self) -> Tensor",
59 "aten::log10(Tensor self) -> Tensor",
60 "aten::log1p(Tensor self) -> Tensor",
61 "aten::log2(Tensor self) -> Tensor",
62 "aten::max(Tensor self, Tensor other) -> Tensor",
63 "aten::min(Tensor self, Tensor other) -> Tensor",
64 "aten::mul(Tensor self, Tensor other) -> Tensor",
65 "aten::neg(Tensor self) -> Tensor",
66 "aten::pow(Tensor self, Tensor exponent) -> Tensor",
67 "aten::pow(Tensor self, Scalar exponent) -> Tensor",
68 "aten::rand_like(Tensor self) -> Tensor",
69 "aten::reciprocal(Tensor self) -> Tensor",
70 "aten::relu(Tensor self) -> Tensor",
71 "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",
72 "aten::remainder(Tensor self, Tensor other) -> Tensor",
73 "aten::round(Tensor self) -> Tensor",
74 "aten::rsqrt(Tensor self) -> Tensor",
75 "aten::sigmoid(Tensor self) -> Tensor",
76 "aten::sin(Tensor self) -> Tensor",
77 "aten::sinh(Tensor self) -> Tensor",
78 "aten::sqrt(Tensor self) -> Tensor",
79 "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
80 "aten::tan(Tensor self) -> Tensor",
81 "aten::tanh(Tensor self) -> Tensor",
82 "aten::trunc(Tensor self) -> Tensor",
83 "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
84 "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
85 "aten::mul(Tensor self, Scalar other) -> Tensor",
86 "aten::div(Tensor self, Scalar other) -> Tensor",
88 "aten::eq(Tensor self, Tensor other) -> Tensor",
89 "aten::eq(Tensor self, Scalar other) -> Tensor",
90 "aten::ne(Tensor self, Tensor other) -> Tensor",
91 "aten::ne(Tensor self, Scalar other) -> Tensor",
92 "aten::ge(Tensor self, Tensor other) -> Tensor",
93 "aten::ge(Tensor self, Scalar other) -> Tensor",
94 "aten::gt(Tensor self, Tensor other) -> Tensor",
95 "aten::gt(Tensor self, Scalar other) -> Tensor",
96 "aten::le(Tensor self, Tensor other) -> Tensor",
97 "aten::le(Tensor self, Scalar other) -> Tensor",
98 "aten::lt(Tensor self, Tensor other) -> Tensor",
99 "aten::lt(Tensor self, Scalar other) -> Tensor",
101 "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
103 "aten::type_as(Tensor self, Tensor other) -> Tensor",
105 if (!simple_mappable.find(node)) {
109 for (Value* input : node->inputs()) {
110 if (input->type()->isSubtypeOf(TensorType::get())) {
113 if (input->node()->kind() != prim::Constant) {
120 RegisterOperators reg_bn_unsqueeze({Operator(
121 "aten::_ncf_unsqueeze(Tensor self, int ndim) -> Tensor",
122 [](
const Node* node) {
123 return [](Stack& stack) {
124 const int64_t ndim = pop(stack).toInt();
125 auto self = pop(stack).toTensor();
127 AT_ASSERT(
self.dim() == 1);
128 sizes.at(1) =
self.size(0);
129 push(stack,
self.reshape(sizes));
136 if (tensor->type()->isSubtypeOf(TensorType::get())) {
139 if (tensor->node()->mustBeNone()) {
145 bool isFusableBatchNorm(Node* batch_norm) {
146 if (!batch_norm->matches(
147 "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")) {
153 return isDefined(batch_norm->namedInput(attr::weight)).has_value() &&
154 isDefined(batch_norm->namedInput(attr::bias)).has_value();
158 AT_ASSERT(!sizes.
empty());
159 Graph* graph = sizes[0]->owningGraph();
161 graph->insertNode(graph->create(prim::BroadcastSizes, sizes));
162 broadcast_n->output()->setType(ListType::ofInts());
163 return broadcast_n->output();
168 std::unique_ptr<AliasDb> aliasDb_;
169 std::shared_ptr<Graph> graph_;
171 GraphFuser(Block* block, std::shared_ptr<Graph> graph)
172 : block_(block), graph_(
std::move(graph)) {}
174 value_list tensorInputs(Node* node) {
175 return filter(node->inputs(), [](Value* v) {
176 return v->type()->isSubtypeOf(TensorType::get());
180 bool containsGradSumToSize(Node* fusion_group) {
181 auto nodes = getSubgraph(fusion_group).nodes();
182 return std::any_of(nodes.begin(), nodes.end(), [](Node* n) {
183 return n->kind() == aten::_grad_sum_to_size;
187 bool isFusable(Node* node) {
188 return isFusableMap(node) || isFusableBatchNorm(node);
191 bool isFusableMap(Node* node) {
194 if (node->owningBlock() != block_)
196 if (node->kind() == aten::_grad_sum_to_size) {
202 return isFusable(node->inputs()[0]->node());
204 return node->kind() == prim::FusionGroup || isSimpleMap(node);
207 bool isFusableCatNode(Node* node) {
208 if (node->kind() != aten::cat)
210 if (!node->is_constant(attr::dim))
212 auto tensors_node = node->namedInput(attr::tensors)->node();
213 if (tensors_node->kind() != prim::ListConstruct)
220 if (tensors_node->output()->uses().size() > 1)
225 bool calculatesSize(Node* node) {
226 return node->matches(
"aten::size(Tensor self) -> int[]");
229 bool allUsersAreThisConsumerOrCalcSizes(Node* consumer, Value* producer) {
230 auto defining_node = producer->node();
231 for (
auto o : defining_node->outputs()) {
232 for (
auto u : o->uses()) {
233 if (u.user != consumer && !calculatesSize(u.user))
240 Graph& getSubgraph(Node* n) {
241 AT_ASSERT(n->kind() == prim::FusionGroup);
242 return *n->g(attr::Subgraph);
245 void decomposeBatchNorm(Node* batch_norm) {
246 static std::shared_ptr<Graph> bn_graph;
247 static std::once_flag flag;
250 [](std::shared_ptr<Graph>* graph_ptr) {
251 static const char* source = R
"SCRIPT( 252 def batch_norm(input : Tensor, running_mean : Optional[Tensor], running_var : Optional[Tensor], training : bool, momentum : float, eps : float) -> Tensor: 254 norm_mean, norm_var = torch.batch_norm_update_stats(input, running_mean, running_var, momentum) 256 norm_mean = torch._unwrap_optional(running_mean) 257 norm_var = torch._unwrap_optional(running_var) 258 norm_mean = torch._ncf_unsqueeze(norm_mean, input.dim()) 259 norm_var = torch._ncf_unsqueeze(norm_var, input.dim()) 260 norm_invstd = 1 / (eps + torch.sqrt(norm_var)) 261 return ((input - norm_mean) * norm_invstd) 263 auto module = std::make_shared<script::Module>();
264 defineMethodsInModule(
265 module, source, script::nativeResolver, c10::nullopt);
266 *graph_ptr = module->get_method(
"batch_norm").graph();
270 AT_ASSERT(isFusableBatchNorm(batch_norm));
271 WithInsertPoint insert_guard{batch_norm};
272 Value* input = batch_norm->namedInput(attr::input);
273 Value* input_dim = graph_->insert(aten::dim, {input});
274 std::vector<Value*> inputs{input,
275 batch_norm->namedInput(attr::running_mean),
276 batch_norm->namedInput(attr::running_var),
277 batch_norm->namedInput(attr::training),
278 batch_norm->namedInput(attr::momentum),
279 batch_norm->namedInput(attr::eps)};
281 SubgraphUtils::inlineGraph(bn_graph, inputs, batch_norm).at(0);
282 auto weight = batch_norm->namedInput(attr::weight);
283 auto bias = batch_norm->namedInput(attr::bias);
284 if (isDefined(weight).value()) {
285 Value* expanded_weight =
286 graph_->insert(aten::_ncf_unsqueeze, {weight, input_dim});
287 new_output = graph_->insert(aten::mul, {new_output, expanded_weight});
289 if (isDefined(bias).value()) {
290 Value* expanded_bias =
291 graph_->insert(aten::_ncf_unsqueeze, {bias, input_dim});
292 new_output = graph_->insert(aten::add, {new_output, expanded_bias});
294 batch_norm->output()->replaceAllUsesWith(new_output);
295 batch_norm->destroy();
298 void mergeFusionGroups(Node* consumer_group, Node* producer_group) {
302 std::vector<Node*> temporary_nodes;
303 auto producer_subgraph = &getSubgraph(producer_group);
306 std::unordered_map<Value*, Value*> inner_to_outer;
307 auto inner_inputs = producer_subgraph->inputs();
308 auto outer_inputs = producer_group->inputs();
309 for (
size_t i = 0; i < inner_inputs.size(); ++i) {
310 inner_to_outer[inner_inputs[i]] = outer_inputs[i];
314 for (
auto inner : producer_subgraph->nodes()) {
315 Node* outer = block_->owningGraph()->createClone(
316 inner, [&](Value* k) -> Value* {
return inner_to_outer.at(k); });
317 outer->insertBefore(producer_group);
318 temporary_nodes.emplace_back(outer);
319 auto inner_outputs = inner->outputs();
320 auto outer_outputs = outer->outputs();
321 for (
size_t i = 0; i < inner_outputs.size(); ++i)
322 inner_to_outer[inner_outputs[i]] = outer_outputs[i];
326 auto subgraph_outputs = producer_subgraph->outputs();
327 for (
size_t i = 0; i < subgraph_outputs.size(); ++i) {
328 auto outer_output = inner_to_outer.at(subgraph_outputs[i]);
329 producer_group->outputs()[i]->replaceAllUsesWith(outer_output);
331 producer_group->destroy();
336 auto consumer_subgraph = &getSubgraph(consumer_group);
337 for (
auto it = temporary_nodes.rbegin(); it != temporary_nodes.rend();
340 Node* merged = mergeNodeIntoGroup(consumer_group, node);
342 auto outputs = node->outputs();
343 for (
size_t i = 0; i < outputs.size(); ++i) {
344 auto output = outputs[i];
345 if (output->uses().size() == 0)
347 consumer_subgraph->registerOutput(merged->outputs()[i]);
348 auto new_output = consumer_group->addOutput();
349 output->replaceAllUsesWith(new_output);
350 new_output->setType(output->type());
359 Node* mergeNodeIntoGroup(Node* group, Node* n) {
360 AT_ASSERT(n->kind() != prim::FusionGroup);
361 auto& subgraph = getSubgraph(group);
364 std::unordered_map<Value*, Value*> inputs_map;
366 size_t tensor_insert_idx = 0;
367 AT_ASSERT(group->inputs().size() == subgraph.inputs().size());
368 for (
auto input : group->inputs()) {
369 inputs_map[input] = subgraph.inputs()[i++];
370 if (input->type()->isSubtypeOf(TensorType::get()))
371 tensor_insert_idx = i;
377 WithInsertPoint guard(*subgraph.nodes().begin());
378 for (
auto input : n->inputs()) {
379 if (inputs_map.count(input) == 0) {
380 if (input->type()->isSubtypeOf(TensorType::get())) {
381 auto in_group = subgraph.insertInput(tensor_insert_idx);
382 in_group->setType(input->type());
383 inputs_map[input] = in_group;
384 group->insertInput(tensor_insert_idx, input);
387 n->kind() == aten::_grad_sum_to_size &&
388 input->type()->isSubtypeOf(ListType::ofInts())) {
389 auto in_group = subgraph.addInput();
390 in_group->setType(input->type());
391 inputs_map[input] = in_group;
392 group->addInput(input);
398 AT_ASSERT(input->node()->kind() == prim::Constant);
400 subgraph.createClone(input->node(), [](Value*) -> Value* {
401 throw std::runtime_error(
"unexpected input");
403 subgraph.insertNode(in_const);
404 inputs_map[input] = in_const->output();
409 Node* in_graph = subgraph.createClone(
410 n, [&](Value* k) -> Value* {
return inputs_map[k]; });
420 auto inputs = group->inputs();
421 for (
size_t i = 0; i < n->outputs().size(); ++i) {
422 auto it = std::find(inputs.begin(), inputs.end(), n->outputs()[i]);
423 if (it != inputs.end()) {
424 size_t p = it - inputs.begin();
425 group->removeInput(p);
426 subgraph.inputs()[p]->replaceAllUsesWith(in_graph->outputs()[i]);
427 subgraph.eraseInput(p);
430 return subgraph.insertNode(in_graph);
435 Node* createSingletonFusionGroup(Node* n) {
436 auto group = block_->owningGraph()->createFusionGroup();
439 group->insertBefore(n);
440 Node* mergedNode = mergeNodeIntoGroup(group, n);
441 getSubgraph(group).registerOutput(mergedNode->output());
442 auto sel = group->addOutput();
443 sel->copyMetadata(n->output());
444 n->replaceAllUsesWith(group);
450 void insertAt(Node** insertion_point, Node* n) {
451 n->insertAfter(*insertion_point);
452 *insertion_point = n;
463 bool shouldFuse = isFusable(producer->node()) &&
468 aliasDb_->moveBeforeTopologicallyValid(producer->node(), consumer);
473 if (producer->node()->kind() == aten::_grad_sum_to_size &&
474 consumer->kind() == prim::FusionGroup) {
479 for (
auto& u : producer->uses()) {
480 if (u.user == consumer) {
481 auto subgraph = &getSubgraph(consumer);
482 if (!trackSingleGradSumToSizeToOutputs(
483 subgraph->inputs().at(u.offset),
nullptr)) {
490 auto group = consumer;
491 if (consumer->kind() != prim::FusionGroup) {
492 group = createSingletonFusionGroup(consumer);
494 if (producer->node()->matches(
495 "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")) {
499 decomposeBatchNorm(producer->node());
502 if (producer->node()->kind() == prim::FusionGroup) {
503 mergeFusionGroups(group, producer->node());
506 AT_ASSERT(producer->node()->outputs().size() == 1);
507 Node* merged = mergeNodeIntoGroup(group, producer->node());
512 if (producer->uses().size() != 0) {
513 getSubgraph(group).registerOutput(merged->output());
514 Value* new_producer = group->addOutput();
515 new_producer->copyMetadata(producer);
516 producer->replaceAllUsesWith(new_producer);
518 producer->node()->destroy();
522 bool canFuseChunk(Node* consumer, Value* producer) {
523 if (consumer->kind() != prim::FusionGroup) {
527 auto* chunk = producer->node();
528 if (chunk->kind() != prim::ConstantChunk)
531 for (
auto s : chunk->outputs()) {
532 for (
auto u : s->uses()) {
533 if (u.user != consumer) {
540 if (chunk->i(attr::chunks) == 1) {
547 AT_ASSERT(group->kind() == prim::FusionGroup);
548 auto it = std::find(group->inputs().begin(), group->inputs().end(), input);
549 if (it == group->inputs().end()) {
552 size_t input_index = it - group->inputs().begin();
553 auto& subgraph = getSubgraph(group);
554 auto* subgraph_input = subgraph.inputs().at(input_index);
556 auto* node = subgraph_input->uses().at(0).user;
557 if (node->kind() == prim::ConstantChunk) {
558 AT_ASSERT(subgraph_input->uses().size() == 1);
564 void fuseChunkByReusingExistingFusedChunk(
567 Node* existingFusedChunk) {
568 if (chunk->outputs().size() != existingFusedChunk->outputs().size()) {
571 auto& subgraph = getSubgraph(group);
572 for (
size_t i = 0; i < chunk->outputs().size(); ++i) {
574 auto* replacement_val = existingFusedChunk->outputs().at(i);
575 auto* val = chunk->outputs().at(i);
576 auto it = std::find(group->inputs().begin(), group->inputs().end(), val);
577 auto input_index = it - group->inputs().begin();
580 auto group_input = subgraph.inputs().at(input_index);
581 group_input->replaceAllUsesWith(replacement_val);
584 group->removeInput(input_index);
585 subgraph.eraseInput(input_index);
594 graph_node_list::iterator fuseChunk(Node* consumer, Value* producer) {
595 auto* chunk = producer->node();
596 AT_ASSERT(consumer->kind() == prim::FusionGroup);
597 AT_ASSERT(chunk->kind() == prim::ConstantChunk);
601 auto* chunked_tensor = producer->node()->input();
602 if (
auto existingFusedChunk = findFusedChunk(consumer, chunked_tensor)) {
603 fuseChunkByReusingExistingFusedChunk(
604 consumer, chunk, *existingFusedChunk);
605 return consumer->reverseIterator();
609 mergeNodeIntoGroup(consumer, chunk);
611 return consumer->reverseIterator();
614 value_list sortReverseTopological(ArrayRef<Value*> inputs) {
616 for (
auto i : inputs) {
617 if (i->node()->owningBlock() == block_) {
622 std::sort(result.begin(), result.end(), [&](Value* a, Value* b) {
623 return a->node()->isAfter(b->node());
628 graph_node_list::iterator scanNodeForChunks(Node* consumer) {
629 if (consumer->kind() == prim::FusionGroup) {
630 auto inputs = sortReverseTopological(consumer->inputs());
631 for (
auto producer : inputs) {
632 if (!canFuseChunk(consumer, producer)) {
635 return fuseChunk(consumer, producer);
638 return ++consumer->reverseIterator();
641 void insertExplicitBroadcast(Node* node) {
642 WithInsertPoint insert_guard{node};
643 auto tensors = tensorInputs(node);
645 SymbolicVariable::broadcast_tensors(fmap<SymbolicVariable>(tensors));
648 auto new_tensors_it = new_tensors.begin();
649 for (
size_t i = 0; i < node->inputs().size(); ++i) {
650 if (node->inputs()[i]->type()->isSubtypeOf(TensorType::get())) {
651 AT_ASSERT(new_tensors_it != new_tensors.end());
652 node->replaceInput(i, *(new_tensors_it++));
657 Node* promoteChunkToBroadcastingChunk(Node* chunk) {
658 AT_ASSERT(chunk->kind() == prim::ConstantChunk);
660 size_t nchunks = chunk->i(attr::chunks);
662 chunk->owningGraph()->create(prim::BroadcastingChunk, nchunks);
663 bchunk->addInput(chunk->input());
664 for (
size_t i = 0; i < nchunks; ++i) {
665 auto* old_output = chunk->outputs().at(i);
666 auto* new_output = bchunk->outputs().at(i);
667 new_output->copyMetadata(old_output);
668 old_output->replaceAllUsesWith(new_output);
670 bchunk->copyAttributes(*chunk);
671 bchunk->insertAfter(chunk);
742 bool tryToMoveChunk(Node* consumer, Value* producer) {
744 auto* chunk = producer->node();
745 if (chunk->kind() != prim::ConstantChunk &&
746 chunk->kind() != prim::BroadcastingChunk)
751 auto it = std::find_if(
752 chunk->inputs().begin(),
753 chunk->inputs().end(),
754 [&](Value* producer_for_chunk) {
755 return isFusableMap(producer_for_chunk->node()) &&
756 allUsersAreThisConsumerOrCalcSizes(chunk, producer_for_chunk);
758 if (it == chunk->inputs().end()) {
761 Value* producer_for_chunk = *it;
762 size_t producer_index = it - chunk->inputs().begin();
765 for (
auto s : chunk->outputs()) {
766 for (
auto u : s->uses()) {
767 if (u.user != consumer)
772 Node* producer_for_chunk_node = producer_for_chunk->node();
773 AT_ASSERT(producer_for_chunk_node->outputs().size() == 1);
777 auto* bchunk = chunk;
778 if (chunk->kind() == prim::ConstantChunk) {
779 bchunk = promoteChunkToBroadcastingChunk(chunk);
781 size_t nchunks = bchunk->i(attr::chunks);
782 WithInsertPoint guard(bchunk->next());
784 std::vector<Value*> producer_chunk_outputs;
785 for (
size_t i = 0; i < nchunks; i++) {
786 producer_chunk_outputs.push_back(
787 bchunk->output(nchunks * producer_index + i));
793 std::vector<std::vector<Value*>> chunked_inputs;
795 for (
auto input : producer_for_chunk_node->inputs()) {
799 if (!input->type()->isSubtypeOf(TensorType::get()))
803 auto bchunk_inputs = bchunk->inputs();
804 auto it = std::find(bchunk_inputs.begin(), bchunk_inputs.end(), input);
805 if (it != bchunk_inputs.end()) {
806 chunked_inputs.emplace_back();
807 auto input_index = std::distance(bchunk_inputs.begin(), it);
808 for (
size_t chunk = 0; chunk < nchunks; ++chunk) {
809 chunked_inputs.back().push_back(
810 bchunk->outputs().at(nchunks * input_index + chunk));
821 bchunk->addInput(input);
822 chunked_inputs.emplace_back();
823 for (
auto chunk_sel : producer_chunk_outputs) {
824 Value* input_chunk_sel = bchunk->addOutput();
825 input_chunk_sel->setType(chunk_sel->type());
826 chunked_inputs.back().push_back(input_chunk_sel);
832 for (
auto chunk_sel : producer_chunk_outputs) {
833 auto original_inputs = producer_for_chunk_node->inputs();
835 block_->owningGraph()->create(producer_for_chunk_node->kind());
836 chunked_op->copyAttributes(*producer_for_chunk_node);
837 chunked_op->output()->setType(chunk_sel->type());
838 auto chunked_inputs_it = chunked_inputs.begin();
839 for (Value* original_input : original_inputs) {
840 if (original_input->type()->isSubtypeOf(TensorType::get())) {
841 AT_ASSERT(chunked_inputs_it != chunked_inputs.end());
842 chunked_op->addInput(
843 chunked_inputs_it->at(chunk_sel->offset() % nchunks));
846 chunked_op->addInput(original_input);
849 bchunk->owningGraph()->insertNode(chunked_op);
850 chunk_sel->replaceAllUsesWith(chunked_op->output());
853 bchunk->removeInput(producer_index);
854 for (
size_t i = 0; i < nchunks; i++) {
855 bchunk->eraseOutput(nchunks * producer_index);
861 auto size_calc_uses = producer_for_chunk_node->output()->uses();
862 if (!size_calc_uses.empty()) {
863 auto tensor_inputs = filter(
864 producer_for_chunk_node->inputs(),
865 [](Value* v) {
return v->type()->isSubtypeOf(TensorType::get()); });
866 auto tensor_sizes = fmap(tensor_inputs, [](Value* v) {
867 return v->owningGraph()->insert(aten::size, {v});
869 AT_ASSERT(!tensor_sizes.empty());
870 Value* output_size = tensor_sizes.size() == 1
872 : broadcastSizes(tensor_sizes);
873 for (Use u : size_calc_uses) {
874 u.user->output()->replaceAllUsesWith(output_size);
878 producer_for_chunk_node->destroy();
883 std::pair<graph_node_list::iterator, bool> scanNode(Node* consumer) {
884 if (isFusable(consumer)) {
888 auto inputs = sortReverseTopological(consumer->inputs());
889 for (
auto producer : inputs) {
890 if (tryToMoveChunk(consumer, producer)) {
893 return std::make_pair(consumer->reverseIterator(),
true);
895 auto fusion_group = tryFuse(consumer, producer);
899 return std::make_pair(fusion_group.value()->reverseIterator(),
true);
903 return std::make_pair(++consumer->reverseIterator(),
false);
906 void replaceIntermediateBroadcastingChunks() {
907 for (
auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
910 if (node->kind() != prim::BroadcastingChunk) {
914 insertExplicitBroadcast(bchunk);
916 auto* graph = block_->owningGraph();
917 size_t nchunks = bchunk->i(attr::chunks);
918 WithInsertPoint guard(bchunk->next());
921 for (
size_t input_offset = 0; input_offset < bchunk->inputs().size();
923 auto* input = bchunk->inputs().at(input_offset);
926 graph->insertNode(graph->create(prim::ConstantChunk, input, 0));
927 new_chunk->copyAttributes(*bchunk);
928 for (
size_t output_offset = 0; output_offset < nchunks;
930 auto new_output = new_chunk->addOutput();
932 bchunk->outputs().at(input_offset * nchunks + output_offset);
933 new_output->copyMetadata(old_output);
934 old_output->replaceAllUsesWith(new_output);
941 bool usedOnlyInSize(Value* v) {
942 const auto& uses = v->uses();
943 return std::all_of(uses.begin(), uses.end(), [](
const Use& u) {
944 return u.user->matches(
"aten::size(Tensor self) -> int[]");
951 std::unordered_map<Value*, Value*> buildShapeExpressions(Node* fusion_group) {
952 WithInsertPoint insert_guard{fusion_group->next()};
953 std::unordered_map<Value*, Value*> shape_of;
955 Graph* graph = fusion_group->owningGraph();
956 auto subgraph = fusion_group->g(attr::Subgraph);
958 auto inputs = fusion_group->inputs();
959 auto sinputs = subgraph->inputs();
960 AT_ASSERT(inputs.size() == sinputs.size());
961 for (
size_t i = 0; i < inputs.size(); ++i) {
962 if (inputs[i]->type()->isSubtypeOf(TensorType::get())) {
963 shape_of[sinputs[i]] = graph->insert(aten::size, {inputs[i]});
971 auto outputs = fusion_group->outputs();
972 auto soutputs = subgraph->outputs();
973 AT_ASSERT(outputs.size() == soutputs.size());
974 for (
size_t i = 0; i < outputs.size(); ++i) {
975 if (usedOnlyInSize(outputs[i]))
977 shape_of[soutputs[i]] = graph->insert(aten::size, {outputs[i]});
980 for (Node* n : subgraph->nodes()) {
983 if (n->kind() == prim::FusedConcat) {
990 if (n->kind() == prim::Constant) {
993 if (n->kind() == prim::ConstantChunk) {
994 Node* sizes_node = graph->insertNode(
995 graph->create(prim::ChunkSizes, shape_of.at(n->input()), 2));
996 sizes_node->i_(attr::dim, n->i(attr::dim));
997 sizes_node->i_(attr::chunks, n->i(attr::chunks));
998 Value* regular_size = sizes_node->outputs().at(0);
999 Value* last_size = sizes_node->outputs().at(1);
1000 regular_size->setType(ListType::ofInts());
1001 last_size->setType(ListType::ofInts());
1002 auto outputs = n->outputs();
1003 for (Value* o : outputs.slice(0, outputs.size() - 1)) {
1004 shape_of.emplace(o, regular_size);
1006 shape_of.emplace(outputs.at(outputs.size() - 1), last_size);
1009 auto tensor_inputs = filter(n->inputs(), [](Value* v) {
1010 return v->type()->isSubtypeOf(TensorType::get());
1013 fmap(tensor_inputs, [&](Value* v) {
return shape_of.at(v); });
1014 AT_ASSERT(!shapes.empty());
1016 n->output(), shapes.size() == 1 ? shapes[0] : broadcastSizes(shapes));
1021 void removeOutputsUsedOnlyInSize(Node* fusion_group) {
1022 if (fusion_group->kind() != prim::FusionGroup)
1024 auto subgraph = fusion_group->g(attr::Subgraph);
1026 auto shape_of = buildShapeExpressions(fusion_group);
1027 auto outputs = fusion_group->outputs().vec();
1028 auto soutputs = subgraph->outputs().vec();
1032 for (int64_t i = static_cast<int64_t>(outputs.size()) - 1; i >= 0; --i) {
1033 auto output = outputs[i];
1034 auto soutput = soutputs[i];
1035 if (usedOnlyInSize(output) && shape_of.count(soutput) > 0) {
1036 auto uses = output->uses();
1037 for (Use u : uses) {
1038 AT_ASSERT(u.user->matches(
"aten::size(Tensor self) -> int[]"));
1039 u.user->output()->replaceAllUsesWith(shape_of.at(soutput));
1042 fusion_group->eraseOutput(i);
1043 subgraph->eraseOutput(i);
1048 void refreshAliasDb() {
1049 aliasDb_ = torch::make_unique<AliasDb>(graph_);
1052 bool canFuseWithConcat(Value* producer, Node* before_check) {
1053 if (!isFusable(producer->node())) {
1058 if (!aliasDb_->couldMoveBeforeTopologically(
1059 producer->node(), before_check)) {
1065 if (producer->node()->kind() == prim::FusionGroup) {
1066 auto subgraph = producer->node()->g(attr::Subgraph);
1067 auto* node = subgraph->outputs().at(producer->offset())->node();
1068 return node->kind() != prim::FusedConcat &&
1069 !containsGradSumToSize(producer->node());
1074 Node* createFusedConcat(Node* node) {
1075 AT_ASSERT(node->kind() == aten::cat);
1077 Graph* graph = node->owningGraph();
1078 Node* list_construct = node->namedInput(attr::tensors)->node();
1079 int64_t dim = node->get<int64_t>(attr::dim).value();
1081 Node* fused_cat = graph->create(prim::FusedConcat, list_construct->inputs())
1082 ->i_(attr::dim, dim);
1083 fused_cat->insertBefore(list_construct);
1084 fused_cat->output()->copyMetadata(node->output());
1087 return createSingletonFusionGroup(fused_cat);
1090 void fuseConcats() {
1091 for (
auto it = block_->nodes().rbegin(); it != block_->nodes().rend();
1094 if (!isFusableCatNode(cat)) {
1097 Node* list_construct = cat->namedInput(attr::tensors)->node();
1098 Node* fused_cat = createFusedConcat(cat);
1099 Value* fused_cat_out = fused_cat->output();
1101 auto sorted_inputs = sortReverseTopological(fused_cat->inputs());
1102 size_t input_idx = 0;
1103 bool any_fused =
false;
1104 while (input_idx < sorted_inputs.size()) {
1105 Value* input = sorted_inputs[input_idx++];
1106 if (!canFuseWithConcat(input, fused_cat)) {
1110 auto maybe_group = tryFuse(fused_cat, input);
1111 AT_ASSERT(maybe_group && maybe_group == fused_cat);
1114 sorted_inputs = sortReverseTopological(fused_cat->inputs());
1119 cat->output()->replaceAllUsesWith(fused_cat_out);
1120 it.destroyCurrent();
1121 if (list_construct->output()->uses().empty()) {
1122 list_construct->destroy();
1125 fused_cat->destroy();
1130 void optimizeFusedGraphs() {
1131 for (Node* node : block_->nodes()) {
1132 if (node->kind() != prim::FusionGroup) {
1135 auto subgraph = node->g(attr::Subgraph);
1136 EliminateDeadCode(subgraph);
1137 EliminateCommonSubexpression(subgraph);
1138 ConstantPooling(subgraph);
1159 bool any_changed =
true;
1160 while (any_changed) {
1161 any_changed =
false;
1163 for (
auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
1165 std::tie(it, changed) = scanNode(*it);
1166 any_changed |= changed;
1173 optimizeFusedGraphs();
1177 replaceIntermediateBroadcastingChunks();
1180 for (
auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
1181 it = scanNodeForChunks(*it);
1185 for (Node* n : block_->nodes()) {
1186 removeOutputsUsedOnlyInSize(n);
1189 for (Node* node : block_->nodes()) {
1190 for (Block* sub_block : node->blocks()) {
1191 GraphFuser(sub_block, graph_).run();
1197 void PeepholeOptimizeShapeExpressions(Block* block) {
1198 auto nodes = block->nodes();
1199 for (
auto it = nodes.begin(); it != nodes.end(); ++it) {
1201 for (Block* subblock : node->blocks()) {
1202 PeepholeOptimizeShapeExpressions(subblock);
1204 if (node->kind() == prim::BroadcastSizes) {
1206 if (node->inputs().size() == 1) {
1207 node->output()->replaceAllUsesWith(node->input());
1208 it.destroyCurrent();
1213 std::map<size_t, Value*> unique_to_value;
1214 for (Value* input : node->inputs()) {
1215 unique_to_value.emplace(input->unique(), input);
1217 if (unique_to_value.size() != node->inputs().size()) {
1218 std::vector<Value*> inputs;
1219 inputs.reserve(unique_to_value.size());
1220 for (
auto& entry : unique_to_value) {
1221 inputs.push_back(entry.second);
1223 if (inputs.size() == 1) {
1224 node->output()->replaceAllUsesWith(inputs[0]);
1226 WithInsertPoint insert_guard{node};
1227 node->output()->replaceAllUsesWith(broadcastSizes(inputs));
1229 it.destroyCurrent();
1234 const auto& uses = node->output()->uses();
1235 if (uses.size() == 1 && uses[0].user->kind() == prim::BroadcastSizes) {
1236 Node* user = uses[0].user;
1237 user->removeInput(uses[0].offset);
1240 for (Value* i : node->inputs()) {
1243 it.destroyCurrent();
1261 bool trackSingleGradSumToSizeToOutputs(
1262 Value* gradSumToSizeOutput,
1263 std::vector<int64_t>* outputGradSumToSizes) {
1264 static OperatorSet commutes_with_SumToSize{{
1265 "aten::mul(Tensor self, Tensor other) -> Tensor",
1266 "aten::div(Tensor self, Tensor other) -> Tensor",
1268 "aten::mul(Tensor self, Scalar other) -> Tensor",
1269 "aten::div(Tensor self, Scalar other) -> Tensor",
1270 "aten::neg(Tensor self) -> Tensor",
1271 "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
1272 "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
1276 std::queue<Use> uses_to_process{};
1277 auto add_to_uses = [&](
const use_list& uses) {
1278 for (
auto u : uses) {
1279 uses_to_process.push(u);
1282 add_to_uses(gradSumToSizeOutput->uses());
1283 while (!uses_to_process.empty()) {
1284 auto user = uses_to_process.front().user;
1285 auto offset = uses_to_process.front().offset;
1286 uses_to_process.pop();
1287 if (user->matches(
"aten::type_as(Tensor self, Tensor other) -> Tensor")) {
1294 add_to_uses(user->output()->uses());
1296 }
else if (commutes_with_SumToSize.find(user)) {
1297 add_to_uses(user->output()->uses());
1298 }
else if (user->kind() == prim::Return) {
1307 if (outputGradSumToSizes && (*outputGradSumToSizes)[offset] == -1) {
1310 (*outputGradSumToSizes)[offset] =
1311 gradSumToSizeOutput->node()->inputs()[1]->offset();
1313 }
else if (user->kind() == aten::_grad_sum_to_size) {
1328 void FuseGraph(std::shared_ptr<Graph>& graph) {
1329 if (canFuseOnCPU() || canFuseOnGPU()) {
1330 GraphFuser(graph->block(), graph).run();
1332 EliminateCommonSubexpression(graph);
1335 EliminateDeadCode(graph);
1337 PeepholeOptimizeShapeExpressions(graph->block());
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small...
constexpr bool empty() const
empty - Check if the array is empty.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...