mirror of
https://github.com/pre-commit/pre-commit.git
synced 2026-02-17 00:04:42 +04:00
change migrate-config to use yaml parse tree instead
This commit is contained in:
parent
504149d2ca
commit
364e6d77f0
5 changed files with 194 additions and 10 deletions
|
|
@ -1,13 +1,20 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import functools
|
||||
import textwrap
|
||||
from typing import Callable
|
||||
|
||||
import cfgv
|
||||
import yaml
|
||||
from yaml.nodes import ScalarNode
|
||||
|
||||
from pre_commit.clientlib import InvalidConfigError
|
||||
from pre_commit.yaml import yaml_compose
|
||||
from pre_commit.yaml import yaml_load
|
||||
from pre_commit.yaml_rewrite import MappingKey
|
||||
from pre_commit.yaml_rewrite import MappingValue
|
||||
from pre_commit.yaml_rewrite import match
|
||||
from pre_commit.yaml_rewrite import SequenceItem
|
||||
|
||||
|
||||
def _is_header_line(line: str) -> bool:
|
||||
|
|
@ -38,16 +45,48 @@ def _migrate_map(contents: str) -> str:
|
|||
return contents
|
||||
|
||||
|
||||
def _migrate_sha_to_rev(contents: str) -> str:
|
||||
return re.sub(r'(\n\s+)sha:', r'\1rev:', contents)
|
||||
def _preserve_style(n: ScalarNode, *, s: str) -> str:
|
||||
return f'{n.style}{s}{n.style}'
|
||||
|
||||
|
||||
def _migrate_python_venv(contents: str) -> str:
|
||||
return re.sub(
|
||||
r'(\n\s+)language: python_venv\b',
|
||||
r'\1language: python',
|
||||
contents,
|
||||
def _migrate_composed(contents: str) -> str:
|
||||
tree = yaml_compose(contents)
|
||||
rewrites: list[tuple[ScalarNode, Callable[[ScalarNode], str]]] = []
|
||||
|
||||
# sha -> rev
|
||||
sha_to_rev_replace = functools.partial(_preserve_style, s='rev')
|
||||
sha_to_rev_matcher = (
|
||||
MappingValue('repos'),
|
||||
SequenceItem(),
|
||||
MappingKey('sha'),
|
||||
)
|
||||
for node in match(tree, sha_to_rev_matcher):
|
||||
rewrites.append((node, sha_to_rev_replace))
|
||||
|
||||
# python_venv -> python
|
||||
language_matcher = (
|
||||
MappingValue('repos'),
|
||||
SequenceItem(),
|
||||
MappingValue('hooks'),
|
||||
SequenceItem(),
|
||||
MappingValue('language'),
|
||||
)
|
||||
python_venv_replace = functools.partial(_preserve_style, s='python')
|
||||
for node in match(tree, language_matcher):
|
||||
if node.value == 'python_venv':
|
||||
rewrites.append((node, python_venv_replace))
|
||||
|
||||
rewrites.sort(reverse=True, key=lambda nf: nf[0].start_mark.index)
|
||||
|
||||
src_parts = []
|
||||
end: int | None = None
|
||||
for node, func in rewrites:
|
||||
src_parts.append(contents[node.end_mark.index:end])
|
||||
src_parts.append(func(node))
|
||||
end = node.start_mark.index
|
||||
src_parts.append(contents[:end])
|
||||
src_parts.reverse()
|
||||
return ''.join(src_parts)
|
||||
|
||||
|
||||
def migrate_config(config_file: str, quiet: bool = False) -> int:
|
||||
|
|
@ -62,8 +101,7 @@ def migrate_config(config_file: str, quiet: bool = False) -> int:
|
|||
raise cfgv.ValidationError(str(e))
|
||||
|
||||
contents = _migrate_map(contents)
|
||||
contents = _migrate_sha_to_rev(contents)
|
||||
contents = _migrate_python_venv(contents)
|
||||
contents = _migrate_composed(contents)
|
||||
|
||||
if contents != orig_contents:
|
||||
with open(config_file, 'w') as f:
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import Any
|
|||
import yaml
|
||||
|
||||
Loader = getattr(yaml, 'CSafeLoader', yaml.SafeLoader)
|
||||
yaml_compose = functools.partial(yaml.compose, Loader=Loader)
|
||||
yaml_load = functools.partial(yaml.load, Loader=Loader)
|
||||
Dumper = getattr(yaml, 'CSafeDumper', yaml.SafeDumper)
|
||||
|
||||
|
|
|
|||
52
pre_commit/yaml_rewrite.py
Normal file
52
pre_commit/yaml_rewrite.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterable
|
||||
from typing import NamedTuple
|
||||
from typing import Protocol
|
||||
|
||||
from yaml.nodes import MappingNode
|
||||
from yaml.nodes import Node
|
||||
from yaml.nodes import ScalarNode
|
||||
from yaml.nodes import SequenceNode
|
||||
|
||||
|
||||
class _Matcher(Protocol):
|
||||
def match(self, n: Node) -> Generator[Node]: ...
|
||||
|
||||
|
||||
class MappingKey(NamedTuple):
|
||||
k: str
|
||||
|
||||
def match(self, n: Node) -> Generator[Node]:
|
||||
if isinstance(n, MappingNode):
|
||||
for k, _ in n.value:
|
||||
if k.value == self.k:
|
||||
yield k
|
||||
|
||||
|
||||
class MappingValue(NamedTuple):
|
||||
k: str
|
||||
|
||||
def match(self, n: Node) -> Generator[Node]:
|
||||
if isinstance(n, MappingNode):
|
||||
for k, v in n.value:
|
||||
if k.value == self.k:
|
||||
yield v
|
||||
|
||||
|
||||
class SequenceItem(NamedTuple):
|
||||
def match(self, n: Node) -> Generator[Node]:
|
||||
if isinstance(n, SequenceNode):
|
||||
yield from n.value
|
||||
|
||||
|
||||
def _match(gen: Iterable[Node], m: _Matcher) -> Iterable[Node]:
|
||||
return (n for src in gen for n in m.match(src))
|
||||
|
||||
|
||||
def match(n: Node, matcher: tuple[_Matcher, ...]) -> Generator[ScalarNode]:
|
||||
gen: Iterable[Node] = (n,)
|
||||
for m in matcher:
|
||||
gen = _match(gen, m)
|
||||
return (n for n in gen if isinstance(n, ScalarNode))
|
||||
Loading…
Add table
Add a link
Reference in a new issue