Compute the maximum command length more accurately

This commit is contained in:
Anthony Sottile 2019-03-09 14:36:24 -08:00
parent 518a72d7e7
commit 985f09ff88
2 changed files with 45 additions and 15 deletions

View file

@ -5,6 +5,7 @@ from __future__ import unicode_literals
import concurrent.futures import concurrent.futures
import contextlib import contextlib
import math import math
import os
import sys import sys
import six import six
@ -13,10 +14,24 @@ from pre_commit import parse_shebang
from pre_commit.util import cmd_output from pre_commit.util import cmd_output
# TODO: properly compute max_length value def _environ_size(_env=None):
def _get_platform_max_length(): environ = _env if _env is not None else getattr(os, 'environb', os.environ)
size = 8 * len(environ) # number of pointers in `envp`
for k, v in environ.items():
size += len(k) + len(v) + 2 # c strings in `envp`
return size
def _get_platform_max_length(): # pragma: no cover (platform specific)
if os.name == 'posix':
maximum = os.sysconf(str('SC_ARG_MAX')) - 2048 - _environ_size()
maximum = min(maximum, 2 ** 17)
return maximum
elif os.name == 'nt':
return 2 ** 15 - 2048 # UNICODE_STRING max - headroom
else:
# posix minimum # posix minimum
return 4 * 1024 return 2 ** 12
def _command_length(*cmd): def _command_length(*cmd):
@ -52,7 +67,7 @@ def partition(cmd, varargs, target_concurrency, _max_length=None):
# Reversed so arguments are in order # Reversed so arguments are in order
varargs = list(reversed(varargs)) varargs = list(reversed(varargs))
total_length = _command_length(*cmd) total_length = _command_length(*cmd) + 1
while varargs: while varargs:
arg = varargs.pop() arg = varargs.pop()
@ -69,7 +84,7 @@ def partition(cmd, varargs, target_concurrency, _max_length=None):
# We've exceeded the length, yield a command # We've exceeded the length, yield a command
ret.append(cmd + tuple(ret_cmd)) ret.append(cmd + tuple(ret_cmd))
ret_cmd = [] ret_cmd = []
total_length = _command_length(*cmd) total_length = _command_length(*cmd) + 1
varargs.append(arg) varargs.append(arg)
ret.append(cmd + tuple(ret_cmd)) ret.append(cmd + tuple(ret_cmd))
@ -99,7 +114,7 @@ def xargs(cmd, varargs, **kwargs):
stderr = b'' stderr = b''
try: try:
parse_shebang.normexe(cmd[0]) cmd = parse_shebang.normalize_cmd(cmd)
except parse_shebang.ExecutableNotFoundError as e: except parse_shebang.ExecutableNotFoundError as e:
return e.to_output() return e.to_output()

View file

@ -10,9 +10,24 @@ import mock
import pytest import pytest
import six import six
from pre_commit import parse_shebang
from pre_commit import xargs from pre_commit import xargs
@pytest.mark.parametrize(
('env', 'expected'),
(
({}, 0),
({b'x': b'1'}, 12),
({b'x': b'12'}, 13),
({b'x': b'1', b'y': b'2'}, 24),
),
)
def test_environ_size(env, expected):
# normalize integer sizing
assert xargs._environ_size(_env=env) == expected
@pytest.fixture @pytest.fixture
def win32_py2_mock(): def win32_py2_mock():
with mock.patch.object(sys, 'getfilesystemencoding', return_value='utf-8'): with mock.patch.object(sys, 'getfilesystemencoding', return_value='utf-8'):
@ -56,7 +71,7 @@ def test_partition_limits():
'.' * 6, '.' * 6,
), ),
1, 1,
_max_length=20, _max_length=21,
) )
assert ret == ( assert ret == (
('ninechars', '.' * 5, '.' * 4), ('ninechars', '.' * 5, '.' * 4),
@ -70,21 +85,21 @@ def test_partition_limit_win32_py3(win32_py3_mock):
cmd = ('ninechars',) cmd = ('ninechars',)
# counted as half because of utf-16 encode # counted as half because of utf-16 encode
varargs = ('😑' * 5,) varargs = ('😑' * 5,)
ret = xargs.partition(cmd, varargs, 1, _max_length=20) ret = xargs.partition(cmd, varargs, 1, _max_length=21)
assert ret == (cmd + varargs,) assert ret == (cmd + varargs,)
def test_partition_limit_win32_py2(win32_py2_mock): def test_partition_limit_win32_py2(win32_py2_mock):
cmd = ('ninechars',) cmd = ('ninechars',)
varargs = ('😑' * 5,) # 4 bytes * 5 varargs = ('😑' * 5,) # 4 bytes * 5
ret = xargs.partition(cmd, varargs, 1, _max_length=30) ret = xargs.partition(cmd, varargs, 1, _max_length=31)
assert ret == (cmd + varargs,) assert ret == (cmd + varargs,)
def test_partition_limit_linux(linux_mock): def test_partition_limit_linux(linux_mock):
cmd = ('ninechars',) cmd = ('ninechars',)
varargs = ('😑' * 5,) varargs = ('😑' * 5,)
ret = xargs.partition(cmd, varargs, 1, _max_length=30) ret = xargs.partition(cmd, varargs, 1, _max_length=31)
assert ret == (cmd + varargs,) assert ret == (cmd + varargs,)
@ -134,9 +149,9 @@ def test_xargs_smoke():
assert err == b'' assert err == b''
exit_cmd = ('bash', '-c', 'exit $1', '--') exit_cmd = parse_shebang.normalize_cmd(('bash', '-c', 'exit $1', '--'))
# Abuse max_length to control the exit code # Abuse max_length to control the exit code
max_length = len(' '.join(exit_cmd)) + 2 max_length = len(' '.join(exit_cmd)) + 3
def test_xargs_negate(): def test_xargs_negate():
@ -165,14 +180,14 @@ def test_xargs_retcode_normal():
def test_xargs_concurrency(): def test_xargs_concurrency():
bash_cmd = ('bash', '-c') bash_cmd = parse_shebang.normalize_cmd(('bash', '-c'))
print_pid = ('sleep 0.5 && echo $$',) print_pid = ('sleep 0.5 && echo $$',)
start = time.time() start = time.time()
ret, stdout, _ = xargs.xargs( ret, stdout, _ = xargs.xargs(
bash_cmd, print_pid * 5, bash_cmd, print_pid * 5,
target_concurrency=5, target_concurrency=5,
_max_length=len(' '.join(bash_cmd + print_pid)), _max_length=len(' '.join(bash_cmd + print_pid)) + 1,
) )
elapsed = time.time() - start elapsed = time.time() - start
assert ret == 0 assert ret == 0