diff --git a/.devcontainer/mysql/devcontainer.json b/.devcontainer/mysql/devcontainer.json index 5a25b6ad50625..011ff292d2718 100644 --- a/.devcontainer/mysql/devcontainer.json +++ b/.devcontainer/mysql/devcontainer.json @@ -22,5 +22,10 @@ "rogalmic.bash-debug" ], "service": "airflow", - "forwardPorts": [8080,5555,5432,6379] + "forwardPorts": [8080,5555,5432,6379], + "workspaceFolder": "/opt/airflow", + // for users who use non-standard git config patterns + // https://github.com/microsoft/vscode-remote-release/issues/2084#issuecomment-989756268 + "initializeCommand": "cd \"${localWorkspaceFolder}\" && git config --local user.email \"$(git config user.email)\" && git config --local user.name \"$(git config user.name)\"", + "overrideCommand": true } diff --git a/.devcontainer/postgres/devcontainer.json b/.devcontainer/postgres/devcontainer.json index 46ba305b58554..419dbedfa1d38 100644 --- a/.devcontainer/postgres/devcontainer.json +++ b/.devcontainer/postgres/devcontainer.json @@ -22,5 +22,10 @@ "rogalmic.bash-debug" ], "service": "airflow", - "forwardPorts": [8080,5555,5432,6379] + "forwardPorts": [8080,5555,5432,6379], + "workspaceFolder": "/opt/airflow", + // for users who use non-standard git config patterns + // https://github.com/microsoft/vscode-remote-release/issues/2084#issuecomment-989756268 + "initializeCommand": "cd \"${localWorkspaceFolder}\" && git config --local user.email \"$(git config user.email)\" && git config --local user.name \"$(git config user.name)\"", + "overrideCommand": true } diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d072d21055cff..1525333ed9dbe 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -360,7 +360,7 @@ repos: types_or: [python, pyi] args: [--fix] require_serial: true - additional_dependencies: ["ruff==0.7.3"] + additional_dependencies: ["ruff==0.8.0"] exclude: ^.*/.*_vendor/|^tests/dags/test_imports.py|^performance/tests/test_.*.py - id: ruff-format name: Run 'ruff format' @@ -370,7 +370,7 @@ repos: types_or: [python, pyi] args: [] require_serial: true - additional_dependencies: ["ruff==0.7.3"] + additional_dependencies: ["ruff==0.8.0"] exclude: ^.*/.*_vendor/|^tests/dags/test_imports.py$ - id: replace-bad-characters name: Replace bad characters diff --git a/airflow/api_fastapi/common/db/common.py b/airflow/api_fastapi/common/db/common.py index 17da1eafacc93..2d7da4bff7376 100644 --- a/airflow/api_fastapi/common/db/common.py +++ b/airflow/api_fastapi/common/db/common.py @@ -24,8 +24,10 @@ from typing import TYPE_CHECKING, Literal, Sequence, overload -from airflow.utils.db import get_query_count -from airflow.utils.session import NEW_SESSION, create_session, provide_session +from sqlalchemy.ext.asyncio import AsyncSession + +from airflow.utils.db import get_query_count, get_query_count_async +from airflow.utils.session import NEW_SESSION, create_session, create_session_async, provide_session if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -53,7 +55,9 @@ def your_route(session: Annotated[Session, Depends(get_session)]): def apply_filters_to_select( - *, base_select: Select, filters: Sequence[BaseParam | None] | None = None + *, + base_select: Select, + filters: Sequence[BaseParam | None] | None = None, ) -> Select: if filters is None: return base_select @@ -65,6 +69,80 @@ def apply_filters_to_select( return base_select +async def get_async_session() -> AsyncSession: + """ + Dependency for providing a session. + + Example usage: + + .. code:: python + + @router.get("/your_path") + def your_route(session: Annotated[AsyncSession, Depends(get_async_session)]): + pass + """ + async with create_session_async() as session: + yield session + + +@overload +async def paginated_select_async( + *, + query: Select, + filters: Sequence[BaseParam] | None = None, + order_by: BaseParam | None = None, + offset: BaseParam | None = None, + limit: BaseParam | None = None, + session: AsyncSession, + return_total_entries: Literal[True] = True, +) -> tuple[Select, int]: ... + + +@overload +async def paginated_select_async( + *, + query: Select, + filters: Sequence[BaseParam] | None = None, + order_by: BaseParam | None = None, + offset: BaseParam | None = None, + limit: BaseParam | None = None, + session: AsyncSession, + return_total_entries: Literal[False], +) -> tuple[Select, None]: ... + + +async def paginated_select_async( + *, + query: Select, + filters: Sequence[BaseParam | None] | None = None, + order_by: BaseParam | None = None, + offset: BaseParam | None = None, + limit: BaseParam | None = None, + session: AsyncSession, + return_total_entries: bool = True, +) -> tuple[Select, int | None]: + query = apply_filters_to_select( + base_select=query, + filters=filters, + ) + + total_entries = None + if return_total_entries: + total_entries = await get_query_count_async(query, session=session) + + # TODO: Re-enable when permissions are handled. Readable / writable entities, + # for instance: + # readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user) + # dags_select = dags_select.where(DagModel.dag_id.in_(readable_dags)) + + query = apply_filters_to_select( + base_select=query, + filters=[order_by, offset, limit], + ) + + return query, total_entries + + @overload def paginated_select( *, diff --git a/airflow/api_fastapi/core_api/routes/public/backfills.py b/airflow/api_fastapi/core_api/routes/public/backfills.py index aa6f540d32791..78b2beb558895 100644 --- a/airflow/api_fastapi/core_api/routes/public/backfills.py +++ b/airflow/api_fastapi/core_api/routes/public/backfills.py @@ -20,9 +20,10 @@ from fastapi import Depends, HTTPException, status from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session -from airflow.api_fastapi.common.db.common import get_session, paginated_select +from airflow.api_fastapi.common.db.common import get_async_session, get_session, paginated_select_async from airflow.api_fastapi.common.parameters import QueryLimit, QueryOffset, SortParam from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.core_api.datamodels.backfills import ( @@ -49,7 +50,7 @@ @backfills_router.get( path="", ) -def list_backfills( +async def list_backfills( dag_id: str, limit: QueryLimit, offset: QueryOffset, @@ -57,18 +58,16 @@ def list_backfills( SortParam, Depends(SortParam(["id"], Backfill).dynamic_depends()), ], - session: Annotated[Session, Depends(get_session)], + session: Annotated[AsyncSession, Depends(get_async_session)], ) -> BackfillCollectionResponse: - select_stmt, total_entries = paginated_select( - select=select(Backfill).where(Backfill.dag_id == dag_id), + select_stmt, total_entries = await paginated_select_async( + query=select(Backfill).where(Backfill.dag_id == dag_id), order_by=order_by, offset=offset, limit=limit, session=session, ) - - backfills = session.scalars(select_stmt) - + backfills = await session.scalars(select_stmt) return BackfillCollectionResponse( backfills=backfills, total_entries=total_entries, diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 108d9e51d7588..169fd548afe08 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -116,7 +116,7 @@ # Todo: AIP-44: Once we get rid of AIP-44 we can remove this. But without this here pydantic fails to resolve # types for serialization -from airflow.utils.task_group import TaskGroup # noqa: TCH001 +from airflow.utils.task_group import TaskGroup # noqa: TC001 TaskPreExecuteHook = Callable[[Context], None] TaskPostExecuteHook = Callable[[Context, Any], None] diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py index 45208e353bdc1..5b6f83f4d59b2 100644 --- a/airflow/models/xcom.py +++ b/airflow/models/xcom.py @@ -34,6 +34,7 @@ select, text, ) +from sqlalchemy.dialects import postgresql from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.orm import Query, reconstructor, relationship @@ -79,7 +80,7 @@ class BaseXCom(TaskInstanceDependencies, LoggingMixin): dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False) run_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False) - value = Column(JSON) + value = Column(JSON().with_variant(postgresql.JSONB, "postgresql")) timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False) __table_args__ = ( diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 61d851aaed118..21fa575676a57 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -110,9 +110,9 @@ HAS_KUBERNETES: bool try: - from kubernetes.client import models as k8s # noqa: TCH004 + from kubernetes.client import models as k8s # noqa: TC004 - from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator # noqa: TCH004 + from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator # noqa: TC004 except ImportError: pass diff --git a/airflow/settings.py b/airflow/settings.py index 5b458efcba473..76b3e948964f3 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -31,7 +31,7 @@ import pluggy from packaging.version import Version from sqlalchemy import create_engine, exc, text -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession as SAAsyncSession, create_async_engine from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.pool import NullPool @@ -111,7 +111,7 @@ # this is achieved by the Session factory above. NonScopedSession: Callable[..., SASession] async_engine: AsyncEngine -create_async_session: Callable[..., AsyncSession] +AsyncSession: Callable[..., SAAsyncSession] # The JSON library to use for DAG Serialization and De-Serialization json = json @@ -469,7 +469,7 @@ def configure_orm(disable_connection_pool=False, pool_class=None): global Session global engine global async_engine - global create_async_session + global AsyncSession global NonScopedSession if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true": @@ -498,11 +498,11 @@ def configure_orm(disable_connection_pool=False, pool_class=None): engine = create_engine(SQL_ALCHEMY_CONN, connect_args=connect_args, **engine_args, future=True) async_engine = create_async_engine(SQL_ALCHEMY_CONN_ASYNC, future=True) - create_async_session = sessionmaker( + AsyncSession = sessionmaker( bind=async_engine, autocommit=False, autoflush=False, - class_=AsyncSession, + class_=SAAsyncSession, expire_on_commit=False, ) mask_secret(engine.url.password) diff --git a/airflow/utils/db.py b/airflow/utils/db.py index d8939a117317f..c899ebf615d06 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -70,6 +70,7 @@ from alembic.runtime.environment import EnvironmentContext from alembic.script import ScriptDirectory from sqlalchemy.engine import Row + from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from sqlalchemy.sql.elements import ClauseElement, TextClause from sqlalchemy.sql.selectable import Select @@ -1447,6 +1448,21 @@ def get_query_count(query_stmt: Select, *, session: Session) -> int: return session.scalar(count_stmt) +async def get_query_count_async(query: Select, *, session: AsyncSession) -> int: + """ + Get count of a query. + + A SELECT COUNT() FROM is issued against the subquery built from the + given statement. The ORDER BY clause is stripped from the statement + since it's unnecessary for COUNT, and can impact query planning and + degrade performance. + + :meta private: + """ + count_stmt = select(func.count()).select_from(query.order_by(None).subquery()) + return await session.scalar(count_stmt) + + def check_query_exists(query_stmt: Select, *, session: Session) -> bool: """ Check whether there is at least one row matching a query. diff --git a/airflow/utils/db_cleanup.py b/airflow/utils/db_cleanup.py index 9f0f8d63fe12c..f71caf06ac8f2 100644 --- a/airflow/utils/db_cleanup.py +++ b/airflow/utils/db_cleanup.py @@ -53,6 +53,10 @@ logger = logging.getLogger(__name__) ARCHIVE_TABLE_PREFIX = "_airflow_deleted__" +# Archived tables created by DB migrations +ARCHIVED_TABLES_FROM_DB_MIGRATIONS = [ + "_xcom_archive" # Table created by the AF 2 -> 3.0.0 migration when the XComs had pickled values +] @dataclass @@ -116,6 +120,7 @@ def readable_config(self): _TableConfig(table_name="task_instance_history", recency_column_name="start_date"), _TableConfig(table_name="task_reschedule", recency_column_name="start_date"), _TableConfig(table_name="xcom", recency_column_name="timestamp"), + _TableConfig(table_name="_xcom_archive", recency_column_name="timestamp"), _TableConfig(table_name="callback_request", recency_column_name="created_at"), _TableConfig(table_name="celery_taskmeta", recency_column_name="date_done"), _TableConfig(table_name="celery_tasksetmeta", recency_column_name="date_done"), @@ -380,13 +385,20 @@ def _effective_table_names(*, table_names: list[str] | None) -> tuple[set[str], def _get_archived_table_names(table_names: list[str] | None, session: Session) -> list[str]: inspector = inspect(session.bind) - db_table_names = [x for x in inspector.get_table_names() if x.startswith(ARCHIVE_TABLE_PREFIX)] + db_table_names = [ + x + for x in inspector.get_table_names() + if x.startswith(ARCHIVE_TABLE_PREFIX) or x in ARCHIVED_TABLES_FROM_DB_MIGRATIONS + ] effective_table_names, _ = _effective_table_names(table_names=table_names) # Filter out tables that don't start with the archive prefix archived_table_names = [ table_name for table_name in db_table_names - if any("__" + x + "__" in table_name for x in effective_table_names) + if ( + any("__" + x + "__" in table_name for x in effective_table_names) + or table_name in ARCHIVED_TABLES_FROM_DB_MIGRATIONS + ) ] return archived_table_names diff --git a/airflow/utils/session.py b/airflow/utils/session.py index a63d3f3f937a8..49383cdf4a8bf 100644 --- a/airflow/utils/session.py +++ b/airflow/utils/session.py @@ -65,6 +65,24 @@ def create_session(scoped: bool = True) -> Generator[SASession, None, None]: session.close() +@contextlib.asynccontextmanager +async def create_session_async(): + """ + Context manager to create async session. + + :meta private: + """ + from airflow.settings import AsyncSession + + async with AsyncSession() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + + PS = ParamSpec("PS") RT = TypeVar("RT") diff --git a/dev/breeze/src/airflow_breeze/commands/release_management_commands.py b/dev/breeze/src/airflow_breeze/commands/release_management_commands.py index 01efdb77fabe3..71d4d2a35fb17 100644 --- a/dev/breeze/src/airflow_breeze/commands/release_management_commands.py +++ b/dev/breeze/src/airflow_breeze/commands/release_management_commands.py @@ -2195,7 +2195,7 @@ class ProviderPRInfo(NamedTuple): f"[warning]Skipping provider {provider_id}. " "The changelog file doesn't contain any PRs for the release.\n" ) - return + continue provider_prs[provider_id] = [pr for pr in prs if pr not in excluded_prs] all_prs.update(provider_prs[provider_id]) g = Github(github_token) @@ -2245,6 +2245,8 @@ class ProviderPRInfo(NamedTuple): progress.advance(task) providers: dict[str, ProviderPRInfo] = {} for provider_id in prepared_package_ids: + if provider_id not in provider_prs: + continue pull_request_list = [pull_requests[pr] for pr in provider_prs[provider_id] if pr in pull_requests] provider_yaml_dict = yaml.safe_load( ( diff --git a/dev/breeze/src/airflow_breeze/utils/console.py b/dev/breeze/src/airflow_breeze/utils/console.py index 0b8861673883a..910a687004d5a 100644 --- a/dev/breeze/src/airflow_breeze/utils/console.py +++ b/dev/breeze/src/airflow_breeze/utils/console.py @@ -83,7 +83,7 @@ class Output(NamedTuple): @property def file(self) -> TextIO: - return open(self.file_name, "a+t") + return open(self.file_name, "a+") @property def escaped_title(self) -> str: diff --git a/dev/breeze/src/airflow_breeze/utils/selective_checks.py b/dev/breeze/src/airflow_breeze/utils/selective_checks.py index f29678dadcfdc..c5ecef5526171 100644 --- a/dev/breeze/src/airflow_breeze/utils/selective_checks.py +++ b/dev/breeze/src/airflow_breeze/utils/selective_checks.py @@ -119,6 +119,7 @@ class FileGroupForCi(Enum): ALL_PROVIDER_YAML_FILES = "all_provider_yaml_files" ALL_DOCS_PYTHON_FILES = "all_docs_python_files" TESTS_UTILS_FILES = "test_utils_files" + ASSET_FILES = "asset_files" class AllProvidersSentinel: @@ -253,6 +254,12 @@ def __hash__(self): r"^task_sdk/src/airflow/sdk/.*\.py$", r"^task_sdk/tests/.*\.py$", ], + FileGroupForCi.ASSET_FILES: [ + r"^airflow/assets/", + r"^airflow/models/assets/", + r"^task_sdk/src/airflow/sdk/definitions/asset/", + r"^airflow/datasets/", + ], } ) @@ -696,6 +703,10 @@ def needs_javascript_scans(self) -> bool: def needs_api_tests(self) -> bool: return self._should_be_run(FileGroupForCi.API_TEST_FILES) + @cached_property + def needs_ol_tests(self) -> bool: + return self._should_be_run(FileGroupForCi.ASSET_FILES) + @cached_property def needs_api_codegen(self) -> bool: return self._should_be_run(FileGroupForCi.API_CODEGEN_FILES) @@ -860,7 +871,15 @@ def _get_providers_test_types_to_run(self, split_to_individual_providers: bool = all_providers_source_files = self._matching_files( FileGroupForCi.ALL_PROVIDERS_PYTHON_FILES, CI_FILE_GROUP_MATCHES, CI_FILE_GROUP_EXCLUDES ) - if len(all_providers_source_files) == 0 and not self.needs_api_tests: + assets_source_files = self._matching_files( + FileGroupForCi.ASSET_FILES, CI_FILE_GROUP_MATCHES, CI_FILE_GROUP_EXCLUDES + ) + + if ( + len(all_providers_source_files) == 0 + and len(assets_source_files) == 0 + and not self.needs_api_tests + ): # IF API tests are needed, that will trigger extra provider checks return [] else: @@ -1440,6 +1459,8 @@ def _find_all_providers_affected(self, include_docs: bool) -> list[str] | AllPro all_providers.add(provider) if self.needs_api_tests: all_providers.add("fab") + if self.needs_ol_tests: + all_providers.add("openlineage") if all_providers_affected: return ALL_PROVIDERS_SENTINEL if suspended_providers: @@ -1473,6 +1494,7 @@ def _find_all_providers_affected(self, include_docs: bool) -> list[str] | AllPro ) if not all_providers: return None + for provider in list(all_providers): all_providers.update( get_related_providers(provider, upstream_dependencies=True, downstream_dependencies=True) diff --git a/dev/breeze/tests/test_selective_checks.py b/dev/breeze/tests/test_selective_checks.py index 749b4e1fa6bdf..2f00013f89f08 100644 --- a/dev/breeze/tests/test_selective_checks.py +++ b/dev/breeze/tests/test_selective_checks.py @@ -1672,6 +1672,36 @@ def test_expected_output_push( }, id="pre commit ts-compile-format-lint should not be ignored if openapi spec changed.", ), + pytest.param( + ( + "airflow/assets/", + "airflow/models/assets/", + "task_sdk/src/airflow/sdk/definitions/asset/", + "airflow/datasets/", + ), + { + "selected-providers-list-as-string": "amazon common.compat common.io common.sql dbt.cloud ftp google mysql openlineage postgres sftp snowflake trino", + "all-python-versions": "['3.9']", + "all-python-versions-list-as-string": "3.9", + "ci-image-build": "true", + "prod-image-build": "false", + "needs-helm-tests": "false", + "run-tests": "true", + "skip-providers-tests": "false", + "test-groups": "['core', 'providers']", + "docs-build": "true", + "docs-list-as-string": "apache-airflow amazon common.compat common.io common.sql dbt.cloud ftp google mysql openlineage postgres sftp snowflake trino", + "skip-pre-commits": "check-provider-yaml-valid,flynt,identity,lint-helm-chart,mypy-airflow,mypy-dev,mypy-docs,mypy-providers,mypy-task-sdk," + "ts-compile-format-lint-ui,ts-compile-format-lint-www", + "run-kubernetes-tests": "false", + "upgrade-to-newer-dependencies": "false", + "core-test-types-list-as-string": "API Always CLI Core Operators Other Serialization WWW", + "providers-test-types-list-as-string": "Providers[amazon] Providers[common.compat,common.io,common.sql,dbt.cloud,ftp,mysql,openlineage,postgres,sftp,snowflake,trino] Providers[google]", + "needs-mypy": "false", + "mypy-checks": "[]", + }, + id="Trigger openlineage and related providers tests when Assets files changed", + ), ], ) def test_expected_output_pull_request_target( diff --git a/docs/apache-airflow-providers-google/operators/cloud/translate.rst b/docs/apache-airflow-providers-google/operators/cloud/translate.rst index 6bcc32ec669ce..d56fac26dbeb4 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/translate.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/translate.rst @@ -100,11 +100,96 @@ For parameter definition, take a look at :class:`~airflow.providers.google.cloud.operators.translate.TranslateTextBatchOperator` +.. _howto/operator:TranslateCreateDatasetOperator: + +TranslateCreateDatasetOperator +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Create a native translation dataset using Cloud Translate API (Advanced V3). + +For parameter definition, take a look at +:class:`~airflow.providers.google.cloud.operators.translate.TranslateCreateDatasetOperator` + +Using the operator +"""""""""""""""""" + +Basic usage of the operator: + +.. exampleinclude:: /../../providers/tests/system/google/cloud/translate/example_translate_dataset.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_translate_automl_create_dataset] + :end-before: [END howto_operator_translate_automl_create_dataset] + + +.. _howto/operator:TranslateImportDataOperator: + +TranslateImportDataOperator +^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Import data to the existing native dataset, using Cloud Translate API (Advanced V3). + +For parameter definition, take a look at +:class:`~airflow.providers.google.cloud.operators.translate.TranslateImportDataOperator` + +Using the operator +"""""""""""""""""" + +Basic usage of the operator: + +.. exampleinclude:: /../../providers/tests/system/google/cloud/translate/example_translate_dataset.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_translate_automl_import_data] + :end-before: [END howto_operator_translate_automl_import_data] + + +.. _howto/operator:TranslateDatasetsListOperator: + +TranslateDatasetsListOperator +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Get list of translation datasets using Cloud Translate API (Advanced V3). + +For parameter definition, take a look at +:class:`~airflow.providers.google.cloud.operators.translate.TranslateDatasetsListOperator` + +Using the operator +"""""""""""""""""" + +Basic usage of the operator: + +.. exampleinclude:: /../../providers/tests/system/google/cloud/translate/example_translate_dataset.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_translate_automl_list_datasets] + :end-before: [END howto_operator_translate_automl_list_datasets] + + +.. _howto/operator:TranslateDeleteDatasetOperator: + +TranslateDeleteDatasetOperator +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Delete a native translation dataset using Cloud Translate API (Advanced V3). + +For parameter definition, take a look at +:class:`~airflow.providers.google.cloud.operators.translate.TranslateDeleteDatasetOperator` + +Using the operator +"""""""""""""""""" + +Basic usage of the operator: + +.. exampleinclude:: /../../providers/tests/system/google/cloud/translate/example_translate_dataset.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_translate_automl_delete_dataset] + :end-before: [END howto_operator_translate_automl_delete_dataset] + + More information """""""""""""""""" See: Base (V2) `Google Cloud Translate documentation `_. Advanced (V3) `Google Cloud Translate (Advanced) documentation `_. +Datasets `Legacy and native dataset comparison `_. Reference diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 20e10c44a12c6..fa8ffb4a2c3c6 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -967,6 +967,7 @@ LineItem lineterminator linter linux +ListDatasetsPager ListGenerator ListInfoTypesResponse ListSecretsPager diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index c23d2c27825a8..24745db0f4002 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -679,7 +679,7 @@ "google-cloud-storage>=2.7.0", "google-cloud-tasks>=2.13.0", "google-cloud-texttospeech>=2.14.1", - "google-cloud-translate>=3.11.0", + "google-cloud-translate>=3.16.0", "google-cloud-videointelligence>=2.11.0", "google-cloud-vision>=3.4.0", "google-cloud-workflows>=1.10.0", diff --git a/hatch_build.py b/hatch_build.py index 22627bfe94ad4..ddaa7aa2671dc 100644 --- a/hatch_build.py +++ b/hatch_build.py @@ -248,7 +248,7 @@ ], "devel-static-checks": [ "black>=23.12.0", - "ruff==0.7.3", + "ruff==0.8.0", "yamllint>=1.33.0", ], "devel-tests": [ diff --git a/newsfragments/aip-72.significant.rst b/newsfragments/aip-72.significant.rst index 9fc34004de7a5..e43e0c2f86c01 100644 --- a/newsfragments/aip-72.significant.rst +++ b/newsfragments/aip-72.significant.rst @@ -30,4 +30,8 @@ As part of this change the following breaking changes have occurred: The ``value`` field in the XCom table has been changed to a ``JSON`` type via DB migration. The XCom records that contains pickled data are archived in the ``_xcom_archive`` table. You can safely drop this table if you don't need - the data anymore. + the data anymore. To drop the table, you can use the following command or manually drop the table from the database. + + .. code-block:: bash + + airflow db drop-archived -t "_xcom_archive" diff --git a/providers/src/airflow/providers/edge/CHANGELOG.rst b/providers/src/airflow/providers/edge/CHANGELOG.rst index 48c7a76b5f0ed..d24373463f88c 100644 --- a/providers/src/airflow/providers/edge/CHANGELOG.rst +++ b/providers/src/airflow/providers/edge/CHANGELOG.rst @@ -27,6 +27,14 @@ Changelog --------- +0.6.1pre0 +......... + +Misc +~~~~ + +* ``Update jobs or edge workers who have been killed to clean up job table.`` + 0.6.0pre0 ......... diff --git a/providers/src/airflow/providers/edge/__init__.py b/providers/src/airflow/providers/edge/__init__.py index 1613f44510a7b..b3f545f82448a 100644 --- a/providers/src/airflow/providers/edge/__init__.py +++ b/providers/src/airflow/providers/edge/__init__.py @@ -29,7 +29,7 @@ __all__ = ["__version__"] -__version__ = "0.6.0pre0" +__version__ = "0.6.1pre0" if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse( "2.10.0" diff --git a/providers/src/airflow/providers/edge/cli/edge_command.py b/providers/src/airflow/providers/edge/cli/edge_command.py index 4d0b46d74e79c..8cb19fb2da9df 100644 --- a/providers/src/airflow/providers/edge/cli/edge_command.py +++ b/providers/src/airflow/providers/edge/cli/edge_command.py @@ -156,7 +156,7 @@ def __init__( @staticmethod def signal_handler(sig, frame): - logger.info("Request to show down Edge Worker received, waiting for jobs to complete.") + logger.info("Request to shut down Edge Worker received, waiting for jobs to complete.") _EdgeWorkerCli.drain = True def shutdown_handler(self, sig, frame): diff --git a/providers/src/airflow/providers/edge/executors/edge_executor.py b/providers/src/airflow/providers/edge/executors/edge_executor.py index a13552fbf8a20..b990a3311571c 100644 --- a/providers/src/airflow/providers/edge/executors/edge_executor.py +++ b/providers/src/airflow/providers/edge/executors/edge_executor.py @@ -26,7 +26,7 @@ from airflow.configuration import conf from airflow.executors.base_executor import BaseExecutor from airflow.models.abstractoperator import DEFAULT_QUEUE -from airflow.models.taskinstance import TaskInstanceState +from airflow.models.taskinstance import TaskInstance, TaskInstanceState from airflow.providers.edge.cli.edge_command import EDGE_COMMANDS from airflow.providers.edge.models.edge_job import EdgeJobModel from airflow.providers.edge.models.edge_logs import EdgeLogsModel @@ -42,7 +42,6 @@ from sqlalchemy.orm import Session from airflow.executors.base_executor import CommandType - from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey PARALLELISM: int = conf.getint("core", "PARALLELISM") @@ -108,6 +107,30 @@ def _check_worker_liveness(self, session: Session) -> bool: return changed + def _update_orphaned_jobs(self, session: Session) -> bool: + """Update status ob jobs when workers die and don't update anymore.""" + heartbeat_interval: int = conf.getint("scheduler", "scheduler_zombie_task_threshold") + lifeless_jobs: list[EdgeJobModel] = ( + session.query(EdgeJobModel) + .filter( + EdgeJobModel.state == TaskInstanceState.RUNNING, + EdgeJobModel.last_update < (timezone.utcnow() - timedelta(seconds=heartbeat_interval)), + ) + .all() + ) + + for job in lifeless_jobs: + ti = TaskInstance.get_task_instance( + dag_id=job.dag_id, + run_id=job.run_id, + task_id=job.task_id, + map_index=job.map_index, + session=session, + ) + job.state = ti.state if ti else TaskInstanceState.REMOVED + + return bool(lifeless_jobs) + def _purge_jobs(self, session: Session) -> bool: """Clean finished jobs.""" purged_marker = False @@ -117,7 +140,12 @@ def _purge_jobs(self, session: Session) -> bool: session.query(EdgeJobModel) .filter( EdgeJobModel.state.in_( - [TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS, TaskInstanceState.FAILED] + [ + TaskInstanceState.RUNNING, + TaskInstanceState.SUCCESS, + TaskInstanceState.FAILED, + TaskInstanceState.REMOVED, + ] ) ) .all() @@ -145,7 +173,7 @@ def _purge_jobs(self, session: Session) -> bool: job.state == TaskInstanceState.SUCCESS and job.last_update_t < (datetime.now() - timedelta(minutes=job_success_purge)).timestamp() ) or ( - job.state == TaskInstanceState.FAILED + job.state in (TaskInstanceState.FAILED, TaskInstanceState.REMOVED) and job.last_update_t < (datetime.now() - timedelta(minutes=job_fail_purge)).timestamp() ): if job.key in self.last_reported_state: @@ -168,7 +196,10 @@ def _purge_jobs(self, session: Session) -> bool: def sync(self, session: Session = NEW_SESSION) -> None: """Sync will get called periodically by the heartbeat method.""" with Stats.timer("edge_executor.sync.duration"): - if self._purge_jobs(session) or self._check_worker_liveness(session): + orphaned = self._update_orphaned_jobs(session) + purged = self._purge_jobs(session) + liveness = self._check_worker_liveness(session) + if purged or liveness or orphaned: session.commit() def end(self) -> None: diff --git a/providers/src/airflow/providers/edge/models/edge_worker.py b/providers/src/airflow/providers/edge/models/edge_worker.py index b65d935038854..7fdcb0cf3d413 100644 --- a/providers/src/airflow/providers/edge/models/edge_worker.py +++ b/providers/src/airflow/providers/edge/models/edge_worker.py @@ -62,7 +62,7 @@ class EdgeWorkerState(str, Enum): TERMINATING = "terminating" """Edge Worker is completing work and stopping.""" OFFLINE = "offline" - """Edge Worker was show down.""" + """Edge Worker was shut down.""" UNKNOWN = "unknown" """No heartbeat signal from worker for some time, Edge Worker probably down.""" diff --git a/providers/src/airflow/providers/edge/provider.yaml b/providers/src/airflow/providers/edge/provider.yaml index 96ce7f152f7f5..6fe609502aa2b 100644 --- a/providers/src/airflow/providers/edge/provider.yaml +++ b/providers/src/airflow/providers/edge/provider.yaml @@ -27,7 +27,7 @@ source-date-epoch: 1729683247 # note that those versions are maintained by release manager - do not update them manually versions: - - 0.6.0pre0 + - 0.6.1pre0 dependencies: - apache-airflow>=2.10.0 diff --git a/providers/src/airflow/providers/google/cloud/hooks/bigquery.py b/providers/src/airflow/providers/google/cloud/hooks/bigquery.py index 159a8f3f639b5..bce89e9e4184b 100644 --- a/providers/src/airflow/providers/google/cloud/hooks/bigquery.py +++ b/providers/src/airflow/providers/google/cloud/hooks/bigquery.py @@ -2095,7 +2095,7 @@ async def create_job_for_partition_get( query_request = { "query": "SELECT partition_id " f"FROM `{project_id}.{dataset_id}.INFORMATION_SCHEMA.PARTITIONS`" - + (f" WHERE table_id={table_id}" if table_id else ""), + + (f" WHERE table_name='{table_id}'" if table_id else ""), "useLegacySql": False, } job_query_resp = await job_client.query(query_request, cast(Session, session)) diff --git a/providers/src/airflow/providers/google/cloud/hooks/translate.py b/providers/src/airflow/providers/google/cloud/hooks/translate.py index 51cb88f1bacef..6ddb220f3e789 100644 --- a/providers/src/airflow/providers/google/cloud/hooks/translate.py +++ b/providers/src/airflow/providers/google/cloud/hooks/translate.py @@ -29,6 +29,7 @@ from google.api_core.exceptions import GoogleAPICallError from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault +from google.api_core.retry import Retry from google.cloud.translate_v2 import Client from google.cloud.translate_v3 import TranslationServiceClient @@ -38,13 +39,31 @@ if TYPE_CHECKING: from google.api_core.operation import Operation - from google.api_core.retry import Retry + from google.cloud.translate_v3.services.translation_service import pagers from google.cloud.translate_v3.types import ( + DatasetInputConfig, InputConfig, OutputConfig, TranslateTextGlossaryConfig, TransliterationConfig, + automl_translation, ) + from proto import Message + + +class WaitOperationNotDoneYetError(Exception): + """Wait operation not done yet error.""" + + pass + + +def _if_exc_is_wait_failed_error(exc: Exception): + return isinstance(exc, WaitOperationNotDoneYetError) + + +def _check_if_operation_done(operation: Operation): + if not operation.done(): + raise WaitOperationNotDoneYetError("Operation is not done yet.") class CloudTranslateHook(GoogleBaseHook): @@ -163,7 +182,42 @@ def get_client(self) -> TranslationServiceClient: return self._client @staticmethod - def wait_for_operation(operation: Operation, timeout: int | None = None): + def wait_for_operation_done( + *, + operation: Operation, + timeout: float | None = None, + initial: float = 3, + multiplier: float = 2, + maximum: float = 3600, + ) -> None: + """ + Wait for long-running operation to be done. + + Calls operation.done() until success or timeout exhaustion, following the back-off retry strategy. + See `google.api_core.retry.Retry`. + It's intended use on `Operation` instances that have empty result + (:class `google.protobuf.empty_pb2.Empty`) by design. + Thus calling operation.result() for such operation triggers the exception + ``GoogleAPICallError("Unexpected state: Long-running operation had neither response nor error set.")`` + even though operation itself is totally fine. + """ + wait_op_for_done = Retry( + predicate=_if_exc_is_wait_failed_error, + initial=initial, + timeout=timeout, + multiplier=multiplier, + maximum=maximum, + )(_check_if_operation_done) + try: + wait_op_for_done(operation=operation) + except GoogleAPICallError: + if timeout: + timeout = int(timeout) + error = operation.exception(timeout=timeout) + raise AirflowException(error) + + @staticmethod + def wait_for_operation_result(operation: Operation, timeout: int | None = None) -> Message: """Wait for long-lasting operation to complete.""" try: return operation.result(timeout=timeout) @@ -171,6 +225,11 @@ def wait_for_operation(operation: Operation, timeout: int | None = None): error = operation.exception(timeout=timeout) raise AirflowException(error) + @staticmethod + def extract_object_id(obj: dict) -> str: + """Return unique id of the object.""" + return obj["name"].rpartition("/")[-1] + def translate_text( self, *, @@ -208,12 +267,10 @@ def translate_text( If not specified, 'global' is used. Non-global location is required for requests using AutoML models or custom glossaries. - Models and glossaries must be within the same region (have the same location-id). :param model: Optional. The ``model`` type requested for this translation. If not provided, the default Google model (NMT) will be used. - The format depends on model type: - AutoML Translation models: @@ -308,8 +365,8 @@ def batch_translate_text( :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. - :returns: Operation object with the batch text translate results, - that are returned by batches as they are ready. + :return: Operation object with the batch text translate results, + that are returned by batches as they are ready. """ client = self.get_client() if location == "global": @@ -334,3 +391,174 @@ def batch_translate_text( metadata=metadata, ) return result + + def create_dataset( + self, + *, + project_id: str = PROVIDE_PROJECT_ID, + location: str, + dataset: dict | automl_translation.Dataset, + timeout: float | _MethodDefault = DEFAULT, + metadata: Sequence[tuple[str, str]] = (), + retry: Retry | _MethodDefault | None = DEFAULT, + ) -> Operation: + """ + Create the translation dataset. + + :param dataset: The dataset to create. If a dict is provided, it must correspond to + the automl_translation.Dataset type. + :param project_id: ID of the Google Cloud project where dataset is located. If not provided + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: `Operation` object for the dataset to be created. + """ + client = self.get_client() + parent = f"projects/{project_id or self.project_id}/locations/{location}" + return client.create_dataset( + request={"parent": parent, "dataset": dataset}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def get_dataset( + self, + dataset_id: str, + project_id: str, + location: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | _MethodDefault = DEFAULT, + metadata: Sequence[tuple[str, str]] = (), + ) -> automl_translation.Dataset: + """ + Retrieve the dataset for the given dataset_id. + + :param dataset_id: ID of translation dataset to be retrieved. + :param project_id: ID of the Google Cloud project where dataset is located. If not provided + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: `automl_translation.Dataset` instance. + """ + client = self.get_client() + name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}" + return client.get_dataset( + request={"name": name}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def import_dataset_data( + self, + dataset_id: str, + location: str, + input_config: dict | DatasetInputConfig, + project_id: str = PROVIDE_PROJECT_ID, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """ + Import data into the translation dataset. + + :param dataset_id: ID of the translation dataset. + :param input_config: The desired input location and its domain specific semantics, if any. + If a dict is provided, it must be of the same form as the protobuf message InputConfig. + :param project_id: ID of the Google Cloud project where dataset is located if None then + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: `Operation` object for the import data. + """ + client = self.get_client() + name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}" + result = client.import_data( + request={"dataset": name, "input_config": input_config}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + def list_datasets( + self, + project_id: str, + location: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | _MethodDefault = DEFAULT, + metadata: Sequence[tuple[str, str]] = (), + ) -> pagers.ListDatasetsPager: + """ + List translation datasets in a project. + + :param project_id: ID of the Google Cloud project where dataset is located. If not provided + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: ``pagers.ListDatasetsPager`` instance, iterable object to retrieve the datasets list. + """ + client = self.get_client() + parent = f"projects/{project_id}/locations/{location}" + result = client.list_datasets( + request={"parent": parent}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + def delete_dataset( + self, + dataset_id: str, + project_id: str, + location: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """ + Delete the translation dataset and all of its contents. + + :param dataset_id: ID of dataset to be deleted. + :param project_id: ID of the Google Cloud project where dataset is located. If not provided + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: `Operation` object with dataset deletion results, when finished. + """ + client = self.get_client() + name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}" + result = client.delete_dataset( + request={"name": name}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result diff --git a/providers/src/airflow/providers/google/cloud/links/translate.py b/providers/src/airflow/providers/google/cloud/links/translate.py index d8cbd18d00dad..0d1489ddcfc81 100644 --- a/providers/src/airflow/providers/google/cloud/links/translate.py +++ b/providers/src/airflow/providers/google/cloud/links/translate.py @@ -45,6 +45,11 @@ TRANSLATION_TRANSLATE_TEXT_BATCH = BASE_LINK + "/storage/browser/{output_uri_prefix}?project={project_id}" +TRANSLATION_NATIVE_DATASET_LINK = ( + TRANSLATION_BASE_LINK + "/locations/{location}/datasets/{dataset_id}/sentences?project={project_id}" +) +TRANSLATION_NATIVE_LIST_LINK = TRANSLATION_BASE_LINK + "/datasets?project={project_id}" + class TranslationLegacyDatasetLink(BaseGoogleLink): """ @@ -214,3 +219,54 @@ def persist( "output_uri_prefix": TranslateTextBatchLink.extract_output_uri_prefix(output_config), }, ) + + +class TranslationNativeDatasetLink(BaseGoogleLink): + """ + Helper class for constructing Legacy Translation Dataset link. + + Legacy Datasets are created and managed by AutoML API. + """ + + name = "Translation Native Dataset" + key = "translation_naive_dataset" + format_str = TRANSLATION_NATIVE_DATASET_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + dataset_id: str, + project_id: str, + ): + task_instance.xcom_push( + context, + key=TranslationNativeDatasetLink.key, + value={"location": task_instance.location, "dataset_id": dataset_id, "project_id": project_id}, + ) + + +class TranslationDatasetsListLink(BaseGoogleLink): + """ + Helper class for constructing Translation Datasets List link. + + Both legacy and native datasets are available under this link. + """ + + name = "Translation Dataset List" + key = "translation_dataset_list" + format_str = TRANSLATION_DATASET_LIST_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + ): + task_instance.xcom_push( + context, + key=TranslationDatasetsListLink.key, + value={ + "project_id": project_id, + }, + ) diff --git a/providers/src/airflow/providers/google/cloud/openlineage/utils.py b/providers/src/airflow/providers/google/cloud/openlineage/utils.py index 82172d5d241c9..403023f7b4315 100644 --- a/providers/src/airflow/providers/google/cloud/openlineage/utils.py +++ b/providers/src/airflow/providers/google/cloud/openlineage/utils.py @@ -27,6 +27,7 @@ from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.common.compat.openlineage.facet import ( + BaseFacet, ColumnLineageDatasetFacet, DocumentationDatasetFacet, Fields, @@ -41,50 +42,82 @@ BIGQUERY_URI = "bigquery" -def get_facets_from_bq_table(table: Table) -> dict[Any, Any]: +def get_facets_from_bq_table(table: Table) -> dict[str, BaseFacet]: """Get facets from BigQuery table object.""" - facets = { - "schema": SchemaDatasetFacet( + facets: dict[str, BaseFacet] = {} + if table.schema: + facets["schema"] = SchemaDatasetFacet( fields=[ SchemaDatasetFacetFields( - name=field.name, type=field.field_type, description=field.description + name=schema_field.name, type=schema_field.field_type, description=schema_field.description ) - for field in table.schema + for schema_field in table.schema ] - ), - "documentation": DocumentationDatasetFacet(description=table.description or ""), - } + ) + if table.description: + facets["documentation"] = DocumentationDatasetFacet(description=table.description) return facets def get_identity_column_lineage_facet( - field_names: list[str], + dest_field_names: list[str], input_datasets: list[Dataset], -) -> ColumnLineageDatasetFacet: +) -> dict[str, ColumnLineageDatasetFacet]: """ - Get column lineage facet. - - Simple lineage will be created, where each source column corresponds to single destination column - in each input dataset and there are no transformations made. + Get column lineage facet for identity transformations. + + This function generates a simple column lineage facet, where each destination column + consists of source columns of the same name from all input datasets that have that column. + The lineage assumes there are no transformations applied, meaning the columns retain their + identity between the source and destination datasets. + + Args: + dest_field_names: A list of destination column names for which lineage should be determined. + input_datasets: A list of input datasets with schema facets. + + Returns: + A dictionary containing a single key, `columnLineage`, mapped to a `ColumnLineageDatasetFacet`. + If no column lineage can be determined, an empty dictionary is returned - see Notes below. + + Notes: + - If any input dataset lacks a schema facet, the function immediately returns an empty dictionary. + - If any field in the source dataset's schema is not present in the destination table, + the function returns an empty dictionary. The destination table can contain extra fields, but all + source columns should be present in the destination table. + - If none of the destination columns can be matched to input dataset columns, an empty + dictionary is returned. + - Extra columns in the destination table that do not exist in the input datasets are ignored and + skipped in the lineage facet, as they cannot be traced back to a source column. + - The function assumes there are no transformations applied, meaning the columns retain their + identity between the source and destination datasets. """ - if field_names and not input_datasets: - raise ValueError("When providing `field_names` You must provide at least one `input_dataset`.") + fields_sources: dict[str, list[Dataset]] = {} + for ds in input_datasets: + if not ds.facets or "schema" not in ds.facets: + return {} + for schema_field in ds.facets["schema"].fields: # type: ignore[attr-defined] + if schema_field.name not in dest_field_names: + return {} + fields_sources[schema_field.name] = fields_sources.get(schema_field.name, []) + [ds] + + if not fields_sources: + return {} column_lineage_facet = ColumnLineageDatasetFacet( fields={ - field: Fields( + field_name: Fields( inputFields=[ - InputField(namespace=dataset.namespace, name=dataset.name, field=field) - for dataset in input_datasets + InputField(namespace=dataset.namespace, name=dataset.name, field=field_name) + for dataset in source_datasets ], transformationType="IDENTITY", transformationDescription="identical", ) - for field in field_names + for field_name, source_datasets in fields_sources.items() } ) - return column_lineage_facet + return {"columnLineage": column_lineage_facet} @define diff --git a/providers/src/airflow/providers/google/cloud/operators/dataplex.py b/providers/src/airflow/providers/google/cloud/operators/dataplex.py index 04edc10795f02..f77c648f20e18 100644 --- a/providers/src/airflow/providers/google/cloud/operators/dataplex.py +++ b/providers/src/airflow/providers/google/cloud/operators/dataplex.py @@ -686,39 +686,44 @@ def execute(self, context: Context): impersonation_chain=self.impersonation_chain, ) - self.log.info("Creating Dataplex Data Quality scan %s", self.data_scan_id) - try: - operation = hook.create_data_scan( - project_id=self.project_id, - region=self.region, - data_scan_id=self.data_scan_id, - body=self.body, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - hook.wait_for_operation(timeout=self.timeout, operation=operation) - self.log.info("Dataplex Data Quality scan %s created successfully!", self.data_scan_id) - except AlreadyExists: - self.log.info("Dataplex Data Quality scan already exists: %s", {self.data_scan_id}) - - operation = hook.update_data_scan( - project_id=self.project_id, - region=self.region, - data_scan_id=self.data_scan_id, - body=self.body, - update_mask=self.update_mask, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - hook.wait_for_operation(timeout=self.timeout, operation=operation) - self.log.info("Dataplex Data Quality scan %s updated successfully!", self.data_scan_id) - except GoogleAPICallError as e: - raise AirflowException(f"Error creating Data Quality scan {self.data_scan_id}", e) + if self.update_mask is not None: + self._update_data_scan(hook) + else: + self.log.info("Creating Dataplex Data Quality scan %s", self.data_scan_id) + try: + operation = hook.create_data_scan( + project_id=self.project_id, + region=self.region, + data_scan_id=self.data_scan_id, + body=self.body, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + hook.wait_for_operation(timeout=self.timeout, operation=operation) + self.log.info("Dataplex Data Quality scan %s created successfully!", self.data_scan_id) + except AlreadyExists: + self._update_data_scan(hook) + except GoogleAPICallError as e: + raise AirflowException(f"Error creating Data Quality scan {self.data_scan_id}", e) return self.data_scan_id + def _update_data_scan(self, hook: DataplexHook): + self.log.info("Dataplex Data Quality scan already exists: %s", {self.data_scan_id}) + operation = hook.update_data_scan( + project_id=self.project_id, + region=self.region, + data_scan_id=self.data_scan_id, + body=self.body, + update_mask=self.update_mask, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + hook.wait_for_operation(timeout=self.timeout, operation=operation) + self.log.info("Dataplex Data Quality scan %s updated successfully!", self.data_scan_id) + class DataplexGetDataQualityScanOperator(GoogleCloudBaseOperator): """ diff --git a/providers/src/airflow/providers/google/cloud/operators/translate.py b/providers/src/airflow/providers/google/cloud/operators/translate.py index a0fa9243e01a4..d384e9b8efa9b 100644 --- a/providers/src/airflow/providers/google/cloud/operators/translate.py +++ b/providers/src/airflow/providers/google/cloud/operators/translate.py @@ -26,17 +26,23 @@ from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.translate import CloudTranslateHook, TranslateHook -from airflow.providers.google.cloud.links.translate import TranslateTextBatchLink +from airflow.providers.google.cloud.links.translate import ( + TranslateTextBatchLink, + TranslationDatasetsListLink, + TranslationNativeDatasetLink, +) from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID if TYPE_CHECKING: from google.api_core.retry import Retry from google.cloud.translate_v3.types import ( + DatasetInputConfig, InputConfig, OutputConfig, TranslateTextGlossaryConfig, TransliterationConfig, + automl_translation, ) from airflow.utils.context import Context @@ -266,7 +272,7 @@ class TranslateTextBatchOperator(GoogleCloudBaseOperator): See https://cloud.google.com/translate/docs/advanced/batch-translation For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:TranslateTextBatchOperator`. + :ref:`howto/operator:TranslateTextBatchOperator`. :param project_id: Optional. The ID of the Google Cloud project that the service belongs to. If not specified the hook project_id will be used. @@ -381,6 +387,338 @@ def execute(self, context: Context) -> dict: project_id=self.project_id or hook.project_id, output_config=self.output_config, ) - hook.wait_for_operation(translate_operation) + hook.wait_for_operation_result(translate_operation) self.log.info("Translate text batch job finished") return {"batch_text_translate_results": self.output_config["gcs_destination"]} + + +class TranslateCreateDatasetOperator(GoogleCloudBaseOperator): + """ + Create a Google Cloud Translate dataset. + + Creates a `native` translation dataset, using API V3. + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TranslateCreateDatasetOperator`. + + :param dataset: The dataset to create. If a dict is provided, it must correspond to + the automl_translation.Dataset type. + :param project_id: ID of the Google Cloud project where dataset is located. + If not provided default project_id is used. + :param location: The location of the project. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "dataset", + "location", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + operator_extra_links = (TranslationNativeDatasetLink(),) + + def __init__( + self, + *, + project_id: str = PROVIDE_PROJECT_ID, + location: str, + dataset: dict | automl_translation.Dataset, + metadata: Sequence[tuple[str, str]] = (), + timeout: float | _MethodDefault = DEFAULT, + retry: Retry | _MethodDefault | None = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.dataset = dataset + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context) -> str: + hook = TranslateHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Dataset creation started %s...", self.dataset) + result_operation = hook.create_dataset( + dataset=self.dataset, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + result = hook.wait_for_operation_result(result_operation) + result = type(result).to_dict(result) + dataset_id = hook.extract_object_id(result) + self.xcom_push(context, key="dataset_id", value=dataset_id) + self.log.info("Dataset creation complete. The dataset_id: %s.", dataset_id) + + project_id = self.project_id or hook.project_id + TranslationNativeDatasetLink.persist( + context=context, + task_instance=self, + dataset_id=dataset_id, + project_id=project_id, + ) + return result + + +class TranslateDatasetsListOperator(GoogleCloudBaseOperator): + """ + Get a list of native Google Cloud Translation datasets in a project. + + Get project's list of `native` translation datasets, using API V3. + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TranslateDatasetsListOperator`. + + :param project_id: ID of the Google Cloud project where dataset is located. + If not provided default project_id is used. + :param location: The location of the project. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "location", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + operator_extra_links = (TranslationDatasetsListLink(),) + + def __init__( + self, + *, + project_id: str = PROVIDE_PROJECT_ID, + location: str, + metadata: Sequence[tuple[str, str]] = (), + timeout: float | _MethodDefault = DEFAULT, + retry: Retry | _MethodDefault = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = TranslateHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + project_id = self.project_id or hook.project_id + TranslationDatasetsListLink.persist( + context=context, + task_instance=self, + project_id=project_id, + ) + self.log.info("Requesting datasets list") + results_pager = hook.list_datasets( + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + result_ids = [] + for ds_item in results_pager: + ds_data = type(ds_item).to_dict(ds_item) + ds_id = hook.extract_object_id(ds_data) + result_ids.append(ds_id) + + self.log.info("Fetching the datasets list complete.") + return result_ids + + +class TranslateImportDataOperator(GoogleCloudBaseOperator): + """ + Import data to the translation dataset. + + Loads data to the translation dataset, using API V3. + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TranslateImportDataOperator`. + + :param dataset_id: The dataset_id of target native dataset to import data to. + :param input_config: The desired input location of translations language pairs file. If a dict provided, + must follow the structure of DatasetInputConfig. + If a dict is provided, it must be of the same form as the protobuf message InputConfig. + :param project_id: ID of the Google Cloud project where dataset is located. If not provided + default project_id is used. + :param location: The location of the project. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "dataset_id", + "input_config", + "location", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + operator_extra_links = (TranslationNativeDatasetLink(),) + + def __init__( + self, + *, + dataset_id: str, + location: str, + input_config: dict | DatasetInputConfig, + project_id: str = PROVIDE_PROJECT_ID, + metadata: Sequence[tuple[str, str]] = (), + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.dataset_id = dataset_id + self.input_config = input_config + self.project_id = project_id + self.location = location + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = TranslateHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + self.log.info("Importing data to dataset...") + operation = hook.import_dataset_data( + dataset_id=self.dataset_id, + input_config=self.input_config, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + project_id = self.project_id or hook.project_id + TranslationNativeDatasetLink.persist( + context=context, + task_instance=self, + dataset_id=self.dataset_id, + project_id=project_id, + ) + hook.wait_for_operation_done(operation=operation, timeout=self.timeout) + self.log.info("Importing data finished!") + + +class TranslateDeleteDatasetOperator(GoogleCloudBaseOperator): + """ + Delete translation dataset and all of its contents. + + Deletes the translation dataset and it's data, using API V3. + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TranslateDeleteDatasetOperator`. + + :param dataset_id: The dataset_id of target native dataset to be deleted. + :param location: The location of the project. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "dataset_id", + "location", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + *, + dataset_id: str, + location: str, + project_id: str = PROVIDE_PROJECT_ID, + metadata: Sequence[tuple[str, str]] = (), + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.dataset_id = dataset_id + self.project_id = project_id + self.location = location + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = TranslateHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + self.log.info("Deleting the dataset %s...", self.dataset_id) + operation = hook.delete_dataset( + dataset_id=self.dataset_id, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + hook.wait_for_operation_done(operation=operation, timeout=self.timeout) + self.log.info("Dataset deletion complete!") diff --git a/providers/src/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py b/providers/src/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py index 7be147d09b607..e1f3d3b13f565 100644 --- a/providers/src/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +++ b/providers/src/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py @@ -110,6 +110,7 @@ def __init__( self.location = location self.impersonation_chain = impersonation_chain self.hook: BigQueryHook | None = None + self._job_conf: dict = {} def _prepare_job_configuration(self): self.source_project_dataset_tables = ( @@ -154,39 +155,94 @@ def _prepare_job_configuration(self): return configuration - def _submit_job( - self, - hook: BigQueryHook, - configuration: dict, - ) -> str: - job = hook.insert_job(configuration=configuration, project_id=hook.project_id) - return job.job_id - def execute(self, context: Context) -> None: self.log.info( "Executing copy of %s into: %s", self.source_project_dataset_tables, self.destination_project_dataset_table, ) - hook = BigQueryHook( + self.hook = BigQueryHook( gcp_conn_id=self.gcp_conn_id, location=self.location, impersonation_chain=self.impersonation_chain, ) - self.hook = hook - if not hook.project_id: + if not self.hook.project_id: raise ValueError("The project_id should be set") configuration = self._prepare_job_configuration() - job_id = self._submit_job(hook=hook, configuration=configuration) + self._job_conf = self.hook.insert_job( + configuration=configuration, project_id=self.hook.project_id + ).to_api_repr() - job = hook.get_job(job_id=job_id, location=self.location).to_api_repr() - conf = job["configuration"]["copy"]["destinationTable"] + dest_table_info = self._job_conf["configuration"]["copy"]["destinationTable"] BigQueryTableLink.persist( context=context, task_instance=self, - dataset_id=conf["datasetId"], - project_id=conf["projectId"], - table_id=conf["tableId"], + dataset_id=dest_table_info["datasetId"], + project_id=dest_table_info["projectId"], + table_id=dest_table_info["tableId"], + ) + + def get_openlineage_facets_on_complete(self, task_instance): + """Implement on_complete as we will include final BQ job id.""" + from airflow.providers.common.compat.openlineage.facet import ( + Dataset, + ExternalQueryRunFacet, + ) + from airflow.providers.google.cloud.openlineage.utils import ( + BIGQUERY_NAMESPACE, + get_facets_from_bq_table, + get_identity_column_lineage_facet, ) + from airflow.providers.openlineage.extractors import OperatorLineage + + if not self.hook: + self.hook = BigQueryHook( + gcp_conn_id=self.gcp_conn_id, + location=self.location, + impersonation_chain=self.impersonation_chain, + ) + + if not self._job_conf: + self.log.debug("OpenLineage could not find BQ job configuration.") + return OperatorLineage() + + bq_job_id = self._job_conf["jobReference"]["jobId"] + source_tables_info = self._job_conf["configuration"]["copy"]["sourceTables"] + dest_table_info = self._job_conf["configuration"]["copy"]["destinationTable"] + + run_facets = { + "externalQuery": ExternalQueryRunFacet(externalQueryId=bq_job_id, source="bigquery"), + } + + input_datasets = [] + for in_table_info in source_tables_info: + table_id = ".".join( + (in_table_info["projectId"], in_table_info["datasetId"], in_table_info["tableId"]) + ) + table_object = self.hook.get_client().get_table(table_id) + input_datasets.append( + Dataset( + namespace=BIGQUERY_NAMESPACE, name=table_id, facets=get_facets_from_bq_table(table_object) + ) + ) + + out_table_id = ".".join( + (dest_table_info["projectId"], dest_table_info["datasetId"], dest_table_info["tableId"]) + ) + out_table_object = self.hook.get_client().get_table(out_table_id) + output_dataset_facets = { + **get_facets_from_bq_table(out_table_object), + **get_identity_column_lineage_facet( + dest_field_names=[field.name for field in out_table_object.schema], + input_datasets=input_datasets, + ), + } + output_dataset = Dataset( + namespace=BIGQUERY_NAMESPACE, + name=out_table_id, + facets=output_dataset_facets, + ) + + return OperatorLineage(inputs=input_datasets, outputs=[output_dataset], run_facets=run_facets) diff --git a/providers/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py b/providers/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py index e2588b8976e33..2833f79a3e81a 100644 --- a/providers/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +++ b/providers/src/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py @@ -294,6 +294,7 @@ def get_openlineage_facets_on_complete(self, task_instance): from pathlib import Path from airflow.providers.common.compat.openlineage.facet import ( + BaseFacet, Dataset, ExternalQueryRunFacet, Identifier, @@ -322,12 +323,12 @@ def get_openlineage_facets_on_complete(self, task_instance): facets=get_facets_from_bq_table(table_object), ) - output_dataset_facets = { - "schema": input_dataset.facets["schema"], - "columnLineage": get_identity_column_lineage_facet( - field_names=[field.name for field in table_object.schema], input_datasets=[input_dataset] - ), - } + output_dataset_facets: dict[str, BaseFacet] = get_identity_column_lineage_facet( + dest_field_names=[field.name for field in table_object.schema], input_datasets=[input_dataset] + ) + if "schema" in input_dataset.facets: + output_dataset_facets["schema"] = input_dataset.facets["schema"] + output_datasets = [] for uri in sorted(self.destination_cloud_storage_uris): bucket, blob = _parse_gcs_url(uri) diff --git a/providers/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py b/providers/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py index 06b9a94171b10..6dfe1bd1a2c31 100644 --- a/providers/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py +++ b/providers/src/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py @@ -784,9 +784,10 @@ def get_openlineage_facets_on_complete(self, task_instance): source_objects = ( self.source_objects if isinstance(self.source_objects, list) else [self.source_objects] ) - input_dataset_facets = { - "schema": output_dataset_facets["schema"], - } + input_dataset_facets = {} + if "schema" in output_dataset_facets: + input_dataset_facets["schema"] = output_dataset_facets["schema"] + input_datasets = [] for blob in sorted(source_objects): additional_facets = {} @@ -811,14 +812,16 @@ def get_openlineage_facets_on_complete(self, task_instance): ) input_datasets.append(dataset) - output_dataset_facets["columnLineage"] = get_identity_column_lineage_facet( - field_names=[field.name for field in table_object.schema], input_datasets=input_datasets - ) - output_dataset = Dataset( namespace="bigquery", name=str(table_object.reference), - facets=output_dataset_facets, + facets={ + **output_dataset_facets, + **get_identity_column_lineage_facet( + dest_field_names=[field.name for field in table_object.schema], + input_datasets=input_datasets, + ), + }, ) run_facets = {} diff --git a/providers/src/airflow/providers/google/provider.yaml b/providers/src/airflow/providers/google/provider.yaml index ce8ce057432d9..b0107286ca7ca 100644 --- a/providers/src/airflow/providers/google/provider.yaml +++ b/providers/src/airflow/providers/google/provider.yaml @@ -150,7 +150,7 @@ dependencies: - google-cloud-storage-transfer>=1.4.1 - google-cloud-tasks>=2.13.0 - google-cloud-texttospeech>=2.14.1 - - google-cloud-translate>=3.11.0 + - google-cloud-translate>=3.16.0 - google-cloud-videointelligence>=2.11.0 - google-cloud-vision>=3.4.0 - google-cloud-workflows>=1.10.0 @@ -1292,6 +1292,8 @@ extra-links: - airflow.providers.google.cloud.links.translate.TranslationLegacyModelTrainLink - airflow.providers.google.cloud.links.translate.TranslationLegacyModelPredictLink - airflow.providers.google.cloud.links.translate.TranslateTextBatchLink + - airflow.providers.google.cloud.links.translate.TranslationNativeDatasetLink + - airflow.providers.google.cloud.links.translate.TranslationDatasetsListLink secrets-backends: diff --git a/providers/tests/edge/executors/test_edge_executor.py b/providers/tests/edge/executors/test_edge_executor.py index 7970e5fad04ce..126afa1fb70b5 100644 --- a/providers/tests/edge/executors/test_edge_executor.py +++ b/providers/tests/edge/executors/test_edge_executor.py @@ -16,11 +16,12 @@ # under the License. from __future__ import annotations -from datetime import datetime +from datetime import datetime, timedelta from unittest.mock import patch import pytest +from airflow.configuration import conf from airflow.models.taskinstancekey import TaskInstanceKey from airflow.providers.edge.executors.edge_executor import EdgeExecutor from airflow.providers.edge.models.edge_job import EdgeJobModel @@ -65,6 +66,44 @@ def test_execute_async_ok_command(self): assert jobs[0].run_id == "test_run" assert jobs[0].task_id == "test_task" + def test_sync_orphaned_tasks(self): + executor = EdgeExecutor() + + delta_to_purge = timedelta(minutes=conf.getint("edge", "job_fail_purge") + 1) + delta_to_orphaned = timedelta(seconds=conf.getint("scheduler", "scheduler_zombie_task_threshold") + 1) + + with create_session() as session: + for task_id, state, last_update in [ + ( + "started_running_orphaned", + TaskInstanceState.RUNNING, + timezone.utcnow() - delta_to_orphaned, + ), + ("started_removed", TaskInstanceState.REMOVED, timezone.utcnow() - delta_to_purge), + ]: + session.add( + EdgeJobModel( + dag_id="test_dag", + task_id=task_id, + run_id="test_run", + map_index=-1, + try_number=1, + state=state, + queue="default", + command="dummy", + last_update=last_update, + ) + ) + session.commit() + + executor.sync() + + with create_session() as session: + jobs = session.query(EdgeJobModel).all() + assert len(jobs) == 1 + assert jobs[0].task_id == "started_running_orphaned" + assert jobs[0].task_id == "started_running_orphaned" + @patch("airflow.providers.edge.executors.edge_executor.EdgeExecutor.running_state") @patch("airflow.providers.edge.executors.edge_executor.EdgeExecutor.success") @patch("airflow.providers.edge.executors.edge_executor.EdgeExecutor.fail") @@ -77,12 +116,14 @@ def remove_from_running(key: TaskInstanceKey): mock_success.side_effect = remove_from_running mock_fail.side_effect = remove_from_running + delta_to_purge = timedelta(minutes=conf.getint("edge", "job_fail_purge") + 1) + # Prepare some data with create_session() as session: - for task_id, state in [ - ("started_running", TaskInstanceState.RUNNING), - ("started_success", TaskInstanceState.SUCCESS), - ("started_failed", TaskInstanceState.FAILED), + for task_id, state, last_update in [ + ("started_running", TaskInstanceState.RUNNING, timezone.utcnow()), + ("started_success", TaskInstanceState.SUCCESS, timezone.utcnow() - delta_to_purge), + ("started_failed", TaskInstanceState.FAILED, timezone.utcnow() - delta_to_purge), ]: session.add( EdgeJobModel( @@ -94,7 +135,7 @@ def remove_from_running(key: TaskInstanceKey): state=state, queue="default", command="dummy", - last_update=timezone.utcnow(), + last_update=last_update, ) ) key = TaskInstanceKey( @@ -106,6 +147,12 @@ def remove_from_running(key: TaskInstanceKey): executor.sync() + with create_session() as session: + jobs = session.query(EdgeJobModel).all() + assert len(session.query(EdgeJobModel).all()) == 1 + assert jobs[0].task_id == "started_running" + assert jobs[0].state == TaskInstanceState.RUNNING + assert len(executor.running) == 1 mock_running_state.assert_called_once() mock_success.assert_called_once() diff --git a/providers/tests/google/cloud/hooks/test_bigquery.py b/providers/tests/google/cloud/hooks/test_bigquery.py index b0e7f8efb2098..ee0f904bb94b6 100644 --- a/providers/tests/google/cloud/hooks/test_bigquery.py +++ b/providers/tests/google/cloud/hooks/test_bigquery.py @@ -1604,7 +1604,7 @@ async def test_create_job_for_partition_get_with_table(self, mock_job_instance, expected_query_request = { "query": "SELECT partition_id " f"FROM `{PROJECT_ID}.{DATASET_ID}.INFORMATION_SCHEMA.PARTITIONS`" - f" WHERE table_id={TABLE_ID}", + f" WHERE table_name='{TABLE_ID}'", "useLegacySql": False, } await hook.create_job_for_partition_get( diff --git a/providers/tests/google/cloud/openlineage/test_utils.py b/providers/tests/google/cloud/openlineage/test_utils.py index 4f2db0038b7b7..e3f40bee1549e 100644 --- a/providers/tests/google/cloud/openlineage/test_utils.py +++ b/providers/tests/google/cloud/openlineage/test_utils.py @@ -19,7 +19,6 @@ import json from unittest.mock import MagicMock -import pytest from google.cloud.bigquery.table import Table from airflow.providers.common.compat.openlineage.facet import ( @@ -89,19 +88,78 @@ def test_get_facets_from_bq_table(): def test_get_facets_from_empty_bq_table(): - expected_facets = { - "schema": SchemaDatasetFacet(fields=[]), - "documentation": DocumentationDatasetFacet(description=""), - } result = get_facets_from_bq_table(TEST_EMPTY_TABLE) - assert result == expected_facets + assert result == {} + + +def test_get_identity_column_lineage_facet_source_datasets_schemas_are_subsets(): + field_names = ["field1", "field2", "field3"] + input_datasets = [ + Dataset( + namespace="gs://first_bucket", + name="dir1", + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field1", type="STRING"), + ] + ) + }, + ), + Dataset( + namespace="gs://second_bucket", + name="dir2", + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field2", type="STRING"), + ] + ) + }, + ), + ] + expected_facet = ColumnLineageDatasetFacet( + fields={ + "field1": Fields( + inputFields=[ + InputField( + namespace="gs://first_bucket", + name="dir1", + field="field1", + ) + ], + transformationType="IDENTITY", + transformationDescription="identical", + ), + "field2": Fields( + inputFields=[ + InputField( + namespace="gs://second_bucket", + name="dir2", + field="field2", + ), + ], + transformationType="IDENTITY", + transformationDescription="identical", + ), + # field3 is missing here as it's not present in any source dataset + } + ) + result = get_identity_column_lineage_facet(dest_field_names=field_names, input_datasets=input_datasets) + assert result == {"columnLineage": expected_facet} def test_get_identity_column_lineage_facet_multiple_input_datasets(): field_names = ["field1", "field2"] + schema_facet = SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field1", type="STRING"), + SchemaDatasetFacetFields(name="field2", type="STRING"), + ] + ) input_datasets = [ - Dataset(namespace="gs://first_bucket", name="dir1"), - Dataset(namespace="gs://second_bucket", name="dir2"), + Dataset(namespace="gs://first_bucket", name="dir1", facets={"schema": schema_facet}), + Dataset(namespace="gs://second_bucket", name="dir2", facets={"schema": schema_facet}), ] expected_facet = ColumnLineageDatasetFacet( fields={ @@ -139,24 +197,69 @@ def test_get_identity_column_lineage_facet_multiple_input_datasets(): ), } ) - result = get_identity_column_lineage_facet(field_names=field_names, input_datasets=input_datasets) - assert result == expected_facet + result = get_identity_column_lineage_facet(dest_field_names=field_names, input_datasets=input_datasets) + assert result == {"columnLineage": expected_facet} + + +def test_get_identity_column_lineage_facet_dest_cols_not_in_input_datasets(): + field_names = ["x", "y"] + input_datasets = [ + Dataset( + namespace="gs://first_bucket", + name="dir1", + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field1", type="STRING"), + ] + ) + }, + ), + Dataset( + namespace="gs://second_bucket", + name="dir2", + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field2", type="STRING"), + ] + ) + }, + ), + ] + + result = get_identity_column_lineage_facet(dest_field_names=field_names, input_datasets=input_datasets) + assert result == {} + + +def test_get_identity_column_lineage_facet_no_schema_in_input_dataset(): + field_names = ["field1", "field2"] + input_datasets = [ + Dataset(namespace="gs://first_bucket", name="dir1"), + ] + result = get_identity_column_lineage_facet(dest_field_names=field_names, input_datasets=input_datasets) + assert result == {} def test_get_identity_column_lineage_facet_no_field_names(): field_names = [] + schema_facet = SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field1", type="STRING"), + SchemaDatasetFacetFields(name="field2", type="STRING"), + ] + ) input_datasets = [ - Dataset(namespace="gs://first_bucket", name="dir1"), - Dataset(namespace="gs://second_bucket", name="dir2"), + Dataset(namespace="gs://first_bucket", name="dir1", facets={"schema": schema_facet}), + Dataset(namespace="gs://second_bucket", name="dir2", facets={"schema": schema_facet}), ] - expected_facet = ColumnLineageDatasetFacet(fields={}) - result = get_identity_column_lineage_facet(field_names=field_names, input_datasets=input_datasets) - assert result == expected_facet + result = get_identity_column_lineage_facet(dest_field_names=field_names, input_datasets=input_datasets) + assert result == {} def test_get_identity_column_lineage_facet_no_input_datasets(): field_names = ["field1", "field2"] input_datasets = [] - with pytest.raises(ValueError): - get_identity_column_lineage_facet(field_names=field_names, input_datasets=input_datasets) + result = get_identity_column_lineage_facet(dest_field_names=field_names, input_datasets=input_datasets) + assert result == {} diff --git a/providers/tests/google/cloud/operators/test_dataplex.py b/providers/tests/google/cloud/operators/test_dataplex.py index 67c9b8ca10f9f..1eec9008e2c10 100644 --- a/providers/tests/google/cloud/operators/test_dataplex.py +++ b/providers/tests/google/cloud/operators/test_dataplex.py @@ -672,6 +672,18 @@ def test_execute(self, hook_mock): api_version=API_VERSION, impersonation_chain=IMPERSONATION_CHAIN, ) + update_operator = DataplexCreateOrUpdateDataQualityScanOperator( + task_id=TASK_ID, + project_id=PROJECT_ID, + region=REGION, + data_scan_id=DATA_SCAN_ID, + body={}, + update_mask={}, + api_version=API_VERSION, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + update_operator.execute(context=mock.MagicMock()) hook_mock.return_value.create_data_scan.assert_called_once_with( project_id=PROJECT_ID, region=REGION, @@ -681,6 +693,16 @@ def test_execute(self, hook_mock): timeout=None, metadata=(), ) + hook_mock.return_value.update_data_scan.assert_called_once_with( + project_id=PROJECT_ID, + region=REGION, + data_scan_id=DATA_SCAN_ID, + body={}, + update_mask={}, + retry=DEFAULT, + timeout=None, + metadata=(), + ) class TestDataplexCreateDataProfileScanOperator: diff --git a/providers/tests/google/cloud/operators/test_translate.py b/providers/tests/google/cloud/operators/test_translate.py index 79f65395369b8..45af2dae92890 100644 --- a/providers/tests/google/cloud/operators/test_translate.py +++ b/providers/tests/google/cloud/operators/test_translate.py @@ -19,15 +19,27 @@ from unittest import mock +from google.api_core.gapic_v1.method import DEFAULT +from google.cloud.translate_v3.types import automl_translation + +from airflow.providers.google.cloud.hooks.translate import TranslateHook from airflow.providers.google.cloud.operators.translate import ( CloudTranslateTextOperator, + TranslateCreateDatasetOperator, + TranslateDatasetsListOperator, + TranslateDeleteDatasetOperator, + TranslateImportDataOperator, TranslateTextBatchOperator, TranslateTextOperator, ) +from providers.tests.system.google.cloud.tasks.example_tasks import LOCATION + GCP_CONN_ID = "google_cloud_default" IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] PROJECT_ID = "test-project-id" +DATASET_ID = "sample_ds_id" +TIMEOUT_VALUE = 30 class TestCloudTranslate: @@ -97,7 +109,7 @@ def test_minimal_green_path(self, mock_hook): target_language_code="en", gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, - timeout=30, + timeout=TIMEOUT_VALUE, retry=None, model=None, ) @@ -117,7 +129,7 @@ def test_minimal_green_path(self, mock_hook): model=None, transliteration_config=None, glossary_config=None, - timeout=30, + timeout=TIMEOUT_VALUE, retry=None, metadata=(), ) @@ -185,3 +197,192 @@ def test_minimal_green_path(self, mock_hook, mock_link_persist): project_id=PROJECT_ID, output_config=OUTPUT_CONFIG, ) + + +class TestTranslateDatasetCreate: + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslationNativeDatasetLink.persist") + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslateCreateDatasetOperator.xcom_push") + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook") + def test_minimal_green_path(self, mock_hook, mock_xcom_push, mock_link_persist): + DS_CREATION_RESULT_SAMPLE = { + "display_name": "", + "example_count": 0, + "name": f"projects/{PROJECT_ID}/locations/{LOCATION}/datasets/{DATASET_ID}", + "source_language_code": "", + "target_language_code": "", + "test_example_count": 0, + "train_example_count": 0, + "validate_example_count": 0, + } + sample_operation = mock.MagicMock() + sample_operation.result.return_value = automl_translation.Dataset(DS_CREATION_RESULT_SAMPLE) + + mock_hook.return_value.create_dataset.return_value = sample_operation + mock_hook.return_value.wait_for_operation_result.side_effect = lambda operation: operation.result() + mock_hook.return_value.extract_object_id = TranslateHook.extract_object_id + + DATASET_DATA = { + "display_name": "sample ds name", + "source_language_code": "es", + "target_language_code": "uk", + } + op = TranslateCreateDatasetOperator( + task_id="task_id", + dataset=DATASET_DATA, + project_id=PROJECT_ID, + location=LOCATION, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + timeout=TIMEOUT_VALUE, + retry=None, + ) + context = mock.MagicMock() + result = op.execute(context=context) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_hook.return_value.create_dataset.assert_called_once_with( + dataset=DATASET_DATA, + project_id=PROJECT_ID, + location=LOCATION, + timeout=TIMEOUT_VALUE, + retry=None, + metadata=(), + ) + mock_xcom_push.assert_called_once_with(context, key="dataset_id", value=DATASET_ID) + mock_link_persist.assert_called_once_with( + context=context, + dataset_id=DATASET_ID, + task_instance=op, + project_id=PROJECT_ID, + ) + assert result == DS_CREATION_RESULT_SAMPLE + + +class TestTranslateListDatasets: + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslationDatasetsListLink.persist") + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook") + def test_minimal_green_path(self, mock_hook, mock_link_persist): + DS_ID_1 = "sample_ds_1" + DS_ID_2 = "sample_ds_2" + dataset_result_1 = automl_translation.Dataset( + dict( + name=f"projects/{PROJECT_ID}/locations/{LOCATION}/datasets/{DS_ID_1}", + display_name="ds1_display_name", + ) + ) + dataset_result_2 = automl_translation.Dataset( + dict( + name=f"projects/{PROJECT_ID}/locations/{LOCATION}/datasets/{DS_ID_2}", + display_name="ds1_display_name", + ) + ) + mock_hook.return_value.list_datasets.return_value = [dataset_result_1, dataset_result_2] + mock_hook.return_value.extract_object_id = TranslateHook.extract_object_id + + op = TranslateDatasetsListOperator( + task_id="task_id", + project_id=PROJECT_ID, + location=LOCATION, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + timeout=TIMEOUT_VALUE, + retry=DEFAULT, + ) + context = mock.MagicMock() + result = op.execute(context=context) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_hook.return_value.list_datasets.assert_called_once_with( + project_id=PROJECT_ID, + location=LOCATION, + timeout=TIMEOUT_VALUE, + retry=DEFAULT, + metadata=(), + ) + mock_link_persist.assert_called_once_with( + context=context, + task_instance=op, + project_id=PROJECT_ID, + ) + assert result == [DS_ID_1, DS_ID_2] + + +class TestTranslateImportData: + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslationNativeDatasetLink.persist") + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook") + def test_minimal_green_path(self, mock_hook, mock_link_persist): + INPUT_CONFIG = { + "input_files": [{"usage": "UNASSIGNED", "gcs_source": {"input_uri": "import data gcs path"}}] + } + mock_hook.return_value.import_dataset_data.return_value = mock.MagicMock() + op = TranslateImportDataOperator( + task_id="task_id", + dataset_id=DATASET_ID, + input_config=INPUT_CONFIG, + project_id=PROJECT_ID, + location=LOCATION, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + timeout=TIMEOUT_VALUE, + retry=DEFAULT, + ) + context = mock.MagicMock() + op.execute(context=context) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_hook.return_value.import_dataset_data.assert_called_once_with( + dataset_id=DATASET_ID, + input_config=INPUT_CONFIG, + project_id=PROJECT_ID, + location=LOCATION, + timeout=TIMEOUT_VALUE, + retry=DEFAULT, + metadata=(), + ) + mock_link_persist.assert_called_once_with( + context=context, + dataset_id=DATASET_ID, + task_instance=op, + project_id=PROJECT_ID, + ) + + +class TestTranslateDeleteData: + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook") + def test_minimal_green_path(self, mock_hook): + m_delete_method_result = mock.MagicMock() + mock_hook.return_value.delete_dataset.return_value = m_delete_method_result + + wait_for_done = mock_hook.return_value.wait_for_operation_done + + op = TranslateDeleteDatasetOperator( + task_id="task_id", + dataset_id=DATASET_ID, + project_id=PROJECT_ID, + location=LOCATION, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + timeout=TIMEOUT_VALUE, + retry=DEFAULT, + ) + context = mock.MagicMock() + op.execute(context=context) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_hook.return_value.delete_dataset.assert_called_once_with( + dataset_id=DATASET_ID, + project_id=PROJECT_ID, + location=LOCATION, + timeout=TIMEOUT_VALUE, + retry=DEFAULT, + metadata=(), + ) + wait_for_done.assert_called_once_with(operation=m_delete_method_result, timeout=TIMEOUT_VALUE) diff --git a/providers/tests/google/cloud/transfers/test_bigquery_to_bigquery.py b/providers/tests/google/cloud/transfers/test_bigquery_to_bigquery.py index ed06928c2ccff..304694126ffc1 100644 --- a/providers/tests/google/cloud/transfers/test_bigquery_to_bigquery.py +++ b/providers/tests/google/cloud/transfers/test_bigquery_to_bigquery.py @@ -19,16 +19,29 @@ from unittest import mock +from google.cloud.bigquery import Table + +from airflow.providers.common.compat.openlineage.facet import ( + ColumnLineageDatasetFacet, + Dataset, + DocumentationDatasetFacet, + ExternalQueryRunFacet, + Fields, + InputField, + SchemaDatasetFacet, + SchemaDatasetFacetFields, +) from airflow.providers.google.cloud.transfers.bigquery_to_bigquery import BigQueryToBigQueryOperator BQ_HOOK_PATH = "airflow.providers.google.cloud.transfers.bigquery_to_bigquery.BigQueryHook" -TASK_ID = "test-bq-create-table-operator" +TASK_ID = "test-bq-to-bq-operator" TEST_GCP_PROJECT_ID = "test-project" TEST_DATASET = "test-dataset" TEST_TABLE_ID = "test-table-id" -SOURCE_PROJECT_DATASET_TABLES = f"{TEST_GCP_PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}" -DESTINATION_PROJECT_DATASET_TABLE = f"{TEST_GCP_PROJECT_ID}.{TEST_DATASET + '_new'}.{TEST_TABLE_ID}" +SOURCE_PROJECT_DATASET_TABLE = f"{TEST_GCP_PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}" +SOURCE_PROJECT_DATASET_TABLE2 = f"{TEST_GCP_PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}-2" +DESTINATION_PROJECT_DATASET_TABLE = f"{TEST_GCP_PROJECT_ID}.{TEST_DATASET}_new.{TEST_TABLE_ID}" WRITE_DISPOSITION = "WRITE_EMPTY" CREATE_DISPOSITION = "CREATE_IF_NEEDED" LABELS = {"k1": "v1"} @@ -36,12 +49,18 @@ def split_tablename_side_effect(*args, **kwargs): - if kwargs["table_input"] == SOURCE_PROJECT_DATASET_TABLES: + if kwargs["table_input"] == SOURCE_PROJECT_DATASET_TABLE: return ( TEST_GCP_PROJECT_ID, TEST_DATASET, TEST_TABLE_ID, ) + elif kwargs["table_input"] == SOURCE_PROJECT_DATASET_TABLE2: + return ( + TEST_GCP_PROJECT_ID, + TEST_DATASET, + TEST_TABLE_ID + "-2", + ) elif kwargs["table_input"] == DESTINATION_PROJECT_DATASET_TABLE: return ( TEST_GCP_PROJECT_ID, @@ -55,7 +74,7 @@ class TestBigQueryToBigQueryOperator: def test_execute_without_location_should_execute_successfully(self, mock_hook): operator = BigQueryToBigQueryOperator( task_id=TASK_ID, - source_project_dataset_tables=SOURCE_PROJECT_DATASET_TABLES, + source_project_dataset_tables=SOURCE_PROJECT_DATASET_TABLE, destination_project_dataset_table=DESTINATION_PROJECT_DATASET_TABLE, write_disposition=WRITE_DISPOSITION, create_disposition=CREATE_DISPOSITION, @@ -95,7 +114,48 @@ def test_execute_single_regional_location_should_execute_successfully(self, mock operator = BigQueryToBigQueryOperator( task_id=TASK_ID, - source_project_dataset_tables=SOURCE_PROJECT_DATASET_TABLES, + source_project_dataset_tables=SOURCE_PROJECT_DATASET_TABLE, + destination_project_dataset_table=DESTINATION_PROJECT_DATASET_TABLE, + write_disposition=WRITE_DISPOSITION, + create_disposition=CREATE_DISPOSITION, + labels=LABELS, + encryption_configuration=ENCRYPTION_CONFIGURATION, + location=location, + ) + + mock_hook.return_value.split_tablename.side_effect = split_tablename_side_effect + operator.execute(context=mock.MagicMock()) + mock_hook.return_value.insert_job.assert_called_once_with( + configuration={ + "copy": { + "createDisposition": CREATE_DISPOSITION, + "destinationEncryptionConfiguration": ENCRYPTION_CONFIGURATION, + "destinationTable": { + "datasetId": TEST_DATASET + "_new", + "projectId": TEST_GCP_PROJECT_ID, + "tableId": TEST_TABLE_ID, + }, + "sourceTables": [ + { + "datasetId": TEST_DATASET, + "projectId": TEST_GCP_PROJECT_ID, + "tableId": TEST_TABLE_ID, + }, + ], + "writeDisposition": WRITE_DISPOSITION, + }, + "labels": LABELS, + }, + project_id=mock_hook.return_value.project_id, + ) + + @mock.patch(BQ_HOOK_PATH) + def test_get_openlineage_facets_on_complete_single_source_table(self, mock_hook): + location = "us-central1" + + operator = BigQueryToBigQueryOperator( + task_id=TASK_ID, + source_project_dataset_tables=SOURCE_PROJECT_DATASET_TABLE, destination_project_dataset_table=DESTINATION_PROJECT_DATASET_TABLE, write_disposition=WRITE_DISPOSITION, create_disposition=CREATE_DISPOSITION, @@ -104,9 +164,265 @@ def test_execute_single_regional_location_should_execute_successfully(self, mock location=location, ) + source_table_api_repr = { + "tableReference": { + "projectId": TEST_GCP_PROJECT_ID, + "datasetId": TEST_DATASET, + "tableId": TEST_TABLE_ID, + }, + "description": "Table description.", + "schema": { + "fields": [ + {"name": "field1", "type": "STRING"}, + {"name": "field2", "type": "INTEGER"}, + ] + }, + } + dest_table_api_repr = {**source_table_api_repr} + dest_table_api_repr["tableReference"]["datasetId"] = TEST_DATASET + "_new" + mock_table_data = { + SOURCE_PROJECT_DATASET_TABLE: Table.from_api_repr(source_table_api_repr), + DESTINATION_PROJECT_DATASET_TABLE: Table.from_api_repr(dest_table_api_repr), + } + + mock_hook.return_value.insert_job.return_value.to_api_repr.return_value = { + "jobReference": { + "projectId": TEST_GCP_PROJECT_ID, + "jobId": "actual_job_id", + "location": location, + }, + "configuration": { + "copy": { + "sourceTables": [ + { + "projectId": TEST_GCP_PROJECT_ID, + "datasetId": TEST_DATASET, + "tableId": TEST_TABLE_ID, + }, + ], + "destinationTable": { + "projectId": TEST_GCP_PROJECT_ID, + "datasetId": TEST_DATASET + "_new", + "tableId": TEST_TABLE_ID, + }, + } + }, + } mock_hook.return_value.split_tablename.side_effect = split_tablename_side_effect + mock_hook.return_value.get_client.return_value.get_table.side_effect = ( + lambda table_id: mock_table_data[table_id] + ) + operator.execute(context=mock.MagicMock()) - mock_hook.return_value.get_job.assert_called_once_with( - job_id=mock_hook.return_value.insert_job.return_value.job_id, + result = operator.get_openlineage_facets_on_complete(None) + + assert result.job_facets == {} + assert result.run_facets == { + "externalQuery": ExternalQueryRunFacet(externalQueryId="actual_job_id", source="bigquery") + } + assert len(result.inputs) == 1 + assert result.inputs[0] == Dataset( + namespace="bigquery", + name=SOURCE_PROJECT_DATASET_TABLE, + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field1", type="STRING"), + SchemaDatasetFacetFields(name="field2", type="INTEGER"), + ] + ), + "documentation": DocumentationDatasetFacet("Table description."), + }, + ) + assert len(result.outputs) == 1 + assert result.outputs[0] == Dataset( + namespace="bigquery", + name=DESTINATION_PROJECT_DATASET_TABLE, + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field1", type="STRING"), + SchemaDatasetFacetFields(name="field2", type="INTEGER"), + ] + ), + "documentation": DocumentationDatasetFacet("Table description."), + "columnLineage": ColumnLineageDatasetFacet( + fields={ + "field1": Fields( + inputFields=[ + InputField( + namespace="bigquery", + name=SOURCE_PROJECT_DATASET_TABLE, + field="field1", + transformations=[], + ) + ], + transformationDescription="identical", + transformationType="IDENTITY", + ), + "field2": Fields( + inputFields=[ + InputField( + namespace="bigquery", + name=SOURCE_PROJECT_DATASET_TABLE, + field="field2", + transformations=[], + ) + ], + transformationDescription="identical", + transformationType="IDENTITY", + ), + }, + dataset=[], + ), + }, + ) + + @mock.patch(BQ_HOOK_PATH) + def test_get_openlineage_facets_on_complete_multiple_source_tables(self, mock_hook): + location = "us-central1" + + operator = BigQueryToBigQueryOperator( + task_id=TASK_ID, + source_project_dataset_tables=[ + SOURCE_PROJECT_DATASET_TABLE, + SOURCE_PROJECT_DATASET_TABLE2, + ], + destination_project_dataset_table=DESTINATION_PROJECT_DATASET_TABLE, + write_disposition=WRITE_DISPOSITION, + create_disposition=CREATE_DISPOSITION, + labels=LABELS, + encryption_configuration=ENCRYPTION_CONFIGURATION, location=location, ) + + source_table_repr = { + "tableReference": { + "projectId": TEST_GCP_PROJECT_ID, + "datasetId": TEST_DATASET, + "tableId": TEST_TABLE_ID, + }, + "schema": { + "fields": [ + {"name": "field1", "type": "STRING"}, + ] + }, + } + source_table_repr2 = {**source_table_repr} + source_table_repr2["tableReference"]["tableId"] = TEST_TABLE_ID + "-2" + dest_table_api_repr = { + "tableReference": { + "projectId": TEST_GCP_PROJECT_ID, + "datasetId": TEST_DATASET + "_new", + "tableId": TEST_TABLE_ID, + }, + "schema": { + "fields": [ + {"name": "field1", "type": "STRING"}, + {"name": "field2", "type": "INTEGER"}, + ] + }, + } + mock_table_data = { + SOURCE_PROJECT_DATASET_TABLE: Table.from_api_repr(source_table_repr), + SOURCE_PROJECT_DATASET_TABLE2: Table.from_api_repr(source_table_repr2), + DESTINATION_PROJECT_DATASET_TABLE: Table.from_api_repr(dest_table_api_repr), + } + + mock_hook.return_value.insert_job.return_value.to_api_repr.return_value = { + "jobReference": { + "projectId": TEST_GCP_PROJECT_ID, + "jobId": "actual_job_id", + "location": location, + }, + "configuration": { + "copy": { + "sourceTables": [ + { + "projectId": TEST_GCP_PROJECT_ID, + "datasetId": TEST_DATASET, + "tableId": TEST_TABLE_ID, + }, + { + "projectId": TEST_GCP_PROJECT_ID, + "datasetId": TEST_DATASET, + "tableId": TEST_TABLE_ID + "-2", + }, + ], + "destinationTable": { + "projectId": TEST_GCP_PROJECT_ID, + "datasetId": TEST_DATASET + "_new", + "tableId": TEST_TABLE_ID, + }, + } + }, + } + mock_hook.return_value.split_tablename.side_effect = split_tablename_side_effect + mock_hook.return_value.get_client.return_value.get_table.side_effect = ( + lambda table_id: mock_table_data[table_id] + ) + operator.execute(context=mock.MagicMock()) + result = operator.get_openlineage_facets_on_complete(None) + assert result.job_facets == {} + assert result.run_facets == { + "externalQuery": ExternalQueryRunFacet(externalQueryId="actual_job_id", source="bigquery") + } + assert len(result.inputs) == 2 + assert result.inputs[0] == Dataset( + namespace="bigquery", + name=SOURCE_PROJECT_DATASET_TABLE, + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field1", type="STRING"), + ] + ) + }, + ) + assert result.inputs[1] == Dataset( + namespace="bigquery", + name=SOURCE_PROJECT_DATASET_TABLE2, + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field1", type="STRING"), + ] + ) + }, + ) + assert len(result.outputs) == 1 + assert result.outputs[0] == Dataset( + namespace="bigquery", + name=DESTINATION_PROJECT_DATASET_TABLE, + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields(name="field1", type="STRING"), + SchemaDatasetFacetFields(name="field2", type="INTEGER"), + ] + ), + "columnLineage": ColumnLineageDatasetFacet( + fields={ + "field1": Fields( + inputFields=[ + InputField( + namespace="bigquery", + name=SOURCE_PROJECT_DATASET_TABLE, + field="field1", + transformations=[], + ), + InputField( + namespace="bigquery", + name=SOURCE_PROJECT_DATASET_TABLE2, + field="field1", + transformations=[], + ), + ], + transformationDescription="identical", + transformationType="IDENTITY", + ) + }, + dataset=[], + ), + }, + ) diff --git a/providers/tests/google/cloud/transfers/test_bigquery_to_gcs.py b/providers/tests/google/cloud/transfers/test_bigquery_to_gcs.py index 7c2a398253752..b451d2037efcb 100644 --- a/providers/tests/google/cloud/transfers/test_bigquery_to_gcs.py +++ b/providers/tests/google/cloud/transfers/test_bigquery_to_gcs.py @@ -299,11 +299,6 @@ def test_get_openlineage_facets_on_complete_bq_dataset(self, mock_hook): def test_get_openlineage_facets_on_complete_bq_dataset_empty_table(self, mock_hook): source_project_dataset_table = f"{PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}" - expected_input_dataset_facets = { - "schema": SchemaDatasetFacet(fields=[]), - "documentation": DocumentationDatasetFacet(description=""), - } - mock_hook.return_value.split_tablename.return_value = (PROJECT_ID, TEST_DATASET, TEST_TABLE_ID) mock_hook.return_value.get_client.return_value.get_table.return_value = TEST_EMPTY_TABLE @@ -320,7 +315,7 @@ def test_get_openlineage_facets_on_complete_bq_dataset_empty_table(self, mock_ho assert lineage.inputs[0] == Dataset( namespace="bigquery", name=source_project_dataset_table, - facets=expected_input_dataset_facets, + facets={}, ) @mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_gcs.BigQueryHook") @@ -330,16 +325,6 @@ def test_get_openlineage_facets_on_complete_gcs_no_wildcard_empty_table(self, mo real_job_id = "123456_hash" bq_namespace = "bigquery" - expected_input_facets = { - "schema": SchemaDatasetFacet(fields=[]), - "documentation": DocumentationDatasetFacet(description=""), - } - - expected_output_facets = { - "schema": SchemaDatasetFacet(fields=[]), - "columnLineage": ColumnLineageDatasetFacet(fields={}), - } - mock_hook.return_value.split_tablename.return_value = (PROJECT_ID, TEST_DATASET, TEST_TABLE_ID) mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False) mock_hook.return_value.get_client.return_value.get_table.return_value = TEST_EMPTY_TABLE @@ -357,12 +342,12 @@ def test_get_openlineage_facets_on_complete_gcs_no_wildcard_empty_table(self, mo assert len(lineage.inputs) == 1 assert len(lineage.outputs) == 1 assert lineage.inputs[0] == Dataset( - namespace=bq_namespace, name=source_project_dataset_table, facets=expected_input_facets + namespace=bq_namespace, name=source_project_dataset_table, facets={} ) assert lineage.outputs[0] == Dataset( namespace=f"gs://{TEST_BUCKET}", name=f"{TEST_FOLDER}/{TEST_OBJECT_NO_WILDCARD}", - facets=expected_output_facets, + facets={}, ) assert lineage.run_facets == { "externalQuery": ExternalQueryRunFacet(externalQueryId=real_job_id, source=bq_namespace) diff --git a/providers/tests/google/cloud/transfers/test_gcs_to_bigquery.py b/providers/tests/google/cloud/transfers/test_gcs_to_bigquery.py index 0ba2e07bb05eb..299fc9fead54d 100644 --- a/providers/tests/google/cloud/transfers/test_gcs_to_bigquery.py +++ b/providers/tests/google/cloud/transfers/test_gcs_to_bigquery.py @@ -1439,12 +1439,6 @@ def test_get_openlineage_facets_on_complete_empty_table(self, hook): hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE) hook.return_value.get_client.return_value.get_table.return_value = TEST_EMPTY_TABLE - expected_output_dataset_facets = { - "schema": SchemaDatasetFacet(fields=[]), - "documentation": DocumentationDatasetFacet(description=""), - "columnLineage": ColumnLineageDatasetFacet(fields={}), - } - operator = GCSToBigQueryOperator( project_id=JOB_PROJECT_ID, task_id=TASK_ID, @@ -1461,18 +1455,17 @@ def test_get_openlineage_facets_on_complete_empty_table(self, hook): assert lineage.outputs[0] == Dataset( namespace="bigquery", name=TEST_EXPLICIT_DEST, - facets=expected_output_dataset_facets, + facets={}, ) assert lineage.inputs[0] == Dataset( namespace=f"gs://{TEST_BUCKET}", name=TEST_OBJECT_NO_WILDCARD, - facets={"schema": SchemaDatasetFacet(fields=[])}, + facets={}, ) assert lineage.inputs[1] == Dataset( namespace=f"gs://{TEST_BUCKET}", name="/", facets={ - "schema": SchemaDatasetFacet(fields=[]), "symlink": SymlinksDatasetFacet( identifiers=[ Identifier( diff --git a/providers/tests/system/google/cloud/translate/example_translate_dataset.py b/providers/tests/system/google/cloud/translate/example_translate_dataset.py new file mode 100644 index 0000000000000..3ad732862449d --- /dev/null +++ b/providers/tests/system/google/cloud/translate/example_translate_dataset.py @@ -0,0 +1,153 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG that translates text in Google Cloud Translate using V3 API version +service in the Google Cloud. +""" + +from __future__ import annotations + +import os +from datetime import datetime + +from airflow.models.dag import DAG +from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator +from airflow.providers.google.cloud.operators.translate import ( + TranslateCreateDatasetOperator, + TranslateDatasetsListOperator, + TranslateDeleteDatasetOperator, + TranslateImportDataOperator, +) +from airflow.providers.google.cloud.transfers.gcs_to_gcs import GCSToGCSOperator +from airflow.utils.trigger_rule import TriggerRule + +DAG_ID = "gcp_translate_automl_native_dataset" +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +REGION = "us-central1" +RESOURCE_DATA_BUCKET = "airflow-system-tests-resources" + +DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-") + +DATA_FILE_NAME = "import_en-es.tsv" + +RESOURCE_PATH = f"V3_translate/create_ds/import_data/{DATA_FILE_NAME}" +COPY_DATA_PATH = f"gs://{RESOURCE_DATA_BUCKET}/V3_translate/create_ds/import_data/{DATA_FILE_NAME}" + +DST_PATH = f"translate/import/{DATA_FILE_NAME}" + +DATASET_DATA_PATH = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/{DST_PATH}" + + +DATASET = { + "display_name": f"op_ds_native{DAG_ID}_{ENV_ID}", + "source_language_code": "es", + "target_language_code": "en", +} + +with DAG( + DAG_ID, + schedule="@once", # Override to match your needs + start_date=datetime(2024, 11, 1), + catchup=False, + tags=[ + "example", + "translate_dataset", + ], +) as dag: + create_bucket = GCSCreateBucketOperator( + task_id="create_bucket", + bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, + storage_class="REGIONAL", + location=REGION, + ) + copy_dataset_source_tsv = GCSToGCSOperator( + task_id="copy_dataset_file", + source_bucket=RESOURCE_DATA_BUCKET, + source_object=RESOURCE_PATH, + destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME, + destination_object=DST_PATH, + ) + + # [START howto_operator_translate_automl_create_dataset] + create_dataset_op = TranslateCreateDatasetOperator( + task_id="translate_v3_ds_create", + dataset=DATASET, + project_id=PROJECT_ID, + location=REGION, + ) + # [END howto_operator_translate_automl_create_dataset] + + # [START howto_operator_translate_automl_import_data] + import_ds_data_op = TranslateImportDataOperator( + task_id="translate_v3_ds_import_data", + dataset_id=create_dataset_op.output["dataset_id"], + input_config={ + "input_files": [{"usage": "UNASSIGNED", "gcs_source": {"input_uri": DATASET_DATA_PATH}}] + }, + project_id=PROJECT_ID, + location=REGION, + ) + # [END howto_operator_translate_automl_import_data] + + # [START howto_operator_translate_automl_list_datasets] + list_datasets_op = TranslateDatasetsListOperator( + task_id="translate_v3_list_ds", + project_id=PROJECT_ID, + location=REGION, + ) + # [END howto_operator_translate_automl_list_datasets] + + # [START howto_operator_translate_automl_delete_dataset] + delete_ds_op = TranslateDeleteDatasetOperator( + task_id="translate_v3_ds_delete", + dataset_id=create_dataset_op.output["dataset_id"], + project_id=PROJECT_ID, + location=REGION, + ) + # [END howto_operator_translate_automl_delete_dataset] + + delete_bucket = GCSDeleteBucketOperator( + task_id="delete_bucket", + bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, + trigger_rule=TriggerRule.ALL_DONE, + ) + + ( + # TEST SETUP + [create_bucket >> copy_dataset_source_tsv] + # TEST BODY + >> create_dataset_op + >> import_ds_data_op + >> list_datasets_op + >> delete_ds_op + # TEST TEARDOWN + >> delete_bucket + ) + + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/pyproject.toml b/pyproject.toml index a6026395bcdc0..719dd892383f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -216,7 +216,7 @@ extend-select = [ "UP", # Pyupgrade "ASYNC", # subset of flake8-async rules "ISC", # Checks for implicit literal string concatenation (auto-fixable) - "TCH", # Rules around TYPE_CHECKING blocks + "TC", # Rules around TYPE_CHECKING blocks "G", # flake8-logging-format rules "LOG", # flake8-logging rules, most of them autofixable "PT", # flake8-pytest-style rules @@ -260,8 +260,7 @@ ignore = [ "D203", "D212", # Conflicts with D213. Both can not be enabled. "E731", # Do not assign a lambda expression, use a def - "TCH003", # Do not move imports from stdlib to TYPE_CHECKING block - "PT004", # Fixture does not return anything, add leading underscore + "TC003", # Do not move imports from stdlib to TYPE_CHECKING block "PT006", # Wrong type of names in @pytest.mark.parametrize "PT007", # Wrong type of values in @pytest.mark.parametrize "PT013", # silly rule prohibiting e.g. `from pytest import param` @@ -313,8 +312,8 @@ section-order = [ testing = ["dev", "providers.tests", "task_sdk.tests", "tests_common", "tests"] [tool.ruff.lint.extend-per-file-ignores] -"airflow/__init__.py" = ["F401", "TCH004", "I002"] -"airflow/models/__init__.py" = ["F401", "TCH004"] +"airflow/__init__.py" = ["F401", "TC004", "I002"] +"airflow/models/__init__.py" = ["F401", "TC004"] "airflow/models/sqla_models.py" = ["F401"] "providers/src/airflow/providers/__init__.py" = ["I002"] "providers/src/airflow/__init__.py" = ["I002"] @@ -326,12 +325,12 @@ testing = ["dev", "providers.tests", "task_sdk.tests", "tests_common", "tests"] # The Pydantic representations of SqlAlchemy Models are not parsed well with Pydantic # when __future__.annotations is used so we need to skip them from upgrading # Pydantic also require models to be imported during execution -"airflow/serialization/pydantic/*.py" = ["I002", "UP007", "TCH001"] +"airflow/serialization/pydantic/*.py" = ["I002", "UP007", "TC001"] # Failing to detect types and functions used in `Annotated[...]` syntax as required at runtime. # Annotated is central for FastAPI dependency injection, skipping rules for FastAPI folders. -"airflow/api_fastapi/*" = ["TCH001", "TCH002"] -"tests/api_fastapi/*" = ["TCH001", "TCH002"] +"airflow/api_fastapi/*" = ["TC001", "TC002"] +"tests/api_fastapi/*" = ["T001", "TC002"] # Ignore pydoc style from these "*.pyi" = ["D"] diff --git a/task_sdk/pyproject.toml b/task_sdk/pyproject.toml index 5da673a79bf05..170aff5ec2440 100644 --- a/task_sdk/pyproject.toml +++ b/task_sdk/pyproject.toml @@ -62,11 +62,11 @@ namespace-packages = ["src/airflow"] # Pycharm barfs if this "stub" file has future imports "src/airflow/__init__.py" = ["I002"] -"src/airflow/sdk/__init__.py" = ["TCH004"] +"src/airflow/sdk/__init__.py" = ["TC004"] # msgspec needs types for annotations to be defined, even with future # annotations, so disable the "type check only import" for these files -"src/airflow/sdk/api/datamodels/*.py" = ["TCH001"] +"src/airflow/sdk/api/datamodels/*.py" = ["TC001"] # Only the public API should _require_ docstrings on classes "!src/airflow/sdk/definitions/*" = ["D101"] diff --git a/tests/utils/test_session.py b/tests/utils/test_session.py index 02cba9e070dc4..8d26a25c626a5 100644 --- a/tests/utils/test_session.py +++ b/tests/utils/test_session.py @@ -58,9 +58,9 @@ def test_provide_session_with_kwargs(self): @pytest.mark.asyncio async def test_async_session(self): - from airflow.settings import create_async_session + from airflow.settings import AsyncSession - session = create_async_session() + session = AsyncSession() session.add(Log(event="hihi1234")) await session.commit() my_special_log_event = await session.scalar(select(Log).where(Log.event == "hihi1234").limit(1))