Add types to pre-commit

This commit is contained in:
Anthony Sottile 2020-01-10 23:32:28 -08:00
parent fa536a8693
commit 327ed924a3
62 changed files with 911 additions and 411 deletions

View file

@ -1,8 +1,12 @@
import collections
import os.path
import re
from typing import Any
from typing import Dict
from typing import List
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from typing import Tuple
from aspy.yaml import ordered_dump
from aspy.yaml import ordered_load
@ -16,20 +20,23 @@ from pre_commit.clientlib import load_manifest
from pre_commit.clientlib import LOCAL
from pre_commit.clientlib import META
from pre_commit.commands.migrate_config import migrate_config
from pre_commit.store import Store
from pre_commit.util import CalledProcessError
from pre_commit.util import cmd_output
from pre_commit.util import cmd_output_b
from pre_commit.util import tmpdir
class RevInfo(collections.namedtuple('RevInfo', ('repo', 'rev', 'frozen'))):
__slots__ = ()
class RevInfo(NamedTuple):
repo: str
rev: str
frozen: Optional[str]
@classmethod
def from_config(cls, config):
def from_config(cls, config: Dict[str, Any]) -> 'RevInfo':
return cls(config['repo'], config['rev'], None)
def update(self, tags_only, freeze):
def update(self, tags_only: bool, freeze: bool) -> 'RevInfo':
if tags_only:
tag_cmd = ('git', 'describe', 'FETCH_HEAD', '--tags', '--abbrev=0')
else:
@ -57,7 +64,11 @@ class RepositoryCannotBeUpdatedError(RuntimeError):
pass
def _check_hooks_still_exist_at_rev(repo_config, info, store):
def _check_hooks_still_exist_at_rev(
repo_config: Dict[str, Any],
info: RevInfo,
store: Store,
) -> None:
try:
path = store.clone(repo_config['repo'], info.rev)
manifest = load_manifest(os.path.join(path, C.MANIFEST_FILE))
@ -78,7 +89,11 @@ REV_LINE_RE = re.compile(r'^(\s+)rev:(\s*)([^\s#]+)(.*)(\r?\n)$', re.DOTALL)
REV_LINE_FMT = '{}rev:{}{}{}{}'
def _original_lines(path, rev_infos, retry=False):
def _original_lines(
path: str,
rev_infos: List[Optional[RevInfo]],
retry: bool = False,
) -> Tuple[List[str], List[int]]:
"""detect `rev:` lines or reformat the file"""
with open(path) as f:
original = f.read()
@ -95,7 +110,7 @@ def _original_lines(path, rev_infos, retry=False):
return _original_lines(path, rev_infos, retry=True)
def _write_new_config(path, rev_infos):
def _write_new_config(path: str, rev_infos: List[Optional[RevInfo]]) -> None:
lines, idxs = _original_lines(path, rev_infos)
for idx, rev_info in zip(idxs, rev_infos):
@ -119,7 +134,13 @@ def _write_new_config(path, rev_infos):
f.write(''.join(lines))
def autoupdate(config_file, store, tags_only, freeze, repos=()):
def autoupdate(
config_file: str,
store: Store,
tags_only: bool,
freeze: bool,
repos: Sequence[str] = (),
) -> int:
"""Auto-update the pre-commit config to the latest versions of repos."""
migrate_config(config_file, quiet=True)
retv = 0

View file

@ -1,10 +1,11 @@
import os.path
from pre_commit import output
from pre_commit.store import Store
from pre_commit.util import rmtree
def clean(store):
def clean(store: Store) -> int:
legacy_path = os.path.expanduser('~/.pre-commit')
for directory in (store.directory, legacy_path):
if os.path.exists(directory):

View file

