Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ODSC-46634/utilize oci UploadManager to upload model artifacts #304

Merged
merged 8 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions ads/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@

# declare custom exception class

# OCI path schema
OCI_SCHEME = "oci"
OCI_PREFIX = f"{OCI_SCHEME}://"


class FileOverwriteError(Exception): # pragma: no cover
pass
Expand Down Expand Up @@ -1599,3 +1603,35 @@ def is_path_exists(uri: str, auth: Optional[Dict] = None) -> bool:
if fsspec.filesystem(path_scheme, **storage_options).exists(uri):
return True
return False


def parse_os_uri(uri: str):
mingkang111 marked this conversation as resolved.
Show resolved Hide resolved
"""
Parse an OCI object storage URI, returning tuple (bucket, namespace, path).

Parameters
----------
uri: str
The OCI Object Storage URI.

Returns
-------
Tuple
The (bucket, ns, type)

Raise
-----
ValueError
If provided URI is not an OCI OS bucket URI.
"""
parsed = urlparse(uri)
if parsed.scheme.lower() != OCI_SCHEME:
raise ValueError("Not an OCI object storage URI: %s" % uri)
path = parsed.path

if path.startswith("/"):
path = path[1:]

bucket, ns = parsed.netloc.split("@")

return bucket, ns, path
94 changes: 73 additions & 21 deletions ads/model/artifact_uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
from abc import ABC, abstractmethod
from typing import Dict, Optional

from ads import logger
from ads.common import utils
from ads.common.oci_client import OCIClientFactory
from ads.model.common import utils as model_utils
from ads.model.service.oci_datascience_model import OCIDataScienceModel

from oci import object_storage


class ArtifactUploader(ABC):
"""The abstract class to upload model artifacts."""
Expand Down Expand Up @@ -94,6 +98,8 @@ def _upload(self):


class SmallArtifactUploader(ArtifactUploader):
"""The class helper to upload small model artifacts."""

PROGRESS_STEPS_COUNT = 1

def _upload(self):
Expand All @@ -104,6 +110,39 @@ def _upload(self):


class LargeArtifactUploader(ArtifactUploader):
"""
The class helper to upload large model artifacts.

Attributes
mingkang111 marked this conversation as resolved.
Show resolved Hide resolved
----------
artifact_path: str
The model artifact location.
artifact_zip_path: str
The uri of the zip of model artifact.
auth: dict
The default authetication is set using `ads.set_auth` API.
If you need to override the default, use the `ads.common.auth.api_keys` or
`ads.common.auth.resource_principal` to create appropriate authentication signer
and kwargs required to instantiate IdentityClient object.
bucket_uri: str
The OCI Object Storage URI where model artifacts will be copied to.
The `bucket_uri` is only necessary for uploading large artifacts which
size is greater than 2GB. Example: `oci://<bucket_name>@<namespace>/prefix/`.
dsc_model: OCIDataScienceModel
The data scince model instance.
overwrite_existing_artifact: bool
Overwrite target bucket artifact if exists.
progress: TqdmProgressBar
An instance of the TqdmProgressBar.
region: str
The destination Object Storage bucket region.
By default the value will be extracted from the `OCI_REGION_METADATA` environment variables.
remove_existing_artifact: bool
Wether artifacts uploaded to object storage bucket need to be removed or not.
upload_manager: UploadManager
The uploadManager simplifies interaction with the Object Storage service.
"""

PROGRESS_STEPS_COUNT = 4

