pre-commit/pre_commit/store.py
anthony sottile 6ee2b2dfb0 wip
2026-03-19 13:45:12 -04:00

249 lines
8.9 KiB
Python

from __future__ import annotations
import contextlib
import json
import logging
import os.path
import sqlite3
import tempfile
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
import pre_commit.constants as C
from pre_commit import clientlib
from pre_commit import file_lock
from pre_commit import git
from pre_commit.util import CalledProcessError
from pre_commit.util import clean_path_on_failure
from pre_commit.util import cmd_output_b
from pre_commit.util import resource_text
logger = logging.getLogger('pre_commit')
def _get_default_directory() -> str:
"""Returns the default directory for the Store. This is intentionally
underscored to indicate that `Store.get_default_directory` is the intended
way to get this information. This is also done so
`Store.get_default_directory` can be mocked in tests and
`_get_default_directory` can be tested.
"""
ret = os.environ.get('PRE_COMMIT_HOME') or os.path.join(
os.environ.get('XDG_CACHE_HOME') or os.path.expanduser('~/.cache'),
'pre-commit',
)
return os.path.realpath(ret)
_LOCAL_RESOURCES = (
'Cargo.toml', 'main.go', 'go.mod', 'main.rs', '.npmignore',
'package.json', 'pre-commit-package-dev-1.rockspec',
'pre_commit_placeholder_package.gemspec', 'setup.py',
'environment.yml', 'Makefile.PL', 'pubspec.yaml',
'renv.lock', 'renv/activate.R', 'renv/LICENSE.renv',
)
def _make_local_repo(directory: str) -> None:
for resource in _LOCAL_RESOURCES:
resource_dirname, resource_basename = os.path.split(resource)
contents = resource_text(f'empty_template_{resource_basename}')
target_dir = os.path.join(directory, resource_dirname)
target_file = os.path.join(target_dir, resource_basename)
os.makedirs(target_dir, exist_ok=True)
with open(target_file, 'w') as f:
f.write(contents)
class Store:
get_default_directory = staticmethod(_get_default_directory)
def __init__(self, directory: str | None = None) -> None:
self.directory = directory or Store.get_default_directory()
self.db_path = os.path.join(self.directory, 'db5.db')
self.readonly = (
os.path.exists(self.directory) and
not os.access(self.directory, os.W_OK)
)
if not os.path.exists(self.directory):
os.makedirs(self.directory, exist_ok=True)
with open(os.path.join(self.directory, 'README'), 'w') as f:
f.write(
'This directory is maintained by the pre-commit project.\n'
'Learn more: https://github.com/pre-commit/pre-commit\n',
)
if os.path.exists(self.db_path):
return
with self.exclusive_lock():
# Another process may have already completed this work
if os.path.exists(self.db_path): # pragma: no cover (race)
return
# To avoid a race where someone ^Cs between db creation and
# execution of the CREATE TABLE statements
fd, tmpfile = tempfile.mkstemp(dir=self.directory)
# We'll be managing this file ourselves
os.close(fd)
with self.connect(db_path=tmpfile) as db:
db.executescript(
'CREATE TABLE configs ('
' path TEXT NOT NULL,'
' PRIMARY KEY (path)'
');',
)
db.executescript(
'CREATE TABLE manifests ('
' repo TEXT NOT NULL,'
' rev TEXT NOT NULL,'
' manifest TEXT NOT NULL,'
' PRIMARY KEY (repo, rev)'
');',
)
db.executescript(
'CREATE TABLE clones ('
' repo TEXT NOT NULL,'
' rev TEXT NOT NULL,'
' path TEXT NOT NULL,'
' PRIMARY KEY (repo, rev)'
');',
)
db.executescript(
'CREATE TABLE installs ('
' key TEXT NOT NULL,'
' path TEXT NOT NULL,'
' PRIMARY KEY (key)'
');',
)
# Atomic file move
os.replace(tmpfile, self.db_path)
@contextlib.contextmanager
def exclusive_lock(self) -> Generator[None]:
def blocked_cb() -> None: # pragma: no cover (tests are in-process)
logger.info('Locking pre-commit directory')
with file_lock.lock(os.path.join(self.directory, '.lock'), blocked_cb):
yield
@contextlib.contextmanager
def connect(
self,
db_path: str | None = None,
) -> Generator[sqlite3.Connection]:
db_path = db_path or self.db_path
# sqlite doesn't close its fd with its contextmanager >.<
# contextlib.closing fixes this.
# See: https://stackoverflow.com/a/28032829/812183
with contextlib.closing(sqlite3.connect(db_path)) as db:
# this creates a transaction
with db:
yield db
def _complete_clone(self, rev: str, git_cmd: Callable[..., None]) -> None:
"""Perform a complete clone of a repository and its submodules """
git_cmd('fetch', 'origin', '--tags')
git_cmd('checkout', rev)
git_cmd('submodule', 'update', '--init', '--recursive')
def _shallow_clone(self, rev: str, git_cmd: Callable[..., None]) -> None:
"""Perform a shallow clone of a repository and its submodules """
v2 = ('-c', 'protocol.version=2')
git_cmd(*v2, 'fetch', 'origin', rev, '--depth=1')
git_cmd('checkout', 'FETCH_HEAD')
git_cmd(
*v2, 'submodule', 'update', '--init', '--recursive',
'--depth=1',
)
def clone(self, repo: str, rev: str) -> str:
def _get_result() -> str | None:
with self.connect() as db:
result = db.execute(
'SELECT path FROM clones WHERE repo = ? AND rev = ?',
(repo, rev),
).fetchone()
return result[0] if result else None
result = _get_result()
if result:
return result
with self.exclusive_lock():
# Another process may have already completed this work
result = _get_result()
if result: # pragma: no cover (race)
return result
logger.info(f'Cloning {repo}...')
directory = tempfile.mkdtemp(prefix='clone', dir=self.directory)
with clean_path_on_failure(directory):
git.init_repo(directory, repo)
env = git.no_git_env()
def _git_cmd(*args: str) -> None:
cmd_output_b('git', *args, cwd=directory, env=env)
try:
self._shallow_clone(rev, _git_cmd)
except CalledProcessError:
self._complete_clone(rev, _git_cmd)
manifest = clientlib.load_raw(
os.path.join(directory, C.MANIFEST_FILE),
display_filename=f'({repo})/{C.MANIFEST_FILE}',
exc_tp=clientlib.InvalidManifestError,
)
# Update our db with the created repo
with self.connect() as db:
db.execute(
'INSERT INTO clones VALUES (?, ?, ?)',
(repo, rev, directory),
)
db.execute(
'INSERT INTO manifests VALUES (?, ?, ?)',
(repo, rev, json.dumps(manifest)),
)
clientlib.warn_for_stages_on_repo_init(repo, directory)
return directory
def maybe_manifest(
self,
repo: str,
rev: str,
) -> dict[str, dict[str, Any]] | None:
with self.connect() as db:
result = db.execute(
'SELECT manifest FROM manifests WHERE repo = ? AND rev = ?',
(repo, rev),
).fetchone()
if result is not None:
return clientlib.load_manifest_contents(repo, result[0])
else:
return None
def manifest(self, repo: str, rev: str) -> dict[str, dict[str, Any]]:
ret = self.maybe_manifest(repo, rev)
if ret is None:
self.clone(repo, rev)
ret = self.maybe_manifest(repo, rev)
assert ret is not None
return ret
def mark_config_used(self, path: str) -> None:
if self.readonly: # pragma: win32 no cover
return
path = os.path.realpath(path)
# don't insert config files that do not exist
if not os.path.exists(path):
return
with self.connect() as db:
db.execute('INSERT OR IGNORE INTO configs VALUES (?)', (path,))