diff --git a/pre_commit/commands/run.py b/pre_commit/commands/run.py index 98ae25dc..a62d0474 100644 --- a/pre_commit/commands/run.py +++ b/pre_commit/commands/run.py @@ -48,12 +48,17 @@ def _filter_by_include_exclude(filenames, include, exclude): } -def _filter_by_types(filenames, types, exclude_types): +def _filter_by_types(filenames, + types, + exclude_types, + get_tags=tags_from_path): types, exclude_types = frozenset(types), frozenset(exclude_types) + valid_types = types - exclude_types + ret = [] for filename in filenames: - tags = tags_from_path(filename) - if tags >= types and not tags & exclude_types: + tags = frozenset(get_tags(filename)) + if len(valid_types.intersection(tags)) > 0: ret.append(filename) return tuple(ret) diff --git a/tests/commands/run_test.py b/tests/commands/run_test.py index d664e801..e9e66534 100644 --- a/tests/commands/run_test.py +++ b/tests/commands/run_test.py @@ -9,12 +9,14 @@ from collections import OrderedDict import pytest +from identify.identify import tags_from_interpreter import pre_commit.constants as C from pre_commit.commands.install_uninstall import install from pre_commit.commands.run import _compute_cols from pre_commit.commands.run import _filter_by_include_exclude from pre_commit.commands.run import _get_skips from pre_commit.commands.run import _has_unmerged_paths +from pre_commit.commands.run import _filter_by_types from pre_commit.commands.run import run from pre_commit.runner import Runner from pre_commit.util import cmd_output @@ -831,3 +833,17 @@ def test_include_exclude_does_search_instead_of_match(some_filenames): def test_include_exclude_exclude_removes_files(some_filenames): ret = _filter_by_include_exclude(some_filenames, '', r'\.py$') assert ret == {'.pre-commit-hooks.yaml'} + + +def get_tags_stub(interpreter): + return lambda x: tags_from_interpreter(interpreter) + + +def test_filter_by_types_for_bash_by_interpreter(): + ret = _filter_by_types(['bash_script'], ['shell', 'sh', 'bash'], [], get_tags=get_tags_stub('bash')) + assert ret == ('bash_script',) + + +def test_filter_by_types_for_python_by_interpreter(): + ret = _filter_by_types(['script.py'], ['python'], [], get_tags=get_tags_stub('python')) + assert ret == ('script.py',)