def __init__(
Expand Down Expand Up @@ -150,36 +189,49 @@ def __init__(
self.bucket_uri = bucket_uri
self.overwrite_existing_artifact = overwrite_existing_artifact
self.remove_existing_artifact = remove_existing_artifact
self.upload_manager = object_storage.UploadManager(
OCIClientFactory(**self.auth).object_storage
)

def _upload(self):
"""Uploads model artifacts to the model catalog."""
self.progress.update("Copying model artifact to the Object Storage bucket")

try:
bucket_uri = self.bucket_uri
bucket_uri_file_name = os.path.basename(bucket_uri)

if not bucket_uri_file_name:
bucket_uri = os.path.join(bucket_uri, f"{self.dsc_model.id}.zip")
elif not bucket_uri.lower().endswith(".zip"):
bucket_uri = f"{bucket_uri}.zip"

bucket_file_name = utils.copy_file(
self.artifact_zip_path,
bucket_uri,
force_overwrite=self.overwrite_existing_artifact,
auth=self.auth,
progressbar_description="Copying model artifact to the Object Storage bucket",
)
except FileExistsError:
bucket_uri = self.bucket_uri
bucket_uri_file_name = os.path.basename(bucket_uri)

if not bucket_uri_file_name:
bucket_uri = os.path.join(bucket_uri, f"{self.dsc_model.id}.zip")
elif not bucket_uri.lower().endswith(".zip"):
bucket_uri = f"{bucket_uri}.zip"

if not self.overwrite_existing_artifact and utils.is_path_exists(
uri=bucket_uri, auth=self.auth
):
raise FileExistsError(
f"The `{self.bucket_uri}` exists. Please use a new file name or "
f"The bucket_uri=`{self.bucket_uri}` exists. Please use a new file name or "
"set `overwrite_existing_artifact` to `True` if you wish to overwrite."
)

bucket_name, namespace_name, object_name = utils.parse_os_uri(bucket_uri)
logger.debug(f"{bucket_name=}, {namespace_name=}, {object_name=}")
try:
response = self.upload_manager.upload_file(
mingkang111 marked this conversation as resolved.
Show resolved Hide resolved
namespace_name=namespace_name,
bucket_name=bucket_name,
object_name=object_name,
file_path=self.artifact_zip_path,
)
logger.debug(response)
assert response.status == 200
except Exception as ex:
raise RuntimeError(
f"Failed to upload model artifact to the given Object Storage path `{self.bucket_uri}`."
f"Exception: {ex}"
)

self.progress.update("Exporting model artifact to the model catalog")
self.dsc_model.export_model_artifact(
bucket_uri=bucket_file_name, region=self.region
)
self.dsc_model.export_model_artifact(bucket_uri=bucket_uri, region=self.region)

if self.remove_existing_artifact:
self.progress.update(
Expand Down
12 changes: 12 additions & 0 deletions tests/unitary/default_setup/common/test_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,3 +498,15 @@ def test_extract_region(self, input_params, expected_result):
return_value={"config": {"region": "default_signer_region"}},
):
assert extract_region(input_params["auth"]) == expected_result

def test_parse_os_uri(self):
bucket, namespace, path = utils.parse_os_uri(
"oci://my-bucket@my-namespace/my-artifact-path"
)
assert bucket == "my-bucket"
assert namespace == "my-namespace"
assert path == "my-artifact-path"

def test_parse_os_uri_with_invalid_scheme(self):
with pytest.raises(ValueError):
utils.parse_os_uri("s3://my-bucket/my-artifact-path")
87 changes: 45 additions & 42 deletions tests/unitary/default_setup/model/test_artifact_uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import pytest
from ads.model.artifact_uploader import LargeArtifactUploader, SmallArtifactUploader
from ads.model.common.utils import zip_artifact
from ads.common.auth import default_signer
from oci import object_storage

MODEL_OCID = "ocid1.datasciencemodel.oc1.xxx"

Expand Down Expand Up @@ -60,7 +62,6 @@ def test__init__(self):

# Ensures the LargeArtifactUploader can be successfully initialized
with patch("os.path.exists", return_value=True):

with pytest.raises(ValueError, match="The `bucket_uri` must be provided."):
lg_artifact_uploader = LargeArtifactUploader(
dsc_model=self.mock_dsc_model,
Expand All @@ -71,11 +72,11 @@ def test__init__(self):
overwrite_existing_artifact=False,
remove_existing_artifact=False,
)

auth = default_signer()
lg_artifact_uploader = LargeArtifactUploader(
dsc_model=self.mock_dsc_model,
artifact_path="existing_path",
auth=self.mock_auth,
auth=auth,
region=self.mock_region,
bucket_uri="test_bucket_uri",
overwrite_existing_artifact=False,
Expand All @@ -85,14 +86,16 @@ def test__init__(self):
assert lg_artifact_uploader.artifact_path == "existing_path"
assert lg_artifact_uploader.artifact_zip_path == None
assert lg_artifact_uploader.progress == None
assert lg_artifact_uploader.auth == self.mock_auth
assert lg_artifact_uploader.auth == auth
assert lg_artifact_uploader.region == self.mock_region
assert lg_artifact_uploader.bucket_uri == "test_bucket_uri"
assert lg_artifact_uploader.overwrite_existing_artifact == False
assert lg_artifact_uploader.remove_existing_artifact == False
assert isinstance(
lg_artifact_uploader.upload_manager, object_storage.UploadManager
)

def test_prepare_artiact_tmp_zip(self):

# Tests case when a folder provided as artifacts location
with patch("ads.model.common.utils.zip_artifact") as mock_zip_artifact:
mock_zip_artifact.return_value = "test_artifact.zip"
Expand Down Expand Up @@ -167,50 +170,50 @@ def test_upload_small_artifact(self):
mock_remove_artiact_tmp_zip.assert_called()
self.mock_dsc_model.create_model_artifact.assert_called()

def test_upload_large_artifact(self):
with tempfile.TemporaryDirectory() as tmp_artifact_dir:
test_bucket_file_name = os.path.join(tmp_artifact_dir, f"{MODEL_OCID}.zip")
# Case when artifact will be created and left in the TMP folder
@patch("ads.common.utils.is_path_exists")
@patch.object(object_storage.UploadManager, "upload_file")
def test_upload_large_artifact(self, mock_upload_file, mock_is_path_exists):
class MockResponse:
def __init__(self, status_code):
self.status = status_code

# Case when artifact already exists and overwrite_existing_artifact==True
dest_path = "oci://my-bucket@my-namespace/my-artifact-path"
test_bucket_file_name = os.path.join(dest_path, f"{MODEL_OCID}.zip")
mock_upload_file.return_value = MockResponse(200)
mock_is_path_exists.return_value = True
artifact_uploader = LargeArtifactUploader(
dsc_model=self.mock_dsc_model,
artifact_path=self.mock_artifact_zip_path,
bucket_uri=dest_path + "/",
auth=default_signer(),
region=self.mock_region,
overwrite_existing_artifact=True,
remove_existing_artifact=False,
)
artifact_uploader.upload()
mock_upload_file.assert_called_with(
namespace_name="my-namespace",
bucket_name="my-bucket",
object_name=f"my-artifact-path/{MODEL_OCID}.zip",
file_path=self.mock_artifact_zip_path,
)
self.mock_dsc_model.export_model_artifact.assert_called_with(
bucket_uri=test_bucket_file_name, region=self.mock_region
)

# Case when artifact already exists and overwrite_existing_artifact==False
with pytest.raises(FileExistsError):
artifact_uploader = LargeArtifactUploader(
dsc_model=self.mock_dsc_model,
artifact_path=self.mock_artifact_path,
bucket_uri=tmp_artifact_dir + "/",
auth=self.mock_auth,
artifact_path=self.mock_artifact_zip_path,
bucket_uri=dest_path + "/",
auth=default_signer(),
region=self.mock_region,
overwrite_existing_artifact=False,
remove_existing_artifact=False,
)
artifact_uploader.upload()
self.mock_dsc_model.export_model_artifact.assert_called_with(
bucket_uri=test_bucket_file_name, region=self.mock_region
)
assert os.path.exists(test_bucket_file_name)

# Case when artifact already exists and overwrite_existing_artifact==False
with pytest.raises(FileExistsError):
artifact_uploader = LargeArtifactUploader(
dsc_model=self.mock_dsc_model,
artifact_path=self.mock_artifact_path,
bucket_uri=tmp_artifact_dir + "/",
auth=self.mock_auth,
region=self.mock_region,
overwrite_existing_artifact=False,
remove_existing_artifact=False,
)
artifact_uploader.upload()

# Case when artifact already exists and overwrite_existing_artifact==True
artifact_uploader = LargeArtifactUploader(
dsc_model=self.mock_dsc_model,
artifact_path=self.mock_artifact_path,
bucket_uri=tmp_artifact_dir + "/",
auth=self.mock_auth,
region=self.mock_region,
overwrite_existing_artifact=True,
remove_existing_artifact=True,
)
artifact_uploader.upload()
assert not os.path.exists(test_bucket_file_name)

def test_zip_artifact_fail(self):
with pytest.raises(ValueError, match="The `artifact_dir` must be provided."):
Expand Down