Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/progress #1335

Merged
merged 11 commits into from
Sep 16, 2024
3 changes: 3 additions & 0 deletions doc/progress.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ Changelog
next
~~~~~~

* ADD #1335: Improve MinIO support.
* Add progress bar for downloading MinIO files. Enable it with setting `show_progress` to true on either `openml.config` or the configuration file.
* When using `download_all_files`, files are only downloaded if they do not yet exist in the cache.
* MAINT #1340: Add Numpy 2.0 support. Update tests to work with scikit-learn <= 1.5.
* ADD #1342: Add HTTP header to requests to indicate they are from openml-python.

Expand Down
9 changes: 9 additions & 0 deletions examples/20_basic/simple_datasets_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@
X, y, categorical_indicator, attribute_names = dataset.get_data(
dataset_format="dataframe", target=dataset.default_target_attribute
)

############################################################################
# Tip: you can get a progress bar for dataset downloads, simply set it in
# the configuration. Either in code or in the configuration file
# (see also the introduction tutorial)

openml.config.show_progress = True


############################################################################
# Visualize the dataset
# =====================
Expand Down
15 changes: 9 additions & 6 deletions openml/_api_calls.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# License: BSD 3-Clause
from __future__ import annotations

import contextlib
import hashlib
import logging
import math
Expand All @@ -26,6 +27,7 @@
OpenMLServerException,
OpenMLServerNoResult,
)
from .utils import ProgressBar

_HEADERS = {"user-agent": f"openml-python/{__version__}"}

Expand Down Expand Up @@ -161,12 +163,12 @@ def _download_minio_file(
proxy_client = ProxyManager(proxy) if proxy else None

client = minio.Minio(endpoint=parsed_url.netloc, secure=False, http_client=proxy_client)

try:
client.fget_object(
bucket_name=bucket,
object_name=object_name,
file_path=str(destination),
progress=ProgressBar() if config.show_progress else None,
PGijsbers marked this conversation as resolved.
Show resolved Hide resolved
request_headers=_HEADERS,
)
if destination.is_file() and destination.suffix == ".zip":
Expand Down Expand Up @@ -206,11 +208,12 @@ def _download_minio_bucket(source: str, destination: str | Path) -> None:
if file_object.object_name is None:
raise ValueError("Object name is None.")

_download_minio_file(
source=source.rsplit("/", 1)[0] + "/" + file_object.object_name.rsplit("/", 1)[1],
destination=Path(destination, file_object.object_name.rsplit("/", 1)[1]),
exists_ok=True,
)
with contextlib.suppress(FileExistsError): # Simply use cached version instead
_download_minio_file(
source=source.rsplit("/", 1)[0] + "/" + file_object.object_name.rsplit("/", 1)[1],
destination=Path(destination, file_object.object_name.rsplit("/", 1)[1]),
exists_ok=False,
)


def _download_text_file(
Expand Down
16 changes: 11 additions & 5 deletions openml/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class _Config(TypedDict):
avoid_duplicate_runs: bool
retry_policy: Literal["human", "robot"]
connection_n_retries: int
show_progress: bool


def _create_log_handlers(create_file_handler: bool = True) -> None: # noqa: FBT001, FBT002
Expand Down Expand Up @@ -111,6 +112,7 @@ def set_file_log_level(file_output_level: int) -> None:
"avoid_duplicate_runs": True,
"retry_policy": "human",
"connection_n_retries": 5,
"show_progress": False,
}

# Default values are actually added here in the _setup() function which is
Expand All @@ -131,6 +133,7 @@ def get_server_base_url() -> str:


apikey: str = _defaults["apikey"]
show_progress: bool = _defaults["show_progress"]
# The current cache directory (without the server name)
_root_cache_directory = Path(_defaults["cachedir"])
avoid_duplicate_runs = _defaults["avoid_duplicate_runs"]
Expand Down Expand Up @@ -238,6 +241,7 @@ def _setup(config: _Config | None = None) -> None:
global server # noqa: PLW0603
global _root_cache_directory # noqa: PLW0603
global avoid_duplicate_runs # noqa: PLW0603
global show_progress # noqa: PLW0603

config_file = determine_config_file_path()
config_dir = config_file.parent
Expand All @@ -255,6 +259,7 @@ def _setup(config: _Config | None = None) -> None:
avoid_duplicate_runs = config["avoid_duplicate_runs"]
apikey = config["apikey"]
server = config["server"]
show_progress = config["show_progress"]
short_cache_dir = Path(config["cachedir"])
n_retries = int(config["connection_n_retries"])

Expand Down Expand Up @@ -328,11 +333,11 @@ def _parse_config(config_file: str | Path) -> _Config:
logger.info("Error opening file %s: %s", config_file, e.args[0])
config_file_.seek(0)
config.read_file(config_file_)
if isinstance(config["FAKE_SECTION"]["avoid_duplicate_runs"], str):
config["FAKE_SECTION"]["avoid_duplicate_runs"] = config["FAKE_SECTION"].getboolean(
"avoid_duplicate_runs"
) # type: ignore
return dict(config.items("FAKE_SECTION")) # type: ignore
configuration = dict(config.items("FAKE_SECTION"))
for boolean_field in ["avoid_duplicate_runs", "show_progress"]:
if isinstance(config["FAKE_SECTION"][boolean_field], str):
configuration[boolean_field] = config["FAKE_SECTION"].getboolean(boolean_field) # type: ignore
return configuration # type: ignore


def get_config_as_dict() -> _Config:
Expand All @@ -343,6 +348,7 @@ def get_config_as_dict() -> _Config:
"avoid_duplicate_runs": avoid_duplicate_runs,
"connection_n_retries": connection_n_retries,
"retry_policy": retry_policy,
"show_progress": show_progress,
}


