diff --git a/pre_commit/xargs.py b/pre_commit/xargs.py index eff57ce7..a7493c01 100644 --- a/pre_commit/xargs.py +++ b/pre_commit/xargs.py @@ -24,6 +24,14 @@ TRet = TypeVar('TRet') def cpu_count() -> int: + try: + # On systems that support it, this will return a more accurate count of + # usable CPUs for the current process, which will take into account + # cgroup limits + return len(os.sched_getaffinity(0)) + except AttributeError: + pass + try: return multiprocessing.cpu_count() except NotImplementedError: diff --git a/tests/lang_base_test.py b/tests/lang_base_test.py index a532b6a5..1cffa0e5 100644 --- a/tests/lang_base_test.py +++ b/tests/lang_base_test.py @@ -30,6 +30,19 @@ def homedir_mck(): yield +@pytest.fixture +def no_sched_getaffinity(): + # Simulates an OS without os.sched_getaffinity available (mac/windows) + # https://docs.python.org/3/library/os.html#interface-to-the-scheduler + with mock.patch.object( + os, + 'sched_getaffinity', + create=True, + side_effect=AttributeError, + ): + yield + + def test_exe_exists_does_not_exist(find_exe_mck, homedir_mck): find_exe_mck.return_value = None assert lang_base.exe_exists('ruby') is False @@ -116,7 +129,17 @@ def test_no_env_noop(tmp_path): assert before == inside == after -def test_target_concurrency_normal(): +def test_target_concurrency_sched_getaffinity(no_sched_getaffinity): + with mock.patch.object( + os, + 'sched_getaffinity', + return_value=set(range(345)), + ): + with mock.patch.dict(os.environ, clear=True): + assert lang_base.target_concurrency() == 345 + + +def test_target_concurrency_without_sched_getaffinity(no_sched_getaffinity): with mock.patch.object(multiprocessing, 'cpu_count', return_value=123): with mock.patch.dict(os.environ, {}, clear=True): assert lang_base.target_concurrency() == 123 @@ -134,7 +157,7 @@ def test_target_concurrency_on_travis(): assert lang_base.target_concurrency() == 2 -def test_target_concurrency_cpu_count_not_implemented(): +def test_target_concurrency_cpu_count_not_implemented(no_sched_getaffinity): with mock.patch.object( multiprocessing, 'cpu_count', side_effect=NotImplementedError, ):