Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split next_dagruns_to_examine function into two #42386

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,9 +1192,10 @@ def _do_scheduling(self, session: Session) -> int:

self._start_queued_dagruns(session)
guard.commit()
dag_runs = self._get_next_dagruns_to_examine(DagRunState.RUNNING, session)

# Bulk fetch the currently active dag runs for the dags we are
# examining, rather than making one query per DagRun
dag_runs = DagRun.running_dag_runs_to_examine(session=session)

callback_tuples = self._schedule_all_dag_runs(guard, dag_runs, session)

Expand Down Expand Up @@ -1248,11 +1249,6 @@ def _do_scheduling(self, session: Session) -> int:

return num_queued_tis

@retry_db_transaction
def _get_next_dagruns_to_examine(self, state: DagRunState, session: Session) -> Query:
"""Get Next DagRuns to Examine with retries."""
return DagRun.next_dagruns_to_examine(state, session)

@retry_db_transaction
def _create_dagruns_for_dags(self, guard: CommitProhibitorGuard, session: Session) -> None:
"""Find Dag Models needing DagRuns and Create Dag Runs with retries in case of OperationalError."""
Expand Down Expand Up @@ -1486,7 +1482,7 @@ def _should_update_dag_next_dagruns(
def _start_queued_dagruns(self, session: Session) -> None:
"""Find DagRuns in queued state and decide moving them to running state."""
# added all() to save runtime, otherwise query is executed more than once
dag_runs: Collection[DagRun] = self._get_next_dagruns_to_examine(DagRunState.QUEUED, session).all()
dag_runs: Collection[DagRun] = DagRun.queued_dag_runs_to_set_running(session).all()

active_runs_of_dags = Counter(
DagRun.active_runs_of_dags((dr.dag_id for dr in dag_runs), only_running=True, session=session),
Expand Down
84 changes: 59 additions & 25 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from airflow.utils.dates import datetime_to_nano
from airflow.utils.helpers import chunks, is_container, prune_dict
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.retries import retry_db_transaction
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, tuple_in_condition, with_row_locks
from airflow.utils.state import DagRunState, State, TaskInstanceState
Expand Down Expand Up @@ -388,55 +389,88 @@ def active_runs_of_dags(
return dict(iter(session.execute(query)))

@classmethod
def next_dagruns_to_examine(
cls,
state: DagRunState,
session: Session,
max_number: int | None = None,
) -> Query:
@retry_db_transaction
def running_dag_runs_to_examine(cls, session: Session) -> Query:
"""
Return the next DagRuns that the scheduler should attempt to schedule.

This will return zero or more DagRun rows that are row-level-locked with a "SELECT ... FOR UPDATE"
query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as
the transaction is committed it will be unlocked.

:meta private:
"""
from airflow.models.dag import DagModel

if max_number is None:
max_number = cls.DEFAULT_DAGRUNS_TO_EXAMINE

# TODO: Bake this query, it is run _A lot_
query = (
select(cls)
.with_hint(cls, "USE INDEX (idx_dag_run_running_dags)", dialect_name="mysql")
.where(cls.state == state, cls.run_type != DagRunType.BACKFILL_JOB)
.where(cls.state == DagRunState.RUNNING, cls.run_type != DagRunType.BACKFILL_JOB)
.join(DagModel, DagModel.dag_id == cls.dag_id)
.where(DagModel.is_paused == false(), DagModel.is_active == true())
.order_by(
nulls_first(cls.last_scheduling_decision, session=session),
cls.execution_date,
)
)
if state == DagRunState.QUEUED:
# For dag runs in the queued state, we check if they have reached the max_active_runs limit
# and if so we drop them
running_drs = (
select(DagRun.dag_id, func.count(DagRun.state).label("num_running"))
.where(DagRun.state == DagRunState.RUNNING)
.group_by(DagRun.dag_id)
.subquery()

if not settings.ALLOW_FUTURE_EXEC_DATES:
query = query.where(DagRun.execution_date <= func.now())

return session.scalars(
with_row_locks(
query.limit(cls.DEFAULT_DAGRUNS_TO_EXAMINE),
of=cls,
session=session,
skip_locked=True,
)
query = query.outerjoin(running_drs, running_drs.c.dag_id == DagRun.dag_id).where(
func.coalesce(running_drs.c.num_running, 0) < DagModel.max_active_runs
)

@classmethod
@retry_db_transaction
def queued_dag_runs_to_set_running(cls, session: Session) -> Query:
"""
Return the next queued DagRuns that the scheduler should attempt to schedule.

This will return zero or more DagRun rows that are row-level-locked with a "SELECT ... FOR UPDATE"
query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as
the transaction is committed it will be unlocked.

:meta private:
"""
from airflow.models.dag import DagModel

# For dag runs in the queued state, we check if they have reached the max_active_runs limit
# and if so we drop them
running_drs = (
select(DagRun.dag_id, func.count(DagRun.state).label("num_running"))
.where(DagRun.state == DagRunState.RUNNING)
.group_by(DagRun.dag_id)
.subquery()
)
query = (
select(cls)
.where(cls.state == DagRunState.QUEUED, cls.run_type != DagRunType.BACKFILL_JOB)
.join(DagModel, DagModel.dag_id == cls.dag_id)
.where(DagModel.is_paused == false(), DagModel.is_active == true())
.outerjoin(running_drs, running_drs.c.dag_id == DagRun.dag_id)
.where(func.coalesce(running_drs.c.num_running, 0) < DagModel.max_active_runs)
.order_by(
nulls_first(cls.last_scheduling_decision, session=session),
cls.execution_date,
)
query = query.order_by(
nulls_first(cls.last_scheduling_decision, session=session),
cls.execution_date,
)

if not settings.ALLOW_FUTURE_EXEC_DATES:
query = query.where(DagRun.execution_date <= func.now())

return session.scalars(
with_row_locks(query.limit(max_number), of=cls, session=session, skip_locked=True)
with_row_locks(
query.limit(cls.DEFAULT_DAGRUNS_TO_EXAMINE),
of=cls,
session=session,
skip_locked=True,
)
)

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion tests/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -6014,7 +6014,7 @@ def test_execute_queries_count_with_harvested_dags(self, expected_query_count, d
self.job_runner.processor_agent = mock_agent

with assert_queries_count(expected_query_count, margin=15):
with mock.patch.object(DagRun, "next_dagruns_to_examine") as mock_dagruns:
with mock.patch.object(DagRun, "running_dag_runs_to_examine") as mock_dagruns:
query = MagicMock()
query.all.return_value = dagruns
mock_dagruns.return_value = query
Expand Down
8 changes: 6 additions & 2 deletions tests/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,14 +931,18 @@ def test_next_dagruns_to_examine_only_unpaused(self, session, state):
**triggered_by_kwargs,
)

runs = DagRun.next_dagruns_to_examine(state, session).all()
if state == DagRunState.RUNNING:
func = DagRun.running_dag_runs_to_examine
else:
func = DagRun.queued_dag_runs_to_set_running
runs = func(session).all()

assert runs == [dr]

orm_dag.is_paused = True
session.flush()

runs = DagRun.next_dagruns_to_examine(state, session).all()
runs = func(session).all()
assert runs == []

@mock.patch.object(Stats, "timing")
Expand Down