From 537effccc032e600fdb4657998c774c1ef108b96 Mon Sep 17 00:00:00 2001
From: PGijsbers
Date: Sun, 17 Nov 2024 15:42:39 +0200
Subject: [PATCH] Add hack that allows specifying the user directory as git
---
runbenchmark.py | 15 ++++++++++++++-
1 file changed, 14 insertions(+), 1 deletion(-)
diff --git a/runbenchmark.py b/runbenchmark.py
index ad5eac7e3..6662d2f4c 100644
--- a/runbenchmark.py
+++ b/runbenchmark.py
@@ -4,6 +4,7 @@
import re
import shutil
import sys
+from pathlib import Path
# prevent asap other modules from defining the root logger using basicConfig
import amlb.logger
@@ -11,7 +12,7 @@
import openml
import amlb
-from amlb.utils import Namespace as ns, config_load, datetime_iso, str2bool, str_sanitize, zip_path
+from amlb.utils import Namespace as ns, config_load, datetime_iso, str2bool, str_sanitize, zip_path, run_cmd
from amlb import log, AutoMLError
from amlb.defaults import default_dirs
@@ -99,6 +100,18 @@
# help="The region on which to run the benchmark when using AWS.")
args = parser.parse_args()
+
+GIT_PATTERN = re.compile(r"https://(?:www.)?\w+.\w+/([a-zA-Z0-9_\-\.]+)/([a-zA-Z0-9_\-\.]+).git")
+if args.userdir and (match := GIT_PATTERN.match(args.userdir)):
+ user, repo = match.groups()
+ DOWNLOAD_DIRECTORY = Path(__file__).parent / "downloads"
+ download_path = DOWNLOAD_DIRECTORY / user / repo
+ if not download_path.exists():
+ download_path.mkdir(parents=True)
+ cmd = f"clone {args.userdir}"
+ run_cmd(f"git {cmd} {download_path}", _log_level_=logging.DEBUG)[0].strip()
+ args.userdir = str(download_path)
+
script_name = os.path.splitext(os.path.basename(__file__))[0]
extras = {t[0]: t[1] if len(t) > 1 else True for t in [x.split('=', 1) for x in args.extra]}