Caffe2 - C++ API
A deep learning, cross platform ML framework
node_hashing.cpp
1 #include <torch/csrc/jit/ir.h>
2 
3 #include <algorithm>
4 #include <unordered_map>
5 
6 #include <ATen/core/functional.h>
7 #include <ATen/core/interned_strings.h>
8 #include <c10/util/Exception.h>
9 #include <torch/csrc/jit/node_hashing.h>
10 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
11 #include <torch/csrc/utils/hash.h>
12 
13 namespace torch {
14 namespace jit {
15 
16 namespace {
17 
18 bool tensorEqual(const at::Tensor& lhs, const at::Tensor& rhs) {
19  return &lhs.type() == &rhs.type() && lhs.equal(rhs);
20 }
21 
22 bool tensorListEqual(
23  const std::vector<at::Tensor>& lhs,
24  const std::vector<at::Tensor>& rhs) {
25  if (lhs.size() != rhs.size())
26  return false;
27  return std::equal(lhs.begin(), lhs.end(), rhs.begin(), tensorEqual);
28 }
29 
30 // Check whether two nodes have the same attributes in CSE.
31 // This function may be too conservative for general use.
32 // Do NOT support g/gs attributes.
33 bool attributesEqualCSE(const Node* lhs, const Node* rhs) {
34  AT_ASSERT(lhs != nullptr);
35  AT_ASSERT(rhs != nullptr);
36  // One has attributes, the other does not.
37  if (lhs->hasAttributes() != rhs->hasAttributes())
38  return false;
39  // Neither has attributes.
40  if (!lhs->hasAttributes() && !rhs->hasAttributes())
41  return true;
42 
43  auto lnames = lhs->attributeNames();
44  auto rnames = rhs->attributeNames();
45  std::sort(lnames.begin(), lnames.end());
46  std::sort(rnames.begin(), rnames.end());
47  if (lnames != rnames)
48  return false;
49 
50  for (auto name : lnames) {
51  if (lhs->kindOf(name) != rhs->kindOf(name))
52  return false;
53 
54 #define COMPARE_ATTRIBUTEVALUE(type) \
55  case AttributeKind::type: { \
56  if (lhs->type(name) != rhs->type(name)) \
57  return false; \
58  } break;
59 
60  switch (lhs->kindOf(name)) {
61  COMPARE_ATTRIBUTEVALUE(f)
62  COMPARE_ATTRIBUTEVALUE(fs)
63  COMPARE_ATTRIBUTEVALUE(i)
64  COMPARE_ATTRIBUTEVALUE(is)
65  COMPARE_ATTRIBUTEVALUE(s)
66  COMPARE_ATTRIBUTEVALUE(ss)
67  case AttributeKind::t: {
68  if (!tensorEqual(lhs->t(name), rhs->t(name)))
69  return false;
70  break;
71  }
72  case AttributeKind::ts: {
73  if (!tensorListEqual(lhs->ts(name), rhs->ts(name)))
74  return false;
75  break;
76  }
77  case AttributeKind::g:
78  case AttributeKind::gs:
79  return false;
80  }
81 
82 #undef COMPARE_ATTRIBUTEVALUE
83  }
84 
85  return true;
86 }
87 
88 } // anonymous namespace
89 
90 size_t HashNode::operator()(const Node* k) const {
91  AT_ASSERT(k != nullptr);
92  return get_hash(
93  k->kind(),
94  fmap(k->outputs(), [](const Value* v) { return v->type()->kind(); }),
95  fmap(k->inputs(), [](const Value* v) { return v->unique(); }));
96 };
97 
98 bool EqualNode::operator()(const Node* lhs, const Node* rhs) const {
99  if (lhs == nullptr && rhs == nullptr)
100  return true;
101  if (lhs == nullptr || rhs == nullptr)
102  return false;
103 
104  if (lhs->kind() != rhs->kind())
105  return false;
106 
107  // Check whether the output types are the same.
108  auto lhs_outputs = lhs->outputs();
109  auto rhs_outputs = rhs->outputs();
110  if (lhs_outputs.size() != rhs_outputs.size())
111  return false;
112  for (size_t i = 0; i < lhs_outputs.size(); ++i) {
113  if (*lhs_outputs[i]->type() != *rhs_outputs[i]->type())
114  return false;
115  }
116 
117  // Check whether the inputs are the same.
118  auto lhs_inputs = lhs->inputs();
119  auto rhs_inputs = rhs->inputs();
120  if (lhs_inputs.size() != rhs_inputs.size())
121  return false;
122  if (!std::equal(lhs_inputs.begin(), lhs_inputs.end(), rhs_inputs.begin()))
123  return false;
124 
125  if (!attributesEqualCSE(lhs, rhs))
126  return false;
127 
128  return true;
129 };
130 
131 } // namespace jit
132 } // namespace torch
Definition: jit_type.h:17