diff --git a/pre_commit/all_languages.py b/pre_commit/all_languages.py index 476bad9d..e4c8a929 100644 --- a/pre_commit/all_languages.py +++ b/pre_commit/all_languages.py @@ -7,6 +7,7 @@ from pre_commit.languages import dart from pre_commit.languages import docker from pre_commit.languages import docker_image from pre_commit.languages import dotnet +from pre_commit.languages import download from pre_commit.languages import fail from pre_commit.languages import golang from pre_commit.languages import haskell @@ -30,6 +31,7 @@ languages: dict[str, Language] = { 'docker': docker, 'docker_image': docker_image, 'dotnet': dotnet, + 'download': download, 'fail': fail, 'golang': golang, 'haskell': haskell, diff --git a/pre_commit/languages/download.py b/pre_commit/languages/download.py new file mode 100644 index 00000000..80c320af --- /dev/null +++ b/pre_commit/languages/download.py @@ -0,0 +1,328 @@ +from __future__ import annotations + +import contextlib +import hashlib +import os.path +import platform +import stat +from base64 import standard_b64decode as b64decode +from base64 import standard_b64encode as b64encode +from netrc import netrc +from os import chmod +from pathlib import Path +from pathlib import PurePath +from types import MappingProxyType +from typing import Collection +from typing import Final +from typing import Generator +from typing import Iterator +from typing import Mapping +from typing import Protocol +from typing import TYPE_CHECKING +from urllib.parse import urlparse +from urllib.request import build_opener +from urllib.request import HTTPBasicAuthHandler +from urllib.request import HTTPPasswordMgrWithDefaultRealm +from urllib.request import Request + +from pre_commit import lang_base +from pre_commit.envcontext import envcontext +from pre_commit.envcontext import PatchesT +from pre_commit.envcontext import Var +from pre_commit.prefix import Prefix + +if TYPE_CHECKING: + from _typeshed import SupportsRead + +ENVIRONMENT_DIR = 'download' +get_default_version = lang_base.basic_get_default_version +run_hook = lang_base.basic_run_hook + + +def get_env_patch(target_dir: str) -> PatchesT: + return ( + ('PATH', (target_dir, os.pathsep, Var('PATH'))), + ) + + +@contextlib.contextmanager +def in_env(prefix: Prefix, version: str) -> Generator[None, None, None]: + envdir = lang_base.environment_dir(prefix, ENVIRONMENT_DIR, version) + with envcontext(get_env_patch(envdir)): + yield + + +class Platform: + OS: Final[Mapping[str, str]] = MappingProxyType({ + 'Linux': 'linux', + 'Darwin': 'darwin', + 'Windows': 'windows', + 'DragonFly': 'dragonfly', + 'FreeBSD': 'freebsd', + }) + CPU: Final[Mapping[str, str]] = MappingProxyType({ + 'aarch64': 'arm64', + 'aarch64_be': 'arm64be', + 'arm': 'arm', + 'i386': '386', + 'i686': '386', + 'x86_64': 'amd64', + 'AMD64': 'amd64', + 'ppc': 'ppc', + 'ppc64': 'ppc64', + 'ppc64le': 'ppc64le', + }) + + def __init__(self, value: str) -> None: + self._value = value + os, cpu = self.parts + if os not in self.OS.values(): + raise ValueError(f"invalid operating system `{os}`, \ +valid values are: {','.join(self.OS.values())}") + if cpu not in self.CPU.values(): + raise ValueError(f"invalid CPU `{cpu}`, \ +valid values are: {','.join(self.CPU.values())}") + + @property + def parts(self) -> tuple[str, str]: + first, second = self.value.split('/', 1) + return (first, second) + + @property + def os(self) -> str: + os, _ = self.parts + return os + + @property + def cpu(self) -> str: + _, cpu = self.parts + return cpu + + @property + def value(self) -> str: + return self._value + + @classmethod + def host(cls: type[Platform]) -> Platform: + os = cls.OS[platform.system()] + cpu = cls.CPU[platform.machine()] + return cls(f'{os}/{cpu}') + + def __str__(self) -> str: + return f'{self.os}/{self.cpu}' + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Platform): + return False + return other.os == self.os and other.cpu == self.cpu + + +class ChecksumMismatchError(ValueError): + def __init__( + self, *, expected: SRI, actual: SRI, + message: str = 'checksum mismatch', + ): + self._expected = expected + self._actual = actual + self._message = message + + @property + def expected(self) -> SRI: + return self._expected + + @property + def actual(self) -> SRI: + return self._actual + + @property + def message(self) -> str: + return self._message + + def __str__(self) -> str: + return f'''{self._message}: + - expected: {self.expected} + - actual : {self.actual} +''' + + +class HasBinaryRead(Protocol): + def read(self, __size: int = -1) -> bytes | None: ... + + +class SRI: + def __init__(self, value: str): + self._value = value + self._algorithm, self._checksum = self.value.split('-', 1) + if self.algorithm not in hashlib.algorithms_available: + raise ValueError(f"`{self.algorithm}` is not available, \ +choose one of: {','.join(hashlib.algorithms_available)}`") + + if b64encode(b64decode(self.checksum)).decode('utf-8') !=\ + self.checksum: + raise ValueError('Invalid checksum string, \ +the checksum string has to be encoded in base64.') + + hasher = hashlib.new(self.algorithm) + + checksum_len = len(b64decode(self.checksum)) + if checksum_len != hasher.digest_size: + raise ValueError( + f'Invalid checksum string length of {checksum_len} for \ +{self.algorithm}, expected {hasher.digest_size}', + ) + + @property + def value(self) -> str: + return self._value + + @property + def algorithm(self) -> str: + return self._algorithm + + @property + def checksum(self) -> str: + return self._checksum + + def __str__(self) -> str: + return self.value + + def check( + self, io: SupportsRead[bytes], + chunk: int = 4096, + ) -> Iterator[bytes]: + hasher = hashlib.new(self.algorithm) + while buffer := io.read(chunk): + hasher.update(buffer) + yield buffer + digest = b64encode(hasher.digest()).decode('utf-8)') + if digest != self.checksum: + raise ChecksumMismatchError( + expected=self, + actual=SRI(f'{self.algorithm}-{digest}'), + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SRI): + return False + return other.value == self.value + + +class URI: + def __init__(self, value: str) -> None: + self._value = value + + url = urlparse(self.value) + if not all([url.scheme, url.netloc]): + raise ValueError(f'Invalid URI: {self.value}') + self._netloc = url.netloc + + @property + def netloc(self) -> str: + return self._netloc + + @property + def value(self) -> str: + return self._value + + def __str__(self) -> str: + return self.value + + +class Metadata: + def __init__(self, value: str) -> None: + self._value = value + + @property + def parts(self) -> tuple[str, str, str, str]: + first, second, third, fourth = self.value.splitlines() + return (first, second, third, fourth) + + @property + def platform(self) -> Platform: + platform, _, _, _ = self.parts + return Platform(platform) + + @property + def sri(self) -> SRI: + _, sri, _, _ = self.parts + return SRI(sri) + + @property + def uri(self) -> URI: + _, _, uri, _ = self.parts + return URI(uri) + + @property + def filename(self) -> PurePath: + _, _, _, path = self.parts + return PurePath(path) + + @property + def value(self) -> str: + return self._value + + +def download(uri: URI, sri: SRI, filename: Path) -> None: + request = Request(str(uri)) + manager = HTTPPasswordMgrWithDefaultRealm() + handler = HTTPBasicAuthHandler(manager) + opener = build_opener(handler) + try: + netrc() + except FileNotFoundError: + pass + else: + authenticators = netrc().authenticators(uri.netloc) + if authenticators is not None: + login, _, password = authenticators + manager.add_password(None, str(uri), login, password or '') + with opener.open(request) as ws: + with filename.open('wb') as fp: + for buffer in sri.check(ws): + fp.write(buffer) + fp.flush() + + +def install_environment( + prefix: Prefix, + version: str, + additional_dependencies: Collection[str], +) -> None: + host = Platform.host() + for dep in additional_dependencies: + m = Metadata(dep) + if host == m.platform: + envdir = Path( + lang_base.environment_dir( + prefix, ENVIRONMENT_DIR, + version, + ), + ) + envdir.mkdir(parents=True, exist_ok=True) + filename = Path(m.filename) + download(m.uri, m.sri, envdir / filename) + chmod(envdir / filename, stat.S_IRUSR | stat.S_IXUSR) + srisum = envdir / 'health.srisum' + with srisum.open('w', encoding='utf8') as stream: + stream.write(f'{m.sri} {filename}\n') + return + raise KeyError(f'Failed to find platform `{host}` in \ +`additional_dependencies`: {additional_dependencies}') + + +def health_check(prefix: Prefix, version: str) -> str | None: + envdir = Path(lang_base.environment_dir(prefix, ENVIRONMENT_DIR, version)) + srisum = envdir / 'health.srisum' + with srisum.open(encoding='utf8') as stream: + for line in stream: + sri_str, filepath = line.strip().split(' ', 1) + filename = envdir / filepath + sri = SRI(sri_str) + with filename.open('rb') as fp: + try: + for _ in sri.check(fp): + pass + except ChecksumMismatchError as err: + return f'{filepath} {err}\ +Please reinstall the Download environment' + return None diff --git a/tests/languages/download_test.py b/tests/languages/download_test.py new file mode 100644 index 00000000..832b7606 --- /dev/null +++ b/tests/languages/download_test.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import platform +from dataclasses import dataclass +from http.server import BaseHTTPRequestHandler +from http.server import HTTPServer +from threading import Thread +from typing import Collection +from typing import Iterator + +import py.path +from pytest import fixture + +import pre_commit.constants as C +from pre_commit.languages import download +from pre_commit.languages.download import ChecksumMismatchError +from pre_commit.languages.download import SRI +from pre_commit.languages.download import URI +from pre_commit.prefix import Prefix + + +@dataclass(frozen=True) +class Script: + content: bytes + integrity: SRI + + +@fixture +def shell() -> Script: + content = b'#!/bin/sh\necho hello\nexit 123' + integrity = SRI('sha256-oRJkj6Cr8nWIivZ9d3W+rVZt/aSW1l9YtxSVh+GtIHM=') + return Script(content, integrity) + + +@fixture +def batch() -> Script: + content = b'@echo off\necho hello\nexit 123' + integrity = SRI('sha256-L63Nefq+fKVIm24IKqlcqbJmc1rrJD3dKhIvutFK+IA=') + return Script(content, integrity) + + +@fixture +def script(shell: Script, batch: Script) -> Script: + if platform.system() == 'Windows': + return batch + else: + return shell + + +@dataclass(frozen=True) +class Server: + uri: URI + script: Script + + +@fixture +def server(script: Script) -> Iterator[Server]: + class HTTPRequestHandler(BaseHTTPRequestHandler): + def do_GET(self): + self.send_response(200) + self.end_headers() + self.wfile.write(script.content) + httpd = HTTPServer(('localhost', 5555), HTTPRequestHandler) + thread = Thread(target=httpd.serve_forever, name='HTTP Server') + thread.start() + host, port = httpd.server_address + uri = URI(f'http://{str(host)}:{str(port)}') + try: + yield Server(uri, fixture) + finally: + httpd.shutdown() + httpd.server_close() + thread.join(1) + assert not thread.is_alive() + + +@dataclass(frozen=True) +class Fixture: + server: Server + dependencies: Collection[str] + + +@fixture +def healthy(server: Server, shell: Script, batch: Script) -> Fixture: + dependencies = ( + f"""linux/amd64 +{shell.integrity} +{server.uri} +test.bat""", + f"""windows/amd64 +{batch.integrity} +{server.uri} +test.bat""", + ) + return Fixture(server, dependencies) + + +@fixture +def unhealthy(server: Server, shell: Script, batch: Script) -> Fixture: + dependencies = ( + f"""linux/amd64 +{batch.integrity} +{server.uri} +test.bat""", + f"""windows/amd64 +{shell.integrity} +{server.uri} +test.bat""", + ) + return Fixture(server, dependencies) + + +@fixture +def prefix(tmpdir: py.path) -> Iterator[Prefix]: + with tmpdir.as_cwd(): + directory = tmpdir.join('prefix').ensure_dir() + prefix = Prefix(str(directory)) + yield prefix + + +def test_download_healthy(prefix: Prefix, healthy: Fixture) -> None: + """Do a download test with healthy SRI checksum""" + download.install_environment(prefix, C.DEFAULT, healthy.dependencies) + assert download.health_check(prefix, C.DEFAULT) is None + + +def test_download_unhealthy(prefix: Prefix, unhealthy: Fixture) -> None: + """Do a download test with unhealthy SRI checksum""" + try: + download.install_environment(prefix, C.DEFAULT, unhealthy.dependencies) + except ChecksumMismatchError: + pass