Add types to pre-commit

This commit is contained in:
Anthony Sottile 2020-01-10 23:32:28 -08:00
parent fa536a8693
commit 327ed924a3
62 changed files with 911 additions and 411 deletions

View file

@ -25,6 +25,10 @@ exclude_lines =
^\s*return NotImplemented\b ^\s*return NotImplemented\b
^\s*raise$ ^\s*raise$
# Ignore typing-related things
^if (False|TYPE_CHECKING):
: \.\.\.$
# Don't complain if non-runnable code isn't run: # Don't complain if non-runnable code isn't run:
^if __name__ == ['"]__main__['"]:$ ^if __name__ == ['"]__main__['"]:$

View file

@ -3,6 +3,10 @@ import functools
import logging import logging
import pipes import pipes
import sys import sys
from typing import Any
from typing import Dict
from typing import Optional
from typing import Sequence
import cfgv import cfgv
from aspy.yaml import ordered_load from aspy.yaml import ordered_load
@ -18,7 +22,7 @@ logger = logging.getLogger('pre_commit')
check_string_regex = cfgv.check_and(cfgv.check_string, cfgv.check_regex) check_string_regex = cfgv.check_and(cfgv.check_string, cfgv.check_regex)
def check_type_tag(tag): def check_type_tag(tag: str) -> None:
if tag not in ALL_TAGS: if tag not in ALL_TAGS:
raise cfgv.ValidationError( raise cfgv.ValidationError(
'Type tag {!r} is not recognized. ' 'Type tag {!r} is not recognized. '
@ -26,7 +30,7 @@ def check_type_tag(tag):
) )
def check_min_version(version): def check_min_version(version: str) -> None:
if parse_version(version) > parse_version(C.VERSION): if parse_version(version) > parse_version(C.VERSION):
raise cfgv.ValidationError( raise cfgv.ValidationError(
'pre-commit version {} is required but version {} is installed. ' 'pre-commit version {} is required but version {} is installed. '
@ -36,7 +40,7 @@ def check_min_version(version):
) )
def _make_argparser(filenames_help): def _make_argparser(filenames_help: str) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help=filenames_help) parser.add_argument('filenames', nargs='*', help=filenames_help)
parser.add_argument('-V', '--version', action='version', version=C.VERSION) parser.add_argument('-V', '--version', action='version', version=C.VERSION)
@ -86,7 +90,7 @@ load_manifest = functools.partial(
) )
def validate_manifest_main(argv=None): def validate_manifest_main(argv: Optional[Sequence[str]] = None) -> int:
parser = _make_argparser('Manifest filenames.') parser = _make_argparser('Manifest filenames.')
args = parser.parse_args(argv) args = parser.parse_args(argv)
ret = 0 ret = 0
@ -107,7 +111,7 @@ class MigrateShaToRev:
key = 'rev' key = 'rev'
@staticmethod @staticmethod
def _cond(key): def _cond(key: str) -> cfgv.Conditional:
return cfgv.Conditional( return cfgv.Conditional(
key, cfgv.check_string, key, cfgv.check_string,
condition_key='repo', condition_key='repo',
@ -115,7 +119,7 @@ class MigrateShaToRev:
ensure_absent=True, ensure_absent=True,
) )
def check(self, dct): def check(self, dct: Dict[str, Any]) -> None:
if dct.get('repo') in {LOCAL, META}: if dct.get('repo') in {LOCAL, META}:
self._cond('rev').check(dct) self._cond('rev').check(dct)
self._cond('sha').check(dct) self._cond('sha').check(dct)
@ -126,14 +130,14 @@ class MigrateShaToRev:
else: else:
self._cond('rev').check(dct) self._cond('rev').check(dct)
def apply_default(self, dct): def apply_default(self, dct: Dict[str, Any]) -> None:
if 'sha' in dct: if 'sha' in dct:
dct['rev'] = dct.pop('sha') dct['rev'] = dct.pop('sha')
remove_default = cfgv.Required.remove_default remove_default = cfgv.Required.remove_default
def _entry(modname): def _entry(modname: str) -> str:
"""the hook `entry` is passed through `shlex.split()` by the command """the hook `entry` is passed through `shlex.split()` by the command
runner, so to prevent issues with spaces and backslashes (on Windows) runner, so to prevent issues with spaces and backslashes (on Windows)
it must be quoted here. it must be quoted here.
@ -143,13 +147,21 @@ def _entry(modname):
) )
def warn_unknown_keys_root(extra, orig_keys, dct): def warn_unknown_keys_root(
extra: Sequence[str],
orig_keys: Sequence[str],
dct: Dict[str, str],
) -> None:
logger.warning( logger.warning(
'Unexpected key(s) present at root: {}'.format(', '.join(extra)), 'Unexpected key(s) present at root: {}'.format(', '.join(extra)),
) )
def warn_unknown_keys_repo(extra, orig_keys, dct): def warn_unknown_keys_repo(
extra: Sequence[str],
orig_keys: Sequence[str],
dct: Dict[str, str],
) -> None:
logger.warning( logger.warning(
'Unexpected key(s) present on {}: {}'.format( 'Unexpected key(s) present on {}: {}'.format(
dct['repo'], ', '.join(extra), dct['repo'], ', '.join(extra),
@ -281,7 +293,7 @@ class InvalidConfigError(FatalError):
pass pass
def ordered_load_normalize_legacy_config(contents): def ordered_load_normalize_legacy_config(contents: str) -> Dict[str, Any]:
data = ordered_load(contents) data = ordered_load(contents)
if isinstance(data, list): if isinstance(data, list):
# TODO: Once happy, issue a deprecation warning and instructions # TODO: Once happy, issue a deprecation warning and instructions
@ -298,7 +310,7 @@ load_config = functools.partial(
) )
def validate_config_main(argv=None): def validate_config_main(argv: Optional[Sequence[str]] = None) -> int:
parser = _make_argparser('Config filenames.') parser = _make_argparser('Config filenames.')
args = parser.parse_args(argv) args = parser.parse_args(argv)
ret = 0 ret = 0

View file

@ -21,7 +21,7 @@ class InvalidColorSetting(ValueError):
pass pass
def format_color(text, color, use_color_setting): def format_color(text: str, color: str, use_color_setting: bool) -> str:
"""Format text with color. """Format text with color.
Args: Args:
@ -38,7 +38,7 @@ def format_color(text, color, use_color_setting):
COLOR_CHOICES = ('auto', 'always', 'never') COLOR_CHOICES = ('auto', 'always', 'never')
def use_color(setting): def use_color(setting: str) -> bool:
"""Choose whether to use color based on the command argument. """Choose whether to use color based on the command argument.
Args: Args:

View file

@ -1,8 +1,12 @@
import collections
import os.path import os.path
import re import re
from typing import Any
from typing import Dict
from typing import List from typing import List
from typing import NamedTuple
from typing import Optional from typing import Optional
from typing import Sequence
from typing import Tuple
from aspy.yaml import ordered_dump from aspy.yaml import ordered_dump
from aspy.yaml import ordered_load from aspy.yaml import ordered_load
@ -16,20 +20,23 @@ from pre_commit.clientlib import load_manifest
from pre_commit.clientlib import LOCAL from pre_commit.clientlib import LOCAL
from pre_commit.clientlib import META from pre_commit.clientlib import META
from pre_commit.commands.migrate_config import migrate_config from pre_commit.commands.migrate_config import migrate_config
from pre_commit.store import Store
from pre_commit.util import CalledProcessError from pre_commit.util import CalledProcessError
from pre_commit.util import cmd_output from pre_commit.util import cmd_output
from pre_commit.util import cmd_output_b from pre_commit.util import cmd_output_b
from pre_commit.util import tmpdir from pre_commit.util import tmpdir
class RevInfo(collections.namedtuple('RevInfo', ('repo', 'rev', 'frozen'))): class RevInfo(NamedTuple):
__slots__ = () repo: str
rev: str
frozen: Optional[str]
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config: Dict[str, Any]) -> 'RevInfo':
return cls(config['repo'], config['rev'], None) return cls(config['repo'], config['rev'], None)
def update(self, tags_only, freeze): def update(self, tags_only: bool, freeze: bool) -> 'RevInfo':
if tags_only: if tags_only:
tag_cmd = ('git', 'describe', 'FETCH_HEAD', '--tags', '--abbrev=0') tag_cmd = ('git', 'describe', 'FETCH_HEAD', '--tags', '--abbrev=0')
else: else:
@ -57,7 +64,11 @@ class RepositoryCannotBeUpdatedError(RuntimeError):
pass pass
def _check_hooks_still_exist_at_rev(repo_config, info, store): def _check_hooks_still_exist_at_rev(
repo_config: Dict[str, Any],
info: RevInfo,
store: Store,
) -> None:
try: try:
path = store.clone(repo_config['repo'], info.rev) path = store.clone(repo_config['repo'], info.rev)
manifest = load_manifest(os.path.join(path, C.MANIFEST_FILE)) manifest = load_manifest(os.path.join(path, C.MANIFEST_FILE))
@ -78,7 +89,11 @@ REV_LINE_RE = re.compile(r'^(\s+)rev:(\s*)([^\s#]+)(.*)(\r?\n)$', re.DOTALL)
REV_LINE_FMT = '{}rev:{}{}{}{}' REV_LINE_FMT = '{}rev:{}{}{}{}'
def _original_lines(path, rev_infos, retry=False): def _original_lines(
path: str,
rev_infos: List[Optional[RevInfo]],
retry: bool = False,
) -> Tuple[List[str], List[int]]:
"""detect `rev:` lines or reformat the file""" """detect `rev:` lines or reformat the file"""
with open(path) as f: with open(path) as f:
original = f.read() original = f.read()
@ -95,7 +110,7 @@ def _original_lines(path, rev_infos, retry=False):
return _original_lines(path, rev_infos, retry=True) return _original_lines(path, rev_infos, retry=True)
def _write_new_config(path, rev_infos): def _write_new_config(path: str, rev_infos: List[Optional[RevInfo]]) -> None:
lines, idxs = _original_lines(path, rev_infos) lines, idxs = _original_lines(path, rev_infos)
for idx, rev_info in zip(idxs, rev_infos): for idx, rev_info in zip(idxs, rev_infos):
@ -119,7 +134,13 @@ def _write_new_config(path, rev_infos):
f.write(''.join(lines)) f.write(''.join(lines))
def autoupdate(config_file, store, tags_only, freeze, repos=()): def autoupdate(
config_file: str,
store: Store,
tags_only: bool,
freeze: bool,
repos: Sequence[str] = (),
) -> int:
"""Auto-update the pre-commit config to the latest versions of repos.""" """Auto-update the pre-commit config to the latest versions of repos."""
migrate_config(config_file, quiet=True) migrate_config(config_file, quiet=True)
retv = 0 retv = 0

View file

@ -1,10 +1,11 @@
import os.path import os.path
from pre_commit import output from pre_commit import output
from pre_commit.store import Store
from pre_commit.util import rmtree from pre_commit.util import rmtree
def clean(store): def clean(store: Store) -> int:
legacy_path = os.path.expanduser('~/.pre-commit') legacy_path = os.path.expanduser('~/.pre-commit')
for directory in (store.directory, legacy_path): for directory in (store.directory, legacy_path):
if os.path.exists(directory): if os.path.exists(directory):

View file

@ -1,4 +1,8 @@
import os.path import os.path
from typing import Any
from typing import Dict
from typing import Set
from typing import Tuple
import pre_commit.constants as C import pre_commit.constants as C
from pre_commit import output from pre_commit import output
@ -8,9 +12,15 @@ from pre_commit.clientlib import load_config
from pre_commit.clientlib import load_manifest from pre_commit.clientlib import load_manifest
from pre_commit.clientlib import LOCAL from pre_commit.clientlib import LOCAL
from pre_commit.clientlib import META from pre_commit.clientlib import META
from pre_commit.store import Store
def _mark_used_repos(store, all_repos, unused_repos, repo): def _mark_used_repos(
store: Store,
all_repos: Dict[Tuple[str, str], str],
unused_repos: Set[Tuple[str, str]],
repo: Dict[str, Any],
) -> None:
if repo['repo'] == META: if repo['repo'] == META:
return return
elif repo['repo'] == LOCAL: elif repo['repo'] == LOCAL:
@ -47,7 +57,7 @@ def _mark_used_repos(store, all_repos, unused_repos, repo):
)) ))
def _gc_repos(store): def _gc_repos(store: Store) -> int:
configs = store.select_all_configs() configs = store.select_all_configs()
repos = store.select_all_repos() repos = store.select_all_repos()
@ -73,7 +83,7 @@ def _gc_repos(store):
return len(unused_repos) return len(unused_repos)
def gc(store): def gc(store: Store) -> int:
with store.exclusive_lock(): with store.exclusive_lock():
repos_removed = _gc_repos(store) repos_removed = _gc_repos(store)
output.write_line(f'{repos_removed} repo(s) removed.') output.write_line(f'{repos_removed} repo(s) removed.')

View file

