1 #include <torch/csrc/jit/ir.h> 3 #include <c10/util/Exception.h> 4 #include <torch/csrc/jit/constants.h> 5 #include <torch/csrc/jit/operator.h> 6 #include <torch/csrc/jit/passes/python_print.h> 7 #include <torch/csrc/jit/script/schema_matching.h> 15 #include <unordered_map> 16 #include <unordered_set> 22 void printQuotedString(std::ostream& stmt,
const std::string& str);
27 static constexpr topo_position_t kLowerBound = INT64_MIN;
28 static constexpr topo_position_t kUpperBound = INT64_MAX;
29 static constexpr topo_position_t kMidPoint = 0;
34 static constexpr topo_position_t kAppendInterval = 1099511627776ULL ;
38 constexpr Symbol PythonOp::Kind;
40 static void printValueRef(std::ostream& out,
const Value* n) {
41 out <<
"%" << n->uniqueName();
47 std::ostream& operator<<(std::ostream& out, const std::vector<T>& nodes) {
48 out << at::ArrayRef<T>{nodes};
53 static std::ostream& printValueRefs(
57 for (
auto n : nodes) {
61 printValueRef(out, n);
69 std::ostream& operator<<(
72 return printValueRefs(out, nodes);
75 std::ostream& operator<<(std::ostream& out, const at::ArrayRef<Value*>& nodes) {
76 return printValueRefs(out, nodes);
84 std::string delim_ =
", ")
85 : values(values), delim(std::move(delim_)) {}
90 for (
auto n : l.values) {
94 printValueRef(out, n);
101 template <
typename T>
102 static void printPrimList(std::ostream& out,
const std::vector<T>& items) {
105 for (
auto& item : items) {
114 static void printStrList(
116 const std::vector<std::string>& items) {
119 for (
auto& item : items) {
122 printQuotedString(out, item);
127 void Node::printAttrValue(std::ostream& out,
const Symbol& name)
const {
128 switch (kindOf(name)) {
129 case AttributeKind::f:
132 case AttributeKind::fs:
133 printPrimList(out, fs(name));
135 case AttributeKind::i:
138 case AttributeKind::is:
139 printPrimList(out, is(name));
141 case AttributeKind::s:
142 printQuotedString(out, s(name));
144 case AttributeKind::ss:
145 printStrList(out, ss(name));
147 case AttributeKind::t: {
150 if (tensor.numel() == 1) {
151 auto scalar_tensor = tensor.view({}).item();
153 if (scalar_tensor.isFloatingPoint()) {
154 out << scalar_tensor.toDouble();
156 out << scalar_tensor.toLong();
159 }
else if (tensor.numel() <= max_tensor_display_size) {
161 std::ostringstream tensor_ss;
163 std::string tensor_s{tensor_ss.str()};
165 std::replace(tensor_s.begin(), tensor_s.end(),
'\n',
' ');
172 case AttributeKind::ts:
173 out <<
"[<Tensors>]";
175 case AttributeKind::g:
178 case AttributeKind::gs:
184 void Node::printAttributes(std::ostream& out,
bool ignore_subgraph =
false)
187 auto names = attributeNames();
189 for (
auto name : names) {
190 if (ignore_subgraph && name == attr::Subgraph) {
200 out << name.toUnqualString() <<
"=";
202 printAttrValue(out, name);
207 static std::ostream& indent(std::ostream& out,
size_t level) {
208 for (
size_t i = 0; i < level; ++i) {
214 std::ostream& Node::print(
217 std::vector<const Node*>* groups)
const {
218 auto outs = outputs();
221 if (kind() == prim::PythonOp) {
223 out <<
"^" << pyOp->name();
224 pyOp->writeScalars(out);
226 if (hasAttribute(attr::Subgraph) && groups) {
227 out << kind().toQualString() <<
"_" << groups->size();
228 if (numAttributes() > 1 && kind() != prim::DifferentiableGraph) {
229 printAttributes(out,
true);
231 groups->push_back(
this);
233 out << kind().toQualString();
234 if (hasAttributes()) {
235 printAttributes(out);
240 out <<
"(" << inputs() <<
")";
241 std::string scName = scopeName();
242 if (scName.empty()) {
246 out <<
"scope: " << scName <<
"\n";
248 for (
size_t i = 0; i < blocks().size(); ++i) {
249 auto b = blocks()[i];
250 indent(out, level + 1) <<
"block" << i <<
"(" 253 for (
auto nested : b->nodes()) {
254 nested->print(out, level + 2, groups);
256 indent(out, level + 2) <<
"-> (" << b->outputs() <<
")\n";
261 std::ostream& operator<<(std::ostream& out,
const Node& n) {
262 return n.print(out, 0,
nullptr);
265 std::ostream& operator<<(std::ostream& out,
const Graph& g) {
268 std::vector<const Node*> groups;
269 for (
auto n : g.nodes()) {
270 n->print(out, 1, &groups);
272 out <<
" return (" << g.outputs() <<
")\n";
274 for (
auto fg : groups) {
275 out <<
"with " << fg->kind().toQualString() <<
"_" << i++ <<
" = " 276 << *fg->g(attr::Subgraph);
291 std::ostream& Graph::prettyPrint(std::ostream& out) {
292 std::vector<at::Tensor> tensor_table;
293 std::vector<ClassTypePtr> class_table;
294 PythonPrint(out, *
this, tensor_table, class_table);
298 void Graph::dumpPretty() {
299 std::vector<at::Tensor> tensor_table;
300 std::vector<ClassTypePtr> class_table;
301 PythonPrint(std::cout, *
this, tensor_table, class_table);
304 static void checkSameDevice(
const Node* node) {
305 bool has_device =
false;
307 auto checkValue = [&](
const Value* v) {
311 device = type->device();
313 AT_ASSERT(device == type->device());
317 for (
auto input : node->inputs()) {
320 for (
auto output : node->outputs()) {
325 using node_set = std::set<const Node*>;
326 #define ALL_OF(container) container.begin(), container.end() 335 void Node::lint()
const {
348 for (
auto input : inputs_) {
352 std::find(ALL_OF(input->uses_),
Use(const_cast<Node*>(
this), i)) !=
354 AT_ASSERT(graph_->all_nodes.count(
this) == 1);
359 for (
auto o : outputs()) {
361 for (
auto use : o->uses()) {
365 AT_ASSERT(use.user->inputs_[use.offset] == o);
373 AT_ASSERT(inputs_.size() == 0);
377 AT_ASSERT(outputs().size() == 0);
381 AT_ASSERT(inputs_.size() == 0);
383 case prim::PythonOp: {
385 size_t n_scalars = 0, n_tensors = 0;
386 auto* value =
static_cast<const PythonOp*
>(
this);
387 for (
auto c : value->cconv) {
390 }
else if (c ==
'd') {
395 AT_ASSERT(static_cast<bool>(value->pyobj));
397 AT_ASSERT(n_scalars == value->scalar_args.size());
398 AT_ASSERT(n_tensors == inputs_.size());
406 case prim::FusionGroup:
407 checkSameDevice(
this);
409 g(attr::Subgraph)->lint();
416 void Graph::lint()
const {
431 LintScope() =
default;
432 LintScope(std::unique_ptr<LintScope> parent) : parent(std::move(parent)) {}
433 bool contains(
const Value* v) {
434 return values.count(v) > 0 || (parent && parent->contains(v));
436 bool contains(
const Node* n) {
437 return nodes.count(n) > 0 || (parent && parent->contains(n));
439 void insert(
const Value* v) {
440 AT_ASSERT(!contains(v));
443 void insert(
const Node* n) {
444 AT_ASSERT(!contains(n));
447 std::unique_ptr<LintScope> parent;
450 std::unordered_set<const Value*> values;
451 std::unordered_set<const Node*> nodes;
456 LintImpl(
const Graph& g)
458 scope(
new LintScope()),
459 all_nodes_set(ALL_OF(g.all_nodes)) {}
461 std::unique_ptr<LintScope> scope;
462 std::unordered_set<size_t> seen_uniques;
463 std::unordered_map<const Node*, int64_t> anticipated_uses;
464 node_set all_nodes_set;
467 void check_value(
const Value* v) {
469 auto b2 = seen_uniques.insert(v->unique());
470 AT_ASSERT(b2.second);
471 AT_ASSERT(v->unique() < g.next_unique_);
473 for (
auto use : v->uses()) {
474 AT_ASSERT(!scope->contains(use.user));
475 AT_ASSERT(g.all_nodes.count(use.user) == 1);
476 anticipated_uses[use.user]++;
479 void check_node(
const Node* n) {
480 for (
auto input : n->inputs_) {
481 if (!scope->contains(input)) {
482 AT_ASSERTM(0, input->unique(),
" not in scope");
485 AT_ASSERT(anticipated_uses[n] == static_cast<int64_t>(n->inputs_.size()));
486 anticipated_uses[n] = -1;
488 for (
auto block : n->blocks()) {
489 std::unique_ptr<LintScope> new_scope(
new LintScope(std::move(scope)));
490 scope = std::move(new_scope);
492 scope = std::move(scope->parent);
495 for (
auto o : n->outputs()) {
496 AT_ASSERT(o->node() == n);
497 AT_ASSERT(i++ == o->offset_);
502 void check_block(
const Block* b) {
504 AT_ASSERT(b->param_node()->isBefore(*b->nodes().begin()));
505 auto curNode = *b->nodes().begin();
506 while (curNode != b->return_node()) {
507 AT_ASSERT(curNode->isBefore(curNode->next()));
508 curNode = curNode->next();
511 for (
auto input : b->inputs()) {
513 AT_ASSERT(input->node()->kind_ == prim::Param);
516 for (
auto n : b->nodes()) {
517 AT_ASSERT(n->kind_ != prim::Param);
518 AT_ASSERT(n->kind_ != prim::Return);
522 AT_ASSERT(b->output_->kind() == prim::Return);
523 check_node(b->output_);
531 node_set nodes_set(ALL_OF(b->nodes()));
532 node_set inputs_set{b->input_};
533 node_set output_set{b->output_};
536 AT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(nodes_set)));
537 AT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(inputs_set)));
538 AT_ASSERT(std::includes(ALL_OF(all_nodes_set), ALL_OF(output_set)));
540 sum_set.insert(ALL_OF(nodes_set));
541 sum_set.insert(ALL_OF(inputs_set));
542 sum_set.insert(ALL_OF(output_set));
545 node_set all_nodes_set(
546 ALL_OF(g.all_nodes));
548 check_block(g.block_);
549 for (
auto kv : anticipated_uses) {
550 AT_ASSERT(kv.second == -1);
552 AT_ASSERT(std::includes(ALL_OF(sum_set), ALL_OF(all_nodes_set)));
555 LintImpl(*this).check_graph();
558 void Graph::dump()
const {
559 std::cout << *
this <<
"\n";
562 void LintGraph(std::shared_ptr<Graph>& graph) {
568 output_(initOutput(graph_->create(prim::Return, 0))),
569 input_(graph_->create(prim::Param, 0)),
570 owning_node_(node_) {
571 graph_->all_blocks.emplace(
this);
572 output_->owning_block_ =
this;
573 output_->topo_position_ = kUpperBound;
574 input_->owning_block_ =
this;
575 input_->topo_position_ = kLowerBound;
578 void Block::reIndexTopology() {
579 auto curPos = kLowerBound;
580 for (
auto node : nodes()) {
581 AT_ASSERT(curPos <= (kUpperBound - kAppendInterval));
582 curPos += kAppendInterval;
583 node->topo_position_ = curPos;
587 void Block::cloneFrom(
Block* src, std::function<
Value*(
Value*)> value_map) {
588 std::unordered_map<Value*, Value*> local_map;
589 auto env = [&](
Value* v) {
590 auto it = local_map.find(v);
591 if (it != local_map.end()) {
597 auto graph = owningGraph();
598 for (
auto input : src->inputs()) {
599 local_map[input] = this->addInput()->copyMetadata(input);
602 for (
auto node : src->nodes()) {
603 auto new_node = this->appendNode(graph->createClone(node, env));
604 for (
size_t i = 0; i < node->outputs().size(); ++i) {
605 auto oo = node->outputs()[i];
606 auto no = new_node->outputs()[i];
608 no->copyMetadata(oo);
611 for (
auto output : src->outputs()) {
612 this->registerOutput(env(output));
616 void Block::destroy() {
619 output_->removeAllInputs();
620 for (
auto it = this->nodes().reverse().begin(),
621 end = this->nodes().reverse().end();
628 graph_->freeBlock(
this);
631 std::shared_ptr<Graph> Graph::copy() {
632 auto new_g = std::make_shared<Graph>();
635 "Graph::copy() encountered a use of a value not in scope. Run lint!");
637 new_g->block()->cloneFrom(this->block(), env);
641 bool Value::mustBeNone()
const {
642 return node_->mustBeNone();
645 std::string Value::uniqueNameBase()
const {
646 std::string name = uniqueName();
647 std::string name_base = name;
648 auto last_dot_pos = name.find_last_of(
'.');
649 if (last_dot_pos != std::string::npos && last_dot_pos + 1 != name.size()) {
650 if (name.find_first_not_of(
"0123456789", last_dot_pos + 1) ==
652 name_base = name.substr(0, last_dot_pos);
658 bool Value::isValidName(
const std::string& name) {
665 if (name.find_first_not_of(
"0123456789") == std::string::npos) {
672 Value* Value::setUniqueName(
const std::string& name) {
673 if (!isValidName(name)) {
674 throw std::runtime_error(
"Invalid name: '" + name +
"'");
677 auto& names = node()->owningGraph()->unique_names_;
680 if (hasUniqueName()) {
681 names.erase(unique_name_);
691 auto old_owner_of_name = names.find(name);
692 if (old_owner_of_name != names.end()) {
694 std::string name_base = name;
695 auto last_dot_pos = name.find_last_of(
'.');
696 if (last_dot_pos != std::string::npos && last_dot_pos + 1 != name.size()) {
697 if (name.find_first_not_of(
"0123456789", last_dot_pos + 1) ==
699 suffix = std::stoll(name.substr(last_dot_pos + 1));
700 name_base = name.substr(0, last_dot_pos);
703 std::string replacement_name;
705 std::stringstream ss;
706 ss << name_base <<
"." << suffix++;
707 replacement_name = ss.str();
708 }
while (names.count(replacement_name) > 0);
709 old_owner_of_name->second->setUniqueName(replacement_name);
718 setType(from->type());
719 if (from->hasUniqueName()) {
720 setUniqueName(from->uniqueName());
725 void Value::replaceFirstUseWith(
Value* newValue) {
726 AT_ASSERT(owningGraph() == newValue->owningGraph());
728 u.user->inputs_[u.offset] = newValue;
729 newValue->uses_.push_back(u);
730 uses_.erase(uses_.begin());
733 void Value::replaceAllUsesWith(
Value* newValue) {
734 while (!uses().empty()) {
735 replaceFirstUseWith(newValue);
740 auto name_str = name.toUnqualString();
741 for (
size_t i = 0; i < the_schema.arguments().size(); ++i) {
742 const Argument* arg = &the_schema.arguments()[i];
743 if (arg->name() == name_str) {
747 throw std::runtime_error(
748 std::string(
"Couldn't find an argument called ") + name.toQualString());
752 return toIValue(namedInput(name));
756 return input(findArgument(schema(), name));
760 const char* signature_literal,
762 if (!sig(signature_literal).matches(
this)) {
765 for (
Symbol s : const_inputs) {
766 if (!is_constant(s)) {
773 bool Node::mustBeNone()
const {
774 return kind_ == prim::Constant && !this->hasAttributes() &&
776 output()->type() == NoneType::get());
779 void Node::dump()
const {
780 std::cout << *
this <<
"\n";
783 void Node::findSchema()
const {
784 schema_ = &getOperatorFor(
this).schema();
789 if (
auto op = findOperatorFor(
this)) {
790 schema_ = &op->schema();
796 bool Node::isNondeterministic()
const {
798 "aten::dropout(Tensor input, float p, bool train) -> Tensor",
799 "aten::_fused_dropout(Tensor self, float p, Generator? generator) -> (Tensor, Tensor)",
800 "aten::_standard_gamma(Tensor self, Generator? generator) -> Tensor",
801 "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
802 "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
803 "aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator? generator) -> Tensor",
804 "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",
805 "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
806 "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
807 "aten::poisson(Tensor self, Generator? generator) -> Tensor",
808 "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
809 "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
810 "aten::rand(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
811 "aten::rand_like(Tensor self) -> Tensor",
812 "aten::rand_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
813 "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
814 "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
815 "aten::randint_like(Tensor self, int high) -> Tensor",
816 "aten::randint_like(Tensor self, int low, int high) -> Tensor",
817 "aten::randint_like(Tensor self, int high, *, int dtype, int layout, Device device) -> Tensor",
818 "aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, Device device) -> Tensor",
819 "aten::randn(int[] size, *, int? dtype, int? layout, Device? device) -> Tensor",
820 "aten::randn_like(Tensor self) -> Tensor",
821 "aten::randn_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
822 "aten::randperm(int n, *, int? dtype, int? layout, Device? device) -> Tensor"};
824 if (nondeterministic_ops.find(
this) ==
nullptr) {
828 if (matches(
"aten::dropout(Tensor input, float p, bool train) -> Tensor") &&
829 is_constant(attr::train) && !get<bool>(attr::train).value()) {
835 bool Node::hasSideEffects()
const {
838 case prim::IgnoredPythonOp:
840 case prim::RaiseException:
860 void Node::assignTopoPosition() {
861 auto returnNode = owningBlock()->return_node();
862 const auto prevPos = prev()->topo_position_;
863 const auto nextPos = next()->topo_position_;
866 if (next() == returnNode) {
867 if (next() == prev()) {
869 topo_position_ = kMidPoint;
873 if (prevPos >= (kUpperBound - kAppendInterval)) {
875 owningBlock()->reIndexTopology();
879 topo_position_ = prevPos + kAppendInterval;
882 }
else if (prev() == returnNode) {
884 if (nextPos <= (kLowerBound + kAppendInterval)) {
886 owningBlock()->reIndexTopology();
890 topo_position_ = nextPos - kAppendInterval;
894 const auto posBetween = prevPos + (nextPos - prevPos) / 2;
895 if (posBetween == prevPos) {
897 owningBlock()->reIndexTopology();
900 topo_position_ = posBetween;
907 owning_block_(
nullptr),
908 scope_(graph_->current_scope_),
911 graph_->all_nodes.emplace(
this);
914 void Node::eraseOutput(
size_t i) {
915 AT_ASSERT(i < outputs_.size());
916 AT_ASSERT(outputs_[i]->uses().empty());
918 Value* n = outputs_[i];
919 outputs_.erase(outputs_.begin() + i);
920 owningGraph()->freeValue(n);
921 for (
size_t j = i; j < outputs_.size(); j++) {
922 outputs_[j]->offset_--;
926 Block* Node::addBlock() {
928 blocks_.push_back(
new Block(owningGraph(),
this));
929 return blocks_.back();
932 void Node::eraseBlock(
size_t i) {
933 AT_ASSERT(i < blocks_.size());
935 Block* n = blocks_[i];
936 blocks_.erase(blocks_.begin() + i);
940 void Node::destroy() {
941 while (!outputs().empty()) {
942 eraseOutput(outputs().size() - 1);
944 while (!blocks().empty()) {
945 eraseBlock(blocks().size() - 1);
951 graph_->freeNode(
this);
954 void Node::cloneFrom(
Node* s) {
955 setSourceLocation(s->getSourceLocation());
956 if (s->scope_ && !s->scope_->isBlank()) {
962 void Node::replaceAllUsesWith(
Node* n) {
963 AT_ASSERT(outputs().size() == n->outputs().size());
964 size_t nOutputs = outputs().size();
965 for (
size_t i = 0; i < nOutputs; i++) {
966 outputs()[i]->replaceAllUsesWith(n->outputs()[i]);
970 Value* Node::insertInput(
size_t i,
Value* value) {
971 AT_ASSERT(graph_ == value->owningGraph());
977 for (
size_t use_itr = i; use_itr < inputs_.size(); ++use_itr) {
979 auto use = findUseForInput(use_itr);
983 inputs_.insert(inputs_.begin() + i, value);
985 value->uses_.emplace_back(
this, i);
990 AT_ASSERT(graph_ == value->owningGraph());
992 value->uses_.emplace_back(
this, inputs_.size());
993 inputs_.push_back(value);
997 Value* Node::replaceInput(
size_t i,
Value* newValue) {
998 AT_ASSERT(newValue->owningGraph() == graph_);
1000 Value* old = dropInput(i);
1001 inputs_[i] = newValue;
1002 newValue->uses_.emplace_back(
this, i);
1006 void Node::replaceInputWith(
Value* from,
Value* to) {
1007 AT_ASSERT(from->owningGraph() == graph_);
1008 AT_ASSERT(to->owningGraph() == graph_);
1011 for (
auto input : inputs()) {
1012 if (input == from) {
1013 replaceInput(i, to);
1019 Value* Node::addOutput() {
1020 outputs_.push_back(
new Value(
this, outputs_.size()));
1022 return outputs_.back();
1025 Value* Node::insertOutput(
size_t i) {
1027 outputs_.insert(outputs_.begin() + i,
new Value(
this, i));
1028 for (
size_t itr = i + 1; itr < outputs_.size(); ++itr) {
1029 outputs_[itr]->setOffset(outputs_[itr]->offset() + 1);
1031 return outputs_.at(i);
1034 bool Node::isBeforeOrAfter(
const Node* n, MoveSide moveSide)
const {
1035 if (this->owningBlock() == n->owningBlock()) {
1036 if (moveSide == MoveSide::BEFORE) {
1037 return this->topo_position_ < n->topo_position_;
1040 if (moveSide == MoveSide::AFTER) {
1041 return this->topo_position_ > n->topo_position_;
1044 AT_ASSERT(
this == n);
1052 AT_ASSERT(lhs->owningBlock());
1056 if (!rhs->owningBlock()) {
1060 if (lhs->owningBlock() == rhs->owningBlock()) {
1061 return lhs->isBeforeOrAfter(rhs, moveSide);
1063 rhs = rhs->owningBlock()->owningNode();
1066 lhs = lhs->owningBlock()->owningNode();
1072 bool Node::isBefore(
const Node* n)
const {
1073 return isBeforeOrAfter(n, MoveSide::BEFORE);
1076 bool Node::isAfter(
const Node* n)
const {
1077 return isBeforeOrAfter(n, MoveSide::AFTER);
1080 Node* Node::insertBefore(
Node* n) {
1081 AT_ASSERT(n->inBlockList());
1082 insertAfter(n->prev());
1087 AT_ASSERT(!inBlockList() && n->inBlockList());
1088 AT_ASSERT(n->owningBlock());
1089 this->owning_block_ = n->owningBlock();
1090 Node* next = n->next();
1093 this->next() = next;
1094 next->prev() =
this;
1095 assignTopoPosition();
1099 void Node::moveAfter(
Node* n) {
1104 void Node::moveBefore(
Node* n) {
1109 void Node::removeInput(
size_t i) {
1114 for (
size_t j = i + 1; j < inputs_.size(); j++) {
1115 auto it = findUseForInput(j);
1118 inputs_.erase(inputs_.begin() + i);
1121 void Node::removeAllInputs() {
1123 for (
size_t i = 0; i < inputs().size(); ++i) {
1129 use_list::iterator Node::findUseForInput(
size_t i) {
1130 auto& input_uses = inputs_[i]->uses_;
1133 auto use_it = std::find(input_uses.begin(), input_uses.end(),
Use(
this, i));
1134 AT_ASSERT(use_it != input_uses.end());
1138 Value* Node::dropInput(
size_t i) {
1139 AT_ASSERT(i < inputs_.size());
1140 auto input_node = inputs_[i];
1141 auto use_it = findUseForInput(i);
1142 input_node->uses_.erase(use_it);
1143 inputs_[i] =
nullptr;
1147 void Node::removeFromList() {
1148 AT_ASSERT(inBlockList());
1149 this->owning_block_ =
nullptr;
1150 Node* next = this->next();
1151 Node* prev = this->prev();
1152 prev->next() = next;
1153 next->prev() = prev;
1154 this->next() =
nullptr;
1155 this->prev() =
nullptr;
1160 std::make_shared<std::string>(
"<internally-created-node>"), 0, 1);
1164 Value* Graph::insert(
1169 return script::emitBuiltinCall(
1170 range.value_or(fakeRange()),
1179 Node* Graph::create(
NodeKind kind,
size_t num_outputs) {
1181 auto n =
new Node(
this, kind);
1182 for (
size_t i = 0; i < num_outputs; i++) {
1188 Node* Graph::create(
1191 size_t num_outputs) {
1192 auto n = create(kind, num_outputs);
1193 for (
auto i : inputs) {
1199 Node* Graph::createAutogradZero() {
1200 return create(prim::AutogradZero);
1203 Node* Graph::createNone(TypePtr typ) {
1204 Node* n = create(prim::Constant);
1205 n->output()->setType(OptionalType::create(std::move(typ)));
1209 Node* Graph::createFusionGroup() {
1210 auto n = create(prim::FusionGroup, 0);
1211 n->g_(attr::Subgraph, std::make_shared<Graph>(current_scope()));
1215 Node* Graph::createTuple(
1218 auto types = fmap(values, [](
Value* v) {
return v->type(); });
1219 auto tt = TupleType::create(std::move(types), std::move(field_names));
1220 auto n = create(prim::TupleConstruct, values);
1221 n->output()->setType(tt);
1225 Node* Graph::createTupleUnpack(
Value* v) {
1226 TupleTypePtr tt = v->type()->expect<
TupleType>();
1227 auto n = create(prim::TupleUnpack, {v}, 0);
1228 for (
auto& element : tt->elements()) {
1229 n->addOutput()->setType(element);
1234 Node* Graph::createTupleIndex(
Value* tup, int64_t index) {
1235 auto n = create(prim::TupleIndex, {tup});
1236 n->i_(attr::index, index);
1237 auto tuple_type = tup->type()->expect<
TupleType>();
1238 n->output()->setType(tuple_type->elements().at(index));
1242 Node* Graph::createTupleSlice(
Value* tup, int64_t beg, int64_t end) {
1243 auto n = create(prim::TupleSlice, {tup});
1244 auto tuple_type = tup->type()->expect<
TupleType>();
1245 n->i_(attr::beg, beg);
1246 n->i_(attr::end, end);
1247 std::vector<TypePtr> output_types;
1248 for (
auto i = beg; i < end; ++i) {
1249 output_types.push_back(tuple_type->elements().at(i));
1251 auto tt = TupleType::create(std::move(output_types));
1252 n->output()->setType(tt);
1257 auto n = create(prim::ListConstruct, values);
1258 for (
const auto& v : values) {
1259 AT_ASSERT(v->type()->isSubtypeOf(elem_type));
1261 n->output()->setType(ListType::create(elem_type));
1264 Node* Graph::createListUnpack(
Value* v,
size_t size) {
1265 ListTypePtr list_type = v->type()->expect<
ListType>();
1266 TypePtr elem_type = list_type->getElementType();
1267 auto n = create(prim::ListUnpack, {v}, 0);
1268 for (
size_t i = 0; i < size; ++i) {
1269 n->addOutput()->setType(elem_type);
1274 Node* Graph::createDict(
1275 const TypePtr& key_type,
1276 const TypePtr& value_type,
1279 AT_ASSERT(keys.
size() == values.
size());
1280 auto n = create(prim::DictConstruct, 1);
1281 for (
size_t i = 0; i < keys.
size(); ++i) {
1282 AT_ASSERT(keys[i]->type()->isSubtypeOf(key_type));
1283 AT_ASSERT(values[i]->type()->isSubtypeOf(value_type));
1285 n->addInput(keys[i]);
1286 n->addInput(values[i]);
1288 n->output()->setType(DictType::create(key_type, value_type));
1293 auto dict_type = dict->type()->expect<
DictType>();
1294 AT_ASSERT(index->type()->isSubtypeOf(dict_type->getKeyType()));
1296 auto n = create(prim::DictIndex, {dict, index});
1297 n->output()->setType(dict_type->getValueType());
1301 Node* Graph::createNumToTensor(
Value* value) {
1302 auto typ = value->type();
1303 Node* result = create(prim::NumToTensor, {value});
1304 result->output()->setType(CompleteTensorType::fromNumberType(std::move(typ)));
1308 Node* Graph::createImplicitTensorToNum(
const TypePtr& type,
Value* value) {
1309 auto* result = create(prim::ImplicitTensorToNum, {value});
1310 result->output()->setType(type);
1314 Node* Graph::createObject(
const ClassTypePtr& type) {
1315 auto result = create(prim::CreateObject);
1316 result->output()->setType(type);
1320 Node* Graph::createSetAttr(
1322 const std::string& field,
1324 auto n = create(prim::SetAttr, {obj, newValue}, 0);
1325 n->s_(attr::name, field);
1329 Node* Graph::createGetAttr(
Value* obj,
const std::string& field) {
1330 const auto classType = obj->type()->expect<
ClassType>();
1332 auto n = create(prim::GetAttr, {obj}, 1);
1333 n->s_(attr::name, field);
1335 const auto outputType = classType->getAttribute(field);
1336 n->output()->setType(outputType);
1340 Node* Graph::createClone(
1342 const std::function<
Value*(
Value*)>& value_map,
1345 Node* r = n->allocNewInstance(
this);
1346 for (
auto o : n->outputs()) {
1347 r->addOutput()->copyMetadata(o);
1350 for (
auto i : n->inputs()) {
1351 r->addInput(value_map(i));
1354 for (
auto b : n->blocks()) {
1355 r->addBlock()->cloneFrom(b, value_map);
1361 Value* Graph::insertConstant(
1363 const TypePtr& result_type,
1366 return jit::insertConstant(
1367 *
this, std::move(val), result_type, std::move(loc), std::move(scope));
1370 std::string Graph::toString()
const {
1371 std::ostringstream oss;
1377 for (
const Node* n : all_nodes) {
1380 for (
const Value* v : all_values) {
1383 for (
const Block* b : all_blocks) {
1388 void Graph::freeNode(
Node* n) {
1389 auto it = all_nodes.find(n);
1390 AT_ASSERT(it != all_nodes.end());
1392 all_nodes.erase(it);
1394 void Graph::freeValue(
Value* v) {
1395 v->setUniqueName(
"");
1396 auto it = all_values.find(v);
1397 AT_ASSERT(it != all_values.end());
1399 all_values.erase(it);
1401 void Graph::freeBlock(
Block* b) {
1402 auto it = all_blocks.find(b);
1403 AT_ASSERT(it != all_blocks.end());
1405 all_blocks.erase(it);
1411 if (v->node()->kind() == prim::TupleConstruct) {
1412 return v->node()->inputs();
1414 auto& g = *v->owningGraph();
1415 return g.insertNode(g.createTupleUnpack(v))->outputs();
1418 std::vector<Value*> inlineCallTo(
1422 bool unpack_outputs) {
1423 std::unordered_map<Value*, Value*> value_map;
1424 auto value_map_func = [&](
Value* v) {
return value_map.at(v); };
1425 AT_ASSERT(callee.inputs().size() == inputs.
size());
1426 for (
size_t i = 0; i < inputs.
size(); ++i) {
1427 value_map[callee.inputs()[i]] = inputs[i];
1429 for (
auto* node : callee.nodes()) {
1430 auto* new_node = g.insertNode(g.createClone(node, value_map_func));
1431 for (
size_t i = 0; i < node->outputs().size(); ++i) {
1432 value_map[node->outputs()[i]] = new_node->outputs()[i];
1436 std::vector<Value*> outputs;
1437 for (
auto* output : callee.outputs()) {
1438 outputs.push_back(value_map_func(output));
1441 if (unpack_outputs && outputs.size() == 1 &&
1442 callee.outputs().at(0)->type()->kind() == TupleType::Kind) {
1443 auto tup = outputs[0];
1445 for (
Value* v : createTupleUnpack(tup)) {
1446 outputs.emplace_back(v);
1450 if (tup->node()->kind() == prim::TupleConstruct &&
1451 !tup->node()->hasUses()) {
1452 tup->node()->destroy();
1460 throw std::runtime_error(
1461 "Trying to allocate a Python object without python bindings loaded");
1463 std::atomic<decltype(&defaultAllocPythonOp)> alloc_python_op;
1467 return alloc_python_op.load()(g);
1470 alloc_python_op.store(v);
constexpr size_t size() const
size - Get the array size.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...