Fix Python 3.8 expecttest machinery again, this time for good. (#60044)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60044
In #59709 I attempted to fix the expecttest machinery to work in Python
3.8. However, I noticed that it would fail to do substitutions in this
case:
```
self.assertExpectedInline(
foo(),
"""bar"""
)
```
This is because the triple quoted string is not on the same line as the
backtrace line number (at the very beginning), and for safety reasons
the preexisting regex refused to search beyond the first line. This
wasn't a big deal prior to Python 3.8 because the flipped version of
the regex simply required the triple quoted string to be flush with
the end of the statement (which it typically was!) But it is a big deal
now that we only have the start of the statement.
I couldn't think of a way to fix this in the current model, so I decided
to call in the big guns. Instead of trying to do the regex with only
the start xor end line number, I now require you provide BOTH line numbers,
and we will only regex within this range. The way we compute these line
numbers is by parsing the Python test file with ast, and then searching
through statements until we find one that is consistent with the line
number reported by the backtrace. If we don't find anything, we
conservatively assume that the string lies exactly in the backtrace
(and you'll probably fail the substitution in that case.)
The resulting code is quite a lot simpler (no more reversed regex) and
hopefully more robust, although I suppose we are going to have to do
some field testing.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Reviewed By: walterddr
Differential Revision: D29146943
Pulled By: ezyang
fbshipit-source-id: 2c24abc3acd4275c5b3a8f222d2a60cbad5e8c78
diff --git a/test/test_expecttest.py b/test/test_expecttest.py
index dbf2432..d193dbb 100644
--- a/test/test_expecttest.py
+++ b/test/test_expecttest.py
@@ -30,8 +30,8 @@
return len("\n".join(xs))
self.assertEqual(expecttest.nth_line(t, lineno), nth_line_ref(t, lineno))
- @hypothesis.given(text(string.printable), booleans(), sampled_from(['"', "'"]), booleans())
- def test_replace_string_literal_roundtrip(self, t, raw, quote, lineno_at_start):
+ @hypothesis.given(text(string.printable), booleans(), sampled_from(['"', "'"]))
+ def test_replace_string_literal_roundtrip(self, t, raw, quote):
if raw:
hypothesis.assume(expecttest.ok_for_raw_triple_quoted_string(t, quote=quote))
prog = """\
@@ -40,7 +40,7 @@
r3 = {r}{quote}placeholder3{quote}
""".format(r='r' if raw else '', quote=quote * 3)
new_prog = expecttest.replace_string_literal(
- textwrap.dedent(prog), 2, t, lineno_at_start=lineno_at_start)[0]
+ textwrap.dedent(prog), 2, 2, t)[0]
ns : Dict[str, Any] = {}
exec(new_prog, ns)
msg = "program was:\n{}".format(new_prog)
@@ -48,7 +48,7 @@
self.assertEqual(ns['r2'], expecttest.normalize_nl(t), msg=msg) # noqa: F821
self.assertEqual(ns['r3'], 'placeholder3', msg=msg) # noqa: F821
- def test_sample_lineno_at_end(self):
+ def test_sample_lineno(self):
prog = r"""
single_single('''0''')
single_multi('''1''')
@@ -65,22 +65,27 @@
multi_multi_more('''\
6
''')
+different_indent(
+ RuntimeError,
+ '''7'''
+)
"""
- # NB: These are the end of the statements, not beginning
- # TODO: Test other permutations of these edits
- edits = [(2, "a"),
- (3, "b\n"),
- (6, "c"),
- (10, "d\n"),
- (13, "e\n"),
- (16, "f\ng\n")]
+ edits = [(2, 2, "a"),
+ (3, 3, "b\n"),
+ (4, 6, "c"),
+ (7, 10, "d\n"),
+ (11, 13, "e\n"),
+ (14, 16, "f\ng\n"),
+ (17, 20, "h")]
history = expecttest.EditHistory()
fn = 'not_a_real_file.py'
- for lineno, actual in edits:
- lineno = history.adjust_lineno(fn, lineno)
+ for start_lineno, end_lineno, actual in edits:
+ start_lineno = history.adjust_lineno(fn, start_lineno)
+ end_lineno = history.adjust_lineno(fn, end_lineno)
prog, delta = expecttest.replace_string_literal(
- prog, lineno, actual, lineno_at_start=False)
- history.record_edit(fn, lineno, delta)
+ prog, start_lineno, end_lineno, actual)
+ # NB: it doesn't really matter start/end you record edit at
+ history.record_edit(fn, start_lineno, delta)
self.assertExpectedInline(prog, r"""
single_single('''a''')
single_multi('''\
@@ -97,56 +102,10 @@
f
g
''')
-""")
-
- def test_sample_lineno_at_start(self):
- prog = r"""
-single_single('''0''')
-single_multi('''1''')
-multi_single('''\
-2
-''')
-multi_multi_less('''\
-3
-4
-''')
-multi_multi_same('''\
-5
-''')
-multi_multi_more('''\
-6
-''')
-"""
- # NB: These are the beginning of the statements
- edits = [(2, "a"),
- (3, "b\n"),
- (4, "c"),
- (7, "d\n"),
- (11, "e\n"),
- (14, "f\ng\n")]
- history = expecttest.EditHistory()
- fn = 'not_a_real_file.py'
- for lineno, actual in edits:
- lineno = history.adjust_lineno(fn, lineno)
- prog, delta = expecttest.replace_string_literal(
- prog, lineno, actual, lineno_at_start=True)
- history.record_edit(fn, lineno, delta)
- self.assertExpectedInline(prog, r"""
-single_single('''a''')
-single_multi('''\
-b
-''')
-multi_single('''c''')
-multi_multi_less('''\
-d
-''')
-multi_multi_same('''\
-e
-''')
-multi_multi_more('''\
-f
-g
-''')
+different_indent(
+ RuntimeError,
+ '''h'''
+)
""")
def test_lineno_assumptions(self):
diff --git a/torch/testing/_internal/expecttest.py b/torch/testing/_internal/expecttest.py
index 7fff670..900274e 100644
--- a/torch/testing/_internal/expecttest.py
+++ b/torch/testing/_internal/expecttest.py
@@ -4,6 +4,7 @@
import os
import string
import sys
+import ast
from typing import Tuple
@@ -137,7 +138,6 @@
RE_EXPECT = re.compile(
(
- r"^(?P<prefix>[^\n]*?)"
r"(?P<raw>r?)"
r"(?P<quote>'''|" r'""")'
r"(?P<body>.*?)"
@@ -147,17 +147,8 @@
)
-# This operates on the REVERSED string (that's why suffix is first)
-RE_REVERSED_EXPECT = \
- re.compile(r"^(?P<suffix>[^\n]*?)"
- r"(?P<quote>'''|" r'""")'
- r"(?P<body>.*?)"
- r"(?P=quote)"
- r"(?P<raw>r?)", re.DOTALL)
-
-
-def replace_string_literal(src : str, lineno : int,
- new_string : str, *, lineno_at_start: bool) -> Tuple[str, int]:
+def replace_string_literal(src : str, start_lineno : int, end_lineno : int,
+ new_string : str) -> Tuple[str, int]:
r"""
Replace a triple quoted string literal with new contents.
Only handles printable ASCII correctly at the moment. This
@@ -168,9 +159,9 @@
Returns a tuple of the replaced string, as well as a delta of
number of lines added/removed.
- >>> replace_string_literal("'''arf'''", 1, "barf", lineno_at_start=False)
+ >>> replace_string_literal("'''arf'''", 1, 1, "barf")
("'''barf'''", 0)
- >>> r = replace_string_literal(" moo = '''arf'''", 1, "'a'\n\\b\n", lineno_at_start=False)
+ >>> r = replace_string_literal(" moo = '''arf'''", 1, 1, "'a'\n\\b\n")
>>> print(r[0])
moo = '''\
'a'
@@ -178,9 +169,9 @@
'''
>>> r[1]
3
- >>> replace_string_literal(" moo = '''\\\narf'''", 2, "'a'\n\\b\n", lineno_at_start=False)[1]
+ >>> replace_string_literal(" moo = '''\\\narf'''", 1, 2, "'a'\n\\b\n")[1]
2
- >>> print(replace_string_literal(" f('''\"\"\"''')", 1, "a ''' b", lineno_at_start=False)[0])
+ >>> print(replace_string_literal(" f('''\"\"\"''')", 1, 1, "a ''' b")[0])
f('''a \'\'\' b''')
"""
# Haven't implemented correct escaping for non-printable characters
@@ -192,7 +183,12 @@
if delta[0] > 0:
delta[0] += 1 # handle the extra \\\n
- def compute_raw_new_body_and_adjust_delta(m):
+ assert start_lineno <= end_lineno
+ start = nth_line(src, start_lineno)
+ end = nth_eol(src, end_lineno)
+ assert start <= end
+
+ def replace(m):
s = new_string
raw = m.group('raw') == 'r'
if not raw or not ok_for_raw_triple_quoted_string(s, quote=m.group('quote')[0]):
@@ -205,39 +201,13 @@
new_body = "\\\n" + s if "\n" in s and not raw else s
delta[0] -= m.group('body').count("\n")
- return raw, new_body
+ return ''.join(['r' if raw else '',
+ m.group('quote'),
+ new_body,
+ m.group('quote'),
+ ])
- if lineno_at_start:
- i = nth_line(src, lineno)
-
- # i points to the start of the string
- def replace(m):
- raw, new_body = compute_raw_new_body_and_adjust_delta(m)
- return ''.join([m.group('prefix'),
- 'r' if raw else '',
- m.group('quote'),
- new_body,
- m.group('quote'),
- ])
-
- return (src[:i] + RE_EXPECT.sub(replace, src[i:], count=1), delta[0])
- else:
- i = nth_eol(src, lineno)
-
- # i points to the END of the string. Do some funny
- # business with reversing the string to do the replace
- def replace(m):
- raw, new_body = compute_raw_new_body_and_adjust_delta(m)
- return ''.join([m.group('suffix'),
- m.group('quote'),
- new_body[::-1],
- m.group('quote'),
- 'r' if raw else '',
- ])
-
- # Having to do this in reverse is very irritating, but it's the
- # only way to make the non-greedy matches work correctly.
- return (RE_REVERSED_EXPECT.sub(replace, src[:i][::-1], count=1)[::-1] + src[i:], delta[0])
+ return (src[:start] + RE_EXPECT.sub(replace, src[start:end], count=1) + src[end:], delta[0])
class TestCase(unittest.TestCase):
@@ -263,15 +233,35 @@
print("Accepting new output for {} at {}:{}".format(self.id(), fn, lineno))
with open(fn, 'r+') as f:
old = f.read()
+ old_ast = ast.parse(old)
- # compute the change in lineno
+ # NB: it's only the traceback line numbers that are wrong;
+ # we reread the file every time we write to it, so AST's
+ # line numbers are correct
lineno = EDIT_HISTORY.adjust_lineno(fn, lineno)
- new, delta = replace_string_literal(
- old, lineno, actual,
- lineno_at_start=LINENO_AT_START
- )
- assert old != new, f"Failed to substitute string at {fn}:{lineno}; did you use triple quotes?"
+ # Conservative assumption to start
+ start_lineno = lineno
+ end_lineno = lineno
+ # Try to give a more accurate bounds based on AST
+ # NB: this walk is in no specified order (in practice it's
+ # breadth first)
+ for n in ast.walk(old_ast):
+ if isinstance(n, ast.Expr):
+ if hasattr(n, 'end_lineno'):
+ assert LINENO_AT_START
+ if n.lineno == start_lineno:
+ end_lineno = n.end_lineno # type: ignore[attr-defined]
+ else:
+ if n.lineno == end_lineno:
+ start_lineno = n.lineno
+
+ new, delta = replace_string_literal(old, start_lineno, end_lineno, actual)
+
+ assert old != new, f"Failed to substitute string at {fn}:{lineno}; did you use triple quotes? " \
+ "If this is unexpected, please file a bug report at " \
+ "https://github.com/pytorch/pytorch/issues/new?labels=module:%20expecttest " \
+ f"with the contents of the source file near {fn}:{lineno}"
# Only write the backup file the first time we hit the
# file