@ -1,4 +1,8 @@
import os.path
from typing import Any
from typing import Dict
from typing import Set
from typing import Tuple
import pre_commit.constants as C
from pre_commit import output
@ -8,9 +12,15 @@ from pre_commit.clientlib import load_config
from pre_commit.clientlib import load_manifest
from pre_commit.clientlib import LOCAL
from pre_commit.clientlib import META
from pre_commit.store import Store
def _mark_used_repos(store, all_repos, unused_repos, repo):
def _mark_used_repos(
store: Store,
all_repos: Dict[Tuple[str, str], str],
unused_repos: Set[Tuple[str, str]],
repo: Dict[str, Any],
) -> None:
if repo['repo'] == META:
return
elif repo['repo'] == LOCAL:
@ -47,7 +57,7 @@ def _mark_used_repos(store, all_repos, unused_repos, repo):
))
def _gc_repos(store):
def _gc_repos(store: Store) -> int:
configs = store.select_all_configs()
repos = store.select_all_repos()
@ -73,7 +83,7 @@ def _gc_repos(store):
return len(unused_repos)
def gc(store):
def gc(store: Store) -> int:
with store.exclusive_lock():
repos_removed = _gc_repos(store)
output.write_line(f'{repos_removed} repo(s) removed.')

View file

