Skip to content

Commit

Permalink
better creation workflow and messages plus database creation
Browse files Browse the repository at this point in the history
  • Loading branch information
iakov-aws committed Dec 26, 2023
1 parent 113b92f commit e25d45b
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 75 deletions.
2 changes: 1 addition & 1 deletion cid/helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from cid.helpers.athena import Athena
from cid.helpers.glue import Glue
from cid.helpers.s3 import S3
from cid.helpers.athena import Athena
from cid.helpers.iam import IAM
from cid.helpers.cur import CUR
from cid.helpers.diff import diff
Expand Down
124 changes: 56 additions & 68 deletions cid/helpers/athena.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import re
import csv
import json
import time
import uuid
import logging
from string import Template
from io import StringIO
from pkg_resources import resource_string

from cid.base import CidBase
from cid.helpers.s3 import S3
from cid.helpers import S3
from cid.helpers import Glue
from cid.utils import get_parameter, get_parameters, cid_print
from cid.helpers.diff import diff
from cid.exceptions import CidCritical, CidError
Expand Down Expand Up @@ -42,13 +38,13 @@ def client(self):

@property
def CatalogName(self) -> str:
""" Check if AWS Datacalog and Athena database exist """
""" Check if AWS DataCatalog and Athena database exist """
if not self._CatalogName:
# Get AWS Glue DataCatalogs
glue_data_catalogs = [d for d in self.list_data_catalogs() if d['Type'] == 'GLUE']
if not len(glue_data_catalogs):
logger.error('AWS DataCatog of type GLUE not found!')
self._status = 'AWS DataCatog of type GLUE not found'
logger.error('AWS DataCatalog of type GLUE not found!')
self._status = 'AWS DataCatalog of type GLUE not found'
if len(glue_data_catalogs) == 1:
self._CatalogName = glue_data_catalogs.pop().get('CatalogName')
elif len(glue_data_catalogs) > 1:
Expand All @@ -58,7 +54,7 @@ def CatalogName(self) -> str:
message="Select AWS DataCatalog to use",
choices=[catalog.get('CatalogName') for catalog in glue_data_catalogs],
)
logger.info(f'Using datacatalog: {self._CatalogName}')
logger.info(f'Using DataCatalog: {self._CatalogName}')
return self._CatalogName

@CatalogName.setter
Expand All @@ -77,38 +73,30 @@ def DatabaseName(self) -> str:
raise CidCritical(f'Database {self._DatabaseName} not found in Athena catalog {self.CatalogName}')
except Exception as exc:
if 'AccessDeniedException' in str(exc):
logger.warning(f'{type(exc)} - Missing athena:GetDatabase permission. Cannot verify existance of {self._DatabaseName} in {self.CatalogName}. Hope you have it there.')
logger.warning(f'{type(exc)} - Missing athena:GetDatabase permission. Cannot verify existence of {self._DatabaseName} in {self.CatalogName}. Hope you have it there.')
return self._DatabaseName
raise
# Get AWS Athena databases
athena_databases = self.list_databases()
if not len(athena_databases):
self._status = 'AWS Athena databases not found'
raise CidCritical(self._status)
if len(athena_databases) == 1:
# Silently choose an existing database
self._DatabaseName = athena_databases.pop().get('Name')
elif len(athena_databases) > 1:
# Remove empty databases from the list
for d in athena_databases:
tables = self.list_table_metadata(
DatabaseName=d.get('Name'),
max_items=1000, # This is an impiric limit. User can have up to 200k tables in one DB we need to draw a line somewhere
)
if not len(tables):
athena_databases.remove(d)
# Select default database if present
default_databases = [d for d in athena_databases if d['Name'] == self.defaults.get('DatabaseName')]
if len(default_databases):
# Silently choose an existing default database
self._DatabaseName = default_databases.pop().get('Name')
else:
# Ask user
self._DatabaseName = get_parameter(
param_name='athena-database',
message="Select AWS Athena database to use",
choices=[d['Name'] for d in athena_databases],
)

