Skip to content

Commit

Permalink
Use the configured repository when running exporting AWS Batch jobs w…
Browse files Browse the repository at this point in the history
…ith `--skip-docker`.
  • Loading branch information
Dennis Li committed Dec 6, 2024
1 parent b0ea667 commit 14b353f
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 9 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ CHANGELOG
0.9.4dev
--------

* Feature - `--skip-docker` uses the configured repository in soopervisor.yaml when exporting to AWS Batch. (by @DennisJLi)

0.9.3 (2024-09-18)
------------------

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def read(*names, **kwargs):
"tqdm",
"pydantic",
"Jinja2",
"pyyaml",
"pyyaml>=6.0.2",
"ploomber>=0.14.6",
"ploomber-core>=0.0.11",
# sdist is generated using python -m build, so adding this here.
Expand Down
2 changes: 1 addition & 1 deletion src/soopervisor/assets/airflow/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from airflow import DAG
from airflow.utils.dates import days_ago
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import (
from airflow.providers.cncf.kubernetes.operators.pod import (
KubernetesPodOperator,
)

Expand Down
8 changes: 3 additions & 5 deletions src/soopervisor/aws/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,13 +280,11 @@ def _export(
"submit all tasks regardless of status"
)
if skip_docker:
pkg_name, version = source.find_package_name_and_version()
pkg_name, _ = source.find_package_name_and_version()
image = f"{cfg.repository}:latest"
default_image_key = get_default_image_key()
if default_image_key:
image_local = f"{pkg_name}:{version}-"
f"{docker.modify_wildcard(default_image_key)}"
image_map = {}
image_map[default_image_key] = image_local
image_map[default_image_key] = image
else:
pkg_name, image_map = docker.build(
cmdr,
Expand Down
2 changes: 1 addition & 1 deletion tests/airflow/test_airflow_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json

from airflow import DAG
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import (
from airflow.providers.cncf.kubernetes.operators.pod import (
KubernetesPodOperator,
)
from airflow.operators.bash import BashOperator
Expand Down
72 changes: 71 additions & 1 deletion tests/aws_batch/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def test_export(


# TODO: check error if wrong task name
# TODO: check errro when task is up to date
# TODO: check error when task is up to date
# TODO: check error if dependencies from submitted task are outdated
@pytest.mark.parametrize(
"mode, args",
Expand Down Expand Up @@ -662,3 +662,73 @@ def test_lazy_load(mock_aws_batch_lazy_load, monkeypatch):
path_to_config="soopervisor.yaml", env_name="train", lazy_import=True
)
exporter.export(mode="incremental", lazy_import=True)


def test_export_with_skip_docker_uses_configured_repository(
mock_batch,
monkeypatch,
tmp_sample_project_multiple_requirement,
monkeypatch_docker_client,
skip_repo_validation,
boto3_mock,
monkeypatch_docker_commons,
load_tasks_mock,
):
monkeypatch.setattr(batch, "uuid4", lambda: "uuid4")
monkeypatch.setattr(batch.boto3, "client", lambda name, region_name: boto3_mock)
monkeypatch.setattr(commons, "load_tasks", load_tasks_mock)

repository = "123456789012.dkr.ecr.us-east-1.amazonaws.com/my-repository/model"

exporter = batch.AWSBatchExporter.new("soopervisor.yaml", "some-env")
exporter._cfg.repository = repository
exporter.add()

# mock commander
commander_mock = MagicMock()
monkeypatch.setattr(
batch, "Commander", lambda workspace, templates_path: commander_mock
)

exporter.export(mode="incremental", skip_docker=True)

jobs = mock_batch.list_jobs(jobQueue="your-job-queue")["jobSummaryList"]

# get jobs information
jobs_info = mock_batch.describe_jobs(jobs=[job["jobId"] for job in jobs])["jobs"]

job_defs = mock_batch.describe_job_definitions(
jobDefinitions=[job["jobDefinition"] for job in jobs_info]
)["jobDefinitions"]

# check all tasks submitted
assert {j["jobName"] for j in jobs_info} == {"raw", "clean-1", "plot", "clean-2"}

# check submitted to the right queue
assert all(["your-job-queue" in j["jobQueue"] for j in jobs_info])

# check created a job definition with the right name
job_definitions = {j["jobName"]: j["jobDefinition"] for j in jobs_info}
assert job_definitions == {
"raw": "arn:aws:batch:us-east-1:123456789012:job-definition/"
"multiple_requirements_project-uuid4:1",
"clean-1": "arn:aws:batch:us-east-1:123456789012:job-definition/"
"multiple_requirements_project-uuid4:1",
"clean-2": "arn:aws:batch:us-east-1:123456789012:job-definition/"
"multiple_requirements_project-uuid4:1",
"plot": "arn:aws:batch:us-east-1:123456789012:job-definition/"
"multiple_requirements_project-uuid4:1",
}

job_images = {
j["jobDefinitionArn"]: j["containerProperties"]["image"] for j in job_defs
}

expected_image = f"{repository}:latest"

expected = {
"arn:aws:batch:us-east-1:123456789012:job-definition/"
"multiple_requirements_project-uuid4:1": expected_image,
}

assert job_images == expected

0 comments on commit 14b353f

Please sign in to comment.