Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/apache/airflow into kalyan/…
Browse files Browse the repository at this point in the history
…AIP-84/fix_dag_run_id
  • Loading branch information
rawwar committed Nov 23, 2024
2 parents 4d02397 + 3c58e01 commit 2390504
Show file tree
Hide file tree
Showing 49 changed files with 2,065 additions and 207 deletions.
7 changes: 6 additions & 1 deletion .devcontainer/mysql/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,10 @@
"rogalmic.bash-debug"
],
"service": "airflow",
"forwardPorts": [8080,5555,5432,6379]
"forwardPorts": [8080,5555,5432,6379],
"workspaceFolder": "/opt/airflow",
// for users who use non-standard git config patterns
// https://github.com/microsoft/vscode-remote-release/issues/2084#issuecomment-989756268
"initializeCommand": "cd \"${localWorkspaceFolder}\" && git config --local user.email \"$(git config user.email)\" && git config --local user.name \"$(git config user.name)\"",
"overrideCommand": true
}
7 changes: 6 additions & 1 deletion .devcontainer/postgres/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,10 @@
"rogalmic.bash-debug"
],
"service": "airflow",
"forwardPorts": [8080,5555,5432,6379]
"forwardPorts": [8080,5555,5432,6379],
"workspaceFolder": "/opt/airflow",
// for users who use non-standard git config patterns
// https://github.com/microsoft/vscode-remote-release/issues/2084#issuecomment-989756268
"initializeCommand": "cd \"${localWorkspaceFolder}\" && git config --local user.email \"$(git config user.email)\" && git config --local user.name \"$(git config user.name)\"",
"overrideCommand": true
}
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ repos:
types_or: [python, pyi]
args: [--fix]
require_serial: true
additional_dependencies: ["ruff==0.7.3"]
additional_dependencies: ["ruff==0.8.0"]
exclude: ^.*/.*_vendor/|^tests/dags/test_imports.py|^performance/tests/test_.*.py
- id: ruff-format
name: Run 'ruff format'
Expand All @@ -370,7 +370,7 @@ repos:
types_or: [python, pyi]
args: []
require_serial: true
additional_dependencies: ["ruff==0.7.3"]
additional_dependencies: ["ruff==0.8.0"]
exclude: ^.*/.*_vendor/|^tests/dags/test_imports.py$
- id: replace-bad-characters
name: Replace bad characters
Expand Down
84 changes: 81 additions & 3 deletions airflow/api_fastapi/common/db/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@

from typing import TYPE_CHECKING, Literal, Sequence, overload

from airflow.utils.db import get_query_count
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from sqlalchemy.ext.asyncio import AsyncSession

from airflow.utils.db import get_query_count, get_query_count_async
from airflow.utils.session import NEW_SESSION, create_session, create_session_async, provide_session

if TYPE_CHECKING:
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -53,7 +55,9 @@ def your_route(session: Annotated[Session, Depends(get_session)]):


