feature: add monorepo recursive configuration

Finds all ".pre-commit-config.yaml" files recursively, depending on the
staged files.

Only tested on Linux.
This commit is contained in:
zp4rand0miz31 2023-04-01 11:43:34 +02:00
parent bb49560dc9
commit d8295a246b

View file

@ -9,7 +9,7 @@ import re
import subprocess import subprocess
import time import time
import unicodedata import unicodedata
from typing import Any from typing import Any, List
from typing import Collection from typing import Collection
from typing import MutableMapping from typing import MutableMapping
from typing import Sequence from typing import Sequence
@ -317,7 +317,6 @@ def _run_hooks(
'git', '--no-pager', 'diff', '--no-ext-diff', 'git', '--no-pager', 'diff', '--no-ext-diff',
f'--color={git_color_opt}', f'--color={git_color_opt}',
)) ))
return retval return retval
@ -334,6 +333,36 @@ def _has_unstaged_config(config_file: str) -> bool:
return retcode == 1 return retcode == 1
def _dive_into_file_hierarchy(dir_path: str):
""" Iterator for crawling recursively a path.
:note!: Not tested on windows (hardcoded '/' for path separator).
"""
accumulated = "" # start from empty string
_split = dir_path.split('/')
while len(_split) > 0:
accumulated = os.path.join(accumulated, _split.pop(0))
yield accumulated
def _find_all_config_files(modified_files: List[str], config_file_name: str) -> List[str]:
""" Finds all the config files relative to modified files.
Every modified file can have a :config_file_name: in its parent directory. If found, get it.
"""
ret_set = set()
for file_name in modified_files:
directory_base = os.path.dirname(file_name)
for subdir_seq_item in _dive_into_file_hierarchy(directory_base):
candidate_fileconfig = os.path.join(subdir_seq_item, config_file_name)
if os.path.exists(candidate_fileconfig):
# logging.debug(f"Will use config file: {candidate_fileconfig}")
ret_set.add(candidate_fileconfig)
# do not forget the root dir !
if os.path.exists(config_file_name):
# logging.debug(f"Will use config file: {config_file_name}")
ret_set.add(config_file_name)
ret = sorted(list(ret_set))
return ret
def run( def run(
config_file: str, config_file: str,
store: Store, store: Store,
@ -341,7 +370,6 @@ def run(
environ: MutableMapping[str, str] = os.environ, environ: MutableMapping[str, str] = os.environ,
) -> int: ) -> int:
stash = not args.all_files and not args.files stash = not args.all_files and not args.files
# Check if we have unresolved merge conflict files and fail fast. # Check if we have unresolved merge conflict files and fail fast.
if stash and _has_unmerged_paths(): if stash and _has_unmerged_paths():
logger.error('Unmerged files. Resolve before committing.') logger.error('Unmerged files. Resolve before committing.')
@ -419,29 +447,36 @@ def run(
if stash: if stash:
exit_stack.enter_context(staged_files_only(store.directory)) exit_stack.enter_context(staged_files_only(store.directory))
config = load_config(config_file) ret_int = 0
hooks = [ # find all applicable config files
hook all_config_files = _find_all_config_files(_all_filenames(args), config_file)
for hook in all_hooks(config, store) for _config_file in all_config_files:
if not args.hook or hook.id == args.hook or hook.alias == args.hook config = load_config(_config_file)
if args.hook_stage in hook.stages if len(all_config_files) > 0 :
] print(f"Hooks from {_config_file}:")
hooks = [
hook
for hook in all_hooks(config, store)
if not args.hook or hook.id == args.hook or hook.alias == args.hook
if args.hook_stage in hook.stages
]
if args.hook and not hooks: if args.hook and not hooks:
output.write_line( output.write_line(
f'No hook with id `{args.hook}` in stage `{args.hook_stage}`', f'No hook with id `{args.hook}` in stage `{args.hook_stage}` of file {_config_file}',
) )
return 1 return 1
skips = _get_skips(environ) skips = _get_skips(environ)
to_install = [ to_install = [
hook hook
for hook in hooks for hook in hooks
if hook.id not in skips and hook.alias not in skips if hook.id not in skips and hook.alias not in skips
] ]
install_hook_envs(to_install, store) install_hook_envs(to_install, store)
return _run_hooks(config, hooks, skips, args) ret_int += _run_hooks(config, hooks, skips, args) # accumulate errors
return ret_int
# https://github.com/python/mypy/issues/7726 # https://github.com/python/mypy/issues/7726
raise AssertionError('unreachable') raise AssertionError('unreachable')