Skip to content

Commit

Permalink
fix mypy check
Browse files Browse the repository at this point in the history
Signed-off-by: kenwoodjw <blackxin55+@gmail.com>
  • Loading branch information
kenwoodjw committed Jan 7, 2025
1 parent de573be commit 130db9a
Showing 1 changed file with 71 additions and 29 deletions.
100 changes: 71 additions & 29 deletions api/extensions/storage/azure_blob_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def _get_cached_token(self, cache_key: str) -> str | None:
"""Get cached token from Redis."""
cached = redis_client.get(cache_key)
if cached:
return cached.decode("utf-8")
result: str = str(cached.decode("utf-8"))
return result
return None

def _cache_token(self, cache_key: str, token: str, expires_in: int):
Expand All @@ -45,14 +46,19 @@ def _cache_token(self, cache_key: str, token: str, expires_in: int):

def _sync_client(self) -> BlobServiceClient:
"""Create a BlobServiceClient based on the configured authentication method."""
if not self.account_url:
raise ValueError("account_url is required")

if self.auth_type == "account_key":
if not self.account_key:
raise ValueError("account_key is required for account_key authentication")
return BlobServiceClient(account_url=self.account_url, credential=self.account_key)
elif self.auth_type == "sas_token":
# If SAS token is provided directly, use it
if self.sas_token:
return BlobServiceClient(account_url=self.account_url, credential=self.sas_token)
# Generate and cache SAS token if account key is available
elif self.account_key:
elif self.account_key and self.account_name:
cache_key = f"azure_blob_sas_token_{self.account_name}"
sas_token = self._get_cached_token(cache_key)

Expand All @@ -66,26 +72,40 @@ def _sync_client(self) -> BlobServiceClient:
),
expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1),
)
self._cache_token(cache_key, sas_token, 3600) # Cache for 1 hour
self._cache_token(cache_key, sas_token, 3600)

return BlobServiceClient(account_url=self.account_url, credential=sas_token)
else:
raise ValueError("Neither SAS token nor account key provided for SAS token authentication")

elif self.auth_type == "service_principal":
if not all([self.tenant_id, self.client_id, self.client_secret]):
raise ValueError(
"tenant_id, client_id, and client_secret are required "
"for service_principal authentication"
)
cache_key = f"azure_sp_token_{self.client_id}"
cached_token = self._get_cached_token(cache_key)

if cached_token:
return BlobServiceClient(account_url=self.account_url, credential=cached_token)

try:
assert self.tenant_id is not None
assert self.client_id is not None
assert self.client_secret is not None

credential = ClientSecretCredential(
tenant_id=self.tenant_id, client_id=self.client_id, client_secret=self.client_secret
tenant_id=self.tenant_id,
client_id=self.client_id,
client_secret=self.client_secret
)
# Get token with 1 hour expiry
token = credential.get_token("https://storage.azure.com/.default")
self._cache_token(cache_key, token.token, token.expires_in)
expires_in = 3600
if hasattr(token, 'expires_on') and isinstance(token.expires_on, datetime):
time_diff = token.expires_on - datetime.now(UTC)
expires_in = int(time_diff.total_seconds())
self._cache_token(cache_key, token.token, expires_in)

return BlobServiceClient(account_url=self.account_url, credential=credential)
except ClientAuthenticationError as e:
Expand All @@ -99,51 +119,73 @@ def _sync_client(self) -> BlobServiceClient:
return BlobServiceClient(account_url=self.account_url, credential=cached_token)

try:
credential = DefaultAzureCredential()
# Get token with 1 hour expiry
credential: ClientSecretCredential = DefaultAzureCredential() # type: ignore
token = credential.get_token("https://storage.azure.com/.default")
self._cache_token(cache_key, token.token, token.expires_in)
expires_in = 3600
if hasattr(token, 'expires_on') and isinstance(token.expires_on, datetime):
time_diff = token.expires_on - datetime.now(UTC)
expires_in = int(time_diff.total_seconds())
self._cache_token(cache_key, token.token, expires_in)

return BlobServiceClient(account_url=self.account_url, credential=credential)
except ClientAuthenticationError as e:
raise ValueError(f"Failed to authenticate with Managed Identity: {str(e)}")
else:
raise ValueError(f"Unsupported authentication type: {self.auth_type}")

def save(self, filename, data):
def save(self, filename: str, data: bytes) -> None:
if not self.bucket_name:
raise ValueError("bucket_name is required")

client = self._sync_client()
blob_container = client.get_container_client(container=self.bucket_name)
blob_container.upload_blob(filename, data)
blob_client = client.get_blob_client(container=self.bucket_name, blob=filename)
blob_client.upload_blob(data)

def load_once(self, filename: str) -> bytes:
if not self.bucket_name:
raise ValueError("bucket_name is required")

client = self._sync_client()
blob = client.get_container_client(container=self.bucket_name)
blob = blob.get_blob_client(blob=filename)
data: bytes = blob.download_blob().readall()
blob_client = client.get_blob_client(container=self.bucket_name, blob=filename)
downloaded = blob_client.download_blob().readall()
# Ensure we return bytes
if isinstance(downloaded, str):
data = downloaded.encode('utf-8')
else:
data = downloaded
return data

def load_stream(self, filename: str) -> Generator:
if not self.bucket_name:
raise ValueError("bucket_name is required")

client = self._sync_client()
blob = client.get_blob_client(container=self.bucket_name, blob=filename)
blob_data = blob.download_blob()
blob_client = client.get_blob_client(container=self.bucket_name, blob=filename)
blob_data = blob_client.download_blob()
yield from blob_data.chunks()

def download(self, filename, target_filepath):
def download(self, filename: str, target_filepath: str) -> None:
if not self.bucket_name:
raise ValueError("bucket_name is required")

client = self._sync_client()

blob = client.get_blob_client(container=self.bucket_name, blob=filename)
blob_client = client.get_blob_client(container=self.bucket_name, blob=filename)
with open(target_filepath, "wb") as my_blob:
blob_data = blob.download_blob()
blob_data = blob_client.download_blob()
blob_data.readinto(my_blob)

def exists(self, filename):
def exists(self, filename: str) -> bool:
if not self.bucket_name:
raise ValueError("bucket_name is required")

client = self._sync_client()
blob_client = client.get_blob_client(container=self.bucket_name, blob=filename)
return blob_client.exists()

blob = client.get_blob_client(container=self.bucket_name, blob=filename)
return blob.exists()

def delete(self, filename):
def delete(self, filename: str) -> None:
if not self.bucket_name:
raise ValueError("bucket_name is required")
client = self._sync_client()

blob_container = client.get_container_client(container=self.bucket_name)
blob_container.delete_blob(filename)
blob_client = client.get_blob_client(container=self.bucket_name, blob=filename)
blob_client.delete_blob()

0 comments on commit 130db9a

Please sign in to comment.