Caffe2 - C++ API
A deep learning, cross platform ML framework
file_check.cpp
1 //==-- llvm/Support/FileCheck.h ---------------------------*- C++ -*-==//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 
10 // API modified from llvm::FileCheck
11 
12 #include <c10/util/Exception.h>
13 #include <c10/util/Optional.h>
14 #include <torch/csrc/WindowsTorchApiMacro.h>
15 #include <torch/csrc/jit/source_range.h>
16 #include <algorithm>
17 #include <iostream>
18 #include <sstream>
19 #include <string>
20 
21 #include <torch/csrc/jit/ir.h>
22 #include <torch/csrc/jit/testing/file_check.h>
23 
24 namespace torch {
25 namespace jit {
26 
27 void printQuotedString(std::ostream& stmt, const std::string& str);
28 
29 namespace testing {
30 
31 enum CheckType {
32  CHECK,
33  CHECK_NEXT,
34  CHECK_SAME,
35  CHECK_NOT,
36  CHECK_COUNT,
37  CHECK_DAG,
38 };
39 
40 struct Check {
41  Check(
42  CheckType type,
43  std::string str,
44  c10::optional<size_t> count = c10::nullopt)
45  : type_(type), search_str_(std::move(str)) {
46  count_ = std::move(count);
47  };
48 
49  CheckType type_;
50  c10::optional<size_t> count_;
51  const std::string search_str_;
52 
53  friend std::ostream& operator<<(std::ostream& out, const Check& c);
54 };
55 
56 std::ostream& operator<<(std::ostream& out, const Check& c) {
57  switch (c.type_) {
58  case CHECK:
59  out << "CHECK";
60  break;
61  case CHECK_NEXT:
62  out << "CHECK-NEXT";
63  break;
64  case CHECK_SAME:
65  out << "CHECK-SAME";
66  break;
67  case CHECK_NOT:
68  out << "CHECK-NOT";
69  break;
70  case CHECK_DAG:
71  out << "CHECK-DAG";
72  break;
73  case CHECK_COUNT:
74  out << "CHECK-COUNT-" << *c.count_;
75  break;
76  }
77  out << ": " << c.search_str_;
78  return out;
79 };
80 
81 namespace {
82 size_t assertFind(
83  const SourceRange& search_range,
84  const std::string& sub,
85  const Check& check) {
86  auto pos = search_range.file_ptr()->find(sub, search_range.start());
87  if (pos == std::string::npos || (pos + sub.size()) > search_range.end()) {
88  auto found_range =
89  SourceRange(search_range.file_ptr(), search_range.start(), sub.size());
90  std::stringstream ss;
91  ss << "Expected to find ";
92  printQuotedString(ss, sub);
93  ss << " but did not find it\n";
94  found_range.highlight(ss);
95  ss << "From " << check << "\n";
96  throw std::runtime_error(ss.str());
97  }
98  return pos;
99 }
100 
101 size_t assertFind(
102  const std::shared_ptr<std::string>& file,
103  const std::string& sub,
104  size_t start,
105  const Check& check) {
106  return assertFind(SourceRange(file, start, file->size()), sub, check);
107 }
108 
109 void assertNotFind(
110  const SourceRange& search_range,
111  const std::string& sub,
112  const Check& check) {
113  auto pos = search_range.file_ptr()->find(sub, search_range.start());
114  if (pos != std::string::npos && (pos + sub.size()) <= search_range.end()) {
115  auto found_range =
116  SourceRange(search_range.file_ptr(), pos, sub.size() + pos);
117  std::stringstream ss;
118  ss << "Expected to not find ";
119  printQuotedString(ss, sub);
120  ss << " but found it\n";
121  found_range.highlight(ss);
122  ss << "From " << check << "\n";
123  throw std::runtime_error(ss.str());
124  }
125 }
126 } // namespace
127 
129  TORCH_API explicit FileCheckImpl() = default;
130 
131  TORCH_API void run(const std::string& test_file) {
132  has_run = true;
133  doChecks(std::make_shared<std::string>(test_file));
134  }
135 
136  TORCH_API void addCheck(
137  CheckType type,
138  const std::string& s,
139  c10::optional<size_t> count = c10::nullopt) {
140  Check check(type, s, std::move(count));
141 
142  // consecutive CHECK_DAGs & CHECK_NOTs need to be evaluated as a group
143  if (groups.size() == 0 || (type != CHECK_NOT && type != CHECK_DAG)) {
144  groups.push_back({check});
145  } else {
146  auto& last_group = groups.back();
147  if (last_group.at(0).type_ == type) {
148  last_group.push_back(check);
149  } else {
150  groups.push_back({check});
151  }
152  }
153 
154  has_run = false;
155  }
156 
157  bool has_run = false;
158 
159  friend std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc);
160 
161  private:
162  void doCheckNot(
163  const std::vector<Check>& nots,
164  const std::shared_ptr<std::string>& file,
165  const SourceRange& prev,
166  const SourceRange& next) {
167  auto start = prev.end(); // inclusive
168  auto end = next.start(); // exclusive
169  if (end < start) {
170  return;
171  }
172  for (const auto& check : nots) {
173  AT_ASSERT(check.type_ == CHECK_NOT);
174  assertNotFind(SourceRange(file, start, end), check.search_str_, check);
175  }
176  }
177 
178  SourceRange matchDagGroup(
179  const std::vector<Check>& group,
180  const std::shared_ptr<std::string>& test_file,
181  const SourceRange& prev) {
182  size_t group_beg = std::string::npos;
183  size_t group_end = 0;
184 
185  AT_ASSERT(groups.size() != 0);
186  for (const auto& check : group) {
187  AT_ASSERT(check.type_ == group[0].type_);
188  auto pos = assertFind(test_file, check.search_str_, prev.end(), check);
189  group_beg = std::min(pos, group_beg);
190  group_end = std::max(pos + check.search_str_.size(), group_end);
191  }
192 
193  return SourceRange(test_file, group_beg, group_end);
194  }
195 
196  SourceRange matchGroup(
197  const std::vector<Check>& group,
198  const std::shared_ptr<std::string>& test_file,
199  const SourceRange& prev) {
200  AT_ASSERT(group.size() != 0);
201  CheckType type = group[0].type_;
202 
203  if (type == CHECK_DAG) {
204  return matchDagGroup(group, test_file, prev);
205  }
206  AT_ASSERT(type != CHECK_NOT);
207  AT_ASSERT(group.size() == 1);
208 
209  const auto& check = group[0];
210  size_t start_range = prev.end();
211  size_t end_range = start_range;
212 
213  switch (check.type_) {
214  case CHECK: {
215  start_range =
216  assertFind(test_file, check.search_str_, start_range, check);
217  end_range = start_range + check.search_str_.size();
218  } break;
219  case CHECK_SAME: {
220  auto pos = assertFind(test_file, check.search_str_, start_range, check);
221  assertNotFind(SourceRange(test_file, prev.end(), pos), "\n", check);
222  start_range = pos;
223  end_range = pos + check.search_str_.size();
224  } break;
225  case CHECK_NEXT: {
226  auto line_end = assertFind(test_file, "\n", start_range, check);
227  auto pos =
228  assertFind(test_file, check.search_str_, line_end + 1, check);
229  assertNotFind(SourceRange(test_file, line_end + 1, pos), "\n", check);
230  start_range = pos;
231  end_range = pos + check.search_str_.size();
232  } break;
233  case CHECK_COUNT: {
234  auto group_start_range = std::string::npos;
235  AT_ASSERT(check.count_ && *check.count_ != 0);
236  for (size_t i = 0; i < *check.count_; ++i) {
237  start_range =
238  assertFind(test_file, check.search_str_, start_range, check);
239  group_start_range = std::min(start_range, group_start_range);
240  end_range = start_range + check.search_str_.size();
241  start_range = end_range;
242  }
243  start_range = group_start_range;
244  } break;
245  case CHECK_DAG: {
246  AT_ERROR();
247  } break;
248  case CHECK_NOT: {
249  AT_ERROR();
250  } break;
251  }
252  return SourceRange(test_file, start_range, end_range);
253  }
254 
255  void doChecks(const std::shared_ptr<std::string>& test_file) {
256  SourceRange prev(test_file, 0, 0);
257  for (size_t i = 0; i < groups.size(); i++) {
258  const auto& curr_group = groups[i];
259  CheckType type = curr_group.at(0).type_;
260  if (type != CHECK_NOT) {
261  prev = matchGroup(curr_group, test_file, prev);
262  } else {
263  if (i + 1 < groups.size()) {
264  const auto& next_group = groups[i + 1];
265  AT_ASSERT(next_group.at(0).type_ != CHECK_NOT);
266  SourceRange after_not = matchGroup(next_group, test_file, prev);
267  doCheckNot(curr_group, test_file, prev, after_not);
268  prev = after_not;
269  ++i; // already checked the group after
270  } else {
271  SourceRange end_of_file(
272  test_file, test_file->size() + 1, test_file->size() + 1);
273  doCheckNot(curr_group, test_file, prev, end_of_file);
274  }
275  }
276  }
277  }
278 
279  std::vector<Check> checks;
280  std::shared_ptr<std::string> check_file;
281  std::vector<std::vector<Check>> groups;
282 };
283 
284 FileCheck::FileCheck() : fcImpl(new FileCheckImpl()){};
285 
286 std::ostream& operator<<(std::ostream& out, const FileCheckImpl& fc) {
287  out << "FileCheck checks:\n";
288  for (const Check& c : fc.checks) {
289  out << "\t" << c << "\n";
290  }
291  return out;
292 };
293 
294 FileCheck::~FileCheck() {
295  if (!fcImpl->has_run) {
296  std::cout << "You have not run this instance of FileCheck!\n";
297  std::cout << *fcImpl;
298  }
299  fcImpl.reset();
300 };
301 
302 void FileCheck::run(const std::string& test_file) {
303  fcImpl->run(test_file);
304 };
305 
306 void FileCheck::run(const Graph& graph) {
307  std::stringstream graph_str;
308  graph_str << graph;
309  fcImpl->run(graph_str.str());
310 };
311 
312 FileCheck* FileCheck::check(const std::string& str) {
313  fcImpl->addCheck(CHECK, str);
314  return this;
315 }
316 
317 FileCheck* FileCheck::check_not(const std::string& str) {
318  fcImpl->addCheck(CHECK_NOT, str);
319  return this;
320 }
321 
322 FileCheck* FileCheck::check_same(const std::string& str) {
323  fcImpl->addCheck(CHECK_SAME, str);
324  return this;
325 }
326 
327 FileCheck* FileCheck::check_next(const std::string& str) {
328  fcImpl->addCheck(CHECK_NEXT, str);
329  return this;
330 }
331 
332 FileCheck* FileCheck::check_count(
333  const std::string& str,
334  size_t count,
335  bool exactly) {
336  fcImpl->addCheck(CHECK_COUNT, str, count);
337  if (exactly) {
338  fcImpl->addCheck(CHECK_NOT, str);
339  }
340  return this;
341 }
342 
343 FileCheck* FileCheck::check_dag(const std::string& str) {
344  fcImpl->addCheck(CHECK_DAG, str);
345  return this;
346 }
347 } // namespace testing
348 } // namespace jit
349 } // namespace torch
Definition: jit_type.h:17