From 364e6d77f051b40d22ac9071ef64bc12f3e6a1fe Mon Sep 17 00:00:00 2001 From: Anthony Sottile Date: Mon, 16 Sep 2024 20:05:29 -0400 Subject: [PATCH] change migrate-config to use yaml parse tree instead --- pre_commit/commands/migrate_config.py | 58 ++++++++++++++++++++++----- pre_commit/yaml.py | 1 + pre_commit/yaml_rewrite.py | 52 ++++++++++++++++++++++++ tests/commands/migrate_config_test.py | 46 +++++++++++++++++++++ tests/yaml_rewrite_test.py | 47 ++++++++++++++++++++++ 5 files changed, 194 insertions(+), 10 deletions(-) create mode 100644 pre_commit/yaml_rewrite.py create mode 100644 tests/yaml_rewrite_test.py diff --git a/pre_commit/commands/migrate_config.py b/pre_commit/commands/migrate_config.py index 842fb3a7..cdce83f5 100644 --- a/pre_commit/commands/migrate_config.py +++ b/pre_commit/commands/migrate_config.py @@ -1,13 +1,20 @@ from __future__ import annotations -import re +import functools import textwrap +from typing import Callable import cfgv import yaml +from yaml.nodes import ScalarNode from pre_commit.clientlib import InvalidConfigError +from pre_commit.yaml import yaml_compose from pre_commit.yaml import yaml_load +from pre_commit.yaml_rewrite import MappingKey +from pre_commit.yaml_rewrite import MappingValue +from pre_commit.yaml_rewrite import match +from pre_commit.yaml_rewrite import SequenceItem def _is_header_line(line: str) -> bool: @@ -38,16 +45,48 @@ def _migrate_map(contents: str) -> str: return contents -def _migrate_sha_to_rev(contents: str) -> str: - return re.sub(r'(\n\s+)sha:', r'\1rev:', contents) +def _preserve_style(n: ScalarNode, *, s: str) -> str: + return f'{n.style}{s}{n.style}' -def _migrate_python_venv(contents: str) -> str: - return re.sub( - r'(\n\s+)language: python_venv\b', - r'\1language: python', - contents, +def _migrate_composed(contents: str) -> str: + tree = yaml_compose(contents) + rewrites: list[tuple[ScalarNode, Callable[[ScalarNode], str]]] = [] + + # sha -> rev + sha_to_rev_replace = functools.partial(_preserve_style, s='rev') + sha_to_rev_matcher = ( + MappingValue('repos'), + SequenceItem(), + MappingKey('sha'), ) + for node in match(tree, sha_to_rev_matcher): + rewrites.append((node, sha_to_rev_replace)) + + # python_venv -> python + language_matcher = ( + MappingValue('repos'), + SequenceItem(), + MappingValue('hooks'), + SequenceItem(), + MappingValue('language'), + ) + python_venv_replace = functools.partial(_preserve_style, s='python') + for node in match(tree, language_matcher): + if node.value == 'python_venv': + rewrites.append((node, python_venv_replace)) + + rewrites.sort(reverse=True, key=lambda nf: nf[0].start_mark.index) + + src_parts = [] + end: int | None = None + for node, func in rewrites: + src_parts.append(contents[node.end_mark.index:end]) + src_parts.append(func(node)) + end = node.start_mark.index + src_parts.append(contents[:end]) + src_parts.reverse() + return ''.join(src_parts) def migrate_config(config_file: str, quiet: bool = False) -> int: @@ -62,8 +101,7 @@ def migrate_config(config_file: str, quiet: bool = False) -> int: raise cfgv.ValidationError(str(e)) contents = _migrate_map(contents) - contents = _migrate_sha_to_rev(contents) - contents = _migrate_python_venv(contents) + contents = _migrate_composed(contents) if contents != orig_contents: with open(config_file, 'w') as f: diff --git a/pre_commit/yaml.py b/pre_commit/yaml.py index bdf4ec47..a5bbbc99 100644 --- a/pre_commit/yaml.py +++ b/pre_commit/yaml.py @@ -6,6 +6,7 @@ from typing import Any import yaml Loader = getattr(yaml, 'CSafeLoader', yaml.SafeLoader) +yaml_compose = functools.partial(yaml.compose, Loader=Loader) yaml_load = functools.partial(yaml.load, Loader=Loader) Dumper = getattr(yaml, 'CSafeDumper', yaml.SafeDumper) diff --git a/pre_commit/yaml_rewrite.py b/pre_commit/yaml_rewrite.py new file mode 100644 index 00000000..8d0e8fdb --- /dev/null +++ b/pre_commit/yaml_rewrite.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from collections.abc import Generator +from collections.abc import Iterable +from typing import NamedTuple +from typing import Protocol + +from yaml.nodes import MappingNode +from yaml.nodes import Node +from yaml.nodes import ScalarNode +from yaml.nodes import SequenceNode + + +class _Matcher(Protocol): + def match(self, n: Node) -> Generator[Node]: ... + + +class MappingKey(NamedTuple): + k: str + + def match(self, n: Node) -> Generator[Node]: + if isinstance(n, MappingNode): + for k, _ in n.value: + if k.value == self.k: + yield k + + +class MappingValue(NamedTuple): + k: str + + def match(self, n: Node) -> Generator[Node]: + if isinstance(n, MappingNode): + for k, v in n.value: + if k.value == self.k: + yield v + + +class SequenceItem(NamedTuple): + def match(self, n: Node) -> Generator[Node]: + if isinstance(n, SequenceNode): + yield from n.value + + +def _match(gen: Iterable[Node], m: _Matcher) -> Iterable[Node]: + return (n for src in gen for n in m.match(src)) + + +def match(n: Node, matcher: tuple[_Matcher, ...]) -> Generator[ScalarNode]: + gen: Iterable[Node] = (n,) + for m in matcher: + gen = _match(gen, m) + return (n for n in gen if isinstance(n, ScalarNode)) diff --git a/tests/commands/migrate_config_test.py b/tests/commands/migrate_config_test.py index ba184636..c563866d 100644 --- a/tests/commands/migrate_config_test.py +++ b/tests/commands/migrate_config_test.py @@ -134,6 +134,27 @@ def test_migrate_config_sha_to_rev(tmpdir): ) +def test_migrate_config_sha_to_rev_json(tmp_path): + contents = """\ +{"repos": [{ + "repo": "https://github.com/pre-commit/pre-commit-hooks", + "sha": "v1.2.0", + "hooks": [] +}]} +""" + expected = """\ +{"repos": [{ + "repo": "https://github.com/pre-commit/pre-commit-hooks", + "rev": "v1.2.0", + "hooks": [] +}]} +""" + cfg = tmp_path.joinpath('cfg.yaml') + cfg.write_text(contents) + assert not migrate_config(str(cfg)) + assert cfg.read_text() == expected + + def test_migrate_config_language_python_venv(tmp_path): src = '''\ repos: @@ -167,6 +188,31 @@ repos: assert cfg.read_text() == expected +def test_migrate_config_quoted_python_venv(tmp_path): + src = '''\ +repos: +- repo: local + hooks: + - id: example + name: example + entry: example + language: "python_venv" +''' + expected = '''\ +repos: +- repo: local + hooks: + - id: example + name: example + entry: example + language: "python" +''' + cfg = tmp_path.joinpath('cfg.yaml') + cfg.write_text(src) + assert migrate_config(str(cfg)) == 0 + assert cfg.read_text() == expected + + def test_migrate_config_invalid_yaml(tmpdir): contents = '[' cfg = tmpdir.join(C.CONFIG_FILE) diff --git a/tests/yaml_rewrite_test.py b/tests/yaml_rewrite_test.py new file mode 100644 index 00000000..d0f6841c --- /dev/null +++ b/tests/yaml_rewrite_test.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import pytest + +from pre_commit.yaml import yaml_compose +from pre_commit.yaml_rewrite import MappingKey +from pre_commit.yaml_rewrite import MappingValue +from pre_commit.yaml_rewrite import match +from pre_commit.yaml_rewrite import SequenceItem + + +def test_match_produces_scalar_values_only(): + src = '''\ +- name: foo +- name: [not, foo] # not a scalar: should be skipped! +- name: bar +''' + matcher = (SequenceItem(), MappingValue('name')) + ret = [n.value for n in match(yaml_compose(src), matcher)] + assert ret == ['foo', 'bar'] + + +@pytest.mark.parametrize('cls', (MappingKey, MappingValue)) +def test_mapping_not_a_map(cls): + m = cls('s') + assert list(m.match(yaml_compose('[foo]'))) == [] + + +def test_sequence_item_not_a_sequence(): + assert list(SequenceItem().match(yaml_compose('s: val'))) == [] + + +def test_mapping_key(): + m = MappingKey('s') + ret = [n.value for n in m.match(yaml_compose('s: val\nt: val2'))] + assert ret == ['s'] + + +def test_mapping_value(): + m = MappingValue('s') + ret = [n.value for n in m.match(yaml_compose('s: val\nt: val2'))] + assert ret == ['val'] + + +def test_sequence_item(): + ret = [n.value for n in SequenceItem().match(yaml_compose('[a, b, c]'))] + assert ret == ['a', 'b', 'c']