Use temp env file to install add_deps for conda

This commit is contained in:
Simon Boehm 2020-12-21 12:08:00 +01:00
parent 0ed7930976
commit 6ee227326c
No known key found for this signature in database
GPG key ID: 59DF9146ABE67092

View file

@ -1,5 +1,6 @@
import contextlib import contextlib
import os import os
from tempfile import NamedTemporaryFile
from typing import Generator from typing import Generator
from typing import Sequence from typing import Sequence
from typing import Tuple from typing import Tuple
@ -14,6 +15,8 @@ from pre_commit.languages import helpers
from pre_commit.prefix import Prefix from pre_commit.prefix import Prefix
from pre_commit.util import clean_path_on_failure from pre_commit.util import clean_path_on_failure
from pre_commit.util import cmd_output_b from pre_commit.util import cmd_output_b
from pre_commit.util import yaml_dump
from pre_commit.util import yaml_load
ENVIRONMENT_DIR = 'conda' ENVIRONMENT_DIR = 'conda'
get_default_version = helpers.basic_get_default_version get_default_version = helpers.basic_get_default_version
@ -59,16 +62,25 @@ def install_environment(
directory = helpers.environment_dir(ENVIRONMENT_DIR, version) directory = helpers.environment_dir(ENVIRONMENT_DIR, version)
env_dir = prefix.path(directory) env_dir = prefix.path(directory)
env_yaml_path = prefix.path('environment.yml')
with clean_path_on_failure(env_dir): with clean_path_on_failure(env_dir):
with open(env_yaml_path) as env_file:
env_yaml = yaml_load(env_file)
env_yaml['dependencies'] += additional_dependencies
try:
with NamedTemporaryFile(
suffix='.yml',
mode='w',
delete=False,
) as tmp_env_file:
yaml_dump(env_yaml, stream=tmp_env_file)
cmd_output_b( cmd_output_b(
'conda', 'env', 'create', '-p', env_dir, '--file', 'conda', 'env', 'create', '-p', env_dir, '--file',
'environment.yml', cwd=prefix.prefix_dir, tmp_env_file.name, cwd=prefix.prefix_dir,
)
if additional_dependencies:
cmd_output_b(
'conda', 'install', '-p', env_dir, *additional_dependencies,
cwd=prefix.prefix_dir,
) )
finally:
os.remove(tmp_env_file.name)
def run_hook( def run_hook(