6 #include <unordered_set> 9 #include <ATen/core/functional.h> 10 #include <torch/csrc/utils/memory.h> 42 using vertex_list = std::vector<Vertex<T>*>;
44 using unique_vertex = std::unique_ptr<Vertex<T>>;
46 enum class DFSDirection { forward, backward };
57 using iterator =
typename vertex_list<T>::iterator;
58 using reverse_iterator =
typename vertex_list<T>::reverse_iterator;
70 data_.erase(std::find(data_.begin(), data_.end(), v));
73 return std::find(data_.begin(), data_.end(), v) != data_.end();
77 return a->ord < b->ord;
89 reverse_iterator rbegin() {
90 return data_.rbegin();
92 reverse_iterator rend() {
97 std::vector<Vertex<T>*> data_;
100 template <
typename T>
109 template <
typename T>
112 for (
auto* v : data_) {
118 AT_ASSERT(!elt->visited_);
119 elt->visited_ =
true;
120 data_.push_back(elt);
125 return a->ord < b->ord;
129 const vertex_list<T>& vector() {
134 vertex_list<T> data_;
137 template <
typename T>
139 Vertex(
size_t ord,
T datum) : ord(ord), visited_(
false) {
140 data.push_back(datum);
146 std::string toString();
148 return edges_.in_edges;
151 return edges_.out_edges;
154 return std::move(edges_);
168 template <
typename T>
179 size_t max_size()
const {
180 return vertices_.size();
184 std::string toString();
187 size_t debugNumVertices()
const;
188 void debugCheckInvariants();
196 DFSDirection direction,
205 std::vector<unique_vertex<T>> vertices_;
209 template <
typename T>
211 return std::count_if(
212 vertices_.begin(), vertices_.end(), [](
const unique_vertex<T>& v) {
219 template <
typename T>
221 vertices_.push_back(torch::make_unique<
Vertex<T>>(vertices_.size(), datum));
222 return vertices_.back().get();
225 template <
typename T>
227 AT_ASSERT(producer != consumer);
228 AT_ASSERT(producer->out_edges().contains(consumer));
229 AT_ASSERT(consumer->in_edges().contains(producer));
230 producer->out_edges().erase(consumer);
231 consumer->in_edges().erase(producer);
234 template <
typename T>
236 for (
size_t ord = 0; ord < vertices_.size(); ++ord) {
237 const auto& vertex = vertices_.at(ord);
241 AT_ASSERTM(vertex->ord == ord, toString());
242 for (
auto* v : vertex->in_edges()) {
243 AT_ASSERTM(v->ord < ord, toString());
245 for (
auto* v : vertex->out_edges()) {
246 AT_ASSERTM(v->ord > ord, toString());
251 template <
typename T>
253 const auto& vertex = vertices_.at(ord);
261 template <
typename T>
263 for (
auto* parent : v->in_edges()) {
264 parent->out_edges().erase(v);
266 for (
auto* child : v->out_edges()) {
267 child->in_edges().erase(v);
269 auto edges = v->move_edges();
270 vertices_[v->ord] =
nullptr;
348 template <
typename T>
350 AT_ASSERT(producer != consumer);
354 bool is_distinct = producer->out_edges().insert(consumer);
357 is_distinct = consumer->in_edges().insert(producer);
358 AT_ASSERT(is_distinct);
360 if (producer->ord <= consumer->ord) {
371 DFSDirection::forward,
377 AT_ERROR(
"Cycle detected while trying to add edge.");
382 AT_ASSERT(!dfsSearch(
383 DFSDirection::backward,
391 reorder(std::move(deltaF), std::move(deltaB));
401 template <
typename T>
403 AT_ASSERT(producer != consumer);
404 if (contractionProducesCycle(producer, consumer)) {
408 removeEdge(producer, consumer);
411 if (producer->out_edges().size() > consumer->in_edges().size()) {
412 mergeConsumerIntoProducer(producer, consumer);
414 mergeProducerIntoConsumer(producer, consumer);
420 template <
typename T>
426 producer->data.insert(
427 producer->data.end(), consumer->data.begin(), consumer->data.end());
428 std::swap(consumer->data, producer->data);
430 auto edges = removeVertex(producer);
434 for (
auto* parent : edges.in_edges) {
435 addEdge(parent, consumer);
440 for (
auto* child : edges.out_edges) {
441 addEdge(consumer, child);
445 template <
typename T>
449 producer->data.insert(
450 producer->data.end(), consumer->data.begin(), consumer->data.end());
452 auto edges = removeVertex(consumer);
456 for (
auto* child : edges.out_edges) {
457 addEdge(producer, child);
462 for (
auto* parent : edges.in_edges) {
463 addEdge(parent, producer);
467 template <
typename T>
478 size_t upper_bound = consumer->ord;
479 for (
auto* child : producer->out_edges()) {
480 if (child == consumer)
482 if (child->visited())
485 DFSDirection::forward, child, consumer, upper_bound, visited)) {
492 static bool is_within_bound(
493 DFSDirection direction,
496 if (direction == DFSDirection::forward) {
497 return value < bound;
499 return value > bound;
506 template <
typename T>
508 DFSDirection direction,
513 vertex_list<T> stack;
516 visited.push_back(v);
522 while (!stack.empty()) {
523 auto* vertex = stack.back();
526 auto& next_edges = (direction == DFSDirection::forward)
527 ? vertex->out_edges()
528 : vertex->in_edges();
536 if (!next->visited() && is_within_bound(direction, next->ord, bound)) {
545 template <
typename T>
550 const auto& deltaB_ = deltaB.vector();
551 const auto& deltaF_ = deltaF.vector();
553 size_t num_affected = deltaB_.size() + deltaF_.size();
556 std::vector<unique_vertex<T>> desired_vertex_ordering;
557 desired_vertex_ordering.reserve(num_affected);
558 for (
auto it = deltaB_.begin(); it != deltaB_.end(); ++it) {
559 desired_vertex_ordering.push_back(std::move(vertices_.at((*it)->ord)));
561 for (
auto it = deltaF_.begin(); it != deltaF_.end(); ++it) {
562 desired_vertex_ordering.push_back(std::move(vertices_.at((*it)->ord)));
569 std::vector<size_t> gathered_ords;
570 gathered_ords.reserve(num_affected);
571 for (
const auto* v : deltaB_) {
572 gathered_ords.push_back(v->ord);
574 auto middle = gathered_ords.size();
575 for (
const auto* v : deltaF_) {
576 gathered_ords.push_back(v->ord);
579 gathered_ords.begin(),
580 gathered_ords.begin() + middle,
581 gathered_ords.end());
584 for (
size_t i = 0; i < num_affected; ++i) {
585 desired_vertex_ordering[i]->ord = gathered_ords[i];
586 vertices_[gathered_ords[i]] = std::move(desired_vertex_ordering[i]);
590 template <
typename T>
592 std::stringstream ss;
593 for (
auto& v : vertices_) {
595 ss << v->toString() <<
"\n";
601 template <
typename T>
603 std::stringstream ss;
604 ss <<
"node(" << ord <<
")\n";
606 for (
auto* c : in_edges()) {
610 for (
auto& d : data) {
611 if (std::is_pointer<T>::value) {
617 ss <<
"} (" << ord <<
") -> [";
618 for (
auto* c : out_edges()) {
Flush-To-Zero and Denormals-Are-Zero mode.