diff --git a/frameworks/AutoGluon/__init__.py b/frameworks/AutoGluon/__init__.py index 2d5734e33..bb2c0f615 100644 --- a/frameworks/AutoGluon/__init__.py +++ b/frameworks/AutoGluon/__init__.py @@ -1,4 +1,3 @@ - from amlb.utils import call_script_in_same_dir from amlb.benchmark import TaskConfig from amlb.data import Dataset, DatasetType @@ -8,8 +7,8 @@ def setup(*args, **kwargs): call_script_in_same_dir(__file__, "setup.sh", *args, **kwargs) -def run(dataset: Dataset, config: TaskConfig): +def run(dataset: Dataset, config: TaskConfig): if dataset.type == DatasetType.timeseries: return run_autogluon_timeseries(dataset, config) else: @@ -18,23 +17,33 @@ def run(dataset: Dataset, config: TaskConfig): def run_autogluon_tabular(dataset: Dataset, config: TaskConfig): from frameworks.shared.caller import run_in_venv + data = dict( - train=dict(path=dataset.train.data_path('parquet')), - test=dict(path=dataset.test.data_path('parquet')), - target=dict( - name=dataset.target.name, - classes=dataset.target.values - ), + train=dict(path=dataset.train.data_path("parquet")), + test=dict(path=dataset.test.data_path("parquet")), + target=dict(name=dataset.target.name, classes=dataset.target.values), problem_type=dataset.type.name, # AutoGluon problem_type is using same names as amlb.data.DatasetType ) if config.measure_inference_time: - data["inference_subsample_files"] = dataset.inference_subsample_files(fmt="parquet") + data["inference_subsample_files"] = dataset.inference_subsample_files( + fmt="parquet" + ) + + options = {"serialization": {"numpy_allow_pickle": False}} + + return run_in_venv( + __file__, + "exec.py", + input_data=data, + dataset=dataset, + config=config, + options=options, + ) - return run_in_venv(__file__, "exec.py", - input_data=data, dataset=dataset, config=config) def run_autogluon_timeseries(dataset: Dataset, config: TaskConfig): from frameworks.shared.caller import run_in_venv + dataset = deepcopy(dataset) data = dict( @@ -50,5 +59,6 @@ def run_autogluon_timeseries(dataset: Dataset, config: TaskConfig): repeated_item_id=dataset.repeated_item_id, ) - return run_in_venv(__file__, "exec_ts.py", - input_data=data, dataset=dataset, config=config) + return run_in_venv( + __file__, "exec_ts.py", input_data=data, dataset=dataset, config=config + )