Skip to content

Commit

Permalink
Update high-level SDK used for exporting project|task|job `data…
Browse files Browse the repository at this point in the history
…sets`|`backups` (#8255)

- Fixed exporting the same dataset or backup twice in a row using
high-level SDK (switched to new export API version) (related
#8256)
- Fixed exporting a dataset or backup using high-level SDK when the
default project or task location refers to cloud storage
- Added ability to explicitly specify location when exporting datasets
and backups using high-level SDK

## Summary by CodeRabbit

- **New Features**
- Introduced mixins for exporting datasets and downloading backups,
enhancing functionality across multiple classes.
- Added a new fixture for testing tasks with specified target storage,
improving test coverage.

- **Bug Fixes**
- Improved error handling in the file download process to ensure
validity before proceeding.

- **Refactor**
- Restructured the downloading mechanism for better modularity and
maintainability.
- Removed outdated methods in favor of mixin functionality, streamlining
class design.

- **Tests**
- Enhanced the test suite with additional scenarios and flexibility for
task management and dataset downloading.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Maxim Zhiltsov <zhiltsov.max35@gmail.com>
  • Loading branch information
Marishka17 and zhiltsov-max authored Aug 30, 2024
1 parent 2d6ac62 commit 878bb41
Show file tree
Hide file tree
Showing 18 changed files with 616 additions and 256 deletions.
12 changes: 12 additions & 0 deletions changelog.d/20240827_171721_maria_update_high_level_export_sdk.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
### Fixed

- An issue that occurred when exporting the same dataset or backup twice in a row using SDK
(<https://github.com/cvat-ai/cvat/pull/8255>)
- An issue that occurred when exporting a dataset or backup using SDK
when the default project or task location refers to cloud storage
(<https://github.com/cvat-ai/cvat/pull/8255>)

### Added

- Ability to specify location when exporting datasets and backups using SDK
(<https://github.com/cvat-ai/cvat/pull/8255>)
34 changes: 28 additions & 6 deletions cvat-sdk/cvat_sdk/core/downloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,12 @@ def download_file(
pbar.advance(len(chunk))
fd.write(chunk)

def prepare_and_download_file_from_endpoint(
def prepare_file(
self,
endpoint: Endpoint,
filename: Path,
*,
url_params: Optional[Dict[str, Any]] = None,
query_params: Optional[Dict[str, Any]] = None,
pbar: Optional[ProgressReporter] = None,
status_check_period: Optional[int] = None,
):
client = self._client
Expand All @@ -91,7 +89,7 @@ def prepare_and_download_file_from_endpoint(

# initialize background process
response = client.api_client.rest_client.request(
method="GET",
method=endpoint.settings["http_method"],
url=url,
headers=client.api_client.get_common_headers(),
)
Expand All @@ -106,5 +104,29 @@ def prepare_and_download_file_from_endpoint(
rq_id, status_check_period=status_check_period
)

downloader = Downloader(client)
downloader.download_file(request.result_url, output_path=filename, pbar=pbar)
return request

def prepare_and_download_file_from_endpoint(
self,
endpoint: Endpoint,
filename: Path,
*,
url_params: Optional[Dict[str, Any]] = None,
query_params: Optional[Dict[str, Any]] = None,
pbar: Optional[ProgressReporter] = None,
status_check_period: Optional[int] = None,
):
client = self._client

if status_check_period is None:
status_check_period = client.config.status_check_period

export_request = self.prepare_file(
endpoint,
url_params=url_params,
query_params=query_params,
status_check_period=status_check_period,
)

assert export_request.result_url, "Result url was not found in server response"
self.download_file(export_request.result_url, output_path=filename, pbar=pbar)
34 changes: 2 additions & 32 deletions cvat-sdk/cvat_sdk/core/proxies/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from PIL import Image

from cvat_sdk.api_client import apis, models
from cvat_sdk.core.downloading import Downloader
from cvat_sdk.core.helpers import get_paginated_collection
from cvat_sdk.core.progress import ProgressReporter
from cvat_sdk.core.proxies.annotations import AnnotationCrudMixin
from cvat_sdk.core.proxies.issues import Issue
from cvat_sdk.core.proxies.model_proxy import (
ExportDatasetMixin,
ModelListMixin,
ModelRetrieveMixin,
ModelUpdateMixin,
Expand All @@ -38,6 +38,7 @@ class Job(
_JobEntityBase,
ModelUpdateMixin[models.IPatchedJobWriteRequest],
AnnotationCrudMixin,
ExportDatasetMixin,
):
_model_partial_update_arg = "patched_job_write_request"
_put_annotations_data_param = "job_annotations_update_request"
Expand Down Expand Up @@ -67,37 +68,6 @@ def import_annotations(

self._client.logger.info(f"Annotation file '{filename}' for job #{self.id} uploaded")

def export_dataset(
self,
format_name: str,
filename: StrPath,
*,
pbar: Optional[ProgressReporter] = None,
status_check_period: Optional[int] = None,
include_images: bool = True,
) -> None:
"""
Download annotations for a job in the specified format (e.g. 'YOLO ZIP 1.0').
"""

filename = Path(filename)

if include_images:
endpoint = self.api.retrieve_dataset_endpoint
else:
endpoint = self.api.retrieve_annotations_endpoint

Downloader(self._client).prepare_and_download_file_from_endpoint(
endpoint=endpoint,
filename=filename,
url_params={"id": self.id},
query_params={"format": format_name},
pbar=pbar,
status_check_period=status_check_period,
)

self._client.logger.info(f"Dataset for job {self.id} has been downloaded to {filename}")

def get_frame(
self,
frame_id: int,
Expand Down
171 changes: 171 additions & 0 deletions cvat-sdk/cvat_sdk/core/proxies/model_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import json
from abc import ABC
from copy import deepcopy
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generic,
List,
Expand All @@ -24,10 +26,16 @@

from typing_extensions import Self

from cvat_sdk.api_client import exceptions
from cvat_sdk.api_client.model_utils import IModelData, ModelNormal, to_json
from cvat_sdk.core.downloading import Downloader
from cvat_sdk.core.helpers import get_paginated_collection
from cvat_sdk.core.progress import ProgressReporter
from cvat_sdk.core.proxies.types import Location

if TYPE_CHECKING:
from _typeshed import StrPath

from cvat_sdk.core.client import Client

IModel = TypeVar("IModel", bound=IModelData)
Expand Down Expand Up @@ -210,3 +218,166 @@ def remove(self: Entity) -> None:
"""

self.api.destroy(id=getattr(self, self._model_id_field))


class _ExportMixin(Generic[_EntityT]):
def export(
self,
endpoint: Callable,
filename: StrPath,
*,
pbar: Optional[ProgressReporter] = None,
status_check_period: Optional[int] = None,
location: Optional[Location] = None,
cloud_storage_id: Optional[int] = None,
**query_params,
) -> None:
query_params = {
**query_params,
**({"location": location} if location else {}),
}

if location == Location.CLOUD_STORAGE:
if not cloud_storage_id:
raise ValueError(
f"Cloud storage ID must be specified when {location!r} location is used"
)

query_params["cloud_storage_id"] = cloud_storage_id

local_downloading = (
location == Location.LOCAL
or not location
and (not self.target_storage or self.target_storage.location.value == Location.LOCAL)
)

if not local_downloading:
query_params["filename"] = str(filename)

downloader = Downloader(self._client)
export_request = downloader.prepare_file(
endpoint,
url_params={"id": self.id},
query_params=query_params,
status_check_period=status_check_period,
)

result_url = export_request.result_url

if (
location == Location.LOCAL
and not result_url
or location == Location.CLOUD_STORAGE
and result_url
):
raise exceptions.ServiceException(500, "Server handled export parameters incorrectly")
elif not location and (
(not self.target_storage or self.target_storage.location.value == Location.LOCAL)
and not result_url
or (
self.target_storage
and self.target_storage.location.value == Location.CLOUD_STORAGE
and result_url
)
):
# SDK should not raise an exception here, because most likely
# a SDK model was outdated while export finished successfully
self._client.logger.warn(
f"{self.__class__.__name__.title()} was outdated. "
f"Use .fetch() method to obtain {self.__class__.__name__.lower()!r} actual version"
)

if result_url:
downloader.download_file(result_url, output_path=Path(filename), pbar=pbar)


class ExportDatasetMixin(_ExportMixin):
def export_dataset(
self,
format_name: str,
filename: StrPath,
*,
pbar: Optional[ProgressReporter] = None,
status_check_period: Optional[int] = None,
include_images: bool = True,
location: Optional[Location] = None,
cloud_storage_id: Optional[int] = None,
) -> None:
"""
Export a dataset in the specified format (e.g. 'YOLO ZIP 1.0').
By default, a result file will be downloaded based on the default configuration.
To force file downloading, pass `location=Location.LOCAL`.
To save a file to a specific cloud storage, use the `location` and `cloud_storage_id` arguments.
Args:
filename (StrPath): A path to which a file will be downloaded
status_check_period (int, optional): Sleep interval in seconds between status checks.
Defaults to None, which means the `Config.status_check_period` is used.
pbar (Optional[ProgressReporter], optional): Can be used to show a progress when downloading file locally.
Defaults to None.
location (Optional[Location], optional): Location to which a file will be uploaded.
Can be Location.LOCAL or Location.CLOUD_STORAGE. Defaults to None.
cloud_storage_id (Optional[int], optional): ID of cloud storage to which a file should be uploaded. Defaults to None.
Raises:
ValueError: When location is Location.CLOUD_STORAGE but no cloud_storage_id is passed
"""

self.export(
self.api.create_dataset_export_endpoint,
filename,
pbar=pbar,
status_check_period=status_check_period,
location=location,
cloud_storage_id=cloud_storage_id,
format=format_name,
save_images=include_images,
)

self._client.logger.info(
f"Dataset for {self.__class__.__name__.lower()} {self.id} has been downloaded to {filename}"
)


class DownloadBackupMixin(_ExportMixin):
def download_backup(
self,
filename: StrPath,
*,
status_check_period: int = None,
pbar: Optional[ProgressReporter] = None,
location: Optional[str] = None,
cloud_storage_id: Optional[int] = None,
) -> None:
"""
Create a resource backup and download it locally or upload to a cloud storage.
By default, a result file will be downloaded based on the default configuration.
To force file downloading, pass `location=Location.LOCAL`.
To save a file to a specific cloud storage, use the `location` and `cloud_storage_id` arguments.
Args:
filename (StrPath): A path to which a file will be downloaded
status_check_period (int, optional): Sleep interval in seconds between status checks.
Defaults to None, which means the `Config.status_check_period` is used.
pbar (Optional[ProgressReporter], optional): Can be used to show a progress when downloading file locally.
Defaults to None.
location (Optional[Location], optional): Location to which a file will be uploaded.
Can be Location.LOCAL or Location.CLOUD_STORAGE. Defaults to None.
cloud_storage_id (Optional[int], optional): ID of cloud storage to which a file should be uploaded. Defaults to None.
Raises:
ValueError: When location is Location.CLOUD_STORAGE but no cloud_storage_id is passed
"""

self.export(
self.api.create_backup_export_endpoint,
filename,
pbar=pbar,
status_check_period=status_check_period,
location=location,
cloud_storage_id=cloud_storage_id,
)

self._client.logger.info(
f"Backup for {self.__class__.__name__.lower()} {self.id} has been downloaded to {filename}"
)
Loading

0 comments on commit 878bb41

Please sign in to comment.