@ -1,14 +1,21 @@
import logging import logging
import os.path import os.path
from typing import Sequence
from pre_commit.commands.install_uninstall import install from pre_commit.commands.install_uninstall import install
from pre_commit.store import Store
from pre_commit.util import CalledProcessError from pre_commit.util import CalledProcessError
from pre_commit.util import cmd_output from pre_commit.util import cmd_output
logger = logging.getLogger('pre_commit') logger = logging.getLogger('pre_commit')
def init_templatedir(config_file, store, directory, hook_types): def init_templatedir(
config_file: str,
store: Store,
directory: str,
hook_types: Sequence[str],
) -> int:
install( install(
config_file, store, hook_types=hook_types, config_file, store, hook_types=hook_types,
overwrite=True, skip_on_missing_config=True, git_dir=directory, overwrite=True, skip_on_missing_config=True, git_dir=directory,
@ -25,3 +32,4 @@ def init_templatedir(config_file, store, directory, hook_types):
logger.warning( logger.warning(
f'maybe `git config --global init.templateDir {dest}`?', f'maybe `git config --global init.templateDir {dest}`?',
) )
return 0

View file

@ -3,12 +3,16 @@ import logging
import os.path import os.path
import shutil import shutil
import sys import sys
from typing import Optional
from typing import Sequence
from typing import Tuple
from pre_commit import git from pre_commit import git
from pre_commit import output from pre_commit import output
from pre_commit.clientlib import load_config from pre_commit.clientlib import load_config
from pre_commit.repository import all_hooks from pre_commit.repository import all_hooks
from pre_commit.repository import install_hook_envs from pre_commit.repository import install_hook_envs
from pre_commit.store import Store
from pre_commit.util import make_executable from pre_commit.util import make_executable
from pre_commit.util import mkdirp from pre_commit.util import mkdirp
from pre_commit.util import resource_text from pre_commit.util import resource_text
@ -29,13 +33,16 @@ TEMPLATE_START = '# start templated\n'
TEMPLATE_END = '# end templated\n' TEMPLATE_END = '# end templated\n'
def _hook_paths(hook_type, git_dir=None): def _hook_paths(
hook_type: str,
git_dir: Optional[str] = None,
) -> Tuple[str, str]:
git_dir = git_dir if git_dir is not None else git.get_git_dir() git_dir = git_dir if git_dir is not None else git.get_git_dir()
pth = os.path.join(git_dir, 'hooks', hook_type) pth = os.path.join(git_dir, 'hooks', hook_type)
return pth, f'{pth}.legacy' return pth, f'{pth}.legacy'
def is_our_script(filename): def is_our_script(filename: str) -> bool:
if not os.path.exists(filename): # pragma: windows no cover (symlink) if not os.path.exists(filename): # pragma: windows no cover (symlink)
return False return False
with open(filename) as f: with open(filename) as f:
@ -43,7 +50,7 @@ def is_our_script(filename):
return any(h in contents for h in (CURRENT_HASH,) + PRIOR_HASHES) return any(h in contents for h in (CURRENT_HASH,) + PRIOR_HASHES)
def shebang(): def shebang() -> str:
if sys.platform == 'win32': if sys.platform == 'win32':
py = 'python' py = 'python'
else: else:
@ -63,9 +70,12 @@ def shebang():
def _install_hook_script( def _install_hook_script(
config_file, hook_type, config_file: str,
overwrite=False, skip_on_missing_config=False, git_dir=None, hook_type: str,
): overwrite: bool = False,
skip_on_missing_config: bool = False,
git_dir: Optional[str] = None,
) -> None:
hook_path, legacy_path = _hook_paths(hook_type, git_dir=git_dir) hook_path, legacy_path = _hook_paths(hook_type, git_dir=git_dir)
mkdirp(os.path.dirname(hook_path)) mkdirp(os.path.dirname(hook_path))
@ -108,10 +118,14 @@ def _install_hook_script(
def install( def install(
config_file, store, hook_types, config_file: str,
overwrite=False, hooks=False, store: Store,
skip_on_missing_config=False, git_dir=None, hook_types: Sequence[str],
): overwrite: bool = False,
hooks: bool = False,
skip_on_missing_config: bool = False,
git_dir: Optional[str] = None,
) -> int:
if git.has_core_hookpaths_set(): if git.has_core_hookpaths_set():
logger.error( logger.error(
'Cowardly refusing to install hooks with `core.hooksPath` set.\n' 'Cowardly refusing to install hooks with `core.hooksPath` set.\n'
@ -133,11 +147,12 @@ def install(
return 0 return 0
def install_hooks(config_file, store): def install_hooks(config_file: str, store: Store) -> int:
install_hook_envs(all_hooks(load_config(config_file), store), store) install_hook_envs(all_hooks(load_config(config_file), store), store)
return 0
def _uninstall_hook_script(hook_type): # type: (str) -> None def _uninstall_hook_script(hook_type: str) -> None:
hook_path, legacy_path = _hook_paths(hook_type) hook_path, legacy_path = _hook_paths(hook_type)
# If our file doesn't exist or it isn't ours, gtfo. # If our file doesn't exist or it isn't ours, gtfo.
@ -152,7 +167,7 @@ def _uninstall_hook_script(hook_type): # type: (str) -> None
output.write_line(f'Restored previous hooks to {hook_path}') output.write_line(f'Restored previous hooks to {hook_path}')
def uninstall(hook_types): def uninstall(hook_types: Sequence[str]) -> int:
for hook_type in hook_types: for hook_type in hook_types:
_uninstall_hook_script(hook_type) _uninstall_hook_script(hook_type)
return 0 return 0

View file

@ -4,16 +4,16 @@ import yaml
from aspy.yaml import ordered_load from aspy.yaml import ordered_load
def _indent(s): def _indent(s: str) -> str:
lines = s.splitlines(True) lines = s.splitlines(True)
return ''.join(' ' * 4 + line if line.strip() else line for line in lines) return ''.join(' ' * 4 + line if line.strip() else line for line in lines)
def _is_header_line(line): def _is_header_line(line: str) -> bool:
return (line.startswith(('#', '---')) or not line.strip()) return line.startswith(('#', '---')) or not line.strip()
def _migrate_map(contents): def _migrate_map(contents: str) -> str:
# Find the first non-header line # Find the first non-header line
lines = contents.splitlines(True) lines = contents.splitlines(True)
i = 0 i = 0
@ -37,12 +37,12 @@ def _migrate_map(contents):
return contents return contents
def _migrate_sha_to_rev(contents): def _migrate_sha_to_rev(contents: str) -> str:
reg = re.compile(r'(\n\s+)sha:') reg = re.compile(r'(\n\s+)sha:')
return reg.sub(r'\1rev:', contents) return reg.sub(r'\1rev:', contents)
def migrate_config(config_file, quiet=False): def migrate_config(config_file: str, quiet: bool = False) -> int:
with open(config_file) as f: with open(config_file) as f:
orig_contents = contents = f.read() orig_contents = contents = f.read()
@ -56,3 +56,4 @@ def migrate_config(config_file, quiet=False):
print('Configuration has been migrated.') print('Configuration has been migrated.')
elif not quiet: elif not quiet:
print('Configuration is already migrated.') print('Configuration is already migrated.')
return 0

View file

@ -1,8 +1,17 @@
import argparse
import functools
import logging import logging
import os import os
import re import re
import subprocess import subprocess
import time import time
from typing import Any
from typing import Collection
from typing import Dict
from typing import List
from typing import Sequence
from typing import Set
from typing import Tuple
from identify.identify import tags_from_path from identify.identify import tags_from_path
@ -12,16 +21,23 @@ from pre_commit import output
from pre_commit.clientlib import load_config from pre_commit.clientlib import load_config
from pre_commit.output import get_hook_message from pre_commit.output import get_hook_message
from pre_commit.repository import all_hooks from pre_commit.repository import all_hooks
from pre_commit.repository import Hook
from pre_commit.repository import install_hook_envs from pre_commit.repository import install_hook_envs
from pre_commit.staged_files_only import staged_files_only from pre_commit.staged_files_only import staged_files_only
from pre_commit.store import Store
from pre_commit.util import cmd_output_b from pre_commit.util import cmd_output_b
from pre_commit.util import EnvironT
from pre_commit.util import noop_context from pre_commit.util import noop_context
logger = logging.getLogger('pre_commit') logger = logging.getLogger('pre_commit')
def filter_by_include_exclude(names, include, exclude): def filter_by_include_exclude(
names: Collection[str],
include: str,
exclude: str,
) -> List[str]:
include_re, exclude_re = re.compile(include), re.compile(exclude) include_re, exclude_re = re.compile(include), re.compile(exclude)
return [ return [
filename for filename in names filename for filename in names
@ -31,24 +47,25 @@ def filter_by_include_exclude(names, include, exclude):
class Classifier: class Classifier:
def __init__(self, filenames): def __init__(self, filenames: Sequence[str]) -> None:
# on windows we normalize all filenames to use forward slashes # on windows we normalize all filenames to use forward slashes
# this makes it easier to filter using the `files:` regex # this makes it easier to filter using the `files:` regex
# this also makes improperly quoted shell-based hooks work better # this also makes improperly quoted shell-based hooks work better
# see #1173 # see #1173
if os.altsep == '/' and os.sep == '\\': if os.altsep == '/' and os.sep == '\\':
filenames = (f.replace(os.sep, os.altsep) for f in filenames) filenames = [f.replace(os.sep, os.altsep) for f in filenames]
self.filenames = [f for f in filenames if os.path.lexists(f)] self.filenames = [f for f in filenames if os.path.lexists(f)]
self._types_cache = {}
def _types_for_file(self, filename): @functools.lru_cache(maxsize=None)
try: def _types_for_file(self, filename: str) -> Set[str]:
return self._types_cache[filename] return tags_from_path(filename)
except KeyError:
ret = self._types_cache[filename] = tags_from_path(filename)
return ret
def by_types(self, names, types, exclude_types): def by_types(
self,
names: Sequence[str],
types: Collection[str],
exclude_types: Collection[str],
) -> List[str]:
types, exclude_types = frozenset(types), frozenset(exclude_types) types, exclude_types = frozenset(types), frozenset(exclude_types)
ret = [] ret = []
for filename in names: for filename in names:
@ -57,14 +74,14 @@ class Classifier:
ret.append(filename) ret.append(filename)
return ret return ret
def filenames_for_hook(self, hook): def filenames_for_hook(self, hook: Hook) -> Tuple[str, ...]:
names = self.filenames names = self.filenames
names = filter_by_include_exclude(names, hook.files, hook.exclude) names = filter_by_include_exclude(names, hook.files, hook.exclude)
names = self.by_types(names, hook.types, hook.exclude_types) names = self.by_types(names, hook.types, hook.exclude_types)
return names return tuple(names)
def _get_skips(environ): def _get_skips(environ: EnvironT) -> Set[str]:
skips = environ.get('SKIP', '') skips = environ.get('SKIP', '')
return {skip.strip() for skip in skips.split(',') if skip.strip()} return {skip.strip() for skip in skips.split(',') if skip.strip()}
@ -73,11 +90,18 @@ SKIPPED = 'Skipped'
NO_FILES = '(no files to check)' NO_FILES = '(no files to check)'
def _subtle_line(s, use_color): def _subtle_line(s: str, use_color: bool) -> None:
output.write_line(color.format_color(s, color.SUBTLE, use_color)) output.write_line(color.format_color(s, color.SUBTLE, use_color))
def _run_single_hook(classifier, hook, skips, cols, verbose, use_color): def _run_single_hook(
classifier: Classifier,
hook: Hook,
skips: Set[str],
cols: int,
verbose: bool,
use_color: bool,
) -> bool:
filenames = classifier.filenames_for_hook(hook) filenames = classifier.filenames_for_hook(hook)
if hook.id in skips or hook.alias in skips: if hook.id in skips or hook.alias in skips:
@ -115,7 +139,8 @@ def _run_single_hook(classifier, hook, skips, cols, verbose, use_color):
diff_cmd = ('git', 'diff', '--no-ext-diff') diff_cmd = ('git', 'diff', '--no-ext-diff')
diff_before = cmd_output_b(*diff_cmd, retcode=None) diff_before = cmd_output_b(*diff_cmd, retcode=None)
filenames = tuple(filenames) if hook.pass_filenames else () if not hook.pass_filenames:
filenames = ()
time_before = time.time() time_before = time.time()
retcode, out = hook.run(filenames, use_color) retcode, out = hook.run(filenames, use_color)
duration = round(time.time() - time_before, 2) or 0 duration = round(time.time() - time_before, 2) or 0
@ -154,7 +179,7 @@ def _run_single_hook(classifier, hook, skips, cols, verbose, use_color):
return files_modified or bool(retcode) return files_modified or bool(retcode)
def _compute_cols(hooks): def _compute_cols(hooks: Sequence[Hook]) -> int:
"""Compute the number of columns to display hook messages. The widest """Compute the number of columns to display hook messages. The widest
that will be displayed is in the no files skipped case: that will be displayed is in the no files skipped case:
@ -169,7 +194,7 @@ def _compute_cols(hooks):
return max(cols, 80) return max(cols, 80)
def _all_filenames(args): def _all_filenames(args: argparse.Namespace) -> Collection[str]:
if args.origin and args.source: if args.origin and args.source:
return git.get_changed_files(args.origin, args.source) return git.get_changed_files(args.origin, args.source)
elif args.hook_stage in {'prepare-commit-msg', 'commit-msg'}: elif args.hook_stage in {'prepare-commit-msg', 'commit-msg'}:
@ -184,7 +209,12 @@ def _all_filenames(args):
return git.get_staged_files() return git.get_staged_files()
def _run_hooks(config, hooks, args, environ): def _run_hooks(
config: Dict[str, Any],
hooks: Sequence[Hook],
args: argparse.Namespace,
environ: EnvironT,
) -> int:
"""Actually run the hooks.""" """Actually run the hooks."""
skips = _get_skips(environ) skips = _get_skips(environ)
cols = _compute_cols(hooks) cols = _compute_cols(hooks)
@ -221,12 +251,12 @@ def _run_hooks(config, hooks, args, environ):
return retval return retval
def _has_unmerged_paths(): def _has_unmerged_paths() -> bool:
_, stdout, _ = cmd_output_b('git', 'ls-files', '--unmerged') _, stdout, _ = cmd_output_b('git', 'ls-files', '--unmerged')
return bool(stdout.strip()) return bool(stdout.strip())
def _has_unstaged_config(config_file): def _has_unstaged_config(config_file: str) -> bool:
retcode, _, _ = cmd_output_b( retcode, _, _ = cmd_output_b(
'git', 'diff', '--no-ext-diff', '--exit-code', config_file, 'git', 'diff', '--no-ext-diff', '--exit-code', config_file,
retcode=None, retcode=None,
@ -235,7 +265,12 @@ def _has_unstaged_config(config_file):
return retcode == 1 return retcode == 1
def run(config_file, store, args, environ=os.environ): def run(
config_file: str,
store: Store,
args: argparse.Namespace,
environ: EnvironT = os.environ,
) -> int:
no_stash = args.all_files or bool(args.files) no_stash = args.all_files or bool(args.files)
# Check if we have unresolved merge conflict files and fail fast. # Check if we have unresolved merge conflict files and fail fast.

View file

@ -16,6 +16,6 @@ repos:
''' '''
def sample_config(): def sample_config() -> int:
print(SAMPLE_CONFIG, end='') print(SAMPLE_CONFIG, end='')
return 0 return 0

View file

@ -1,6 +1,8 @@
import argparse
import collections import collections
import logging import logging
import os.path import os.path
from typing import Tuple
from aspy.yaml import ordered_dump from aspy.yaml import ordered_dump
@ -17,7 +19,7 @@ from pre_commit.xargs import xargs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _repo_ref(tmpdir, repo, ref): def _repo_ref(tmpdir: str, repo: str, ref: str) -> Tuple[str, str]:
# if `ref` is explicitly passed, use it # if `ref` is explicitly passed, use it
if ref: if ref:
return repo, ref return repo, ref
@ -47,7 +49,7 @@ def _repo_ref(tmpdir, repo, ref):
return repo, ref return repo, ref
def try_repo(args): def try_repo(args: argparse.Namespace) -> int:
with tmpdir() as tempdir: with tmpdir() as tempdir:
repo, ref = _repo_ref(tempdir, args.repo, args.ref) repo, ref = _repo_ref(tempdir, args.repo, args.ref)

View file

@ -1,10 +1,14 @@
import contextlib import contextlib
import enum import enum
import os import os
from typing import Generator
from typing import NamedTuple from typing import NamedTuple
from typing import Optional
from typing import Tuple from typing import Tuple
from typing import Union from typing import Union
from pre_commit.util import EnvironT
class _Unset(enum.Enum): class _Unset(enum.Enum):
UNSET = 1 UNSET = 1
@ -23,7 +27,7 @@ ValueT = Union[str, _Unset, SubstitutionT]
PatchesT = Tuple[Tuple[str, ValueT], ...] PatchesT = Tuple[Tuple[str, ValueT], ...]
def format_env(parts, env): def format_env(parts: SubstitutionT, env: EnvironT) -> str:
return ''.join( return ''.join(
env.get(part.name, part.default) if isinstance(part, Var) else part env.get(part.name, part.default) if isinstance(part, Var) else part
for part in parts for part in parts
@ -31,7 +35,10 @@ def format_env(parts, env):
@contextlib.contextmanager @contextlib.contextmanager
def envcontext(patch, _env=None): def envcontext(
patch: PatchesT,
_env: Optional[EnvironT] = None,
) -> Generator[None, None, None]:
"""In this context, `os.environ` is modified according to `patch`. """In this context, `os.environ` is modified according to `patch`.
`patch` is an iterable of 2-tuples (key, value): `patch` is an iterable of 2-tuples (key, value):

View file

@ -2,6 +2,7 @@ import contextlib
import os.path import os.path
import sys import sys
import traceback import traceback
from typing import Generator
from typing import Union from typing import Union
import pre_commit.constants as C import pre_commit.constants as C
@ -14,14 +15,11 @@ class FatalError(RuntimeError):
pass pass
def _to_bytes(exc): def _to_bytes(exc: BaseException) -> bytes:
try:
return bytes(exc)
except Exception:
return str(exc).encode('UTF-8') return str(exc).encode('UTF-8')
def _log_and_exit(msg, exc, formatted): def _log_and_exit(msg: str, exc: BaseException, formatted: str) -> None:
error_msg = b''.join(( error_msg = b''.join((
five.to_bytes(msg), b': ', five.to_bytes(msg), b': ',
five.to_bytes(type(exc).__name__), b': ', five.to_bytes(type(exc).__name__), b': ',
@ -62,7 +60,7 @@ def _log_and_exit(msg, exc, formatted):
@contextlib.contextmanager @contextlib.contextmanager
def error_handler(): def error_handler() -> Generator[None, None, None]:
try: try:
yield yield
except (Exception, KeyboardInterrupt) as e: except (Exception, KeyboardInterrupt) as e:

View file

@ -1,6 +1,8 @@
import contextlib import contextlib
import errno import errno
import os import os
from typing import Callable
from typing import Generator
if os.name == 'nt': # pragma: no cover (windows) if os.name == 'nt': # pragma: no cover (windows)
@ -13,7 +15,10 @@ if os.name == 'nt': # pragma: no cover (windows)
_region = 0xffff _region = 0xffff
@contextlib.contextmanager @contextlib.contextmanager
def _locked(fileno, blocked_cb): def _locked(
fileno: int,
blocked_cb: Callable[[], None],
) -> Generator[None, None, None]:
try: try:
# TODO: https://github.com/python/typeshed/pull/3607 # TODO: https://github.com/python/typeshed/pull/3607
msvcrt.locking(fileno, msvcrt.LK_NBLCK, _region) # type: ignore msvcrt.locking(fileno, msvcrt.LK_NBLCK, _region) # type: ignore
@ -42,11 +47,14 @@ if os.name == 'nt': # pragma: no cover (windows)
# before closing a file or exiting the program." # before closing a file or exiting the program."
# TODO: https://github.com/python/typeshed/pull/3607 # TODO: https://github.com/python/typeshed/pull/3607
msvcrt.locking(fileno, msvcrt.LK_UNLCK, _region) # type: ignore msvcrt.locking(fileno, msvcrt.LK_UNLCK, _region) # type: ignore
else: # pramga: windows no cover else: # pragma: windows no cover
import fcntl import fcntl
@contextlib.contextmanager @contextlib.contextmanager
def _locked(fileno, blocked_cb): def _locked(
fileno: int,
blocked_cb: Callable[[], None],
) -> Generator[None, None, None]:
try: try:
fcntl.flock(fileno, fcntl.LOCK_EX | fcntl.LOCK_NB) fcntl.flock(fileno, fcntl.LOCK_EX | fcntl.LOCK_NB)
except OSError: # pragma: no cover (tests are single-threaded) except OSError: # pragma: no cover (tests are single-threaded)
@ -59,7 +67,10 @@ else: # pramga: windows no cover
@contextlib.contextmanager @contextlib.contextmanager
def lock(path, blocked_cb): def lock(
path: str,
blocked_cb: Callable[[], None],
) -> Generator[None, None, None]:
with open(path, 'a+') as f: with open(path, 'a+') as f:
with _locked(f.fileno(), blocked_cb): with _locked(f.fileno(), blocked_cb):
yield yield

View file

@ -1,8 +1,11 @@
def to_text(s): from typing import Union
def to_text(s: Union[str, bytes]) -> str:
return s if isinstance(s, str) else s.decode('UTF-8') return s if isinstance(s, str) else s.decode('UTF-8')
def to_bytes(s): def to_bytes(s: Union[str, bytes]) -> bytes:
return s if isinstance(s, bytes) else s.encode('UTF-8') return s if isinstance(s, bytes) else s.encode('UTF-8')

View file

@ -1,15 +1,20 @@
import logging import logging
import os.path import os.path
import sys import sys
from typing import Dict
from typing import List
from typing import Optional
from typing import Set
from pre_commit.util import cmd_output from pre_commit.util import cmd_output
from pre_commit.util import cmd_output_b from pre_commit.util import cmd_output_b
from pre_commit.util import EnvironT
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def zsplit(s): def zsplit(s: str) -> List[str]:
s = s.strip('\0') s = s.strip('\0')
if s: if s:
return s.split('\0') return s.split('\0')
@ -17,7 +22,7 @@ def zsplit(s):
return [] return []
def no_git_env(_env=None): def no_git_env(_env: Optional[EnvironT] = None) -> Dict[str, str]:
# Too many bugs dealing with environment variables and GIT: # Too many bugs dealing with environment variables and GIT:
# https://github.com/pre-commit/pre-commit/issues/300 # https://github.com/pre-commit/pre-commit/issues/300
# In git 2.6.3 (maybe others), git exports GIT_WORK_TREE while running # In git 2.6.3 (maybe others), git exports GIT_WORK_TREE while running
@ -34,11 +39,11 @@ def no_git_env(_env=None):
} }
def get_root(): def get_root() -> str:
return cmd_output('git', 'rev-parse', '--show-toplevel')[1].strip() return cmd_output('git', 'rev-parse', '--show-toplevel')[1].strip()
def get_git_dir(git_root='.'): def get_git_dir(git_root: str = '.') -> str:
opts = ('--git-common-dir', '--git-dir') opts = ('--git-common-dir', '--git-dir')
_, out, _ = cmd_output('git', 'rev-parse', *opts, cwd=git_root) _, out, _ = cmd_output('git', 'rev-parse', *opts, cwd=git_root)
for line, opt in zip(out.splitlines(), opts): for line, opt in zip(out.splitlines(), opts):
@ -48,12 +53,12 @@ def get_git_dir(git_root='.'):
raise AssertionError('unreachable: no git dir') raise AssertionError('unreachable: no git dir')
def get_remote_url(git_root): def get_remote_url(git_root: str) -> str:
_, out, _ = cmd_output('git', 'config', 'remote.origin.url', cwd=git_root) _, out, _ = cmd_output('git', 'config', 'remote.origin.url', cwd=git_root)
return out.strip() return out.strip()
def is_in_merge_conflict(): def is_in_merge_conflict() -> bool:
git_dir = get_git_dir('.') git_dir = get_git_dir('.')
return ( return (
os.path.exists(os.path.join(git_dir, 'MERGE_MSG')) and os.path.exists(os.path.join(git_dir, 'MERGE_MSG')) and
@ -61,7 +66,7 @@ def is_in_merge_conflict():
) )
def parse_merge_msg_for_conflicts(merge_msg): def parse_merge_msg_for_conflicts(merge_msg: bytes) -> List[str]:
# Conflicted files start with tabs # Conflicted files start with tabs
return [ return [
line.lstrip(b'#').strip().decode('UTF-8') line.lstrip(b'#').strip().decode('UTF-8')
@ -71,7 +76,7 @@ def parse_merge_msg_for_conflicts(merge_msg):
] ]
def get_conflicted_files(): def get_conflicted_files() -> Set[str]:
logger.info('Checking merge-conflict files only.') logger.info('Checking merge-conflict files only.')
# Need to get the conflicted files from the MERGE_MSG because they could # Need to get the conflicted files from the MERGE_MSG because they could
# have resolved the conflict by choosing one side or the other # have resolved the conflict by choosing one side or the other
@ -92,7 +97,7 @@ def get_conflicted_files():
return set(merge_conflict_filenames) | set(merge_diff_filenames) return set(merge_conflict_filenames) | set(merge_diff_filenames)
def get_staged_files(cwd=None): def get_staged_files(cwd: Optional[str] = None) -> List[str]:
return zsplit( return zsplit(
cmd_output( cmd_output(
'git', 'diff', '--staged', '--name-only', '--no-ext-diff', '-z', 'git', 'diff', '--staged', '--name-only', '--no-ext-diff', '-z',
@ -103,7 +108,7 @@ def get_staged_files(cwd=None):
) )
def intent_to_add_files(): def intent_to_add_files() -> List[str]:
_, stdout, _ = cmd_output('git', 'status', '--porcelain', '-z') _, stdout, _ = cmd_output('git', 'status', '--porcelain', '-z')
parts = list(reversed(zsplit(stdout))) parts = list(reversed(zsplit(stdout)))
intent_to_add = [] intent_to_add = []
@ -117,11 +122,11 @@ def intent_to_add_files():
return intent_to_add return intent_to_add
def get_all_files(): def get_all_files() -> List[str]:
return zsplit(cmd_output('git', 'ls-files', '-z')[1]) return zsplit(cmd_output('git', 'ls-files', '-z')[1])
def get_changed_files(new, old): def get_changed_files(new: str, old: str) -> List[str]:
return zsplit( return zsplit(
cmd_output( cmd_output(
'git', 'diff', '--name-only', '--no-ext-diff', '-z', 'git', 'diff', '--name-only', '--no-ext-diff', '-z',
@ -130,24 +135,22 @@ def get_changed_files(new, old):
) )
def head_rev(remote): def head_rev(remote: str) -> str:
_, out, _ = cmd_output('git', 'ls-remote', '--exit-code', remote, 'HEAD') _, out, _ = cmd_output('git', 'ls-remote', '--exit-code', remote, 'HEAD')
return out.split()[0] return out.split()[0]
def has_diff(*args, **kwargs): def has_diff(*args: str, repo: str = '.') -> bool:
repo = kwargs.pop('repo', '.')
assert not kwargs, kwargs
cmd = ('git', 'diff', '--quiet', '--no-ext-diff') + args cmd = ('git', 'diff', '--quiet', '--no-ext-diff') + args
return cmd_output_b(*cmd, cwd=repo, retcode=None)[0] == 1 return cmd_output_b(*cmd, cwd=repo, retcode=None)[0] == 1
def has_core_hookpaths_set(): def has_core_hookpaths_set() -> bool:
_, out, _ = cmd_output_b('git', 'config', 'core.hooksPath', retcode=None) _, out, _ = cmd_output_b('git', 'config', 'core.hooksPath', retcode=None)
return bool(out.strip()) return bool(out.strip())
def init_repo(path, remote): def init_repo(path: str, remote: str) -> None:
if os.path.isdir(remote): if os.path.isdir(remote):
remote = os.path.abspath(remote) remote = os.path.abspath(remote)
@ -156,7 +159,7 @@ def init_repo(path, remote):
cmd_output_b('git', 'remote', 'add', 'origin', remote, cwd=path, env=env) cmd_output_b('git', 'remote', 'add', 'origin', remote, cwd=path, env=env)
def commit(repo='.'): def commit(repo: str = '.') -> None:
env = no_git_env() env = no_git_env()
name, email = 'pre-commit', 'asottile+pre-commit@umich.edu' name, email = 'pre-commit', 'asottile+pre-commit@umich.edu'
env['GIT_AUTHOR_NAME'] = env['GIT_COMMITTER_NAME'] = name env['GIT_AUTHOR_NAME'] = env['GIT_COMMITTER_NAME'] = name
@ -165,12 +168,12 @@ def commit(repo='.'):
cmd_output_b(*cmd, cwd=repo, env=env) cmd_output_b(*cmd, cwd=repo, env=env)
def git_path(name, repo='.'): def git_path(name: str, repo: str = '.') -> str:
_, out, _ = cmd_output('git', 'rev-parse', '--git-path', name, cwd=repo) _, out, _ = cmd_output('git', 'rev-parse', '--git-path', name, cwd=repo)
return os.path.join(repo, out.strip()) return os.path.join(repo, out.strip())
def check_for_cygwin_mismatch(): def check_for_cygwin_mismatch() -> None:
"""See https://github.com/pre-commit/pre-commit/issues/354""" """See https://github.com/pre-commit/pre-commit/issues/354"""
if sys.platform in ('cygwin', 'win32'): # pragma: no cover (windows) if sys.platform in ('cygwin', 'win32'): # pragma: no cover (windows)
is_cygwin_python = sys.platform == 'cygwin' is_cygwin_python = sys.platform == 'cygwin'

View file

@ -1,20 +1,29 @@
import contextlib import contextlib
import os import os
from typing import Generator
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from pre_commit.envcontext import envcontext from pre_commit.envcontext import envcontext
from pre_commit.envcontext import PatchesT
from pre_commit.envcontext import SubstitutionT from pre_commit.envcontext import SubstitutionT
from pre_commit.envcontext import UNSET from pre_commit.envcontext import UNSET
from pre_commit.envcontext import Var from pre_commit.envcontext import Var
from pre_commit.languages import helpers from pre_commit.languages import helpers
from pre_commit.prefix import Prefix
from pre_commit.util import clean_path_on_failure from pre_commit.util import clean_path_on_failure
from pre_commit.util import cmd_output_b from pre_commit.util import cmd_output_b
if TYPE_CHECKING:
from pre_commit.repository import Hook
ENVIRONMENT_DIR = 'conda' ENVIRONMENT_DIR = 'conda'
get_default_version = helpers.basic_get_default_version get_default_version = helpers.basic_get_default_version
healthy = helpers.basic_healthy healthy = helpers.basic_healthy
def get_env_patch(env): def get_env_patch(env: str) -> PatchesT:
# On non-windows systems executable live in $CONDA_PREFIX/bin, on Windows # On non-windows systems executable live in $CONDA_PREFIX/bin, on Windows
# they can be in $CONDA_PREFIX/bin, $CONDA_PREFIX/Library/bin, # they can be in $CONDA_PREFIX/bin, $CONDA_PREFIX/Library/bin,
# $CONDA_PREFIX/Scripts and $CONDA_PREFIX. Whereas the latter only # $CONDA_PREFIX/Scripts and $CONDA_PREFIX. Whereas the latter only
@ -34,14 +43,21 @@ def get_env_patch(env):
@contextlib.contextmanager @contextlib.contextmanager
def in_env(prefix, language_version): def in_env(
prefix: Prefix,
language_version: str,
) -> Generator[None, None, None]:
directory = helpers.environment_dir(ENVIRONMENT_DIR, language_version) directory = helpers.environment_dir(ENVIRONMENT_DIR, language_version)
envdir = prefix.path(directory) envdir = prefix.path(directory)
with envcontext(get_env_patch(envdir)): with envcontext(get_env_patch(envdir)):
yield yield
def install_environment(prefix, version, additional_dependencies): def install_environment(
prefix: Prefix,
version: str,
additional_dependencies: Sequence[str],
) -> None:
helpers.assert_version_default('conda', version) helpers.assert_version_default('conda', version)
directory = helpers.environment_dir(ENVIRONMENT_DIR, version) directory = helpers.environment_dir(ENVIRONMENT_DIR, version)
@ -58,7 +74,11 @@ def install_environment(prefix, version, additional_dependencies):
) )
def run_hook(hook, file_args, color): def run_hook(
hook: 'Hook',
file_args: Sequence[str],
color: bool,
) -> Tuple[int, bytes]:
# TODO: Some rare commands need to be run using `conda run` but mostly we # TODO: Some rare commands need to be run using `conda run` but mostly we
# can run them withot which is much quicker and produces a better # can run them withot which is much quicker and produces a better
# output. # output.

View file

@ -1,14 +1,18 @@
import hashlib import hashlib
import os import os
from typing import Sequence
from typing import Tuple from typing import Tuple
from typing import TYPE_CHECKING
import pre_commit.constants as C import pre_commit.constants as C
from pre_commit import five
from pre_commit.languages import helpers from pre_commit.languages import helpers
from pre_commit.prefix import Prefix
from pre_commit.util import CalledProcessError from pre_commit.util import CalledProcessError
from pre_commit.util import clean_path_on_failure from pre_commit.util import clean_path_on_failure
from pre_commit.util import cmd_output_b from pre_commit.util import cmd_output_b
if TYPE_CHECKING:
from pre_commit.repository import Hook
ENVIRONMENT_DIR = 'docker' ENVIRONMENT_DIR = 'docker'
PRE_COMMIT_LABEL = 'PRE_COMMIT' PRE_COMMIT_LABEL = 'PRE_COMMIT'
@ -16,16 +20,16 @@ get_default_version = helpers.basic_get_default_version
healthy = helpers.basic_healthy healthy = helpers.basic_healthy
def md5(s): # pragma: windows no cover def md5(s: str) -> str: # pragma: windows no cover
return hashlib.md5(five.to_bytes(s)).hexdigest() return hashlib.md5(s.encode()).hexdigest()
def docker_tag(prefix): # pragma: windows no cover def docker_tag(prefix: Prefix) -> str: # pragma: windows no cover
md5sum = md5(os.path.basename(prefix.prefix_dir)).lower() md5sum = md5(os.path.basename(prefix.prefix_dir)).lower()
return f'pre-commit-{md5sum}' return f'pre-commit-{md5sum}'
def docker_is_running(): # pragma: windows no cover def docker_is_running() -> bool: # pragma: windows no cover
try: try:
cmd_output_b('docker', 'ps') cmd_output_b('docker', 'ps')
except CalledProcessError: except CalledProcessError:
@ -34,15 +38,17 @@ def docker_is_running(): # pragma: windows no cover
return True return True
def assert_docker_available(): # pragma: windows no cover def assert_docker_available() -> None: # pragma: windows no cover
assert docker_is_running(), ( assert docker_is_running(), (
'Docker is either not running or not configured in this environment' 'Docker is either not running or not configured in this environment'
) )
def build_docker_image(prefix, **kwargs): # pragma: windows no cover def build_docker_image(
pull = kwargs.pop('pull') prefix: Prefix,
assert not kwargs, kwargs *,
pull: bool,
) -> None: # pragma: windows no cover
cmd: Tuple[str, ...] = ( cmd: Tuple[str, ...] = (
'docker', 'build', 'docker', 'build',
'--tag', docker_tag(prefix), '--tag', docker_tag(prefix),
@ -56,8 +62,8 @@ def build_docker_image(prefix, **kwargs): # pragma: windows no cover
def install_environment( def install_environment(
prefix, version, additional_dependencies, prefix: Prefix, version: str, additional_dependencies: Sequence[str],
): # pragma: windows no cover ) -> None: # pragma: windows no cover
helpers.assert_version_default('docker', version) helpers.assert_version_default('docker', version)
helpers.assert_no_additional_deps('docker', additional_dependencies) helpers.assert_no_additional_deps('docker', additional_dependencies)
assert_docker_available() assert_docker_available()
@ -73,14 +79,14 @@ def install_environment(
os.mkdir(directory) os.mkdir(directory)
def get_docker_user(): # pragma: windows no cover def get_docker_user() -> str: # pragma: windows no cover
try: try:
return '{}:{}'.format(os.getuid(), os.getgid()) return '{}:{}'.format(os.getuid(), os.getgid())
except AttributeError: except AttributeError:
return '1000:1000' return '1000:1000'
def docker_cmd(): # pragma: windows no cover def docker_cmd() -> Tuple[str, ...]: # pragma: windows no cover
return ( return (
'docker', 'run', 'docker', 'run',
'--rm', '--rm',
@ -93,7 +99,11 @@ def docker_cmd(): # pragma: windows no cover
) )
def run_hook(hook, file_args, color): # pragma: windows no cover def run_hook(
hook: 'Hook',
file_args: Sequence[str],
color: bool,
) -> Tuple[int, bytes]: # pragma: windows no cover
assert_docker_available() assert_docker_available()
# Rebuild the docker image in case it has gone missing, as many people do # Rebuild the docker image in case it has gone missing, as many people do
# automated cleanup of docker images. # automated cleanup of docker images.

View file

@ -1,7 +1,13 @@
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from pre_commit.languages import helpers from pre_commit.languages import helpers
from pre_commit.languages.docker import assert_docker_available from pre_commit.languages.docker import assert_docker_available
from pre_commit.languages.docker import docker_cmd from pre_commit.languages.docker import docker_cmd
if TYPE_CHECKING:
from pre_commit.repository import Hook
ENVIRONMENT_DIR = None ENVIRONMENT_DIR = None
get_default_version = helpers.basic_get_default_version get_default_version = helpers.basic_get_default_version
@ -9,7 +15,11 @@ healthy = helpers.basic_healthy
install_environment = helpers.no_install install_environment = helpers.no_install
def run_hook(hook, file_args, color): # pragma: windows no cover def run_hook(
hook: 'Hook',
file_args: Sequence[str],
color: bool,
) -> Tuple[int, bytes]: # pragma: windows no cover
assert_docker_available() assert_docker_available()
cmd = docker_cmd() + hook.cmd cmd = docker_cmd() + hook.cmd
return helpers.run_xargs(hook, cmd, file_args, color=color) return helpers.run_xargs(hook, cmd, file_args, color=color)

View file

@ -1,5 +1,11 @@
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from pre_commit.languages import helpers from pre_commit.languages import helpers
if TYPE_CHECKING:
from pre_commit.repository import Hook
ENVIRONMENT_DIR = None ENVIRONMENT_DIR = None
get_default_version = helpers.basic_get_default_version get_default_version = helpers.basic_get_default_version
@ -7,7 +13,11 @@ healthy = helpers.basic_healthy
install_environment = helpers.no_install install_environment = helpers.no_install
def run_hook(hook, file_args, color): def run_hook(
hook: 'Hook',
file_args: Sequence[str],
color: bool,
) -> Tuple[int, bytes]:
out = hook.entry.encode('UTF-8') + b'\n\n' out = hook.entry.encode('UTF-8') + b'\n\n'
out += b'\n'.join(f.encode('UTF-8') for f in file_args) + b'\n' out += b'\n'.join(f.encode('UTF-8') for f in file_args) + b'\n'
return 1, out return 1, out

View file

@ -1,31 +1,39 @@
import contextlib import contextlib
import os.path import os.path
import sys import sys
from typing import Generator
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
import pre_commit.constants as C import pre_commit.constants as C
from pre_commit import git from pre_commit import git
from pre_commit.envcontext import envcontext from pre_commit.envcontext import envcontext
from pre_commit.envcontext import PatchesT
from pre_commit.envcontext import Var from pre_commit.envcontext import Var
from pre_commit.languages import helpers from pre_commit.languages import helpers
from pre_commit.prefix import Prefix
from pre_commit.util import clean_path_on_failure from pre_commit.util import clean_path_on_failure
from pre_commit.util import cmd_output from pre_commit.util import cmd_output
from pre_commit.util import cmd_output_b from pre_commit.util import cmd_output_b
from pre_commit.util import rmtree from pre_commit.util import rmtree
if TYPE_CHECKING:
from pre_commit.repository import Hook
ENVIRONMENT_DIR = 'golangenv' ENVIRONMENT_DIR = 'golangenv'
get_default_version = helpers.basic_get_default_version get_default_version = helpers.basic_get_default_version
healthy = helpers.basic_healthy healthy = helpers.basic_healthy
def get_env_patch(venv): def get_env_patch(venv: str) -> PatchesT:
return ( return (
('PATH', (os.path.join(venv, 'bin'), os.pathsep, Var('PATH'))), ('PATH', (os.path.join(venv, 'bin'), os.pathsep, Var('PATH'))),
) )
@contextlib.contextmanager @contextlib.contextmanager
def in_env(prefix): def in_env(prefix: Prefix) -> Generator[None, None, None]:
envdir = prefix.path( envdir = prefix.path(
helpers.environment_dir(ENVIRONMENT_DIR, C.DEFAULT), helpers.environment_dir(ENVIRONMENT_DIR, C.DEFAULT),
) )
@ -33,7 +41,7 @@ def in_env(prefix):
yield yield
def guess_go_dir(remote_url): def guess_go_dir(remote_url: str) -> str:
if remote_url.endswith('.git'): if remote_url.endswith('.git'):
remote_url = remote_url[:-1 * len('.git')] remote_url = remote_url[:-1 * len('.git')]
looks_like_url = ( looks_like_url = (
@ -49,7 +57,11 @@ def guess_go_dir(remote_url):
return 'unknown_src_dir' return 'unknown_src_dir'
def install_environment(prefix, version, additional_dependencies): def install_environment(
prefix: Prefix,
version: str,
additional_dependencies: Sequence[str],
) -> None:
helpers.assert_version_default('golang', version) helpers.assert_version_default('golang', version)
directory = prefix.path( directory = prefix.path(
helpers.environment_dir(ENVIRONMENT_DIR, C.DEFAULT), helpers.environment_dir(ENVIRONMENT_DIR, C.DEFAULT),
@ -79,6 +91,10 @@ def install_environment(prefix, version, additional_dependencies):
rmtree(pkgdir) rmtree(pkgdir)
def run_hook(hook, file_args, color): def run_hook(
hook: 'Hook',
file_args: Sequence[str],
color: bool,
) -> Tuple[int, bytes]:
with in_env(hook.prefix): with in_env(hook.prefix):
return helpers.run_xargs(hook, hook.cmd, file_args, color=color) return helpers.run_xargs(hook, hook.cmd, file_args, color=color)

View file

@ -1,33 +1,54 @@
import multiprocessing import multiprocessing
import os import os
import random import random
from typing import Any
from typing import List
from typing import NoReturn
from typing import Optional
from typing import overload
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
import pre_commit.constants as C import pre_commit.constants as C
from pre_commit.prefix import Prefix
from pre_commit.util import cmd_output_b from pre_commit.util import cmd_output_b
from pre_commit.xargs import xargs from pre_commit.xargs import xargs
if TYPE_CHECKING:
from pre_commit.repository import Hook
FIXED_RANDOM_SEED = 1542676186 FIXED_RANDOM_SEED = 1542676186
def run_setup_cmd(prefix, cmd): def run_setup_cmd(prefix: Prefix, cmd: Tuple[str, ...]) -> None:
cmd_output_b(*cmd, cwd=prefix.prefix_dir) cmd_output_b(*cmd, cwd=prefix.prefix_dir)
def environment_dir(ENVIRONMENT_DIR, language_version): @overload
if ENVIRONMENT_DIR is None: def environment_dir(d: None, language_version: str) -> None: ...
@overload
def environment_dir(d: str, language_version: str) -> str: ...
def environment_dir(d: Optional[str], language_version: str) -> Optional[str]:
if d is None:
return None return None
else: else:
return f'{ENVIRONMENT_DIR}-{language_version}' return f'{d}-{language_version}'
def assert_version_default(binary, version): def assert_version_default(binary: str, version: str) -> None:
if version != C.DEFAULT: if version != C.DEFAULT:
raise AssertionError( raise AssertionError(
f'For now, pre-commit requires system-installed {binary}', f'For now, pre-commit requires system-installed {binary}',
) )
def assert_no_additional_deps(lang, additional_deps): def assert_no_additional_deps(
lang: str,
additional_deps: Sequence[str],
) -> None:
if additional_deps: if additional_deps:
raise AssertionError( raise AssertionError(
'For now, pre-commit does not support ' 'For now, pre-commit does not support '
@ -35,19 +56,23 @@ def assert_no_additional_deps(lang, additional_deps):
) )
def basic_get_default_version(): def basic_get_default_version() -> str:
return C.DEFAULT return C.DEFAULT
def basic_healthy(prefix, language_version): def basic_healthy(prefix: Prefix, language_version: str) -> bool:
return True return True
def no_install(prefix, version, additional_dependencies): def no_install(
prefix: Prefix,
version: str,
additional_dependencies: Sequence[str],
) -> NoReturn:
raise AssertionError('This type is not installable') raise AssertionError('This type is not installable')
def target_concurrency(hook): def target_concurrency(hook: 'Hook') -> int:
if hook.require_serial or 'PRE_COMMIT_NO_CONCURRENCY' in os.environ: if hook.require_serial or 'PRE_COMMIT_NO_CONCURRENCY' in os.environ:
return 1 return 1
else: else:
@ -61,8 +86,8 @@ def target_concurrency(hook):
return 1 return 1
def _shuffled(seq): def _shuffled(seq: Sequence[str]) -> List[str]:
"""Deterministically shuffle identically under both py2 + py3.""" """Deterministically shuffle"""
fixed_random = random.Random() fixed_random = random.Random()
fixed_random.seed(FIXED_RANDOM_SEED, version=1) fixed_random.seed(FIXED_RANDOM_SEED, version=1)
@ -71,7 +96,12 @@ def _shuffled(seq):
return seq return seq
def run_xargs(hook, cmd, file_args, **kwargs): def run_xargs(
hook: 'Hook',
cmd: Tuple[str, ...],
file_args: Sequence[str],
**kwargs: Any,
) -> Tuple[int, bytes]:
# Shuffle the files so that they more evenly fill out the xargs partitions, # Shuffle the files so that they more evenly fill out the xargs partitions,
# but do it deterministically in case a hook cares about ordering. # but do it deterministically in case a hook cares about ordering.
file_args = _shuffled(file_args) file_args = _shuffled(file_args)

View file

@ -1,28 +1,36 @@
import contextlib import contextlib
import os import os
import sys import sys
from typing import Generator
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
import pre_commit.constants as C import pre_commit.constants as C
from pre_commit.envcontext import envcontext from pre_commit.envcontext import envcontext
from pre_commit.envcontext import PatchesT
from pre_commit.envcontext import Var from pre_commit.envcontext import Var
from pre_commit.languages import helpers from pre_commit.languages import helpers
from pre_commit.languages.python import bin_dir from pre_commit.languages.python import bin_dir
from pre_commit.prefix import Prefix
from pre_commit.util import clean_path_on_failure from pre_commit.util import clean_path_on_failure
from pre_commit.util import cmd_output from pre_commit.util import cmd_output
from pre_commit.util import cmd_output_b from pre_commit.util import cmd_output_b
if TYPE_CHECKING:
from pre_commit.repository import Hook
ENVIRONMENT_DIR = 'node_env' ENVIRONMENT_DIR = 'node_env'
get_default_version = helpers.basic_get_default_version get_default_version = helpers.basic_get_default_version
healthy = helpers.basic_healthy healthy = helpers.basic_healthy
def _envdir(prefix, version): def _envdir(prefix: Prefix, version: str) -> str:
directory = helpers.environment_dir(ENVIRONMENT_DIR, version) directory = helpers.environment_dir(ENVIRONMENT_DIR, version)
return prefix.path(directory) return prefix.path(directory)
def get_env_patch(venv): # pragma: windows no cover def get_env_patch(venv: str) -> PatchesT: # pragma: windows no cover
if sys.platform == 'cygwin': # pragma: no cover if sys.platform == 'cygwin': # pragma: no cover
_, win_venv, _ = cmd_output('cygpath', '-w', venv) _, win_venv, _ = cmd_output('cygpath', '-w', venv)
install_prefix = r'{}\bin'.format(win_venv.strip()) install_prefix = r'{}\bin'.format(win_venv.strip())
@ -43,14 +51,17 @@ def get_env_patch(venv): # pragma: windows no cover
@contextlib.contextmanager @contextlib.contextmanager
def in_env(prefix, language_version): # pragma: windows no cover def in_env(
prefix: Prefix,
language_version: str,
) -> Generator[None, None, None]: # pragma: windows no cover
with envcontext(get_env_patch(_envdir(prefix, language_version))): with envcontext(get_env_patch(_envdir(prefix, language_version))):
yield yield
def install_environment( def install_environment(
prefix, version, additional_dependencies, prefix: Prefix, version: str, additional_dependencies: Sequence[str],
): # pragma: windows no cover ) -> None: # pragma: windows no cover
additional_dependencies = tuple(additional_dependencies) additional_dependencies = tuple(additional_dependencies)
assert prefix.exists('package.json') assert prefix.exists('package.json')
envdir = _envdir(prefix, version) envdir = _envdir(prefix, version)
@ -76,6 +87,10 @@ def install_environment(
) )
def run_hook(hook, file_args, color): # pragma: windows no cover def run_hook(
hook: 'Hook',
file_args: Sequence[str],
color: bool,
) -> Tuple[int, bytes]: # pragma: windows no cover
with in_env(hook.prefix, hook.language_version): with in_env(hook.prefix, hook.language_version):
return helpers.run_xargs(hook, hook.cmd, file_args, color=color) return helpers.run_xargs(hook, hook.cmd, file_args, color=color)

View file

@ -1,11 +1,18 @@
import argparse import argparse
import re import re
import sys import sys
from typing import Optional
from typing import Pattern
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from pre_commit import output from pre_commit import output
from pre_commit.languages import helpers from pre_commit.languages import helpers
from pre_commit.xargs import xargs from pre_commit.xargs import xargs
if TYPE_CHECKING:
from pre_commit.repository import Hook
ENVIRONMENT_DIR = None ENVIRONMENT_DIR = None
get_default_version = helpers.basic_get_default_version get_default_version = helpers.basic_get_default_version
@ -13,7 +20,7 @@ healthy = helpers.basic_healthy
install_environment = helpers.no_install install_environment = helpers.no_install
def _process_filename_by_line(pattern, filename): def _process_filename_by_line(pattern: Pattern[bytes], filename: str) -> int:
retv = 0 retv = 0
with open(filename, 'rb') as f: with open(filename, 'rb') as f:
for line_no, line in enumerate(f, start=1): for line_no, line in enumerate(f, start=1):
@ -24,7 +31,7 @@ def _process_filename_by_line(pattern, filename):
return retv return retv
def _process_filename_at_once(pattern, filename): def _process_filename_at_once(pattern: Pattern[bytes], filename: str) -> int:
retv = 0 retv = 0
with open(filename, 'rb') as f: with open(filename, 'rb') as f:
contents = f.read() contents = f.read()
@ -41,12 +48,16 @@ def _process_filename_at_once(pattern, filename):
return retv return retv
def run_hook(hook, file_args, color): def run_hook(
hook: 'Hook',
file_args: Sequence[str],
color: bool,
) -> Tuple[int, bytes]:
exe = (sys.executable, '-m', __name__) + tuple(hook.args) + (hook.entry,) exe = (sys.executable, '-m', __name__) + tuple(hook.args) + (hook.entry,)
return xargs(exe, file_args, color=color) return xargs(exe, file_args, color=color)
def main(argv=None): def main(argv: Optional[Sequence[str]] = None) -> int:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description=( description=(
'grep-like finder using python regexes. Unlike grep, this tool ' 'grep-like finder using python regexes. Unlike grep, this tool '

View file

@ -2,29 +2,40 @@ import contextlib
import functools import functools
import os import os
import sys import sys
from typing import Callable
from typing import ContextManager
from typing import Generator
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
import pre_commit.constants as C import pre_commit.constants as C
from pre_commit.envcontext import envcontext from pre_commit.envcontext import envcontext
from pre_commit.envcontext import PatchesT
from pre_commit.envcontext import UNSET from pre_commit.envcontext import UNSET
from pre_commit.envcontext import Var from pre_commit.envcontext import Var
from pre_commit.languages import helpers from pre_commit.languages import helpers
from pre_commit.parse_shebang import find_executable from pre_commit.parse_shebang import find_executable
from pre_commit.prefix import Prefix
from pre_commit.util import CalledProcessError from pre_commit.util import CalledProcessError
from pre_commit.util import clean_path_on_failure from pre_commit.util import clean_path_on_failure
from pre_commit.util import cmd_output from pre_commit.util import cmd_output
from pre_commit.util import cmd_output_b from pre_commit.util import cmd_output_b
if TYPE_CHECKING:
from pre_commit.repository import Hook
ENVIRONMENT_DIR = 'py_env' ENVIRONMENT_DIR = 'py_env'
def bin_dir(venv): def bin_dir(venv: str) -> str:
"""On windows there's a different directory for the virtualenv""" """On windows there's a different directory for the virtualenv"""
bin_part = 'Scripts' if os.name == 'nt' else 'bin' bin_part = 'Scripts' if os.name == 'nt' else 'bin'
return os.path.join(venv, bin_part) return os.path.join(venv, bin_part)
def get_env_patch(venv): def get_env_patch(venv: str) -> PatchesT:
return ( return (
('PYTHONHOME', UNSET), ('PYTHONHOME', UNSET),
('VIRTUAL_ENV', venv), ('VIRTUAL_ENV', venv),
@ -32,7 +43,9 @@ def get_env_patch(venv):
) )
def _find_by_py_launcher(version): # pragma: no cover (windows only) def _find_by_py_launcher(
version: str,
) -> Optional[str]: # pragma: no cover (windows only)
if version.startswith('python'): if version.startswith('python'):
try: try:
return cmd_output( return cmd_output(
@ -41,14 +54,16 @@ def _find_by_py_launcher(version): # pragma: no cover (windows only)
)[1].strip() )[1].strip()
except CalledProcessError: except CalledProcessError:
pass pass
return None
def _find_by_sys_executable(): def _find_by_sys_executable() -> Optional[str]:
def _norm(path): def _norm(path: str) -> Optional[str]:
_, exe = os.path.split(path.lower()) _, exe = os.path.split(path.lower())
exe, _, _ = exe.partition('.exe') exe, _, _ = exe.partition('.exe')
if find_executable(exe) and exe not in {'python', 'pythonw'}: if find_executable(exe) and exe not in {'python', 'pythonw'}:
return exe return exe
return None
# On linux, I see these common sys.executables: # On linux, I see these common sys.executables:
# #
@ -66,7 +81,7 @@ def _find_by_sys_executable():
@functools.lru_cache(maxsize=1) @functools.lru_cache(maxsize=1)
def get_default_version(): # pragma: no cover (platform dependent) def get_default_version() -> str: # pragma: no cover (platform dependent)
# First attempt from `sys.executable` (or the realpath) # First attempt from `sys.executable` (or the realpath)
exe = _find_by_sys_executable() exe = _find_by_sys_executable()
if exe: if exe:
@ -88,7 +103,7 @@ def get_default_version(): # pragma: no cover (platform dependent)
return C.DEFAULT return C.DEFAULT
def _sys_executable_matches(version): def _sys_executable_matches(version: str) -> bool:
if version == 'python': if version == 'python':
return True return True
elif not version.startswith('python'): elif not version.startswith('python'):
@ -102,7 +117,7 @@ def _sys_executable_matches(version):
return sys.version_info[:len(info)] == info return sys.version_info[:len(info)] == info
def norm_version(version): def norm_version(version: str) -> str:
# first see if our current executable is appropriate # first see if our current executable is appropriate
if _sys_executable_matches(version): if _sys_executable_matches(version):
return sys.executable return sys.executable
@ -126,14 +141,25 @@ def norm_version(version):
return os.path.expanduser(version) return os.path.expanduser(version)
def py_interface(_dir, _make_venv): def py_interface(
_dir: str,
_make_venv: Callable[[str, str], None],
) -> Tuple[
Callable[[Prefix, str], ContextManager[None]],
Callable[[Prefix, str], bool],
Callable[['Hook', Sequence[str], bool], Tuple[int, bytes]],
Callable[[Prefix, str, Sequence[str]], None],
]:
@contextlib.contextmanager @contextlib.contextmanager
def in_env(prefix, language_version): def in_env(
prefix: Prefix,
language_version: str,
) -> Generator[None, None, None]:
envdir = prefix.path(helpers.environment_dir(_dir, language_version)) envdir = prefix.path(helpers.environment_dir(_dir, language_version))
with envcontext(get_env_patch(envdir)): with envcontext(get_env_patch(envdir)):
yield yield
def healthy(prefix, language_version): def healthy(prefix: Prefix, language_version: str) -> bool:
with in_env(prefix, language_version): with in_env(prefix, language_version):
retcode, _, _ = cmd_output_b( retcode, _, _ = cmd_output_b(
'python', '-c', 'python', '-c',
@ -143,11 +169,19 @@ def py_interface(_dir, _make_venv):
) )
return retcode == 0 return retcode == 0
def run_hook(hook, file_args, color): def run_hook(
hook: 'Hook',
file_args: Sequence[str],
color: bool,
) -> Tuple[int, bytes]:
with in_env(hook.prefix, hook.language_version): with in_env(hook.prefix, hook.language_version):
return helpers.run_xargs(hook, hook.cmd, file_args, color=color) return helpers.run_xargs(hook, hook.cmd, file_args, color=color)
def install_environment(prefix, version, additional_dependencies): def install_environment(
prefix: Prefix,
version: str,
additional_dependencies: Sequence[str],
) -> None:
additional_dependencies = tuple(additional_dependencies) additional_dependencies = tuple(additional_dependencies)
directory = helpers.environment_dir(_dir, version) directory = helpers.environment_dir(_dir, version)
@ -166,7 +200,7 @@ def py_interface(_dir, _make_venv):
return in_env, healthy, run_hook, install_environment return in_env, healthy, run_hook, install_environment
def make_venv(envdir, python): def make_venv(envdir: str, python: str) -> None:
env = dict(os.environ, VIRTUALENV_NO_DOWNLOAD='1') env = dict(os.environ, VIRTUALENV_NO_DOWNLOAD='1')
cmd = (sys.executable, '-mvirtualenv', envdir, '-p', python) cmd = (sys.executable, '-mvirtualenv', envdir, '-p', python)
cmd_output_b(*cmd, env=env, cwd='/') cmd_output_b(*cmd, env=env, cwd='/')

View file

@ -5,15 +5,11 @@ from pre_commit.util import CalledProcessError
from pre_commit.util import cmd_output from pre_commit.util import cmd_output
from pre_commit.util import cmd_output_b from pre_commit.util import cmd_output_b
ENVIRONMENT_DIR = 'py_venv' ENVIRONMENT_DIR = 'py_venv'
get_default_version = python.get_default_version
def get_default_version(): # pragma: no cover (version specific) def orig_py_exe(exe: str) -> str: # pragma: no cover (platform specific)
return python.get_default_version()
def orig_py_exe(exe): # pragma: no cover (platform specific)
"""A -mvenv virtualenv made from a -mvirtualenv virtualenv installs """A -mvenv virtualenv made from a -mvirtualenv virtualenv installs
packages to the incorrect location. Attempt to find the _original_ exe packages to the incorrect location. Attempt to find the _original_ exe
and invoke `-mvenv` from there. and invoke `-mvenv` from there.
@ -42,7 +38,7 @@ def orig_py_exe(exe): # pragma: no cover (platform specific)
return exe return exe
def make_venv(envdir, python): def make_venv(envdir: str, python: str) -> None:
cmd_output_b(orig_py_exe(python), '-mvenv', envdir, cwd='/') cmd_output_b(orig_py_exe(python), '-mvenv', envdir, cwd='/')

View file

@ -2,23 +2,33 @@ import contextlib
import os.path import os.path
import shutil import shutil
import tarfile import tarfile
from typing import Generator
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
import pre_commit.constants as C import pre_commit.constants as C
from pre_commit.envcontext import envcontext from pre_commit.envcontext import envcontext
from pre_commit.envcontext import PatchesT from pre_commit.envcontext import PatchesT
from pre_commit.envcontext import Var from pre_commit.envcontext import Var
from pre_commit.languages import helpers from pre_commit.languages import helpers
from pre_commit.prefix import Prefix
from pre_commit.util import CalledProcessError from pre_commit.util import CalledProcessError
from pre_commit.util import clean_path_on_failure from pre_commit.util import clean_path_on_failure
from pre_commit.util import resource_bytesio from pre_commit.util import resource_bytesio
if TYPE_CHECKING:
from pre_comit.repository import Hook
ENVIRONMENT_DIR = 'rbenv' ENVIRONMENT_DIR = 'rbenv'
get_default_version = helpers.basic_get_default_version get_default_version = helpers.basic_get_default_version
healthy = helpers.basic_healthy healthy = helpers.basic_healthy
def get_env_patch(venv, language_version): # pragma: windows no cover def get_env_patch(
venv: str,
language_version: str,
) -> PatchesT: # pragma: windows no cover
patches: PatchesT = ( patches: PatchesT = (
('GEM_HOME', os.path.join(venv, 'gems')), ('GEM_HOME', os.path.join(venv, 'gems')),
('RBENV_ROOT', venv), ('RBENV_ROOT', venv),
@ -36,8 +46,11 @@ def get_env_patch(venv, language_version): # pragma: windows no cover
return patches return patches
@contextlib.contextmanager @contextlib.contextmanager # pragma: windows no cover
def in_env(prefix, language_version): # pragma: windows no cover def in_env(
prefix: Prefix,
language_version: str,
) -> Generator[None, None, None]:
envdir = prefix.path( envdir = prefix.path(
helpers.environment_dir(ENVIRONMENT_DIR, language_version), helpers.environment_dir(ENVIRONMENT_DIR, language_version),
) )
@ -45,13 +58,16 @@ def in_env(prefix, language_version): # pragma: windows no cover
yield yield
def _extract_resource(filename, dest): def _extract_resource(filename: str, dest: str) -> None:
with resource_bytesio(filename) as bio: with resource_bytesio(filename) as bio:
with tarfile.open(fileobj=bio) as tf: with tarfile.open(fileobj=bio) as tf:
tf.extractall(dest) tf.extractall(dest)
def _install_rbenv(prefix, version=C.DEFAULT): # pragma: windows no cover def _install_rbenv(
prefix: Prefix,
version: str = C.DEFAULT,
) -> None: # pragma: windows no cover
directory = helpers.environment_dir(ENVIRONMENT_DIR, version) directory = helpers.environment_dir(ENVIRONMENT_DIR, version)
_extract_resource('rbenv.tar.gz', prefix.path('.')) _extract_resource('rbenv.tar.gz', prefix.path('.'))
@ -87,7 +103,10 @@ def _install_rbenv(prefix, version=C.DEFAULT): # pragma: windows no cover
activate_file.write(f'export RBENV_VERSION="{version}"\n') activate_file.write(f'export RBENV_VERSION="{version}"\n')
def _install_ruby(prefix, version): # pragma: windows no cover def _install_ruby(
prefix: Prefix,
version: str,
) -> None: # pragma: windows no cover
try: try:
helpers.run_setup_cmd(prefix, ('rbenv', 'download', version)) helpers.run_setup_cmd(prefix, ('rbenv', 'download', version))
except CalledProcessError: # pragma: no cover (usually find with download) except CalledProcessError: # pragma: no cover (usually find with download)
@ -96,8 +115,8 @@ def _install_ruby(prefix, version): # pragma: windows no cover
def install_environment( def install_environment(
prefix, version, additional_dependencies, prefix: Prefix, version: str, additional_dependencies: Sequence[str],
): # pragma: windows no cover ) -> None: # pragma: windows no cover
additional_dependencies = tuple(additional_dependencies) additional_dependencies = tuple(additional_dependencies)
directory = helpers.environment_dir(ENVIRONMENT_DIR, version) directory = helpers.environment_dir(ENVIRONMENT_DIR, version)
with clean_path_on_failure(prefix.path(directory)): with clean_path_on_failure(prefix.path(directory)):
@ -122,6 +141,10 @@ def install_environment(
) )
def run_hook(hook, file_args, color): # pragma: windows no cover def run_hook(
hook: 'Hook',
file_args: Sequence[str],
color: bool,
) -> Tuple[int, bytes]: # pragma: windows no cover
with in_env(hook.prefix, hook.language_version): with in_env(hook.prefix, hook.language_version):
return helpers.run_xargs(hook, hook.cmd, file_args, color=color) return helpers.run_xargs(hook, hook.cmd, file_args, color=color)

View file

@ -1,24 +1,31 @@
import contextlib import contextlib
import os.path import os.path
from typing import Generator
from typing import Sequence
from typing import Set from typing import Set
from typing import Tuple from typing import Tuple
from typing import TYPE_CHECKING
import toml import toml
import pre_commit.constants as C import pre_commit.constants as C
from pre_commit.envcontext import envcontext from pre_commit.envcontext import envcontext
from pre_commit.envcontext import PatchesT
from pre_commit.envcontext import Var from pre_commit.envcontext import Var
from pre_commit.languages import helpers from pre_commit.languages import helpers
from pre_commit.prefix import Prefix
from pre_commit.util import clean_path_on_failure from pre_commit.util import clean_path_on_failure
from pre_commit.util import cmd_output_b from pre_commit.util import cmd_output_b
if TYPE_CHECKING:
from pre_commit.repository import Hook
ENVIRONMENT_DIR = 'rustenv' ENVIRONMENT_DIR = 'rustenv'
get_default_version = helpers.basic_get_default_version get_default_version = helpers.basic_get_default_version
healthy = helpers.basic_healthy healthy = helpers.basic_healthy
def get_env_patch(target_dir): def get_env_patch(target_dir: str) -> PatchesT:
return ( return (
( (
'PATH', 'PATH',
@ -28,7 +35,7 @@ def get_env_patch(target_dir):
@contextlib.contextmanager @contextlib.contextmanager
def in_env(prefix): def in_env(prefix: Prefix) -> Generator[None, None, None]:
target_dir = prefix.path( target_dir = prefix.path(
helpers.environment_dir(ENVIRONMENT_DIR, C.DEFAULT), helpers.environment_dir(ENVIRONMENT_DIR, C.DEFAULT),
) )
@ -36,7 +43,10 @@ def in_env(prefix):
yield yield
def _add_dependencies(cargo_toml_path, additional_dependencies): def _add_dependencies(
cargo_toml_path: str,
additional_dependencies: Set[str],
) -> None:
with open(cargo_toml_path, 'r+') as f: with open(cargo_toml_path, 'r+') as f:
cargo_toml = toml.load(f) cargo_toml = toml.load(f)
cargo_toml.setdefault('dependencies', {}) cargo_toml.setdefault('dependencies', {})
@ -48,7 +58,11 @@ def _add_dependencies(cargo_toml_path, additional_dependencies):
f.truncate() f.truncate()
def install_environment(prefix, version, additional_dependencies): def install_environment(
prefix: Prefix,
version: str,
additional_dependencies: Sequence[str],
) -> None:
helpers.assert_version_default('rust', version) helpers.assert_version_default('rust', version)
directory = prefix.path( directory = prefix.path(
helpers.environment_dir(ENVIRONMENT_DIR, C.DEFAULT), helpers.environment_dir(ENVIRONMENT_DIR, C.DEFAULT),
@ -82,13 +96,17 @@ def install_environment(prefix, version, additional_dependencies):
else: else:
packages_to_install.add((package,)) packages_to_install.add((package,))
for package in packages_to_install: for args in packages_to_install:
cmd_output_b( cmd_output_b(
'cargo', 'install', '--bins', '--root', directory, *package, 'cargo', 'install', '--bins', '--root', directory, *args,
cwd=prefix.prefix_dir, cwd=prefix.prefix_dir,
) )
def run_hook(hook, file_args, color): def run_hook(
hook: 'Hook',
file_args: Sequence[str],
color: bool,
) -> Tuple[int, bytes]:
with in_env(hook.prefix): with in_env(hook.prefix):
return helpers.run_xargs(hook, hook.cmd, file_args, color=color) return helpers.run_xargs(hook, hook.cmd, file_args, color=color)

View file

@ -1,5 +1,11 @@
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from pre_commit.languages import helpers from pre_commit.languages import helpers
if TYPE_CHECKING:
from pre_commit.repository import Hook
ENVIRONMENT_DIR = None ENVIRONMENT_DIR = None
get_default_version = helpers.basic_get_default_version get_default_version = helpers.basic_get_default_version
@ -7,7 +13,11 @@ healthy = helpers.basic_healthy
install_environment = helpers.no_install install_environment = helpers.no_install
def run_hook(hook, file_args, color): def run_hook(
hook: 'Hook',
file_args: Sequence[str],
color: bool,
) -> Tuple[int, bytes]:
cmd = hook.cmd cmd = hook.cmd
cmd = (hook.prefix.path(cmd[0]),) + cmd[1:] cmd = (hook.prefix.path(cmd[0]),) + cmd[1:]
return helpers.run_xargs(hook, cmd, file_args, color=color) return helpers.run_xargs(hook, cmd, file_args, color=color)

View file

@ -1,13 +1,22 @@
import contextlib import contextlib
import os import os
from typing import Generator
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
import pre_commit.constants as C import pre_commit.constants as C
from pre_commit.envcontext import envcontext from pre_commit.envcontext import envcontext
from pre_commit.envcontext import PatchesT
from pre_commit.envcontext import Var from pre_commit.envcontext import Var
from pre_commit.languages import helpers from pre_commit.languages import helpers
from pre_commit.prefix import Prefix
from pre_commit.util import clean_path_on_failure from pre_commit.util import clean_path_on_failure
from pre_commit.util import cmd_output_b from pre_commit.util import cmd_output_b
if TYPE_CHECKING:
from pre_commit.repository import Hook
ENVIRONMENT_DIR = 'swift_env' ENVIRONMENT_DIR = 'swift_env'
get_default_version = helpers.basic_get_default_version get_default_version = helpers.basic_get_default_version
healthy = helpers.basic_healthy healthy = helpers.basic_healthy
@ -15,13 +24,13 @@ BUILD_DIR = '.build'
BUILD_CONFIG = 'release' BUILD_CONFIG = 'release'
def get_env_patch(venv): # pragma: windows no cover def get_env_patch(venv: str) -> PatchesT: # pragma: windows no cover
bin_path = os.path.join(venv, BUILD_DIR, BUILD_CONFIG) bin_path = os.path.join(venv, BUILD_DIR, BUILD_CONFIG)
return (('PATH', (bin_path, os.pathsep, Var('PATH'))),) return (('PATH', (bin_path, os.pathsep, Var('PATH'))),)
@contextlib.contextmanager @contextlib.contextmanager # pragma: windows no cover
def in_env(prefix): # pragma: windows no cover def in_env(prefix: Prefix) -> Generator[None, None, None]:
envdir = prefix.path( envdir = prefix.path(
helpers.environment_dir(ENVIRONMENT_DIR, C.DEFAULT), helpers.environment_dir(ENVIRONMENT_DIR, C.DEFAULT),
) )
@ -30,8 +39,8 @@ def in_env(prefix): # pragma: windows no cover
def install_environment( def install_environment(
prefix, version, additional_dependencies, prefix: Prefix, version: str, additional_dependencies: Sequence[str],
): # pragma: windows no cover ) -> None: # pragma: windows no cover
helpers.assert_version_default('swift', version) helpers.assert_version_default('swift', version)
helpers.assert_no_additional_deps('swift', additional_dependencies) helpers.assert_no_additional_deps('swift', additional_dependencies)
directory = prefix.path( directory = prefix.path(
@ -49,6 +58,10 @@ def install_environment(
) )
def run_hook(hook, file_args, color): # pragma: windows no cover def run_hook(
hook: 'Hook',
file_args: Sequence[str],
color: bool,
) -> Tuple[int, bytes]: # pragma: windows no cover
with in_env(hook.prefix): with in_env(hook.prefix):
return helpers.run_xargs(hook, hook.cmd, file_args, color=color) return helpers.run_xargs(hook, hook.cmd, file_args, color=color)

View file

@ -1,5 +1,12 @@
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from pre_commit.languages import helpers from pre_commit.languages import helpers
if TYPE_CHECKING:
from pre_commit.repository import Hook
ENVIRONMENT_DIR = None ENVIRONMENT_DIR = None
get_default_version = helpers.basic_get_default_version get_default_version = helpers.basic_get_default_version
@ -7,5 +14,9 @@ healthy = helpers.basic_healthy
install_environment = helpers.no_install install_environment = helpers.no_install
def run_hook(hook, file_args, color): def run_hook(
hook: 'Hook',
file_args: Sequence[str],
color: bool,
) -> Tuple[int, bytes]:
return helpers.run_xargs(hook, hook.cmd, file_args, color=color) return helpers.run_xargs(hook, hook.cmd, file_args, color=color)

View file

@ -1,10 +1,10 @@
import contextlib import contextlib
import logging import logging
from typing import Generator
from pre_commit import color from pre_commit import color
from pre_commit import output from pre_commit import output
logger = logging.getLogger('pre_commit') logger = logging.getLogger('pre_commit')
LOG_LEVEL_COLORS = { LOG_LEVEL_COLORS = {
@ -16,11 +16,11 @@ LOG_LEVEL_COLORS = {
class LoggingHandler(logging.Handler): class LoggingHandler(logging.Handler):
def __init__(self, use_color): def __init__(self, use_color: bool) -> None:
super().__init__() super().__init__()
self.use_color = use_color self.use_color = use_color
def emit(self, record): def emit(self, record: logging.LogRecord) -> None:
output.write_line( output.write_line(
'{} {}'.format( '{} {}'.format(
color.format_color( color.format_color(
@ -34,8 +34,8 @@ class LoggingHandler(logging.Handler):
@contextlib.contextmanager @contextlib.contextmanager
def logging_handler(*args, **kwargs): def logging_handler(use_color: bool) -> Generator[None, None, None]:
handler = LoggingHandler(*args, **kwargs) handler = LoggingHandler(use_color)
logger.addHandler(handler) logger.addHandler(handler)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
try: try:

View file

@ -2,6 +2,10 @@ import argparse
import logging import logging
import os import os
import sys import sys
from typing import Any
from typing import Optional
from typing import Sequence
from typing import Union
import pre_commit.constants as C import pre_commit.constants as C
from pre_commit import color from pre_commit import color
@ -37,7 +41,7 @@ os.environ.pop('__PYVENV_LAUNCHER__', None)
COMMANDS_NO_GIT = {'clean', 'gc', 'init-templatedir', 'sample-config'} COMMANDS_NO_GIT = {'clean', 'gc', 'init-templatedir', 'sample-config'}
def _add_color_option(parser): def _add_color_option(parser: argparse.ArgumentParser) -> None:
parser.add_argument( parser.add_argument(
'--color', default=os.environ.get('PRE_COMMIT_COLOR', 'auto'), '--color', default=os.environ.get('PRE_COMMIT_COLOR', 'auto'),
type=color.use_color, type=color.use_color,
@ -46,7 +50,7 @@ def _add_color_option(parser):
) )
def _add_config_option(parser): def _add_config_option(parser: argparse.ArgumentParser) -> None:
parser.add_argument( parser.add_argument(
'-c', '--config', default=C.CONFIG_FILE, '-c', '--config', default=C.CONFIG_FILE,
help='Path to alternate config file', help='Path to alternate config file',
@ -54,18 +58,24 @@ def _add_config_option(parser):
class AppendReplaceDefault(argparse.Action): class AppendReplaceDefault(argparse.Action):
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.appended = False self.appended = False
def __call__(self, parser, namespace, values, option_string=None): def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Union[str, Sequence[str], None],
option_string: Optional[str] = None,
) -> None:
if not self.appended: if not self.appended:
setattr(namespace, self.dest, []) setattr(namespace, self.dest, [])
self.appended = True self.appended = True
getattr(namespace, self.dest).append(values) getattr(namespace, self.dest).append(values)
def _add_hook_type_option(parser): def _add_hook_type_option(parser: argparse.ArgumentParser) -> None:
parser.add_argument( parser.add_argument(
'-t', '--hook-type', choices=( '-t', '--hook-type', choices=(
'pre-commit', 'pre-merge-commit', 'pre-push', 'pre-commit', 'pre-merge-commit', 'pre-push',
@ -77,7 +87,7 @@ def _add_hook_type_option(parser):
) )
def _add_run_options(parser): def _add_run_options(parser: argparse.ArgumentParser) -> None:
parser.add_argument('hook', nargs='?', help='A single hook-id to run') parser.add_argument('hook', nargs='?', help='A single hook-id to run')
parser.add_argument('--verbose', '-v', action='store_true', default=False) parser.add_argument('--verbose', '-v', action='store_true', default=False)
parser.add_argument( parser.add_argument(
@ -111,7 +121,7 @@ def _add_run_options(parser):
) )
def _adjust_args_and_chdir(args): def _adjust_args_and_chdir(args: argparse.Namespace) -> None:
# `--config` was specified relative to the non-root working directory # `--config` was specified relative to the non-root working directory
if os.path.exists(args.config): if os.path.exists(args.config):
args.config = os.path.abspath(args.config) args.config = os.path.abspath(args.config)
@ -143,7 +153,7 @@ def _adjust_args_and_chdir(args):
args.repo = os.path.relpath(args.repo) args.repo = os.path.relpath(args.repo)
def main(argv=None): def main(argv: Optional[Sequence[str]] = None) -> int:
argv = argv if argv is not None else sys.argv[1:] argv = argv if argv is not None else sys.argv[1:]
argv = [five.to_text(arg) for arg in argv] argv = [five.to_text(arg) for arg in argv]
parser = argparse.ArgumentParser(prog='pre-commit') parser = argparse.ArgumentParser(prog='pre-commit')

View file

@ -1,6 +1,8 @@
import argparse import argparse
import os.path import os.path
import tarfile import tarfile
from typing import Optional
from typing import Sequence
from pre_commit import output from pre_commit import output
from pre_commit.util import cmd_output_b from pre_commit.util import cmd_output_b
@ -23,7 +25,7 @@ REPOS = (
) )
def make_archive(name, repo, ref, destdir): def make_archive(name: str, repo: str, ref: str, destdir: str) -> str:
"""Makes an archive of a repository in the given destdir. """Makes an archive of a repository in the given destdir.
:param text name: Name to give the archive. For instance foo. The file :param text name: Name to give the archive. For instance foo. The file
@ -49,7 +51,7 @@ def make_archive(name, repo, ref, destdir):
return output_path return output_path
def main(argv=None): def main(argv: Optional[Sequence[str]] = None) -> int:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--dest', default='pre_commit/resources') parser.add_argument('--dest', default='pre_commit/resources')
args = parser.parse_args(argv) args = parser.parse_args(argv)
@ -58,6 +60,7 @@ def main(argv=None):
f'Making {archive_name}.tar.gz for {repo}@{ref}', f'Making {archive_name}.tar.gz for {repo}@{ref}',
) )
make_archive(archive_name, repo, ref, args.dest) make_archive(archive_name, repo, ref, args.dest)
return 0
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -1,4 +1,6 @@
import argparse import argparse
from typing import Optional
from typing import Sequence
import pre_commit.constants as C import pre_commit.constants as C
from pre_commit import git from pre_commit import git
@ -8,7 +10,7 @@ from pre_commit.repository import all_hooks
from pre_commit.store import Store from pre_commit.store import Store
def check_all_hooks_match_files(config_file): def check_all_hooks_match_files(config_file: str) -> int:
classifier = Classifier(git.get_all_files()) classifier = Classifier(git.get_all_files())
retv = 0 retv = 0
@ -22,7 +24,7 @@ def check_all_hooks_match_files(config_file):
return retv return retv
def main(argv=None): def main(argv: Optional[Sequence[str]] = None) -> int:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', default=[C.CONFIG_FILE]) parser.add_argument('filenames', nargs='*', default=[C.CONFIG_FILE])
args = parser.parse_args(argv) args = parser.parse_args(argv)

View file

@ -1,5 +1,7 @@
import argparse import argparse
import re import re
from typing import Optional
from typing import Sequence
from cfgv import apply_defaults from cfgv import apply_defaults
@ -10,7 +12,11 @@ from pre_commit.clientlib import MANIFEST_HOOK_DICT
from pre_commit.commands.run import Classifier from pre_commit.commands.run import Classifier
def exclude_matches_any(filenames, include, exclude): def exclude_matches_any(
filenames: Sequence[str],
include: str,
exclude: str,
) -> bool:
if exclude == '^$': if exclude == '^$':
return True return True
include_re, exclude_re = re.compile(include), re.compile(exclude) include_re, exclude_re = re.compile(include), re.compile(exclude)
@ -20,7 +26,7 @@ def exclude_matches_any(filenames, include, exclude):
return False return False
def check_useless_excludes(config_file): def check_useless_excludes(config_file: str) -> int:
config = load_config(config_file) config = load_config(config_file)
classifier = Classifier(git.get_all_files()) classifier = Classifier(git.get_all_files())
retv = 0 retv = 0
@ -52,7 +58,7 @@ def check_useless_excludes(config_file):
return retv return retv
def main(argv=None): def main(argv: Optional[Sequence[str]] = None) -> int:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', default=[C.CONFIG_FILE]) parser.add_argument('filenames', nargs='*', default=[C.CONFIG_FILE])
args = parser.parse_args(argv) args = parser.parse_args(argv)

View file

@ -1,12 +1,15 @@
import sys import sys
from typing import Optional
from typing import Sequence
from pre_commit import output from pre_commit import output
def main(argv=None): def main(argv: Optional[Sequence[str]] = None) -> int:
argv = argv if argv is not None else sys.argv[1:] argv = argv if argv is not None else sys.argv[1:]
for arg in argv: for arg in argv:
output.write_line(arg) output.write_line(arg)
return 0
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -1,19 +1,22 @@
import contextlib import contextlib
import sys import sys
from typing import IO
from typing import Optional
from typing import Union
from pre_commit import color from pre_commit import color
from pre_commit import five from pre_commit import five
def get_hook_message( def get_hook_message(
start, start: str,
postfix='', postfix: str = '',
end_msg=None, end_msg: Optional[str] = None,
end_len=0, end_len: int = 0,
end_color=None, end_color: Optional[str] = None,
use_color=None, use_color: Optional[bool] = None,
cols=80, cols: int = 80,
): ) -> str:
"""Prints a message for running a hook. """Prints a message for running a hook.
This currently supports three approaches: This currently supports three approaches:
@ -44,16 +47,13 @@ def get_hook_message(
) )
start...........................................................postfix end start...........................................................postfix end
""" """
if bool(end_msg) == bool(end_len):
raise ValueError('Expected one of (`end_msg`, `end_len`)')
if end_msg is not None and (end_color is None or use_color is None):
raise ValueError(
'`end_color` and `use_color` are required with `end_msg`',
)
if end_len: if end_len:
assert end_msg is None, end_msg
return start + '.' * (cols - len(start) - end_len - 1) return start + '.' * (cols - len(start) - end_len - 1)
else: else:
assert end_msg is not None
assert end_color is not None
assert use_color is not None
return '{}{}{}{}\n'.format( return '{}{}{}{}\n'.format(
start, start,
'.' * (cols - len(start) - len(postfix) - len(end_msg) - 1), '.' * (cols - len(start) - len(postfix) - len(end_msg) - 1),
@ -62,15 +62,16 @@ def get_hook_message(
) )
stdout_byte_stream = getattr(sys.stdout, 'buffer', sys.stdout) def write(s: str, stream: IO[bytes] = sys.stdout.buffer) -> None:
def write(s, stream=stdout_byte_stream):
stream.write(five.to_bytes(s)) stream.write(five.to_bytes(s))
stream.flush() stream.flush()
def write_line(s=None, stream=stdout_byte_stream, logfile_name=None): def write_line(
s: Union[None, str, bytes] = None,
stream: IO[bytes] = sys.stdout.buffer,
logfile_name: Optional[str] = None,
) -> None:
with contextlib.ExitStack() as exit_stack: with contextlib.ExitStack() as exit_stack:
output_streams = [stream] output_streams = [stream]
if logfile_name: if logfile_name:

View file

@ -1,21 +1,28 @@
import os.path import os.path
from typing import Mapping
from typing import NoReturn
from typing import Optional
from typing import Tuple
from identify.identify import parse_shebang_from_file from identify.identify import parse_shebang_from_file
class ExecutableNotFoundError(OSError): class ExecutableNotFoundError(OSError):
def to_output(self): def to_output(self) -> Tuple[int, bytes, None]:
return (1, self.args[0].encode('UTF-8'), b'') return (1, self.args[0].encode('UTF-8'), None)
def parse_filename(filename): def parse_filename(filename: str) -> Tuple[str, ...]:
if not os.path.exists(filename): if not os.path.exists(filename):
return () return ()
else: else:
return parse_shebang_from_file(filename) return parse_shebang_from_file(filename)
def find_executable(exe, _environ=None): def find_executable(
exe: str,
_environ: Optional[Mapping[str, str]] = None,
) -> Optional[str]:
exe = os.path.normpath(exe) exe = os.path.normpath(exe)
if os.sep in exe: if os.sep in exe:
return exe return exe
@ -39,8 +46,8 @@ def find_executable(exe, _environ=None):
return None return None
def normexe(orig): def normexe(orig: str) -> str:
def _error(msg): def _error(msg: str) -> NoReturn:
raise ExecutableNotFoundError(f'Executable `{orig}` {msg}') raise ExecutableNotFoundError(f'Executable `{orig}` {msg}')
if os.sep not in orig and (not os.altsep or os.altsep not in orig): if os.sep not in orig and (not os.altsep or os.altsep not in orig):
@ -58,7 +65,7 @@ def normexe(orig):
return orig return orig
def normalize_cmd(cmd): def normalize_cmd(cmd: Tuple[str, ...]) -> Tuple[str, ...]:
"""Fixes for the following issues on windows """Fixes for the following issues on windows
- https://bugs.python.org/issue8557 - https://bugs.python.org/issue8557
- windows does not parse shebangs - windows does not parse shebangs

View file

@ -1,16 +1,17 @@
import collections
import os.path import os.path
from typing import NamedTuple
from typing import Tuple
class Prefix(collections.namedtuple('Prefix', ('prefix_dir',))): class Prefix(NamedTuple):
__slots__ = () prefix_dir: str
def path(self, *parts): def path(self, *parts: str) -> str:
return os.path.normpath(os.path.join(self.prefix_dir, *parts)) return os.path.normpath(os.path.join(self.prefix_dir, *parts))
def exists(self, *parts): def exists(self, *parts: str) -> bool:
return os.path.exists(self.path(*parts)) return os.path.exists(self.path(*parts))
def star(self, end): def star(self, end: str) -> Tuple[str, ...]:
paths = os.listdir(self.prefix_dir) paths = os.listdir(self.prefix_dir)
return tuple(path for path in paths if path.endswith(end)) return tuple(path for path in paths if path.endswith(end))

View file

@ -2,9 +2,14 @@ import json
import logging import logging
import os import os
import shlex import shlex
from typing import Any
from typing import Dict
from typing import List
from typing import NamedTuple from typing import NamedTuple
from typing import Optional
from typing import Sequence from typing import Sequence
from typing import Set from typing import Set
from typing import Tuple
import pre_commit.constants as C import pre_commit.constants as C
from pre_commit import five from pre_commit import five
@ -15,6 +20,7 @@ from pre_commit.clientlib import META
from pre_commit.languages.all import languages from pre_commit.languages.all import languages
from pre_commit.languages.helpers import environment_dir from pre_commit.languages.helpers import environment_dir
from pre_commit.prefix import Prefix from pre_commit.prefix import Prefix
from pre_commit.store import Store
from pre_commit.util import parse_version from pre_commit.util import parse_version
from pre_commit.util import rmtree from pre_commit.util import rmtree
@ -22,15 +28,15 @@ from pre_commit.util import rmtree
logger = logging.getLogger('pre_commit') logger = logging.getLogger('pre_commit')
def _state(additional_deps): def _state(additional_deps: Sequence[str]) -> object:
return {'additional_dependencies': sorted(additional_deps)} return {'additional_dependencies': sorted(additional_deps)}
def _state_filename(prefix, venv): def _state_filename(prefix: Prefix, venv: str) -> str:
return prefix.path(venv, '.install_state_v' + C.INSTALLED_STATE_VERSION) return prefix.path(venv, '.install_state_v' + C.INSTALLED_STATE_VERSION)
def _read_state(prefix, venv): def _read_state(prefix: Prefix, venv: str) -> Optional[object]:
filename = _state_filename(prefix, venv) filename = _state_filename(prefix, venv)
if not os.path.exists(filename): if not os.path.exists(filename):
return None return None
@ -39,7 +45,7 @@ def _read_state(prefix, venv):
return json.load(f) return json.load(f)
def _write_state(prefix, venv, state): def _write_state(prefix: Prefix, venv: str, state: object) -> None:
state_filename = _state_filename(prefix, venv) state_filename = _state_filename(prefix, venv)
staging = state_filename + 'staging' staging = state_filename + 'staging'
with open(staging, 'w') as state_file: with open(staging, 'w') as state_file:
@ -76,11 +82,11 @@ class Hook(NamedTuple):
verbose: bool verbose: bool
@property @property
def cmd(self): def cmd(self) -> Tuple[str, ...]:
return tuple(shlex.split(self.entry)) + tuple(self.args) return tuple(shlex.split(self.entry)) + tuple(self.args)
@property @property
def install_key(self): def install_key(self) -> Tuple[Prefix, str, str, Tuple[str, ...]]:
return ( return (
self.prefix, self.prefix,
self.language, self.language,
@ -88,7 +94,7 @@ class Hook(NamedTuple):
tuple(self.additional_dependencies), tuple(self.additional_dependencies),
) )
def installed(self): def installed(self) -> bool:
lang = languages[self.language] lang = languages[self.language]
venv = environment_dir(lang.ENVIRONMENT_DIR, self.language_version) venv = environment_dir(lang.ENVIRONMENT_DIR, self.language_version)
return ( return (
@ -101,7 +107,7 @@ class Hook(NamedTuple):
) )
) )
def install(self): def install(self) -> None:
logger.info(f'Installing environment for {self.src}.') logger.info(f'Installing environment for {self.src}.')
logger.info('Once installed this environment will be reused.') logger.info('Once installed this environment will be reused.')
logger.info('This may take a few minutes...') logger.info('This may take a few minutes...')
@ -120,12 +126,12 @@ class Hook(NamedTuple):
# Write our state to indicate we're installed # Write our state to indicate we're installed
_write_state(self.prefix, venv, _state(self.additional_dependencies)) _write_state(self.prefix, venv, _state(self.additional_dependencies))
def run(self, file_args, color): def run(self, file_args: Sequence[str], color: bool) -> Tuple[int, bytes]:
lang = languages[self.language] lang = languages[self.language]
return lang.run_hook(self, file_args, color) return lang.run_hook(self, file_args, color)
@classmethod @classmethod
def create(cls, src, prefix, dct): def create(cls, src: str, prefix: Prefix, dct: Dict[str, Any]) -> 'Hook':
# TODO: have cfgv do this (?) # TODO: have cfgv do this (?)
extra_keys = set(dct) - set(_KEYS) extra_keys = set(dct) - set(_KEYS)
if extra_keys: if extra_keys:
@ -136,9 +142,10 @@ class Hook(NamedTuple):
return cls(src=src, prefix=prefix, **{k: dct[k] for k in _KEYS}) return cls(src=src, prefix=prefix, **{k: dct[k] for k in _KEYS})
def _hook(*hook_dicts, **kwargs): def _hook(
root_config = kwargs.pop('root_config') *hook_dicts: Dict[str, Any],
assert not kwargs, kwargs root_config: Dict[str, Any],
) -> Dict[str, Any]:
ret, rest = dict(hook_dicts[0]), hook_dicts[1:] ret, rest = dict(hook_dicts[0]), hook_dicts[1:]
for dct in rest: for dct in rest:
ret.update(dct) ret.update(dct)
@ -166,8 +173,12 @@ def _hook(*hook_dicts, **kwargs):
return ret return ret
def _non_cloned_repository_hooks(repo_config, store, root_config): def _non_cloned_repository_hooks(
def _prefix(language_name, deps): repo_config: Dict[str, Any],
store: Store,
root_config: Dict[str, Any],
) -> Tuple[Hook, ...]:
def _prefix(language_name: str, deps: Sequence[str]) -> Prefix:
language = languages[language_name] language = languages[language_name]
# pygrep / script / system / docker_image do not have # pygrep / script / system / docker_image do not have
# environments so they work out of the current directory # environments so they work out of the current directory
@ -186,7 +197,11 @@ def _non_cloned_repository_hooks(repo_config, store, root_config):
) )
def _cloned_repository_hooks(repo_config, store, root_config): def _cloned_repository_hooks(
repo_config: Dict[str, Any],
store: Store,
root_config: Dict[str, Any],
) -> Tuple[Hook, ...]:
repo, rev = repo_config['repo'], repo_config['rev'] repo, rev = repo_config['repo'], repo_config['rev']
manifest_path = os.path.join(store.clone(repo, rev), C.MANIFEST_FILE) manifest_path = os.path.join(store.clone(repo, rev), C.MANIFEST_FILE)
by_id = {hook['id']: hook for hook in load_manifest(manifest_path)} by_id = {hook['id']: hook for hook in load_manifest(manifest_path)}
@ -215,16 +230,20 @@ def _cloned_repository_hooks(repo_config, store, root_config):
) )
def _repository_hooks(repo_config, store, root_config): def _repository_hooks(
repo_config: Dict[str, Any],
store: Store,
root_config: Dict[str, Any],
) -> Tuple[Hook, ...]:
if repo_config['repo'] in {LOCAL, META}: if repo_config['repo'] in {LOCAL, META}:
return _non_cloned_repository_hooks(repo_config, store, root_config) return _non_cloned_repository_hooks(repo_config, store, root_config)
else: else:
return _cloned_repository_hooks(repo_config, store, root_config) return _cloned_repository_hooks(repo_config, store, root_config)
def install_hook_envs(hooks, store): def install_hook_envs(hooks: Sequence[Hook], store: Store) -> None:
def _need_installed(): def _need_installed() -> List[Hook]:
seen: Set[Hook] = set() seen: Set[Tuple[Prefix, str, str, Tuple[str, ...]]] = set()
ret = [] ret = []
for hook in hooks: for hook in hooks:
if hook.install_key not in seen and not hook.installed(): if hook.install_key not in seen and not hook.installed():
@ -240,7 +259,7 @@ def install_hook_envs(hooks, store):
hook.install() hook.install()
def all_hooks(root_config, store): def all_hooks(root_config: Dict[str, Any], store: Store) -> Tuple[Hook, ...]:
return tuple( return tuple(
hook hook
for repo in root_config['repos'] for repo in root_config['repos']

View file

@ -4,6 +4,8 @@ import distutils.spawn
import os import os
import subprocess import subprocess
import sys import sys
from typing import Callable
from typing import Dict
from typing import Tuple from typing import Tuple
# work around https://github.com/Homebrew/homebrew-core/issues/30445 # work around https://github.com/Homebrew/homebrew-core/issues/30445
@ -28,7 +30,7 @@ class FatalError(RuntimeError):
pass pass
def _norm_exe(exe): def _norm_exe(exe: str) -> Tuple[str, ...]:
"""Necessary for shebang support on windows. """Necessary for shebang support on windows.
roughly lifted from `identify.identify.parse_shebang` roughly lifted from `identify.identify.parse_shebang`
@ -47,7 +49,7 @@ def _norm_exe(exe):
return tuple(cmd) return tuple(cmd)
def _run_legacy(): def _run_legacy() -> Tuple[int, bytes]:
if __file__.endswith('.legacy'): if __file__.endswith('.legacy'):
raise SystemExit( raise SystemExit(
"bug: pre-commit's script is installed in migration mode\n" "bug: pre-commit's script is installed in migration mode\n"
@ -59,9 +61,9 @@ def _run_legacy():
) )
if HOOK_TYPE == 'pre-push': if HOOK_TYPE == 'pre-push':
stdin = getattr(sys.stdin, 'buffer', sys.stdin).read() stdin = sys.stdin.buffer.read()
else: else:
stdin = None stdin = b''
legacy_hook = os.path.join(HERE, f'{HOOK_TYPE}.legacy') legacy_hook = os.path.join(HERE, f'{HOOK_TYPE}.legacy')
if os.access(legacy_hook, os.X_OK): if os.access(legacy_hook, os.X_OK):
@ -73,7 +75,7 @@ def _run_legacy():
return 0, stdin return 0, stdin
def _validate_config(): def _validate_config() -> None:
cmd = ('git', 'rev-parse', '--show-toplevel') cmd = ('git', 'rev-parse', '--show-toplevel')
top_level = subprocess.check_output(cmd).decode('UTF-8').strip() top_level = subprocess.check_output(cmd).decode('UTF-8').strip()
cfg = os.path.join(top_level, CONFIG) cfg = os.path.join(top_level, CONFIG)
@ -97,7 +99,7 @@ def _validate_config():
) )
def _exe(): def _exe() -> Tuple[str, ...]:
with open(os.devnull, 'wb') as devnull: with open(os.devnull, 'wb') as devnull:
for exe in (INSTALL_PYTHON, sys.executable): for exe in (INSTALL_PYTHON, sys.executable):
try: try:
@ -117,11 +119,11 @@ def _exe():
) )
def _rev_exists(rev): def _rev_exists(rev: str) -> bool:
return not subprocess.call(('git', 'rev-list', '--quiet', rev)) return not subprocess.call(('git', 'rev-list', '--quiet', rev))
def _pre_push(stdin): def _pre_push(stdin: bytes) -> Tuple[str, ...]:
remote = sys.argv[1] remote = sys.argv[1]
opts: Tuple[str, ...] = () opts: Tuple[str, ...] = ()
@ -158,8 +160,8 @@ def _pre_push(stdin):
raise EarlyExit() raise EarlyExit()
def _opts(stdin): def _opts(stdin: bytes) -> Tuple[str, ...]:
fns = { fns: Dict[str, Callable[[bytes], Tuple[str, ...]]] = {
'prepare-commit-msg': lambda _: ('--commit-msg-filename', sys.argv[1]), 'prepare-commit-msg': lambda _: ('--commit-msg-filename', sys.argv[1]),
'commit-msg': lambda _: ('--commit-msg-filename', sys.argv[1]), 'commit-msg': lambda _: ('--commit-msg-filename', sys.argv[1]),
'pre-merge-commit': lambda _: (), 'pre-merge-commit': lambda _: (),
@ -171,13 +173,14 @@ def _opts(stdin):
if sys.version_info < (3, 7): # https://bugs.python.org/issue25942 if sys.version_info < (3, 7): # https://bugs.python.org/issue25942
def _subprocess_call(cmd): # this is the python 2.7 implementation # this is the python 2.7 implementation
def _subprocess_call(cmd: Tuple[str, ...]) -> int:
return subprocess.Popen(cmd).wait() return subprocess.Popen(cmd).wait()
else: else:
_subprocess_call = subprocess.call _subprocess_call = subprocess.call
def main(): def main() -> int:
retv, stdin = _run_legacy() retv, stdin = _run_legacy()
try: try:
_validate_config() _validate_config()

View file

@ -2,6 +2,7 @@ import contextlib
import logging import logging
import os.path import os.path
import time import time
from typing import Generator
from pre_commit import git from pre_commit import git
from pre_commit.util import CalledProcessError from pre_commit.util import CalledProcessError
@ -14,7 +15,7 @@ from pre_commit.xargs import xargs
logger = logging.getLogger('pre_commit') logger = logging.getLogger('pre_commit')
def _git_apply(patch): def _git_apply(patch: str) -> None:
args = ('apply', '--whitespace=nowarn', patch) args = ('apply', '--whitespace=nowarn', patch)
try: try:
cmd_output_b('git', *args) cmd_output_b('git', *args)
@ -24,7 +25,7 @@ def _git_apply(patch):
@contextlib.contextmanager @contextlib.contextmanager
def _intent_to_add_cleared(): def _intent_to_add_cleared() -> Generator[None, None, None]:
intent_to_add = git.intent_to_add_files() intent_to_add = git.intent_to_add_files()
if intent_to_add: if intent_to_add:
logger.warning('Unstaged intent-to-add files detected.') logger.warning('Unstaged intent-to-add files detected.')
@ -39,7 +40,7 @@ def _intent_to_add_cleared():
@contextlib.contextmanager @contextlib.contextmanager
def _unstaged_changes_cleared(patch_dir): def _unstaged_changes_cleared(patch_dir: str) -> Generator[None, None, None]:
tree = cmd_output('git', 'write-tree')[1].strip() tree = cmd_output('git', 'write-tree')[1].strip()
retcode, diff_stdout_binary, _ = cmd_output_b( retcode, diff_stdout_binary, _ = cmd_output_b(
'git', 'diff-index', '--ignore-submodules', '--binary', 'git', 'diff-index', '--ignore-submodules', '--binary',
@ -84,7 +85,7 @@ def _unstaged_changes_cleared(patch_dir):
@contextlib.contextmanager @contextlib.contextmanager
def staged_files_only(patch_dir): def staged_files_only(patch_dir: str) -> Generator[None, None, None]:
"""Clear any unstaged changes from the git working directory inside this """Clear any unstaged changes from the git working directory inside this
context. context.
""" """

View file

@ -3,6 +3,12 @@ import logging
import os.path import os.path
import sqlite3 import sqlite3
import tempfile import tempfile
from typing import Callable
from typing import Generator
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
import pre_commit.constants as C import pre_commit.constants as C
from pre_commit import file_lock from pre_commit import file_lock
@ -18,7 +24,7 @@ from pre_commit.util import rmtree
logger = logging.getLogger('pre_commit') logger = logging.getLogger('pre_commit')
def _get_default_directory(): def _get_default_directory() -> str:
"""Returns the default directory for the Store. This is intentionally """Returns the default directory for the Store. This is intentionally
underscored to indicate that `Store.get_default_directory` is the intended underscored to indicate that `Store.get_default_directory` is the intended
way to get this information. This is also done so way to get this information. This is also done so
@ -34,7 +40,7 @@ def _get_default_directory():
class Store: class Store:
get_default_directory = staticmethod(_get_default_directory) get_default_directory = staticmethod(_get_default_directory)
def __init__(self, directory=None): def __init__(self, directory: Optional[str] = None) -> None:
self.directory = directory or Store.get_default_directory() self.directory = directory or Store.get_default_directory()
self.db_path = os.path.join(self.directory, 'db.db') self.db_path = os.path.join(self.directory, 'db.db')
@ -66,21 +72,24 @@ class Store:
' PRIMARY KEY (repo, ref)' ' PRIMARY KEY (repo, ref)'
');', ');',
) )
self._create_config_table_if_not_exists(db) self._create_config_table(db)
# Atomic file move # Atomic file move
os.rename(tmpfile, self.db_path) os.rename(tmpfile, self.db_path)
@contextlib.contextmanager @contextlib.contextmanager
def exclusive_lock(self): def exclusive_lock(self) -> Generator[None, None, None]:
def blocked_cb(): # pragma: no cover (tests are single-process) def blocked_cb() -> None: # pragma: no cover (tests are in-process)
logger.info('Locking pre-commit directory') logger.info('Locking pre-commit directory')
with file_lock.lock(os.path.join(self.directory, '.lock'), blocked_cb): with file_lock.lock(os.path.join(self.directory, '.lock'), blocked_cb):
yield yield
@contextlib.contextmanager @contextlib.contextmanager
def connect(self, db_path=None): def connect(
self,
db_path: Optional[str] = None,
) -> Generator[sqlite3.Connection, None, None]:
db_path = db_path or self.db_path db_path = db_path or self.db_path
# sqlite doesn't close its fd with its contextmanager >.< # sqlite doesn't close its fd with its contextmanager >.<
# contextlib.closing fixes this. # contextlib.closing fixes this.
@ -91,24 +100,29 @@ class Store:
yield db yield db
@classmethod @classmethod
def db_repo_name(cls, repo, deps): def db_repo_name(cls, repo: str, deps: Sequence[str]) -> str:
if deps: if deps:
return '{}:{}'.format(repo, ','.join(sorted(deps))) return '{}:{}'.format(repo, ','.join(sorted(deps)))
else: else:
return repo return repo
def _new_repo(self, repo, ref, deps, make_strategy): def _new_repo(
self,
repo: str,
ref: str,
deps: Sequence[str],
make_strategy: Callable[[str], None],
) -> str:
repo = self.db_repo_name(repo, deps) repo = self.db_repo_name(repo, deps)
def _get_result(): def _get_result() -> Optional[str]:
# Check if we already exist # Check if we already exist
with self.connect() as db: with self.connect() as db:
result = db.execute( result = db.execute(
'SELECT path FROM repos WHERE repo = ? AND ref = ?', 'SELECT path FROM repos WHERE repo = ? AND ref = ?',
(repo, ref), (repo, ref),
).fetchone() ).fetchone()
if result: return result[0] if result else None
return result[0]
result = _get_result() result = _get_result()
if result: if result:
@ -133,14 +147,14 @@ class Store:
) )
return directory return directory
def _complete_clone(self, ref, git_cmd): def _complete_clone(self, ref: str, git_cmd: Callable[..., None]) -> None:
"""Perform a complete clone of a repository and its submodules """ """Perform a complete clone of a repository and its submodules """
git_cmd('fetch', 'origin', '--tags') git_cmd('fetch', 'origin', '--tags')
git_cmd('checkout', ref) git_cmd('checkout', ref)
git_cmd('submodule', 'update', '--init', '--recursive') git_cmd('submodule', 'update', '--init', '--recursive')
def _shallow_clone(self, ref, git_cmd): def _shallow_clone(self, ref: str, git_cmd: Callable[..., None]) -> None:
"""Perform a shallow clone of a repository and its submodules """ """Perform a shallow clone of a repository and its submodules """
git_config = 'protocol.version=2' git_config = 'protocol.version=2'
@ -151,14 +165,14 @@ class Store:
'--depth=1', '--depth=1',
) )
def clone(self, repo, ref, deps=()): def clone(self, repo: str, ref: str, deps: Sequence[str] = ()) -> str:
"""Clone the given url and checkout the specific ref.""" """Clone the given url and checkout the specific ref."""
def clone_strategy(directory): def clone_strategy(directory: str) -> None:
git.init_repo(directory, repo) git.init_repo(directory, repo)
env = git.no_git_env() env = git.no_git_env()
def _git_cmd(*args): def _git_cmd(*args: str) -> None:
cmd_output_b('git', *args, cwd=directory, env=env) cmd_output_b('git', *args, cwd=directory, env=env)
try: try:
@ -173,8 +187,8 @@ class Store:
'pre_commit_dummy_package.gemspec', 'setup.py', 'environment.yml', 'pre_commit_dummy_package.gemspec', 'setup.py', 'environment.yml',
) )
def make_local(self, deps): def make_local(self, deps: Sequence[str]) -> str:
def make_local_strategy(directory): def make_local_strategy(directory: str) -> None:
for resource in self.LOCAL_RESOURCES: for resource in self.LOCAL_RESOURCES:
contents = resource_text(f'empty_template_{resource}') contents = resource_text(f'empty_template_{resource}')
with open(os.path.join(directory, resource), 'w') as f: with open(os.path.join(directory, resource), 'w') as f:
@ -183,7 +197,7 @@ class Store:
env = git.no_git_env() env = git.no_git_env()
# initialize the git repository so it looks more like cloned repos # initialize the git repository so it looks more like cloned repos
def _git_cmd(*args): def _git_cmd(*args: str) -> None:
cmd_output_b('git', *args, cwd=directory, env=env) cmd_output_b('git', *args, cwd=directory, env=env)
git.init_repo(directory, '<<unknown>>') git.init_repo(directory, '<<unknown>>')
@ -194,7 +208,7 @@ class Store:
'local', C.LOCAL_REPO_VERSION, deps, make_local_strategy, 'local', C.LOCAL_REPO_VERSION, deps, make_local_strategy,
) )
def _create_config_table_if_not_exists(self, db): def _create_config_table(self, db: sqlite3.Connection) -> None:
db.executescript( db.executescript(
'CREATE TABLE IF NOT EXISTS configs (' 'CREATE TABLE IF NOT EXISTS configs ('
' path TEXT NOT NULL,' ' path TEXT NOT NULL,'
@ -202,32 +216,32 @@ class Store:
');', ');',
) )
def mark_config_used(self, path): def mark_config_used(self, path: str) -> None:
path = os.path.realpath(path) path = os.path.realpath(path)
# don't insert config files that do not exist # don't insert config files that do not exist
if not os.path.exists(path): if not os.path.exists(path):
return return
with self.connect() as db: with self.connect() as db:
# TODO: eventually remove this and only create in _create # TODO: eventually remove this and only create in _create
self._create_config_table_if_not_exists(db) self._create_config_table(db)
db.execute('INSERT OR IGNORE INTO configs VALUES (?)', (path,)) db.execute('INSERT OR IGNORE INTO configs VALUES (?)', (path,))
def select_all_configs(self): def select_all_configs(self) -> List[str]:
with self.connect() as db: with self.connect() as db:
self._create_config_table_if_not_exists(db) self._create_config_table(db)
rows = db.execute('SELECT path FROM configs').fetchall() rows = db.execute('SELECT path FROM configs').fetchall()
return [path for path, in rows] return [path for path, in rows]
def delete_configs(self, configs): def delete_configs(self, configs: List[str]) -> None:
with self.connect() as db: with self.connect() as db:
rows = [(path,) for path in configs] rows = [(path,) for path in configs]
db.executemany('DELETE FROM configs WHERE path = ?', rows) db.executemany('DELETE FROM configs WHERE path = ?', rows)
def select_all_repos(self): def select_all_repos(self) -> List[Tuple[str, str, str]]:
with self.connect() as db: with self.connect() as db:
return db.execute('SELECT repo, ref, path from repos').fetchall() return db.execute('SELECT repo, ref, path from repos').fetchall()
def delete_repo(self, db_repo_name, ref, path): def delete_repo(self, db_repo_name: str, ref: str, path: str) -> None:
with self.connect() as db: with self.connect() as db:
db.execute( db.execute(
'DELETE FROM repos WHERE repo = ? and ref = ?', 'DELETE FROM repos WHERE repo = ? and ref = ?',

View file

@ -6,6 +6,16 @@ import stat
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
from types import TracebackType
from typing import Any
from typing import Callable
from typing import Dict
from typing import Generator
from typing import IO
from typing import Optional
from typing import Tuple
from typing import Type
from typing import Union
from pre_commit import five from pre_commit import five
from pre_commit import parse_shebang from pre_commit import parse_shebang
@ -17,8 +27,10 @@ else: # pragma: no cover (<PY37)
from importlib_resources import open_binary from importlib_resources import open_binary
from importlib_resources import read_text from importlib_resources import read_text
EnvironT = Union[Dict[str, str], 'os._Environ']
def mkdirp(path):
def mkdirp(path: str) -> None:
try: try:
os.makedirs(path) os.makedirs(path)
except OSError: except OSError:
@ -27,7 +39,7 @@ def mkdirp(path):
@contextlib.contextmanager @contextlib.contextmanager
def clean_path_on_failure(path): def clean_path_on_failure(path: str) -> Generator[None, None, None]:
"""Cleans up the directory on an exceptional failure.""" """Cleans up the directory on an exceptional failure."""
try: try:
yield yield
@ -38,12 +50,12 @@ def clean_path_on_failure(path):
@contextlib.contextmanager @contextlib.contextmanager
def noop_context(): def noop_context() -> Generator[None, None, None]:
yield yield
@contextlib.contextmanager @contextlib.contextmanager
def tmpdir(): def tmpdir() -> Generator[str, None, None]:
"""Contextmanager to create a temporary directory. It will be cleaned up """Contextmanager to create a temporary directory. It will be cleaned up
afterwards. afterwards.
""" """
@ -54,15 +66,15 @@ def tmpdir():
rmtree(tempdir) rmtree(tempdir)
def resource_bytesio(filename): def resource_bytesio(filename: str) -> IO[bytes]:
return open_binary('pre_commit.resources', filename) return open_binary('pre_commit.resources', filename)
def resource_text(filename): def resource_text(filename: str) -> str:
return read_text('pre_commit.resources', filename) return read_text('pre_commit.resources', filename)
def make_executable(filename): def make_executable(filename: str) -> None:
original_mode = os.stat(filename).st_mode original_mode = os.stat(filename).st_mode
os.chmod( os.chmod(
filename, original_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH, filename, original_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH,
@ -70,18 +82,23 @@ def make_executable(filename):
class CalledProcessError(RuntimeError): class CalledProcessError(RuntimeError):
def __init__(self, returncode, cmd, expected_returncode, stdout, stderr): def __init__(
super().__init__( self,
returncode, cmd, expected_returncode, stdout, stderr, returncode: int,
) cmd: Tuple[str, ...],
expected_returncode: int,
stdout: bytes,
stderr: Optional[bytes],
) -> None:
super().__init__(returncode, cmd, expected_returncode, stdout, stderr)
self.returncode = returncode self.returncode = returncode
self.cmd = cmd self.cmd = cmd
self.expected_returncode = expected_returncode self.expected_returncode = expected_returncode
self.stdout = stdout self.stdout = stdout
self.stderr = stderr self.stderr = stderr
def __bytes__(self): def __bytes__(self) -> bytes:
def _indent_or_none(part): def _indent_or_none(part: Optional[bytes]) -> bytes:
if part: if part:
return b'\n ' + part.replace(b'\n', b'\n ') return b'\n ' + part.replace(b'\n', b'\n ')
else: else:
@ -97,11 +114,14 @@ class CalledProcessError(RuntimeError):
b'stderr:', _indent_or_none(self.stderr), b'stderr:', _indent_or_none(self.stderr),
)) ))
def __str__(self): def __str__(self) -> str:
return self.__bytes__().decode('UTF-8') return self.__bytes__().decode('UTF-8')
def _cmd_kwargs(*cmd, **kwargs): def _cmd_kwargs(
*cmd: str,
**kwargs: Any,
) -> Tuple[Tuple[str, ...], Dict[str, Any]]:
# py2/py3 on windows are more strict about the types here # py2/py3 on windows are more strict about the types here
cmd = tuple(five.n(arg) for arg in cmd) cmd = tuple(five.n(arg) for arg in cmd)
kwargs['env'] = { kwargs['env'] = {
@ -113,7 +133,10 @@ def _cmd_kwargs(*cmd, **kwargs):
return cmd, kwargs return cmd, kwargs
def cmd_output_b(*cmd, **kwargs): def cmd_output_b(
*cmd: str,
**kwargs: Any,
) -> Tuple[int, bytes, Optional[bytes]]:
retcode = kwargs.pop('retcode', 0) retcode = kwargs.pop('retcode', 0)
cmd, kwargs = _cmd_kwargs(*cmd, **kwargs) cmd, kwargs = _cmd_kwargs(*cmd, **kwargs)
@ -132,7 +155,7 @@ def cmd_output_b(*cmd, **kwargs):
return returncode, stdout_b, stderr_b return returncode, stdout_b, stderr_b
def cmd_output(*cmd, **kwargs): def cmd_output(*cmd: str, **kwargs: Any) -> Tuple[int, str, Optional[str]]:
returncode, stdout_b, stderr_b = cmd_output_b(*cmd, **kwargs) returncode, stdout_b, stderr_b = cmd_output_b(*cmd, **kwargs)
stdout = stdout_b.decode('UTF-8') if stdout_b is not None else None stdout = stdout_b.decode('UTF-8') if stdout_b is not None else None
stderr = stderr_b.decode('UTF-8') if stderr_b is not None else None stderr = stderr_b.decode('UTF-8') if stderr_b is not None else None
@ -144,10 +167,11 @@ if os.name != 'nt': # pragma: windows no cover
import termios import termios
class Pty: class Pty:
def __init__(self): def __init__(self) -> None:
self.r = self.w = None self.r: Optional[int] = None
self.w: Optional[int] = None
def __enter__(self): def __enter__(self) -> 'Pty':
self.r, self.w = openpty() self.r, self.w = openpty()
# tty flags normally change \n to \r\n # tty flags normally change \n to \r\n
@ -158,21 +182,29 @@ if os.name != 'nt': # pragma: windows no cover
return self return self
def close_w(self): def close_w(self) -> None:
if self.w is not None: if self.w is not None:
os.close(self.w) os.close(self.w)
self.w = None self.w = None
def close_r(self): def close_r(self) -> None:
assert self.r is not None assert self.r is not None
os.close(self.r) os.close(self.r)
self.r = None self.r = None
def __exit__(self, exc_type, exc_value, traceback): def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self.close_w() self.close_w()
self.close_r() self.close_r()
def cmd_output_p(*cmd, **kwargs): def cmd_output_p(
*cmd: str,
**kwargs: Any,
) -> Tuple[int, bytes, Optional[bytes]]:
assert kwargs.pop('retcode') is None assert kwargs.pop('retcode') is None
assert kwargs['stderr'] == subprocess.STDOUT, kwargs['stderr'] assert kwargs['stderr'] == subprocess.STDOUT, kwargs['stderr']
cmd, kwargs = _cmd_kwargs(*cmd, **kwargs) cmd, kwargs = _cmd_kwargs(*cmd, **kwargs)
@ -183,6 +215,7 @@ if os.name != 'nt': # pragma: windows no cover
return e.to_output() return e.to_output()
with open(os.devnull) as devnull, Pty() as pty: with open(os.devnull) as devnull, Pty() as pty:
assert pty.r is not None
kwargs.update({'stdin': devnull, 'stdout': pty.w, 'stderr': pty.w}) kwargs.update({'stdin': devnull, 'stdout': pty.w, 'stderr': pty.w})
proc = subprocess.Popen(cmd, **kwargs) proc = subprocess.Popen(cmd, **kwargs)
pty.close_w() pty.close_w()
@ -206,9 +239,13 @@ else: # pragma: no cover
cmd_output_p = cmd_output_b cmd_output_p = cmd_output_b
def rmtree(path): def rmtree(path: str) -> None:
"""On windows, rmtree fails for readonly dirs.""" """On windows, rmtree fails for readonly dirs."""
def handle_remove_readonly(func, path, exc): def handle_remove_readonly(
func: Callable[..., Any],
path: str,
exc: Tuple[Type[OSError], OSError, TracebackType],
) -> None:
excvalue = exc[1] excvalue = exc[1]
if ( if (
func in (os.rmdir, os.remove, os.unlink) and func in (os.rmdir, os.remove, os.unlink) and
@ -222,6 +259,6 @@ def rmtree(path):
shutil.rmtree(path, ignore_errors=False, onerror=handle_remove_readonly) shutil.rmtree(path, ignore_errors=False, onerror=handle_remove_readonly)
def parse_version(s): def parse_version(s: str) -> Tuple[int, ...]:
"""poor man's version comparison""" """poor man's version comparison"""
return tuple(int(p) for p in s.split('.')) return tuple(int(p) for p in s.split('.'))

View file

@ -4,14 +4,26 @@ import math
import os import os
import subprocess import subprocess
import sys import sys
from typing import Any
from typing import Callable
from typing import Generator
from typing import Iterable
from typing import List from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TypeVar
from pre_commit import parse_shebang from pre_commit import parse_shebang
from pre_commit.util import cmd_output_b from pre_commit.util import cmd_output_b
from pre_commit.util import cmd_output_p from pre_commit.util import cmd_output_p
from pre_commit.util import EnvironT
TArg = TypeVar('TArg')
TRet = TypeVar('TRet')
def _environ_size(_env=None): def _environ_size(_env: Optional[EnvironT] = None) -> int:
environ = _env if _env is not None else getattr(os, 'environb', os.environ) environ = _env if _env is not None else getattr(os, 'environb', os.environ)
size = 8 * len(environ) # number of pointers in `envp` size = 8 * len(environ) # number of pointers in `envp`
for k, v in environ.items(): for k, v in environ.items():
@ -19,7 +31,7 @@ def _environ_size(_env=None):
return size return size
def _get_platform_max_length(): # pragma: no cover (platform specific) def _get_platform_max_length() -> int: # pragma: no cover (platform specific)
if os.name == 'posix': if os.name == 'posix':
maximum = os.sysconf('SC_ARG_MAX') - 2048 - _environ_size() maximum = os.sysconf('SC_ARG_MAX') - 2048 - _environ_size()
maximum = max(min(maximum, 2 ** 17), 2 ** 12) maximum = max(min(maximum, 2 ** 17), 2 ** 12)
@ -31,7 +43,7 @@ def _get_platform_max_length(): # pragma: no cover (platform specific)
return 2 ** 12 return 2 ** 12
def _command_length(*cmd): def _command_length(*cmd: str) -> int:
full_cmd = ' '.join(cmd) full_cmd = ' '.join(cmd)
# win32 uses the amount of characters, more details at: # win32 uses the amount of characters, more details at:
@ -47,7 +59,12 @@ class ArgumentTooLongError(RuntimeError):
pass pass
def partition(cmd, varargs, target_concurrency, _max_length=None): def partition(
cmd: Sequence[str],
varargs: Sequence[str],
target_concurrency: int,
_max_length: Optional[int] = None,
) -> Tuple[Tuple[str, ...], ...]:
_max_length = _max_length or _get_platform_max_length() _max_length = _max_length or _get_platform_max_length()
# Generally, we try to partition evenly into at least `target_concurrency` # Generally, we try to partition evenly into at least `target_concurrency`
@ -87,7 +104,10 @@ def partition(cmd, varargs, target_concurrency, _max_length=None):
@contextlib.contextmanager @contextlib.contextmanager
def _thread_mapper(maxsize): def _thread_mapper(maxsize: int) -> Generator[
Callable[[Callable[[TArg], TRet], Iterable[TArg]], Iterable[TRet]],
None, None,
]:
if maxsize == 1: if maxsize == 1:
yield map yield map
else: else:
@ -95,7 +115,11 @@ def _thread_mapper(maxsize):
yield ex.map yield ex.map
def xargs(cmd, varargs, **kwargs): def xargs(
cmd: Tuple[str, ...],
varargs: Sequence[str],
**kwargs: Any,
) -> Tuple[int, bytes]:
"""A simplified implementation of xargs. """A simplified implementation of xargs.
color: Make a pty if on a platform that supports it color: Make a pty if on a platform that supports it
@ -115,7 +139,9 @@ def xargs(cmd, varargs, **kwargs):
partitions = partition(cmd, varargs, target_concurrency, max_length) partitions = partition(cmd, varargs, target_concurrency, max_length)
def run_cmd_partition(run_cmd): def run_cmd_partition(
run_cmd: Tuple[str, ...],
) -> Tuple[int, bytes, Optional[bytes]]:
return cmd_fn( return cmd_fn(
*run_cmd, retcode=None, stderr=subprocess.STDOUT, **kwargs, *run_cmd, retcode=None, stderr=subprocess.STDOUT, **kwargs,
) )

View file

@ -57,6 +57,7 @@ universal = True
check_untyped_defs = true check_untyped_defs = true
disallow_any_generics = true disallow_any_generics = true
disallow_incomplete_defs = true disallow_incomplete_defs = true
disallow_untyped_defs = true
no_implicit_optional = true no_implicit_optional = true
[mypy-testing.*] [mypy-testing.*]

View file

@ -37,21 +37,21 @@ def test_use_color_no_tty():
def test_use_color_tty_with_color_support(): def test_use_color_tty_with_color_support():
with mock.patch.object(sys.stdout, 'isatty', return_value=True): with mock.patch.object(sys.stdout, 'isatty', return_value=True):
with mock.patch('pre_commit.color.terminal_supports_color', True): with mock.patch('pre_commit.color.terminal_supports_color', True):
with envcontext.envcontext([('TERM', envcontext.UNSET)]): with envcontext.envcontext((('TERM', envcontext.UNSET),)):
assert use_color('auto') is True assert use_color('auto') is True
def test_use_color_tty_without_color_support(): def test_use_color_tty_without_color_support():
with mock.patch.object(sys.stdout, 'isatty', return_value=True): with mock.patch.object(sys.stdout, 'isatty', return_value=True):
with mock.patch('pre_commit.color.terminal_supports_color', False): with mock.patch('pre_commit.color.terminal_supports_color', False):
with envcontext.envcontext([('TERM', envcontext.UNSET)]): with envcontext.envcontext((('TERM', envcontext.UNSET),)):
assert use_color('auto') is False assert use_color('auto') is False
def test_use_color_dumb_term(): def test_use_color_dumb_term():
with mock.patch.object(sys.stdout, 'isatty', return_value=True): with mock.patch.object(sys.stdout, 'isatty', return_value=True):
with mock.patch('pre_commit.color.terminal_supports_color', True): with mock.patch('pre_commit.color.terminal_supports_color', True):
with envcontext.envcontext([('TERM', 'dumb')]): with envcontext.envcontext((('TERM', 'dumb'),)):
assert use_color('auto') is False assert use_color('auto') is False

View file

@ -24,7 +24,7 @@ def test_init_templatedir(tmpdir, tempdir_factory, store, cap_out):
'[WARNING] maybe `git config --global init.templateDir', '[WARNING] maybe `git config --global init.templateDir',
) )
with envcontext([('GIT_TEMPLATE_DIR', target)]): with envcontext((('GIT_TEMPLATE_DIR', target),)):
path = make_consuming_repo(tempdir_factory, 'script_hooks_repo') path = make_consuming_repo(tempdir_factory, 'script_hooks_repo')
with cwd(path): with cwd(path):
@ -52,7 +52,7 @@ def test_init_templatedir_already_set(tmpdir, tempdir_factory, store, cap_out):
def test_init_templatedir_not_set(tmpdir, store, cap_out): def test_init_templatedir_not_set(tmpdir, store, cap_out):
# set HOME to ignore the current `.gitconfig` # set HOME to ignore the current `.gitconfig`
with envcontext([('HOME', str(tmpdir))]): with envcontext((('HOME', str(tmpdir)),)):
with tmpdir.join('tmpl').ensure_dir().as_cwd(): with tmpdir.join('tmpl').ensure_dir().as_cwd():
# we have not set init.templateDir so this should produce a warning # we have not set init.templateDir so this should produce a warning
init_templatedir( init_templatedir(

View file

@ -274,5 +274,5 @@ def fake_log_handler():
@pytest.fixture(scope='session', autouse=True) @pytest.fixture(scope='session', autouse=True)
def set_git_templatedir(tmpdir_factory): def set_git_templatedir(tmpdir_factory):
tdir = str(tmpdir_factory.mktemp('git_template_dir')) tdir = str(tmpdir_factory.mktemp('git_template_dir'))
with envcontext([('GIT_TEMPLATE_DIR', tdir)]): with envcontext((('GIT_TEMPLATE_DIR', tdir),)):
yield yield

View file

@ -93,7 +93,7 @@ def test_exception_safety():
env = {'hello': 'world'} env = {'hello': 'world'}
with pytest.raises(MyError): with pytest.raises(MyError):
with envcontext([('foo', 'bar')], _env=env): with envcontext((('foo', 'bar'),), _env=env):
raise MyError() raise MyError()
assert env == {'hello': 'world'} assert env == {'hello': 'world'}
@ -101,6 +101,6 @@ def test_exception_safety():
def test_integration_os_environ(): def test_integration_os_environ():
with mock.patch.dict(os.environ, {'FOO': 'bar'}, clear=True): with mock.patch.dict(os.environ, {'FOO': 'bar'}, clear=True):
assert os.environ == {'FOO': 'bar'} assert os.environ == {'FOO': 'bar'}
with envcontext([('HERP', 'derp')]): with envcontext((('HERP', 'derp'),)):
assert os.environ == {'FOO': 'bar', 'HERP': 'derp'} assert os.environ == {'FOO': 'bar', 'HERP': 'derp'}
assert os.environ == {'FOO': 'bar'} assert os.environ == {'FOO': 'bar'}

View file

@ -1,23 +1,31 @@
import functools
import inspect import inspect
from typing import Sequence
from typing import Tuple
import pytest import pytest
from pre_commit.languages.all import all_languages from pre_commit.languages.all import all_languages
from pre_commit.languages.all import languages from pre_commit.languages.all import languages
from pre_commit.prefix import Prefix
ArgSpec = functools.partial( def _argspec(annotations):
inspect.FullArgSpec, varargs=None, varkw=None, defaults=None, args = [k for k in annotations if k != 'return']
kwonlyargs=[], kwonlydefaults=None, annotations={}, return inspect.FullArgSpec(
args=args, annotations=annotations,
varargs=None, varkw=None, defaults=None,
kwonlyargs=[], kwonlydefaults=None,
) )
@pytest.mark.parametrize('language', all_languages) @pytest.mark.parametrize('language', all_languages)
def test_install_environment_argspec(language): def test_install_environment_argspec(language):
expected_argspec = ArgSpec( expected_argspec = _argspec({
args=['prefix', 'version', 'additional_dependencies'], 'return': None,
) 'prefix': Prefix,
'version': str,
'additional_dependencies': Sequence[str],
})
argspec = inspect.getfullargspec(languages[language].install_environment) argspec = inspect.getfullargspec(languages[language].install_environment)
assert argspec == expected_argspec assert argspec == expected_argspec
@ -29,20 +37,26 @@ def test_ENVIRONMENT_DIR(language):
@pytest.mark.parametrize('language', all_languages) @pytest.mark.parametrize('language', all_languages)
def test_run_hook_argspec(language): def test_run_hook_argspec(language):
expected_argspec = ArgSpec(args=['hook', 'file_args', 'color']) expected_argspec = _argspec({
'return': Tuple[int, bytes],
'hook': 'Hook', 'file_args': Sequence[str], 'color': bool,
})
argspec = inspect.getfullargspec(languages[language].run_hook) argspec = inspect.getfullargspec(languages[language].run_hook)
assert argspec == expected_argspec assert argspec == expected_argspec
@pytest.mark.parametrize('language', all_languages) @pytest.mark.parametrize('language', all_languages)
def test_get_default_version_argspec(language): def test_get_default_version_argspec(language):
expected_argspec = ArgSpec(args=[]) expected_argspec = _argspec({'return': str})
argspec = inspect.getfullargspec(languages[language].get_default_version) argspec = inspect.getfullargspec(languages[language].get_default_version)
assert argspec == expected_argspec assert argspec == expected_argspec
@pytest.mark.parametrize('language', all_languages) @pytest.mark.parametrize('language', all_languages)
def test_healthy_argspec(language): def test_healthy_argspec(language):
expected_argspec = ArgSpec(args=['prefix', 'language_version']) expected_argspec = _argspec({
'return': bool,
'prefix': Prefix, 'language_version': str,
})
argspec = inspect.getfullargspec(languages[language].healthy) argspec = inspect.getfullargspec(languages[language].healthy)
assert argspec == expected_argspec assert argspec == expected_argspec

View file

@ -7,7 +7,7 @@ from pre_commit.util import CalledProcessError
def test_docker_is_running_process_error(): def test_docker_is_running_process_error():
with mock.patch( with mock.patch(
'pre_commit.languages.docker.cmd_output_b', 'pre_commit.languages.docker.cmd_output_b',
side_effect=CalledProcessError(None, None, None, None, None), side_effect=CalledProcessError(1, (), 0, b'', None),
): ):
assert docker.docker_is_running() is False assert docker.docker_is_running() is False

View file

@ -17,7 +17,7 @@ def test_basic_get_default_version():
def test_basic_healthy(): def test_basic_healthy():
assert helpers.basic_healthy(None, None) is True assert helpers.basic_healthy(Prefix('.'), 'default') is True
def test_failed_setup_command_does_not_unicode_error(): def test_failed_setup_command_does_not_unicode_error():
@ -77,4 +77,6 @@ def test_target_concurrency_cpu_count_not_implemented():
def test_shuffled_is_deterministic(): def test_shuffled_is_deterministic():
assert helpers._shuffled(range(10)) == [3, 7, 8, 2, 4, 6, 5, 1, 0, 9] seq = [str(i) for i in range(10)]
expected = ['3', '7', '8', '2', '4', '6', '5', '1', '0', '9']
assert helpers._shuffled(seq) == expected

View file

@ -1,25 +1,21 @@
import logging
from pre_commit import color from pre_commit import color
from pre_commit.logging_handler import LoggingHandler from pre_commit.logging_handler import LoggingHandler
class FakeLogRecord: def _log_record(message, level):
def __init__(self, message, levelname, levelno): return logging.LogRecord('name', level, '', 1, message, {}, None)
self.message = message
self.levelname = levelname
self.levelno = levelno
def getMessage(self):
return self.message
def test_logging_handler_color(cap_out): def test_logging_handler_color(cap_out):
handler = LoggingHandler(True) handler = LoggingHandler(True)
handler.emit(FakeLogRecord('hi', 'WARNING', 30)) handler.emit(_log_record('hi', logging.WARNING))
ret = cap_out.get() ret = cap_out.get()
assert ret == color.YELLOW + '[WARNING]' + color.NORMAL + ' hi\n' assert ret == color.YELLOW + '[WARNING]' + color.NORMAL + ' hi\n'
def test_logging_handler_no_color(cap_out): def test_logging_handler_no_color(cap_out):
handler = LoggingHandler(False) handler = LoggingHandler(False)
handler.emit(FakeLogRecord('hi', 'WARNING', 30)) handler.emit(_log_record('hi', logging.WARNING))
assert cap_out.get() == '[WARNING] hi\n' assert cap_out.get() == '[WARNING] hi\n'

View file

@ -1,8 +1,5 @@
import argparse import argparse
import os.path import os.path
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from unittest import mock from unittest import mock
import pytest import pytest
@ -27,25 +24,24 @@ def test_append_replace_default(argv, expected):
assert parser.parse_args(argv).f == expected assert parser.parse_args(argv).f == expected
class Args(NamedTuple): def _args(**kwargs):
command: str = 'help' kwargs.setdefault('command', 'help')
config: str = C.CONFIG_FILE kwargs.setdefault('config', C.CONFIG_FILE)
files: Sequence[str] = [] return argparse.Namespace(**kwargs)
repo: Optional[str] = None
def test_adjust_args_and_chdir_not_in_git_dir(in_tmpdir): def test_adjust_args_and_chdir_not_in_git_dir(in_tmpdir):
with pytest.raises(FatalError): with pytest.raises(FatalError):
main._adjust_args_and_chdir(Args()) main._adjust_args_and_chdir(_args())
def test_adjust_args_and_chdir_in_dot_git_dir(in_git_dir): def test_adjust_args_and_chdir_in_dot_git_dir(in_git_dir):
with in_git_dir.join('.git').as_cwd(), pytest.raises(FatalError): with in_git_dir.join('.git').as_cwd(), pytest.raises(FatalError):
main._adjust_args_and_chdir(Args()) main._adjust_args_and_chdir(_args())
def test_adjust_args_and_chdir_noop(in_git_dir): def test_adjust_args_and_chdir_noop(in_git_dir):
args = Args(command='run', files=['f1', 'f2']) args = _args(command='run', files=['f1', 'f2'])
main._adjust_args_and_chdir(args) main._adjust_args_and_chdir(args)
assert os.getcwd() == in_git_dir assert os.getcwd() == in_git_dir
assert args.config == C.CONFIG_FILE assert args.config == C.CONFIG_FILE
@ -56,7 +52,7 @@ def test_adjust_args_and_chdir_relative_things(in_git_dir):
in_git_dir.join('foo/cfg.yaml').ensure() in_git_dir.join('foo/cfg.yaml').ensure()
in_git_dir.join('foo').chdir() in_git_dir.join('foo').chdir()
args = Args(command='run', files=['f1', 'f2'], config='cfg.yaml') args = _args(command='run', files=['f1', 'f2'], config='cfg.yaml')
main._adjust_args_and_chdir(args) main._adjust_args_and_chdir(args)
assert os.getcwd() == in_git_dir assert os.getcwd() == in_git_dir
assert args.config == os.path.join('foo', 'cfg.yaml') assert args.config == os.path.join('foo', 'cfg.yaml')
@ -66,7 +62,7 @@ def test_adjust_args_and_chdir_relative_things(in_git_dir):
def test_adjust_args_and_chdir_non_relative_config(in_git_dir): def test_adjust_args_and_chdir_non_relative_config(in_git_dir):
in_git_dir.join('foo').ensure_dir().chdir() in_git_dir.join('foo').ensure_dir().chdir()
args = Args() args = _args()
main._adjust_args_and_chdir(args) main._adjust_args_and_chdir(args)
assert os.getcwd() == in_git_dir assert os.getcwd() == in_git_dir
assert args.config == C.CONFIG_FILE assert args.config == C.CONFIG_FILE
@ -75,7 +71,7 @@ def test_adjust_args_and_chdir_non_relative_config(in_git_dir):
def test_adjust_args_try_repo_repo_relative(in_git_dir): def test_adjust_args_try_repo_repo_relative(in_git_dir):
in_git_dir.join('foo').ensure_dir().chdir() in_git_dir.join('foo').ensure_dir().chdir()
args = Args(command='try-repo', repo='../foo', files=[]) args = _args(command='try-repo', repo='../foo', files=[])
assert args.repo is not None assert args.repo is not None
assert os.path.exists(args.repo) assert os.path.exists(args.repo)
main._adjust_args_and_chdir(args) main._adjust_args_and_chdir(args)

View file

@ -22,7 +22,7 @@ from pre_commit import output
), ),
) )
def test_get_hook_message_raises(kwargs): def test_get_hook_message_raises(kwargs):
with pytest.raises(ValueError): with pytest.raises(AssertionError):
output.get_hook_message('start', **kwargs) output.get_hook_message('start', **kwargs)

View file

@ -311,7 +311,7 @@ def test_golang_hook(tempdir_factory, store):
def test_golang_hook_still_works_when_gobin_is_set(tempdir_factory, store): def test_golang_hook_still_works_when_gobin_is_set(tempdir_factory, store):
gobin_dir = tempdir_factory.get() gobin_dir = tempdir_factory.get()
with envcontext([('GOBIN', gobin_dir)]): with envcontext((('GOBIN', gobin_dir),)):
test_golang_hook(tempdir_factory, store) test_golang_hook(tempdir_factory, store)
assert os.listdir(gobin_dir) == [] assert os.listdir(gobin_dir) == []

View file

@ -120,7 +120,7 @@ def test_clone_shallow_failure_fallback_to_complete(
# Force shallow clone failure # Force shallow clone failure
def fake_shallow_clone(self, *args, **kwargs): def fake_shallow_clone(self, *args, **kwargs):
raise CalledProcessError(None, None, None, None, None) raise CalledProcessError(1, (), 0, b'', None)
store._shallow_clone = fake_shallow_clone store._shallow_clone = fake_shallow_clone
ret = store.clone(path, rev) ret = store.clone(path, rev)

View file

@ -15,9 +15,9 @@ from pre_commit.util import tmpdir
def test_CalledProcessError_str(): def test_CalledProcessError_str():
error = CalledProcessError(1, ['exe'], 0, b'output', b'errors') error = CalledProcessError(1, ('exe',), 0, b'output', b'errors')
assert str(error) == ( assert str(error) == (
"command: ['exe']\n" "command: ('exe',)\n"
'return code: 1\n' 'return code: 1\n'
'expected return code: 0\n' 'expected return code: 0\n'
'stdout:\n' 'stdout:\n'
@ -28,9 +28,9 @@ def test_CalledProcessError_str():
def test_CalledProcessError_str_nooutput(): def test_CalledProcessError_str_nooutput():
error = CalledProcessError(1, ['exe'], 0, b'', b'') error = CalledProcessError(1, ('exe',), 0, b'', b'')
assert str(error) == ( assert str(error) == (
"command: ['exe']\n" "command: ('exe',)\n"
'return code: 1\n' 'return code: 1\n'
'expected return code: 0\n' 'expected return code: 0\n'
'stdout: (none)\n' 'stdout: (none)\n'

View file

@ -2,6 +2,7 @@ import concurrent.futures
import os import os
import sys import sys
import time import time
from typing import Tuple
from unittest import mock from unittest import mock
import pytest import pytest
@ -166,9 +167,8 @@ def test_xargs_concurrency():
def test_thread_mapper_concurrency_uses_threadpoolexecutor_map(): def test_thread_mapper_concurrency_uses_threadpoolexecutor_map():
with xargs._thread_mapper(10) as thread_map: with xargs._thread_mapper(10) as thread_map:
assert isinstance( _self = thread_map.__self__ # type: ignore
thread_map.__self__, concurrent.futures.ThreadPoolExecutor, assert isinstance(_self, concurrent.futures.ThreadPoolExecutor)
) is True
def test_thread_mapper_concurrency_uses_regular_map(): def test_thread_mapper_concurrency_uses_regular_map():
@ -178,7 +178,7 @@ def test_thread_mapper_concurrency_uses_regular_map():
def test_xargs_propagate_kwargs_to_cmd(): def test_xargs_propagate_kwargs_to_cmd():
env = {'PRE_COMMIT_TEST_VAR': 'Pre commit is awesome'} env = {'PRE_COMMIT_TEST_VAR': 'Pre commit is awesome'}
cmd = ('bash', '-c', 'echo $PRE_COMMIT_TEST_VAR', '--') cmd: Tuple[str, ...] = ('bash', '-c', 'echo $PRE_COMMIT_TEST_VAR', '--')
cmd = parse_shebang.normalize_cmd(cmd) cmd = parse_shebang.normalize_cmd(cmd)
ret, stdout = xargs.xargs(cmd, ('1',), env=env) ret, stdout = xargs.xargs(cmd, ('1',), env=env)