Skip to content

Commit

Permalink
add proxy cur command
Browse files Browse the repository at this point in the history
  • Loading branch information
iakov-aws committed May 6, 2024
1 parent 47786f9 commit 30df100
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 45 deletions.
18 changes: 17 additions & 1 deletion cid/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def cid_command(func):
def wrapper(ctx, **kwargs):

def get_command_line():
params = get_parameters()
return ('cid-cmd ' + ctx.info_name
+ ''.join([f" --{k.replace('_','-')}" for k, v in ctx.params.items() if isinstance(v, bool) and v])
+ ''.join([f" --{k.replace('_','-')} '{v}'" for k, v in ctx.params.items() if not isinstance(v, bool) and v is not None])
Expand All @@ -43,7 +44,7 @@ def get_command_line():
res = func(ctx, **kwargs)
except (CidCritical, CidError) as exc:
logger.debug(exc, exc_info=True)
logger(f'When running {get_command_line()}')
logger.debug(f'When running {get_command_line()}')
logger.error(exc)
params = get_parameters()
logger.info('Next time you can use following command:')
Expand Down Expand Up @@ -263,6 +264,21 @@ def create_cur_table(ctx, **kwargs):

ctx.obj.create_cur_table(**kwargs)

@click.option('-v', '--verbose', count=True)
@click.option('--cur-version', help='Cur Version (1 or 2)')
@click.option('--fields', help='CUR fields', default='')
@cid_command
def create_cur_proxy(ctx, cur_version, fields, **kwargs):
"""Create CUR proxy
\b
--cur-version (1|2) Version of CUR
--fields Comma Separated list of additional CUR fields
"""

ctx.obj.create_cur_proxy(**kwargs)


@click.option('-v', '--verbose', count=True)
@click.option('-y', '--yes', help='confirm all', is_flag=True, default=False)
@cid_command
Expand Down
29 changes: 24 additions & 5 deletions cid/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,7 +1363,7 @@ def create_or_update_dataset(self, dataset_definition: dict, dataset_id: str=Non
# Read dataset definition from template
data = self.get_data_from_definition('dataset', dataset_definition)
template = Template(json.dumps(data))
cur_required = dataset_definition.get('dependsOn', dict()).get('cur')
cur1_required = dataset_definition.get('dependsOn', dict()).get('cur') or dataset_definition.get('dependsOn', dict()).get('cur')
cur2_required = dataset_definition.get('dependsOn', dict()).get('cur2')
athena_datasource = None

Expand Down Expand Up @@ -1487,7 +1487,8 @@ def create_or_update_dataset(self, dataset_definition: dict, dataset_id: str=Non
columns_tpl = {
'athena_datasource_arn': athena_datasource.arn,
'athena_database_name': self.athena.DatabaseName,
'cur_table_name': self.cur1.table_name if cur_required else None,
'cur_table_name': self.cur1.table_name if cur1_required else None,
'cur1_table_name': self.cur1.table_name if cur1_required else None,
'cur2_table_name': self.cur2.table_name if cur2_required else None,
}

Expand Down Expand Up @@ -1518,7 +1519,7 @@ def create_or_update_dataset(self, dataset_definition: dict, dataset_id: str=Non
elif found_dataset.name != compiled_dataset.get('Name'):
print(f"Dataset found with name {found_dataset.name}, but {compiled_dataset.get('Name')} expected. Updating.")
update_dataset = True
if update_dataset and get_parameters().get('on-drift', 'show').lower() != 'override' and isatty() and not cur_required and not cur2_required:
if update_dataset and get_parameters().get('on-drift', 'show').lower() != 'override' and isatty() and not cur1_required and not cur2_required:
while True:
diff = self.qs.dataset_diff(found_dataset.raw, compiled_dataset)
if diff and diff['diff']:
Expand Down Expand Up @@ -1737,7 +1738,7 @@ def get_view_query(self, view_name: str) -> str:
""" Returns a fully compiled AHQ """
# View path
view_definition = self.get_definition("view", name=view_name)
cur_required = view_definition.get('dependsOn', dict()).get('cur') or view_definition.get('dependsOn', dict()).get('cur1')
cur1_required = view_definition.get('dependsOn', dict()).get('cur') or view_definition.get('dependsOn', dict()).get('cur1')
cur2_required = view_definition.get('dependsOn', dict()).get('cur2')
#if cur_required and self.cur.has_savings_plans and self.cur.has_reservations and view_definition.get('spriFile'):
# view_definition['File'] = view_definition.get('spriFile')
Expand All @@ -1760,7 +1761,8 @@ def get_view_query(self, view_name: str) -> str:

# Prepare template parameters
columns_tpl = {
'cur_table_name': self.cur1.table_name if cur_required else None,
'cur_table_name': self.cur1.table_name if cur1_required else None,
'cur1_table_name': self.cur1.table_name if cur1_required else None,
'cur2_table_name': self.cur2.table_name if cur2_required else None,
'athenaTableName': view_name,
'athena_database_name': self.athena.DatabaseName,
Expand Down Expand Up @@ -1806,6 +1808,23 @@ def init_qs(self, **kwargs):
""" Initialize QuickSight resources for deployment """
return InitQsCommand(cid=self, **kwargs).execute()

@command
def create_cur_proxy(self, cur_version=None, fields=None, **kwargs):
cid_print(f'Using {self.cur.table_name}') # need to call self.cur
cur_version = cur_version or get_parameter(
'cur-version',
message='Enter a version of CUR you want to create or update',
choices=['1', '2'],
)
if cur_version.startswith('1'):
cur_proxy = self.cur1
if cur_version.startswith('2'):
cur_proxy = self.cur2
fields = get_parameters().get('fields', [])
cur_proxy.metadata
cur_proxy.proxy.fields_to_expose += (fields.split(',') if fields else [])
cur_proxy.proxy.create_or_update_view()

@command
def create_cur_table(self, **kwargs):
""" Initialize CUR """
Expand Down
2 changes: 0 additions & 2 deletions cid/helpers/cur.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,6 @@ def table_is_cur(self, table: dict=None, name: str=None, return_reason: bool=Fal
return False if not return_reason else (False, f'cannot get table {name}. {exc}.')

table_name = table.get('Name')
if '_proxy' in table_name:
return False if not return_reason else (False, f"Table {table_name} most likely is a proxy.")
columns = [col.get('Name') for col in table.get('Columns')]
missing_columns = [col for col in self.cur_minimal_required_columns if col not in columns]
logger.critical(missing_columns)
Expand Down
63 changes: 26 additions & 37 deletions cid/helpers/cur_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,12 +351,12 @@ class ProxyView():
creates a proxy view for CUR
"""
def __init__(self, cur, target_cur_version, fields_to_expose=None):
def __init__(self, cur, target_cur_version):
self.cur = cur
self.target_cur_version = target_cur_version
self.current_cur_version = self.cur.version
logger.debug(f'CUR proxy from {self.current_cur_version } to {self.target_cur_version }')
self.fields_to_expose = list(set(default_columns[self.target_cur_version] + (fields_to_expose or []) ))
self.fields_to_expose = list(set(default_columns[self.target_cur_version]))
self.athena = self.cur.athena
self.name = f'cur{self.target_cur_version}_proxy'
self.exposed_fields = []
Expand Down Expand Up @@ -403,17 +403,17 @@ def source_column_equivalent(self, field):
for cur2map in cur2_maps:
if field.startswith(cur2map + '_'):
return cur2map
logger.warning(f"{field} not known field of CUR1. needs to be added in code. Please create a github issue")
logger.warning(f"{field} not known field of CUR2. needs to be added in code. Please create a github issue")
res = cur1to2_mapping.get(field, field)
return res.split('[')[0]
if self.current_cur_version.startswith('1') and self.target_cur_version.startswith('2'): # field from CUR2 to CUR1
matches = re.findall(r"(\w+)\['(\w+)'\]", field)
if matches:
field, key = matches[0]
if field not in self.fields_to_expose_in_maps:
self.fields_to_expose_in_maps = set()
self.fields_to_expose_in_maps[key].add(key)
return f'{field}_{key}'
map_field, key = matches[0]
if map_field not in self.fields_to_expose_in_maps:
self.fields_to_expose_in_maps[map_field] = set()
self.fields_to_expose_in_maps[map_field].add(key)
return f'{map_field}_{key}'
cur2to1_mapping = {value: key for key, value in cur1to2_mapping.items()}
if field not in cur2to1_mapping:
logger.warning(f"{field} not known field of CUR1. needs to be added in code. Please create a github issue")
Expand All @@ -432,12 +432,19 @@ def get_sql_expression(self, field, field_type):
if self.current_cur_version.startswith('1') and self.target_cur_version.startswith('1'): # field from CUR1 to CUR2
return field
if self.current_cur_version.startswith('1') and self.target_cur_version.startswith('2'): # field from CUR1 to CUR2
if field_type.startswith('map'):
if field_type.lower().startswith('map'):
self.source_column_equivalent(field) # Do not remove this
map_field = field.split('[')[0]
map_mapping = {}
keys = set(self.exposed_maps.get(field, set())).update(self.fields_to_expose_in_maps.get(field, set()))
for key in keys:
if f'{field}_{key}' in self.cur.fields:
map_mapping[key] = f'{field}_{key}'
keys_set = set(self.exposed_maps.get(field, set()))
print('field', field)
print('map_field', map_field)
print('fields_to_expose_in_maps', self.fields_to_expose_in_maps)
keys_set.update(self.fields_to_expose_in_maps.get(map_field, set()))
print('keys_set =',keys_set)
for key in keys_set:
if f'{map_field}_{key}' in self.cur.fields:
map_mapping[key] = f'{map_field}_{key}'
else:
map_mapping[key] = 'CAST(NULL as VARCHAR)'
if not map_mapping:
Expand All @@ -464,15 +471,16 @@ def create_or_update_view(self):
""" Create or update view
"""
self.read_from_athena()
all_fields = sorted(list(set(self.exposed_fields + self.fields_to_expose)))
all_target_fields = sorted(list(set(self.exposed_fields + self.fields_to_expose)))
print('all_target_fields', all_target_fields)
lines = {}
print('all_fields', all_fields)
for field in all_fields:
for field in all_target_fields:
target_field = field.split('[')[0] # take a first part only
field_type = self.cur.get_type_of_column(field)
field_type = self.cur.get_type_of_column(target_field)
mapped_expression = self.get_sql_expression(field, field_type)
print(field, target_field, field_type, mapped_expression)
requirement = mapped_expression.split('[')[0]
if not re.match(r'^[a-zA-Z0-9_]+$', requirement) or self.cur.column_exists(requirement):
if field_type.lower().startswith('map') or (not re.match(r'^[a-zA-Z0-9_]+$', requirement) or self.cur.column_exists(requirement)):
expression = mapped_expression
else:
if field_type not in empty:
Expand All @@ -494,22 +502,3 @@ def create_or_update_view(self):

def get_table_metadata(self):
return self.athena.get_table_metadata(self.name)


if __name__ == '__main__':

import boto3
from cid.helpers.athena import Athena
from cid.helpers.glue import Glue
from cid.helpers.cur import CUR

logging.basicConfig(level=logging.INFO)
logger.setLevel(logging.DEBUG)
athena = Athena(session=boto3.Session())
glue = Glue(session=boto3.Session())
cur = CUR(athena=athena, glue=glue)
proxy = ProxyView(
cur=cur,
target_cur_version='2',
)
proxy.create_or_update_view()

0 comments on commit 30df100

Please sign in to comment.