change migrate-config to use yaml parse tree instead

This commit is contained in:
Anthony Sottile 2024-09-16 20:05:29 -04:00
parent 504149d2ca
commit 364e6d77f0
5 changed files with 194 additions and 10 deletions

View file

@ -1,13 +1,20 @@
from __future__ import annotations
import re
import functools
import textwrap
from typing import Callable
import cfgv
import yaml
from yaml.nodes import ScalarNode
from pre_commit.clientlib import InvalidConfigError
from pre_commit.yaml import yaml_compose
from pre_commit.yaml import yaml_load
from pre_commit.yaml_rewrite import MappingKey
from pre_commit.yaml_rewrite import MappingValue
from pre_commit.yaml_rewrite import match
from pre_commit.yaml_rewrite import SequenceItem
def _is_header_line(line: str) -> bool:
@ -38,16 +45,48 @@ def _migrate_map(contents: str) -> str:
return contents
def _migrate_sha_to_rev(contents: str) -> str:
return re.sub(r'(\n\s+)sha:', r'\1rev:', contents)
def _preserve_style(n: ScalarNode, *, s: str) -> str:
return f'{n.style}{s}{n.style}'
def _migrate_python_venv(contents: str) -> str:
return re.sub(
r'(\n\s+)language: python_venv\b',
r'\1language: python',
contents,
def _migrate_composed(contents: str) -> str:
tree = yaml_compose(contents)
rewrites: list[tuple[ScalarNode, Callable[[ScalarNode], str]]] = []
# sha -> rev
sha_to_rev_replace = functools.partial(_preserve_style, s='rev')
sha_to_rev_matcher = (
MappingValue('repos'),
SequenceItem(),
MappingKey('sha'),
)
for node in match(tree, sha_to_rev_matcher):
rewrites.append((node, sha_to_rev_replace))
# python_venv -> python
language_matcher = (
MappingValue('repos'),
SequenceItem(),
MappingValue('hooks'),
SequenceItem(),
MappingValue('language'),
)
python_venv_replace = functools.partial(_preserve_style, s='python')
for node in match(tree, language_matcher):
if node.value == 'python_venv':
rewrites.append((node, python_venv_replace))
rewrites.sort(reverse=True, key=lambda nf: nf[0].start_mark.index)
src_parts = []
end: int | None = None
for node, func in rewrites:
src_parts.append(contents[node.end_mark.index:end])
src_parts.append(func(node))
end = node.start_mark.index
src_parts.append(contents[:end])
src_parts.reverse()
return ''.join(src_parts)
def migrate_config(config_file: str, quiet: bool = False) -> int:
@ -62,8 +101,7 @@ def migrate_config(config_file: str, quiet: bool = False) -> int:
raise cfgv.ValidationError(str(e))
contents = _migrate_map(contents)
contents = _migrate_sha_to_rev(contents)
contents = _migrate_python_venv(contents)
contents = _migrate_composed(contents)
if contents != orig_contents:
with open(config_file, 'w') as f:

View file

@ -6,6 +6,7 @@ from typing import Any
import yaml
Loader = getattr(yaml, 'CSafeLoader', yaml.SafeLoader)
yaml_compose = functools.partial(yaml.compose, Loader=Loader)
yaml_load = functools.partial(yaml.load, Loader=Loader)
Dumper = getattr(yaml, 'CSafeDumper', yaml.SafeDumper)

View file

@ -0,0 +1,52 @@
from __future__ import annotations
from collections.abc import Generator
from collections.abc import Iterable
from typing import NamedTuple
from typing import Protocol
from yaml.nodes import MappingNode
from yaml.nodes import Node
from yaml.nodes import ScalarNode
from yaml.nodes import SequenceNode
class _Matcher(Protocol):
def match(self, n: Node) -> Generator[Node]: ...
class MappingKey(NamedTuple):
k: str
def match(self, n: Node) -> Generator[Node]:
if isinstance(n, MappingNode):
for k, _ in n.value:
if k.value == self.k:
yield k
class MappingValue(NamedTuple):
k: str
def match(self, n: Node) -> Generator[Node]:
if isinstance(n, MappingNode):
for k, v in n.value:
if k.value == self.k:
yield v
class SequenceItem(NamedTuple):
def match(self, n: Node) -> Generator[Node]:
if isinstance(n, SequenceNode):
yield from n.value
def _match(gen: Iterable[Node], m: _Matcher) -> Iterable[Node]:
return (n for src in gen for n in m.match(src))
def match(n: Node, matcher: tuple[_Matcher, ...]) -> Generator[ScalarNode]:
gen: Iterable[Node] = (n,)
for m in matcher:
gen = _match(gen, m)
return (n for n in gen if isinstance(n, ScalarNode))

View file

@ -134,6 +134,27 @@ def test_migrate_config_sha_to_rev(tmpdir):
)
def test_migrate_config_sha_to_rev_json(tmp_path):
contents = """\
{"repos": [{
"repo": "https://github.com/pre-commit/pre-commit-hooks",
"sha": "v1.2.0",
"hooks": []
}]}
"""
expected = """\
{"repos": [{
"repo": "https://github.com/pre-commit/pre-commit-hooks",
"rev": "v1.2.0",
"hooks": []
}]}
"""
cfg = tmp_path.joinpath('cfg.yaml')
cfg.write_text(contents)
assert not migrate_config(str(cfg))
assert cfg.read_text() == expected
def test_migrate_config_language_python_venv(tmp_path):
src = '''\
repos:
@ -167,6 +188,31 @@ repos:
assert cfg.read_text() == expected
def test_migrate_config_quoted_python_venv(tmp_path):
src = '''\
repos:
- repo: local
hooks:
- id: example
name: example
entry: example
language: "python_venv"
'''
expected = '''\
repos:
- repo: local
hooks:
- id: example
name: example
entry: example
language: "python"
'''
cfg = tmp_path.joinpath('cfg.yaml')
cfg.write_text(src)
assert migrate_config(str(cfg)) == 0
assert cfg.read_text() == expected
def test_migrate_config_invalid_yaml(tmpdir):
contents = '['
cfg = tmpdir.join(C.CONFIG_FILE)

View file

@ -0,0 +1,47 @@
from __future__ import annotations
import pytest
from pre_commit.yaml import yaml_compose
from pre_commit.yaml_rewrite import MappingKey
from pre_commit.yaml_rewrite import MappingValue
from pre_commit.yaml_rewrite import match
from pre_commit.yaml_rewrite import SequenceItem
def test_match_produces_scalar_values_only():
src = '''\
- name: foo
- name: [not, foo] # not a scalar: should be skipped!
- name: bar
'''
matcher = (SequenceItem(), MappingValue('name'))
ret = [n.value for n in match(yaml_compose(src), matcher)]
assert ret == ['foo', 'bar']
@pytest.mark.parametrize('cls', (MappingKey, MappingValue))
def test_mapping_not_a_map(cls):
m = cls('s')
assert list(m.match(yaml_compose('[foo]'))) == []
def test_sequence_item_not_a_sequence():
assert list(SequenceItem().match(yaml_compose('s: val'))) == []
def test_mapping_key():
m = MappingKey('s')
ret = [n.value for n in m.match(yaml_compose('s: val\nt: val2'))]
assert ret == ['s']
def test_mapping_value():
m = MappingValue('s')
ret = [n.value for n in m.match(yaml_compose('s: val\nt: val2'))]
assert ret == ['val']
def test_sequence_item():
ret = [n.value for n in SequenceItem().match(yaml_compose('[a, b, c]'))]
assert ret == ['a', 'b', 'c']