Skip to content

Commit

Permalink
Improve --update-data handler
Browse files Browse the repository at this point in the history
  • Loading branch information
ikonst committed May 22, 2023
1 parent 6c7e480 commit ea590fb
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 36 deletions.
30 changes: 30 additions & 0 deletions mypy/test/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import sys
import tempfile
from abc import abstractmethod
from pathlib import Path
from typing import Any, Iterator, NamedTuple, Pattern, Union
from typing_extensions import Final, TypeAlias as _TypeAlias

Expand Down Expand Up @@ -698,6 +699,12 @@ def collect(self) -> Iterator[DataFileCollector]:
yield DataFileCollector.from_parent(parent=self, name=data_file)


class DataFileFix(NamedTuple):
lineno: int # 1-offset, inclusive
end_lineno: int # 1-offset, exclusive
lines: list[str]


class DataFileCollector(pytest.Collector):
"""Represents a single `.test` data driven test file.
Expand All @@ -706,6 +713,8 @@ class DataFileCollector(pytest.Collector):

parent: DataSuiteCollector

_fixes: list[DataFileFix]

@classmethod # We have to fight with pytest here:
def from_parent(
cls, parent: DataSuiteCollector, *, name: str # type: ignore[override]
Expand All @@ -721,6 +730,27 @@ def collect(self) -> Iterator[DataDrivenTestCase]:
file=os.path.join(self.parent.obj.data_prefix, self.name),
)

def setup(self) -> None:
super().setup()
self._fixes = []

def teardown(self) -> None:
super().teardown()
self._apply_fixes()

def enqueue_fix(self, fix: DataFileFix) -> None:
self._fixes.append(fix)

def _apply_fixes(self) -> None:
if not self._fixes:
return
data_path = Path(self.parent.obj.data_prefix) / self.name
lines = data_path.read_text().split("\n")
# start from end to prevent line offsets from shifting as we update
for fix in sorted(self._fixes, reverse=True):
lines[fix.lineno - 1 : fix.end_lineno - 1] = fix.lines
data_path.write_text("\n".join(lines))


def add_test_name_suffix(name: str, suffix: str) -> str:
# Find magic suffix of form "-foobar" (used for things like "-skip").
Expand Down
33 changes: 0 additions & 33 deletions mypy/test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,39 +141,6 @@ def assert_target_equivalence(name: str, expected: list[str], actual: list[str])
)


def update_testcase_output(testcase: DataDrivenTestCase, output: list[str]) -> None:
assert testcase.old_cwd is not None, "test was not properly set up"
testcase_path = os.path.join(testcase.old_cwd, testcase.file)
with open(testcase_path, encoding="utf8") as f:
data_lines = f.read().splitlines()
test = "\n".join(data_lines[testcase.line : testcase.last_line])

mapping: dict[str, list[str]] = {}
for old, new in zip(testcase.output, output):
PREFIX = "error:"
ind = old.find(PREFIX)
if ind != -1 and old[:ind] == new[:ind]:
old, new = old[ind + len(PREFIX) :], new[ind + len(PREFIX) :]
mapping.setdefault(old, []).append(new)

for old in mapping:
if test.count(old) == len(mapping[old]):
betweens = test.split(old)

# Interleave betweens and mapping[old]
from itertools import chain

interleaved = [betweens[0]] + list(
chain.from_iterable(zip(mapping[old], betweens[1:]))
)
test = "".join(interleaved)

data_lines[testcase.line : testcase.last_line] = [test]
data = "\n".join(data_lines)
with open(testcase_path, "w", encoding="utf8") as f:
print(data, file=f)


def show_align_message(s1: str, s2: str) -> None:
"""Align s1 and s2 so that the their first difference is highlighted.
Expand Down
68 changes: 65 additions & 3 deletions mypy/test/testcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,24 @@
import os
import re
import sys
from collections import defaultdict
from typing import Iterator

from mypy import build
from mypy.build import Graph
from mypy.errors import CompileError
from mypy.modulefinder import BuildSource, FindModuleCache, SearchPaths
from mypy.options import TYPE_VAR_TUPLE, UNPACK
from mypy.test.config import test_data_prefix, test_temp_dir
from mypy.test.data import DataDrivenTestCase, DataSuite, FileOperation, module_from_path
from mypy.test.data import (
DataDrivenTestCase,
DataFileCollector,
DataFileFix,
DataSuite,
FileOperation,
module_from_path,
parse_test_data,
)
from mypy.test.helpers import (
assert_module_equivalence,
assert_string_arrays_equal,
Expand All @@ -22,7 +32,6 @@
normalize_error_messages,
parse_options,
perform_file_operations,
update_testcase_output,
)