# Select default database if present
default_databases = [database for database in athena_databases if database['Name'] == self.defaults.get('DatabaseName')]
if len(default_databases):
# Silently choose an existing default database
self._DatabaseName = default_databases.pop().get('Name')
else:
# Ask user
choices = [d['Name'] for d in athena_databases]
if self.defaults.get('DatabaseName') not in choices:
choices.append(self.defaults.get('DatabaseName') + ' (CREATE NEW)')
self._DatabaseName = get_parameter(
param_name='athena-database',
message="Select AWS Athena database to use",
choices=choices,
)
if self._DatabaseName.endswith( ' (CREATE NEW)'):
Glue(self.session).create_database(name=self.defaults.get('DatabaseName'))
self._DatabaseName = self.defaults.get('DatabaseName')
logger.info(f'Using Athena database: {self._DatabaseName}')
return self._DatabaseName

Expand Down Expand Up @@ -147,7 +135,7 @@ def WorkGroup(self) -> str:
if ' (create new)' in selected_workgroup:
selected_workgroup = selected_workgroup.replace(' (create new)', '')
self.WorkGroup = self._ensure_workgroup(name=selected_workgroup)

logger.info(f'Selected workgroup: "{self._WorkGroup}"')
return self._WorkGroup

Expand All @@ -165,14 +153,13 @@ def WorkGroup(self, name: str):
logger.info(f'Selected Athena WorkGroup: "{self._WorkGroup}"')

def _ensure_workgroup(self, name: str) -> str:
"""Ensure a workgroup exists and configured with an S3 bucket"""
try:
s3 = S3(session=self.session)
bucket_name = f'{self.partition}-athena-query-results-cid-{self.account_id}-{self.region}'