@ -1,14 +1,21 @@
import logging
import os.path
from typing import Sequence
from pre_commit.commands.install_uninstall import install
from pre_commit.store import Store
from pre_commit.util import CalledProcessError
from pre_commit.util import cmd_output
logger = logging.getLogger('pre_commit')
def init_templatedir(config_file, store, directory, hook_types):
def init_templatedir(
config_file: str,
store: Store,
directory: str,
hook_types: Sequence[str],
) -> int:
install(
config_file, store, hook_types=hook_types,
overwrite=True, skip_on_missing_config=True, git_dir=directory,
@ -25,3 +32,4 @@ def init_templatedir(config_file, store, directory, hook_types):
logger.warning(
f'maybe `git config --global init.templateDir {dest}`?',
)
return 0

View file

@ -3,12 +3,16 @@ import logging
import os.path
import shutil
import sys
from typing import Optional
from typing import Sequence
from typing import Tuple
from pre_commit import git
from pre_commit import output
from pre_commit.clientlib import load_config
from pre_commit.repository import all_hooks
from pre_commit.repository import install_hook_envs
from pre_commit.store import Store
from pre_commit.util import make_executable
from pre_commit.util import mkdirp
from pre_commit.util import resource_text
@ -29,13 +33,16 @@ TEMPLATE_START = '# start templated\n'
TEMPLATE_END = '# end templated\n'
def _hook_paths(hook_type, git_dir=None):
def _hook_paths(
hook_type: str,
git_dir: Optional[str] = None,
) -> Tuple[str, str]:
git_dir = git_dir if git_dir is not None else git.get_git_dir()
pth = os.path.join(git_dir, 'hooks', hook_type)
return pth, f'{pth}.legacy'
def is_our_script(filename):
def is_our_script(filename: str) -> bool:
if not os.path.exists(filename): # pragma: windows no cover (symlink)
return False
with open(filename) as f:
@ -43,7 +50,7 @@ def is_our_script(filename):
return any(h in contents for h in (CURRENT_HASH,) + PRIOR_HASHES)
def shebang():
def shebang() -> str:
if sys.platform == 'win32':
py = 'python'
else:
@ -63,9 +70,12 @@ def shebang():
def _install_hook_script(
config_file, hook_type,
overwrite=False, skip_on_missing_config=False, git_dir=None,
):
config_file: str,
hook_type: str,
overwrite: bool = False,
skip_on_missing_config: bool = False,
git_dir: Optional[str] = None,
) -> None:
hook_path, legacy_path = _hook_paths(hook_type, git_dir=git_dir)
mkdirp(os.path.dirname(hook_path))
@ -108,10 +118,14 @@ def _install_hook_script(
def install(
config_file, store, hook_types,
overwrite=False, hooks=False,
skip_on_missing_config=False, git_dir=None,
):
config_file: str,
store: Store,
hook_types: Sequence[str],
overwrite: bool = False,
hooks: bool = False,
skip_on_missing_config: bool = False,
git_dir: Optional[str] = None,
) -> int:
if git.has_core_hookpaths_set():
logger.error(
'Cowardly refusing to install hooks with `core.hooksPath` set.\n'
@ -133,11 +147,12 @@ def install(
return 0
def install_hooks(config_file, store):
def install_hooks(config_file: str, store: Store) -> int:
install_hook_envs(all_hooks(load_config(config_file), store), store)
return 0
def _uninstall_hook_script(hook_type): # type: (str) -> None
def _uninstall_hook_script(hook_type: str) -> None:
hook_path, legacy_path = _hook_paths(hook_type)
# If our file doesn't exist or it isn't ours, gtfo.
@ -152,7 +167,7 @@ def _uninstall_hook_script(hook_type): # type: (str) -> None
output.write_line(f'Restored previous hooks to {hook_path}')
def uninstall(hook_types):
def uninstall(hook_types: Sequence[str]) -> int:
for hook_type in hook_types:
_uninstall_hook_script(hook_type)
return 0

View file

@ -4,16 +4,16 @@ import yaml
from aspy.yaml import ordered_load
def _indent(s):
def _indent(s: str) -> str:
lines = s.splitlines(True)
return ''.join(' ' * 4 + line if line.strip() else line for line in lines)
def _is_header_line(line):
return (line.startswith(('#', '---')) or not line.strip())
def _is_header_line(line: str) -> bool:
return line.startswith(('#', '---')) or not line.strip()
def _migrate_map(contents):
def _migrate_map(contents: str) -> str:
# Find the first non-header line
lines = contents.splitlines(True)
i = 0
@ -37,12 +37,12 @@ def _migrate_map(contents):
return contents
def _migrate_sha_to_rev(contents):
def _migrate_sha_to_rev(contents: str) -> str:
reg = re.compile(r'(\n\s+)sha:')
return reg.sub(r'\1rev:', contents)
def migrate_config(config_file, quiet=False):
def migrate_config(config_file: str, quiet: bool = False) -> int:
with open(config_file) as f:
orig_contents = contents = f.read()
@ -56,3 +56,4 @@ def migrate_config(config_file, quiet=False):
print('Configuration has been migrated.')
elif not quiet:
print('Configuration is already migrated.')
return 0

View file

@ -1,8 +1,17 @@
import argparse
import functools
import logging
import os
import re
import subprocess
import time
from typing import Any
from typing import Collection
from typing import Dict
from typing import List
from typing import Sequence
from typing import Set
from typing import Tuple
from identify.identify import tags_from_path
@ -12,16 +21,23 @@ from pre_commit import output
from pre_commit.clientlib import load_config
from pre_commit.output import get_hook_message
from pre_commit.repository import all_hooks
from pre_commit.repository import Hook
from pre_commit.repository import install_hook_envs
from pre_commit.staged_files_only import staged_files_only
from pre_commit.store import Store
from pre_commit.util import cmd_output_b
from pre_commit.util import EnvironT
from pre_commit.util import noop_context
logger = logging.getLogger('pre_commit')
def filter_by_include_exclude(names, include, exclude):
def filter_by_include_exclude(
names: Collection[str],
include: str,
exclude: str,
) -> List[str]:
include_re, exclude_re = re.compile(include), re.compile(exclude)
return [
filename for filename in names
@ -31,24 +47,25 @@ def filter_by_include_exclude(names, include, exclude):
class Classifier:
def __init__(self, filenames):
def __init__(self, filenames: Sequence[str]) -> None:
# on windows we normalize all filenames to use forward slashes
# this makes it easier to filter using the `files:` regex
# this also makes improperly quoted shell-based hooks work better
# see #1173
if os.altsep == '/' and os.sep == '\\':
filenames = (f.replace(os.sep, os.altsep) for f in filenames)
filenames = [f.replace(os.sep, os.altsep) for f in filenames]
self.filenames = [f for f in filenames if os.path.lexists(f)]
self._types_cache = {}
def _types_for_file(self, filename):
try:
return self._types_cache[filename]
except KeyError:
ret = self._types_cache[filename] = tags_from_path(filename)
return ret
@functools.lru_cache(maxsize=None)
def _types_for_file(self, filename: str) -> Set[str]:
return tags_from_path(filename)
def by_types(self, names, types, exclude_types):
def by_types(
self,
names: Sequence[str],
types: Collection[str],
exclude_types: Collection[str],
) -> List[str]:
types, exclude_types = frozenset(types), frozenset(exclude_types)
ret = []
for filename in names:
@ -57,14 +74,14 @@ class Classifier:
ret.append(filename)
return ret
def filenames_for_hook(self, hook):
def filenames_for_hook(self, hook: Hook) -> Tuple[str, ...]:
names = self.filenames
names = filter_by_include_exclude(names, hook.files, hook.exclude)
names = self.by_types(names, hook.types, hook.exclude_types)
return names
return tuple(names)
def _get_skips(environ):
def _get_skips(environ: EnvironT) -> Set[str]:
skips = environ.get('SKIP', '')
return {skip.strip() for skip in skips.split(',') if skip.strip()}
@ -73,11 +90,18 @@ SKIPPED = 'Skipped'
NO_FILES = '(no files to check)'
def _subtle_line(s, use_color):
def _subtle_line(s: str, use_color: bool) -> None:
output.write_line(color.format_color(s, color.SUBTLE, use_color))
def _run_single_hook(classifier, hook, skips, cols, verbose, use_color):
def _run_single_hook(
classifier: Classifier,
hook: Hook,
skips: Set[str],
cols: int,
verbose: bool,
use_color: bool,
) -> bool:
filenames = classifier.filenames_for_hook(hook)
if hook.id in skips or hook.alias in skips:
@ -115,7 +139,8 @@ def _run_single_hook(classifier, hook, skips, cols, verbose, use_color):
diff_cmd = ('git', 'diff', '--no-ext-diff')
diff_before = cmd_output_b(*diff_cmd, retcode=None)
filenames = tuple(filenames) if hook.pass_filenames else ()
if not hook.pass_filenames:
filenames = ()
time_before = time.time()
retcode, out = hook.run(filenames, use_color)
duration = round(time.time() - time_before, 2) or 0
@ -154,7 +179,7 @@ def _run_single_hook(classifier, hook, skips, cols, verbose, use_color):
return files_modified or bool(retcode)
def _compute_cols(hooks):
def _compute_cols(hooks: Sequence[Hook]) -> int:
"""Compute the number of columns to display hook messages. The widest
that will be displayed is in the no files skipped case:
@ -169,7 +194,7 @@ def _compute_cols(hooks):
return max(cols, 80)
def _all_filenames(args):
def _all_filenames(args: argparse.Namespace) -> Collection[str]:
if args.origin and args.source:
return git.get_changed_files(args.origin, args.source)
elif args.hook_stage in {'prepare-commit-msg', 'commit-msg'}:
@ -184,7 +209,12 @@ def _all_filenames(args):
return git.get_staged_files()
def _run_hooks(config, hooks, args, environ):
def _run_hooks(
config: Dict[str, Any],
hooks: Sequence[Hook],
args: argparse.Namespace,
environ: EnvironT,
) -> int:
"""Actually run the hooks."""
skips = _get_skips(environ)
cols = _compute_cols(hooks)
@ -221,12 +251,12 @@ def _run_hooks(config, hooks, args, environ):
return retval
def _has_unmerged_paths():
def _has_unmerged_paths() -> bool:
_, stdout, _ = cmd_output_b('git', 'ls-files', '--unmerged')
return bool(stdout.strip())
def _has_unstaged_config(config_file):
def _has_unstaged_config(config_file: str) -> bool:
retcode, _, _ = cmd_output_b(
'git', 'diff', '--no-ext-diff', '--exit-code', config_file,
retcode=None,
@ -235,7 +265,12 @@ def _has_unstaged_config(config_file):
return retcode == 1
def run(config_file, store, args, environ=os.environ):
def run(
config_file: str,
store: Store,
args: argparse.Namespace,
environ: EnvironT = os.environ,
) -> int:
no_stash = args.all_files or bool(args.files)
# Check if we have unresolved merge conflict files and fail fast.

View file

@ -16,6 +16,6 @@ repos:
'''
def sample_config():
def sample_config() -> int:
print(SAMPLE_CONFIG, end='')
return 0

View file

@ -1,6 +1,8 @@
import argparse
import collections
import logging
import os.path
from typing import Tuple
from aspy.yaml import ordered_dump
@ -17,7 +19,7 @@ from pre_commit.xargs import xargs
logger = logging.getLogger(__name__)
def _repo_ref(tmpdir, repo, ref):
def _repo_ref(tmpdir: str, repo: str, ref: str) -> Tuple[str, str]:
# if `ref` is explicitly passed, use it
if ref:
return repo, ref
@ -47,7 +49,7 @@ def _repo_ref(tmpdir, repo, ref):
return repo, ref
def try_repo(args):
def try_repo(args: argparse.Namespace) -> int:
with tmpdir() as tempdir:
repo, ref = _repo_ref(tempdir, args.repo, args.ref)