try:
Expand Down Expand Up @@ -180,7 +189,11 @@ def run_case_once(
raise AssertionError()

if output != a and testcase.config.getoption("--update-data", False):
update_testcase_output(testcase, a)
collector = testcase.parent
assert isinstance(collector, DataFileCollector)
for fix in self.iter_data_file_fixes(a, testcase):
collector.enqueue_fix(fix)

assert_string_arrays_equal(output, a, msg.format(testcase.file, testcase.line))

if res:
Expand Down Expand Up @@ -214,6 +227,55 @@ def run_case_once(
if testcase.output_files:
check_test_output_files(testcase, incremental_step, strip_prefix="tmp/")

def iter_data_file_fixes(
self, actual: list[str], testcase: DataDrivenTestCase
) -> Iterator[DataFileFix]:
reports_by_line: dict[tuple[str, int], list[tuple[str, str]]] = defaultdict(list)
for error_line in actual:
comment_match = re.match(
r"^(?P<filename>[^:]+):(?P<lineno>\d+): (?P<severity>error|note|warning): (?P<msg>.+)$",
error_line,
)
if comment_match:
filename = comment_match.group("filename")
lineno = int(comment_match.group("lineno")) - 1
severity = comment_match.group("severity")
msg = comment_match.group("msg")
reports_by_line[filename, lineno].append((severity, msg))

for item in parse_test_data(testcase.data, testcase.name):
if item.id == "case":
source_lines = item.data
file_path = "main"
elif item.id == "file":
source_lines = item.data
file_path = f"tmp/{item.arg}"
else:
continue # other sections we don't touch

fix_lines = []
for lineno, source_line in enumerate(source_lines):
reports = reports_by_line.get((file_path, lineno))
comment_match = re.search(r"(?P<indent>\s+)(?P<comment># [EWN]: .+)$", source_line)
if comment_match:
source_line = source_line[: comment_match.start("indent")] # strip old comment
if reports:
indent = comment_match.group("indent") if comment_match else " "
# multiline comments are on the first line and then on subsequent lines emtpy lines
# with a continuation backslash
for j, (severity, msg) in enumerate(reports):
out_l = source_line if j == 0 else " " * len(source_line)
is_last = j == len(reports) - 1
severity_char = severity[0].upper()
continuation = "" if is_last else " \\"
fix_lines.append(f"{out_l}{indent}# {severity_char}: {msg}{continuation}")
else:
fix_lines.append(source_line)

lineno = testcase.line + item.line - 1 # both testcase and item are 1-offset
end_lineno = lineno + len(item.data)
yield DataFileFix(lineno, end_lineno, fix_lines)

def verify_cache(
self,
module_data: list[tuple[str, str, str]],
Expand Down
82 changes: 82 additions & 0 deletions mypy/test/testupdatedata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import subprocess
import textwrap
from os.path import basename
from tempfile import NamedTemporaryFile

from mypy.test.config import test_data_prefix
from mypy.test.helpers import Suite


class UpdateDataSuite(Suite):
def _update_test(self, testcase: str) -> str:
with NamedTemporaryFile(
mode="w+", dir=test_data_prefix, prefix="check-update-data-", suffix=".test"
) as tmp_file:
tmp_file.write(textwrap.dedent(testcase))
tmp_file.flush()
test_nodeid = f"mypy/test/testcheck.py::TypeCheckSuite::{basename(tmp_file.name)}"
res = subprocess.run(
["pytest", "-n", "0", "-s", "--update-data", test_nodeid], capture_output=True
)
assert res.returncode == 1
tmp_file.seek(0)
return tmp_file.read()

def test_update_data(self) -> None:
actual = self._update_test(
"""
[case testCorrect]
s: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
[case testWrong]
s: str = 42 # E: wrong error
[case testMissingMultiline]
s: str = 42; i: int = 'foo'
[case testExtraneous]
s: str = 'foo' # E: wrong error
[case testExtraneousMultiline]
s: str = 'foo' # E: foo \
# E: bar
[case testExtraneousMultilineNonError]
s: str = 'foo' # W: foo \
# N: bar
[case testWrongMultipleFiles]
import a, b
s: str = 42 # E: foo
[file a.py]
s1: str = 42 # E: bar
[file b.py]
s2: str = 43 # E: baz
[out]
make sure we're not touching this
[builtins fixtures/list.pyi]
"""
)

# Assert
assert actual == textwrap.dedent(
"""
[case testCorrect]
s: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
[case testWrong]
s: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
[case testMissingMultiline]
s: str = 42; i: int = 'foo' # E: Incompatible types in assignment (expression has type "int", variable has type "str") \\
# E: Incompatible types in assignment (expression has type "str", variable has type "int")
[case testExtraneous]
s: str = 'foo'
[case testExtraneousMultiline]
s: str = 'foo'
[case testExtraneousMultilineNonError]
s: str = 'foo'
[case testWrongMultipleFiles]
import a, b
s: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
[file a.py]
s1: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
[file b.py]
s2: str = 43 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
[out]
make sure we're not touching this
[builtins fixtures/list.pyi]
"""
)

0 comments on commit ea590fb

Please sign in to comment.