diff --git a/CHANGELOG.rst b/CHANGELOG.rst index f7d2bda..5a10827 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,6 +4,8 @@ CHANGELOG 0.9.4dev -------- +* Feature - `--skip-docker` uses the configured repository in soopervisor.yaml when exporting to AWS Batch. (by @DennisJLi) + 0.9.3 (2024-09-18) ------------------ diff --git a/setup.py b/setup.py index 2bec3bc..467065e 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ def read(*names, **kwargs): "tqdm", "pydantic", "Jinja2", - "pyyaml", + "pyyaml>=6.0.2", "ploomber>=0.14.6", "ploomber-core>=0.0.11", # sdist is generated using python -m build, so adding this here. diff --git a/src/soopervisor/assets/airflow/kubernetes.py b/src/soopervisor/assets/airflow/kubernetes.py index 83b4dee..2912085 100644 --- a/src/soopervisor/assets/airflow/kubernetes.py +++ b/src/soopervisor/assets/airflow/kubernetes.py @@ -3,7 +3,7 @@ from airflow import DAG from airflow.utils.dates import days_ago -from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import ( +from airflow.providers.cncf.kubernetes.operators.pod import ( KubernetesPodOperator, ) diff --git a/src/soopervisor/aws/batch.py b/src/soopervisor/aws/batch.py index 4c4a4d2..5d07c51 100644 --- a/src/soopervisor/aws/batch.py +++ b/src/soopervisor/aws/batch.py @@ -280,13 +280,11 @@ def _export( "submit all tasks regardless of status" ) if skip_docker: - pkg_name, version = source.find_package_name_and_version() + pkg_name, _ = source.find_package_name_and_version() + image = f"{cfg.repository}:latest" default_image_key = get_default_image_key() - if default_image_key: - image_local = f"{pkg_name}:{version}-" - f"{docker.modify_wildcard(default_image_key)}" image_map = {} - image_map[default_image_key] = image_local + image_map[default_image_key] = image else: pkg_name, image_map = docker.build( cmdr, diff --git a/tests/airflow/test_airflow_export.py b/tests/airflow/test_airflow_export.py index d01a5de..f0355c6 100644 --- a/tests/airflow/test_airflow_export.py +++ b/tests/airflow/test_airflow_export.py @@ -6,7 +6,7 @@ import json from airflow import DAG -from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import ( +from airflow.providers.cncf.kubernetes.operators.pod import ( KubernetesPodOperator, ) from airflow.operators.bash import BashOperator diff --git a/tests/aws_batch/test_export.py b/tests/aws_batch/test_export.py index ace7d12..47dbb83 100644 --- a/tests/aws_batch/test_export.py +++ b/tests/aws_batch/test_export.py @@ -169,7 +169,7 @@ def test_export( # TODO: check error if wrong task name -# TODO: check errro when task is up to date +# TODO: check error when task is up to date # TODO: check error if dependencies from submitted task are outdated @pytest.mark.parametrize( "mode, args", @@ -662,3 +662,73 @@ def test_lazy_load(mock_aws_batch_lazy_load, monkeypatch): path_to_config="soopervisor.yaml", env_name="train", lazy_import=True ) exporter.export(mode="incremental", lazy_import=True) + + +def test_export_with_skip_docker_uses_configured_repository( + mock_batch, + monkeypatch, + tmp_sample_project_multiple_requirement, + monkeypatch_docker_client, + skip_repo_validation, + boto3_mock, + monkeypatch_docker_commons, + load_tasks_mock, +): + monkeypatch.setattr(batch, "uuid4", lambda: "uuid4") + monkeypatch.setattr(batch.boto3, "client", lambda name, region_name: boto3_mock) + monkeypatch.setattr(commons, "load_tasks", load_tasks_mock) + + repository = "123456789012.dkr.ecr.us-east-1.amazonaws.com/my-repository/model" + + exporter = batch.AWSBatchExporter.new("soopervisor.yaml", "some-env") + exporter._cfg.repository = repository + exporter.add() + + # mock commander + commander_mock = MagicMock() + monkeypatch.setattr( + batch, "Commander", lambda workspace, templates_path: commander_mock + ) + + exporter.export(mode="incremental", skip_docker=True) + + jobs = mock_batch.list_jobs(jobQueue="your-job-queue")["jobSummaryList"] + + # get jobs information + jobs_info = mock_batch.describe_jobs(jobs=[job["jobId"] for job in jobs])["jobs"] + + job_defs = mock_batch.describe_job_definitions( + jobDefinitions=[job["jobDefinition"] for job in jobs_info] + )["jobDefinitions"] + + # check all tasks submitted + assert {j["jobName"] for j in jobs_info} == {"raw", "clean-1", "plot", "clean-2"} + + # check submitted to the right queue + assert all(["your-job-queue" in j["jobQueue"] for j in jobs_info]) + + # check created a job definition with the right name + job_definitions = {j["jobName"]: j["jobDefinition"] for j in jobs_info} + assert job_definitions == { + "raw": "arn:aws:batch:us-east-1:123456789012:job-definition/" + "multiple_requirements_project-uuid4:1", + "clean-1": "arn:aws:batch:us-east-1:123456789012:job-definition/" + "multiple_requirements_project-uuid4:1", + "clean-2": "arn:aws:batch:us-east-1:123456789012:job-definition/" + "multiple_requirements_project-uuid4:1", + "plot": "arn:aws:batch:us-east-1:123456789012:job-definition/" + "multiple_requirements_project-uuid4:1", + } + + job_images = { + j["jobDefinitionArn"]: j["containerProperties"]["image"] for j in job_defs + } + + expected_image = f"{repository}:latest" + + expected = { + "arn:aws:batch:us-east-1:123456789012:job-definition/" + "multiple_requirements_project-uuid4:1": expected_image, + } + + assert job_images == expected