diff --git a/pre_commit/git.py b/pre_commit/git.py index ea065370..3802b5c0 100644 --- a/pre_commit/git.py +++ b/pre_commit/git.py @@ -1,14 +1,12 @@ +import functools import os -import re import pkg_resources +import re from plumbum import local from pre_commit.util import memoize_by_cwd -def _get_root_original(): - return local['git']['rev-parse', '--show-toplevel']().strip() - def _get_root_new(): path = os.getcwd() while len(path) > 1: @@ -22,7 +20,6 @@ def _get_root_new(): @memoize_by_cwd def get_root(): return _get_root_new() - return local['git']['rev-parse', '--show-toplevel']().strip() @memoize_by_cwd @@ -50,23 +47,24 @@ def get_staged_files(): return local['git']['diff', '--staged', '--name-only']().splitlines() -@memoize_by_cwd -def get_staged_files_matching(expr): - regex = re.compile(expr) - return set( - filename for filename in get_staged_files() if regex.search(filename) - ) - - @memoize_by_cwd def get_all_files(): return local['git']['ls-files']().splitlines() -# Smell: this is duplicated above -@memoize_by_cwd -def get_all_files_matching(expr): - regex = re.compile(expr) - return set( - filename for filename in get_all_files() if regex.search(filename) - ) \ No newline at end of file +def get_files_matching(all_file_list_strategy): + @functools.wraps(all_file_list_strategy) + @memoize_by_cwd + def wrapper(expr): + regex = re.compile(expr) + return set( + filename + for filename in all_file_list_strategy() + if regex.search(filename) + ) + return wrapper + + + +get_staged_files_matching = get_files_matching(get_staged_files) +get_all_files_matching = get_files_matching(get_all_files) diff --git a/pre_commit/util.py b/pre_commit/util.py index 3a238590..648b7892 100644 --- a/pre_commit/util.py +++ b/pre_commit/util.py @@ -22,15 +22,16 @@ class cached_property(object): def memoize_by_cwd(func): """Memoize a function call based on os.getcwd().""" - cache = {} @functools.wraps(func) def wrapper(*args): cwd = os.getcwd() key = (cwd,) + args try: - return cache[key] + return wrapper._cache[key] except KeyError: - ret = cache[key] = func(*args) + ret = wrapper._cache[key] = func(*args) return ret + wrapper._cache = {} + return wrapper