Merge pull request #153 from pre-commit/error_handler

Error handler
This commit is contained in:
Ken Struys 2014-08-13 11:34:39 -07:00
commit f3b5886c8a
15 changed files with 199 additions and 89 deletions

View file

@ -6,10 +6,10 @@ import jsonschema
import jsonschema.exceptions import jsonschema.exceptions
import os.path import os.path
import re import re
import sys
import yaml import yaml
from pre_commit.jsonschema_extensions import apply_defaults from pre_commit.jsonschema_extensions import apply_defaults
from pre_commit.util import entry
def is_regex_valid(regex): def is_regex_valid(regex):
@ -64,8 +64,8 @@ def get_validator(
def get_run_function(filenames_help, validate_strategy, exception_cls): def get_run_function(filenames_help, validate_strategy, exception_cls):
@entry def run(argv=None):
def run(argv): argv = argv if argv is not None else sys.argv[1:]
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*', help=filenames_help) parser.add_argument('filenames', nargs='*', help=filenames_help)
args = parser.parse_args(argv) args = parser.parse_args(argv)

View file

@ -1,13 +1,12 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import sys
from pre_commit.clientlib.validate_base import get_run_function from pre_commit.clientlib.validate_base import get_run_function
from pre_commit.clientlib.validate_base import get_validator from pre_commit.clientlib.validate_base import get_validator
from pre_commit.clientlib.validate_base import is_regex_valid from pre_commit.clientlib.validate_base import is_regex_valid
from pre_commit.errors import FatalError
class InvalidConfigError(ValueError): class InvalidConfigError(FatalError):
pass pass
@ -71,4 +70,4 @@ run = get_run_function('Config filenames.', load_config, InvalidConfigError)
if __name__ == '__main__': if __name__ == '__main__':
sys.exit(run()) exit(run())

View file

@ -1,7 +1,5 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import sys
from pre_commit.clientlib.validate_base import get_run_function from pre_commit.clientlib.validate_base import get_run_function
from pre_commit.clientlib.validate_base import get_validator from pre_commit.clientlib.validate_base import get_validator
from pre_commit.clientlib.validate_base import is_regex_valid from pre_commit.clientlib.validate_base import is_regex_valid
@ -74,4 +72,4 @@ run = get_run_function(
if __name__ == '__main__': if __name__ == '__main__':
sys.exit(run()) exit(run())

View file

@ -0,0 +1,42 @@
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import contextlib
import io
import os.path
import traceback
from pre_commit.errors import FatalError
from pre_commit.store import Store
# For testing purposes
class PreCommitSystemExit(SystemExit):
pass
def _log_and_exit(msg, exc, formatted, print_fn=print):
error_msg = '{0}: {1}: {2}'.format(msg, type(exc).__name__, exc)
print_fn(error_msg)
print_fn('Check the log at ~/.pre-commit/pre-commit.log')
store = Store()
store.require_created()
with io.open(os.path.join(store.directory, 'pre-commit.log'), 'w') as log:
log.write(error_msg + '\n')
log.write(formatted + '\n')
raise PreCommitSystemExit(1)
@contextlib.contextmanager
def error_handler():
try:
yield
except FatalError as e:
_log_and_exit('An error has occurred', e, traceback.format_exc())
except Exception as e:
_log_and_exit(
'An unexpected error has occurred',
e,
traceback.format_exc(),
)

6
pre_commit/errors.py Normal file
View file

@ -0,0 +1,6 @@
from __future__ import absolute_import
from __future__ import unicode_literals
class FatalError(RuntimeError):
pass

View file

@ -1,6 +1,5 @@
from __future__ import unicode_literals from __future__ import unicode_literals
"""five: six, redux"""
# pylint:disable=invalid-name # pylint:disable=invalid-name
PY2 = str is bytes PY2 = str is bytes
PY3 = str is not bytes PY3 = str is not bytes

View file

@ -7,13 +7,13 @@ import os.path
import re import re
from plumbum import local from plumbum import local
from pre_commit.errors import FatalError
from pre_commit.util import memoize_by_cwd from pre_commit.util import memoize_by_cwd
logger = logging.getLogger('pre_commit') logger = logging.getLogger('pre_commit')
@memoize_by_cwd
def get_root(): def get_root():
path = os.getcwd() path = os.getcwd()
while len(path) > 1: while len(path) > 1:
@ -21,7 +21,10 @@ def get_root():
return path return path
else: else:
path = os.path.normpath(os.path.join(path, '../')) path = os.path.normpath(os.path.join(path, '../'))
raise AssertionError('called from outside of the gits') raise FatalError(
'Called from outside of the gits. '
'Please cd to a git repository.'
)
def is_in_merge_conflict(): def is_in_merge_conflict():

View file

@ -20,20 +20,20 @@ def extend_validator_cls(validator_cls, modify):
def default_values(properties, instance): def default_values(properties, instance):
for property, subschema in properties.items(): for prop, subschema in properties.items():
if 'default' in subschema: if 'default' in subschema:
instance.setdefault( instance.setdefault(
property, copy.deepcopy(subschema['default']), prop, copy.deepcopy(subschema['default']),
) )
def remove_default_values(properties, instance): def remove_default_values(properties, instance):
for property, subschema in properties.items(): for prop, subschema in properties.items():
if ( if (
'default' in subschema and 'default' in subschema and
instance.get(property) == subschema['default'] instance.get(prop) == subschema['default']
): ):
del instance[property] del instance[prop]
_AddDefaultsValidator = extend_validator_cls( _AddDefaultsValidator = extend_validator_cls(

View file

@ -2,6 +2,7 @@ from __future__ import unicode_literals
import argparse import argparse
import pkg_resources import pkg_resources
import sys
from pre_commit import color from pre_commit import color
from pre_commit.commands.autoupdate import autoupdate from pre_commit.commands.autoupdate import autoupdate
@ -9,12 +10,12 @@ from pre_commit.commands.clean import clean
from pre_commit.commands.install_uninstall import install from pre_commit.commands.install_uninstall import install
from pre_commit.commands.install_uninstall import uninstall from pre_commit.commands.install_uninstall import uninstall
from pre_commit.commands.run import run from pre_commit.commands.run import run
from pre_commit.error_handler import error_handler
from pre_commit.runner import Runner from pre_commit.runner import Runner
from pre_commit.util import entry
@entry def main(argv=None):
def main(argv): argv = argv if argv is not None else sys.argv[1:]
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# http://stackoverflow.com/a/8521644/812183 # http://stackoverflow.com/a/8521644/812183
@ -83,6 +84,7 @@ def main(argv):
else: else:
parser.parse_args(['--help']) parser.parse_args(['--help'])
with error_handler():
runner = Runner.create() runner = Runner.create()
if args.command == 'install': if args.command == 'install':

View file

@ -7,7 +7,6 @@ import os
import os.path import os.path
import pkg_resources import pkg_resources
import shutil import shutil
import sys
import tarfile import tarfile
import tempfile import tempfile
@ -29,19 +28,6 @@ def memoize_by_cwd(func):
return wrapper return wrapper
def entry(func):
"""Allows a function that has `argv` as an argument to be used as a
commandline entry. This will make the function callable using either
explicitly passed argv or defaulting to sys.argv[1:]
"""
@functools.wraps(func)
def wrapper(argv=None):
if argv is None:
argv = sys.argv[1:]
return func(argv)
return wrapper
@contextlib.contextmanager @contextlib.contextmanager
def clean_path_on_failure(path): def clean_path_on_failure(path):
"""Cleans up the directory on an exceptional failure.""" """Cleans up the directory on an exceptional failure."""

View file

@ -1,5 +1,5 @@
[MESSAGES CONTROL] [MESSAGES CONTROL]
disable=missing-docstring,abstract-method,redefined-builtin,useless-else-on-loop,redefined-outer-name,invalid-name disable=locally-disabled,fixme,missing-docstring,abstract-method,useless-else-on-loop,invalid-name
[REPORTS] [REPORTS]
output-format=colorized output-format=colorized

View file

@ -0,0 +1,93 @@
from __future__ import absolute_import
from __future__ import unicode_literals
import io
import os.path
import mock
import pytest
import re
from pre_commit import error_handler
from pre_commit.errors import FatalError
@pytest.yield_fixture
def mocked_log_and_exit():
with mock.patch.object(error_handler, '_log_and_exit') as log_and_exit:
yield log_and_exit
def test_error_handler_no_exception(mocked_log_and_exit):
with error_handler.error_handler():
pass
assert mocked_log_and_exit.call_count == 0
def test_error_handler_fatal_error(mocked_log_and_exit):
exc = FatalError('just a test')
with error_handler.error_handler():
raise exc
mocked_log_and_exit.assert_called_once_with(
'An error has occurred',
exc,
# Tested below
mock.ANY,
)
assert re.match(
'Traceback \(most recent call last\):\n'
' File ".+/pre_commit/error_handler.py", line \d+, in error_handler\n'
' yield\n'
' File ".+/tests/error_handler_test.py", line \d+, '
'in test_error_handler_fatal_error\n'
' raise exc\n'
'(pre_commit\.errors\.)?FatalError: just a test\n',
mocked_log_and_exit.call_args[0][2],
)
def test_error_handler_uncaught_error(mocked_log_and_exit):
exc = ValueError('another test')
with error_handler.error_handler():
raise exc
mocked_log_and_exit.assert_called_once_with(
'An unexpected error has occurred',
exc,
# Tested below
mock.ANY,
)
assert re.match(
'Traceback \(most recent call last\):\n'
' File ".+/pre_commit/error_handler.py", line \d+, in error_handler\n'
' yield\n'
' File ".+/tests/error_handler_test.py", line \d+, '
'in test_error_handler_uncaught_error\n'
' raise exc\n'
'ValueError: another test\n',
mocked_log_and_exit.call_args[0][2],
)
def test_log_and_exit(mock_out_store_directory):
mocked_print = mock.Mock()
with pytest.raises(error_handler.PreCommitSystemExit):
error_handler._log_and_exit(
'msg', FatalError('hai'), "I'm a stacktrace",
print_fn=mocked_print,
)
printed = '\n'.join(call[0][0] for call in mocked_print.call_args_list)
assert printed == (
'msg: FatalError: hai\n'
'Check the log at ~/.pre-commit/pre-commit.log'
)
log_file = os.path.join(mock_out_store_directory, 'pre-commit.log')
assert os.path.exists(log_file)
contents = io.open(log_file).read()
assert contents == (
'msg: FatalError: hai\n'
"I'm a stacktrace\n"
)

View file

@ -6,6 +6,7 @@ import pytest
from plumbum import local from plumbum import local
from pre_commit import git from pre_commit import git
from pre_commit.errors import FatalError
from testing.fixtures import git_dir from testing.fixtures import git_dir
@ -24,6 +25,12 @@ def test_get_root_deeper(tmpdir_factory):
assert git.get_root() == path assert git.get_root() == path
def test_get_root_not_git_dir(tmpdir_factory):
with local.cwd(tmpdir_factory.get()):
with pytest.raises(FatalError):
git.get_root()
def test_is_not_in_merge_conflict(tmpdir_factory): def test_is_not_in_merge_conflict(tmpdir_factory):
path = git_dir(tmpdir_factory) path = git_dir(tmpdir_factory)
with local.cwd(path): with local.cwd(path):

View file

@ -1,15 +1,12 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import mock
import pytest import pytest
import os import os
import os.path import os.path
import random import random
import sys
from plumbum import local from plumbum import local
from pre_commit.util import clean_path_on_failure from pre_commit.util import clean_path_on_failure
from pre_commit.util import entry
from pre_commit.util import memoize_by_cwd from pre_commit.util import memoize_by_cwd
from pre_commit.util import shell_escape from pre_commit.util import shell_escape
from pre_commit.util import tmpdir from pre_commit.util import tmpdir
@ -46,28 +43,6 @@ def test_memoized_by_cwd_changes_with_different_cwd(memoized_by_cwd):
assert ret != ret2 assert ret != ret2
@pytest.fixture
def entry_func():
@entry
def func(argv):
return argv
return func
def test_explicitly_passed_argv_are_passed(entry_func):
input = object()
ret = entry_func(input)
assert ret is input
def test_no_arguments_passed_uses_argv(entry_func):
argv = [1, 2, 3, 4]
with mock.patch.object(sys, 'argv', argv):
ret = entry_func()
assert ret == argv[1:]
def test_clean_on_failure_noop(in_tmpdir): def test_clean_on_failure_noop(in_tmpdir):
with clean_path_on_failure('foo'): with clean_path_on_failure('foo'):
pass pass