from __future__ import annotations import contextlib import json import logging import os.path import sqlite3 import tempfile from collections.abc import Callable from collections.abc import Generator from typing import Any import pre_commit.constants as C from pre_commit import clientlib from pre_commit import file_lock from pre_commit import git from pre_commit.util import CalledProcessError from pre_commit.util import clean_path_on_failure from pre_commit.util import cmd_output_b from pre_commit.util import resource_text from pre_commit.util import rmtree logger = logging.getLogger('pre_commit') def _get_default_directory() -> str: """Returns the default directory for the Store. This is intentionally underscored to indicate that `Store.get_default_directory` is the intended way to get this information. This is also done so `Store.get_default_directory` can be mocked in tests and `_get_default_directory` can be tested. """ ret = os.environ.get('PRE_COMMIT_HOME') or os.path.join( os.environ.get('XDG_CACHE_HOME') or os.path.expanduser('~/.cache'), 'pre-commit', ) return os.path.realpath(ret) _LOCAL_RESOURCES = ( 'Cargo.toml', 'main.go', 'go.mod', 'main.rs', '.npmignore', 'package.json', 'pre-commit-package-dev-1.rockspec', 'pre_commit_placeholder_package.gemspec', 'setup.py', 'environment.yml', 'Makefile.PL', 'pubspec.yaml', 'renv.lock', 'renv/activate.R', 'renv/LICENSE.renv', ) def _make_local_repo(directory: str) -> None: for resource in _LOCAL_RESOURCES: resource_dirname, resource_basename = os.path.split(resource) contents = resource_text(f'empty_template_{resource_basename}') target_dir = os.path.join(directory, resource_dirname) target_file = os.path.join(target_dir, resource_basename) os.makedirs(target_dir, exist_ok=True) with open(target_file, 'w') as f: f.write(contents) class Store: get_default_directory = staticmethod(_get_default_directory) def __init__(self, directory: str | None = None) -> None: self.directory = directory or Store.get_default_directory() self.db_path = os.path.join(self.directory, 'db5.db') self.readonly = ( os.path.exists(self.directory) and not os.access(self.directory, os.W_OK) ) if not os.path.exists(self.directory): os.makedirs(self.directory, exist_ok=True) with open(os.path.join(self.directory, 'README'), 'w') as f: f.write( 'This directory is maintained by the pre-commit project.\n' 'Learn more: https://github.com/pre-commit/pre-commit\n', ) if os.path.exists(self.db_path): return with self.exclusive_lock(): # Another process may have already completed this work if os.path.exists(self.db_path): # pragma: no cover (race) return # To avoid a race where someone ^Cs between db creation and # execution of the CREATE TABLE statements fd, tmpfile = tempfile.mkstemp(dir=self.directory) # We'll be managing this file ourselves os.close(fd) with self.connect(db_path=tmpfile) as db: db.executescript( 'CREATE TABLE configs (' ' path TEXT NOT NULL,' ' PRIMARY KEY (path)' ');', ) db.executescript( 'CREATE TABLE manifests (' ' repo TEXT NOT NULL,' ' rev TEXT NOT NULL,' ' manifest TEXT NOT NULL,' ' PRIMARY KEY (repo, rev)' ');', ) db.executescript( 'CREATE TABLE clones (' ' repo TEXT NOT NULL,' ' rev TEXT NOT NULL,' ' path TEXT NOT NULL,' ' PRIMARY KEY (repo, rev)' ');', ) db.executescript( 'CREATE TABLE installs (' ' repo TEXT NOT NULL,' ' rev TEXT NOT NULL,' ' language TEXT NOT NULL,' ' language_version TEXT NOT NULL,' ' additional_dependencies TEXT NOT NULL,' ' path TEXT NOT NULL,' ' PRIMARY KEY (repo, rev, language, language_version, additional_dependencies)' # noqa: E501 ');', ) # Atomic file move os.replace(tmpfile, self.db_path) @contextlib.contextmanager def exclusive_lock(self) -> Generator[None]: def blocked_cb() -> None: # pragma: no cover (tests are in-process) logger.info('Locking pre-commit directory') with file_lock.lock(os.path.join(self.directory, '.lock'), blocked_cb): yield @contextlib.contextmanager def connect( self, db_path: str | None = None, ) -> Generator[sqlite3.Connection]: db_path = db_path or self.db_path # sqlite doesn't close its fd with its contextmanager >.< # contextlib.closing fixes this. # See: https://stackoverflow.com/a/28032829/812183 with contextlib.closing(sqlite3.connect(db_path)) as db: # this creates a transaction with db: yield db def _new_repo( self, repo: str, rev: str, make_strategy: Callable[[str], None], ) -> str: def _get_result() -> str | None: # Check if we already exist with self.connect() as db: result = db.execute( 'SELECT path FROM clones WHERE repo = ? AND rev = ?', (repo, rev), ).fetchone() return result[0] if result else None result = _get_result() if result: return result with self.exclusive_lock(): # Another process may have already completed this work result = _get_result() if result: # pragma: no cover (race) return result logger.info(f'Cloning {repo}...') directory = tempfile.mkdtemp(prefix='clone', dir=self.directory) with clean_path_on_failure(directory): make_strategy(directory) manifest = clientlib.load_manifest( os.path.join(directory, C.MANIFEST_FILE), display_filename=f'({repo})/{C.MANIFEST_FILE}', ) by_id = {hook['id']: hook for hook in manifest} # Update our db with the created repo with self.connect() as db: db.execute( 'INSERT INTO clones VALUES (?, ?, ?)', (repo, rev, directory), ) db.execute( 'INSERT INTO manifests VALUES (?, ?, ?)', (repo, rev, json.dumps(by_id)), ) clientlib.warn_for_stages_on_repo_init(repo, directory) return directory def _complete_clone(self, rev: str, git_cmd: Callable[..., None]) -> None: """Perform a complete clone of a repository and its submodules """ git_cmd('fetch', 'origin', '--tags') git_cmd('checkout', rev) git_cmd('submodule', 'update', '--init', '--recursive') def _shallow_clone(self, rev: str, git_cmd: Callable[..., None]) -> None: """Perform a shallow clone of a repository and its submodules """ v2 = ('-c', 'protocol.version=2') git_cmd(*v2, 'fetch', 'origin', rev, '--depth=1') git_cmd('checkout', 'FETCH_HEAD') git_cmd( *v2, 'submodule', 'update', '--init', '--recursive', '--depth=1', ) def clone(self, repo: str, rev: str) -> str: """Clone the given url and checkout the specific rev.""" def clone_strategy(directory: str) -> None: git.init_repo(directory, repo) env = git.no_git_env() def _git_cmd(*args: str) -> None: cmd_output_b('git', *args, cwd=directory, env=env) try: self._shallow_clone(rev, _git_cmd) except CalledProcessError: self._complete_clone(rev, _git_cmd) return self._new_repo(repo, rev, clone_strategy) def make_local(self) -> str: return self._new_repo('local', C.LOCAL_REPO_VERSION, _make_local_repo) def mark_config_used(self, path: str) -> None: if self.readonly: # pragma: win32 no cover return path = os.path.realpath(path) # don't insert config files that do not exist if not os.path.exists(path): return with self.connect() as db: db.execute('INSERT OR IGNORE INTO configs VALUES (?)', (path,)) def _mark_used_installs( self, manifests: dict[tuple[str, str], dict[str, dict[str, Any]]], unused_repos: set[tuple[str, str]], repo: dict[str, Any], ) -> None: if repo['repo'] == clientlib.META: return elif repo['repo'] == clientlib.LOCAL: for hook in repo['hooks']: deps = hook.get('additional_dependencies') unused_repos.discard(( self.db_repo_name(repo['repo'], deps), C.LOCAL_REPO_VERSION, )) else: key = (repo['repo'], repo['rev']) path = all_repos.get(key) # can't inspect manifest if it isn't cloned if path is None: return try: manifest = clientlib.load_manifest( os.path.join(path, C.MANIFEST_FILE), ) except clientlib.InvalidManifestError: return else: unused_repos.discard(key) by_id = {hook['id']: hook for hook in manifest} for hook in repo['hooks']: if hook['id'] not in by_id: continue deps = hook.get( 'additional_dependencies', by_id[hook['id']]['additional_dependencies'], ) unused_repos.discard(( self.db_repo_name(repo['repo'], deps), repo['rev'], )) def gc(self) -> int: with self.exclusive_lock(), self.connect() as db: all_installs = { ( repo, rev, language, language_version, tuple(json.loads(deps)), ): path for repo, rev, language, language_version, deps, path in db.execute( 'SELECT repo, rev, language, language_version, deps, path\n' 'FROM repos' ).fetchall() } unused_installs = set(all_installs) manifests = { (repo, rev): json.loads(manifest) for repo, rev, manifest in db.execute( 'SELECT repo, rev, manifest FROM manifests' ).fetchall() } configs_rows = db.execute('SELECT path FROM configs').fetchall() configs = [path for path, in configs_rows] dead_configs = [] for config_path in configs: try: config = clientlib.load_config(config_path) except clientlib.InvalidConfigError: dead_configs.append(config_path) continue else: for repo in config['repos']: self._mark_used_installs(manifests, unused_repos, repo) paths = [(path,) for path in dead_configs] db.executemany('DELETE FROM configs WHERE path = ?', paths) db.executemany( 'DELETE FROM repos WHERE repo = ? and ref = ?', sorted(unused_repos), ) for k in unused_repos: rmtree(all_repos[k]) res = db.execute('SELECT path FROM clones').fetchall() clones = [path for path, in res] db.execute('DELETE FROM clones') for path in clones: rmtree(path) return len(paths), len(unused_repos)