diff --git a/src/emr_cli/deployments/emr_eks.py b/src/emr_cli/deployments/emr_eks.py index 0f996f7..b05f5ad 100644 --- a/src/emr_cli/deployments/emr_eks.py +++ b/src/emr_cli/deployments/emr_eks.py @@ -1,6 +1,7 @@ import re import sys from os.path import join +from platform import release from time import sleep from typing import List, Optional @@ -19,10 +20,21 @@ def __init__( self.s3_client = boto3.client("s3") if region: self.client = boto3.client("emr-containers", region_name=region) + self.emr_client = boto3.client("emr", region_name=region) else: # Note that boto3 uses AWS_DEFAULT_REGION, not AWS_REGION # We may want to add an extra check here for the latter. self.client = boto3.client("emr-containers") + self.emr_client = boto3.client("emr") + + def fetch_latest_release_label(self): + response = self.emr_client.list_release_labels( + Filters={"Application": "Spark", "Prefix": "emr-6"}, MaxResults=1 + ) + if len(response["ReleaseLabels"]) == 0: + console_log("Error: No release labels found") + sys.exit(1) + return response["ReleaseLabels"][0] def run_job( self, @@ -32,10 +44,16 @@ def run_job( wait: bool = True, show_logs: bool = False, s3_logs_uri: Optional[str] = None, + release_label: Optional[str] = None, ): if show_logs and not s3_logs_uri: raise RuntimeError("--show-stdout requires --s3-logs-uri to be set.") + if release_label is None: + release_label = self.fetch_latest_release_label() + console_log(f"Using latest release label {release_label}") + release_label = f"{release_label}-latest" + # If job_name is the default, just replace the space. # Otherwise throw an error if job_name == "emr-cli job": @@ -70,7 +88,7 @@ def run_job( name=job_name, jobDriver=jobDriver, configurationOverrides=config_overrides, - releaseLabel="emr-6.15.0-latest", + releaseLabel=release_label, ) job_run_id = response.get("id") diff --git a/src/emr_cli/emr_cli.py b/src/emr_cli/emr_cli.py index d985eee..ae1f904 100644 --- a/src/emr_cli/emr_cli.py +++ b/src/emr_cli/emr_cli.py @@ -210,6 +210,9 @@ def deploy(project, entry_point, s3_code_uri): help="Update the config file with the provided options", is_flag=True, ) +@click.option( + "--emr-eks-release-label", help="EMR on EKS release label (emr-6.15.0) - defaults to latest release", default=None +) @click.pass_obj @click.pass_context def run( @@ -229,6 +232,7 @@ def run( build, show_stdout, save_config, + emr_eks_release_label, ): """ Run a project on EMR, optionally build and deploy @@ -242,7 +246,7 @@ def run( ) # Only one resource ID can be specified - if resource_ids.count(None) != (len(resource_ids)-1): + if resource_ids.count(None) != (len(resource_ids) - 1): raise click.BadArgumentUsage( "Only one of --application-id, --cluster-id, or --virtual-cluster-id can be specified" ) @@ -252,6 +256,13 @@ def run( raise click.BadArgumentUsage("--entry-point and --s3-code-uri are required.") p = project(entry_point, s3_code_uri) + # Do a brief validation of the EMR on EKS release label + if emr_eks_release_label: + if not virtual_cluster_id: + raise click.BadArgumentUsage("--emr-eks-release-label can only be used with --virtual-cluster-id") + elif not emr_eks_release_label.startswith("emr-"): + raise click.BadArgumentUsage(f"--emr-eks-release-label must start with 'emr-', provided '{emr_eks_release_label}'") + # If the user passes --save-config, update our stored config file if save_config: run_config = {"run": ctx.__dict__.get("params")} @@ -289,7 +300,7 @@ def run( if job_args: job_args = job_args.split(",") emreks = EMREKS(virtual_cluster_id, job_role, p) - emreks.run_job(job_name, job_args, spark_submit_opts, wait, show_stdout, s3_logs_uri) + emreks.run_job(job_name, job_args, spark_submit_opts, wait, show_stdout, s3_logs_uri, emr_eks_release_label) cli.add_command(package)