1 #include <torch/csrc/jit/ir.h> 4 #include <unordered_map> 6 #include <ATen/core/functional.h> 7 #include <ATen/core/interned_strings.h> 8 #include <c10/util/Exception.h> 9 #include <torch/csrc/jit/node_hashing.h> 10 #include <torch/csrc/jit/passes/common_subexpression_elimination.h> 11 #include <torch/csrc/utils/hash.h> 19 return &lhs.type() == &rhs.type() && lhs.equal(rhs);
23 const std::vector<at::Tensor>& lhs,
24 const std::vector<at::Tensor>& rhs) {
25 if (lhs.size() != rhs.size())
27 return std::equal(lhs.begin(), lhs.end(), rhs.begin(), tensorEqual);
33 bool attributesEqualCSE(
const Node* lhs,
const Node* rhs) {
34 AT_ASSERT(lhs !=
nullptr);
35 AT_ASSERT(rhs !=
nullptr);
37 if (lhs->hasAttributes() != rhs->hasAttributes())
40 if (!lhs->hasAttributes() && !rhs->hasAttributes())
43 auto lnames = lhs->attributeNames();
44 auto rnames = rhs->attributeNames();
45 std::sort(lnames.begin(), lnames.end());
46 std::sort(rnames.begin(), rnames.end());
50 for (
auto name : lnames) {
51 if (lhs->kindOf(name) != rhs->kindOf(name))
54 #define COMPARE_ATTRIBUTEVALUE(type) \ 55 case AttributeKind::type: { \ 56 if (lhs->type(name) != rhs->type(name)) \ 60 switch (lhs->kindOf(name)) {
61 COMPARE_ATTRIBUTEVALUE(f)
62 COMPARE_ATTRIBUTEVALUE(fs)
63 COMPARE_ATTRIBUTEVALUE(i)
64 COMPARE_ATTRIBUTEVALUE(is)
65 COMPARE_ATTRIBUTEVALUE(s)
66 COMPARE_ATTRIBUTEVALUE(ss)
67 case AttributeKind::t: {
68 if (!tensorEqual(lhs->t(name), rhs->t(name)))
72 case AttributeKind::ts: {
73 if (!tensorListEqual(lhs->ts(name), rhs->ts(name)))
77 case AttributeKind::g:
78 case AttributeKind::gs:
82 #undef COMPARE_ATTRIBUTEVALUE 90 size_t HashNode::operator()(
const Node* k)
const {
91 AT_ASSERT(k !=
nullptr);
94 fmap(k->outputs(), [](
const Value* v) {
return v->type()->kind(); }),
95 fmap(k->inputs(), [](
const Value* v) {
return v->unique(); }));
98 bool EqualNode::operator()(
const Node* lhs,
const Node* rhs)
const {
99 if (lhs ==
nullptr && rhs ==
nullptr)
101 if (lhs ==
nullptr || rhs ==
nullptr)
104 if (lhs->kind() != rhs->kind())
108 auto lhs_outputs = lhs->outputs();
109 auto rhs_outputs = rhs->outputs();
110 if (lhs_outputs.size() != rhs_outputs.size())
112 for (
size_t i = 0; i < lhs_outputs.size(); ++i) {
113 if (*lhs_outputs[i]->type() != *rhs_outputs[i]->type())
118 auto lhs_inputs = lhs->inputs();
119 auto rhs_inputs = rhs->inputs();
120 if (lhs_inputs.size() != rhs_inputs.size())
122 if (!std::equal(lhs_inputs.begin(), lhs_inputs.end(), rhs_inputs.begin()))
125 if (!attributesEqualCSE(lhs, rhs))