From fe166ec03918825d0f5274fef72a1f027137fdd6 Mon Sep 17 00:00:00 2001 From: perry2of5 Date: Fri, 22 Nov 2024 09:35:23 -0800 Subject: [PATCH 01/14] Specify workspaceFolder so devcontainer will start with local docker from vscode (#44273) --- .devcontainer/mysql/devcontainer.json | 7 ++++++- .devcontainer/postgres/devcontainer.json | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/.devcontainer/mysql/devcontainer.json b/.devcontainer/mysql/devcontainer.json index 5a25b6ad5062..011ff292d271 100644 --- a/.devcontainer/mysql/devcontainer.json +++ b/.devcontainer/mysql/devcontainer.json @@ -22,5 +22,10 @@ "rogalmic.bash-debug" ], "service": "airflow", - "forwardPorts": [8080,5555,5432,6379] + "forwardPorts": [8080,5555,5432,6379], + "workspaceFolder": "/opt/airflow", + // for users who use non-standard git config patterns + // https://github.com/microsoft/vscode-remote-release/issues/2084#issuecomment-989756268 + "initializeCommand": "cd \"${localWorkspaceFolder}\" && git config --local user.email \"$(git config user.email)\" && git config --local user.name \"$(git config user.name)\"", + "overrideCommand": true } diff --git a/.devcontainer/postgres/devcontainer.json b/.devcontainer/postgres/devcontainer.json index 46ba305b5855..419dbedfa1d3 100644 --- a/.devcontainer/postgres/devcontainer.json +++ b/.devcontainer/postgres/devcontainer.json @@ -22,5 +22,10 @@ "rogalmic.bash-debug" ], "service": "airflow", - "forwardPorts": [8080,5555,5432,6379] + "forwardPorts": [8080,5555,5432,6379], + "workspaceFolder": "/opt/airflow", + // for users who use non-standard git config patterns + // https://github.com/microsoft/vscode-remote-release/issues/2084#issuecomment-989756268 + "initializeCommand": "cd \"${localWorkspaceFolder}\" && git config --local user.email \"$(git config user.email)\" && git config --local user.name \"$(git config user.name)\"", + "overrideCommand": true } From 5b2a96ee72a4bd7b7fbd4925c2f8c8468f047a64 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Fri, 22 Nov 2024 23:51:28 +0530 Subject: [PATCH 02/14] Don't exit doc preparation even if changelog is empty for any provider (#44207) --- .../airflow_breeze/commands/release_management_commands.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dev/breeze/src/airflow_breeze/commands/release_management_commands.py b/dev/breeze/src/airflow_breeze/commands/release_management_commands.py index 01efdb77fabe..71d4d2a35fb1 100644 --- a/dev/breeze/src/airflow_breeze/commands/release_management_commands.py +++ b/dev/breeze/src/airflow_breeze/commands/release_management_commands.py @@ -2195,7 +2195,7 @@ class ProviderPRInfo(NamedTuple): f"[warning]Skipping provider {provider_id}. " "The changelog file doesn't contain any PRs for the release.\n" ) - return + continue provider_prs[provider_id] = [pr for pr in prs if pr not in excluded_prs] all_prs.update(provider_prs[provider_id]) g = Github(github_token) @@ -2245,6 +2245,8 @@ class ProviderPRInfo(NamedTuple): progress.advance(task) providers: dict[str, ProviderPRInfo] = {} for provider_id in prepared_package_ids: + if provider_id not in provider_prs: + continue pull_request_list = [pull_requests[pr] for pr in provider_prs[provider_id] if pr in pull_requests] provider_yaml_dict = yaml.safe_load( ( From d79c6c21f2d571bae236419bad87bc48bf9c97ce Mon Sep 17 00:00:00 2001 From: AutomationDev85 <96178949+AutomationDev85@users.noreply.github.com> Date: Fri, 22 Nov 2024 21:02:47 +0100 Subject: [PATCH 03/14] [edge] Clean up of dead tasks in edge_jobs table (#44280) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add edge_job clean up * Reworked unit test * Reworked unit test --------- Co-authored-by: Marco Küttelwesch --- .../src/airflow/providers/edge/CHANGELOG.rst | 8 +++ .../src/airflow/providers/edge/__init__.py | 2 +- .../providers/edge/executors/edge_executor.py | 41 +++++++++++-- .../src/airflow/providers/edge/provider.yaml | 2 +- .../edge/executors/test_edge_executor.py | 59 +++++++++++++++++-- 5 files changed, 99 insertions(+), 13 deletions(-) diff --git a/providers/src/airflow/providers/edge/CHANGELOG.rst b/providers/src/airflow/providers/edge/CHANGELOG.rst index 48c7a76b5f0e..d24373463f88 100644 --- a/providers/src/airflow/providers/edge/CHANGELOG.rst +++ b/providers/src/airflow/providers/edge/CHANGELOG.rst @@ -27,6 +27,14 @@ Changelog --------- +0.6.1pre0 +......... + +Misc +~~~~ + +* ``Update jobs or edge workers who have been killed to clean up job table.`` + 0.6.0pre0 ......... diff --git a/providers/src/airflow/providers/edge/__init__.py b/providers/src/airflow/providers/edge/__init__.py index 1613f44510a7..b3f545f82448 100644 --- a/providers/src/airflow/providers/edge/__init__.py +++ b/providers/src/airflow/providers/edge/__init__.py @@ -29,7 +29,7 @@ __all__ = ["__version__"] -__version__ = "0.6.0pre0" +__version__ = "0.6.1pre0" if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse( "2.10.0" diff --git a/providers/src/airflow/providers/edge/executors/edge_executor.py b/providers/src/airflow/providers/edge/executors/edge_executor.py index a13552fbf8a2..b990a3311571 100644 --- a/providers/src/airflow/providers/edge/executors/edge_executor.py +++ b/providers/src/airflow/providers/edge/executors/edge_executor.py @@ -26,7 +26,7 @@ from airflow.configuration import conf from airflow.executors.base_executor import BaseExecutor from airflow.models.abstractoperator import DEFAULT_QUEUE -from airflow.models.taskinstance import TaskInstanceState +from airflow.models.taskinstance import TaskInstance, TaskInstanceState from airflow.providers.edge.cli.edge_command import EDGE_COMMANDS from airflow.providers.edge.models.edge_job import EdgeJobModel from airflow.providers.edge.models.edge_logs import EdgeLogsModel @@ -42,7 +42,6 @@ from sqlalchemy.orm import Session from airflow.executors.base_executor import CommandType - from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey PARALLELISM: int = conf.getint("core", "PARALLELISM") @@ -108,6 +107,30 @@ def _check_worker_liveness(self, session: Session) -> bool: return changed + def _update_orphaned_jobs(self, session: Session) -> bool: + """Update status ob jobs when workers die and don't update anymore.""" + heartbeat_interval: int = conf.getint("scheduler", "scheduler_zombie_task_threshold") + lifeless_jobs: list[EdgeJobModel] = ( + session.query(EdgeJobModel) + .filter( + EdgeJobModel.state == TaskInstanceState.RUNNING, + EdgeJobModel.last_update < (timezone.utcnow() - timedelta(seconds=heartbeat_interval)), + ) + .all() + ) + + for job in lifeless_jobs: + ti = TaskInstance.get_task_instance( + dag_id=job.dag_id, + run_id=job.run_id, + task_id=job.task_id, + map_index=job.map_index, + session=session, + ) + job.state = ti.state if ti else TaskInstanceState.REMOVED + + return bool(lifeless_jobs) + def _purge_jobs(self, session: Session) -> bool: """Clean finished jobs.""" purged_marker = False @@ -117,7 +140,12 @@ def _purge_jobs(self, session: Session) -> bool: session.query(EdgeJobModel) .filter( EdgeJobModel.state.in_( - [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS, TaskInstanceState.FAILED] + [ + TaskInstanceState.RUNNING, + TaskInstanceState.SUCCESS, + TaskInstanceState.FAILED, + TaskInstanceState.REMOVED, + ] ) ) .all() @@ -145,7 +173,7 @@ def _purge_jobs(self, session: Session) -> bool: job.state == TaskInstanceState.SUCCESS and job.last_update_t < (datetime.now() - timedelta(minutes=job_success_purge)).timestamp() ) or ( - job.state == TaskInstanceState.FAILED + job.state in (TaskInstanceState.FAILED, TaskInstanceState.REMOVED) and job.last_update_t < (datetime.now() - timedelta(minutes=job_fail_purge)).timestamp() ): if job.key in self.last_reported_state: @@ -168,7 +196,10 @@ def _purge_jobs(self, session: Session) -> bool: def sync(self, session: Session = NEW_SESSION) -> None: """Sync will get called periodically by the heartbeat method.""" with Stats.timer("edge_executor.sync.duration"): - if self._purge_jobs(session) or self._check_worker_liveness(session): + orphaned = self._update_orphaned_jobs(session) + purged = self._purge_jobs(session) + liveness = self._check_worker_liveness(session) + if purged or liveness or orphaned: session.commit() def end(self) -> None: diff --git a/providers/src/airflow/providers/edge/provider.yaml b/providers/src/airflow/providers/edge/provider.yaml index 96ce7f152f7f..6fe609502aa2 100644 --- a/providers/src/airflow/providers/edge/provider.yaml +++ b/providers/src/airflow/providers/edge/provider.yaml @@ -27,7 +27,7 @@ source-date-epoch: 1729683247 # note that those versions are maintained by release manager - do not update them manually versions: - - 0.6.0pre0 + - 0.6.1pre0 dependencies: - apache-airflow>=2.10.0 diff --git a/providers/tests/edge/executors/test_edge_executor.py b/providers/tests/edge/executors/test_edge_executor.py index 7970e5fad04c..126afa1fb70b 100644 --- a/providers/tests/edge/executors/test_edge_executor.py +++ b/providers/tests/edge/executors/test_edge_executor.py @@ -16,11 +16,12 @@ # under the License. from __future__ import annotations -from datetime import datetime +from datetime import datetime, timedelta from unittest.mock import patch import pytest +from airflow.configuration import conf from airflow.models.taskinstancekey import TaskInstanceKey from airflow.providers.edge.executors.edge_executor import EdgeExecutor from airflow.providers.edge.models.edge_job import EdgeJobModel @@ -65,6 +66,44 @@ def test_execute_async_ok_command(self): assert jobs[0].run_id == "test_run" assert jobs[0].task_id == "test_task" + def test_sync_orphaned_tasks(self): + executor = EdgeExecutor() + + delta_to_purge = timedelta(minutes=conf.getint("edge", "job_fail_purge") + 1) + delta_to_orphaned = timedelta(seconds=conf.getint("scheduler", "scheduler_zombie_task_threshold") + 1) + + with create_session() as session: + for task_id, state, last_update in [ + ( + "started_running_orphaned", + TaskInstanceState.RUNNING, + timezone.utcnow() - delta_to_orphaned, + ), + ("started_removed", TaskInstanceState.REMOVED, timezone.utcnow() - delta_to_purge), + ]: + session.add( + EdgeJobModel( + dag_id="test_dag", + task_id=task_id, + run_id="test_run", + map_index=-1, + try_number=1, + state=state, + queue="default", + command="dummy", + last_update=last_update, + ) + ) + session.commit() + + executor.sync() + + with create_session() as session: + jobs = session.query(EdgeJobModel).all() + assert len(jobs) == 1 + assert jobs[0].task_id == "started_running_orphaned" + assert jobs[0].task_id == "started_running_orphaned" + @patch("airflow.providers.edge.executors.edge_executor.EdgeExecutor.running_state") @patch("airflow.providers.edge.executors.edge_executor.EdgeExecutor.success") @patch("airflow.providers.edge.executors.edge_executor.EdgeExecutor.fail") @@ -77,12 +116,14 @@ def remove_from_running(key: TaskInstanceKey): mock_success.side_effect = remove_from_running mock_fail.side_effect = remove_from_running + delta_to_purge = timedelta(minutes=conf.getint("edge", "job_fail_purge") + 1) + # Prepare some data with create_session() as session: - for task_id, state in [ - ("started_running", TaskInstanceState.RUNNING), - ("started_success", TaskInstanceState.SUCCESS), - ("started_failed", TaskInstanceState.FAILED), + for task_id, state, last_update in [ + ("started_running", TaskInstanceState.RUNNING, timezone.utcnow()), + ("started_success", TaskInstanceState.SUCCESS, timezone.utcnow() - delta_to_purge), + ("started_failed", TaskInstanceState.FAILED, timezone.utcnow() - delta_to_purge), ]: session.add( EdgeJobModel( @@ -94,7 +135,7 @@ def remove_from_running(key: TaskInstanceKey): state=state, queue="default", command="dummy", - last_update=timezone.utcnow(), + last_update=last_update, ) ) key = TaskInstanceKey( @@ -106,6 +147,12 @@ def remove_from_running(key: TaskInstanceKey): executor.sync() + with create_session() as session: + jobs = session.query(EdgeJobModel).all() + assert len(session.query(EdgeJobModel).all()) == 1 + assert jobs[0].task_id == "started_running" + assert jobs[0].state == TaskInstanceState.RUNNING + assert len(executor.running) == 1 mock_running_state.assert_called_once() mock_success.assert_called_once() From 66407e8d2260ae8bf96f75048f8d1b2d1db3766c Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Fri, 22 Nov 2024 12:37:09 -0800 Subject: [PATCH 04/14] Update backfill `list` endpoint to be async (#44208) This is a sort of hello world / proof of concept for having an route implemented using asyncio. Gotta start somewhere. --- airflow/api_fastapi/common/db/common.py | 84 ++++++++++++++++++- .../core_api/routes/public/backfills.py | 15 ++-- airflow/settings.py | 10 +-- airflow/utils/db.py | 16 ++++ airflow/utils/session.py | 18 ++++ tests/utils/test_session.py | 4 +- 6 files changed, 129 insertions(+), 18 deletions(-) diff --git a/airflow/api_fastapi/common/db/common.py b/airflow/api_fastapi/common/db/common.py index 17da1eafacc9..2d7da4bff737 100644 --- a/airflow/api_fastapi/common/db/common.py +++ b/airflow/api_fastapi/common/db/common.py @@ -24,8 +24,10 @@ from typing import TYPE_CHECKING, Literal, Sequence, overload -from airflow.utils.db import get_query_count -from airflow.utils.session import NEW_SESSION, create_session, provide_session +from sqlalchemy.ext.asyncio import AsyncSession + +from airflow.utils.db import get_query_count, get_query_count_async +from airflow.utils.session import NEW_SESSION, create_session, create_session_async, provide_session if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -53,7 +55,9 @@ def your_route(session: Annotated[Session, Depends(get_session)]): def apply_filters_to_select( - *, base_select: Select, filters: Sequence[BaseParam | None] | None = None + *, + base_select: Select, + filters: Sequence[BaseParam | None] | None = None, ) -> Select: if filters is None: return base_select @@ -65,6 +69,80 @@ def apply_filters_to_select( return base_select +async def get_async_session() -> AsyncSession: + """ + Dependency for providing a session. + + Example usage: + + .. code:: python + + @router.get("/your_path") + def your_route(session: Annotated[AsyncSession, Depends(get_async_session)]): + pass + """ + async with create_session_async() as session: + yield session + + +@overload +async def paginated_select_async( + *, + query: Select, + filters: Sequence[BaseParam] | None = None, + order_by: BaseParam | None = None, + offset: BaseParam | None = None, + limit: BaseParam | None = None, + session: AsyncSession, + return_total_entries: Literal[True] = True, +) -> tuple[Select, int]: ... + + +@overload +async def paginated_select_async( + *, + query: Select, + filters: Sequence[BaseParam] | None = None, + order_by: BaseParam | None = None, + offset: BaseParam | None = None, + limit: BaseParam | None = None, + session: AsyncSession, + return_total_entries: Literal[False], +) -> tuple[Select, None]: ... + + +async def paginated_select_async( + *, + query: Select, + filters: Sequence[BaseParam | None] | None = None, + order_by: BaseParam | None = None, + offset: BaseParam | None = None, + limit: BaseParam | None = None, + session: AsyncSession, + return_total_entries: bool = True, +) -> tuple[Select, int | None]: + query = apply_filters_to_select( + base_select=query, + filters=filters, + ) + + total_entries = None + if return_total_entries: + total_entries = await get_query_count_async(query, session=session) + + # TODO: Re-enable when permissions are handled. Readable / writable entities, + # for instance: + # readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user) + # dags_select = dags_select.where(DagModel.dag_id.in_(readable_dags)) + + query = apply_filters_to_select( + base_select=query, + filters=[order_by, offset, limit], + ) + + return query, total_entries + + @overload def paginated_select( *, diff --git a/airflow/api_fastapi/core_api/routes/public/backfills.py b/airflow/api_fastapi/core_api/routes/public/backfills.py index aa6f540d3279..78b2beb55889 100644 --- a/airflow/api_fastapi/core_api/routes/public/backfills.py +++ b/airflow/api_fastapi/core_api/routes/public/backfills.py @@ -20,9 +20,10 @@ from fastapi import Depends, HTTPException, status from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session -from airflow.api_fastapi.common.db.common import get_session, paginated_select +from airflow.api_fastapi.common.db.common import get_async_session, get_session, paginated_select_async from airflow.api_fastapi.common.parameters import QueryLimit, QueryOffset, SortParam from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.core_api.datamodels.backfills import ( @@ -49,7 +50,7 @@ @backfills_router.get( path="", ) -def list_backfills( +async def list_backfills( dag_id: str, limit: QueryLimit, offset: QueryOffset, @@ -57,18 +58,16 @@ def list_backfills( SortParam, Depends(SortParam(["id"], Backfill).dynamic_depends()), ], - session: Annotated[Session, Depends(get_session)], + session: Annotated[AsyncSession, Depends(get_async_session)], ) -> BackfillCollectionResponse: - select_stmt, total_entries = paginated_select( - select=select(Backfill).where(Backfill.dag_id == dag_id), + select_stmt, total_entries = await paginated_select_async( + query=select(Backfill).where(Backfill.dag_id == dag_id), order_by=order_by, offset=offset, limit=limit, session=session, ) - - backfills = session.scalars(select_stmt) - + backfills = await session.scalars(select_stmt) return BackfillCollectionResponse( backfills=backfills, total_entries=total_entries, diff --git a/airflow/settings.py b/airflow/settings.py index 5b458efcba47..76b3e948964f 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -31,7 +31,7 @@ import pluggy from packaging.version import Version from sqlalchemy import create_engine, exc, text -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession as SAAsyncSession, create_async_engine from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.pool import NullPool @@ -111,7 +111,7 @@ # this is achieved by the Session factory above. NonScopedSession: Callable[..., SASession] async_engine: AsyncEngine -create_async_session: Callable[..., AsyncSession] +AsyncSession: Callable[..., SAAsyncSession] # The JSON library to use for DAG Serialization and De-Serialization json = json @@ -469,7 +469,7 @@ def configure_orm(disable_connection_pool=False, pool_class=None): global Session global engine global async_engine - global create_async_session + global AsyncSession global NonScopedSession if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true": @@ -498,11 +498,11 @@ def configure_orm(disable_connection_pool=False, pool_class=None): engine = create_engine(SQL_ALCHEMY_CONN, connect_args=connect_args, **engine_args, future=True) async_engine = create_async_engine(SQL_ALCHEMY_CONN_ASYNC, future=True) - create_async_session = sessionmaker( + AsyncSession = sessionmaker( bind=async_engine, autocommit=False, autoflush=False, - class_=AsyncSession, + class_=SAAsyncSession, expire_on_commit=False, ) mask_secret(engine.url.password) diff --git a/airflow/utils/db.py b/airflow/utils/db.py index d8939a117317..c899ebf615d0 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -70,6 +70,7 @@ from alembic.runtime.environment import EnvironmentContext from alembic.script import ScriptDirectory from sqlalchemy.engine import Row + from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from sqlalchemy.sql.elements import ClauseElement, TextClause from sqlalchemy.sql.selectable import Select @@ -1447,6 +1448,21 @@ def get_query_count(query_stmt: Select, *, session: Session) -> int: return session.scalar(count_stmt) +async def get_query_count_async(query: Select, *, session: AsyncSession) -> int: + """ + Get count of a query. + + A SELECT COUNT() FROM is issued against the subquery built from the + given statement. The ORDER BY clause is stripped from the statement + since it's unnecessary for COUNT, and can impact query planning and + degrade performance. + + :meta private: + """ + count_stmt = select(func.count()).select_from(query.order_by(None).subquery()) + return await session.scalar(count_stmt) + + def check_query_exists(query_stmt: Select, *, session: Session) -> bool: """ Check whether there is at least one row matching a query. diff --git a/airflow/utils/session.py b/airflow/utils/session.py index a63d3f3f937a..49383cdf4a8b 100644 --- a/airflow/utils/session.py +++ b/airflow/utils/session.py @@ -65,6 +65,24 @@ def create_session(scoped: bool = True) -> Generator[SASession, None, None]: session.close() +@contextlib.asynccontextmanager +async def create_session_async(): + """ + Context manager to create async session. + + :meta private: + """ + from airflow.settings import AsyncSession + + async with AsyncSession() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + + PS = ParamSpec("PS") RT = TypeVar("RT") diff --git a/tests/utils/test_session.py b/tests/utils/test_session.py index 02cba9e070dc..8d26a25c626a 100644 --- a/tests/utils/test_session.py +++ b/tests/utils/test_session.py @@ -58,9 +58,9 @@ def test_provide_session_with_kwargs(self): @pytest.mark.asyncio async def test_async_session(self): - from airflow.settings import create_async_session + from airflow.settings import AsyncSession - session = create_async_session() + session = AsyncSession() session.add(Log(event="hihi1234")) await session.commit() my_special_log_event = await session.scalar(select(Log).where(Log.event == "hihi1234").limit(1)) From fedfaa703e51e6857a0478e0dadc498c5f6106bc Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Fri, 22 Nov 2024 20:42:03 +0000 Subject: [PATCH 05/14] Use `JSONB` type for `XCom.value` column in PostgreSQL (#44290) `JSONB` is more efficient, we already use that in migration, this makes it consistent when DB tables are created from ORM. --- airflow/models/xcom.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py index 45208e353bdc..5b6f83f4d59b 100644 --- a/airflow/models/xcom.py +++ b/airflow/models/xcom.py @@ -34,6 +34,7 @@ select, text, ) +from sqlalchemy.dialects import postgresql from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.orm import Query, reconstructor, relationship @@ -79,7 +80,7 @@ class BaseXCom(TaskInstanceDependencies, LoggingMixin): dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False) run_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False) - value = Column(JSON) + value = Column(JSON().with_variant(postgresql.JSONB, "postgresql")) timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False) __table_args__ = ( From ab38d01c5c06a6b88a703933eed52f2e8ffe19bd Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Fri, 22 Nov 2024 21:20:36 +0000 Subject: [PATCH 06/14] Bump Ruff to 0.8.0 (#44287) --- .pre-commit-config.yaml | 4 ++-- airflow/models/baseoperator.py | 2 +- airflow/serialization/serialized_objects.py | 4 ++-- dev/breeze/src/airflow_breeze/utils/console.py | 2 +- hatch_build.py | 2 +- pyproject.toml | 15 +++++++-------- task_sdk/pyproject.toml | 4 ++-- 7 files changed, 16 insertions(+), 17 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d072d21055cf..1525333ed9db 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -360,7 +360,7 @@ repos: types_or: [python, pyi] args: [--fix] require_serial: true - additional_dependencies: ["ruff==0.7.3"] + additional_dependencies: ["ruff==0.8.0"] exclude: ^.*/.*_vendor/|^tests/dags/test_imports.py|^performance/tests/test_.*.py - id: ruff-format name: Run 'ruff format' @@ -370,7 +370,7 @@ repos: types_or: [python, pyi] args: [] require_serial: true - additional_dependencies: ["ruff==0.7.3"] + additional_dependencies: ["ruff==0.8.0"] exclude: ^.*/.*_vendor/|^tests/dags/test_imports.py$ - id: replace-bad-characters name: Replace bad characters diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 108d9e51d758..169fd548afe0 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -116,7 +116,7 @@ # Todo: AIP-44: Once we get rid of AIP-44 we can remove this. But without this here pydantic fails to resolve # types for serialization -from airflow.utils.task_group import TaskGroup # noqa: TCH001 +from airflow.utils.task_group import TaskGroup # noqa: TC001 TaskPreExecuteHook = Callable[[Context], None] TaskPostExecuteHook = Callable[[Context, Any], None] diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 61d851aaed11..21fa575676a5 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -110,9 +110,9 @@ HAS_KUBERNETES: bool try: - from kubernetes.client import models as k8s # noqa: TCH004 + from kubernetes.client import models as k8s # noqa: TC004 - from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator # noqa: TCH004 + from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator # noqa: TC004 except ImportError: pass diff --git a/dev/breeze/src/airflow_breeze/utils/console.py b/dev/breeze/src/airflow_breeze/utils/console.py index 0b8861673883..910a687004d5 100644 --- a/dev/breeze/src/airflow_breeze/utils/console.py +++ b/dev/breeze/src/airflow_breeze/utils/console.py @@ -83,7 +83,7 @@ class Output(NamedTuple): @property def file(self) -> TextIO: - return open(self.file_name, "a+t") + return open(self.file_name, "a+") @property def escaped_title(self) -> str: diff --git a/hatch_build.py b/hatch_build.py index 22627bfe94ad..ddaa7aa2671d 100644 --- a/hatch_build.py +++ b/hatch_build.py @@ -248,7 +248,7 @@ ], "devel-static-checks": [ "black>=23.12.0", - "ruff==0.7.3", + "ruff==0.8.0", "yamllint>=1.33.0", ], "devel-tests": [ diff --git a/pyproject.toml b/pyproject.toml index a6026395bcdc..719dd892383f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -216,7 +216,7 @@ extend-select = [ "UP", # Pyupgrade "ASYNC", # subset of flake8-async rules "ISC", # Checks for implicit literal string concatenation (auto-fixable) - "TCH", # Rules around TYPE_CHECKING blocks + "TC", # Rules around TYPE_CHECKING blocks "G", # flake8-logging-format rules "LOG", # flake8-logging rules, most of them autofixable "PT", # flake8-pytest-style rules @@ -260,8 +260,7 @@ ignore = [ "D203", "D212", # Conflicts with D213. Both can not be enabled. "E731", # Do not assign a lambda expression, use a def - "TCH003", # Do not move imports from stdlib to TYPE_CHECKING block - "PT004", # Fixture does not return anything, add leading underscore + "TC003", # Do not move imports from stdlib to TYPE_CHECKING block "PT006", # Wrong type of names in @pytest.mark.parametrize "PT007", # Wrong type of values in @pytest.mark.parametrize "PT013", # silly rule prohibiting e.g. `from pytest import param` @@ -313,8 +312,8 @@ section-order = [ testing = ["dev", "providers.tests", "task_sdk.tests", "tests_common", "tests"] [tool.ruff.lint.extend-per-file-ignores] -"airflow/__init__.py" = ["F401", "TCH004", "I002"] -"airflow/models/__init__.py" = ["F401", "TCH004"] +"airflow/__init__.py" = ["F401", "TC004", "I002"] +"airflow/models/__init__.py" = ["F401", "TC004"] "airflow/models/sqla_models.py" = ["F401"] "providers/src/airflow/providers/__init__.py" = ["I002"] "providers/src/airflow/__init__.py" = ["I002"] @@ -326,12 +325,12 @@ testing = ["dev", "providers.tests", "task_sdk.tests", "tests_common", "tests"] # The Pydantic representations of SqlAlchemy Models are not parsed well with Pydantic # when __future__.annotations is used so we need to skip them from upgrading # Pydantic also require models to be imported during execution -"airflow/serialization/pydantic/*.py" = ["I002", "UP007", "TCH001"] +"airflow/serialization/pydantic/*.py" = ["I002", "UP007", "TC001"] # Failing to detect types and functions used in `Annotated[...]` syntax as required at runtime. # Annotated is central for FastAPI dependency injection, skipping rules for FastAPI folders. -"airflow/api_fastapi/*" = ["TCH001", "TCH002"] -"tests/api_fastapi/*" = ["TCH001", "TCH002"] +"airflow/api_fastapi/*" = ["TC001", "TC002"] +"tests/api_fastapi/*" = ["T001", "TC002"] # Ignore pydoc style from these "*.pyi" = ["D"] diff --git a/task_sdk/pyproject.toml b/task_sdk/pyproject.toml index 5da673a79bf0..170aff5ec244 100644 --- a/task_sdk/pyproject.toml +++ b/task_sdk/pyproject.toml @@ -62,11 +62,11 @@ namespace-packages = ["src/airflow"] # Pycharm barfs if this "stub" file has future imports "src/airflow/__init__.py" = ["I002"] -"src/airflow/sdk/__init__.py" = ["TCH004"] +"src/airflow/sdk/__init__.py" = ["TC004"] # msgspec needs types for annotations to be defined, even with future # annotations, so disable the "type check only import" for these files -"src/airflow/sdk/api/datamodels/*.py" = ["TCH001"] +"src/airflow/sdk/api/datamodels/*.py" = ["TC001"] # Only the public API should _require_ docstrings on classes "!src/airflow/sdk/definitions/*" = ["D101"] From 91bd1eafb035ff311a9d573b3004699e68200e08 Mon Sep 17 00:00:00 2001 From: olegkachur-e Date: Fri, 22 Nov 2024 22:59:31 +0100 Subject: [PATCH 07/14] Introduce gcp advance API (V3) translate native datasets operators (#44271) - Add support for native datasets for Cloud Translation API. - The datasets created via automl API are considered legacy, as they keep been supported, all new enhancements will be avaliable for native datasets(reccomended), created by Cloud Translate API, see more: https://cloud.google.com/translate/docs/advanced/automl-upgrade. Co-authored-by: Oleg Kachur --- .../operators/cloud/translate.rst | 85 +++++ docs/spelling_wordlist.txt | 1 + .../providers/google/cloud/hooks/translate.py | 240 +++++++++++- .../providers/google/cloud/links/translate.py | 56 +++ .../google/cloud/operators/translate.py | 344 +++++++++++++++++- .../airflow/providers/google/provider.yaml | 2 + .../google/cloud/operators/test_translate.py | 205 ++++++++++- .../translate/example_translate_dataset.py | 153 ++++++++ 8 files changed, 1075 insertions(+), 11 deletions(-) create mode 100644 providers/tests/system/google/cloud/translate/example_translate_dataset.py diff --git a/docs/apache-airflow-providers-google/operators/cloud/translate.rst b/docs/apache-airflow-providers-google/operators/cloud/translate.rst index 6bcc32ec669c..d56fac26dbeb 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/translate.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/translate.rst @@ -100,11 +100,96 @@ For parameter definition, take a look at :class:`~airflow.providers.google.cloud.operators.translate.TranslateTextBatchOperator` +.. _howto/operator:TranslateCreateDatasetOperator: + +TranslateCreateDatasetOperator +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Create a native translation dataset using Cloud Translate API (Advanced V3). + +For parameter definition, take a look at +:class:`~airflow.providers.google.cloud.operators.translate.TranslateCreateDatasetOperator` + +Using the operator +"""""""""""""""""" + +Basic usage of the operator: + +.. exampleinclude:: /../../providers/tests/system/google/cloud/translate/example_translate_dataset.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_translate_automl_create_dataset] + :end-before: [END howto_operator_translate_automl_create_dataset] + + +.. _howto/operator:TranslateImportDataOperator: + +TranslateImportDataOperator +^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Import data to the existing native dataset, using Cloud Translate API (Advanced V3). + +For parameter definition, take a look at +:class:`~airflow.providers.google.cloud.operators.translate.TranslateImportDataOperator` + +Using the operator +"""""""""""""""""" + +Basic usage of the operator: + +.. exampleinclude:: /../../providers/tests/system/google/cloud/translate/example_translate_dataset.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_translate_automl_import_data] + :end-before: [END howto_operator_translate_automl_import_data] + + +.. _howto/operator:TranslateDatasetsListOperator: + +TranslateDatasetsListOperator +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Get list of translation datasets using Cloud Translate API (Advanced V3). + +For parameter definition, take a look at +:class:`~airflow.providers.google.cloud.operators.translate.TranslateDatasetsListOperator` + +Using the operator +"""""""""""""""""" + +Basic usage of the operator: + +.. exampleinclude:: /../../providers/tests/system/google/cloud/translate/example_translate_dataset.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_translate_automl_list_datasets] + :end-before: [END howto_operator_translate_automl_list_datasets] + + +.. _howto/operator:TranslateDeleteDatasetOperator: + +TranslateDeleteDatasetOperator +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Delete a native translation dataset using Cloud Translate API (Advanced V3). + +For parameter definition, take a look at +:class:`~airflow.providers.google.cloud.operators.translate.TranslateDeleteDatasetOperator` + +Using the operator +"""""""""""""""""" + +Basic usage of the operator: + +.. exampleinclude:: /../../providers/tests/system/google/cloud/translate/example_translate_dataset.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_translate_automl_delete_dataset] + :end-before: [END howto_operator_translate_automl_delete_dataset] + + More information """""""""""""""""" See: Base (V2) `Google Cloud Translate documentation `_. Advanced (V3) `Google Cloud Translate (Advanced) documentation `_. +Datasets `Legacy and native dataset comparison `_. Reference diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 20e10c44a12c..fa8ffb4a2c3c 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -967,6 +967,7 @@ LineItem lineterminator linter linux +ListDatasetsPager ListGenerator ListInfoTypesResponse ListSecretsPager diff --git a/providers/src/airflow/providers/google/cloud/hooks/translate.py b/providers/src/airflow/providers/google/cloud/hooks/translate.py index 51cb88f1bace..6ddb220f3e78 100644 --- a/providers/src/airflow/providers/google/cloud/hooks/translate.py +++ b/providers/src/airflow/providers/google/cloud/hooks/translate.py @@ -29,6 +29,7 @@ from google.api_core.exceptions import GoogleAPICallError from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault +from google.api_core.retry import Retry from google.cloud.translate_v2 import Client from google.cloud.translate_v3 import TranslationServiceClient @@ -38,13 +39,31 @@ if TYPE_CHECKING: from google.api_core.operation import Operation - from google.api_core.retry import Retry + from google.cloud.translate_v3.services.translation_service import pagers from google.cloud.translate_v3.types import ( + DatasetInputConfig, InputConfig, OutputConfig, TranslateTextGlossaryConfig, TransliterationConfig, + automl_translation, ) + from proto import Message + + +class WaitOperationNotDoneYetError(Exception): + """Wait operation not done yet error.""" + + pass + + +def _if_exc_is_wait_failed_error(exc: Exception): + return isinstance(exc, WaitOperationNotDoneYetError) + + +def _check_if_operation_done(operation: Operation): + if not operation.done(): + raise WaitOperationNotDoneYetError("Operation is not done yet.") class CloudTranslateHook(GoogleBaseHook): @@ -163,7 +182,42 @@ def get_client(self) -> TranslationServiceClient: return self._client @staticmethod - def wait_for_operation(operation: Operation, timeout: int | None = None): + def wait_for_operation_done( + *, + operation: Operation, + timeout: float | None = None, + initial: float = 3, + multiplier: float = 2, + maximum: float = 3600, + ) -> None: + """ + Wait for long-running operation to be done. + + Calls operation.done() until success or timeout exhaustion, following the back-off retry strategy. + See `google.api_core.retry.Retry`. + It's intended use on `Operation` instances that have empty result + (:class `google.protobuf.empty_pb2.Empty`) by design. + Thus calling operation.result() for such operation triggers the exception + ``GoogleAPICallError("Unexpected state: Long-running operation had neither response nor error set.")`` + even though operation itself is totally fine. + """ + wait_op_for_done = Retry( + predicate=_if_exc_is_wait_failed_error, + initial=initial, + timeout=timeout, + multiplier=multiplier, + maximum=maximum, + )(_check_if_operation_done) + try: + wait_op_for_done(operation=operation) + except GoogleAPICallError: + if timeout: + timeout = int(timeout) + error = operation.exception(timeout=timeout) + raise AirflowException(error) + + @staticmethod + def wait_for_operation_result(operation: Operation, timeout: int | None = None) -> Message: """Wait for long-lasting operation to complete.""" try: return operation.result(timeout=timeout) @@ -171,6 +225,11 @@ def wait_for_operation(operation: Operation, timeout: int | None = None): error = operation.exception(timeout=timeout) raise AirflowException(error) + @staticmethod + def extract_object_id(obj: dict) -> str: + """Return unique id of the object.""" + return obj["name"].rpartition("/")[-1] + def translate_text( self, *, @@ -208,12 +267,10 @@ def translate_text( If not specified, 'global' is used. Non-global location is required for requests using AutoML models or custom glossaries. - Models and glossaries must be within the same region (have the same location-id). :param model: Optional. The ``model`` type requested for this translation. If not provided, the default Google model (NMT) will be used. - The format depends on model type: - AutoML Translation models: @@ -308,8 +365,8 @@ def batch_translate_text( :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. - :returns: Operation object with the batch text translate results, - that are returned by batches as they are ready. + :return: Operation object with the batch text translate results, + that are returned by batches as they are ready. """ client = self.get_client() if location == "global": @@ -334,3 +391,174 @@ def batch_translate_text( metadata=metadata, ) return result + + def create_dataset( + self, + *, + project_id: str = PROVIDE_PROJECT_ID, + location: str, + dataset: dict | automl_translation.Dataset, + timeout: float | _MethodDefault = DEFAULT, + metadata: Sequence[tuple[str, str]] = (), + retry: Retry | _MethodDefault | None = DEFAULT, + ) -> Operation: + """ + Create the translation dataset. + + :param dataset: The dataset to create. If a dict is provided, it must correspond to + the automl_translation.Dataset type. + :param project_id: ID of the Google Cloud project where dataset is located. If not provided + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: `Operation` object for the dataset to be created. + """ + client = self.get_client() + parent = f"projects/{project_id or self.project_id}/locations/{location}" + return client.create_dataset( + request={"parent": parent, "dataset": dataset}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def get_dataset( + self, + dataset_id: str, + project_id: str, + location: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | _MethodDefault = DEFAULT, + metadata: Sequence[tuple[str, str]] = (), + ) -> automl_translation.Dataset: + """ + Retrieve the dataset for the given dataset_id. + + :param dataset_id: ID of translation dataset to be retrieved. + :param project_id: ID of the Google Cloud project where dataset is located. If not provided + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: `automl_translation.Dataset` instance. + """ + client = self.get_client() + name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}" + return client.get_dataset( + request={"name": name}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def import_dataset_data( + self, + dataset_id: str, + location: str, + input_config: dict | DatasetInputConfig, + project_id: str = PROVIDE_PROJECT_ID, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """ + Import data into the translation dataset. + + :param dataset_id: ID of the translation dataset. + :param input_config: The desired input location and its domain specific semantics, if any. + If a dict is provided, it must be of the same form as the protobuf message InputConfig. + :param project_id: ID of the Google Cloud project where dataset is located if None then + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: `Operation` object for the import data. + """ + client = self.get_client() + name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}" + result = client.import_data( + request={"dataset": name, "input_config": input_config}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + def list_datasets( + self, + project_id: str, + location: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | _MethodDefault = DEFAULT, + metadata: Sequence[tuple[str, str]] = (), + ) -> pagers.ListDatasetsPager: + """ + List translation datasets in a project. + + :param project_id: ID of the Google Cloud project where dataset is located. If not provided + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: ``pagers.ListDatasetsPager`` instance, iterable object to retrieve the datasets list. + """ + client = self.get_client() + parent = f"projects/{project_id}/locations/{location}" + result = client.list_datasets( + request={"parent": parent}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + def delete_dataset( + self, + dataset_id: str, + project_id: str, + location: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """ + Delete the translation dataset and all of its contents. + + :param dataset_id: ID of dataset to be deleted. + :param project_id: ID of the Google Cloud project where dataset is located. If not provided + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: `Operation` object with dataset deletion results, when finished. + """ + client = self.get_client() + name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}" + result = client.delete_dataset( + request={"name": name}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result diff --git a/providers/src/airflow/providers/google/cloud/links/translate.py b/providers/src/airflow/providers/google/cloud/links/translate.py index d8cbd18d00da..0d1489ddcfc8 100644 --- a/providers/src/airflow/providers/google/cloud/links/translate.py +++ b/providers/src/airflow/providers/google/cloud/links/translate.py @@ -45,6 +45,11 @@ TRANSLATION_TRANSLATE_TEXT_BATCH = BASE_LINK + "/storage/browser/{output_uri_prefix}?project={project_id}" +TRANSLATION_NATIVE_DATASET_LINK = ( + TRANSLATION_BASE_LINK + "/locations/{location}/datasets/{dataset_id}/sentences?project={project_id}" +) +TRANSLATION_NATIVE_LIST_LINK = TRANSLATION_BASE_LINK + "/datasets?project={project_id}" + class TranslationLegacyDatasetLink(BaseGoogleLink): """ @@ -214,3 +219,54 @@ def persist( "output_uri_prefix": TranslateTextBatchLink.extract_output_uri_prefix(output_config), }, ) + + +class TranslationNativeDatasetLink(BaseGoogleLink): + """ + Helper class for constructing Legacy Translation Dataset link. + + Legacy Datasets are created and managed by AutoML API. + """ + + name = "Translation Native Dataset" + key = "translation_naive_dataset" + format_str = TRANSLATION_NATIVE_DATASET_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + dataset_id: str, + project_id: str, + ): + task_instance.xcom_push( + context, + key=TranslationNativeDatasetLink.key, + value={"location": task_instance.location, "dataset_id": dataset_id, "project_id": project_id}, + ) + + +class TranslationDatasetsListLink(BaseGoogleLink): + """ + Helper class for constructing Translation Datasets List link. + + Both legacy and native datasets are available under this link. + """ + + name = "Translation Dataset List" + key = "translation_dataset_list" + format_str = TRANSLATION_DATASET_LIST_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + ): + task_instance.xcom_push( + context, + key=TranslationDatasetsListLink.key, + value={ + "project_id": project_id, + }, + ) diff --git a/providers/src/airflow/providers/google/cloud/operators/translate.py b/providers/src/airflow/providers/google/cloud/operators/translate.py index a0fa9243e01a..d384e9b8efa9 100644 --- a/providers/src/airflow/providers/google/cloud/operators/translate.py +++ b/providers/src/airflow/providers/google/cloud/operators/translate.py @@ -26,17 +26,23 @@ from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.translate import CloudTranslateHook, TranslateHook -from airflow.providers.google.cloud.links.translate import TranslateTextBatchLink +from airflow.providers.google.cloud.links.translate import ( + TranslateTextBatchLink, + TranslationDatasetsListLink, + TranslationNativeDatasetLink, +) from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID if TYPE_CHECKING: from google.api_core.retry import Retry from google.cloud.translate_v3.types import ( + DatasetInputConfig, InputConfig, OutputConfig, TranslateTextGlossaryConfig, TransliterationConfig, + automl_translation, ) from airflow.utils.context import Context @@ -266,7 +272,7 @@ class TranslateTextBatchOperator(GoogleCloudBaseOperator): See https://cloud.google.com/translate/docs/advanced/batch-translation For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:TranslateTextBatchOperator`. + :ref:`howto/operator:TranslateTextBatchOperator`. :param project_id: Optional. The ID of the Google Cloud project that the service belongs to. If not specified the hook project_id will be used. @@ -381,6 +387,338 @@ def execute(self, context: Context) -> dict: project_id=self.project_id or hook.project_id, output_config=self.output_config, ) - hook.wait_for_operation(translate_operation) + hook.wait_for_operation_result(translate_operation) self.log.info("Translate text batch job finished") return {"batch_text_translate_results": self.output_config["gcs_destination"]} + + +class TranslateCreateDatasetOperator(GoogleCloudBaseOperator): + """ + Create a Google Cloud Translate dataset. + + Creates a `native` translation dataset, using API V3. + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TranslateCreateDatasetOperator`. + + :param dataset: The dataset to create. If a dict is provided, it must correspond to + the automl_translation.Dataset type. + :param project_id: ID of the Google Cloud project where dataset is located. + If not provided default project_id is used. + :param location: The location of the project. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "dataset", + "location", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + operator_extra_links = (TranslationNativeDatasetLink(),) + + def __init__( + self, + *, + project_id: str = PROVIDE_PROJECT_ID, + location: str, + dataset: dict | automl_translation.Dataset, + metadata: Sequence[tuple[str, str]] = (), + timeout: float | _MethodDefault = DEFAULT, + retry: Retry | _MethodDefault | None = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.dataset = dataset + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context) -> str: + hook = TranslateHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Dataset creation started %s...", self.dataset) + result_operation = hook.create_dataset( + dataset=self.dataset, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + result = hook.wait_for_operation_result(result_operation) + result = type(result).to_dict(result) + dataset_id = hook.extract_object_id(result) + self.xcom_push(context, key="dataset_id", value=dataset_id) + self.log.info("Dataset creation complete. The dataset_id: %s.", dataset_id) + + project_id = self.project_id or hook.project_id + TranslationNativeDatasetLink.persist( + context=context, + task_instance=self, + dataset_id=dataset_id, + project_id=project_id, + ) + return result + + +class TranslateDatasetsListOperator(GoogleCloudBaseOperator): + """ + Get a list of native Google Cloud Translation datasets in a project. + + Get project's list of `native` translation datasets, using API V3. + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TranslateDatasetsListOperator`. + + :param project_id: ID of the Google Cloud project where dataset is located. + If not provided default project_id is used. + :param location: The location of the project. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "location", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + operator_extra_links = (TranslationDatasetsListLink(),) + + def __init__( + self, + *, + project_id: str = PROVIDE_PROJECT_ID, + location: str, + metadata: Sequence[tuple[str, str]] = (), + timeout: float | _MethodDefault = DEFAULT, + retry: Retry | _MethodDefault = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = TranslateHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + project_id = self.project_id or hook.project_id + TranslationDatasetsListLink.persist( + context=context, + task_instance=self, + project_id=project_id, + ) + self.log.info("Requesting datasets list") + results_pager = hook.list_datasets( + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + result_ids = [] + for ds_item in results_pager: + ds_data = type(ds_item).to_dict(ds_item) + ds_id = hook.extract_object_id(ds_data) + result_ids.append(ds_id) + + self.log.info("Fetching the datasets list complete.") + return result_ids + + +class TranslateImportDataOperator(GoogleCloudBaseOperator): + """ + Import data to the translation dataset. + + Loads data to the translation dataset, using API V3. + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TranslateImportDataOperator`. + + :param dataset_id: The dataset_id of target native dataset to import data to. + :param input_config: The desired input location of translations language pairs file. If a dict provided, + must follow the structure of DatasetInputConfig. + If a dict is provided, it must be of the same form as the protobuf message InputConfig. + :param project_id: ID of the Google Cloud project where dataset is located. If not provided + default project_id is used. + :param location: The location of the project. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "dataset_id", + "input_config", + "location", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + operator_extra_links = (TranslationNativeDatasetLink(),) + + def __init__( + self, + *, + dataset_id: str, + location: str, + input_config: dict | DatasetInputConfig, + project_id: str = PROVIDE_PROJECT_ID, + metadata: Sequence[tuple[str, str]] = (), + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.dataset_id = dataset_id + self.input_config = input_config + self.project_id = project_id + self.location = location + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = TranslateHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + self.log.info("Importing data to dataset...") + operation = hook.import_dataset_data( + dataset_id=self.dataset_id, + input_config=self.input_config, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + project_id = self.project_id or hook.project_id + TranslationNativeDatasetLink.persist( + context=context, + task_instance=self, + dataset_id=self.dataset_id, + project_id=project_id, + ) + hook.wait_for_operation_done(operation=operation, timeout=self.timeout) + self.log.info("Importing data finished!") + + +class TranslateDeleteDatasetOperator(GoogleCloudBaseOperator): + """ + Delete translation dataset and all of its contents. + + Deletes the translation dataset and it's data, using API V3. + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TranslateDeleteDatasetOperator`. + + :param dataset_id: The dataset_id of target native dataset to be deleted. + :param location: The location of the project. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "dataset_id", + "location", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + *, + dataset_id: str, + location: str, + project_id: str = PROVIDE_PROJECT_ID, + metadata: Sequence[tuple[str, str]] = (), + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.dataset_id = dataset_id + self.project_id = project_id + self.location = location + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = TranslateHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + self.log.info("Deleting the dataset %s...", self.dataset_id) + operation = hook.delete_dataset( + dataset_id=self.dataset_id, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + hook.wait_for_operation_done(operation=operation, timeout=self.timeout) + self.log.info("Dataset deletion complete!") diff --git a/providers/src/airflow/providers/google/provider.yaml b/providers/src/airflow/providers/google/provider.yaml index ce8ce057432d..57c5ddcd5ec3 100644 --- a/providers/src/airflow/providers/google/provider.yaml +++ b/providers/src/airflow/providers/google/provider.yaml @@ -1292,6 +1292,8 @@ extra-links: - airflow.providers.google.cloud.links.translate.TranslationLegacyModelTrainLink - airflow.providers.google.cloud.links.translate.TranslationLegacyModelPredictLink - airflow.providers.google.cloud.links.translate.TranslateTextBatchLink + - airflow.providers.google.cloud.links.translate.TranslationNativeDatasetLink + - airflow.providers.google.cloud.links.translate.TranslationDatasetsListLink secrets-backends: diff --git a/providers/tests/google/cloud/operators/test_translate.py b/providers/tests/google/cloud/operators/test_translate.py index 79f65395369b..45af2dae9289 100644 --- a/providers/tests/google/cloud/operators/test_translate.py +++ b/providers/tests/google/cloud/operators/test_translate.py @@ -19,15 +19,27 @@ from unittest import mock +from google.api_core.gapic_v1.method import DEFAULT +from google.cloud.translate_v3.types import automl_translation + +from airflow.providers.google.cloud.hooks.translate import TranslateHook from airflow.providers.google.cloud.operators.translate import ( CloudTranslateTextOperator, + TranslateCreateDatasetOperator, + TranslateDatasetsListOperator, + TranslateDeleteDatasetOperator, + TranslateImportDataOperator, TranslateTextBatchOperator, TranslateTextOperator, ) +from providers.tests.system.google.cloud.tasks.example_tasks import LOCATION + GCP_CONN_ID = "google_cloud_default" IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] PROJECT_ID = "test-project-id" +DATASET_ID = "sample_ds_id" +TIMEOUT_VALUE = 30 class TestCloudTranslate: @@ -97,7 +109,7 @@ def test_minimal_green_path(self, mock_hook): target_language_code="en", gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, - timeout=30, + timeout=TIMEOUT_VALUE, retry=None, model=None, ) @@ -117,7 +129,7 @@ def test_minimal_green_path(self, mock_hook): model=None, transliteration_config=None, glossary_config=None, - timeout=30, + timeout=TIMEOUT_VALUE, retry=None, metadata=(), ) @@ -185,3 +197,192 @@ def test_minimal_green_path(self, mock_hook, mock_link_persist): project_id=PROJECT_ID, output_config=OUTPUT_CONFIG, ) + + +class TestTranslateDatasetCreate: + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslationNativeDatasetLink.persist") + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslateCreateDatasetOperator.xcom_push") + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook") + def test_minimal_green_path(self, mock_hook, mock_xcom_push, mock_link_persist): + DS_CREATION_RESULT_SAMPLE = { + "display_name": "", + "example_count": 0, + "name": f"projects/{PROJECT_ID}/locations/{LOCATION}/datasets/{DATASET_ID}", + "source_language_code": "", + "target_language_code": "", + "test_example_count": 0, + "train_example_count": 0, + "validate_example_count": 0, + } + sample_operation = mock.MagicMock() + sample_operation.result.return_value = automl_translation.Dataset(DS_CREATION_RESULT_SAMPLE) + + mock_hook.return_value.create_dataset.return_value = sample_operation + mock_hook.return_value.wait_for_operation_result.side_effect = lambda operation: operation.result() + mock_hook.return_value.extract_object_id = TranslateHook.extract_object_id + + DATASET_DATA = { + "display_name": "sample ds name", + "source_language_code": "es", + "target_language_code": "uk", + } + op = TranslateCreateDatasetOperator( + task_id="task_id", + dataset=DATASET_DATA, + project_id=PROJECT_ID, + location=LOCATION, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + timeout=TIMEOUT_VALUE, + retry=None, + ) + context = mock.MagicMock() + result = op.execute(context=context) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_hook.return_value.create_dataset.assert_called_once_with( + dataset=DATASET_DATA, + project_id=PROJECT_ID, + location=LOCATION, + timeout=TIMEOUT_VALUE, + retry=None, + metadata=(), + ) + mock_xcom_push.assert_called_once_with(context, key="dataset_id", value=DATASET_ID) + mock_link_persist.assert_called_once_with( + context=context, + dataset_id=DATASET_ID, + task_instance=op, + project_id=PROJECT_ID, + ) + assert result == DS_CREATION_RESULT_SAMPLE + + +class TestTranslateListDatasets: + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslationDatasetsListLink.persist") + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook") + def test_minimal_green_path(self, mock_hook, mock_link_persist): + DS_ID_1 = "sample_ds_1" + DS_ID_2 = "sample_ds_2" + dataset_result_1 = automl_translation.Dataset( + dict( + name=f"projects/{PROJECT_ID}/locations/{LOCATION}/datasets/{DS_ID_1}", + display_name="ds1_display_name", + ) + ) + dataset_result_2 = automl_translation.Dataset( + dict( + name=f"projects/{PROJECT_ID}/locations/{LOCATION}/datasets/{DS_ID_2}", + display_name="ds1_display_name", + ) + ) + mock_hook.return_value.list_datasets.return_value = [dataset_result_1, dataset_result_2] + mock_hook.return_value.extract_object_id = TranslateHook.extract_object_id + + op = TranslateDatasetsListOperator( + task_id="task_id", + project_id=PROJECT_ID, + location=LOCATION, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + timeout=TIMEOUT_VALUE, + retry=DEFAULT, + ) + context = mock.MagicMock() + result = op.execute(context=context) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_hook.return_value.list_datasets.assert_called_once_with( + project_id=PROJECT_ID, + location=LOCATION, + timeout=TIMEOUT_VALUE, + retry=DEFAULT, + metadata=(), + ) + mock_link_persist.assert_called_once_with( + context=context, + task_instance=op, + project_id=PROJECT_ID, + ) + assert result == [DS_ID_1, DS_ID_2] + + +class TestTranslateImportData: + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslationNativeDatasetLink.persist") + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook") + def test_minimal_green_path(self, mock_hook, mock_link_persist): + INPUT_CONFIG = { + "input_files": [{"usage": "UNASSIGNED", "gcs_source": {"input_uri": "import data gcs path"}}] + } + mock_hook.return_value.import_dataset_data.return_value = mock.MagicMock() + op = TranslateImportDataOperator( + task_id="task_id", + dataset_id=DATASET_ID, + input_config=INPUT_CONFIG, + project_id=PROJECT_ID, + location=LOCATION, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + timeout=TIMEOUT_VALUE, + retry=DEFAULT, + ) + context = mock.MagicMock() + op.execute(context=context) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_hook.return_value.import_dataset_data.assert_called_once_with( + dataset_id=DATASET_ID, + input_config=INPUT_CONFIG, + project_id=PROJECT_ID, + location=LOCATION, + timeout=TIMEOUT_VALUE, + retry=DEFAULT, + metadata=(), + ) + mock_link_persist.assert_called_once_with( + context=context, + dataset_id=DATASET_ID, + task_instance=op, + project_id=PROJECT_ID, + ) + + +class TestTranslateDeleteData: + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook") + def test_minimal_green_path(self, mock_hook): + m_delete_method_result = mock.MagicMock() + mock_hook.return_value.delete_dataset.return_value = m_delete_method_result + + wait_for_done = mock_hook.return_value.wait_for_operation_done + + op = TranslateDeleteDatasetOperator( + task_id="task_id", + dataset_id=DATASET_ID, + project_id=PROJECT_ID, + location=LOCATION, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + timeout=TIMEOUT_VALUE, + retry=DEFAULT, + ) + context = mock.MagicMock() + op.execute(context=context) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_hook.return_value.delete_dataset.assert_called_once_with( + dataset_id=DATASET_ID, + project_id=PROJECT_ID, + location=LOCATION, + timeout=TIMEOUT_VALUE, + retry=DEFAULT, + metadata=(), + ) + wait_for_done.assert_called_once_with(operation=m_delete_method_result, timeout=TIMEOUT_VALUE) diff --git a/providers/tests/system/google/cloud/translate/example_translate_dataset.py b/providers/tests/system/google/cloud/translate/example_translate_dataset.py new file mode 100644 index 000000000000..3ad732862449 --- /dev/null +++ b/providers/tests/system/google/cloud/translate/example_translate_dataset.py @@ -0,0 +1,153 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG that translates text in Google Cloud Translate using V3 API version +service in the Google Cloud. +""" + +from __future__ import annotations + +import os +from datetime import datetime + +from airflow.models.dag import DAG +from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator +from airflow.providers.google.cloud.operators.translate import ( + TranslateCreateDatasetOperator, + TranslateDatasetsListOperator, + TranslateDeleteDatasetOperator, + TranslateImportDataOperator, +) +from airflow.providers.google.cloud.transfers.gcs_to_gcs import GCSToGCSOperator +from airflow.utils.trigger_rule import TriggerRule + +DAG_ID = "gcp_translate_automl_native_dataset" +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +REGION = "us-central1" +RESOURCE_DATA_BUCKET = "airflow-system-tests-resources" + +DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-") + +DATA_FILE_NAME = "import_en-es.tsv" + +RESOURCE_PATH = f"V3_translate/create_ds/import_data/{DATA_FILE_NAME}" +COPY_DATA_PATH = f"gs://{RESOURCE_DATA_BUCKET}/V3_translate/create_ds/import_data/{DATA_FILE_NAME}" + +DST_PATH = f"translate/import/{DATA_FILE_NAME}" + +DATASET_DATA_PATH = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/{DST_PATH}" + + +DATASET = { + "display_name": f"op_ds_native{DAG_ID}_{ENV_ID}", + "source_language_code": "es", + "target_language_code": "en", +} + +with DAG( + DAG_ID, + schedule="@once", # Override to match your needs + start_date=datetime(2024, 11, 1), + catchup=False, + tags=[ + "example", + "translate_dataset", + ], +) as dag: + create_bucket = GCSCreateBucketOperator( + task_id="create_bucket", + bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, + storage_class="REGIONAL", + location=REGION, + ) + copy_dataset_source_tsv = GCSToGCSOperator( + task_id="copy_dataset_file", + source_bucket=RESOURCE_DATA_BUCKET, + source_object=RESOURCE_PATH, + destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME, + destination_object=DST_PATH, + ) + + # [START howto_operator_translate_automl_create_dataset] + create_dataset_op = TranslateCreateDatasetOperator( + task_id="translate_v3_ds_create", + dataset=DATASET, + project_id=PROJECT_ID, + location=REGION, + ) + # [END howto_operator_translate_automl_create_dataset] + + # [START howto_operator_translate_automl_import_data] + import_ds_data_op = TranslateImportDataOperator( + task_id="translate_v3_ds_import_data", + dataset_id=create_dataset_op.output["dataset_id"], + input_config={ + "input_files": [{"usage": "UNASSIGNED", "gcs_source": {"input_uri": DATASET_DATA_PATH}}] + }, + project_id=PROJECT_ID, + location=REGION, + ) + # [END howto_operator_translate_automl_import_data] + + # [START howto_operator_translate_automl_list_datasets] + list_datasets_op = TranslateDatasetsListOperator( + task_id="translate_v3_list_ds", + project_id=PROJECT_ID, + location=REGION, + ) + # [END howto_operator_translate_automl_list_datasets] + + # [START howto_operator_translate_automl_delete_dataset] + delete_ds_op = TranslateDeleteDatasetOperator( + task_id="translate_v3_ds_delete", + dataset_id=create_dataset_op.output["dataset_id"], + project_id=PROJECT_ID, + location=REGION, + ) + # [END howto_operator_translate_automl_delete_dataset] + + delete_bucket = GCSDeleteBucketOperator( + task_id="delete_bucket", + bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, + trigger_rule=TriggerRule.ALL_DONE, + ) + + ( + # TEST SETUP + [create_bucket >> copy_dataset_source_tsv] + # TEST BODY + >> create_dataset_op + >> import_ds_data_op + >> list_datasets_op + >> delete_ds_op + # TEST TEARDOWN + >> delete_bucket + ) + + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) From d6364992b8a48c907ac119fb21900e47d4adaf3c Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Fri, 22 Nov 2024 22:53:08 +0000 Subject: [PATCH 08/14] Allow dropping `_xcom_archive` table via CLI (#44291) This tables was created to not cause data loss (in https://github.com/apache/airflow/pull/44166) when upgrading from AF 2 to AF 3 if a user had pickled values in XCom table. - Introduced `ARCHIVED_TABLES_FROM_DB_MIGRATIONS` to track tables created during database migrations, such as `_xcom_archive`. - Added `_xcom_archive` to the db cleanup `config_list` for handling its records based on `timestamp`. - Add support in `airflow db drop-archived` to drop `_xcom_archive`. --- airflow/utils/db_cleanup.py | 16 ++++++++++++++-- newsfragments/aip-72.significant.rst | 6 +++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/airflow/utils/db_cleanup.py b/airflow/utils/db_cleanup.py index 9f0f8d63fe12..f71caf06ac8f 100644 --- a/airflow/utils/db_cleanup.py +++ b/airflow/utils/db_cleanup.py @@ -53,6 +53,10 @@ logger = logging.getLogger(__name__) ARCHIVE_TABLE_PREFIX = "_airflow_deleted__" +# Archived tables created by DB migrations +ARCHIVED_TABLES_FROM_DB_MIGRATIONS = [ + "_xcom_archive" # Table created by the AF 2 -> 3.0.0 migration when the XComs had pickled values +] @dataclass @@ -116,6 +120,7 @@ def readable_config(self): _TableConfig(table_name="task_instance_history", recency_column_name="start_date"), _TableConfig(table_name="task_reschedule", recency_column_name="start_date"), _TableConfig(table_name="xcom", recency_column_name="timestamp"), + _TableConfig(table_name="_xcom_archive", recency_column_name="timestamp"), _TableConfig(table_name="callback_request", recency_column_name="created_at"), _TableConfig(table_name="celery_taskmeta", recency_column_name="date_done"), _TableConfig(table_name="celery_tasksetmeta", recency_column_name="date_done"), @@ -380,13 +385,20 @@ def _effective_table_names(*, table_names: list[str] | None) -> tuple[set[str], def _get_archived_table_names(table_names: list[str] | None, session: Session) -> list[str]: inspector = inspect(session.bind) - db_table_names = [x for x in inspector.get_table_names() if x.startswith(ARCHIVE_TABLE_PREFIX)] + db_table_names = [ + x + for x in inspector.get_table_names() + if x.startswith(ARCHIVE_TABLE_PREFIX) or x in ARCHIVED_TABLES_FROM_DB_MIGRATIONS + ] effective_table_names, _ = _effective_table_names(table_names=table_names) # Filter out tables that don't start with the archive prefix archived_table_names = [ table_name for table_name in db_table_names - if any("__" + x + "__" in table_name for x in effective_table_names) + if ( + any("__" + x + "__" in table_name for x in effective_table_names) + or table_name in ARCHIVED_TABLES_FROM_DB_MIGRATIONS + ) ] return archived_table_names diff --git a/newsfragments/aip-72.significant.rst b/newsfragments/aip-72.significant.rst index 9fc34004de7a..e43e0c2f86c0 100644 --- a/newsfragments/aip-72.significant.rst +++ b/newsfragments/aip-72.significant.rst @@ -30,4 +30,8 @@ As part of this change the following breaking changes have occurred: The ``value`` field in the XCom table has been changed to a ``JSON`` type via DB migration. The XCom records that contains pickled data are archived in the ``_xcom_archive`` table. You can safely drop this table if you don't need - the data anymore. + the data anymore. To drop the table, you can use the following command or manually drop the table from the database. + + .. code-block:: bash + + airflow db drop-archived -t "_xcom_archive" From b22e3c1fcd5d92238f0c187c8338c11bdae73acb Mon Sep 17 00:00:00 2001 From: Amir Mor <49829354+amirmor1@users.noreply.github.com> Date: Sat, 23 Nov 2024 01:01:15 +0200 Subject: [PATCH 09/14] Fix Dataplex Data Quality partial update (#44262) * 44012 - Update index.rst * Fix Dataplex Data Quality Task partial update When we try to update dataplex data quality task using the DataplexCreateOrUpdateDataQualityScanOperator, it will first try to create the task, and only if it fails with AlreadyExists exception, it will try to update the task, but if you want to provide a partial parameters to the update (and not to replace the entire data scan properties), it will fail with AirflowException `Error creating Data Quality scan` because its missing mandatory parameters in the DataScan, and will never update the task. I've added a check to see if update_mask is not None, first try to do this update, and only if not -> try to create the task. Also moved the update section into a private function to reuse it this check, and later if we are trying to do a full update of the task * add empty line for lint * add test to verify update when update_mask is not none --------- Co-authored-by: Amir Mor --- .../google/cloud/operators/dataplex.py | 65 ++++++++++--------- .../google/cloud/operators/test_dataplex.py | 22 +++++++ 2 files changed, 57 insertions(+), 30 deletions(-) diff --git a/providers/src/airflow/providers/google/cloud/operators/dataplex.py b/providers/src/airflow/providers/google/cloud/operators/dataplex.py index 04edc10795f0..f77c648f20e1 100644 --- a/providers/src/airflow/providers/google/cloud/operators/dataplex.py +++ b/providers/src/airflow/providers/google/cloud/operators/dataplex.py @@ -686,39 +686,44 @@ def execute(self, context: Context): impersonation_chain=self.impersonation_chain, ) - self.log.info("Creating Dataplex Data Quality scan %s", self.data_scan_id) - try: - operation = hook.create_data_scan( - project_id=self.project_id, - region=self.region, - data_scan_id=self.data_scan_id, - body=self.body, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - hook.wait_for_operation(timeout=self.timeout, operation=operation) - self.log.info("Dataplex Data Quality scan %s created successfully!", self.data_scan_id) - except AlreadyExists: - self.log.info("Dataplex Data Quality scan already exists: %s", {self.data_scan_id}) - - operation = hook.update_data_scan( - project_id=self.project_id, - region=self.region, - data_scan_id=self.data_scan_id, - body=self.body, - update_mask=self.update_mask, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - hook.wait_for_operation(timeout=self.timeout, operation=operation) - self.log.info("Dataplex Data Quality scan %s updated successfully!", self.data_scan_id) - except GoogleAPICallError as e: - raise AirflowException(f"Error creating Data Quality scan {self.data_scan_id}", e) + if self.update_mask is not None: + self._update_data_scan(hook) + else: + self.log.info("Creating Dataplex Data Quality scan %s", self.data_scan_id) + try: + operation = hook.create_data_scan( + project_id=self.project_id, + region=self.region, + data_scan_id=self.data_scan_id, + body=self.body, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + hook.wait_for_operation(timeout=self.timeout, operation=operation) + self.log.info("Dataplex Data Quality scan %s created successfully!", self.data_scan_id) + except AlreadyExists: + self._update_data_scan(hook) + except GoogleAPICallError as e: + raise AirflowException(f"Error creating Data Quality scan {self.data_scan_id}", e) return self.data_scan_id + def _update_data_scan(self, hook: DataplexHook): + self.log.info("Dataplex Data Quality scan already exists: %s", {self.data_scan_id}) + operation = hook.update_data_scan( + project_id=self.project_id, + region=self.region, + data_scan_id=self.data_scan_id, + body=self.body, + update_mask=self.update_mask, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + hook.wait_for_operation(timeout=self.timeout, operation=operation) + self.log.info("Dataplex Data Quality scan %s updated successfully!", self.data_scan_id) + class DataplexGetDataQualityScanOperator(GoogleCloudBaseOperator): """ diff --git a/providers/tests/google/cloud/operators/test_dataplex.py b/providers/tests/google/cloud/operators/test_dataplex.py index 67c9b8ca10f9..1eec9008e2c1 100644 --- a/providers/tests/google/cloud/operators/test_dataplex.py +++ b/providers/tests/google/cloud/operators/test_dataplex.py @@ -672,6 +672,18 @@ def test_execute(self, hook_mock): api_version=API_VERSION, impersonation_chain=IMPERSONATION_CHAIN, ) + update_operator = DataplexCreateOrUpdateDataQualityScanOperator( + task_id=TASK_ID, + project_id=PROJECT_ID, + region=REGION, + data_scan_id=DATA_SCAN_ID, + body={}, + update_mask={}, + api_version=API_VERSION, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + update_operator.execute(context=mock.MagicMock()) hook_mock.return_value.create_data_scan.assert_called_once_with( project_id=PROJECT_ID, region=REGION, @@ -681,6 +693,16 @@ def test_execute(self, hook_mock): timeout=None, metadata=(), ) + hook_mock.return_value.update_data_scan.assert_called_once_with( + project_id=PROJECT_ID, + region=REGION, + data_scan_id=DATA_SCAN_ID, + body={}, + update_mask={}, + retry=DEFAULT, + timeout=None, + metadata=(), + ) class TestDataplexCreateDataProfileScanOperator: From 5f6b233906e69c6437bf556827844842683a5555 Mon Sep 17 00:00:00 2001 From: Kacper Muda Date: Sat, 23 Nov 2024 00:37:29 +0100 Subject: [PATCH 10/14] feat: add OpenLineage support for BigQueryToBigQueryOperator (#44214) Signed-off-by: Kacper Muda --- .../google/cloud/openlineage/utils.py | 75 ++-- .../cloud/transfers/bigquery_to_bigquery.py | 90 ++++- .../google/cloud/transfers/bigquery_to_gcs.py | 13 +- .../google/cloud/transfers/gcs_to_bigquery.py | 19 +- .../google/cloud/openlineage/test_utils.py | 137 +++++++- .../transfers/test_bigquery_to_bigquery.py | 332 +++++++++++++++++- .../cloud/transfers/test_bigquery_to_gcs.py | 21 +- .../cloud/transfers/test_gcs_to_bigquery.py | 11 +- 8 files changed, 594 insertions(+), 104 deletions(-) diff --git a/providers/src/airflow/providers/google/cloud/openlineage/utils.py b/providers/src/airflow/providers/google/cloud/openlineage/utils.py index 82172d5d241c..403023f7b431 100644 --- a/providers/src/airflow/providers/google/cloud/openlineage/utils.py +++ b/providers/src/airflow/providers/google/cloud/openlineage/utils.py @@ -27,6 +27,7 @@ from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.common.compat.openlineage.facet import ( + BaseFacet, ColumnLineageDatasetFacet, DocumentationDatasetFacet, Fields, @@ -41,50 +42,82 @@ BIGQUERY_URI = "bigquery" -def get_facets_from_bq_table(table: Table) -> dict[Any, Any]: +def get_facets_from_bq_table(table: Table) -> dict[str, BaseFacet]: """Get facets from BigQuery table object.""" - facets = { - "schema": SchemaDatasetFacet( + facets: dict[str, BaseFacet] = {} + if table.schema: + facets["schema"] = SchemaDatasetFacet( fields=[ SchemaDatasetFacetFields( - name=field.name, type=field.field_type, description=field.description + name=schema_field.name, type=schema_field.field_type, description=schema_field.description ) - for field in table.schema + for schema_field in table.schema ] - ), - "documentation": DocumentationDatasetFacet(description=table.description or ""), - } + ) + if table.description: + facets["documentation"] = DocumentationDatasetFacet(description=table.description) return facets def get_identity_column_lineage_facet( - field_names: list[str], + dest_field_names: list[str], input_datasets: list[Dataset], -) -> ColumnLineageDatasetFacet: +) -> dict[str, ColumnLineageDatasetFacet]: """ - Get column lineage facet. - - Simple lineage will be created, where each source column corresponds to single destination column - in each input dataset and there are no transformations made. + Get column lineage facet for identity transformations. + + This function generates a simple column lineage facet, where each destination column + consists of source columns of the same name from all input datasets that have that column. + The lineage assumes there are no transformations applied, meaning the columns retain their + identity between the source and destination datasets. + + Args: + dest_field_names: A list of destination column names for which lineage should be determined. + input_datasets: A list of input datasets with schema facets. + + Returns: + A dictionary containing a single key, `columnLineage`, mapped to a `ColumnLineageDatasetFacet`. + If no column lineage can be determined, an empty dictionary is returned - see Notes below. + + Notes: + - If any input dataset lacks a schema facet, the function immediately returns an empty dictionary. + - If any field in the source dataset's schema is not present in the destination table, + the function returns an empty dictionary. The destination table can contain extra fields, but all + source columns should be present in the destination table. + - If none of the destination columns can be matched to input dataset columns, an empty + dictionary is returned. + - Extra columns in the destination table that do not exist in the input datasets are ignored and + skipped in the lineage facet, as they cannot be traced back to a source column. + - The function assumes there are no transformations applied, meaning the columns retain their + identity between the source and destination datasets. """ - if field_names and not input_datasets: - raise ValueError("When providing `field_names` You must provide at least one `input_dataset`.") + fields_sources: dict[str, list[Dataset]] = {} + for ds in input_datasets: + if not ds.facets or "schema" not in ds.facets: + return {} + for schema_field in ds.facets["schema"].fields: # type: ignore[attr-defined] + if schema_field.name not in dest_field_names: + return {} + fields_sources[schema_field.name] = fields_sources.get(schema_field.name, []) + [ds] + + if not fields_sources: + return {} column_lineage_facet = ColumnLineageDatasetFacet( fields={ - field: Fields( + field_name: Fields( inputFields=[ - InputField(namespace=dataset.namespace, name=dataset.name, field=field) - for dataset in input_datasets + InputField(namespace=dataset.namespace, name=dataset.name, field=field_name) + for dataset in source_datasets ], transformationType="IDENTITY", transformationDescription="identical", ) - for field in field_names + for field_name, source_datasets in fields_sources.items() } ) - return column_lineage_facet + return {"columnLineage": column_lineage_facet} @define diff --git a/providers/src/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py b/providers/src/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py index 7be147d09b60..e1f3d3b13f56 100644 --- a/providers/src/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +++ b/providers/src/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py @@ -110,6 +110,7 @@ def __init__( self.location = location self.impersonation_chain = impersonation_chain self.hook: BigQueryHook | None = None + self._job_conf: dict = {} def _prepare_job_configuration(self): self.source_project_dataset_tables = ( @@ -154,39 +155,94 @@ def _prepare_job_configuration(self): return configuration - def _submit_job( - self, - hook: BigQueryHook, - configuration: dict, - ) -> str: - job = hook.insert_job(configuration=configuration, project_id=hook.project_id) - return job.job_id - def execute(self, context: Context) -> None: self.log.info( "Executing copy of %s into: %s", self.source_project_dataset_tables, self.destination_project_dataset_table, ) - hook = BigQueryHook( + self.hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, location=self.location, impersonation_chain=self.impersonation_chain, ) - self.hook = hook - if not hook.project_id: + if not self.hook.project_id: raise ValueError("The project_id should be set") configuration = self._prepare_job_configuration() - job_id = self._submit_job(hook=hook, configuration=configuration) + self._job_conf = self.hook.insert_job( + configuration=configuration, project_id=self.hook.project_id + ).to_api_repr() - job = hook.get_job(job_id=job_id, location=self.location).to_api_repr() - conf = job["configuration"]["copy"]["destinationTable"] + dest_table_info = self._job_conf["configuration"]["copy"]["destinationTable"] BigQueryTableLink.persist( context=context, task_instance=self, - dataset_id=conf["datasetId"], - project_id=conf["projectId"], - table_id=conf["tableId"], + dataset_id=dest_table_info["datasetId"], + project_id=dest_table_info["projectId"], + table_id=dest_table_info["tableId"], + ) + + def get_openlineage_facets_on_complete(self, task_instance): + """Implement on_complete as we will include final BQ job id.""" + from airflow.providers.common.compat.openlineage.facet import ( + Dataset, + ExternalQueryRunFacet, + ) + from airflow.providers.google.cloud.openlineage.utils import ( + BIGQUERY_NAMESPACE, + get_facets_from_bq_table, + get_identity_column_lineage_facet, ) + from airflow.providers.openlineage.extractors import OperatorLineage + + if not self.hook: + self.hook = BigQueryHook( + gcp_conn_id=self.gcp_conn_id, + location=self.location, + impersonation_chain=self.impersonation_chain, + ) + + if not self._job_conf: + self.log.debug("OpenLineage could not find BQ job configuration.") + return OperatorLineage() + + bq_job_id = self._job_conf["jobReference"]["jobId"] + source_tables_info = self._job_conf["configuration"]["copy"]["sourceTables"] + dest_table_info = self._job_conf["configuration"]["copy"]["destinationTable"] + + run_facets = { + "externalQuery": ExternalQueryRunFacet(externalQueryId=bq_job_id, source="bigquery"), + } + + input_datasets = [] + for in_table_info in source_tables_info: + table_id = ".".join( + (in_table_info["projectId"], in_table_info["datasetId"], in_table_info["tableId"]) + ) + table_object = self.hook.get_client().get_table(table_id) + input_datasets.append( + Dataset( + namespace=BIGQUERY_NAMESPACE, name=table_id, facets=get_facets_from_bq_table(table_object) + ) + ) + + out_table_id = ".".join( + (dest_table_info["projectId"], dest_table_info["datasetId"], dest_table_info["tableId"]) + ) + out_table_object = self.hook.get_client().get_table(out_table_id) + output_dataset_facets = { + **get_facets_from_bq_table(out_table_object), + **get_identity_column_lineage_facet( + dest_field_names=[field.name for field in out_table_object.schema], + input_datasets=input_datasets, + ), + } + output_dataset = Dataset( + namespace=BIGQUERY_NAMESPACE, + name=out_table_id, + facets=output_dataset_facets, + ) + + return OperatorLineage(inputs=input_datasets, outputs=[output_dataset], run_facets=run_facets) diff --git a/providers/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py b/providers/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py index e2588b8976e3..2833f79a3e81 100644 --- a/providers/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +++ b/providers/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py @@ -294,6 +294,7 @@ def get_openlineage_facets_on_complete(self, task_instance): from pathlib import Path from airflow.providers.common.compat.openlineage.facet import ( + BaseFacet, Dataset, ExternalQueryRunFacet, Identifier, @@ -322,12 +323,12 @@ def get_openlineage_facets_on_complete(self, task_instance): facets=get_facets_from_bq_table(table_object), ) - output_dataset_facets = { - "schema": input_dataset.facets["schema"], - "columnLineage": get_identity_column_lineage_facet( - field_names=[field.name for field in table_object.schema], input_datasets=[input_dataset] - ), - } + output_dataset_facets: dict[str, BaseFacet] = get_identity_column_lineage_facet( + dest_field_names=[field.name for field in table_object.schema], input_datasets=[input_dataset] + ) + if "schema" in input_dataset.facets: + output_dataset_facets["schema"] = input_dataset.facets["schema"] + output_datasets = [] for uri in sorted(self.destination_cloud_storage_uris): bucket, blob = _parse_gcs_url(uri) diff --git a/providers/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py b/providers/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py index 06b9a94171b1..6dfe1bd1a2c3 100644 --- a/providers/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +++ b/providers/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py @@ -784,9 +784,10 @@ def get_openlineage_facets_on_complete(self, task_instance): source_objects = ( self.source_objects if isinstance(self.source_objects, list) else [self.source_objects] ) - input_dataset_facets = { - "schema": output_dataset_facets["schema"], - } + input_dataset_facets = {} + if "schema" in output_dataset_facets: + input_dataset_facets["schema"] = output_dataset_facets["schema"] + input_datasets = [] for blob in sorted(source_objects): additional_facets = {} @@ -811,14 +812,16 @@ def get_openlineage_facets_on_complete(self, task_instance): ) input_datasets.append(dataset) - output_dataset_facets["columnLineage"] = get_identity_column_lineage_facet( - field_names=[field.name for field in table_object.schema], input_datasets=input_datasets - ) - output_dataset = Dataset( namespace="bigquery", name=str(table_object.reference), - facets=output_dataset_facets, + facets={ + **output_dataset_facets, + **get_identity_column_lineage_facet( + dest_field_names=[field.name for field in table_object.schema], + input_datasets=input_datasets, + ), + }, ) run_facets = {} diff --git a/providers/tests/google/cloud/openlineage/test_utils.py b/providers/tests/google/cloud/openlineage/test_utils.py index 4f2db0038b7b..e3f40bee1549 100644 --- a/providers/tests/google/cloud/openlineage/test_utils.py +++ b/providers/tests/google/cloud/openlineage/test_utils.py @@ -19,7 +19,6 @@ import json from unittest.mock import MagicMock -import pytest from google.cloud.bigquery.table import Table from airflow.providers.common.compat.openlineage.facet import ( @@ -89,19 +88,78 @@ def test_get_facets_from_bq_table(): def test_get_facets_from_empty_bq_table(): - expected_facets = { - "schema": SchemaDatasetFacet(fields=[]), - "documentation": DocumentationDatasetFacet(description=""), - } result = get_facets_from_bq_table(TEST_EMPTY_TABLE) - assert result == expected_facets + assert result == {} + + +def test_get_identity_column_lineage_facet_source_datasets_schemas_are_subsets(): + field_names = ["field1", "field2", "field3"] + input_datasets = [ + Dataset( + namespace="gs://first_bucket", + name="dir1", + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field1", type="STRING"), + ] + ) + }, + ), + Dataset( + namespace="gs://second_bucket", + name="dir2", + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field2", type="STRING"), + ] + ) + }, + ), + ] + expected_facet = ColumnLineageDatasetFacet( + fields={ + "field1": Fields( + inputFields=[ + InputField( + namespace="gs://first_bucket", + name="dir1", + field="field1", + ) + ], + transformationType="IDENTITY", + transformationDescription="identical", + ), + "field2": Fields( + inputFields=[ + InputField( + namespace="gs://second_bucket", + name="dir2", + field="field2", + ), + ], + transformationType="IDENTITY", + transformationDescription="identical", + ), + # field3 is missing here as it's not present in any source dataset + } + ) + result = get_identity_column_lineage_facet(dest_field_names=field_names, input_datasets=input_datasets) + assert result == {"columnLineage": expected_facet} def test_get_identity_column_lineage_facet_multiple_input_datasets(): field_names = ["field1", "field2"] + schema_facet = SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field1", type="STRING"), + SchemaDatasetFacetFields(name="field2", type="STRING"), + ] + ) input_datasets = [ - Dataset(namespace="gs://first_bucket", name="dir1"), - Dataset(namespace="gs://second_bucket", name="dir2"), + Dataset(namespace="gs://first_bucket", name="dir1", facets={"schema": schema_facet}), + Dataset(namespace="gs://second_bucket", name="dir2", facets={"schema": schema_facet}), ] expected_facet = ColumnLineageDatasetFacet( fields={ @@ -139,24 +197,69 @@ def test_get_identity_column_lineage_facet_multiple_input_datasets(): ), } ) - result = get_identity_column_lineage_facet(field_names=field_names, input_datasets=input_datasets) - assert result == expected_facet + result = get_identity_column_lineage_facet(dest_field_names=field_names, input_datasets=input_datasets) + assert result == {"columnLineage": expected_facet} + + +def test_get_identity_column_lineage_facet_dest_cols_not_in_input_datasets(): + field_names = ["x", "y"] + input_datasets = [ + Dataset( + namespace="gs://first_bucket", + name="dir1", + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field1", type="STRING"), + ] + ) + }, + ), + Dataset( + namespace="gs://second_bucket", + name="dir2", + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field2", type="STRING"), + ] + ) + }, + ), + ] + + result = get_identity_column_lineage_facet(dest_field_names=field_names, input_datasets=input_datasets) + assert result == {} + + +def test_get_identity_column_lineage_facet_no_schema_in_input_dataset(): + field_names = ["field1", "field2"] + input_datasets = [ + Dataset(namespace="gs://first_bucket", name="dir1"), + ] + result = get_identity_column_lineage_facet(dest_field_names=field_names, input_datasets=input_datasets) + assert result == {} def test_get_identity_column_lineage_facet_no_field_names(): field_names = [] + schema_facet = SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field1", type="STRING"), + SchemaDatasetFacetFields(name="field2", type="STRING"), + ] + ) input_datasets = [ - Dataset(namespace="gs://first_bucket", name="dir1"), - Dataset(namespace="gs://second_bucket", name="dir2"), + Dataset(namespace="gs://first_bucket", name="dir1", facets={"schema": schema_facet}), + Dataset(namespace="gs://second_bucket", name="dir2", facets={"schema": schema_facet}), ] - expected_facet = ColumnLineageDatasetFacet(fields={}) - result = get_identity_column_lineage_facet(field_names=field_names, input_datasets=input_datasets) - assert result == expected_facet + result = get_identity_column_lineage_facet(dest_field_names=field_names, input_datasets=input_datasets) + assert result == {} def test_get_identity_column_lineage_facet_no_input_datasets(): field_names = ["field1", "field2"] input_datasets = [] - with pytest.raises(ValueError): - get_identity_column_lineage_facet(field_names=field_names, input_datasets=input_datasets) + result = get_identity_column_lineage_facet(dest_field_names=field_names, input_datasets=input_datasets) + assert result == {} diff --git a/providers/tests/google/cloud/transfers/test_bigquery_to_bigquery.py b/providers/tests/google/cloud/transfers/test_bigquery_to_bigquery.py index ed06928c2ccf..304694126ffc 100644 --- a/providers/tests/google/cloud/transfers/test_bigquery_to_bigquery.py +++ b/providers/tests/google/cloud/transfers/test_bigquery_to_bigquery.py @@ -19,16 +19,29 @@ from unittest import mock +from google.cloud.bigquery import Table + +from airflow.providers.common.compat.openlineage.facet import ( + ColumnLineageDatasetFacet, + Dataset, + DocumentationDatasetFacet, + ExternalQueryRunFacet, + Fields, + InputField, + SchemaDatasetFacet, + SchemaDatasetFacetFields, +) from airflow.providers.google.cloud.transfers.bigquery_to_bigquery import BigQueryToBigQueryOperator BQ_HOOK_PATH = "airflow.providers.google.cloud.transfers.bigquery_to_bigquery.BigQueryHook" -TASK_ID = "test-bq-create-table-operator" +TASK_ID = "test-bq-to-bq-operator" TEST_GCP_PROJECT_ID = "test-project" TEST_DATASET = "test-dataset" TEST_TABLE_ID = "test-table-id" -SOURCE_PROJECT_DATASET_TABLES = f"{TEST_GCP_PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}" -DESTINATION_PROJECT_DATASET_TABLE = f"{TEST_GCP_PROJECT_ID}.{TEST_DATASET + '_new'}.{TEST_TABLE_ID}" +SOURCE_PROJECT_DATASET_TABLE = f"{TEST_GCP_PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}" +SOURCE_PROJECT_DATASET_TABLE2 = f"{TEST_GCP_PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}-2" +DESTINATION_PROJECT_DATASET_TABLE = f"{TEST_GCP_PROJECT_ID}.{TEST_DATASET}_new.{TEST_TABLE_ID}" WRITE_DISPOSITION = "WRITE_EMPTY" CREATE_DISPOSITION = "CREATE_IF_NEEDED" LABELS = {"k1": "v1"} @@ -36,12 +49,18 @@ def split_tablename_side_effect(*args, **kwargs): - if kwargs["table_input"] == SOURCE_PROJECT_DATASET_TABLES: + if kwargs["table_input"] == SOURCE_PROJECT_DATASET_TABLE: return ( TEST_GCP_PROJECT_ID, TEST_DATASET, TEST_TABLE_ID, ) + elif kwargs["table_input"] == SOURCE_PROJECT_DATASET_TABLE2: + return ( + TEST_GCP_PROJECT_ID, + TEST_DATASET, + TEST_TABLE_ID + "-2", + ) elif kwargs["table_input"] == DESTINATION_PROJECT_DATASET_TABLE: return ( TEST_GCP_PROJECT_ID, @@ -55,7 +74,7 @@ class TestBigQueryToBigQueryOperator: def test_execute_without_location_should_execute_successfully(self, mock_hook): operator = BigQueryToBigQueryOperator( task_id=TASK_ID, - source_project_dataset_tables=SOURCE_PROJECT_DATASET_TABLES, + source_project_dataset_tables=SOURCE_PROJECT_DATASET_TABLE, destination_project_dataset_table=DESTINATION_PROJECT_DATASET_TABLE, write_disposition=WRITE_DISPOSITION, create_disposition=CREATE_DISPOSITION, @@ -95,7 +114,48 @@ def test_execute_single_regional_location_should_execute_successfully(self, mock operator = BigQueryToBigQueryOperator( task_id=TASK_ID, - source_project_dataset_tables=SOURCE_PROJECT_DATASET_TABLES, + source_project_dataset_tables=SOURCE_PROJECT_DATASET_TABLE, + destination_project_dataset_table=DESTINATION_PROJECT_DATASET_TABLE, + write_disposition=WRITE_DISPOSITION, + create_disposition=CREATE_DISPOSITION, + labels=LABELS, + encryption_configuration=ENCRYPTION_CONFIGURATION, + location=location, + ) + + mock_hook.return_value.split_tablename.side_effect = split_tablename_side_effect + operator.execute(context=mock.MagicMock()) + mock_hook.return_value.insert_job.assert_called_once_with( + configuration={ + "copy": { + "createDisposition": CREATE_DISPOSITION, + "destinationEncryptionConfiguration": ENCRYPTION_CONFIGURATION, + "destinationTable": { + "datasetId": TEST_DATASET + "_new", + "projectId": TEST_GCP_PROJECT_ID, + "tableId": TEST_TABLE_ID, + }, + "sourceTables": [ + { + "datasetId": TEST_DATASET, + "projectId": TEST_GCP_PROJECT_ID, + "tableId": TEST_TABLE_ID, + }, + ], + "writeDisposition": WRITE_DISPOSITION, + }, + "labels": LABELS, + }, + project_id=mock_hook.return_value.project_id, + ) + + @mock.patch(BQ_HOOK_PATH) + def test_get_openlineage_facets_on_complete_single_source_table(self, mock_hook): + location = "us-central1" + + operator = BigQueryToBigQueryOperator( + task_id=TASK_ID, + source_project_dataset_tables=SOURCE_PROJECT_DATASET_TABLE, destination_project_dataset_table=DESTINATION_PROJECT_DATASET_TABLE, write_disposition=WRITE_DISPOSITION, create_disposition=CREATE_DISPOSITION, @@ -104,9 +164,265 @@ def test_execute_single_regional_location_should_execute_successfully(self, mock location=location, ) + source_table_api_repr = { + "tableReference": { + "projectId": TEST_GCP_PROJECT_ID, + "datasetId": TEST_DATASET, + "tableId": TEST_TABLE_ID, + }, + "description": "Table description.", + "schema": { + "fields": [ + {"name": "field1", "type": "STRING"}, + {"name": "field2", "type": "INTEGER"}, + ] + }, + } + dest_table_api_repr = {**source_table_api_repr} + dest_table_api_repr["tableReference"]["datasetId"] = TEST_DATASET + "_new" + mock_table_data = { + SOURCE_PROJECT_DATASET_TABLE: Table.from_api_repr(source_table_api_repr), + DESTINATION_PROJECT_DATASET_TABLE: Table.from_api_repr(dest_table_api_repr), + } + + mock_hook.return_value.insert_job.return_value.to_api_repr.return_value = { + "jobReference": { + "projectId": TEST_GCP_PROJECT_ID, + "jobId": "actual_job_id", + "location": location, + }, + "configuration": { + "copy": { + "sourceTables": [ + { + "projectId": TEST_GCP_PROJECT_ID, + "datasetId": TEST_DATASET, + "tableId": TEST_TABLE_ID, + }, + ], + "destinationTable": { + "projectId": TEST_GCP_PROJECT_ID, + "datasetId": TEST_DATASET + "_new", + "tableId": TEST_TABLE_ID, + }, + } + }, + } mock_hook.return_value.split_tablename.side_effect = split_tablename_side_effect + mock_hook.return_value.get_client.return_value.get_table.side_effect = ( + lambda table_id: mock_table_data[table_id] + ) + operator.execute(context=mock.MagicMock()) - mock_hook.return_value.get_job.assert_called_once_with( - job_id=mock_hook.return_value.insert_job.return_value.job_id, + result = operator.get_openlineage_facets_on_complete(None) + + assert result.job_facets == {} + assert result.run_facets == { + "externalQuery": ExternalQueryRunFacet(externalQueryId="actual_job_id", source="bigquery") + } + assert len(result.inputs) == 1 + assert result.inputs[0] == Dataset( + namespace="bigquery", + name=SOURCE_PROJECT_DATASET_TABLE, + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field1", type="STRING"), + SchemaDatasetFacetFields(name="field2", type="INTEGER"), + ] + ), + "documentation": DocumentationDatasetFacet("Table description."), + }, + ) + assert len(result.outputs) == 1 + assert result.outputs[0] == Dataset( + namespace="bigquery", + name=DESTINATION_PROJECT_DATASET_TABLE, + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field1", type="STRING"), + SchemaDatasetFacetFields(name="field2", type="INTEGER"), + ] + ), + "documentation": DocumentationDatasetFacet("Table description."), + "columnLineage": ColumnLineageDatasetFacet( + fields={ + "field1": Fields( + inputFields=[ + InputField( + namespace="bigquery", + name=SOURCE_PROJECT_DATASET_TABLE, + field="field1", + transformations=[], + ) + ], + transformationDescription="identical", + transformationType="IDENTITY", + ), + "field2": Fields( + inputFields=[ + InputField( + namespace="bigquery", + name=SOURCE_PROJECT_DATASET_TABLE, + field="field2", + transformations=[], + ) + ], + transformationDescription="identical", + transformationType="IDENTITY", + ), + }, + dataset=[], + ), + }, + ) + + @mock.patch(BQ_HOOK_PATH) + def test_get_openlineage_facets_on_complete_multiple_source_tables(self, mock_hook): + location = "us-central1" + + operator = BigQueryToBigQueryOperator( + task_id=TASK_ID, + source_project_dataset_tables=[ + SOURCE_PROJECT_DATASET_TABLE, + SOURCE_PROJECT_DATASET_TABLE2, + ], + destination_project_dataset_table=DESTINATION_PROJECT_DATASET_TABLE, + write_disposition=WRITE_DISPOSITION, + create_disposition=CREATE_DISPOSITION, + labels=LABELS, + encryption_configuration=ENCRYPTION_CONFIGURATION, location=location, ) + + source_table_repr = { + "tableReference": { + "projectId": TEST_GCP_PROJECT_ID, + "datasetId": TEST_DATASET, + "tableId": TEST_TABLE_ID, + }, + "schema": { + "fields": [ + {"name": "field1", "type": "STRING"}, + ] + }, + } + source_table_repr2 = {**source_table_repr} + source_table_repr2["tableReference"]["tableId"] = TEST_TABLE_ID + "-2" + dest_table_api_repr = { + "tableReference": { + "projectId": TEST_GCP_PROJECT_ID, + "datasetId": TEST_DATASET + "_new", + "tableId": TEST_TABLE_ID, + }, + "schema": { + "fields": [ + {"name": "field1", "type": "STRING"}, + {"name": "field2", "type": "INTEGER"}, + ] + }, + } + mock_table_data = { + SOURCE_PROJECT_DATASET_TABLE: Table.from_api_repr(source_table_repr), + SOURCE_PROJECT_DATASET_TABLE2: Table.from_api_repr(source_table_repr2), + DESTINATION_PROJECT_DATASET_TABLE: Table.from_api_repr(dest_table_api_repr), + } + + mock_hook.return_value.insert_job.return_value.to_api_repr.return_value = { + "jobReference": { + "projectId": TEST_GCP_PROJECT_ID, + "jobId": "actual_job_id", + "location": location, + }, + "configuration": { + "copy": { + "sourceTables": [ + { + "projectId": TEST_GCP_PROJECT_ID, + "datasetId": TEST_DATASET, + "tableId": TEST_TABLE_ID, + }, + { + "projectId": TEST_GCP_PROJECT_ID, + "datasetId": TEST_DATASET, + "tableId": TEST_TABLE_ID + "-2", + }, + ], + "destinationTable": { + "projectId": TEST_GCP_PROJECT_ID, + "datasetId": TEST_DATASET + "_new", + "tableId": TEST_TABLE_ID, + }, + } + }, + } + mock_hook.return_value.split_tablename.side_effect = split_tablename_side_effect + mock_hook.return_value.get_client.return_value.get_table.side_effect = ( + lambda table_id: mock_table_data[table_id] + ) + operator.execute(context=mock.MagicMock()) + result = operator.get_openlineage_facets_on_complete(None) + assert result.job_facets == {} + assert result.run_facets == { + "externalQuery": ExternalQueryRunFacet(externalQueryId="actual_job_id", source="bigquery") + } + assert len(result.inputs) == 2 + assert result.inputs[0] == Dataset( + namespace="bigquery", + name=SOURCE_PROJECT_DATASET_TABLE, + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field1", type="STRING"), + ] + ) + }, + ) + assert result.inputs[1] == Dataset( + namespace="bigquery", + name=SOURCE_PROJECT_DATASET_TABLE2, + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field1", type="STRING"), + ] + ) + }, + ) + assert len(result.outputs) == 1 + assert result.outputs[0] == Dataset( + namespace="bigquery", + name=DESTINATION_PROJECT_DATASET_TABLE, + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field1", type="STRING"), + SchemaDatasetFacetFields(name="field2", type="INTEGER"), + ] + ), + "columnLineage": ColumnLineageDatasetFacet( + fields={ + "field1": Fields( + inputFields=[ + InputField( + namespace="bigquery", + name=SOURCE_PROJECT_DATASET_TABLE, + field="field1", + transformations=[], + ), + InputField( + namespace="bigquery", + name=SOURCE_PROJECT_DATASET_TABLE2, + field="field1", + transformations=[], + ), + ], + transformationDescription="identical", + transformationType="IDENTITY", + ) + }, + dataset=[], + ), + }, + ) diff --git a/providers/tests/google/cloud/transfers/test_bigquery_to_gcs.py b/providers/tests/google/cloud/transfers/test_bigquery_to_gcs.py index 7c2a39825375..b451d2037efc 100644 --- a/providers/tests/google/cloud/transfers/test_bigquery_to_gcs.py +++ b/providers/tests/google/cloud/transfers/test_bigquery_to_gcs.py @@ -299,11 +299,6 @@ def test_get_openlineage_facets_on_complete_bq_dataset(self, mock_hook): def test_get_openlineage_facets_on_complete_bq_dataset_empty_table(self, mock_hook): source_project_dataset_table = f"{PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}" - expected_input_dataset_facets = { - "schema": SchemaDatasetFacet(fields=[]), - "documentation": DocumentationDatasetFacet(description=""), - } - mock_hook.return_value.split_tablename.return_value = (PROJECT_ID, TEST_DATASET, TEST_TABLE_ID) mock_hook.return_value.get_client.return_value.get_table.return_value = TEST_EMPTY_TABLE @@ -320,7 +315,7 @@ def test_get_openlineage_facets_on_complete_bq_dataset_empty_table(self, mock_ho assert lineage.inputs[0] == Dataset( namespace="bigquery", name=source_project_dataset_table, - facets=expected_input_dataset_facets, + facets={}, ) @mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_gcs.BigQueryHook") @@ -330,16 +325,6 @@ def test_get_openlineage_facets_on_complete_gcs_no_wildcard_empty_table(self, mo real_job_id = "123456_hash" bq_namespace = "bigquery" - expected_input_facets = { - "schema": SchemaDatasetFacet(fields=[]), - "documentation": DocumentationDatasetFacet(description=""), - } - - expected_output_facets = { - "schema": SchemaDatasetFacet(fields=[]), - "columnLineage": ColumnLineageDatasetFacet(fields={}), - } - mock_hook.return_value.split_tablename.return_value = (PROJECT_ID, TEST_DATASET, TEST_TABLE_ID) mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False) mock_hook.return_value.get_client.return_value.get_table.return_value = TEST_EMPTY_TABLE @@ -357,12 +342,12 @@ def test_get_openlineage_facets_on_complete_gcs_no_wildcard_empty_table(self, mo assert len(lineage.inputs) == 1 assert len(lineage.outputs) == 1 assert lineage.inputs[0] == Dataset( - namespace=bq_namespace, name=source_project_dataset_table, facets=expected_input_facets + namespace=bq_namespace, name=source_project_dataset_table, facets={} ) assert lineage.outputs[0] == Dataset( namespace=f"gs://{TEST_BUCKET}", name=f"{TEST_FOLDER}/{TEST_OBJECT_NO_WILDCARD}", - facets=expected_output_facets, + facets={}, ) assert lineage.run_facets == { "externalQuery": ExternalQueryRunFacet(externalQueryId=real_job_id, source=bq_namespace) diff --git a/providers/tests/google/cloud/transfers/test_gcs_to_bigquery.py b/providers/tests/google/cloud/transfers/test_gcs_to_bigquery.py index 0ba2e07bb05e..299fc9fead54 100644 --- a/providers/tests/google/cloud/transfers/test_gcs_to_bigquery.py +++ b/providers/tests/google/cloud/transfers/test_gcs_to_bigquery.py @@ -1439,12 +1439,6 @@ def test_get_openlineage_facets_on_complete_empty_table(self, hook): hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE) hook.return_value.get_client.return_value.get_table.return_value = TEST_EMPTY_TABLE - expected_output_dataset_facets = { - "schema": SchemaDatasetFacet(fields=[]), - "documentation": DocumentationDatasetFacet(description=""), - "columnLineage": ColumnLineageDatasetFacet(fields={}), - } - operator = GCSToBigQueryOperator( project_id=JOB_PROJECT_ID, task_id=TASK_ID, @@ -1461,18 +1455,17 @@ def test_get_openlineage_facets_on_complete_empty_table(self, hook): assert lineage.outputs[0] == Dataset( namespace="bigquery", name=TEST_EXPLICIT_DEST, - facets=expected_output_dataset_facets, + facets={}, ) assert lineage.inputs[0] == Dataset( namespace=f"gs://{TEST_BUCKET}", name=TEST_OBJECT_NO_WILDCARD, - facets={"schema": SchemaDatasetFacet(fields=[])}, + facets={}, ) assert lineage.inputs[1] == Dataset( namespace=f"gs://{TEST_BUCKET}", name="/", facets={ - "schema": SchemaDatasetFacet(fields=[]), "symlink": SymlinksDatasetFacet( identifiers=[ Identifier( From da9468652056794a4a4cd5f11ef4abaac6a16a9a Mon Sep 17 00:00:00 2001 From: Sean Rose <1994030+sean-rose@users.noreply.github.com> Date: Fri, 22 Nov 2024 15:39:10 -0800 Subject: [PATCH 11/14] Fix incorrect query in `BigQueryAsyncHook.create_job_for_partition_get`. (#44225) * The table ID column in `INFORMATION_SCHEMA.PARTITIONS` is named `table_name`, not `table_id`. * The table ID string value needs to be quoted in the SQL. --- providers/src/airflow/providers/google/cloud/hooks/bigquery.py | 2 +- providers/tests/google/cloud/hooks/test_bigquery.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/providers/src/airflow/providers/google/cloud/hooks/bigquery.py b/providers/src/airflow/providers/google/cloud/hooks/bigquery.py index 159a8f3f639b..bce89e9e4184 100644 --- a/providers/src/airflow/providers/google/cloud/hooks/bigquery.py +++ b/providers/src/airflow/providers/google/cloud/hooks/bigquery.py @@ -2095,7 +2095,7 @@ async def create_job_for_partition_get( query_request = { "query": "SELECT partition_id " f"FROM `{project_id}.{dataset_id}.INFORMATION_SCHEMA.PARTITIONS`" - + (f" WHERE table_id={table_id}" if table_id else ""), + + (f" WHERE table_name='{table_id}'" if table_id else ""), "useLegacySql": False, } job_query_resp = await job_client.query(query_request, cast(Session, session)) diff --git a/providers/tests/google/cloud/hooks/test_bigquery.py b/providers/tests/google/cloud/hooks/test_bigquery.py index b0e7f8efb209..ee0f904bb94b 100644 --- a/providers/tests/google/cloud/hooks/test_bigquery.py +++ b/providers/tests/google/cloud/hooks/test_bigquery.py @@ -1604,7 +1604,7 @@ async def test_create_job_for_partition_get_with_table(self, mock_job_instance, expected_query_request = { "query": "SELECT partition_id " f"FROM `{PROJECT_ID}.{DATASET_ID}.INFORMATION_SCHEMA.PARTITIONS`" - f" WHERE table_id={TABLE_ID}", + f" WHERE table_name='{TABLE_ID}'", "useLegacySql": False, } await hook.create_job_for_partition_get( From 32f064f73c9b55325a87e0c6f1124a8c1f551137 Mon Sep 17 00:00:00 2001 From: vatsrahul1001 <43964496+vatsrahul1001@users.noreply.github.com> Date: Sat, 23 Nov 2024 05:11:54 +0530 Subject: [PATCH 12/14] Trigger openlineage test when asset files changes (#44172) --- .../airflow_breeze/utils/selective_checks.py | 24 ++++++++++++++- dev/breeze/tests/test_selective_checks.py | 30 +++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/dev/breeze/src/airflow_breeze/utils/selective_checks.py b/dev/breeze/src/airflow_breeze/utils/selective_checks.py index f29678dadcfd..c5ecef552617 100644 --- a/dev/breeze/src/airflow_breeze/utils/selective_checks.py +++ b/dev/breeze/src/airflow_breeze/utils/selective_checks.py @@ -119,6 +119,7 @@ class FileGroupForCi(Enum): ALL_PROVIDER_YAML_FILES = "all_provider_yaml_files" ALL_DOCS_PYTHON_FILES = "all_docs_python_files" TESTS_UTILS_FILES = "test_utils_files" + ASSET_FILES = "asset_files" class AllProvidersSentinel: @@ -253,6 +254,12 @@ def __hash__(self): r"^task_sdk/src/airflow/sdk/.*\.py$", r"^task_sdk/tests/.*\.py$", ], + FileGroupForCi.ASSET_FILES: [ + r"^airflow/assets/", + r"^airflow/models/assets/", + r"^task_sdk/src/airflow/sdk/definitions/asset/", + r"^airflow/datasets/", + ], } ) @@ -696,6 +703,10 @@ def needs_javascript_scans(self) -> bool: def needs_api_tests(self) -> bool: return self._should_be_run(FileGroupForCi.API_TEST_FILES) + @cached_property + def needs_ol_tests(self) -> bool: + return self._should_be_run(FileGroupForCi.ASSET_FILES) + @cached_property def needs_api_codegen(self) -> bool: return self._should_be_run(FileGroupForCi.API_CODEGEN_FILES) @@ -860,7 +871,15 @@ def _get_providers_test_types_to_run(self, split_to_individual_providers: bool = all_providers_source_files = self._matching_files( FileGroupForCi.ALL_PROVIDERS_PYTHON_FILES, CI_FILE_GROUP_MATCHES, CI_FILE_GROUP_EXCLUDES ) - if len(all_providers_source_files) == 0 and not self.needs_api_tests: + assets_source_files = self._matching_files( + FileGroupForCi.ASSET_FILES, CI_FILE_GROUP_MATCHES, CI_FILE_GROUP_EXCLUDES + ) + + if ( + len(all_providers_source_files) == 0 + and len(assets_source_files) == 0 + and not self.needs_api_tests + ): # IF API tests are needed, that will trigger extra provider checks return [] else: @@ -1440,6 +1459,8 @@ def _find_all_providers_affected(self, include_docs: bool) -> list[str] | AllPro all_providers.add(provider) if self.needs_api_tests: all_providers.add("fab") + if self.needs_ol_tests: + all_providers.add("openlineage") if all_providers_affected: return ALL_PROVIDERS_SENTINEL if suspended_providers: @@ -1473,6 +1494,7 @@ def _find_all_providers_affected(self, include_docs: bool) -> list[str] | AllPro ) if not all_providers: return None + for provider in list(all_providers): all_providers.update( get_related_providers(provider, upstream_dependencies=True, downstream_dependencies=True) diff --git a/dev/breeze/tests/test_selective_checks.py b/dev/breeze/tests/test_selective_checks.py index 749b4e1fa6bd..2f00013f89f0 100644 --- a/dev/breeze/tests/test_selective_checks.py +++ b/dev/breeze/tests/test_selective_checks.py @@ -1672,6 +1672,36 @@ def test_expected_output_push( }, id="pre commit ts-compile-format-lint should not be ignored if openapi spec changed.", ), + pytest.param( + ( + "airflow/assets/", + "airflow/models/assets/", + "task_sdk/src/airflow/sdk/definitions/asset/", + "airflow/datasets/", + ), + { + "selected-providers-list-as-string": "amazon common.compat common.io common.sql dbt.cloud ftp google mysql openlineage postgres sftp snowflake trino", + "all-python-versions": "['3.9']", + "all-python-versions-list-as-string": "3.9", + "ci-image-build": "true", + "prod-image-build": "false", + "needs-helm-tests": "false", + "run-tests": "true", + "skip-providers-tests": "false", + "test-groups": "['core', 'providers']", + "docs-build": "true", + "docs-list-as-string": "apache-airflow amazon common.compat common.io common.sql dbt.cloud ftp google mysql openlineage postgres sftp snowflake trino", + "skip-pre-commits": "check-provider-yaml-valid,flynt,identity,lint-helm-chart,mypy-airflow,mypy-dev,mypy-docs,mypy-providers,mypy-task-sdk," + "ts-compile-format-lint-ui,ts-compile-format-lint-www", + "run-kubernetes-tests": "false", + "upgrade-to-newer-dependencies": "false", + "core-test-types-list-as-string": "API Always CLI Core Operators Other Serialization WWW", + "providers-test-types-list-as-string": "Providers[amazon] Providers[common.compat,common.io,common.sql,dbt.cloud,ftp,mysql,openlineage,postgres,sftp,snowflake,trino] Providers[google]", + "needs-mypy": "false", + "mypy-checks": "[]", + }, + id="Trigger openlineage and related providers tests when Assets files changed", + ), ], ) def test_expected_output_pull_request_target( From e5de5506d54aeacbd78da319a5975411db4b03cd Mon Sep 17 00:00:00 2001 From: Jens Scheffler <95105677+jscheffl@users.noreply.github.com> Date: Sat, 23 Nov 2024 01:18:18 +0100 Subject: [PATCH 13/14] Fix the Show Down text (#44292) --- providers/src/airflow/providers/edge/cli/edge_command.py | 2 +- providers/src/airflow/providers/edge/models/edge_worker.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/providers/src/airflow/providers/edge/cli/edge_command.py b/providers/src/airflow/providers/edge/cli/edge_command.py index 4d0b46d74e79..8cb19fb2da9d 100644 --- a/providers/src/airflow/providers/edge/cli/edge_command.py +++ b/providers/src/airflow/providers/edge/cli/edge_command.py @@ -156,7 +156,7 @@ def __init__( @staticmethod def signal_handler(sig, frame): - logger.info("Request to show down Edge Worker received, waiting for jobs to complete.") + logger.info("Request to shut down Edge Worker received, waiting for jobs to complete.") _EdgeWorkerCli.drain = True def shutdown_handler(self, sig, frame): diff --git a/providers/src/airflow/providers/edge/models/edge_worker.py b/providers/src/airflow/providers/edge/models/edge_worker.py index b65d93503885..7fdcb0cf3d41 100644 --- a/providers/src/airflow/providers/edge/models/edge_worker.py +++ b/providers/src/airflow/providers/edge/models/edge_worker.py @@ -62,7 +62,7 @@ class EdgeWorkerState(str, Enum): TERMINATING = "terminating" """Edge Worker is completing work and stopping.""" OFFLINE = "offline" - """Edge Worker was show down.""" + """Edge Worker was shut down.""" UNKNOWN = "unknown" """No heartbeat signal from worker for some time, Edge Worker probably down.""" From 3c58e01266f884544fdebc70f92b63848c610d2d Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Sat, 23 Nov 2024 03:42:11 +0000 Subject: [PATCH 14/14] Bump `google-cloud-translate` to `3.16` (#44297) --- generated/provider_dependencies.json | 2 +- providers/src/airflow/providers/google/provider.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index c23d2c27825a..24745db0f400 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -679,7 +679,7 @@ "google-cloud-storage>=2.7.0", "google-cloud-tasks>=2.13.0", "google-cloud-texttospeech>=2.14.1", - "google-cloud-translate>=3.11.0", + "google-cloud-translate>=3.16.0", "google-cloud-videointelligence>=2.11.0", "google-cloud-vision>=3.4.0", "google-cloud-workflows>=1.10.0", diff --git a/providers/src/airflow/providers/google/provider.yaml b/providers/src/airflow/providers/google/provider.yaml index 57c5ddcd5ec3..b0107286ca7c 100644 --- a/providers/src/airflow/providers/google/provider.yaml +++ b/providers/src/airflow/providers/google/provider.yaml @@ -150,7 +150,7 @@ dependencies: - google-cloud-storage-transfer>=1.4.1 - google-cloud-tasks>=2.13.0 - google-cloud-texttospeech>=2.14.1 - - google-cloud-translate>=3.11.0 + - google-cloud-translate>=3.16.0 - google-cloud-videointelligence>=2.11.0 - google-cloud-vision>=3.4.0 - google-cloud-workflows>=1.10.0