12 #include <c10/util/Exception.h> 13 #include <c10/util/Optional.h> 14 #include <torch/csrc/WindowsTorchApiMacro.h> 15 #include <torch/csrc/jit/source_range.h> 21 #include <torch/csrc/jit/ir.h> 22 #include <torch/csrc/jit/testing/file_check.h> 27 void printQuotedString(std::ostream& stmt,
const std::string& str);
45 : type_(type), search_str_(std::move(str)) {
46 count_ = std::move(count);
51 const std::string search_str_;
53 friend std::ostream& operator<<(std::ostream& out,
const Check& c);
56 std::ostream& operator<<(std::ostream& out,
const Check& c) {
74 out <<
"CHECK-COUNT-" << *c.count_;
77 out <<
": " << c.search_str_;
84 const std::string& sub,
86 auto pos = search_range.file_ptr()->find(sub, search_range.start());
87 if (pos == std::string::npos || (pos + sub.size()) > search_range.end()) {
89 SourceRange(search_range.file_ptr(), search_range.start(), sub.size());
91 ss <<
"Expected to find ";
92 printQuotedString(ss, sub);
93 ss <<
" but did not find it\n";
94 found_range.highlight(ss);
95 ss <<
"From " << check <<
"\n";
96 throw std::runtime_error(ss.str());
102 const std::shared_ptr<std::string>& file,
103 const std::string& sub,
105 const Check& check) {
106 return assertFind(
SourceRange(file, start, file->size()), sub, check);
111 const std::string& sub,
112 const Check& check) {
113 auto pos = search_range.file_ptr()->find(sub, search_range.start());
114 if (pos != std::string::npos && (pos + sub.size()) <= search_range.end()) {
116 SourceRange(search_range.file_ptr(), pos, sub.size() + pos);
117 std::stringstream ss;
118 ss <<
"Expected to not find ";
119 printQuotedString(ss, sub);
120 ss <<
" but found it\n";
121 found_range.highlight(ss);
122 ss <<
"From " << check <<
"\n";
123 throw std::runtime_error(ss.str());
131 TORCH_API
void run(
const std::string& test_file) {
133 doChecks(std::make_shared<std::string>(test_file));
136 TORCH_API
void addCheck(
138 const std::string& s,
140 Check check(type, s, std::move(count));
143 if (groups.size() == 0 || (type != CHECK_NOT && type != CHECK_DAG)) {
144 groups.push_back({check});
146 auto& last_group = groups.back();
147 if (last_group.at(0).type_ == type) {
148 last_group.push_back(check);
150 groups.push_back({check});
157 bool has_run =
false;
159 friend std::ostream& operator<<(std::ostream& out,
const FileCheckImpl& fc);
163 const std::vector<Check>& nots,
164 const std::shared_ptr<std::string>& file,
167 auto start = prev.end();
168 auto end = next.start();
172 for (
const auto& check : nots) {
173 AT_ASSERT(check.type_ == CHECK_NOT);
174 assertNotFind(
SourceRange(file, start, end), check.search_str_, check);
179 const std::vector<Check>& group,
180 const std::shared_ptr<std::string>& test_file,
182 size_t group_beg = std::string::npos;
183 size_t group_end = 0;
185 AT_ASSERT(groups.size() != 0);
186 for (
const auto& check : group) {
187 AT_ASSERT(check.type_ == group[0].type_);
188 auto pos = assertFind(test_file, check.search_str_, prev.end(), check);
189 group_beg = std::min(pos, group_beg);
190 group_end = std::max(pos + check.search_str_.size(), group_end);
193 return SourceRange(test_file, group_beg, group_end);
197 const std::vector<Check>& group,
198 const std::shared_ptr<std::string>& test_file,
200 AT_ASSERT(group.size() != 0);
201 CheckType type = group[0].type_;
203 if (type == CHECK_DAG) {
204 return matchDagGroup(group, test_file, prev);
206 AT_ASSERT(type != CHECK_NOT);
207 AT_ASSERT(group.size() == 1);
209 const auto& check = group[0];
210 size_t start_range = prev.end();
211 size_t end_range = start_range;
213 switch (check.type_) {
216 assertFind(test_file, check.search_str_, start_range, check);
217 end_range = start_range + check.search_str_.size();
220 auto pos = assertFind(test_file, check.search_str_, start_range, check);
221 assertNotFind(
SourceRange(test_file, prev.end(), pos),
"\n", check);
223 end_range = pos + check.search_str_.size();
226 auto line_end = assertFind(test_file,
"\n", start_range, check);
228 assertFind(test_file, check.search_str_, line_end + 1, check);
229 assertNotFind(
SourceRange(test_file, line_end + 1, pos),
"\n", check);
231 end_range = pos + check.search_str_.size();
234 auto group_start_range = std::string::npos;
235 AT_ASSERT(check.count_ && *check.count_ != 0);
236 for (
size_t i = 0; i < *check.count_; ++i) {
238 assertFind(test_file, check.search_str_, start_range, check);
239 group_start_range = std::min(start_range, group_start_range);
240 end_range = start_range + check.search_str_.size();
241 start_range = end_range;
243 start_range = group_start_range;
252 return SourceRange(test_file, start_range, end_range);
255 void doChecks(
const std::shared_ptr<std::string>& test_file) {
257 for (
size_t i = 0; i < groups.size(); i++) {
258 const auto& curr_group = groups[i];
259 CheckType type = curr_group.at(0).type_;
260 if (type != CHECK_NOT) {
261 prev = matchGroup(curr_group, test_file, prev);
263 if (i + 1 < groups.size()) {
264 const auto& next_group = groups[i + 1];
265 AT_ASSERT(next_group.at(0).type_ != CHECK_NOT);
266 SourceRange after_not = matchGroup(next_group, test_file, prev);
267 doCheckNot(curr_group, test_file, prev, after_not);
272 test_file, test_file->size() + 1, test_file->size() + 1);
273 doCheckNot(curr_group, test_file, prev, end_of_file);
279 std::vector<Check> checks;
280 std::shared_ptr<std::string> check_file;
281 std::vector<std::vector<Check>> groups;
286 std::ostream& operator<<(std::ostream& out,
const FileCheckImpl& fc) {
287 out <<
"FileCheck checks:\n";
288 for (
const Check& c : fc.checks) {
289 out <<
"\t" << c <<
"\n";
294 FileCheck::~FileCheck() {
295 if (!fcImpl->has_run) {
296 std::cout <<
"You have not run this instance of FileCheck!\n";
297 std::cout << *fcImpl;
302 void FileCheck::run(
const std::string& test_file) {
303 fcImpl->run(test_file);
306 void FileCheck::run(
const Graph& graph) {
307 std::stringstream graph_str;
309 fcImpl->run(graph_str.str());
312 FileCheck* FileCheck::check(
const std::string& str) {
313 fcImpl->addCheck(CHECK, str);
317 FileCheck* FileCheck::check_not(
const std::string& str) {
318 fcImpl->addCheck(CHECK_NOT, str);
322 FileCheck* FileCheck::check_same(
const std::string& str) {
323 fcImpl->addCheck(CHECK_SAME, str);
327 FileCheck* FileCheck::check_next(
const std::string& str) {
328 fcImpl->addCheck(CHECK_NEXT, str);
333 const std::string& str,
336 fcImpl->addCheck(CHECK_COUNT, str, count);
338 fcImpl->addCheck(CHECK_NOT, str);
343 FileCheck* FileCheck::check_dag(
const std::string& str) {
344 fcImpl->addCheck(CHECK_DAG, str);