mirror of
https://github.com/pre-commit/pre-commit.git
synced 2026-02-17 00:04:42 +04:00
change migrate-config to use yaml parse tree instead
This commit is contained in:
parent
504149d2ca
commit
364e6d77f0
5 changed files with 194 additions and 10 deletions
|
|
@ -1,13 +1,20 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import functools
|
||||
import textwrap
|
||||
from typing import Callable
|
||||
|
||||
import cfgv
|
||||
import yaml
|
||||
from yaml.nodes import ScalarNode
|
||||
|
||||
from pre_commit.clientlib import InvalidConfigError
|
||||
from pre_commit.yaml import yaml_compose
|
||||
from pre_commit.yaml import yaml_load
|
||||
from pre_commit.yaml_rewrite import MappingKey
|
||||
from pre_commit.yaml_rewrite import MappingValue
|
||||
from pre_commit.yaml_rewrite import match
|
||||
from pre_commit.yaml_rewrite import SequenceItem
|
||||
|
||||
|
||||
def _is_header_line(line: str) -> bool:
|
||||
|
|
@ -38,16 +45,48 @@ def _migrate_map(contents: str) -> str:
|
|||
return contents
|
||||
|
||||
|
||||
def _migrate_sha_to_rev(contents: str) -> str:
|
||||
return re.sub(r'(\n\s+)sha:', r'\1rev:', contents)
|
||||
def _preserve_style(n: ScalarNode, *, s: str) -> str:
|
||||
return f'{n.style}{s}{n.style}'
|
||||
|
||||
|
||||
def _migrate_python_venv(contents: str) -> str:
|
||||
return re.sub(
|
||||
r'(\n\s+)language: python_venv\b',
|
||||
r'\1language: python',
|
||||
contents,
|
||||
def _migrate_composed(contents: str) -> str:
|
||||
tree = yaml_compose(contents)
|
||||
rewrites: list[tuple[ScalarNode, Callable[[ScalarNode], str]]] = []
|
||||
|
||||
# sha -> rev
|
||||
sha_to_rev_replace = functools.partial(_preserve_style, s='rev')
|
||||
sha_to_rev_matcher = (
|
||||
MappingValue('repos'),
|
||||
SequenceItem(),
|
||||
MappingKey('sha'),
|
||||
)
|
||||
for node in match(tree, sha_to_rev_matcher):
|
||||
rewrites.append((node, sha_to_rev_replace))
|
||||
|
||||
# python_venv -> python
|
||||
language_matcher = (
|
||||
MappingValue('repos'),
|
||||
SequenceItem(),
|
||||
MappingValue('hooks'),
|
||||
SequenceItem(),
|
||||
MappingValue('language'),
|
||||
)
|
||||
python_venv_replace = functools.partial(_preserve_style, s='python')
|
||||
for node in match(tree, language_matcher):
|
||||
if node.value == 'python_venv':
|
||||
rewrites.append((node, python_venv_replace))
|
||||
|
||||
rewrites.sort(reverse=True, key=lambda nf: nf[0].start_mark.index)
|
||||
|
||||
src_parts = []
|
||||
end: int | None = None
|
||||
for node, func in rewrites:
|
||||
src_parts.append(contents[node.end_mark.index:end])
|
||||
src_parts.append(func(node))
|
||||
end = node.start_mark.index
|
||||
src_parts.append(contents[:end])
|
||||
src_parts.reverse()
|
||||
return ''.join(src_parts)
|
||||
|
||||
|
||||
def migrate_config(config_file: str, quiet: bool = False) -> int:
|
||||
|
|
@ -62,8 +101,7 @@ def migrate_config(config_file: str, quiet: bool = False) -> int:
|
|||
raise cfgv.ValidationError(str(e))
|
||||
|
||||
contents = _migrate_map(contents)
|
||||
contents = _migrate_sha_to_rev(contents)
|
||||
contents = _migrate_python_venv(contents)
|
||||
contents = _migrate_composed(contents)
|
||||
|
||||
if contents != orig_contents:
|
||||
with open(config_file, 'w') as f:
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import Any
|
|||
import yaml
|
||||
|
||||
Loader = getattr(yaml, 'CSafeLoader', yaml.SafeLoader)
|
||||
yaml_compose = functools.partial(yaml.compose, Loader=Loader)
|
||||
yaml_load = functools.partial(yaml.load, Loader=Loader)
|
||||
Dumper = getattr(yaml, 'CSafeDumper', yaml.SafeDumper)
|
||||
|
||||
|
|
|
|||
52
pre_commit/yaml_rewrite.py
Normal file
52
pre_commit/yaml_rewrite.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterable
|
||||
from typing import NamedTuple
|
||||
from typing import Protocol
|
||||
|
||||
from yaml.nodes import MappingNode
|
||||
from yaml.nodes import Node
|
||||
from yaml.nodes import ScalarNode
|
||||
from yaml.nodes import SequenceNode
|
||||
|
||||
|
||||
class _Matcher(Protocol):
|
||||
def match(self, n: Node) -> Generator[Node]: ...
|
||||
|
||||
|
||||
class MappingKey(NamedTuple):
|
||||
k: str
|
||||
|
||||
def match(self, n: Node) -> Generator[Node]:
|
||||
if isinstance(n, MappingNode):
|
||||
for k, _ in n.value:
|
||||
if k.value == self.k:
|
||||
yield k
|
||||
|
||||
|
||||
class MappingValue(NamedTuple):
|
||||
k: str
|
||||
|
||||
def match(self, n: Node) -> Generator[Node]:
|
||||
if isinstance(n, MappingNode):
|
||||
for k, v in n.value:
|
||||
if k.value == self.k:
|
||||
yield v
|
||||
|
||||
|
||||
class SequenceItem(NamedTuple):
|
||||
def match(self, n: Node) -> Generator[Node]:
|
||||
if isinstance(n, SequenceNode):
|
||||
yield from n.value
|
||||
|
||||
|
||||
def _match(gen: Iterable[Node], m: _Matcher) -> Iterable[Node]:
|
||||
return (n for src in gen for n in m.match(src))
|
||||
|
||||
|
||||
def match(n: Node, matcher: tuple[_Matcher, ...]) -> Generator[ScalarNode]:
|
||||
gen: Iterable[Node] = (n,)
|
||||
for m in matcher:
|
||||
gen = _match(gen, m)
|
||||
return (n for n in gen if isinstance(n, ScalarNode))
|
||||
|
|
@ -134,6 +134,27 @@ def test_migrate_config_sha_to_rev(tmpdir):
|
|||
)
|
||||
|
||||
|
||||
def test_migrate_config_sha_to_rev_json(tmp_path):
|
||||
contents = """\
|
||||
{"repos": [{
|
||||
"repo": "https://github.com/pre-commit/pre-commit-hooks",
|
||||
"sha": "v1.2.0",
|
||||
"hooks": []
|
||||
}]}
|
||||
"""
|
||||
expected = """\
|
||||
{"repos": [{
|
||||
"repo": "https://github.com/pre-commit/pre-commit-hooks",
|
||||
"rev": "v1.2.0",
|
||||
"hooks": []
|
||||
}]}
|
||||
"""
|
||||
cfg = tmp_path.joinpath('cfg.yaml')
|
||||
cfg.write_text(contents)
|
||||
assert not migrate_config(str(cfg))
|
||||
assert cfg.read_text() == expected
|
||||
|
||||
|
||||
def test_migrate_config_language_python_venv(tmp_path):
|
||||
src = '''\
|
||||
repos:
|
||||
|
|
@ -167,6 +188,31 @@ repos:
|
|||
assert cfg.read_text() == expected
|
||||
|
||||
|
||||
def test_migrate_config_quoted_python_venv(tmp_path):
|
||||
src = '''\
|
||||
repos:
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: example
|
||||
name: example
|
||||
entry: example
|
||||
language: "python_venv"
|
||||
'''
|
||||
expected = '''\
|
||||
repos:
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: example
|
||||
name: example
|
||||
entry: example
|
||||
language: "python"
|
||||
'''
|
||||
cfg = tmp_path.joinpath('cfg.yaml')
|
||||
cfg.write_text(src)
|
||||
assert migrate_config(str(cfg)) == 0
|
||||
assert cfg.read_text() == expected
|
||||
|
||||
|
||||
def test_migrate_config_invalid_yaml(tmpdir):
|
||||
contents = '['
|
||||
cfg = tmpdir.join(C.CONFIG_FILE)
|
||||
|
|
|
|||
47
tests/yaml_rewrite_test.py
Normal file
47
tests/yaml_rewrite_test.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from pre_commit.yaml import yaml_compose
|
||||
from pre_commit.yaml_rewrite import MappingKey
|
||||
from pre_commit.yaml_rewrite import MappingValue
|
||||
from pre_commit.yaml_rewrite import match
|
||||
from pre_commit.yaml_rewrite import SequenceItem
|
||||
|
||||
|
||||
def test_match_produces_scalar_values_only():
|
||||
src = '''\
|
||||
- name: foo
|
||||
- name: [not, foo] # not a scalar: should be skipped!
|
||||
- name: bar
|
||||
'''
|
||||
matcher = (SequenceItem(), MappingValue('name'))
|
||||
ret = [n.value for n in match(yaml_compose(src), matcher)]
|
||||
assert ret == ['foo', 'bar']
|
||||
|
||||
|
||||
@pytest.mark.parametrize('cls', (MappingKey, MappingValue))
|
||||
def test_mapping_not_a_map(cls):
|
||||
m = cls('s')
|
||||
assert list(m.match(yaml_compose('[foo]'))) == []
|
||||
|
||||
|
||||
def test_sequence_item_not_a_sequence():
|
||||
assert list(SequenceItem().match(yaml_compose('s: val'))) == []
|
||||
|
||||
|
||||
def test_mapping_key():
|
||||
m = MappingKey('s')
|
||||
ret = [n.value for n in m.match(yaml_compose('s: val\nt: val2'))]
|
||||
assert ret == ['s']
|
||||
|
||||
|
||||
def test_mapping_value():
|
||||
m = MappingValue('s')
|
||||
ret = [n.value for n in m.match(yaml_compose('s: val\nt: val2'))]
|
||||
assert ret == ['val']
|
||||
|
||||
|
||||
def test_sequence_item():
|
||||
ret = [n.value for n in SequenceItem().match(yaml_compose('[a, b, c]'))]
|
||||
assert ret == ['a', 'b', 'c']
|
||||
Loading…
Add table
Add a link
Reference in a new issue