Expand Down
7 changes: 3 additions & 4 deletions openml/datasets/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,10 +1262,9 @@ def _get_dataset_parquet(
if old_file_path.is_file():
old_file_path.rename(output_file_path)

# For this release, we want to be able to force a new download even if the
# parquet file is already present when ``download_all_files`` is set.
# For now, it would be the only way for the user to fetch the additional
# files in the bucket (no function exists on an OpenMLDataset to do this).
# The call below skips files already on disk, so avoids downloading the parquet file twice.
# To force the old behavior of always downloading everything, use `force_refresh_cache`
# of `get_dataset`
if download_all_files:
openml._api_calls._download_minio_bucket(source=url, destination=cache_directory)

Expand Down
38 changes: 38 additions & 0 deletions openml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import numpy as np
import pandas as pd
import xmltodict
from minio.helpers import ProgressType
from tqdm import tqdm

import openml
import openml._api_calls
Expand Down Expand Up @@ -471,3 +473,39 @@ def _create_lockfiles_dir() -> Path:
with contextlib.suppress(OSError):
path.mkdir(exist_ok=True, parents=True)
return path


class ProgressBar(ProgressType):
"""Progressbar for MinIO function's `progress` parameter."""

def __init__(self) -> None:
self._object_name = ""
self._progress_bar: tqdm | None = None

def set_meta(self, object_name: str, total_length: int) -> None:
"""Initializes the progress bar.

Parameters
----------
object_name: str
Not used.

total_length: int
File size of the object in bytes.
"""
self._object_name = object_name
self._progress_bar = tqdm(total=total_length, unit_scale=True, unit="B")

def update(self, length: int) -> None:
"""Updates the progress bar.

Parameters
----------
length: int
Number of bytes downloaded since last `update` call.
"""
if not self._progress_bar:
raise RuntimeError("Call `set_meta` before calling `update`.")
self._progress_bar.update(length)
if self._progress_bar.total <= self._progress_bar.n:
self._progress_bar.close()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies = [
"numpy>=1.6.2",
"minio",
"pyarrow",
"tqdm", # For MinIO download progress bars
"packaging",
]
requires-python = ">=3.8"
Expand Down
41 changes: 41 additions & 0 deletions tests/test_openml/test_api_calls.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from __future__ import annotations

import unittest.mock
from pathlib import Path
from typing import NamedTuple, Iterable, Iterator
from unittest import mock

import minio
import pytest

import openml
import openml.testing
from openml._api_calls import _download_minio_bucket


class TestConfig(openml.testing.TestBase):
Expand All @@ -30,3 +35,39 @@ def test_retry_on_database_error(self, Session_class_mock, _):
openml._api_calls._send_request("get", "/abc", {})

assert Session_class_mock.return_value.__enter__.return_value.get.call_count == 20

class FakeObject(NamedTuple):
object_name: str

class FakeMinio:
def __init__(self, objects: Iterable[FakeObject] | None = None):
self._objects = objects or []

def list_objects(self, *args, **kwargs) -> Iterator[FakeObject]:
yield from self._objects

def fget_object(self, object_name: str, file_path: str, *args, **kwargs) -> None:
if object_name in [obj.object_name for obj in self._objects]:
Path(file_path).write_text("foo")
return
raise FileNotFoundError


@mock.patch.object(minio, "Minio")
def test_download_all_files_observes_cache(mock_minio, tmp_path: Path) -> None:
some_prefix, some_filename = "some/prefix", "dataset.arff"
some_object_path = f"{some_prefix}/{some_filename}"
some_url = f"https://not.real.com/bucket/{some_object_path}"
mock_minio.return_value = FakeMinio(
objects=[
FakeObject(some_object_path),
],
)

_download_minio_bucket(source=some_url, destination=tmp_path)
time_created = (tmp_path / "dataset.arff").stat().st_ctime

_download_minio_bucket(source=some_url, destination=tmp_path)
time_modified = (tmp_path / some_filename).stat().st_mtime

assert time_created == time_modified
10 changes: 10 additions & 0 deletions tests/test_openml/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,13 @@ def test_configuration_file_not_overwritten_on_load():

assert config_file_content == new_file_content
assert "abcd" == read_config["apikey"]

def test_configuration_loads_booleans(tmp_path):
config_file_content = "avoid_duplicate_runs=true\nshow_progress=false"
with (tmp_path/"config").open("w") as config_file:
config_file.write(config_file_content)
read_config = openml.config._parse_config(tmp_path)

# Explicit test to avoid truthy/falsy modes of other types
assert True == read_config["avoid_duplicate_runs"]
assert False == read_config["show_progress"]
Loading