workgroup = self.client.get_work_group(WorkGroup=name)
# "${AWS::Partition}-athena-query-results-cid-${AWS::AccountId}-${AWS::Region}"
if not workgroup.get('WorkGroup', {}).get('Configuration', {}).get('ResultConfiguration', {}).get('OutputLocation', None):
s3 = S3(session=self.session)
buckets = s3.list_buckets(region_name=self.region)
if bucket_name not in buckets:
buckets.append(f'{bucket_name} (create new)')
Expand All @@ -184,7 +171,7 @@ def _ensure_workgroup(self, name: str) -> str:
if ' (create new)' in bucket_name:
bucket_name = bucket_name.replace(' (create new)', '')
s3.ensure_bucket(name=bucket_name)
response = self.client.update_work_group(
self.client.update_work_group(
WorkGroup=name,
Description='string',
ConfigurationUpdates={
Expand All @@ -200,11 +187,11 @@ def _ensure_workgroup(self, name: str) -> str:
}
)
return name
except self.client.exceptions.InvalidRequestException as ex:
except self.client.exceptions.InvalidRequestException as exc:
# Workgroup does not exist
if 'WorkGroup is not found' in ex.response.get('Error', {}).get('Message'):
if 'WorkGroup is not found' in exc.response.get('Error', {}).get('Message'):
s3.ensure_bucket(name=bucket_name)
response = self.client.create_work_group(
self.client.create_work_group(
Name=name,
Configuration={
'ResultConfiguration': {
Expand All @@ -219,28 +206,29 @@ def _ensure_workgroup(self, name: str) -> str:
}
)
return name
except Exception as ex:
raise CidCritical('Failed to create Athena work group') from ex

else:
raise
except Exception as exc:
logger.exception(exc)
raise CidCritical(f'Failed to create Athena work group ({exc})') from exc

def list_data_catalogs(self) -> list:
return self.client.list_data_catalogs().get('DataCatalogsSummary')

def list_databases(self) -> list:
return self.client.list_databases(CatalogName=self.CatalogName).get('DatabaseList')

def get_database(self, DatabaseName: str=None) -> bool:
""" Check if AWS Datacalog and Athena database exist """
if not DatabaseName:
DatabaseName = self.DatabaseName
""" Check if AWS DataCatalog and Athena database exist """
DatabaseName = DatabaseName or self.DatabaseName
try:
self.client.get_database(CatalogName=self.CatalogName, DatabaseName=DatabaseName).get('Database')
return True
except Exception as exc:
return self.client.get_database(CatalogName=self.CatalogName, DatabaseName=DatabaseName).get('Database')
except self.client.exceptions.ClientError as exc:
if 'AccessDeniedException' in str(exc):
raise
else:
logger.debug(exc, exc_info=True)
return False
return None

def list_table_metadata(self, DatabaseName: str=None, max_items: int=None) -> dict:
params = {
Expand All @@ -261,13 +249,13 @@ def list_table_metadata(self, DatabaseName: str=None, max_items: int=None) -> di
except Exception as e:
logger.error(f'Failed to list tables in {DatabaseName if DatabaseName else self.DatabaseName}')
logger.error(e)

return table_metadata

def list_work_groups(self) -> list:
""" List AWS Athena workgroups """
result = self.client.list_work_groups()
logger.debug(f'Workgroups: {result.get("WorkGroups")}')
logger.debug(f'WorkGroups: {result.get("WorkGroups")}')
return result.get('WorkGroups')

def get_table_metadata(self, TableName: str) -> dict:
Expand Down Expand Up @@ -352,13 +340,13 @@ def query(self, sql, include_header=False, **kwargs) -> list:
execution_id = self.execute_query(sql, **kwargs)
results = self.get_query_results(execution_id)
#logger.debug(f'results = {json.dumps(results, indent=2)}')
prarsed = self.parse_response_as_table(results, include_header)
logger.debug(f'prarsed res = {json.dumps(prarsed, indent=2)}')
return prarsed
parsed = self.parse_response_as_table(results, include_header)
logger.debug(f'parsed res = {json.dumps(parsed, indent=2)}')
return parsed


def discover_views(self, views: dict={}) -> None:
""" Discover views from a given list of view names and cahe them. """
""" Discover views from a given list of view names and cache them. """
for view_name in views:
try:
self.get_table_metadata(TableName=view_name)
Expand Down Expand Up @@ -390,7 +378,7 @@ def delete_table(self, name: str, catalog: str=None, database: str=None):
return False

try:
res = self.execute_query(
self.execute_query(
f'DROP TABLE IF EXISTS {name};',
catalog=catalog,
database=database,
Expand All @@ -414,7 +402,7 @@ def delete_view(self, name: str, catalog: str=None, database: str=None):
return False

try:
res = self.execute_query(
self.execute_query(
f'DROP VIEW IF EXISTS {name};',
catalog=catalog,
database=database,
Expand All @@ -430,7 +418,7 @@ def delete_view(self, name: str, catalog: str=None, database: str=None):
return True

def get_view_diff(self, name, sql):
""" returns a diff between existing and new viws. """
""" returns a diff between existing and new views. """
tmp_name = 'cid_tmp_deleteme'
existing_sql = ''
try:
Expand Down Expand Up @@ -490,7 +478,7 @@ def _recursively_process_view(view):
all_views[view]["dependsOn"]['views'] = []
deps = re.findall(r'FROM\W+?([\w."]+)', sql)
for dep_view in deps:
#FIXME: need to add cross Database Dependancies
#FIXME: need to add cross Database Dependencies
if dep_view.upper() in ('SELECT', 'VALUES'): # remove "FROM SELECT" and "FROM VALUES"
continue
dep_view = dep_view.replace('"', '')
Expand Down
9 changes: 9 additions & 0 deletions cid/helpers/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ def create_or_update_table(self, view_name: str, definition: dict) -> None:
logger.error(definition)
raise

def create_database(self, name, description: str='Cloud Intelligence Dashboards Database'):
"""Create Database"""
return self.client.create_database(
DatabaseInput={
'Name': name,
'Description': description,
},
)

def get_table(self, name, catalog, database):
"""Get table"""
return self.client.get_table(
Expand Down
3 changes: 3 additions & 0 deletions cid/helpers/iam.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
''' IAM helper
'''
import json
import time
import logging

from cid.base import CidBase
Expand Down Expand Up @@ -96,12 +97,14 @@ def ensure_role_with_policy(self, role_name, assume_role_policy_document, policy
PolicyDocument=json.dumps(assume_role_policy_document)
)
logger.info(f'Role {role_name} updated')

except self.client.exceptions.NoSuchEntityException:
self.client.create_role(
RoleName=role_name,
AssumeRolePolicyDocument=json.dumps(assume_role_policy_document)
)
logger.info(f'Role {role_name} created')
time.sleep(5) # Some times the role cannot be assumed without this delay after creation

try:
policy_response = self.client.get_policy(PolicyArn=f"arn:aws:iam::{self.account_id}:policy/{policy_name}")
Expand Down
23 changes: 20 additions & 3 deletions cid/helpers/quicksight/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,12 @@ def create_data_source(self, athena_workgroup, datasource_id: str=None, role_arn
if not datasource.status.endswith('IN_PROGRESS'):
break
if not datasource.is_healthy:
logger.error(f'Data source creation failed: {datasource.error_info}')
logger.error(f'Data source creation failed: {datasource.error_info}.')
if "The QuickSight service role required to access your AWS resources has not been created yet." in str(datasource.error_info):
logger.error(
'Please check that QuickSight has a default role that can access S3 Buckets and Athena https://quicksight.aws.amazon.com/sn/admin?#aws '
'OR provide a custom datasource role as a parameter --quicksight-datasource-role-arn'
)
if get_parameter(
param_name='quicksight-delete-failed-datasource',
message=f'Data source creation failed: {datasource.error_info}. Delete?',
Expand All @@ -472,8 +477,20 @@ def create_data_source(self, athena_workgroup, datasource_id: str=None, role_arn
return None
return datasource
except self.client.exceptions.ResourceExistsException:
logger.error('Data source already exists')
return self.describe_data_source(datasource_id, update=True)
datasource = self.describe_data_source(datasource_id, update=True)
logger.error(f'Data source already exists {datasource.raw}')
if not datasource.is_healthy:
if get_parameter(
param_name='quicksight-delete-failed-datasource',
message=f'Data source creation failed: {datasource.error_info}. Delete?',
choices=['yes', 'no'],
) == 'yes':
try:
self.delete_data_source(datasource.id)
raise CidCritical('Issue on datasource creation. Please retry.')
except self.client.exceptions.AccessDeniedException:
raise CidCritical('Access denied deleting datasource in QS. Please cleanup manually and retry.')
return datasource
except self.client.exceptions.AccessDeniedException as exc:
logger.info('Access denied creating Athena datasource')
logger.debug(exc, exc_info=True)
Expand Down
6 changes: 3 additions & 3 deletions cid/helpers/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def list_buckets(self, region_name: Optional[str] = None) -> List[str]:
'''
bucket_names = [bucket['Name'] for bucket in self.client.list_buckets()['Buckets']]
if region_name:
bucket_names = bucket_names.filter(
bucket_names = list(filter(
lambda bucket_name: self.client.get_bucket_location(Bucket=bucket_name).get('LocationConstraint', 'us-east-1') == region_name,
bucket_names,
lambda bucket_name: self.client.get_bucket_location(Bucket=bucket_name).get('LocationConstraint', 'us-east-1') == region_name
)
))
return bucket_names

def iterate_objects(self, bucket: str, prefix: str='/', search: str='Contents[].Key') -> List[str]:
Expand Down

0 comments on commit e25d45b

Please sign in to comment.