def apply_filters_to_select(
*, base_select: Select, filters: Sequence[BaseParam | None] | None = None
*,
base_select: Select,
filters: Sequence[BaseParam | None] | None = None,
) -> Select:
if filters is None:
return base_select
Expand All @@ -65,6 +69,80 @@ def apply_filters_to_select(
return base_select


async def get_async_session() -> AsyncSession:
"""
Dependency for providing a session.
Example usage:
.. code:: python
@router.get("/your_path")
def your_route(session: Annotated[AsyncSession, Depends(get_async_session)]):
pass
"""
async with create_session_async() as session:
yield session


@overload
async def paginated_select_async(
*,
query: Select,
filters: Sequence[BaseParam] | None = None,
order_by: BaseParam | None = None,
offset: BaseParam | None = None,
limit: BaseParam | None = None,
session: AsyncSession,
return_total_entries: Literal[True] = True,
) -> tuple[Select, int]: ...


@overload
async def paginated_select_async(
*,
query: Select,
filters: Sequence[BaseParam] | None = None,
order_by: BaseParam | None = None,
offset: BaseParam | None = None,
limit: BaseParam | None = None,
session: AsyncSession,
return_total_entries: Literal[False],
) -> tuple[Select, None]: ...


async def paginated_select_async(
*,
query: Select,
filters: Sequence[BaseParam | None] | None = None,
order_by: BaseParam | None = None,
offset: BaseParam | None = None,
limit: BaseParam | None = None,
session: AsyncSession,
return_total_entries: bool = True,
) -> tuple[Select, int | None]:
query = apply_filters_to_select(
base_select=query,
filters=filters,
)

total_entries = None
if return_total_entries:
total_entries = await get_query_count_async(query, session=session)

# TODO: Re-enable when permissions are handled. Readable / writable entities,
# for instance:
# readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user)
# dags_select = dags_select.where(DagModel.dag_id.in_(readable_dags))

query = apply_filters_to_select(
base_select=query,
filters=[order_by, offset, limit],
)

return query, total_entries


@overload
def paginated_select(
*,
Expand Down
15 changes: 7 additions & 8 deletions airflow/api_fastapi/core_api/routes/public/backfills.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@

from fastapi import Depends, HTTPException, status
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session

from airflow.api_fastapi.common.db.common import get_session, paginated_select
from airflow.api_fastapi.common.db.common import get_async_session, get_session, paginated_select_async
from airflow.api_fastapi.common.parameters import QueryLimit, QueryOffset, SortParam
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.datamodels.backfills import (
Expand All @@ -49,26 +50,24 @@
@backfills_router.get(
path="",
)
def list_backfills(
async def list_backfills(
dag_id: str,
limit: QueryLimit,
offset: QueryOffset,
order_by: Annotated[
SortParam,
Depends(SortParam(["id"], Backfill).dynamic_depends()),
],
session: Annotated[Session, Depends(get_session)],
session: Annotated[AsyncSession, Depends(get_async_session)],
) -> BackfillCollectionResponse:
select_stmt, total_entries = paginated_select(
select=select(Backfill).where(Backfill.dag_id == dag_id),
select_stmt, total_entries = await paginated_select_async(
query=select(Backfill).where(Backfill.dag_id == dag_id),
order_by=order_by,
offset=offset,
limit=limit,
session=session,
)

backfills = session.scalars(select_stmt)

backfills = await session.scalars(select_stmt)
return BackfillCollectionResponse(
backfills=backfills,
total_entries=total_entries,
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@

# Todo: AIP-44: Once we get rid of AIP-44 we can remove this. But without this here pydantic fails to resolve
# types for serialization
from airflow.utils.task_group import TaskGroup # noqa: TCH001
from airflow.utils.task_group import TaskGroup # noqa: TC001

TaskPreExecuteHook = Callable[[Context], None]
TaskPostExecuteHook = Callable[[Context, Any], None]
Expand Down
3 changes: 2 additions & 1 deletion airflow/models/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
select,
text,
)
from sqlalchemy.dialects import postgresql
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import Query, reconstructor, relationship

Expand Down Expand Up @@ -79,7 +80,7 @@ class BaseXCom(TaskInstanceDependencies, LoggingMixin):
dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
run_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)

value = Column(JSON)
value = Column(JSON().with_variant(postgresql.JSONB, "postgresql"))
timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False)

__table_args__ = (
Expand Down
4 changes: 2 additions & 2 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@

HAS_KUBERNETES: bool
try:
from kubernetes.client import models as k8s # noqa: TCH004
from kubernetes.client import models as k8s # noqa: TC004

from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator # noqa: TCH004
from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator # noqa: TC004
except ImportError:
pass

Expand Down
10 changes: 5 additions & 5 deletions airflow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import pluggy
from packaging.version import Version
from sqlalchemy import create_engine, exc, text
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession as SAAsyncSession, create_async_engine
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.pool import NullPool

Expand Down Expand Up @@ -111,7 +111,7 @@
# this is achieved by the Session factory above.
NonScopedSession: Callable[..., SASession]
async_engine: AsyncEngine
create_async_session: Callable[..., AsyncSession]
AsyncSession: Callable[..., SAAsyncSession]

# The JSON library to use for DAG Serialization and De-Serialization
json = json
Expand Down Expand Up @@ -469,7 +469,7 @@ def configure_orm(disable_connection_pool=False, pool_class=None):
global Session
global engine
global async_engine
global create_async_session
global AsyncSession
global NonScopedSession

if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true":
Expand Down Expand Up @@ -498,11 +498,11 @@ def configure_orm(disable_connection_pool=False, pool_class=None):

engine = create_engine(SQL_ALCHEMY_CONN, connect_args=connect_args, **engine_args, future=True)
async_engine = create_async_engine(SQL_ALCHEMY_CONN_ASYNC, future=True)
create_async_session = sessionmaker(
AsyncSession = sessionmaker(
bind=async_engine,
autocommit=False,
autoflush=False,
class_=AsyncSession,
class_=SAAsyncSession,
expire_on_commit=False,
)
mask_secret(engine.url.password)
Expand Down
16 changes: 16 additions & 0 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from alembic.runtime.environment import EnvironmentContext
from alembic.script import ScriptDirectory
from sqlalchemy.engine import Row
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from sqlalchemy.sql.elements import ClauseElement, TextClause
from sqlalchemy.sql.selectable import Select
Expand Down Expand Up @@ -1447,6 +1448,21 @@ def get_query_count(query_stmt: Select, *, session: Session) -> int:
return session.scalar(count_stmt)


async def get_query_count_async(query: Select, *, session: AsyncSession) -> int:
"""
Get count of a query.
A SELECT COUNT() FROM is issued against the subquery built from the
given statement. The ORDER BY clause is stripped from the statement
since it's unnecessary for COUNT, and can impact query planning and
degrade performance.
:meta private:
"""
count_stmt = select(func.count()).select_from(query.order_by(None).subquery())
return await session.scalar(count_stmt)


def check_query_exists(query_stmt: Select, *, session: Session) -> bool:
"""
Check whether there is at least one row matching a query.
Expand Down
16 changes: 14 additions & 2 deletions airflow/utils/db_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@
logger = logging.getLogger(__name__)

ARCHIVE_TABLE_PREFIX = "_airflow_deleted__"
# Archived tables created by DB migrations
ARCHIVED_TABLES_FROM_DB_MIGRATIONS = [
"_xcom_archive" # Table created by the AF 2 -> 3.0.0 migration when the XComs had pickled values
]


@dataclass
Expand Down Expand Up @@ -116,6 +120,7 @@ def readable_config(self):
_TableConfig(table_name="task_instance_history", recency_column_name="start_date"),
_TableConfig(table_name="task_reschedule", recency_column_name="start_date"),
_TableConfig(table_name="xcom", recency_column_name="timestamp"),
_TableConfig(table_name="_xcom_archive", recency_column_name="timestamp"),
_TableConfig(table_name="callback_request", recency_column_name="created_at"),
_TableConfig(table_name="celery_taskmeta", recency_column_name="date_done"),
_TableConfig(table_name="celery_tasksetmeta", recency_column_name="date_done"),
Expand Down Expand Up @@ -380,13 +385,20 @@ def _effective_table_names(*, table_names: list[str] | None) -> tuple[set[str],

def _get_archived_table_names(table_names: list[str] | None, session: Session) -> list[str]:
inspector = inspect(session.bind)
db_table_names = [x for x in inspector.get_table_names() if x.startswith(ARCHIVE_TABLE_PREFIX)]
db_table_names = [
x
for x in inspector.get_table_names()
if x.startswith(ARCHIVE_TABLE_PREFIX) or x in ARCHIVED_TABLES_FROM_DB_MIGRATIONS
]
effective_table_names, _ = _effective_table_names(table_names=table_names)
# Filter out tables that don't start with the archive prefix
archived_table_names = [
table_name
for table_name in db_table_names
if any("__" + x + "__" in table_name for x in effective_table_names)
if (
any("__" + x + "__" in table_name for x in effective_table_names)
or table_name in ARCHIVED_TABLES_FROM_DB_MIGRATIONS
)
]
return archived_table_names

Expand Down
18 changes: 18 additions & 0 deletions airflow/utils/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,24 @@ def create_session(scoped: bool = True) -> Generator[SASession, None, None]:
session.close()


@contextlib.asynccontextmanager
async def create_session_async():
"""
Context manager to create async session.
:meta private:
"""
from airflow.settings import AsyncSession

async with AsyncSession() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise


PS = ParamSpec("PS")
RT = TypeVar("RT")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2195,7 +2195,7 @@ class ProviderPRInfo(NamedTuple):
f"[warning]Skipping provider {provider_id}. "
"The changelog file doesn't contain any PRs for the release.\n"
)
return
continue
provider_prs[provider_id] = [pr for pr in prs if pr not in excluded_prs]
all_prs.update(provider_prs[provider_id])
g = Github(github_token)
Expand Down Expand Up @@ -2245,6 +2245,8 @@ class ProviderPRInfo(NamedTuple):
progress.advance(task)
providers: dict[str, ProviderPRInfo] = {}
for provider_id in prepared_package_ids:
if provider_id not in provider_prs:
continue
pull_request_list = [pull_requests[pr] for pr in provider_prs[provider_id] if pr in pull_requests]
provider_yaml_dict = yaml.safe_load(
(
Expand Down
2 changes: 1 addition & 1 deletion dev/breeze/src/airflow_breeze/utils/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class Output(NamedTuple):

@property
def file(self) -> TextIO:
return open(self.file_name, "a+t")
return open(self.file_name, "a+")

@property
def escaped_title(self) -> str:
Expand Down
Loading

0 comments on commit 2390504

Please sign in to comment.