Add types to pre-commit

This commit is contained in:
Anthony Sottile 2020-01-10 23:32:28 -08:00
parent fa536a8693
commit 327ed924a3
62 changed files with 911 additions and 411 deletions

View file

@ -1,15 +1,20 @@
import logging
import os.path
import sys
from typing import Dict
from typing import List
from typing import Optional
from typing import Set
from pre_commit.util import cmd_output
from pre_commit.util import cmd_output_b
from pre_commit.util import EnvironT
logger = logging.getLogger(__name__)
def zsplit(s):
def zsplit(s: str) -> List[str]:
s = s.strip('\0')
if s:
return s.split('\0')
@ -17,7 +22,7 @@ def zsplit(s):
return []
def no_git_env(_env=None):
def no_git_env(_env: Optional[EnvironT] = None) -> Dict[str, str]:
# Too many bugs dealing with environment variables and GIT:
# https://github.com/pre-commit/pre-commit/issues/300
# In git 2.6.3 (maybe others), git exports GIT_WORK_TREE while running
@ -34,11 +39,11 @@ def no_git_env(_env=None):
}
def get_root():
def get_root() -> str:
return cmd_output('git', 'rev-parse', '--show-toplevel')[1].strip()
def get_git_dir(git_root='.'):
def get_git_dir(git_root: str = '.') -> str:
opts = ('--git-common-dir', '--git-dir')
_, out, _ = cmd_output('git', 'rev-parse', *opts, cwd=git_root)
for line, opt in zip(out.splitlines(), opts):
@ -48,12 +53,12 @@ def get_git_dir(git_root='.'):
raise AssertionError('unreachable: no git dir')
def get_remote_url(git_root):
def get_remote_url(git_root: str) -> str:
_, out, _ = cmd_output('git', 'config', 'remote.origin.url', cwd=git_root)
return out.strip()
def is_in_merge_conflict():
def is_in_merge_conflict() -> bool:
git_dir = get_git_dir('.')
return (
os.path.exists(os.path.join(git_dir, 'MERGE_MSG')) and
@ -61,7 +66,7 @@ def is_in_merge_conflict():
)
def parse_merge_msg_for_conflicts(merge_msg):
def parse_merge_msg_for_conflicts(merge_msg: bytes) -> List[str]:
# Conflicted files start with tabs
return [
line.lstrip(b'#').strip().decode('UTF-8')
@ -71,7 +76,7 @@ def parse_merge_msg_for_conflicts(merge_msg):
]
def get_conflicted_files():
def get_conflicted_files() -> Set[str]:
logger.info('Checking merge-conflict files only.')
# Need to get the conflicted files from the MERGE_MSG because they could
# have resolved the conflict by choosing one side or the other
@ -92,7 +97,7 @@ def get_conflicted_files():
return set(merge_conflict_filenames) | set(merge_diff_filenames)
def get_staged_files(cwd=None):
def get_staged_files(cwd: Optional[str] = None) -> List[str]:
return zsplit(
cmd_output(
'git', 'diff', '--staged', '--name-only', '--no-ext-diff', '-z',
@ -103,7 +108,7 @@ def get_staged_files(cwd=None):
)
def intent_to_add_files():
def intent_to_add_files() -> List[str]:
_, stdout, _ = cmd_output('git', 'status', '--porcelain', '-z')
parts = list(reversed(zsplit(stdout)))
intent_to_add = []
@ -117,11 +122,11 @@ def intent_to_add_files():
return intent_to_add
def get_all_files():
def get_all_files() -> List[str]:
return zsplit(cmd_output('git', 'ls-files', '-z')[1])
def get_changed_files(new, old):
def get_changed_files(new: str, old: str) -> List[str]:
return zsplit(
cmd_output(
'git', 'diff', '--name-only', '--no-ext-diff', '-z',
@ -130,24 +135,22 @@ def get_changed_files(new, old):
)
def head_rev(remote):
def head_rev(remote: str) -> str:
_, out, _ = cmd_output('git', 'ls-remote', '--exit-code', remote, 'HEAD')
return out.split()[0]
def has_diff(*args, **kwargs):
repo = kwargs.pop('repo', '.')
assert not kwargs, kwargs
def has_diff(*args: str, repo: str = '.') -> bool:
cmd = ('git', 'diff', '--quiet', '--no-ext-diff') + args
return cmd_output_b(*cmd, cwd=repo, retcode=None)[0] == 1
def has_core_hookpaths_set():
def has_core_hookpaths_set() -> bool:
_, out, _ = cmd_output_b('git', 'config', 'core.hooksPath', retcode=None)
return bool(out.strip())
def init_repo(path, remote):
def init_repo(path: str, remote: str) -> None:
if os.path.isdir(remote):
remote = os.path.abspath(remote)
@ -156,7 +159,7 @@ def init_repo(path, remote):
cmd_output_b('git', 'remote', 'add', 'origin', remote, cwd=path, env=env)
def commit(repo='.'):
def commit(repo: str = '.') -> None:
env = no_git_env()
name, email = 'pre-commit', 'asottile+pre-commit@umich.edu'
env['GIT_AUTHOR_NAME'] = env['GIT_COMMITTER_NAME'] = name
@ -165,12 +168,12 @@ def commit(repo='.'):
cmd_output_b(*cmd, cwd=repo, env=env)
def git_path(name, repo='.'):
def git_path(name: str, repo: str = '.') -> str:
_, out, _ = cmd_output('git', 'rev-parse', '--git-path', name, cwd=repo)
return os.path.join(repo, out.strip())
def check_for_cygwin_mismatch():
def check_for_cygwin_mismatch() -> None:
"""See https://github.com/pre-commit/pre-commit/issues/354"""
if sys.platform in ('cygwin', 'win32'): # pragma: no cover (windows)
is_cygwin_python = sys.platform == 'cygwin'