Add a 'download' language for fetching pre-built binaries

This commit implements https://github.com/pre-commit/pre-commit/issues/1453#issuecomment-1702880253
This commit is contained in:
Fufu Fang 2023-08-12 09:59:52 +01:00
parent 19aa121db0
commit d0a5fde7ae
3 changed files with 462 additions and 0 deletions

View file

@ -7,6 +7,7 @@ from pre_commit.languages import dart
from pre_commit.languages import docker
from pre_commit.languages import docker_image
from pre_commit.languages import dotnet
from pre_commit.languages import download
from pre_commit.languages import fail
from pre_commit.languages import golang
from pre_commit.languages import haskell
@ -30,6 +31,7 @@ languages: dict[str, Language] = {
'docker': docker,
'docker_image': docker_image,
'dotnet': dotnet,
'download': download,
'fail': fail,
'golang': golang,
'haskell': haskell,

View file

@ -0,0 +1,328 @@
from __future__ import annotations
import contextlib
import hashlib
import os.path
import platform
import stat
from base64 import standard_b64decode as b64decode
from base64 import standard_b64encode as b64encode
from netrc import netrc
from os import chmod
from pathlib import Path
from pathlib import PurePath
from types import MappingProxyType
from typing import Collection
from typing import Final
from typing import Generator
from typing import Iterator
from typing import Mapping
from typing import Protocol
from typing import TYPE_CHECKING
from urllib.parse import urlparse
from urllib.request import build_opener
from urllib.request import HTTPBasicAuthHandler
from urllib.request import HTTPPasswordMgrWithDefaultRealm
from urllib.request import Request
from pre_commit import lang_base
from pre_commit.envcontext import envcontext
from pre_commit.envcontext import PatchesT
from pre_commit.envcontext import Var
from pre_commit.prefix import Prefix
if TYPE_CHECKING:
from _typeshed import SupportsRead
ENVIRONMENT_DIR = 'download'
get_default_version = lang_base.basic_get_default_version
run_hook = lang_base.basic_run_hook
def get_env_patch(target_dir: str) -> PatchesT:
return (
('PATH', (target_dir, os.pathsep, Var('PATH'))),
)
@contextlib.contextmanager
def in_env(prefix: Prefix, version: str) -> Generator[None, None, None]:
envdir = lang_base.environment_dir(prefix, ENVIRONMENT_DIR, version)
with envcontext(get_env_patch(envdir)):
yield
class Platform:
OS: Final[Mapping[str, str]] = MappingProxyType({
'Linux': 'linux',
'Darwin': 'darwin',
'Windows': 'windows',
'DragonFly': 'dragonfly',
'FreeBSD': 'freebsd',
})
CPU: Final[Mapping[str, str]] = MappingProxyType({
'aarch64': 'arm64',
'aarch64_be': 'arm64be',
'arm': 'arm',
'i386': '386',
'i686': '386',
'x86_64': 'amd64',
'AMD64': 'amd64',
'ppc': 'ppc',
'ppc64': 'ppc64',
'ppc64le': 'ppc64le',
})
def __init__(self, value: str) -> None:
self._value = value
os, cpu = self.parts
if os not in self.OS.values():
raise ValueError(f"invalid operating system `{os}`, \
valid values are: {','.join(self.OS.values())}")
if cpu not in self.CPU.values():
raise ValueError(f"invalid CPU `{cpu}`, \
valid values are: {','.join(self.CPU.values())}")
@property
def parts(self) -> tuple[str, str]:
first, second = self.value.split('/', 1)
return (first, second)
@property
def os(self) -> str:
os, _ = self.parts
return os
@property
def cpu(self) -> str:
_, cpu = self.parts
return cpu
@property
def value(self) -> str:
return self._value
@classmethod
def host(cls: type[Platform]) -> Platform:
os = cls.OS[platform.system()]
cpu = cls.CPU[platform.machine()]
return cls(f'{os}/{cpu}')
def __str__(self) -> str:
return f'{self.os}/{self.cpu}'
def __eq__(self, other: object) -> bool:
if not isinstance(other, Platform):
return False
return other.os == self.os and other.cpu == self.cpu
class ChecksumMismatchError(ValueError):
def __init__(
self, *, expected: SRI, actual: SRI,
message: str = 'checksum mismatch',
):
self._expected = expected
self._actual = actual
self._message = message
@property
def expected(self) -> SRI:
return self._expected
@property
def actual(self) -> SRI:
return self._actual
@property
def message(self) -> str:
return self._message
def __str__(self) -> str:
return f'''{self._message}:
- expected: {self.expected}
- actual : {self.actual}
'''
class HasBinaryRead(Protocol):
def read(self, __size: int = -1) -> bytes | None: ...
class SRI:
def __init__(self, value: str):
self._value = value
self._algorithm, self._checksum = self.value.split('-', 1)
if self.algorithm not in hashlib.algorithms_available:
raise ValueError(f"`{self.algorithm}` is not available, \
choose one of: {','.join(hashlib.algorithms_available)}`")
if b64encode(b64decode(self.checksum)).decode('utf-8') !=\
self.checksum:
raise ValueError('Invalid checksum string, \
the checksum string has to be encoded in base64.')
hasher = hashlib.new(self.algorithm)
checksum_len = len(b64decode(self.checksum))
if checksum_len != hasher.digest_size:
raise ValueError(
f'Invalid checksum string length of {checksum_len} for \
{self.algorithm}, expected {hasher.digest_size}',
)
@property
def value(self) -> str:
return self._value
@property
def algorithm(self) -> str:
return self._algorithm
@property
def checksum(self) -> str:
return self._checksum
def __str__(self) -> str:
return self.value
def check(
self, io: SupportsRead[bytes],
chunk: int = 4096,
) -> Iterator[bytes]:
hasher = hashlib.new(self.algorithm)
while buffer := io.read(chunk):
hasher.update(buffer)
yield buffer
digest = b64encode(hasher.digest()).decode('utf-8)')
if digest != self.checksum:
raise ChecksumMismatchError(
expected=self,
actual=SRI(f'{self.algorithm}-{digest}'),
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, SRI):
return False
return other.value == self.value
class URI:
def __init__(self, value: str) -> None:
self._value = value
url = urlparse(self.value)
if not all([url.scheme, url.netloc]):
raise ValueError(f'Invalid URI: {self.value}')
self._netloc = url.netloc
@property
def netloc(self) -> str:
return self._netloc
@property
def value(self) -> str:
return self._value
def __str__(self) -> str:
return self.value
class Metadata:
def __init__(self, value: str) -> None:
self._value = value
@property
def parts(self) -> tuple[str, str, str, str]:
first, second, third, fourth = self.value.splitlines()
return (first, second, third, fourth)
@property
def platform(self) -> Platform:
platform, _, _, _ = self.parts
return Platform(platform)
@property
def sri(self) -> SRI:
_, sri, _, _ = self.parts
return SRI(sri)
@property
def uri(self) -> URI:
_, _, uri, _ = self.parts
return URI(uri)
@property
def filename(self) -> PurePath:
_, _, _, path = self.parts
return PurePath(path)
@property
def value(self) -> str:
return self._value
def download(uri: URI, sri: SRI, filename: Path) -> None:
request = Request(str(uri))
manager = HTTPPasswordMgrWithDefaultRealm()
handler = HTTPBasicAuthHandler(manager)
opener = build_opener(handler)
try:
netrc()
except FileNotFoundError:
pass
else:
authenticators = netrc().authenticators(uri.netloc)
if authenticators is not None:
login, _, password = authenticators
manager.add_password(None, str(uri), login, password or '')
with opener.open(request) as ws:
with filename.open('wb') as fp:
for buffer in sri.check(ws):
fp.write(buffer)
fp.flush()
def install_environment(
prefix: Prefix,
version: str,
additional_dependencies: Collection[str],
) -> None:
host = Platform.host()
for dep in additional_dependencies:
m = Metadata(dep)
if host == m.platform:
envdir = Path(
lang_base.environment_dir(
prefix, ENVIRONMENT_DIR,
version,
),
)
envdir.mkdir(parents=True, exist_ok=True)
filename = Path(m.filename)
download(m.uri, m.sri, envdir / filename)
chmod(envdir / filename, stat.S_IRUSR | stat.S_IXUSR)
srisum = envdir / 'health.srisum'
with srisum.open('w', encoding='utf8') as stream:
stream.write(f'{m.sri} {filename}\n')
return
raise KeyError(f'Failed to find platform `{host}` in \
`additional_dependencies`: {additional_dependencies}')
def health_check(prefix: Prefix, version: str) -> str | None:
envdir = Path(lang_base.environment_dir(prefix, ENVIRONMENT_DIR, version))
srisum = envdir / 'health.srisum'
with srisum.open(encoding='utf8') as stream:
for line in stream:
sri_str, filepath = line.strip().split(' ', 1)
filename = envdir / filepath
sri = SRI(sri_str)
with filename.open('rb') as fp:
try:
for _ in sri.check(fp):
pass
except ChecksumMismatchError as err:
return f'{filepath} {err}\
Please reinstall the Download environment'
return None

View file

@ -0,0 +1,132 @@
from __future__ import annotations
import platform
from dataclasses import dataclass
from http.server import BaseHTTPRequestHandler
from http.server import HTTPServer
from threading import Thread
from typing import Collection
from typing import Iterator
import py.path
from pytest import fixture
import pre_commit.constants as C
from pre_commit.languages import download
from pre_commit.languages.download import ChecksumMismatchError
from pre_commit.languages.download import SRI
from pre_commit.languages.download import URI
from pre_commit.prefix import Prefix
@dataclass(frozen=True)
class Script:
content: bytes
integrity: SRI
@fixture
def shell() -> Script:
content = b'#!/bin/sh\necho hello\nexit 123'
integrity = SRI('sha256-oRJkj6Cr8nWIivZ9d3W+rVZt/aSW1l9YtxSVh+GtIHM=')
return Script(content, integrity)
@fixture
def batch() -> Script:
content = b'@echo off\necho hello\nexit 123'
integrity = SRI('sha256-L63Nefq+fKVIm24IKqlcqbJmc1rrJD3dKhIvutFK+IA=')
return Script(content, integrity)
@fixture
def script(shell: Script, batch: Script) -> Script:
if platform.system() == 'Windows':
return batch
else:
return shell
@dataclass(frozen=True)
class Server:
uri: URI
script: Script
@fixture
def server(script: Script) -> Iterator[Server]:
class HTTPRequestHandler(BaseHTTPRequestHandler):
def do_GET(self):
self.send_response(200)
self.end_headers()
self.wfile.write(script.content)
httpd = HTTPServer(('localhost', 5555), HTTPRequestHandler)
thread = Thread(target=httpd.serve_forever, name='HTTP Server')
thread.start()
host, port = httpd.server_address
uri = URI(f'http://{str(host)}:{str(port)}')
try:
yield Server(uri, fixture)
finally:
httpd.shutdown()
httpd.server_close()
thread.join(1)
assert not thread.is_alive()
@dataclass(frozen=True)
class Fixture:
server: Server
dependencies: Collection[str]
@fixture
def healthy(server: Server, shell: Script, batch: Script) -> Fixture:
dependencies = (
f"""linux/amd64
{shell.integrity}
{server.uri}
test.bat""",
f"""windows/amd64
{batch.integrity}
{server.uri}
test.bat""",
)
return Fixture(server, dependencies)
@fixture
def unhealthy(server: Server, shell: Script, batch: Script) -> Fixture:
dependencies = (
f"""linux/amd64
{batch.integrity}
{server.uri}
test.bat""",
f"""windows/amd64
{shell.integrity}
{server.uri}
test.bat""",
)
return Fixture(server, dependencies)
@fixture
def prefix(tmpdir: py.path) -> Iterator[Prefix]:
with tmpdir.as_cwd():
directory = tmpdir.join('prefix').ensure_dir()
prefix = Prefix(str(directory))
yield prefix
def test_download_healthy(prefix: Prefix, healthy: Fixture) -> None:
"""Do a download test with healthy SRI checksum"""
download.install_environment(prefix, C.DEFAULT, healthy.dependencies)
assert download.health_check(prefix, C.DEFAULT) is None
def test_download_unhealthy(prefix: Prefix, unhealthy: Fixture) -> None:
"""Do a download test with unhealthy SRI checksum"""
try:
download.install_environment(prefix, C.DEFAULT, unhealthy.dependencies)
except ChecksumMismatchError:
pass