Add types to pre-commit

This commit is contained in:
Anthony Sottile 2020-01-10 23:32:28 -08:00
parent fa536a8693
commit 327ed924a3
62 changed files with 911 additions and 411 deletions

View file

@ -25,6 +25,10 @@ exclude_lines =
^\s*return NotImplemented\b
^\s*raise$
# Ignore typing-related things
^if (False|TYPE_CHECKING):
: \.\.\.$
# Don't complain if non-runnable code isn't run:
^if __name__ == ['"]__main__['"]:$

View file

@ -3,6 +3,10 @@ import functools
import logging
import pipes
import sys
from typing import Any
from typing import Dict
from typing import Optional
from typing import Sequence
import cfgv
from aspy.yaml import ordered_load
@ -18,7 +22,7 @@ logger = logging.getLogger('pre_commit')
check_string_regex = cfgv.check_and(cfgv.check_string, cfgv.check_regex)
def check_type_tag(tag):
def check_type_tag(tag: str) -> None:
if tag not in ALL_TAGS:
raise cfgv.ValidationError(
'Type tag {!r} is not recognized. '
@ -26,7 +30,7 @@ def check_type_tag(tag):
)
def check_min_version(version):
def check_min_version(version: str) -> None:
if parse_version(version) > parse_version(C.VERSION):
raise cfgv.ValidationError(
'pre-commit version {} is required but version {} is installed. '
@ -36,7 +40,7 @@ def check_min_version(version):
)
def _make_argparser(filenames_help):
def _make_argparser(filenames_help: str) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help=filenames_help)
parser.add_argument('-V', '--version', action='version', version=C.VERSION)
@ -86,7 +90,7 @@ load_manifest = functools.partial(
)
def validate_manifest_main(argv=None):
def validate_manifest_main(argv: Optional[Sequence[str]] = None) -> int:
parser = _make_argparser('Manifest filenames.')
args = parser.parse_args(argv)
ret = 0
@ -107,7 +111,7 @@ class MigrateShaToRev:
key = 'rev'
@staticmethod
def _cond(key):
def _cond(key: str) -> cfgv.Conditional:
return cfgv.Conditional(
key, cfgv.check_string,
condition_key='repo',
@ -115,7 +119,7 @@ class MigrateShaToRev:
ensure_absent=True,
)
def check(self, dct):
def check(self, dct: Dict[str, Any]) -> None:
if dct.get('repo') in {LOCAL, META}:
self._cond('rev').check(dct)
self._cond('sha').check(dct)
@ -126,14 +130,14 @@ class MigrateShaToRev:
else:
self._cond('rev').check(dct)
def apply_default(self, dct):
def apply_default(self, dct: Dict[str, Any]) -> None:
if 'sha' in dct:
dct['rev'] = dct.pop('sha')
remove_default = cfgv.Required.remove_default
def _entry(modname):
def _entry(modname: str) -> str:
"""the hook `entry` is passed through `shlex.split()` by the command
runner, so to prevent issues with spaces and backslashes (on Windows)
it must be quoted here.
@ -143,13 +147,21 @@ def _entry(modname):
)
def warn_unknown_keys_root(extra, orig_keys, dct):
def warn_unknown_keys_root(
extra: Sequence[str],
orig_keys: Sequence[str],
dct: Dict[str, str],
) -> None:
logger.warning(
'Unexpected key(s) present at root: {}'.format(', '.join(extra)),
)
def warn_unknown_keys_repo(extra, orig_keys, dct):
def warn_unknown_keys_repo(
extra: Sequence[str],
orig_keys: Sequence[str],
dct: Dict[str, str],
) -> None:
logger.warning(
'Unexpected key(s) present on {}: {}'.format(
dct['repo'], ', '.join(extra),
@ -281,7 +293,7 @@ class InvalidConfigError(FatalError):
pass
def ordered_load_normalize_legacy_config(contents):
def ordered_load_normalize_legacy_config(contents: str) -> Dict[str, Any]:
data = ordered_load(contents)
if isinstance(data, list):
# TODO: Once happy, issue a deprecation warning and instructions
@ -298,7 +310,7 @@ load_config = functools.partial(
)
def validate_config_main(argv=None):
def validate_config_main(argv: Optional[Sequence[str]] = None) -> int:
parser = _make_argparser('Config filenames.')
args = parser.parse_args(argv)
ret = 0

View file

@ -21,7 +21,7 @@ class InvalidColorSetting(ValueError):
pass
def format_color(text, color, use_color_setting):
def format_color(text: str, color: str, use_color_setting: bool) -> str:
"""Format text with color.
Args:
@ -38,7 +38,7 @@ def format_color(text, color, use_color_setting):
COLOR_CHOICES = ('auto', 'always', 'never')
def use_color(setting):
def use_color(setting: str) -> bool:
"""Choose whether to use color based on the command argument.
Args:

View file

@ -1,8 +1,12 @@
import collections
import os.path
import re
from typing import Any
from typing import Dict
from typing import List
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from typing import Tuple
from aspy.yaml import ordered_dump
from aspy.yaml import ordered_load
@ -16,20 +20,23 @@ from pre_commit.clientlib import load_manifest
from pre_commit.clientlib import LOCAL
from pre_commit.clientlib import META
from pre_commit.commands.migrate_config import migrate_config
from pre_commit.store import Store
from pre_commit.util import CalledProcessError
from pre_commit.util import cmd_output
from pre_commit.util import cmd_output_b
from pre_commit.util import tmpdir
class RevInfo(collections.namedtuple('RevInfo', ('repo', 'rev', 'frozen'))):
__slots__ = ()
class RevInfo(NamedTuple):
repo: str
rev: str
frozen: Optional[str]
@classmethod
def from_config(cls, config):
def from_config(cls, config: Dict[str, Any]) -> 'RevInfo':
return cls(config['repo'], config['rev'], None)
def update(self, tags_only, freeze):
def update(self, tags_only: bool, freeze: bool) -> 'RevInfo':
if tags_only:
tag_cmd = ('git', 'describe', 'FETCH_HEAD', '--tags', '--abbrev=0')
else:
@ -57,7 +64,11 @@ class RepositoryCannotBeUpdatedError(RuntimeError):
pass
def _check_hooks_still_exist_at_rev(repo_config, info, store):
def _check_hooks_still_exist_at_rev(
repo_config: Dict[str, Any],
info: RevInfo,
store: Store,
) -> None:
try:
path = store.clone(repo_config['repo'], info.rev)
manifest = load_manifest(os.path.join(path, C.MANIFEST_FILE))
@ -78,7 +89,11 @@ REV_LINE_RE = re.compile(r'^(\s+)rev:(\s*)([^\s#]+)(.*)(\r?\n)$', re.DOTALL)
REV_LINE_FMT = '{}rev:{}{}{}{}'
def _original_lines(path, rev_infos, retry=False):
def _original_lines(
path: str,
rev_infos: List[Optional[RevInfo]],
retry: bool = False,
) -> Tuple[List[str], List[int]]:
"""detect `rev:` lines or reformat the file"""
with open(path) as f:
original = f.read()
@ -95,7 +110,7 @@ def _original_lines(path, rev_infos, retry=False):
return _original_lines(path, rev_infos, retry=True)
def _write_new_config(path, rev_infos):
def _write_new_config(path: str, rev_infos: List[Optional[RevInfo]]) -> None:
lines, idxs = _original_lines(path, rev_infos)
for idx, rev_info in zip(idxs, rev_infos):
@ -119,7 +134,13 @@ def _write_new_config(path, rev_infos):
f.write(''.join(lines))
def autoupdate(config_file, store, tags_only, freeze, repos=()):
def autoupdate(
config_file: str,
store: Store,
tags_only: bool,
freeze: bool,
repos: Sequence[str] = (),
) -> int:
"""Auto-update the pre-commit config to the latest versions of repos."""
migrate_config(config_file, quiet=True)
retv = 0

View file

@ -1,10 +1,11 @@
import os.path
from pre_commit import output
from pre_commit.store import Store
from pre_commit.util import rmtree
def clean(store):
def clean(store: Store) -> int:
legacy_path = os.path.expanduser('~/.pre-commit')
for directory in (store.directory, legacy_path):
if os.path.exists(directory):

View file

@ -1,4 +1,8 @@
import os.path
from typing import Any
from typing import Dict
from typing import Set
from typing import Tuple
import pre_commit.constants as C
from pre_commit import output
@ -8,9 +12,15 @@ from pre_commit.clientlib import load_config
from pre_commit.clientlib import load_manifest
from pre_commit.clientlib import LOCAL
from pre_commit.clientlib import META
from pre_commit.store import Store
def _mark_used_repos(store, all_repos, unused_repos, repo):
def _mark_used_repos(
store: Store,
all_repos: Dict[Tuple[str, str], str],
unused_repos: Set[Tuple[str, str]],
repo: Dict[str, Any],
) -> None:
if repo['repo'] == META:
return
elif repo['repo'] == LOCAL:
@ -47,7 +57,7 @@ def _mark_used_repos(store, all_repos, unused_repos, repo):
))
def _gc_repos(store):
def _gc_repos(store: Store) -> int:
configs = store.select_all_configs()
repos = store.select_all_repos()
@ -73,7 +83,7 @@ def _gc_repos(store):
return len(unused_repos)
def gc(store):
def gc(store: Store) -> int:
with store.exclusive_lock():
repos_removed = _gc_repos(store)
output.write_line(f'{repos_removed} repo(s) removed.')

View file

@ -1,14 +1,21 @@
import logging
import os.path
from typing import Sequence
from pre_commit.commands.install_uninstall import install
from pre_commit.store import Store
from pre_commit.util import CalledProcessError
from pre_commit.util import cmd_output
logger = logging.getLogger('pre_commit')
def init_templatedir(config_file, store, directory, hook_types):
def init_templatedir(
config_file: str,
store: Store,
directory: str,
hook_types: Sequence[str],
) -> int:
install(
config_file, store, hook_types=hook_types,
overwrite=True, skip_on_missing_config=True, git_dir=directory,
@ -25,3 +32,4 @@ def init_templatedir(config_file, store, directory, hook_types):
logger.warning(
f'maybe `git config --global init.templateDir {dest}`?',
)
return 0

View file

@ -3,12 +3,16 @@ import logging
import os.path
import shutil
import sys
from typing import Optional
from typing import Sequence
from typing import Tuple
from pre_commit import git
from pre_commit import output
from pre_commit.clientlib import load_config
from pre_commit.repository import all_hooks
from pre_commit.repository import install_hook_envs
from pre_commit.store import Store
from pre_commit.util import make_executable
from pre_commit.util import mkdirp
from pre_commit.util import resource_text
@ -29,13 +33,16 @@ TEMPLATE_START = '# start templated\n'
TEMPLATE_END = '# end templated\n'
def _hook_paths(hook_type, git_dir=None):
def _hook_paths(
hook_type: str,
git_dir: Optional[str] = None,
) -> Tuple[str, str]:
git_dir = git_dir if git_dir is not None else git.get_git_dir()
pth = os.path.join(git_dir, 'hooks', hook_type)
return pth, f'{pth}.legacy'
def is_our_script(filename):
def is_our_script(filename: str) -> bool:
if not os.path.exists(filename): # pragma: windows no cover (symlink)
return False
with open(filename) as f:
@ -43,7 +50,7 @@ def is_our_script(filename):
return any(h in contents for h in (CURRENT_HASH,) + PRIOR_HASHES)
def shebang():
def shebang() -> str:
if sys.platform == 'win32':
py = 'python'
else:
@ -63,9 +70,12 @@ def shebang():
def _install_hook_script(
config_file, hook_type,
overwrite=False, skip_on_missing_config=False, git_dir=None,
):
config_file: str,
hook_type: str,
overwrite: bool = False,
skip_on_missing_config: bool = False,
git_dir: Optional[str] = None,
) -> None:
hook_path, legacy_path = _hook_paths(hook_type, git_dir=git_dir)
mkdirp(os.path.dirname(hook_path))
@ -108,10 +118,14 @@ def _install_hook_script(
def install(
config_file, store, hook_types,
overwrite=False, hooks=False,
skip_on_missing_config=False, git_dir=None,
):
config_file: str,
store: Store,
hook_types: Sequence[str],
overwrite: bool = False,
hooks: bool = False,
skip_on_missing_config: bool = False,
git_dir: Optional[str] = None,
) -> int:
if git.has_core_hookpaths_set():
logger.error(
'Cowardly refusing to install hooks with `core.hooksPath` set.\n'
@ -133,11 +147,12 @@ def install(
return 0
def install_hooks(config_file, store):
def install_hooks(config_file: str, store: Store) -> int:
install_hook_envs(all_hooks(load_config(config_file), store), store)
return 0
def _uninstall_hook_script(hook_type): # type: (str) -> None
def _uninstall_hook_script(hook_type: str) -> None:
hook_path, legacy_path = _hook_paths(hook_type)
# If our file doesn't exist or it isn't ours, gtfo.
@ -152,7 +167,7 @@ def _uninstall_hook_script(hook_type): # type: (str) -> None
output.write_line(f'Restored previous hooks to {hook_path}')
def uninstall(hook_types):
def uninstall(hook_types: Sequence[str]) -> int:
for hook_type in hook_types:
_uninstall_hook_script(hook_type)
return 0

View file

@ -4,16 +4,16 @@ import yaml
from aspy.yaml import ordered_load
def _indent(s):
def _indent(s: str) -> str:
lines = s.splitlines(True)
return ''.join(' ' * 4 + line if line.strip() else line for line in lines)
def _is_header_line(line):
return (line.startswith(('#', '---')) or not line.strip())
def _is_header_line(line: str) -> bool:
return line.startswith(('#', '---')) or not line.strip()
def _migrate_map(contents):
def _migrate_map(contents: str) -> str:
# Find the first non-header line
lines = contents.splitlines(True)
i = 0
@ -37,12 +37,12 @@ def _migrate_map(contents):
return contents
def _migrate_sha_to_rev(contents):
def _migrate_sha_to_rev(contents: str) -> str:
reg = re.compile(r'(\n\s+)sha:')
return reg.sub(r'\1rev:', contents)
def migrate_config(config_file, quiet=False):
def migrate_config(config_file: str, quiet: bool = False) -> int:
with open(config_file) as f:
orig_contents = contents = f.read()
@ -56,3 +56,4 @@ def migrate_config(config_file, quiet=False):
print('Configuration has been migrated.')
elif not quiet:
print('Configuration is already migrated.')
return 0

View file

@ -1,8 +1,17 @@
import argparse
import functools
import logging
import os
import re
import subprocess
import time
from typing import Any
from typing import Collection
from typing import Dict
from typing import List
from typing import Sequence
from typing import Set
from typing import Tuple
from identify.identify import tags_from_path
@ -12,16 +21,23 @@ from pre_commit import output
from pre_commit.clientlib import load_config
from pre_commit.output import get_hook_message
from pre_commit.repository import all_hooks
from pre_commit.repository import Hook
from pre_commit.repository import install_hook_envs
from pre_commit.staged_files_only import staged_files_only
from pre_commit.store import Store
from pre_commit.util import cmd_output_b
from pre_commit.util import EnvironT
from pre_commit.util import noop_context
logger = logging.getLogger('pre_commit')
def filter_by_include_exclude(names, include, exclude):
def filter_by_include_exclude(
names: Collection[str],
include: str,
exclude: str,
) -> List[str]:
include_re, exclude_re = re.compile(include), re.compile(exclude)
return [
filename for filename in names
@ -31,24 +47,25 @@ def filter_by_include_exclude(names, include, exclude):
class Classifier:
def __init__(self, filenames):
def __init__(self, filenames: Sequence[str]) -> None:
# on windows we normalize all filenames to use forward slashes
# this makes it easier to filter using the `files:` regex
# this also makes improperly quoted shell-based hooks work better
# see #1173
if os.altsep == '/' and os.sep == '\\':
filenames = (f.replace(os.sep, os.altsep) for f in filenames)
filenames = [f.replace(os.sep, os.altsep) for f in filenames]
self.filenames = [f for f in filenames if os.path.lexists(f)]
self._types_cache = {}
def _types_for_file(self, filename):
try:
return self._types_cache[filename]
except KeyError:
ret = self._types_cache[filename] = tags_from_path(filename)
return ret
@functools.lru_cache(maxsize=None)
def _types_for_file(self, filename: str) -> Set[str]:
return tags_from_path(filename)
def by_types(self, names, types, exclude_types):
def by_types(
self,
names: Sequence[str],
types: Collection[str],
exclude_types: Collection[str],
) -> List[str]:
types, exclude_types = frozenset(types), frozenset(exclude_types)
ret = []
for filename in names:
@ -57,14 +74,14 @@ class Classifier:
ret.append(filename)
return ret
def filenames_for_hook(self, hook):
def filenames_for_hook(self, hook: Hook) -> Tuple[str, ...]:
names = self.filenames
names = filter_by_include_exclude(names, hook.files, hook.exclude)
names = self.by_types(names, hook.types, hook.exclude_types)
return names
return tuple(names)
def _get_skips(environ):
def _get_skips(environ: EnvironT) -> Set[str]:
skips = environ.get('SKIP', '')
return {skip.strip() for skip in skips.split(',') if skip.strip()}
@ -73,11 +90,18 @@ SKIPPED = 'Skipped'
NO_FILES = '(no files to check)'
def _subtle_line(s, use_color):
def _subtle_line(s: str, use_color: bool) -> None:
output.write_line(color.format_color(s, color.SUBTLE, use_color))
def _run_single_hook(classifier, hook, skips, cols, verbose, use_color):
def _run_single_hook(
classifier: Classifier,
hook: Hook,
skips: Set[str],
cols: int,
verbose: bool,
use_color: bool,
) -> bool:
filenames = classifier.filenames_for_hook(hook)
if hook.id in skips or hook.alias in skips:
@ -115,7 +139,8 @@ def _run_single_hook(classifier, hook, skips, cols, verbose, use_color):
diff_cmd = ('git', 'diff', '--no-ext-diff')
diff_before = cmd_output_b(*diff_cmd, retcode=None)
filenames = tuple(filenames) if hook.pass_filenames else ()
if not hook.pass_filenames:
filenames = ()
time_before = time.time()
retcode, out = hook.run(filenames, use_color)
duration = round(time.time() - time_before, 2) or 0
@ -154,7 +179,7 @@ def _run_single_hook(classifier, hook, skips, cols, verbose, use_color):
return files_modified or bool(retcode)
def _compute_cols(hooks):
def _compute_cols(hooks: Sequence[Hook]) -> int:
"""Compute the number of columns to display hook messages. The widest
that will be displayed is in the no files skipped case:
@ -169,7 +194,7 @@ def _compute_cols(hooks):
return max(cols, 80)
def _all_filenames(args):
def _all_filenames(args: argparse.Namespace) -> Collection[str]:
if args.origin and args.source:
return git.get_changed_files(args.origin, args.source)
elif args.hook_stage in {'prepare-commit-msg', 'commit-msg'}:
@ -184,7 +209,12 @@ def _all_filenames(args):
return git.get_staged_files()
def _run_hooks(config, hooks, args, environ):
def _run_hooks(
config: Dict[str, Any],
hooks: Sequence[Hook],
args: argparse.Namespace,
environ: EnvironT,
) -> int:
"""Actually run the hooks."""
skips = _get_skips(environ)
cols = _compute_cols(hooks)
@ -221,12 +251,12 @@ def _run_hooks(config, hooks, args, environ):
return retval
def _has_unmerged_paths():
def _has_unmerged_paths() -> bool:
_, stdout, _ = cmd_output_b('git', 'ls-files', '--unmerged')
return bool(stdout.strip())
def _has_unstaged_config(config_file):
def _has_unstaged_config(config_file: str) -> bool:
retcode, _, _ = cmd_output_b(
'git', 'diff', '--no-ext-diff', '--exit-code', config_file,
retcode=None,
@ -235,7 +265,12 @@ def _has_unstaged_config(config_file):
return retcode == 1
def run(config_file, store, args, environ=os.environ):
def run(
config_file: str,
store: Store,
args: argparse.Namespace,
environ: EnvironT = os.environ,
) -> int:
no_stash = args.all_files or bool(args.files)
# Check if we have unresolved merge conflict files and fail fast.

View file

@ -16,6 +16,6 @@ repos:
'''
def sample_config():
def sample_config() -> int:
print(SAMPLE_CONFIG, end='')
return 0

View file

@ -1,6 +1,8 @@
import argparse
import collections
import logging
import os.path
from typing import Tuple
from aspy.yaml import ordered_dump
@ -17,7 +19,7 @@ from pre_commit.xargs import xargs
logger = logging.getLogger(__name__)
def _repo_ref(tmpdir, repo, ref):
def _repo_ref(tmpdir: str, repo: str, ref: str) -> Tuple[str, str]:
# if `ref` is explicitly passed, use it
if ref:
return repo, ref
@ -47,7 +49,7 @@ def _repo_ref(tmpdir, repo, ref):
return repo, ref
def try_repo(args):
def try_repo(args: argparse.Namespace) -> int:
with tmpdir() as tempdir:
repo, ref = _repo_ref(tempdir, args.repo, args.ref)

View file

@ -1,10 +1,14 @@
import contextlib
import enum
import os
from typing import Generator
from typing import NamedTuple
from typing import Optional
from typing import Tuple
from typing import Union
from pre_commit.util import EnvironT
class _Unset(enum.Enum):
UNSET = 1
@ -23,7 +27,7 @@ ValueT = Union[str, _Unset, SubstitutionT]
PatchesT = Tuple[Tuple[str, ValueT], ...]
def format_env(parts, env):
def format_env(parts: SubstitutionT, env: EnvironT) -> str:
return ''.join(
env.get(part.name, part.default) if isinstance(part, Var) else part
for part in parts
@ -31,7 +35,10 @@ def format_env(parts, env):
@contextlib.contextmanager
def envcontext(patch, _env=None):
def envcontext(
patch: PatchesT,
_env: Optional[EnvironT] = None,
) -> Generator[None, None, None]:
"""In this context, `os.environ` is modified according to `patch`.
`patch` is an iterable of 2-tuples (key, value):

View file

@ -2,6 +2,7 @@ import contextlib
import os.path
import sys
import traceback
from typing import Generator
from typing import Union
import pre_commit.constants as C
@ -14,14 +15,11 @@ class FatalError(RuntimeError):
pass
def _to_bytes(exc):
try:
return bytes(exc)
except Exception:
def _to_bytes(exc: BaseException) -> bytes:
return str(exc).encode('UTF-8')
def _log_and_exit(msg, exc, formatted):
def _log_and_exit(msg: str, exc: BaseException, formatted: str) -> None:
error_msg = b''.join((
five.to_bytes(msg), b': ',
five.to_bytes(type(exc).__name__), b': ',
@ -62,7 +60,7 @@ def _log_and_exit(msg, exc, formatted):
@contextlib.contextmanager
def error_handler():
def error_handler() -> Generator[None, None, None]:
try:
yield
except (Exception, KeyboardInterrupt) as e:

View file

@ -1,6 +1,8 @@
import contextlib
import errno
import os
from typing import Callable
from typing import Generator
if os.name == 'nt': # pragma: no cover (windows)
@ -13,7 +15,10 @@ if os.name == 'nt': # pragma: no cover (windows)
_region = 0xffff
@contextlib.contextmanager
def _locked(fileno, blocked_cb):
def _locked(
fileno: int,
blocked_cb: Callable[[], None],
) -> Generator[None, None, None]:
try:
# TODO: https://github.com/python/typeshed/pull/3607
msvcrt.locking(fileno, msvcrt.LK_NBLCK, _region) # type: ignore
@ -42,11 +47,14 @@ if os.name == 'nt': # pragma: no cover (windows)
# before closing a file or exiting the program."
# TODO: https://github.com/python/typeshed/pull/3607
msvcrt.locking(fileno, msvcrt.LK_UNLCK, _region) # type: ignore
else: # pramga: windows no cover
else: # pragma: windows no cover
import fcntl
@contextlib.contextmanager
def _locked(fileno, blocked_cb):
def _locked(
fileno: int,
blocked_cb: Callable[[], None],
) -> Generator[None, None, None]:
try:
fcntl.flock(fileno, fcntl.LOCK_EX | fcntl.LOCK_NB)
except OSError: # pragma: no cover (tests are single-threaded)
@ -59,7 +67,10 @@ else: # pramga: windows no cover
@contextlib.contextmanager
def lock(path, blocked_cb):
def lock(
path: str,
blocked_cb: Callable[[], None],
) -> Generator[None, None, None]:
with open(path, 'a+') as f:
with _locked(f.fileno(), blocked_cb):
yield

View file

@ -1,8 +1,11 @@
def to_text(s):
from typing import Union
def to_text(s: Union[str, bytes]) -> str:
return s if isinstance(s, str) else s.decode('UTF-8')
def to_bytes(s):
def to_bytes(s: Union[str, bytes]) -> bytes:
return s if isinstance(s, bytes) else s.encode('UTF-8')

View file

@ -1,15 +1,20 @@
import logging
import os.path
import sys
from typing import Dict
from typing import List
from typing import Optional
from typing import Set
from pre_commit.util import cmd_output
from pre_commit.util import cmd_output_b
from pre_commit.util import EnvironT
logger = logging.getLogger(__name__)
def zsplit(s):
def zsplit(s: str) -> List[str]:
s = s.strip('\0')
if s:
return s.split('\0')
@ -17,7 +22,7 @@ def zsplit(s):
return []
def no_git_env(_env=None):
def no_git_env(_env: Optional[EnvironT] = None) -> Dict[str, str]:
# Too many bugs dealing with environment variables and GIT:
# https://github.com/pre-commit/pre-commit/issues/300
# In git 2.6.3 (maybe others), git exports GIT_WORK_TREE while running
@ -34,11 +39,11 @@ def no_git_env(_env=None):
}
def get_root():
def get_root() -> str:
return cmd_output('git', 'rev-parse', '--show-toplevel')[1].strip()
def get_git_dir(git_root='.'):
def get_git_dir(git_root: str = '.') -> str:
opts = ('--git-common-dir', '--git-dir')
_, out, _ = cmd_output('git', 'rev-parse', *opts, cwd=git_root)
for line, opt in zip(out.splitlines(), opts):
@ -48,12 +53,12 @@ def get_git_dir(git_root='.'):
raise AssertionError('unreachable: no git dir')
def get_remote_url(git_root):
def get_remote_url(git_root: str) -> str:
_, out, _ = cmd_output('git', 'config', 'remote.origin.url', cwd=git_root)
return out.strip()
def is_in_merge_conflict():
def is_in_merge_conflict() -> bool:
git_dir = get_git_dir('.')
return (
os.path.exists(os.path.join(git_dir, 'MERGE_MSG')) and
@ -61,7 +66,7 @@ def is_in_merge_conflict():
)
def parse_merge_msg_for_conflicts(merge_msg):
def parse_merge_msg_for_conflicts(merge_msg: bytes) -> List[str]:
# Conflicted files start with tabs
return [
line.lstrip(b'#').strip().decode('UTF-8')
@ -71,7 +76,7 @@ def parse_merge_msg_for_conflicts(merge_msg):
]
def get_conflicted_files():
def get_conflicted_files() -> Set[str]:
logger.info('Checking merge-conflict files only.')
# Need to get the conflicted files from the MERGE_MSG because they could
# have resolved the conflict by choosing one side or the other
@ -92,7 +97,7 @@ def get_conflicted_files():
return set(merge_conflict_filenames) | set(merge_diff_filenames)
def get_staged_files(cwd=None):
def get_staged_files(cwd: Optional[str] = None) -> List[str]:
return zsplit(
cmd_output(
'git', 'diff', '--staged', '--name-only', '--no-ext-diff', '-z',
@ -103,7 +108,7 @@ def get_staged_files(cwd=None):
)
def intent_to_add_files():
def intent_to_add_files() -> List[str]:
_, stdout, _ = cmd_output('git', 'status', '--porcelain', '-z')
parts = list(reversed(zsplit(stdout)))
intent_to_add = []
@ -117,11 +122,11 @@ def intent_to_add_files():
return intent_to_add
def get_all_files():
def get_all_files() -> List[str]:
return zsplit(cmd_output('git', 'ls-files', '-z')[1])
def get_changed_files(new, old):
def get_changed_files(new: str, old: str) -> List[str]:
return zsplit(
cmd_output(
'git', 'diff', '--name-only', '--no-ext-diff', '-z',
@ -130,24 +135,22 @@ def get_changed_files(new, old):
)
def head_rev(remote):
def head_rev(remote: str) -> str:
_, out, _ = cmd_output('git', 'ls-remote', '--exit-code', remote, 'HEAD')
return out.split()[0]
def has_diff(*args, **kwargs):
repo = kwargs.pop('repo', '.')
assert not kwargs, kwargs
def has_diff(*args: str, repo: str = '.') -> bool:
cmd = ('git', 'diff', '--quiet', '--no-ext-diff') + args
return cmd_output_b(*cmd, cwd=repo, retcode=None)[0] == 1
def has_core_hookpaths_set():
def has_core_hookpaths_set() -> bool:
_, out, _ = cmd_output_b('git', 'config', 'core.hooksPath', retcode=None)
return bool(out.strip())
def init_repo(path, remote):
def init_repo(path: str, remote: str) -> None:
if os.path.isdir(remote):
remote = os.path.abspath(remote)
@ -156,7 +159,7 @@ def init_repo(path, remote):
cmd_output_b('git', 'remote', 'add', 'origin', remote, cwd=path, env=env)
def commit(repo='.'):
def commit(repo: str = '.') -> None:
env = no_git_env()
name, email = 'pre-commit', 'asottile+pre-commit@umich.edu'
env['GIT_AUTHOR_NAME'] = env['GIT_COMMITTER_NAME'] = name
@ -165,12 +168,12 @@ def commit(repo='.'):
cmd_output_b(*cmd, cwd=repo, env=env)
def git_path(name, repo='.'):
def git_path(name: str, repo: str = '.') -> str:
_, out, _ = cmd_output('git', 'rev-parse', '--git-path', name, cwd=repo)
return os.path.join(repo, out.strip())
def check_for_cygwin_mismatch():
def check_for_cygwin_mismatch() -> None:
"""See https://github.com/pre-commit/pre-commit/issues/354"""
if sys.platform in ('cygwin', 'win32'): # pragma: no cover (windows)
is_cygwin_python = sys.platform == 'cygwin'

View file

@ -1,20 +1,29 @@
import contextlib
import os
from typing import Generator
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from pre_commit.envcontext import envcontext
from pre_commit.envcontext import PatchesT
from pre_commit.envcontext import SubstitutionT
from pre_commit.envcontext import UNSET
from pre_commit.envcontext import Var
from pre_commit.languages import helpers
from pre_commit.prefix import Prefix
from pre_commit.util import clean_path_on_failure
from pre_commit.util import cmd_output_b
if TYPE_CHECKING:
from pre_commit.repository import Hook
ENVIRONMENT_DIR = 'conda'
get_default_version = helpers.basic_get_default_version
healthy = helpers.basic_healthy
def get_env_patch(env):
def get_env_patch(env: str) -> PatchesT:
# On non-windows systems executable live in $CONDA_PREFIX/bin, on Windows
# they can be in $CONDA_PREFIX/bin, $CONDA_PREFIX/Library/bin,
# $CONDA_PREFIX/Scripts and $CONDA_PREFIX. Whereas the latter only
@ -34,14 +43,21 @@ def get_env_patch(env):
@contextlib.contextmanager
def in_env(prefix, language_version):
def in_env(
prefix: Prefix,
language_version: str,
) -> Generator[None, None, None]:
directory = helpers.environment_dir(ENVIRONMENT_DIR, language_version)
envdir = prefix.path(directory)
with envcontext(get_env_patch(envdir)):
yield
def install_environment(prefix, version, additional_dependencies):
def install_environment(
prefix: Prefix,
version: str,
additional_dependencies: Sequence[str],
) -> None:
helpers.assert_version_default('conda', version)
directory = helpers.environment_dir(ENVIRONMENT_DIR, version)
@ -58,7 +74,11 @@ def install_environment(prefix, version, additional_dependencies):
)
def run_hook(hook, file_args, color):
def run_hook(
hook: 'Hook',
file_args: Sequence[str],
color: bool,
) -> Tuple[int, bytes]:
# TODO: Some rare commands need to be run using `conda run` but mostly we
# can run them withot which is much quicker and produces a better
# output.

View file

@ -1,14 +1,18 @@
import hashlib
import os
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
import pre_commit.constants as C
from pre_commit import five
from pre_commit.languages import helpers
from pre_commit.prefix import Prefix
from pre_commit.util import CalledProcessError
from pre_commit.util import clean_path_on_failure
from pre_commit.util import cmd_output_b
if TYPE_CHECKING:
from pre_commit.repository import Hook
ENVIRONMENT_DIR = 'docker'
PRE_COMMIT_LABEL = 'PRE_COMMIT'
@ -16,16 +20,16 @@ get_default_version = helpers.basic_get_default_version
healthy = helpers.basic_healthy
def md5(s): # pragma: windows no cover
return hashlib.md5(five.to_bytes(s)).hexdigest()
def md5(s: str) -> str: # pragma: windows no cover
return hashlib.md5(s.encode()).hexdigest()
def docker_tag(prefix): # pragma: windows no cover
def docker_tag(prefix: Prefix) -> str: # pragma: windows no cover
md5sum = md5(os.path.basename(prefix.prefix_dir)).lower()
return f'pre-commit-{md5sum}'
def docker_is_running(): # pragma: windows no cover
def docker_is_running() -> bool: # pragma: windows no cover
try:
cmd_output_b('docker', 'ps')
except CalledProcessError:
@ -34,15 +38,17 @@ def docker_is_running(): # pragma: windows no cover
return True
def assert_docker_available(): # pragma: windows no cover
def assert_docker_available() -> None: # pragma: windows no cover
assert docker_is_running(), (
'Docker is either not running or not configured in this environment'
)
def build_docker_image(prefix, **kwargs): # pragma: windows no cover
pull = kwargs.pop('pull')
assert not kwargs, kwargs
def build_docker_image(
prefix: Prefix,
*,
pull: bool,
) -> None: # pragma: windows no cover
cmd: Tuple[str, ...] = (
'docker', 'build',
'--tag', docker_tag(prefix),
@ -56,8 +62,8 @@ def build_docker_image(prefix, **kwargs): # pragma: windows no cover
def install_environment(
prefix, version, additional_dependencies,
): # pragma: windows no cover
prefix: Prefix, version: str, additional_dependencies: Sequence[str],
) -> None: # pragma: windows no cover
helpers.assert_version_default('docker', version)
helpers.assert_no_additional_deps('docker', additional_dependencies)
assert_docker_available()
@ -73,14 +79,14 @@ def install_environment(
os.mkdir(directory)
def get_docker_user(): # pragma: windows no cover
def get_docker_user() -> str: # pragma: windows no cover
try:
return '{}:{}'.format(os.getuid(), os.getgid())
except AttributeError:
return '1000:1000'
def docker_cmd(): # pragma: windows no cover
def docker_cmd() -> Tuple[str, ...]: # pragma: windows no cover
return (
'docker', 'run',
'--rm',
@ -93,7 +99,11 @@ def docker_cmd(): # pragma: windows no cover
)
def run_hook(hook, file_args, color): # pragma: windows no cover
def run_hook(
hook: 'Hook',
file_args: Sequence[str],
color: bool,
) -> Tuple[int, bytes]: # pragma: windows no cover
assert_docker_available()
# Rebuild the docker image in case it has gone missing, as many people do
# automated cleanup of docker images.

View file

@ -1,7 +1,13 @@
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from pre_commit.languages import helpers
from pre_commit.languages.docker import assert_docker_available
from pre_commit.languages.docker import docker_cmd
if TYPE_CHECKING:
from pre_commit.repository import Hook
ENVIRONMENT_DIR = None
get_default_version = helpers.basic_get_default_version
@ -9,7 +15,11 @@ healthy = helpers.basic_healthy
install_environment = helpers.no_install
def run_hook(hook, file_args, color): # pragma: windows no cover
def run_hook(
hook: 'Hook',
file_args: Sequence[str],
color: bool,
) -> Tuple[int, bytes]: # pragma: windows no cover
assert_docker_available()
cmd = docker_cmd() + hook.cmd
return helpers.run_xargs(hook, cmd, file_args, color=color)

View file

@ -1,5 +1,11 @@
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from pre_commit.languages import helpers
if TYPE_CHECKING:
from pre_commit.repository import Hook
ENVIRONMENT_DIR = None
get_default_version = helpers.basic_get_default_version
@ -7,7 +13,11 @@ healthy = helpers.basic_healthy
install_environment = helpers.no_install
def run_hook(hook, file_args, color):
def run_hook(
hook: 'Hook',
file_args: Sequence[str],
color: bool,
) -> Tuple[int, bytes]:
out = hook.entry.encode('UTF-8') + b'\n\n'
out += b'\n'.join(f.encode('UTF-8') for f in file_args) + b'\n'
return 1, out

View file

@ -1,31 +1,39 @@
import contextlib
import os.path
import sys
from typing import Generator
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
import pre_commit.constants as C
from pre_commit import git
from pre_commit.envcontext import envcontext
from pre_commit.envcontext import PatchesT
from pre_commit.envcontext import Var
from pre_commit.languages import helpers
from pre_commit.prefix import Prefix
from pre_commit.util import clean_path_on_failure
from pre_commit.util import cmd_output
from pre_commit.util import cmd_output_b
from pre_commit.util import rmtree
if TYPE_CHECKING:
from pre_commit.repository import Hook
ENVIRONMENT_DIR = 'golangenv'
get_default_version = helpers.basic_get_default_version
healthy = helpers.basic_healthy
def get_env_patch(venv):
def get_env_patch(venv: str) -> PatchesT:
return (
('PATH', (os.path.join(venv, 'bin'), os.pathsep, Var('PATH'))),
)
@contextlib.contextmanager
def in_env(prefix):
def in_env(prefix: Prefix) -> Generator[None, None, None]:
envdir = prefix.path(
helpers.environment_dir(ENVIRONMENT_DIR, C.DEFAULT),
)
@ -33,7 +41,7 @@ def in_env(prefix):
yield
def guess_go_dir(remote_url):
def guess_go_dir(remote_url: str) -> str:
if remote_url.endswith('.git'):
remote_url = remote_url[:-1 * len('.git')]
looks_like_url = (
@ -49,7 +57,11 @@ def guess_go_dir(remote_url):
return 'unknown_src_dir'
def install_environment(prefix, version, additional_dependencies):
def install_environment(
prefix: Prefix,
version: str,
additional_dependencies: Sequence[str],
) -> None:
helpers.assert_version_default('golang', version)
directory = prefix.path(
helpers.environment_dir(ENVIRONMENT_DIR, C.DEFAULT),
@ -79,6 +91,10 @@ def install_environment(prefix, version, additional_dependencies):
rmtree(pkgdir)
def run_hook(hook, file_args, color):
def run_hook(
hook: 'Hook',
file_args: Sequence[str],
color: bool,
) -> Tuple[int, bytes]:
with in_env(hook.prefix):
return helpers.run_xargs(hook, hook.cmd, file_args, color=color)

View file

@ -1,33 +1,54 @@
import multiprocessing
import os
import random
from typing import Any
from typing import List
from typing import NoReturn
from typing import Optional
from typing import overload
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
import pre_commit.constants as C
from pre_commit.prefix import Prefix
from pre_commit.util import cmd_output_b
from pre_commit.xargs import xargs
if TYPE_CHECKING:
from pre_commit.repository import Hook
FIXED_RANDOM_SEED = 1542676186
def run_setup_cmd(prefix, cmd):
def run_setup_cmd(prefix: Prefix, cmd: Tuple[str, ...]) -> None:
cmd_output_b(*cmd, cwd=prefix.prefix_dir)
def environment_dir(ENVIRONMENT_DIR, language_version):
if ENVIRONMENT_DIR is None:
@overload
def environment_dir(d: None, language_version: str) -> None: ...
@overload
def environment_dir(d: str, language_version: str) -> str: ...
def environment_dir(d: Optional[str], language_version: str) -> Optional[str]:
if d is None:
return None
else:
return f'{ENVIRONMENT_DIR}-{language_version}'
return f'{d}-{language_version}'
def assert_version_default(binary, version):
def assert_version_default(binary: str, version: str) -> None:
if version != C.DEFAULT:
raise AssertionError(
f'For now, pre-commit requires system-installed {binary}',
)
def assert_no_additional_deps(lang, additional_deps):
def assert_no_additional_deps(
lang: str,
additional_deps: Sequence[str],
) -> None:
if additional_deps:
raise AssertionError(
'For now, pre-commit does not support '
@ -35,19 +56,23 @@ def assert_no_additional_deps(lang, additional_deps):
)
def basic_get_default_version():
def basic_get_default_version() -> str:
return C.DEFAULT
def basic_healthy(prefix, language_version):
def basic_healthy(prefix: Prefix, language_version: str) -> bool:
return True
def no_install(prefix, version, additional_dependencies):
def no_install(
prefix: Prefix,
version: str,
additional_dependencies: Sequence[str],
) -> NoReturn:
raise AssertionError('This type is not installable')
def target_concurrency(hook):
def target_concurrency(hook: 'Hook') -> int:
if hook.require_serial or 'PRE_COMMIT_NO_CONCURRENCY' in os.environ:
return 1
else:
@ -61,8 +86,8 @@ def target_concurrency(hook):
return 1
def _shuffled(seq):
"""Deterministically shuffle identically under both py2 + py3."""
def _shuffled(seq: Sequence[str]) -> List[str]:
"""Deterministically shuffle"""
fixed_random = random.Random()
fixed_random.seed(FIXED_RANDOM_SEED, version=1)
@ -71,7 +96,12 @@ def _shuffled(seq):
return seq
def run_xargs(hook, cmd, file_args, **kwargs):
def run_xargs(
hook: 'Hook',
cmd: Tuple[str, ...],
file_args: Sequence[str],
**kwargs: Any,
) -> Tuple[int, bytes]:
# Shuffle the files so that they more evenly fill out the xargs partitions,
# but do it deterministically in case a hook cares about ordering.
file_args = _shuffled(file_args)

View file

@ -1,28 +1,36 @@
import contextlib
import os
import sys
from typing import Generator
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
import pre_commit.constants as C
from pre_commit.envcontext import envcontext
from pre_commit.envcontext import PatchesT
from pre_commit.envcontext import Var
from pre_commit.languages import helpers
from pre_commit.languages.python import bin_dir
from pre_commit.prefix import Prefix
from pre_commit.util import clean_path_on_failure
from pre_commit.util import cmd_output
from pre_commit.util import cmd_output_b
if TYPE_CHECKING:
from pre_commit.repository import Hook
ENVIRONMENT_DIR = 'node_env'
get_default_version = helpers.basic_get_default_version
healthy = helpers.basic_healthy
def _envdir(prefix, version):
def _envdir(prefix: Prefix, version: str) -> str:
directory = helpers.environment_dir(ENVIRONMENT_DIR, version)
return prefix.path(directory)
def get_env_patch(venv): # pragma: windows no cover
def get_env_patch(venv: str) -> PatchesT: # pragma: windows no cover
if sys.platform == 'cygwin': # pragma: no cover
_, win_venv, _ = cmd_output('cygpath', '-w', venv)
install_prefix = r'{}\bin'.format(win_venv.strip())
@ -43,14 +51,17 @@ def get_env_patch(venv): # pragma: windows no cover
@contextlib.contextmanager
def in_env(prefix, language_version): # pragma: windows no cover
def in_env(
prefix: Prefix,
language_version: str,
) -> Generator[None, None, None]: # pragma: windows no cover
with envcontext(get_env_patch(_envdir(prefix, language_version))):
yield
def install_environment(
prefix, version, additional_dependencies,
): # pragma: windows no cover
prefix: Prefix, version: str, additional_dependencies: Sequence[str],
) -> None: # pragma: windows no cover
additional_dependencies = tuple(additional_dependencies)
assert prefix.exists('package.json')
envdir = _envdir(prefix, version)
@ -76,6 +87,10 @@ def install_environment(
)
def run_hook(hook, file_args, color): # pragma: windows no cover
def run_hook(
hook: 'Hook',
file_args: Sequence[str],
color: bool,
) -> Tuple[int, bytes]: # pragma: windows no cover
with in_env(hook.prefix, hook.language_version):
return helpers.run_xargs(hook, hook.cmd, file_args, color=color)

View file

@ -1,11 +1,18 @@
import argparse
import re
import sys
from typing import Optional
from typing import Pattern
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from pre_commit import output
from pre_commit.languages import helpers
from pre_commit.xargs import xargs
if TYPE_CHECKING:
from pre_commit.repository import Hook
ENVIRONMENT_DIR = None
get_default_version = helpers.basic_get_default_version
@ -13,7 +20,7 @@ healthy = helpers.basic_healthy
install_environment = helpers.no_install
def _process_filename_by_line(pattern, filename):
def _process_filename_by_line(pattern: Pattern[bytes], filename: str) -> int:
retv = 0
with open(filename, 'rb') as f:
for line_no, line in enumerate(f, start=1):
@ -24,7 +31,7 @@ def _process_filename_by_line(pattern, filename):
return retv
def _process_filename_at_once(pattern, filename):
def _process_filename_at_once(pattern: Pattern[bytes], filename: str) -> int:
retv = 0
with open(filename, 'rb') as f:
contents = f.read()
@ -41,12 +48,16 @@ def _process_filename_at_once(pattern, filename):
return retv
def run_hook(hook, file_args, color):
def run_hook(
hook: 'Hook',
file_args: Sequence[str],
color: bool,
) -> Tuple[int, bytes]:
exe = (sys.executable, '-m', __name__) + tuple(hook.args) + (hook.entry,)
return xargs(exe, file_args, color=color)
def main(argv=None):
def main(argv: Optional[Sequence[str]] = None) -> int:
parser = argparse.ArgumentParser(
description=(
'grep-like finder using python regexes. Unlike grep, this tool '

View file

@ -2,29 +2,40 @@ import contextlib
import functools
import os
import sys
from typing import Callable
from typing import ContextManager
from typing import Generator
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
import pre_commit.constants as C
from pre_commit.envcontext import envcontext
from pre_commit.envcontext import PatchesT
from pre_commit.envcontext import UNSET
from pre_commit.envcontext import Var
from pre_commit.languages import helpers
from pre_commit.parse_shebang import find_executable
from pre_commit.prefix import Prefix
from pre_commit.util import CalledProcessError
from pre_commit.util import clean_path_on_failure
from pre_commit.util import cmd_output
from pre_commit.util import cmd_output_b
if TYPE_CHECKING:
from pre_commit.repository import Hook
ENVIRONMENT_DIR = 'py_env'
def bin_dir(venv):
def bin_dir(venv: str) -> str:
"""On windows there's a different directory for the virtualenv"""
bin_part = 'Scripts' if os.name == 'nt' else 'bin'
return os.path.join(venv, bin_part)
def get_env_patch(venv):
def get_env_patch(venv: str) -> PatchesT:
return (
('PYTHONHOME', UNSET),
('VIRTUAL_ENV', venv),
@ -32,7 +43,9 @@ def get_env_patch(venv):
)
def _find_by_py_launcher(version): # pragma: no cover (windows only)
def _find_by_py_launcher(
version: str,
) -> Optional[str]: # pragma: no cover (windows only)
if version.startswith('python'):
try:
return cmd_output(
@ -41,14 +54,16 @@ def _find_by_py_launcher(version): # pragma: no cover (windows only)
)[1].strip()
except CalledProcessError:
pass
return None
def _find_by_sys_executable():
def _norm(path):
def _find_by_sys_executable() -> Optional[str]:
def _norm(path: str) -> Optional[str]:
_, exe = os.path.split(path.lower())
exe, _, _ = exe.partition('.exe')
if find_executable(exe) and exe not in {'python', 'pythonw'}:
return exe
return None
# On linux, I see these common sys.executables:
#
@ -66,7 +81,7 @@ def _find_by_sys_executable():
@functools.lru_cache(maxsize=1)
def get_default_version(): # pragma: no cover (platform dependent)
def get_default_version() -> str: # pragma: no cover (platform dependent)
# First attempt from `sys.executable` (or the realpath)
exe = _find_by_sys_executable()
if exe:
@ -88,7 +103,7 @@ def get_default_version(): # pragma: no cover (platform dependent)
return C.DEFAULT
def _sys_executable_matches(version):
def _sys_executable_matches(version: str) -> bool:
if version == 'python':
return True
elif not version.startswith('python'):
@ -102,7 +117,7 @@ def _sys_executable_matches(version):
return sys.version_info[:len(info)] == info
def norm_version(version):
def norm_version(version: str) -> str:
# first see if our current executable is appropriate
if _sys_executable_matches(version):
return sys.executable
@ -126,14 +141,25 @@ def norm_version(version):
return os.path.expanduser(version)
def py_interface(_dir, _make_venv):
def py_interface(
_dir: str,
_make_venv: Callable[[str, str], None],
) -> Tuple[
Callable[[Prefix, str], ContextManager[None]],
Callable[[Prefix, str], bool],
Callable[['Hook', Sequence[str], bool], Tuple[int, bytes]],
Callable[[Prefix, str, Sequence[str]], None],
]:
@contextlib.contextmanager
def in_env(prefix, language_version):
def in_env(
prefix: Prefix,
language_version: str,
) -> Generator[None, None, None]:
envdir = prefix.path(helpers.environment_dir(_dir, language_version))
with envcontext(get_env_patch(envdir)):
yield
def healthy(prefix, language_version):
def healthy(prefix: Prefix, language_version: str) -> bool:
with in_env(prefix, language_version):
retcode, _, _ = cmd_output_b(
'python', '-c',
@ -143,11 +169,19 @@ def py_interface(_dir, _make_venv):
)
return retcode == 0
def run_hook(hook, file_args, color):
def run_hook(
hook: 'Hook',
file_args: Sequence[str],
color: bool,
) -> Tuple[int, bytes]:
with in_env(hook.prefix, hook.language_version):
return helpers.run_xargs(hook, hook.cmd, file_args, color=color)
def install_environment(prefix, version, additional_dependencies):
def install_environment(
prefix: Prefix,
version: str,
additional_dependencies: Sequence[str],
) -> None:
additional_dependencies = tuple(additional_dependencies)
directory = helpers.environment_dir(_dir, version)
@ -166,7 +200,7 @@ def py_interface(_dir, _make_venv):
return in_env, healthy, run_hook, install_environment
def make_venv(envdir, python):
def make_venv(envdir: str, python: str) -> None:
env = dict(os.environ, VIRTUALENV_NO_DOWNLOAD='1')
cmd = (sys.executable, '-mvirtualenv', envdir, '-p', python)
cmd_output_b(*cmd, env=env, cwd='/')

View file

@ -5,15 +5,11 @@ from pre_commit.util import CalledProcessError
from pre_commit.util import cmd_output
from pre_commit.util import cmd_output_b
ENVIRONMENT_DIR = 'py_venv'
get_default_version = python.get_default_version
def get_default_version(): # pragma: no cover (version specific)
return python.get_default_version()
def orig_py_exe(exe): # pragma: no cover (platform specific)
def orig_py_exe(exe: str) -> str: # pragma: no cover (platform specific)
"""A -mvenv virtualenv made from a -mvirtualenv virtualenv installs
packages to the incorrect location. Attempt to find the _original_ exe
and invoke `-mvenv` from there.
@ -42,7 +38,7 @@ def orig_py_exe(exe): # pragma: no cover (platform specific)
return exe
def make_venv(envdir, python):
def make_venv(envdir: str, python: str) -> None:
cmd_output_b(orig_py_exe(python), '-mvenv', envdir, cwd='/')

View file

@ -2,23 +2,33 @@ import contextlib
import os.path
import shutil
import tarfile
from typing import Generator
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
import pre_commit.constants as C
from pre_commit.envcontext import envcontext
from pre_commit.envcontext import PatchesT
from pre_commit.envcontext import Var
from pre_commit.languages import helpers
from pre_commit.prefix import Prefix
from pre_commit.util import CalledProcessError
from pre_commit.util import clean_path_on_failure
from pre_commit.util import resource_bytesio
if TYPE_CHECKING:
from pre_comit.repository import Hook
ENVIRONMENT_DIR = 'rbenv'
get_default_version = helpers.basic_get_default_version
healthy = helpers.basic_healthy
def get_env_patch(venv, language_version): # pragma: windows no cover
def get_env_patch(
venv: str,
language_version: str,
) -> PatchesT: # pragma: windows no cover
patches: PatchesT = (
('GEM_HOME', os.path.join(venv, 'gems')),
('RBENV_ROOT', venv),
@ -36,8 +46,11 @@ def get_env_patch(venv, language_version): # pragma: windows no cover
return patches
@contextlib.contextmanager
def in_env(prefix, language_version): # pragma: windows no cover
@contextlib.contextmanager # pragma: windows no cover
def in_env(
prefix: Prefix,
language_version: str,
) -> Generator[None, None, None]:
envdir = prefix.path(
helpers.environment_dir(ENVIRONMENT_DIR, language_version),
)
@ -45,13 +58,16 @@ def in_env(prefix, language_version): # pragma: windows no cover
yield
def _extract_resource(filename, dest):
def _extract_resource(filename: str, dest: str) -> None:
with resource_bytesio(filename) as bio:
with tarfile.open(fileobj=bio) as tf:
tf.extractall(dest)
def _install_rbenv(prefix, version=C.DEFAULT): # pragma: windows no cover
def _install_rbenv(
prefix: Prefix,
version: str = C.DEFAULT,
) -> None: # pragma: windows no cover
directory = helpers.environment_dir(ENVIRONMENT_DIR, version)
_extract_resource('rbenv.tar.gz', prefix.path('.'))
@ -87,7 +103,10 @@ def _install_rbenv(prefix, version=C.DEFAULT): # pragma: windows no cover
activate_file.write(f'export RBENV_VERSION="{version}"\n')
def _install_ruby(prefix, version): # pragma: windows no cover
def _install_ruby(
prefix: Prefix,
version: str,
) -> None: # pragma: windows no cover
try:
helpers.run_setup_cmd(prefix, ('rbenv', 'download', version))
except CalledProcessError: # pragma: no cover (usually find with download)
@ -96,8 +115,8 @@ def _install_ruby(prefix, version): # pragma: windows no cover
def install_environment(
prefix, version, additional_dependencies,
): # pragma: windows no cover
prefix: Prefix, version: str, additional_dependencies: Sequence[str],
) -> None: # pragma: windows no cover
additional_dependencies = tuple(additional_dependencies)
directory = helpers.environment_dir(ENVIRONMENT_DIR, version)
with clean_path_on_failure(prefix.path(directory)):
@ -122,6 +141,10 @@ def install_environment(
)
def run_hook(hook, file_args, color): # pragma: windows no cover
def run_hook(
hook: 'Hook',
file_args: Sequence[str],
color: bool,
) -> Tuple[int, bytes]: # pragma: windows no cover
with in_env(hook.prefix, hook.language_version):
return helpers.run_xargs(hook, hook.cmd, file_args, color=color)

View file

@ -1,24 +1,31 @@
import contextlib
import os.path
from typing import Generator
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import TYPE_CHECKING
import toml
import pre_commit.constants as C
from pre_commit.envcontext import envcontext
from pre_commit.envcontext import PatchesT
from pre_commit.envcontext import Var
from pre_commit.languages import helpers
from pre_commit.prefix import Prefix
from pre_commit.util import clean_path_on_failure
from pre_commit.util import cmd_output_b
if TYPE_CHECKING:
from pre_commit.repository import Hook
ENVIRONMENT_DIR = 'rustenv'
get_default_version = helpers.basic_get_default_version
healthy = helpers.basic_healthy
def get_env_patch(target_dir):
def get_env_patch(target_dir: str) -> PatchesT:
return (
(
'PATH',
@ -28,7 +35,7 @@ def get_env_patch(target_dir):
@contextlib.contextmanager
def in_env(prefix):
def in_env(prefix: Prefix) -> Generator[None, None, None]:
target_dir = prefix.path(
helpers.environment_dir(ENVIRONMENT_DIR, C.DEFAULT),
)
@ -36,7 +43,10 @@ def in_env(prefix):
yield
def _add_dependencies(cargo_toml_path, additional_dependencies):
def _add_dependencies(
cargo_toml_path: str,
additional_dependencies: Set[str],
) -> None:
with open(cargo_toml_path, 'r+') as f:
cargo_toml = toml.load(f)
cargo_toml.setdefault('dependencies', {})
@ -48,7 +58,11 @@ def _add_dependencies(cargo_toml_path, additional_dependencies):
f.truncate()
def install_environment(prefix, version, additional_dependencies):
def install_environment(
prefix: Prefix,
version: str,
additional_dependencies: Sequence[str],
) -> None:
helpers.assert_version_default('rust', version)
directory = prefix.path(
helpers.environment_dir(ENVIRONMENT_DIR, C.DEFAULT),
@ -82,13 +96,17 @@ def install_environment(prefix, version, additional_dependencies):
else:
packages_to_install.add((package,))
for package in packages_to_install:
for args in packages_to_install:
cmd_output_b(
'cargo', 'install', '--bins', '--root', directory, *package,
'cargo', 'install', '--bins', '--root', directory, *args,
cwd=prefix.prefix_dir,
)
def run_hook(hook, file_args, color):
def run_hook(
hook: 'Hook',
file_args: Sequence[str],
color: bool,
) -> Tuple[int, bytes]:
with in_env(hook.prefix):
return helpers.run_xargs(hook, hook.cmd, file_args, color=color)

View file

@ -1,5 +1,11 @@
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from pre_commit.languages import helpers
if TYPE_CHECKING:
from pre_commit.repository import Hook
ENVIRONMENT_DIR = None
get_default_version = helpers.basic_get_default_version
@ -7,7 +13,11 @@ healthy = helpers.basic_healthy
install_environment = helpers.no_install
def run_hook(hook, file_args, color):
def run_hook(
hook: 'Hook',
file_args: Sequence[str],
color: bool,
) -> Tuple[int, bytes]:
cmd = hook.cmd
cmd = (hook.prefix.path(cmd[0]),) + cmd[1:]
return helpers.run_xargs(hook, cmd, file_args, color=color)

View file

@ -1,13 +1,22 @@
import contextlib
import os
from typing import Generator
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
import pre_commit.constants as C
from pre_commit.envcontext import envcontext
from pre_commit.envcontext import PatchesT
from pre_commit.envcontext import Var
from pre_commit.languages import helpers
from pre_commit.prefix import Prefix
from pre_commit.util import clean_path_on_failure
from pre_commit.util import cmd_output_b
if TYPE_CHECKING:
from pre_commit.repository import Hook
ENVIRONMENT_DIR = 'swift_env'
get_default_version = helpers.basic_get_default_version
healthy = helpers.basic_healthy
@ -15,13 +24,13 @@ BUILD_DIR = '.build'
BUILD_CONFIG = 'release'
def get_env_patch(venv): # pragma: windows no cover
def get_env_patch(venv: str) -> PatchesT: # pragma: windows no cover
bin_path = os.path.join(venv, BUILD_DIR, BUILD_CONFIG)
return (('PATH', (bin_path, os.pathsep, Var('PATH'))),)
@contextlib.contextmanager
def in_env(prefix): # pragma: windows no cover
@contextlib.contextmanager # pragma: windows no cover
def in_env(prefix: Prefix) -> Generator[None, None, None]:
envdir = prefix.path(
helpers.environment_dir(ENVIRONMENT_DIR, C.DEFAULT),
)
@ -30,8 +39,8 @@ def in_env(prefix): # pragma: windows no cover
def install_environment(
prefix, version, additional_dependencies,
): # pragma: windows no cover
prefix: Prefix, version: str, additional_dependencies: Sequence[str],
) -> None: # pragma: windows no cover
helpers.assert_version_default('swift', version)
helpers.assert_no_additional_deps('swift', additional_dependencies)
directory = prefix.path(
@ -49,6 +58,10 @@ def install_environment(
)
def run_hook(hook, file_args, color): # pragma: windows no cover
def run_hook(
hook: 'Hook',
file_args: Sequence[str],
color: bool,
) -> Tuple[int, bytes]: # pragma: windows no cover
with in_env(hook.prefix):
return helpers.run_xargs(hook, hook.cmd, file_args, color=color)

View file

@ -1,5 +1,12 @@
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from pre_commit.languages import helpers
if TYPE_CHECKING:
from pre_commit.repository import Hook
ENVIRONMENT_DIR = None
get_default_version = helpers.basic_get_default_version
@ -7,5 +14,9 @@ healthy = helpers.basic_healthy
install_environment = helpers.no_install
def run_hook(hook, file_args, color):
def run_hook(
hook: 'Hook',
file_args: Sequence[str],
color: bool,
) -> Tuple[int, bytes]:
return helpers.run_xargs(hook, hook.cmd, file_args, color=color)

View file

@ -1,10 +1,10 @@
import contextlib
import logging
from typing import Generator
from pre_commit import color
from pre_commit import output
logger = logging.getLogger('pre_commit')
LOG_LEVEL_COLORS = {
@ -16,11 +16,11 @@ LOG_LEVEL_COLORS = {
class LoggingHandler(logging.Handler):
def __init__(self, use_color):
def __init__(self, use_color: bool) -> None:
super().__init__()
self.use_color = use_color
def emit(self, record):
def emit(self, record: logging.LogRecord) -> None:
output.write_line(
'{} {}'.format(
color.format_color(
@ -34,8 +34,8 @@ class LoggingHandler(logging.Handler):
@contextlib.contextmanager
def logging_handler(*args, **kwargs):
handler = LoggingHandler(*args, **kwargs)
def logging_handler(use_color: bool) -> Generator[None, None, None]:
handler = LoggingHandler(use_color)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
try:

View file

@ -2,6 +2,10 @@ import argparse
import logging
import os
import sys
from typing import Any
from typing import Optional
from typing import Sequence
from typing import Union
import pre_commit.constants as C
from pre_commit import color
@ -37,7 +41,7 @@ os.environ.pop('__PYVENV_LAUNCHER__', None)
COMMANDS_NO_GIT = {'clean', 'gc', 'init-templatedir', 'sample-config'}
def _add_color_option(parser):
def _add_color_option(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
'--color', default=os.environ.get('PRE_COMMIT_COLOR', 'auto'),
type=color.use_color,
@ -46,7 +50,7 @@ def _add_color_option(parser):
)
def _add_config_option(parser):
def _add_config_option(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
'-c', '--config', default=C.CONFIG_FILE,
help='Path to alternate config file',
@ -54,18 +58,24 @@ def _add_config_option(parser):
class AppendReplaceDefault(argparse.Action):
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.appended = False
def __call__(self, parser, namespace, values, option_string=None):
def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Union[str, Sequence[str], None],
option_string: Optional[str] = None,
) -> None:
if not self.appended:
setattr(namespace, self.dest, [])
self.appended = True
getattr(namespace, self.dest).append(values)
def _add_hook_type_option(parser):
def _add_hook_type_option(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
'-t', '--hook-type', choices=(
'pre-commit', 'pre-merge-commit', 'pre-push',
@ -77,7 +87,7 @@ def _add_hook_type_option(parser):
)
def _add_run_options(parser):
def _add_run_options(parser: argparse.ArgumentParser) -> None:
parser.add_argument('hook', nargs='?', help='A single hook-id to run')
parser.add_argument('--verbose', '-v', action='store_true', default=False)
parser.add_argument(
@ -111,7 +121,7 @@ def _add_run_options(parser):
)
def _adjust_args_and_chdir(args):
def _adjust_args_and_chdir(args: argparse.Namespace) -> None:
# `--config` was specified relative to the non-root working directory
if os.path.exists(args.config):
args.config = os.path.abspath(args.config)
@ -143,7 +153,7 @@ def _adjust_args_and_chdir(args):
args.repo = os.path.relpath(args.repo)
def main(argv=None):
def main(argv: Optional[Sequence[str]] = None) -> int:
argv = argv if argv is not None else sys.argv[1:]
argv = [five.to_text(arg) for arg in argv]
parser = argparse.ArgumentParser(prog='pre-commit')

View file

@ -1,6 +1,8 @@
import argparse
import os.path
import tarfile
from typing import Optional
from typing import Sequence
from pre_commit import output
from pre_commit.util import cmd_output_b
@ -23,7 +25,7 @@ REPOS = (
)
def make_archive(name, repo, ref, destdir):
def make_archive(name: str, repo: str, ref: str, destdir: str) -> str:
"""Makes an archive of a repository in the given destdir.
:param text name: Name to give the archive. For instance foo. The file
@ -49,7 +51,7 @@ def make_archive(name, repo, ref, destdir):
return output_path
def main(argv=None):
def main(argv: Optional[Sequence[str]] = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('--dest', default='pre_commit/resources')
args = parser.parse_args(argv)
@ -58,6 +60,7 @@ def main(argv=None):
f'Making {archive_name}.tar.gz for {repo}@{ref}',
)
make_archive(archive_name, repo, ref, args.dest)
return 0
if __name__ == '__main__':

View file

@ -1,4 +1,6 @@
import argparse
from typing import Optional
from typing import Sequence
import pre_commit.constants as C
from pre_commit import git
@ -8,7 +10,7 @@ from pre_commit.repository import all_hooks
from pre_commit.store import Store
def check_all_hooks_match_files(config_file):
def check_all_hooks_match_files(config_file: str) -> int:
classifier = Classifier(git.get_all_files())
retv = 0
@ -22,7 +24,7 @@ def check_all_hooks_match_files(config_file):
return retv
def main(argv=None):
def main(argv: Optional[Sequence[str]] = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', default=[C.CONFIG_FILE])
args = parser.parse_args(argv)

View file

@ -1,5 +1,7 @@
import argparse
import re
from typing import Optional
from typing import Sequence
from cfgv import apply_defaults
@ -10,7 +12,11 @@ from pre_commit.clientlib import MANIFEST_HOOK_DICT
from pre_commit.commands.run import Classifier
def exclude_matches_any(filenames, include, exclude):
def exclude_matches_any(
filenames: Sequence[str],
include: str,
exclude: str,
) -> bool:
if exclude == '^$':
return True
include_re, exclude_re = re.compile(include), re.compile(exclude)
@ -20,7 +26,7 @@ def exclude_matches_any(filenames, include, exclude):
return False
def check_useless_excludes(config_file):
def check_useless_excludes(config_file: str) -> int:
config = load_config(config_file)
classifier = Classifier(git.get_all_files())
retv = 0
@ -52,7 +58,7 @@ def check_useless_excludes(config_file):
return retv
def main(argv=None):
def main(argv: Optional[Sequence[str]] = None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', default=[C.CONFIG_FILE])
args = parser.parse_args(argv)

View file

@ -1,12 +1,15 @@
import sys
from typing import Optional
from typing import Sequence
from pre_commit import output
def main(argv=None):
def main(argv: Optional[Sequence[str]] = None) -> int:
argv = argv if argv is not None else sys.argv[1:]
for arg in argv:
output.write_line(arg)
return 0
if __name__ == '__main__':

View file

@ -1,19 +1,22 @@
import contextlib
import sys
from typing import IO
from typing import Optional
from typing import Union
from pre_commit import color
from pre_commit import five
def get_hook_message(
start,
postfix='',
end_msg=None,
end_len=0,
end_color=None,
use_color=None,
cols=80,
):
start: str,
postfix: str = '',
end_msg: Optional[str] = None,
end_len: int = 0,
end_color: Optional[str] = None,
use_color: Optional[bool] = None,
cols: int = 80,
) -> str:
"""Prints a message for running a hook.
This currently supports three approaches:
@ -44,16 +47,13 @@ def get_hook_message(
)
start...........................................................postfix end
"""
if bool(end_msg) == bool(end_len):
raise ValueError('Expected one of (`end_msg`, `end_len`)')
if end_msg is not None and (end_color is None or use_color is None):
raise ValueError(
'`end_color` and `use_color` are required with `end_msg`',
)
if end_len:
assert end_msg is None, end_msg
return start + '.' * (cols - len(start) - end_len - 1)
else:
assert end_msg is not None
assert end_color is not None
assert use_color is not None
return '{}{}{}{}\n'.format(
start,
'.' * (cols - len(start) - len(postfix) - len(end_msg) - 1),
@ -62,15 +62,16 @@ def get_hook_message(
)
stdout_byte_stream = getattr(sys.stdout, 'buffer', sys.stdout)
def write(s, stream=stdout_byte_stream):
def write(s: str, stream: IO[bytes] = sys.stdout.buffer) -> None:
stream.write(five.to_bytes(s))
stream.flush()
def write_line(s=None, stream=stdout_byte_stream, logfile_name=None):
def write_line(
s: Union[None, str, bytes] = None,
stream: IO[bytes] = sys.stdout.buffer,
logfile_name: Optional[str] = None,
) -> None:
with contextlib.ExitStack() as exit_stack:
output_streams = [stream]
if logfile_name:

View file

@ -1,21 +1,28 @@
import os.path
from typing import Mapping
from typing import NoReturn
from typing import Optional
from typing import Tuple
from identify.identify import parse_shebang_from_file
class ExecutableNotFoundError(OSError):
def to_output(self):
return (1, self.args[0].encode('UTF-8'), b'')
def to_output(self) -> Tuple[int, bytes, None]:
return (1, self.args[0].encode('UTF-8'), None)
def parse_filename(filename):
def parse_filename(filename: str) -> Tuple[str, ...]:
if not os.path.exists(filename):
return ()
else:
return parse_shebang_from_file(filename)
def find_executable(exe, _environ=None):
def find_executable(
exe: str,
_environ: Optional[Mapping[str, str]] = None,
) -> Optional[str]:
exe = os.path.normpath(exe)
if os.sep in exe:
return exe
@ -39,8 +46,8 @@ def find_executable(exe, _environ=None):
return None
def normexe(orig):
def _error(msg):
def normexe(orig: str) -> str:
def _error(msg: str) -> NoReturn:
raise ExecutableNotFoundError(f'Executable `{orig}` {msg}')
if os.sep not in orig and (not os.altsep or os.altsep not in orig):
@ -58,7 +65,7 @@ def normexe(orig):
return orig
def normalize_cmd(cmd):
def normalize_cmd(cmd: Tuple[str, ...]) -> Tuple[str, ...]:
"""Fixes for the following issues on windows
- https://bugs.python.org/issue8557
- windows does not parse shebangs

View file

@ -1,16 +1,17 @@
import collections
import os.path
from typing import NamedTuple
from typing import Tuple
class Prefix(collections.namedtuple('Prefix', ('prefix_dir',))):
__slots__ = ()
class Prefix(NamedTuple):
prefix_dir: str
def path(self, *parts):
def path(self, *parts: str) -> str:
return os.path.normpath(os.path.join(self.prefix_dir, *parts))
def exists(self, *parts):
def exists(self, *parts: str) -> bool:
return os.path.exists(self.path(*parts))
def star(self, end):
def star(self, end: str) -> Tuple[str, ...]:
paths = os.listdir(self.prefix_dir)
return tuple(path for path in paths if path.endswith(end))

View file

@ -2,9 +2,14 @@ import json
import logging
import os
import shlex
from typing import Any
from typing import Dict
from typing import List
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
import pre_commit.constants as C
from pre_commit import five
@ -15,6 +20,7 @@ from pre_commit.clientlib import META
from pre_commit.languages.all import languages
from pre_commit.languages.helpers import environment_dir
from pre_commit.prefix import Prefix
from pre_commit.store import Store
from pre_commit.util import parse_version
from pre_commit.util import rmtree
@ -22,15 +28,15 @@ from pre_commit.util import rmtree
logger = logging.getLogger('pre_commit')
def _state(additional_deps):
def _state(additional_deps: Sequence[str]) -> object:
return {'additional_dependencies': sorted(additional_deps)}
def _state_filename(prefix, venv):
def _state_filename(prefix: Prefix, venv: str) -> str:
return prefix.path(venv, '.install_state_v' + C.INSTALLED_STATE_VERSION)
def _read_state(prefix, venv):
def _read_state(prefix: Prefix, venv: str) -> Optional[object]:
filename = _state_filename(prefix, venv)
if not os.path.exists(filename):
return None
@ -39,7 +45,7 @@ def _read_state(prefix, venv):
return json.load(f)
def _write_state(prefix, venv, state):
def _write_state(prefix: Prefix, venv: str, state: object) -> None:
state_filename = _state_filename(prefix, venv)
staging = state_filename + 'staging'
with open(staging, 'w') as state_file:
@ -76,11 +82,11 @@ class Hook(NamedTuple):
verbose: bool
@property
def cmd(self):
def cmd(self) -> Tuple[str, ...]:
return tuple(shlex.split(self.entry)) + tuple(self.args)
@property
def install_key(self):
def install_key(self) -> Tuple[Prefix, str, str, Tuple[str, ...]]:
return (
self.prefix,
self.language,
@ -88,7 +94,7 @@ class Hook(NamedTuple):
tuple(self.additional_dependencies),
)
def installed(self):
def installed(self) -> bool:
lang = languages[self.language]
venv = environment_dir(lang.ENVIRONMENT_DIR, self.language_version)
return (
@ -101,7 +107,7 @@ class Hook(NamedTuple):
)
)
def install(self):
def install(self) -> None:
logger.info(f'Installing environment for {self.src}.')
logger.info('Once installed this environment will be reused.')
logger.info('This may take a few minutes...')
@ -120,12 +126,12 @@ class Hook(NamedTuple):
# Write our state to indicate we're installed
_write_state(self.prefix, venv, _state(self.additional_dependencies))
def run(self, file_args, color):
def run(self, file_args: Sequence[str], color: bool) -> Tuple[int, bytes]:
lang = languages[self.language]
return lang.run_hook(self, file_args, color)
@classmethod
def create(cls, src, prefix, dct):
def create(cls, src: str, prefix: Prefix, dct: Dict[str, Any]) -> 'Hook':
# TODO: have cfgv do this (?)
extra_keys = set(dct) - set(_KEYS)
if extra_keys:
@ -136,9 +142,10 @@ class Hook(NamedTuple):
return cls(src=src, prefix=prefix, **{k: dct[k] for k in _KEYS})
def _hook(*hook_dicts, **kwargs):
root_config = kwargs.pop('root_config')
assert not kwargs, kwargs
def _hook(
*hook_dicts: Dict[str, Any],
root_config: Dict[str, Any],
) -> Dict[str, Any]:
ret, rest = dict(hook_dicts[0]), hook_dicts[1:]
for dct in rest:
ret.update(dct)
@ -166,8 +173,12 @@ def _hook(*hook_dicts, **kwargs):
return ret
def _non_cloned_repository_hooks(repo_config, store, root_config):
def _prefix(language_name, deps):
def _non_cloned_repository_hooks(
repo_config: Dict[str, Any],
store: Store,
root_config: Dict[str, Any],
) -> Tuple[Hook, ...]:
def _prefix(language_name: str, deps: Sequence[str]) -> Prefix:
language = languages[language_name]
# pygrep / script / system / docker_image do not have
# environments so they work out of the current directory
@ -186,7 +197,11 @@ def _non_cloned_repository_hooks(repo_config, store, root_config):
)
def _cloned_repository_hooks(repo_config, store, root_config):
def _cloned_repository_hooks(
repo_config: Dict[str, Any],
store: Store,
root_config: Dict[str, Any],
) -> Tuple[Hook, ...]:
repo, rev = repo_config['repo'], repo_config['rev']
manifest_path = os.path.join(store.clone(repo, rev), C.MANIFEST_FILE)
by_id = {hook['id']: hook for hook in load_manifest(manifest_path)}
@ -215,16 +230,20 @@ def _cloned_repository_hooks(repo_config, store, root_config):
)
def _repository_hooks(repo_config, store, root_config):
def _repository_hooks(
repo_config: Dict[str, Any],
store: Store,
root_config: Dict[str, Any],
) -> Tuple[Hook, ...]:
if repo_config['repo'] in {LOCAL, META}:
return _non_cloned_repository_hooks(repo_config, store, root_config)
else:
return _cloned_repository_hooks(repo_config, store, root_config)
def install_hook_envs(hooks, store):
def _need_installed():
seen: Set[Hook] = set()
def install_hook_envs(hooks: Sequence[Hook], store: Store) -> None:
def _need_installed() -> List[Hook]:
seen: Set[Tuple[Prefix, str, str, Tuple[str, ...]]] = set()
ret = []
for hook in hooks:
if hook.install_key not in seen and not hook.installed():
@ -240,7 +259,7 @@ def install_hook_envs(hooks, store):
hook.install()
def all_hooks(root_config, store):
def all_hooks(root_config: Dict[str, Any], store: Store) -> Tuple[Hook, ...]:
return tuple(
hook
for repo in root_config['repos']

View file

@ -4,6 +4,8 @@ import distutils.spawn
import os
import subprocess
import sys
from typing import Callable
from typing import Dict
from typing import Tuple
# work around https://github.com/Homebrew/homebrew-core/issues/30445
@ -28,7 +30,7 @@ class FatalError(RuntimeError):
pass
def _norm_exe(exe):
def _norm_exe(exe: str) -> Tuple[str, ...]:
"""Necessary for shebang support on windows.
roughly lifted from `identify.identify.parse_shebang`
@ -47,7 +49,7 @@ def _norm_exe(exe):
return tuple(cmd)
def _run_legacy():
def _run_legacy() -> Tuple[int, bytes]:
if __file__.endswith('.legacy'):
raise SystemExit(
"bug: pre-commit's script is installed in migration mode\n"
@ -59,9 +61,9 @@ def _run_legacy():
)
if HOOK_TYPE == 'pre-push':
stdin = getattr(sys.stdin, 'buffer', sys.stdin).read()
stdin = sys.stdin.buffer.read()
else:
stdin = None
stdin = b''
legacy_hook = os.path.join(HERE, f'{HOOK_TYPE}.legacy')
if os.access(legacy_hook, os.X_OK):
@ -73,7 +75,7 @@ def _run_legacy():
return 0, stdin
def _validate_config():
def _validate_config() -> None:
cmd = ('git', 'rev-parse', '--show-toplevel')
top_level = subprocess.check_output(cmd).decode('UTF-8').strip()
cfg = os.path.join(top_level, CONFIG)
@ -97,7 +99,7 @@ def _validate_config():
)
def _exe():
def _exe() -> Tuple[str, ...]:
with open(os.devnull, 'wb') as devnull:
for exe in (INSTALL_PYTHON, sys.executable):
try:
@ -117,11 +119,11 @@ def _exe():
)
def _rev_exists(rev):
def _rev_exists(rev: str) -> bool:
return not subprocess.call(('git', 'rev-list', '--quiet', rev))
def _pre_push(stdin):
def _pre_push(stdin: bytes) -> Tuple[str, ...]:
remote = sys.argv[1]
opts: Tuple[str, ...] = ()
@ -158,8 +160,8 @@ def _pre_push(stdin):
raise EarlyExit()
def _opts(stdin):
fns = {
def _opts(stdin: bytes) -> Tuple[str, ...]:
fns: Dict[str, Callable[[bytes], Tuple[str, ...]]] = {
'prepare-commit-msg': lambda _: ('--commit-msg-filename', sys.argv[1]),
'commit-msg': lambda _: ('--commit-msg-filename', sys.argv[1]),
'pre-merge-commit': lambda _: (),
@ -171,13 +173,14 @@ def _opts(stdin):
if sys.version_info < (3, 7): # https://bugs.python.org/issue25942
def _subprocess_call(cmd): # this is the python 2.7 implementation
# this is the python 2.7 implementation
def _subprocess_call(cmd: Tuple[str, ...]) -> int:
return subprocess.Popen(cmd).wait()
else:
_subprocess_call = subprocess.call
def main():
def main() -> int:
retv, stdin = _run_legacy()
try:
_validate_config()

View file

@ -2,6 +2,7 @@ import contextlib
import logging
import os.path
import time
from typing import Generator
from pre_commit import git
from pre_commit.util import CalledProcessError
@ -14,7 +15,7 @@ from pre_commit.xargs import xargs
logger = logging.getLogger('pre_commit')
def _git_apply(patch):
def _git_apply(patch: str) -> None:
args = ('apply', '--whitespace=nowarn', patch)
try:
cmd_output_b('git', *args)
@ -24,7 +25,7 @@ def _git_apply(patch):
@contextlib.contextmanager
def _intent_to_add_cleared():
def _intent_to_add_cleared() -> Generator[None, None, None]:
intent_to_add = git.intent_to_add_files()
if intent_to_add:
logger.warning('Unstaged intent-to-add files detected.')
@ -39,7 +40,7 @@ def _intent_to_add_cleared():
@contextlib.contextmanager
def _unstaged_changes_cleared(patch_dir):
def _unstaged_changes_cleared(patch_dir: str) -> Generator[None, None, None]:
tree = cmd_output('git', 'write-tree')[1].strip()
retcode, diff_stdout_binary, _ = cmd_output_b(
'git', 'diff-index', '--ignore-submodules', '--binary',
@ -84,7 +85,7 @@ def _unstaged_changes_cleared(patch_dir):
@contextlib.contextmanager
def staged_files_only(patch_dir):
def staged_files_only(patch_dir: str) -> Generator[None, None, None]:
"""Clear any unstaged changes from the git working directory inside this
context.
"""

View file

@ -3,6 +3,12 @@ import logging
import os.path
import sqlite3
import tempfile
from typing import Callable
from typing import Generator
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
import pre_commit.constants as C
from pre_commit import file_lock
@ -18,7 +24,7 @@ from pre_commit.util import rmtree
logger = logging.getLogger('pre_commit')
def _get_default_directory():
def _get_default_directory() -> str:
"""Returns the default directory for the Store. This is intentionally
underscored to indicate that `Store.get_default_directory` is the intended
way to get this information. This is also done so
@ -34,7 +40,7 @@ def _get_default_directory():
class Store:
get_default_directory = staticmethod(_get_default_directory)
def __init__(self, directory=None):
def __init__(self, directory: Optional[str] = None) -> None:
self.directory = directory or Store.get_default_directory()
self.db_path = os.path.join(self.directory, 'db.db')
@ -66,21 +72,24 @@ class Store:
' PRIMARY KEY (repo, ref)'
');',
)
self._create_config_table_if_not_exists(db)
self._create_config_table(db)
# Atomic file move
os.rename(tmpfile, self.db_path)
@contextlib.contextmanager
def exclusive_lock(self):
def blocked_cb(): # pragma: no cover (tests are single-process)
def exclusive_lock(self) -> Generator[None, None, None]:
def blocked_cb() -> None: # pragma: no cover (tests are in-process)
logger.info('Locking pre-commit directory')
with file_lock.lock(os.path.join(self.directory, '.lock'), blocked_cb):
yield
@contextlib.contextmanager
def connect(self, db_path=None):
def connect(
self,
db_path: Optional[str] = None,
) -> Generator[sqlite3.Connection, None, None]:
db_path = db_path or self.db_path
# sqlite doesn't close its fd with its contextmanager >.<
# contextlib.closing fixes this.
@ -91,24 +100,29 @@ class Store:
yield db
@classmethod
def db_repo_name(cls, repo, deps):
def db_repo_name(cls, repo: str, deps: Sequence[str]) -> str:
if deps:
return '{}:{}'.format(repo, ','.join(sorted(deps)))
else:
return repo
def _new_repo(self, repo, ref, deps, make_strategy):
def _new_repo(
self,
repo: str,
ref: str,
deps: Sequence[str],
make_strategy: Callable[[str], None],
) -> str:
repo = self.db_repo_name(repo, deps)
def _get_result():
def _get_result() -> Optional[str]:
# Check if we already exist
with self.connect() as db:
result = db.execute(
'SELECT path FROM repos WHERE repo = ? AND ref = ?',
(repo, ref),
).fetchone()
if result:
return result[0]
return result[0] if result else None
result = _get_result()
if result:
@ -133,14 +147,14 @@ class Store:
)
return directory
def _complete_clone(self, ref, git_cmd):
def _complete_clone(self, ref: str, git_cmd: Callable[..., None]) -> None:
"""Perform a complete clone of a repository and its submodules """
git_cmd('fetch', 'origin', '--tags')
git_cmd('checkout', ref)
git_cmd('submodule', 'update', '--init', '--recursive')
def _shallow_clone(self, ref, git_cmd):
def _shallow_clone(self, ref: str, git_cmd: Callable[..., None]) -> None:
"""Perform a shallow clone of a repository and its submodules """
git_config = 'protocol.version=2'
@ -151,14 +165,14 @@ class Store:
'--depth=1',
)
def clone(self, repo, ref, deps=()):
def clone(self, repo: str, ref: str, deps: Sequence[str] = ()) -> str:
"""Clone the given url and checkout the specific ref."""
def clone_strategy(directory):
def clone_strategy(directory: str) -> None:
git.init_repo(directory, repo)
env = git.no_git_env()
def _git_cmd(*args):
def _git_cmd(*args: str) -> None:
cmd_output_b('git', *args, cwd=directory, env=env)
try:
@ -173,8 +187,8 @@ class Store:
'pre_commit_dummy_package.gemspec', 'setup.py', 'environment.yml',
)
def make_local(self, deps):
def make_local_strategy(directory):
def make_local(self, deps: Sequence[str]) -> str:
def make_local_strategy(directory: str) -> None:
for resource in self.LOCAL_RESOURCES:
contents = resource_text(f'empty_template_{resource}')
with open(os.path.join(directory, resource), 'w') as f:
@ -183,7 +197,7 @@ class Store:
env = git.no_git_env()
# initialize the git repository so it looks more like cloned repos
def _git_cmd(*args):
def _git_cmd(*args: str) -> None:
cmd_output_b('git', *args, cwd=directory, env=env)
git.init_repo(directory, '<<unknown>>')
@ -194,7 +208,7 @@ class Store:
'local', C.LOCAL_REPO_VERSION, deps, make_local_strategy,
)
def _create_config_table_if_not_exists(self, db):
def _create_config_table(self, db: sqlite3.Connection) -> None:
db.executescript(
'CREATE TABLE IF NOT EXISTS configs ('
' path TEXT NOT NULL,'
@ -202,32 +216,32 @@ class Store:
');',
)
def mark_config_used(self, path):
def mark_config_used(self, path: str) -> None:
path = os.path.realpath(path)
# don't insert config files that do not exist
if not os.path.exists(path):
return
with self.connect() as db:
# TODO: eventually remove this and only create in _create
self._create_config_table_if_not_exists(db)
self._create_config_table(db)
db.execute('INSERT OR IGNORE INTO configs VALUES (?)', (path,))
def select_all_configs(self):
def select_all_configs(self) -> List[str]:
with self.connect() as db:
self._create_config_table_if_not_exists(db)
self._create_config_table(db)
rows = db.execute('SELECT path FROM configs').fetchall()
return [path for path, in rows]
def delete_configs(self, configs):
def delete_configs(self, configs: List[str]) -> None:
with self.connect() as db:
rows = [(path,) for path in configs]
db.executemany('DELETE FROM configs WHERE path = ?', rows)
def select_all_repos(self):
def select_all_repos(self) -> List[Tuple[str, str, str]]:
with self.connect() as db:
return db.execute('SELECT repo, ref, path from repos').fetchall()
def delete_repo(self, db_repo_name, ref, path):
def delete_repo(self, db_repo_name: str, ref: str, path: str) -> None:
with self.connect() as db:
db.execute(
'DELETE FROM repos WHERE repo = ? and ref = ?',

View file

@ -6,6 +6,16 @@ import stat
import subprocess
import sys
import tempfile
from types import TracebackType
from typing import Any
from typing import Callable
from typing import Dict
from typing import Generator
from typing import IO
from typing import Optional
from typing import Tuple
from typing import Type
from typing import Union
from pre_commit import five
from pre_commit import parse_shebang
@ -17,8 +27,10 @@ else: # pragma: no cover (<PY37)
from importlib_resources import open_binary
from importlib_resources import read_text
EnvironT = Union[Dict[str, str], 'os._Environ']
def mkdirp(path):
def mkdirp(path: str) -> None:
try:
os.makedirs(path)
except OSError:
@ -27,7 +39,7 @@ def mkdirp(path):
@contextlib.contextmanager
def clean_path_on_failure(path):
def clean_path_on_failure(path: str) -> Generator[None, None, None]:
"""Cleans up the directory on an exceptional failure."""
try:
yield
@ -38,12 +50,12 @@ def clean_path_on_failure(path):
@contextlib.contextmanager
def noop_context():
def noop_context() -> Generator[None, None, None]:
yield
@contextlib.contextmanager
def tmpdir():
def tmpdir() -> Generator[str, None, None]:
"""Contextmanager to create a temporary directory. It will be cleaned up
afterwards.
"""
@ -54,15 +66,15 @@ def tmpdir():
rmtree(tempdir)
def resource_bytesio(filename):
def resource_bytesio(filename: str) -> IO[bytes]:
return open_binary('pre_commit.resources', filename)
def resource_text(filename):
def resource_text(filename: str) -> str:
return read_text('pre_commit.resources', filename)
def make_executable(filename):
def make_executable(filename: str) -> None:
original_mode = os.stat(filename).st_mode
os.chmod(
filename, original_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH,
@ -70,18 +82,23 @@ def make_executable(filename):
class CalledProcessError(RuntimeError):
def __init__(self, returncode, cmd, expected_returncode, stdout, stderr):
super().__init__(
returncode, cmd, expected_returncode, stdout, stderr,
)
def __init__(
self,
returncode: int,
cmd: Tuple[str, ...],
expected_returncode: int,
stdout: bytes,
stderr: Optional[bytes],
) -> None:
super().__init__(returncode, cmd, expected_returncode, stdout, stderr)
self.returncode = returncode
self.cmd = cmd
self.expected_returncode = expected_returncode
self.stdout = stdout
self.stderr = stderr
def __bytes__(self):
def _indent_or_none(part):
def __bytes__(self) -> bytes:
def _indent_or_none(part: Optional[bytes]) -> bytes:
if part:
return b'\n ' + part.replace(b'\n', b'\n ')
else:
@ -97,11 +114,14 @@ class CalledProcessError(RuntimeError):
b'stderr:', _indent_or_none(self.stderr),
))
def __str__(self):
def __str__(self) -> str:
return self.__bytes__().decode('UTF-8')
def _cmd_kwargs(*cmd, **kwargs):
def _cmd_kwargs(
*cmd: str,
**kwargs: Any,
) -> Tuple[Tuple[str, ...], Dict[str, Any]]:
# py2/py3 on windows are more strict about the types here
cmd = tuple(five.n(arg) for arg in cmd)
kwargs['env'] = {
@ -113,7 +133,10 @@ def _cmd_kwargs(*cmd, **kwargs):
return cmd, kwargs
def cmd_output_b(*cmd, **kwargs):
def cmd_output_b(
*cmd: str,
**kwargs: Any,
) -> Tuple[int, bytes, Optional[bytes]]:
retcode = kwargs.pop('retcode', 0)
cmd, kwargs = _cmd_kwargs(*cmd, **kwargs)
@ -132,7 +155,7 @@ def cmd_output_b(*cmd, **kwargs):
return returncode, stdout_b, stderr_b
def cmd_output(*cmd, **kwargs):
def cmd_output(*cmd: str, **kwargs: Any) -> Tuple[int, str, Optional[str]]:
returncode, stdout_b, stderr_b = cmd_output_b(*cmd, **kwargs)
stdout = stdout_b.decode('UTF-8') if stdout_b is not None else None
stderr = stderr_b.decode('UTF-8') if stderr_b is not None else None
@ -144,10 +167,11 @@ if os.name != 'nt': # pragma: windows no cover
import termios
class Pty:
def __init__(self):
self.r = self.w = None
def __init__(self) -> None:
self.r: Optional[int] = None
self.w: Optional[int] = None
def __enter__(self):
def __enter__(self) -> 'Pty':
self.r, self.w = openpty()
# tty flags normally change \n to \r\n
@ -158,21 +182,29 @@ if os.name != 'nt': # pragma: windows no cover
return self
def close_w(self):
def close_w(self) -> None:
if self.w is not None:
os.close(self.w)
self.w = None
def close_r(self):
def close_r(self) -> None:
assert self.r is not None
os.close(self.r)
self.r = None
def __exit__(self, exc_type, exc_value, traceback):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self.close_w()
self.close_r()
def cmd_output_p(*cmd, **kwargs):
def cmd_output_p(
*cmd: str,
**kwargs: Any,
) -> Tuple[int, bytes, Optional[bytes]]:
assert kwargs.pop('retcode') is None
assert kwargs['stderr'] == subprocess.STDOUT, kwargs['stderr']
cmd, kwargs = _cmd_kwargs(*cmd, **kwargs)
@ -183,6 +215,7 @@ if os.name != 'nt': # pragma: windows no cover
return e.to_output()
with open(os.devnull) as devnull, Pty() as pty:
assert pty.r is not None
kwargs.update({'stdin': devnull, 'stdout': pty.w, 'stderr': pty.w})
proc = subprocess.Popen(cmd, **kwargs)
pty.close_w()
@ -206,9 +239,13 @@ else: # pragma: no cover
cmd_output_p = cmd_output_b
def rmtree(path):
def rmtree(path: str) -> None:
"""On windows, rmtree fails for readonly dirs."""
def handle_remove_readonly(func, path, exc):
def handle_remove_readonly(
func: Callable[..., Any],
path: str,
exc: Tuple[Type[OSError], OSError, TracebackType],
) -> None:
excvalue = exc[1]
if (
func in (os.rmdir, os.remove, os.unlink) and
@ -222,6 +259,6 @@ def rmtree(path):
shutil.rmtree(path, ignore_errors=False, onerror=handle_remove_readonly)
def parse_version(s):
def parse_version(s: str) -> Tuple[int, ...]:
"""poor man's version comparison"""
return tuple(int(p) for p in s.split('.'))

View file

@ -4,14 +4,26 @@ import math
import os
import subprocess
import sys
from typing import Any
from typing import Callable
from typing import Generator
from typing import Iterable
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TypeVar
from pre_commit import parse_shebang
from pre_commit.util import cmd_output_b
from pre_commit.util import cmd_output_p
from pre_commit.util import EnvironT
TArg = TypeVar('TArg')
TRet = TypeVar('TRet')
def _environ_size(_env=None):
def _environ_size(_env: Optional[EnvironT] = None) -> int:
environ = _env if _env is not None else getattr(os, 'environb', os.environ)
size = 8 * len(environ) # number of pointers in `envp`
for k, v in environ.items():
@ -19,7 +31,7 @@ def _environ_size(_env=None):
return size
def _get_platform_max_length(): # pragma: no cover (platform specific)
def _get_platform_max_length() -> int: # pragma: no cover (platform specific)
if os.name == 'posix':
maximum = os.sysconf('SC_ARG_MAX') - 2048 - _environ_size()
maximum = max(min(maximum, 2 ** 17), 2 ** 12)
@ -31,7 +43,7 @@ def _get_platform_max_length(): # pragma: no cover (platform specific)
return 2 ** 12
def _command_length(*cmd):
def _command_length(*cmd: str) -> int:
full_cmd = ' '.join(cmd)
# win32 uses the amount of characters, more details at:
@ -47,7 +59,12 @@ class ArgumentTooLongError(RuntimeError):
pass
def partition(cmd, varargs, target_concurrency, _max_length=None):
def partition(
cmd: Sequence[str],
varargs: Sequence[str],
target_concurrency: int,
_max_length: Optional[int] = None,
) -> Tuple[Tuple[str, ...], ...]:
_max_length = _max_length or _get_platform_max_length()
# Generally, we try to partition evenly into at least `target_concurrency`
@ -87,7 +104,10 @@ def partition(cmd, varargs, target_concurrency, _max_length=None):
@contextlib.contextmanager
def _thread_mapper(maxsize):
def _thread_mapper(maxsize: int) -> Generator[
Callable[[Callable[[TArg], TRet], Iterable[TArg]], Iterable[TRet]],
None, None,
]:
if maxsize == 1:
yield map
else:
@ -95,7 +115,11 @@ def _thread_mapper(maxsize):
yield ex.map
def xargs(cmd, varargs, **kwargs):
def xargs(
cmd: Tuple[str, ...],
varargs: Sequence[str],
**kwargs: Any,
) -> Tuple[int, bytes]:
"""A simplified implementation of xargs.
color: Make a pty if on a platform that supports it
@ -115,7 +139,9 @@ def xargs(cmd, varargs, **kwargs):
partitions = partition(cmd, varargs, target_concurrency, max_length)
def run_cmd_partition(run_cmd):
def run_cmd_partition(
run_cmd: Tuple[str, ...],
) -> Tuple[int, bytes, Optional[bytes]]:
return cmd_fn(
*run_cmd, retcode=None, stderr=subprocess.STDOUT, **kwargs,
)

View file

@ -57,6 +57,7 @@ universal = True
check_untyped_defs = true
disallow_any_generics = true
disallow_incomplete_defs = true
disallow_untyped_defs = true
no_implicit_optional = true
[mypy-testing.*]

View file

@ -37,21 +37,21 @@ def test_use_color_no_tty():
def test_use_color_tty_with_color_support():
with mock.patch.object(sys.stdout, 'isatty', return_value=True):
with mock.patch('pre_commit.color.terminal_supports_color', True):
with envcontext.envcontext([('TERM', envcontext.UNSET)]):
with envcontext.envcontext((('TERM', envcontext.UNSET),)):
assert use_color('auto') is True
def test_use_color_tty_without_color_support():
with mock.patch.object(sys.stdout, 'isatty', return_value=True):
with mock.patch('pre_commit.color.terminal_supports_color', False):
with envcontext.envcontext([('TERM', envcontext.UNSET)]):
with envcontext.envcontext((('TERM', envcontext.UNSET),)):
assert use_color('auto') is False
def test_use_color_dumb_term():
with mock.patch.object(sys.stdout, 'isatty', return_value=True):
with mock.patch('pre_commit.color.terminal_supports_color', True):
with envcontext.envcontext([('TERM', 'dumb')]):
with envcontext.envcontext((('TERM', 'dumb'),)):
assert use_color('auto') is False

View file

@ -24,7 +24,7 @@ def test_init_templatedir(tmpdir, tempdir_factory, store, cap_out):
'[WARNING] maybe `git config --global init.templateDir',
)
with envcontext([('GIT_TEMPLATE_DIR', target)]):
with envcontext((('GIT_TEMPLATE_DIR', target),)):
path = make_consuming_repo(tempdir_factory, 'script_hooks_repo')
with cwd(path):
@ -52,7 +52,7 @@ def test_init_templatedir_already_set(tmpdir, tempdir_factory, store, cap_out):
def test_init_templatedir_not_set(tmpdir, store, cap_out):
# set HOME to ignore the current `.gitconfig`
with envcontext([('HOME', str(tmpdir))]):
with envcontext((('HOME', str(tmpdir)),)):
with tmpdir.join('tmpl').ensure_dir().as_cwd():
# we have not set init.templateDir so this should produce a warning
init_templatedir(

View file

@ -274,5 +274,5 @@ def fake_log_handler():
@pytest.fixture(scope='session', autouse=True)
def set_git_templatedir(tmpdir_factory):
tdir = str(tmpdir_factory.mktemp('git_template_dir'))
with envcontext([('GIT_TEMPLATE_DIR', tdir)]):
with envcontext((('GIT_TEMPLATE_DIR', tdir),)):
yield

View file

@ -93,7 +93,7 @@ def test_exception_safety():
env = {'hello': 'world'}
with pytest.raises(MyError):
with envcontext([('foo', 'bar')], _env=env):
with envcontext((('foo', 'bar'),), _env=env):
raise MyError()
assert env == {'hello': 'world'}
@ -101,6 +101,6 @@ def test_exception_safety():
def test_integration_os_environ():
with mock.patch.dict(os.environ, {'FOO': 'bar'}, clear=True):
assert os.environ == {'FOO': 'bar'}
with envcontext([('HERP', 'derp')]):
with envcontext((('HERP', 'derp'),)):
assert os.environ == {'FOO': 'bar', 'HERP': 'derp'}
assert os.environ == {'FOO': 'bar'}

View file

@ -1,23 +1,31 @@
import functools
import inspect
from typing import Sequence
from typing import Tuple
import pytest
from pre_commit.languages.all import all_languages
from pre_commit.languages.all import languages
from pre_commit.prefix import Prefix
ArgSpec = functools.partial(
inspect.FullArgSpec, varargs=None, varkw=None, defaults=None,
kwonlyargs=[], kwonlydefaults=None, annotations={},
def _argspec(annotations):
args = [k for k in annotations if k != 'return']
return inspect.FullArgSpec(
args=args, annotations=annotations,
varargs=None, varkw=None, defaults=None,
kwonlyargs=[], kwonlydefaults=None,
)
@pytest.mark.parametrize('language', all_languages)
def test_install_environment_argspec(language):
expected_argspec = ArgSpec(
args=['prefix', 'version', 'additional_dependencies'],
)
expected_argspec = _argspec({
'return': None,
'prefix': Prefix,
'version': str,
'additional_dependencies': Sequence[str],
})
argspec = inspect.getfullargspec(languages[language].install_environment)
assert argspec == expected_argspec
@ -29,20 +37,26 @@ def test_ENVIRONMENT_DIR(language):
@pytest.mark.parametrize('language', all_languages)
def test_run_hook_argspec(language):
expected_argspec = ArgSpec(args=['hook', 'file_args', 'color'])
expected_argspec = _argspec({
'return': Tuple[int, bytes],
'hook': 'Hook', 'file_args': Sequence[str], 'color': bool,
})
argspec = inspect.getfullargspec(languages[language].run_hook)
assert argspec == expected_argspec
@pytest.mark.parametrize('language', all_languages)
def test_get_default_version_argspec(language):
expected_argspec = ArgSpec(args=[])
expected_argspec = _argspec({'return': str})
argspec = inspect.getfullargspec(languages[language].get_default_version)
assert argspec == expected_argspec
@pytest.mark.parametrize('language', all_languages)
def test_healthy_argspec(language):
expected_argspec = ArgSpec(args=['prefix', 'language_version'])
expected_argspec = _argspec({
'return': bool,
'prefix': Prefix, 'language_version': str,
})
argspec = inspect.getfullargspec(languages[language].healthy)
assert argspec == expected_argspec

View file

@ -7,7 +7,7 @@ from pre_commit.util import CalledProcessError
def test_docker_is_running_process_error():
with mock.patch(
'pre_commit.languages.docker.cmd_output_b',
side_effect=CalledProcessError(None, None, None, None, None),
side_effect=CalledProcessError(1, (), 0, b'', None),
):
assert docker.docker_is_running() is False

View file

@ -17,7 +17,7 @@ def test_basic_get_default_version():
def test_basic_healthy():
assert helpers.basic_healthy(None, None) is True
assert helpers.basic_healthy(Prefix('.'), 'default') is True
def test_failed_setup_command_does_not_unicode_error():
@ -77,4 +77,6 @@ def test_target_concurrency_cpu_count_not_implemented():
def test_shuffled_is_deterministic():
assert helpers._shuffled(range(10)) == [3, 7, 8, 2, 4, 6, 5, 1, 0, 9]
seq = [str(i) for i in range(10)]
expected = ['3', '7', '8', '2', '4', '6', '5', '1', '0', '9']
assert helpers._shuffled(seq) == expected

View file

@ -1,25 +1,21 @@
import logging
from pre_commit import color
from pre_commit.logging_handler import LoggingHandler
class FakeLogRecord:
def __init__(self, message, levelname, levelno):
self.message = message
self.levelname = levelname
self.levelno = levelno
def getMessage(self):
return self.message
def _log_record(message, level):
return logging.LogRecord('name', level, '', 1, message, {}, None)
def test_logging_handler_color(cap_out):
handler = LoggingHandler(True)
handler.emit(FakeLogRecord('hi', 'WARNING', 30))
handler.emit(_log_record('hi', logging.WARNING))
ret = cap_out.get()
assert ret == color.YELLOW + '[WARNING]' + color.NORMAL + ' hi\n'
def test_logging_handler_no_color(cap_out):
handler = LoggingHandler(False)
handler.emit(FakeLogRecord('hi', 'WARNING', 30))
handler.emit(_log_record('hi', logging.WARNING))
assert cap_out.get() == '[WARNING] hi\n'

View file

@ -1,8 +1,5 @@
import argparse
import os.path
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from unittest import mock
import pytest
@ -27,25 +24,24 @@ def test_append_replace_default(argv, expected):
assert parser.parse_args(argv).f == expected
class Args(NamedTuple):
command: str = 'help'
config: str = C.CONFIG_FILE
files: Sequence[str] = []
repo: Optional[str] = None
def _args(**kwargs):
kwargs.setdefault('command', 'help')
kwargs.setdefault('config', C.CONFIG_FILE)
return argparse.Namespace(**kwargs)
def test_adjust_args_and_chdir_not_in_git_dir(in_tmpdir):
with pytest.raises(FatalError):
main._adjust_args_and_chdir(Args())
main._adjust_args_and_chdir(_args())
def test_adjust_args_and_chdir_in_dot_git_dir(in_git_dir):
with in_git_dir.join('.git').as_cwd(), pytest.raises(FatalError):
main._adjust_args_and_chdir(Args())
main._adjust_args_and_chdir(_args())
def test_adjust_args_and_chdir_noop(in_git_dir):
args = Args(command='run', files=['f1', 'f2'])
args = _args(command='run', files=['f1', 'f2'])
main._adjust_args_and_chdir(args)
assert os.getcwd() == in_git_dir
assert args.config == C.CONFIG_FILE
@ -56,7 +52,7 @@ def test_adjust_args_and_chdir_relative_things(in_git_dir):
in_git_dir.join('foo/cfg.yaml').ensure()
in_git_dir.join('foo').chdir()
args = Args(command='run', files=['f1', 'f2'], config='cfg.yaml')
args = _args(command='run', files=['f1', 'f2'], config='cfg.yaml')
main._adjust_args_and_chdir(args)
assert os.getcwd() == in_git_dir
assert args.config == os.path.join('foo', 'cfg.yaml')
@ -66,7 +62,7 @@ def test_adjust_args_and_chdir_relative_things(in_git_dir):
def test_adjust_args_and_chdir_non_relative_config(in_git_dir):
in_git_dir.join('foo').ensure_dir().chdir()
args = Args()
args = _args()
main._adjust_args_and_chdir(args)
assert os.getcwd() == in_git_dir
assert args.config == C.CONFIG_FILE
@ -75,7 +71,7 @@ def test_adjust_args_and_chdir_non_relative_config(in_git_dir):
def test_adjust_args_try_repo_repo_relative(in_git_dir):
in_git_dir.join('foo').ensure_dir().chdir()
args = Args(command='try-repo', repo='../foo', files=[])
args = _args(command='try-repo', repo='../foo', files=[])
assert args.repo is not None
assert os.path.exists(args.repo)
main._adjust_args_and_chdir(args)

View file

@ -22,7 +22,7 @@ from pre_commit import output
),
)
def test_get_hook_message_raises(kwargs):
with pytest.raises(ValueError):
with pytest.raises(AssertionError):
output.get_hook_message('start', **kwargs)

View file

@ -311,7 +311,7 @@ def test_golang_hook(tempdir_factory, store):
def test_golang_hook_still_works_when_gobin_is_set(tempdir_factory, store):
gobin_dir = tempdir_factory.get()
with envcontext([('GOBIN', gobin_dir)]):
with envcontext((('GOBIN', gobin_dir),)):
test_golang_hook(tempdir_factory, store)
assert os.listdir(gobin_dir) == []

View file

@ -120,7 +120,7 @@ def test_clone_shallow_failure_fallback_to_complete(
# Force shallow clone failure
def fake_shallow_clone(self, *args, **kwargs):
raise CalledProcessError(None, None, None, None, None)
raise CalledProcessError(1, (), 0, b'', None)
store._shallow_clone = fake_shallow_clone
ret = store.clone(path, rev)

View file

@ -15,9 +15,9 @@ from pre_commit.util import tmpdir
def test_CalledProcessError_str():
error = CalledProcessError(1, ['exe'], 0, b'output', b'errors')
error = CalledProcessError(1, ('exe',), 0, b'output', b'errors')
assert str(error) == (
"command: ['exe']\n"
"command: ('exe',)\n"
'return code: 1\n'
'expected return code: 0\n'
'stdout:\n'
@ -28,9 +28,9 @@ def test_CalledProcessError_str():
def test_CalledProcessError_str_nooutput():
error = CalledProcessError(1, ['exe'], 0, b'', b'')
error = CalledProcessError(1, ('exe',), 0, b'', b'')
assert str(error) == (
"command: ['exe']\n"
"command: ('exe',)\n"
'return code: 1\n'
'expected return code: 0\n'
'stdout: (none)\n'

View file

@ -2,6 +2,7 @@ import concurrent.futures
import os
import sys
import time
from typing import Tuple
from unittest import mock
import pytest
@ -166,9 +167,8 @@ def test_xargs_concurrency():
def test_thread_mapper_concurrency_uses_threadpoolexecutor_map():
with xargs._thread_mapper(10) as thread_map:
assert isinstance(
thread_map.__self__, concurrent.futures.ThreadPoolExecutor,
) is True
_self = thread_map.__self__ # type: ignore
assert isinstance(_self, concurrent.futures.ThreadPoolExecutor)
def test_thread_mapper_concurrency_uses_regular_map():
@ -178,7 +178,7 @@ def test_thread_mapper_concurrency_uses_regular_map():
def test_xargs_propagate_kwargs_to_cmd():
env = {'PRE_COMMIT_TEST_VAR': 'Pre commit is awesome'}
cmd = ('bash', '-c', 'echo $PRE_COMMIT_TEST_VAR', '--')
cmd: Tuple[str, ...] = ('bash', '-c', 'echo $PRE_COMMIT_TEST_VAR', '--')
cmd = parse_shebang.normalize_cmd(cmd)
ret, stdout = xargs.xargs(cmd, ('1',), env=env)