Caffe2 - Python API
A deep learning, cross platform ML framework
expecttest.py
1 import re
2 import unittest
3 import traceback
4 import os
5 import string
6 
7 
8 ACCEPT = os.getenv('EXPECTTEST_ACCEPT')
9 
10 
11 def nth_line(src, lineno):
12  """
13  Compute the starting index of the n-th line (where n is 1-indexed)
14 
15  >>> nth_line("aaa\\nbb\\nc", 2)
16  4
17  """
18  assert lineno >= 1
19  pos = 0
20  for _ in range(lineno - 1):
21  pos = src.find('\n', pos) + 1
22  return pos
23 
24 
25 def nth_eol(src, lineno):
26  """
27  Compute the ending index of the n-th line (before the newline,
28  where n is 1-indexed)
29 
30  >>> nth_eol("aaa\\nbb\\nc", 2)
31  6
32  """
33  assert lineno >= 1
34  pos = -1
35  for _ in range(lineno):
36  pos = src.find('\n', pos + 1)
37  if pos == -1:
38  return len(src)
39  return pos
40 
41 
42 def normalize_nl(t):
43  return t.replace('\r\n', '\n').replace('\r', '\n')
44 
45 
46 def escape_trailing_quote(s, quote):
47  if s and s[-1] == quote:
48  return s[:-1] + '\\' + quote
49  else:
50  return s
51 
52 
53 class EditHistory(object):
54  def __init__(self):
55  self.state = {}
56 
57  def adjust_lineno(self, fn, lineno):
58  if fn not in self.state:
59  return lineno
60  for edit_loc, edit_diff in self.state[fn]:
61  if lineno > edit_loc:
62  lineno += edit_diff
63  return lineno
64 
65  def seen_file(self, fn):
66  return fn in self.state
67 
68  def record_edit(self, fn, lineno, delta):
69  self.state.setdefault(fn, []).append((lineno, delta))
70 
71 
72 EDIT_HISTORY = EditHistory()
73 
74 
75 def ok_for_raw_triple_quoted_string(s, quote):
76  """
77  Is this string representable inside a raw triple-quoted string?
78  Due to the fact that backslashes are always treated literally,
79  some strings are not representable.
80 
81  >>> ok_for_raw_triple_quoted_string("blah", quote="'")
82  True
83  >>> ok_for_raw_triple_quoted_string("'", quote="'")
84  False
85  >>> ok_for_raw_triple_quoted_string("a ''' b", quote="'")
86  False
87  """
88  return quote * 3 not in s and (not s or s[-1] not in [quote, '\\'])
89 
90 
91 # This operates on the REVERSED string (that's why suffix is first)
92 RE_EXPECT = re.compile(r"^(?P<suffix>[^\n]*?)"
93  r"(?P<quote>'''|" r'""")'
94  r"(?P<body>.*?)"
95  r"(?P=quote)"
96  r"(?P<raw>r?)", re.DOTALL)
97 
98 
99 def replace_string_literal(src, lineno, new_string):
100  r"""
101  Replace a triple quoted string literal with new contents.
102  Only handles printable ASCII correctly at the moment. This
103  will preserve the quote style of the original string, and
104  makes a best effort to preserve raw-ness (unless it is impossible
105  to do so.)
106 
107  Returns a tuple of the replaced string, as well as a delta of
108  number of lines added/removed.
109 
110  >>> replace_string_literal("'''arf'''", 1, "barf")
111  ("'''barf'''", 0)
112  >>> r = replace_string_literal(" moo = '''arf'''", 1, "'a'\n\\b\n")
113  >>> print(r[0])
114  moo = '''\
115  'a'
116  \\b
117  '''
118  >>> r[1]
119  3
120  >>> replace_string_literal(" moo = '''\\\narf'''", 2, "'a'\n\\b\n")[1]
121  2
122  >>> print(replace_string_literal(" f('''\"\"\"''')", 1, "a ''' b")[0])
123  f('''a \'\'\' b''')
124  """
125  # Haven't implemented correct escaping for non-printable characters
126  assert all(c in string.printable for c in new_string)
127  i = nth_eol(src, lineno)
128  new_string = normalize_nl(new_string)
129 
130  delta = [new_string.count("\n")]
131  if delta[0] > 0:
132  delta[0] += 1 # handle the extra \\\n
133 
134  def replace(m):
135  s = new_string
136  raw = m.group('raw') == 'r'
137  if not raw or not ok_for_raw_triple_quoted_string(s, quote=m.group('quote')[0]):
138  raw = False
139  s = s.replace('\\', '\\\\')
140  if m.group('quote') == "'''":
141  s = escape_trailing_quote(s, "'").replace("'''", r"\'\'\'")
142  else:
143  s = escape_trailing_quote(s, '"').replace('"""', r'\"\"\"')
144 
145  new_body = "\\\n" + s if "\n" in s and not raw else s
146  delta[0] -= m.group('body').count("\n")
147 
148  return ''.join([m.group('suffix'),
149  m.group('quote'),
150  new_body[::-1],
151  m.group('quote'),
152  'r' if raw else '',
153  ])
154 
155  # Having to do this in reverse is very irritating, but it's the
156  # only way to make the non-greedy matches work correctly.
157  return (RE_EXPECT.sub(replace, src[:i][::-1], count=1)[::-1] + src[i:], delta[0])
158 
159 
160 class TestCase(unittest.TestCase):
161  longMessage = True
162 
163  def assertExpectedInline(self, actual, expect, skip=0):
164  if ACCEPT:
165  if actual != expect:
166  # current frame and parent frame, plus any requested skip
167  tb = traceback.extract_stack(limit=2 + skip)
168  fn, lineno, _, _ = tb[0]
169  print("Accepting new output for {} at {}:{}".format(self.id(), fn, lineno))
170  with open(fn, 'r+') as f:
171  old = f.read()
172 
173  # compute the change in lineno
174  lineno = EDIT_HISTORY.adjust_lineno(fn, lineno)
175  new, delta = replace_string_literal(old, lineno, actual)
176 
177  assert old != new, "Failed to substitute string at {}:{}".format(fn, lineno)
178 
179  # Only write the backup file the first time we hit the
180  # file
181  if not EDIT_HISTORY.seen_file(fn):
182  with open(fn + ".bak", 'w') as f_bak:
183  f_bak.write(old)
184  f.seek(0)
185  f.truncate(0)
186 
187  f.write(new)
188 
189  EDIT_HISTORY.record_edit(fn, lineno, delta)
190  else:
191  help_text = ("To accept the new output, re-run test with "
192  "envvar EXPECTTEST_ACCEPT=1 (we recommend "
193  "staging/committing your changes before doing this)")
194  if hasattr(self, "assertMultiLineEqual"):
195  self.assertMultiLineEqual(expect, actual, msg=help_text)
196  else:
197  self.assertEqual(expect, actual, msg=help_text)
198 
199 
200 if __name__ == "__main__":
201  import doctest
202  doctest.testmod()