diff --git a/pre_commit/languages/golang.py b/pre_commit/languages/golang.py index 60908796..678c04b1 100644 --- a/pre_commit/languages/golang.py +++ b/pre_commit/languages/golang.py @@ -75,6 +75,7 @@ def get_env_patch(venv: str, version: str) -> PatchesT: return ( ('GOROOT', os.path.join(venv, '.go')), + ('GOTOOLCHAIN', 'local'), ( 'PATH', ( os.path.join(venv, 'bin'), os.pathsep, @@ -145,6 +146,7 @@ def install_environment( env = no_git_env(dict(os.environ, GOPATH=gopath)) env.pop('GOBIN', None) if version != 'system': + env['GOTOOLCHAIN'] = 'local' env['GOROOT'] = os.path.join(env_dir, '.go') env['PATH'] = os.pathsep.join(( os.path.join(env_dir, '.go', 'bin'), os.environ['PATH'], diff --git a/tests/languages/golang_test.py b/tests/languages/golang_test.py index 02e35d71..7fb6ab18 100644 --- a/tests/languages/golang_test.py +++ b/tests/languages/golang_test.py @@ -11,11 +11,13 @@ from pre_commit.commands.install_uninstall import install from pre_commit.envcontext import envcontext from pre_commit.languages import golang from pre_commit.store import _make_local_repo +from pre_commit.util import CalledProcessError from pre_commit.util import cmd_output from testing.fixtures import add_config_to_repo from testing.fixtures import make_config_from_repo from testing.language_helpers import run_language from testing.util import cmd_output_mocked_pre_commit_home +from testing.util import cwd from testing.util import git_commit @@ -165,3 +167,70 @@ def test_during_commit_all(tmp_path, tempdir_factory, store, in_git_dir): fn=cmd_output_mocked_pre_commit_home, tempdir_factory=tempdir_factory, ) + + +def test_automatic_toolchain_switching(tmp_path): + go_mod = '''\ +module toolchain-version-test + +go 1.23.1 +''' + main_go = '''\ +package main + +func main() {} +''' + tmp_path.joinpath('go.mod').write_text(go_mod) + mod_dir = tmp_path.joinpath('toolchain-version-test') + mod_dir.mkdir() + main_file = mod_dir.joinpath('main.go') + main_file.write_text(main_go) + + with pytest.raises(CalledProcessError) as excinfo: + run_language( + path=tmp_path, + language=golang, + version='1.22.0', + exe='golang-version-test', + ) + + assert 'go.mod requires go >= 1.23.1' in excinfo.value.stderr.decode() + + +def test_automatic_toolchain_switching_go_fmt(tmp_path, monkeypatch): + go_mod_hook = '''\ +module toolchain-version-test + +go 1.22.0 +''' + go_mod = '''\ +module toolchain-version-test + +go 1.23.1 +''' + main_go = '''\ +package main + +func main() {} +''' + hook_dir = tmp_path.joinpath('hook') + hook_dir.mkdir() + hook_dir.joinpath('go.mod').write_text(go_mod_hook) + + test_dir = tmp_path.joinpath('test') + test_dir.mkdir() + test_dir.joinpath('go.mod').write_text(go_mod) + main_file = test_dir.joinpath('main.go') + main_file.write_text(main_go) + + with cwd(test_dir): + ret, out = run_language( + path=hook_dir, + language=golang, + version='1.22.0', + exe='go fmt', + file_args=(str(main_file),), + ) + + assert ret == 1 + assert 'go.mod requires go >= 1.23.1' in out.decode()