diff --git a/python/kserve/kserve/storage/storage.py b/python/kserve/kserve/storage/storage.py index fa82d90864c..6466ea883fc 100644 --- a/python/kserve/kserve/storage/storage.py +++ b/python/kserve/kserve/storage/storage.py @@ -112,13 +112,17 @@ def _update_with_storage_spec(): storage_secret_json[key] = value if storage_secret_json.get("type", "") == "s3": - os.environ["AWS_ENDPOINT_URL"] = storage_secret_json.get("endpoint_url", "") - os.environ["AWS_ACCESS_KEY_ID"] = storage_secret_json.get("access_key_id", "") - os.environ["AWS_SECRET_ACCESS_KEY"] = storage_secret_json.get("secret_access_key", "") - os.environ["AWS_DEFAULT_REGION"] = storage_secret_json.get("region", "") - os.environ["AWS_CA_BUNDLE"] = storage_secret_json.get("certificate", "") - os.environ["S3_VERIFY_SSL"] = storage_secret_json.get("verify_ssl", "1") - os.environ["awsAnonymousCredential"] = storage_secret_json.get("anonymous", "") + for env_var, key in ( + ("AWS_ENDPOINT_URL", "endpoint_url"), + ("AWS_ACCESS_KEY_ID", "access_key_id"), + ("AWS_SECRET_ACCESS_KEY", "secret_access_key"), + ("AWS_DEFAULT_REGION", "region"), + ("AWS_CA_BUNDLE", "certificate"), + ("S3_VERIFY_SSL", "verify_ssl"), + ("awsAnonymousCredential", "anonymous"), + ): + if key in storage_secret_json: + os.environ[env_var] = storage_secret_json.get(key) if storage_secret_json.get("type", "") == "hdfs" or storage_secret_json.get("type", "") == "webhdfs": temp_dir = tempfile.mkdtemp() diff --git a/python/kserve/kserve/storage/test/test_s3_storage.py b/python/kserve/kserve/storage/test/test_s3_storage.py index 14a2f5e52fd..f5328f4553c 100644 --- a/python/kserve/kserve/storage/test/test_s3_storage.py +++ b/python/kserve/kserve/storage/test/test_s3_storage.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import json import unittest.mock as mock from botocore.client import Config @@ -212,3 +213,50 @@ def test_get_S3_config(): with mock.patch.dict(os.environ, {"S3_USER_VIRTUAL_BUCKET": "True"}): config7 = Storage.get_S3_config() assert config7.s3["addressing_style"] == VIRTUAL_CONFIG.s3["addressing_style"] + + +def test_update_with_storage_spec_s3(monkeypatch): + # save the environment and restore it after the test to avoid mutating it + # since _update_with_storage_spec modifies it + previous_env = os.environ.copy() + + monkeypatch.setenv("STORAGE_CONFIG", '{"type": "s3"}') + Storage._update_with_storage_spec() + + for var in ( + "AWS_ENDPOINT_URL", + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_DEFAULT_REGION", + "AWS_CA_BUNDLE", + "S3_VERIFY_SSL", + "awsAnonymousCredential", + ): + assert os.getenv(var) is None + + storage_config = { + "access_key_id": "xxxxxxxxxxxxxxxxxxxx", + "bucket": "abucketname", + "default_bucket": "abucketname", + "endpoint_url": "https://s3.us-east-2.amazonaws.com/", + "region": "us-east-2", + "secret_access_key": "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + "type": "s3", + "certificate": "/path/to/ca.bundle", + "verify_ssl": "false", + "anonymous": "True", + } + + monkeypatch.setenv("STORAGE_CONFIG", json.dumps(storage_config)) + Storage._update_with_storage_spec() + + assert os.getenv("AWS_ENDPOINT_URL") == storage_config["endpoint_url"] + assert os.getenv("AWS_ACCESS_KEY_ID") == storage_config["access_key_id"] + assert os.getenv("AWS_SECRET_ACCESS_KEY") == storage_config["secret_access_key"] + assert os.getenv("AWS_DEFAULT_REGION") == storage_config["region"] + assert os.getenv("AWS_CA_BUNDLE") == storage_config["certificate"] + assert os.getenv("S3_VERIFY_SSL") == storage_config["verify_ssl"] + assert os.getenv("awsAnonymousCredential") == storage_config["anonymous"] + + # revert changes + os.environ = previous_env