From a1fc3f3d8a5c10a5bbe350bbb9b76e73d5e1251d Mon Sep 17 00:00:00 2001 From: Morten Kuhlwein Date: Wed, 30 Oct 2024 16:52:11 +0100 Subject: [PATCH] Add per_step_k8s_config to k8s_job_executor (#25561) ## Summary & Motivation Dagster currently requires that k8s config at the op-level be defined statically with no option for configuring at launch-time (#22138). This was resolved for the celery-k8s-job-executor in #23053. This PR aims to add an identical solution to the k8s-job-executor. ## How I Tested These Changes I added a unit test in the k8s package to test that the executor per_step_k8s_config overwrites other sources (op-tag, job-tag, executor step_k8s_config). ## Changelog [dagster-k8s] Added a per_step_k8s_config configuration option to the k8s_job_executor, allowing the k8s configuration of individual steps to be configured at run launch time. --- .../dagster-k8s/dagster_k8s/executor.py | 21 +++++- .../unit_tests/test_executor.py | 66 +++++++++++++++++++ 2 files changed, 86 insertions(+), 1 deletion(-) diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py b/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py index c3e6cd6ffd186..c8b0d7289d4dc 100644 --- a/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py +++ b/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py @@ -4,6 +4,7 @@ from dagster import ( Field, IntSource, + Map, Noneable, StringSource, _check as check, @@ -73,6 +74,12 @@ is_required=False, description="Raw Kubernetes configuration for each step launched by the executor.", ), + "per_step_k8s_config": Field( + Map(str, USER_DEFINED_K8S_JOB_CONFIG_SCHEMA, key_label_name="step_name"), + is_required=False, + default_value={}, + description="Per op k8s configuration overrides.", + ), }, ) @@ -161,6 +168,7 @@ def k8s_job_executor(init_context: InitExecutorContext) -> Executor: container_context=k8s_container_context, load_incluster_config=load_incluster_config, kubeconfig_file=kubeconfig_file, + per_step_k8s_config=exc_cfg.get("per_step_k8s_config", {}), ), retries=RetryMode.from_config(exc_cfg["retries"]), # type: ignore max_concurrent=check.opt_int_elem(exc_cfg, "max_concurrent"), @@ -181,6 +189,7 @@ def __init__( load_incluster_config: bool, kubeconfig_file: Optional[str], k8s_client_batch_api=None, + per_step_k8s_config=None, ): super().__init__() @@ -202,6 +211,9 @@ def __init__( self._api_client = DagsterKubernetesClient.production_client( batch_api_override=k8s_client_batch_api ) + self._per_step_k8s_config = check.opt_dict_param( + per_step_k8s_config, "per_step_k8s_config", key_type=str, value_type=dict + ) def _get_step_key(self, step_handler_context: StepHandlerContext) -> str: step_keys_to_execute = cast( @@ -225,7 +237,14 @@ def _get_container_context( user_defined_k8s_config = get_user_defined_k8s_config( step_handler_context.step_tags[step_key] ) - return context.merge(K8sContainerContext(run_k8s_config=user_defined_k8s_config)) + + per_op_override = UserDefinedDagsterK8sConfig.from_dict( + self._per_step_k8s_config.get(step_key, {}) + ) + + return context.merge(K8sContainerContext(run_k8s_config=user_defined_k8s_config)).merge( + K8sContainerContext(run_k8s_config=per_op_override) + ) def _get_k8s_step_job_name(self, step_handler_context: StepHandlerContext): step_key = self._get_step_key(step_handler_context) diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s_tests/unit_tests/test_executor.py b/python_modules/libraries/dagster-k8s/dagster_k8s_tests/unit_tests/test_executor.py index a2d14d4f021b8..91591ab251927 100644 --- a/python_modules/libraries/dagster-k8s/dagster_k8s_tests/unit_tests/test_executor.py +++ b/python_modules/libraries/dagster-k8s/dagster_k8s_tests/unit_tests/test_executor.py @@ -59,6 +59,11 @@ def foo(): "requests": {"cpu": "2500m", "memory": "1280Mi"}, } +FOURTH_RESOURCES_TAGS = { + "limits": {"cpu": "6000m", "memory": "3560Mi"}, + "requests": {"cpu": "3500m", "memory": "2280Mi"}, +} + @job( executor_def=k8s_job_executor, @@ -668,3 +673,64 @@ def test_step_raw_k8s_config_inheritance( assert raw_k8s_config.container_config["resources"] == OTHER_RESOURCE_TAGS assert raw_k8s_config.container_config["working_dir"] == "MY_WORKING_DIR" assert raw_k8s_config.container_config["volume_mounts"] == OTHER_VOLUME_MOUNTS_TAGS + + +def test_per_step_k8s_config(k8s_run_launcher_instance, python_origin_with_container_context): + container_context_config = { + "k8s": { + "run_k8s_config": {"container_config": {"volume_mounts": OTHER_VOLUME_MOUNTS_TAGS}}, + } + } + + python_origin = reconstructable(bar_with_tags_in_job_and_op).get_python_origin() + + python_origin_with_container_context = python_origin._replace( + repository_origin=python_origin.repository_origin._replace( + container_context=container_context_config + ) + ) + + # Verifies that k8s config for step pods is pulled from the container context and + # executor-level per_step_k8s_config, and that per_step_k8s_config precedes step_k8s_config + executor = _get_executor( + k8s_run_launcher_instance, + reconstructable(bar_with_tags_in_job_and_op), + { + "step_k8s_config": { # injected into every step + "container_config": { + "working_dir": "MY_WORKING_DIR", # set on every step + "resources": THIRD_RESOURCES_TAGS, # overridden by the per_step level, so ignored + } + }, + "per_step_k8s_config": { + "foo": { # injected only for "foo" step + "container_config": { + "resources": FOURTH_RESOURCES_TAGS, + } + } + }, + }, + ) + + run = create_run_for_test( + k8s_run_launcher_instance, + job_name="bar_with_tags_in_job_and_op", + job_code_origin=python_origin_with_container_context, + ) + + step_handler_context = _step_handler_context( + job_def=reconstructable(bar_with_tags_in_job_and_op), + dagster_run=run, + instance=k8s_run_launcher_instance, + executor=executor, + ) + + container_context = executor._step_handler._get_container_context( # noqa: SLF001 + step_handler_context + ) + + raw_k8s_config = container_context.run_k8s_config + + assert raw_k8s_config.container_config["resources"] == FOURTH_RESOURCES_TAGS + assert raw_k8s_config.container_config["working_dir"] == "MY_WORKING_DIR" + assert raw_k8s_config.container_config["volume_mounts"] == OTHER_VOLUME_MOUNTS_TAGS