Skip to content

Commit

Permalink
Airflow operator extra links support regex extraction from container …
Browse files Browse the repository at this point in the history
…logs (#4145)

Co-authored-by: Martynas Asipauskas <Martynas.Asipauskas@gresearch.co.uk>
  • Loading branch information
masipauskas and Martynas Asipauskas authored Jan 15, 2025
1 parent 7e6ca31 commit f89883c
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 109 deletions.
70 changes: 3 additions & 67 deletions docs/python_airflow_operator.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ and handles job cancellation if the Airflow task is killed.
* **reattach_policy** (*Optional**[**str**] **| **Callable**[**[**JobState**, **str**]**, **bool**]*) –


* **extra_links** (*Optional**[**Dict**[**str**, **str**]**]*) –
* **extra_links** (*Optional**[**Dict**[**str**, **Union**[**str**, **re.Pattern**]**]**]*) –



Expand Down Expand Up @@ -199,74 +199,10 @@ acknowledged by Armada.
:param reattach_policy: Operator reattach policy to use (defaults to: never)
:type reattach_policy: Optional[str] | Callable[[JobState, str], bool]
:param kwargs: Additional keyword arguments to pass to the BaseOperator.
:param extra_links: Extra links to be shown in addition to Lookout URL.
:type extra_links: Optional[Dict[str, str]]
:param extra_links: Extra links to be shown in addition to Lookout URL. Regex patterns will be extracted from container logs (taking first match).
:type extra_links: Optional[Dict[str, Union[str, re.Pattern]]]
:param kwargs: Additional keyword arguments to pass to the BaseOperator.


### _class_ armada.operators.armada.DynamicLink(name)
Bases: `BaseOperatorLink`, `LoggingMixin`


* **Parameters**

**name** (*str*) –



#### get_link(operator, \*, ti_key)
Link to external system.

Note: The old signature of this function was `(self, operator, dttm: datetime)`. That is still
supported at runtime but is deprecated.


* **Parameters**


* **operator** (*BaseOperator*) – The Airflow operator object this link is associated to.


* **ti_key** (*TaskInstanceKey*) – TaskInstance ID to return link for.



* **Returns**

link to external system



#### name(_: st_ )

### _class_ armada.operators.armada.LookoutLink()
Bases: `BaseOperatorLink`


#### get_link(operator, \*, ti_key)
Link to external system.

Note: The old signature of this function was `(self, operator, dttm: datetime)`. That is still
supported at runtime but is deprecated.


* **Parameters**


* **operator** (*BaseOperator*) – The Airflow operator object this link is associated to.


* **ti_key** (*TaskInstanceKey*) – TaskInstance ID to return link for.



* **Returns**

link to external system



#### name(_ = 'Lookout_ )
## armada.triggers.armada module

## armada.auth module
Expand Down
4 changes: 2 additions & 2 deletions third_party/airflow/armada/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def get_provider_info():
"name": "Armada Airflow Operator",
"description": "Armada Airflow Operator.",
"extra-links": [
"armada.operators.armada.LookoutLink",
"armada.operators.armada.DynamicLink",
"armada.links.LookoutLink",
"armada.links.DynamicLink",
],
"versions": ["1.0.0"],
}
86 changes: 86 additions & 0 deletions third_party/airflow/armada/links.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from __future__ import annotations

import re
import attrs

from typing import Dict, Optional, Union
from airflow.models import XCom
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.models import BaseOperator, BaseOperatorLink
from airflow.models.taskinstancekey import TaskInstanceKey


def get_link_value(ti_key: TaskInstanceKey, name: str) -> Optional[str]:
return XCom.get_value(ti_key=ti_key, key=f"armada_{name.lower()}_url")


def persist_link_value(ti_key: TaskInstanceKey, name: str, value: str):
XCom.set(
key=f"armada_{name.lower()}_url",
value=value,
dag_id=ti_key.dag_id,
task_id=ti_key.task_id,
run_id=ti_key.run_id,
map_index=ti_key.map_index,
)


class LookoutLink(BaseOperatorLink):
name = "Lookout"

def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey):
task_state = XCom.get_value(ti_key=ti_key, key="job_context")
if not task_state:
return ""

return task_state.get("armada_lookout_url", "")


@attrs.define(init=True)
class DynamicLink(BaseOperatorLink, LoggingMixin):
name: str

def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey):
url = get_link_value(ti_key, self.name)
if not url:
return ""
return url


class UrlFromLogsExtractor:
"""Extracts and persists URLs from log messages based on regex patterns."""

def __init__(self, extra_links: Dict[str, re.Pattern], ti_key: TaskInstanceKey):
"""
:param extra_links: Dictionary of link names to regex patterns for URLs
:param ti_key: TaskInstanceKey for XCom
"""
self._extra_links = extra_links
self._ti_key = ti_key

@staticmethod
def create(
extra_links: Dict[str, Union[str, re.Pattern]], ti_key: TaskInstanceKey
) -> UrlFromLogsExtractor:
valid_links = {
name: pattern
for name, pattern in extra_links.items()
if isinstance(pattern, re.Pattern) and not get_link_value(ti_key, name)
}
return UrlFromLogsExtractor(valid_links, ti_key)

def extract_and_persist_urls(self, message: str):
if not self._extra_links:
return

matches = []
for name in self._extra_links:
pattern = self._extra_links[name]
match = re.search(pattern, message)
if match:
url = match.group(0)
persist_link_value(self._ti_key, name, url)
matches.append(name)

for m in matches:
del self._extra_links[m]
29 changes: 20 additions & 9 deletions third_party/airflow/armada/log_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
import pendulum
import tenacity
from airflow.utils.log.logging_mixin import LoggingMixin
from armada.auth import TokenRetriever
from kubernetes import client, config
from pendulum import DateTime
from pendulum.parsing.exceptions import ParserError
from urllib3.exceptions import HTTPError

from .auth import TokenRetriever
from .links import UrlFromLogsExtractor


class KubernetesPodLogManager(LoggingMixin):
"""Monitor logs of Kubernetes pods asynchronously."""
Expand Down Expand Up @@ -54,6 +56,7 @@ def fetch_container_logs(
pod: str,
container: str,
since_time: Optional[DateTime],
link_extractor: UrlFromLogsExtractor,
) -> Optional[DateTime]:
"""
Fetches container logs, do not follow container logs.
Expand Down Expand Up @@ -81,10 +84,14 @@ def fetch_container_logs(
self.log.exception(f"There was an error reading the kubernetes API: {e}.")
raise

return self._stream_logs(container, since_time, logs)
return self._stream_logs(container, since_time, logs, link_extractor)

def _stream_logs(
self, container: str, since_time: Optional[DateTime], logs: HTTPResponse
self,
container: str,
since_time: Optional[DateTime],
logs: HTTPResponse,
link_extractor: UrlFromLogsExtractor,
) -> Optional[DateTime]:
messages: List[str] = []
message_timestamp = None
Expand All @@ -97,21 +104,25 @@ def _stream_logs(
if line_timestamp: # detect new log-line (starts with timestamp)
if since_time and line_timestamp <= since_time:
continue
self._log_container_message(container, messages)
messages.clear()
self._log_container_message(container, messages, link_extractor)
message_timestamp = line_timestamp
messages.append(message)
except HTTPError as e:
self.log.warning(
f"Reading of logs interrupted for container {container} with error {e}."
)

self._log_container_message(container, messages)
self._log_container_message(container, messages, link_extractor)
return message_timestamp

def _log_container_message(self, container: str, messages: List[str]):
if messages:
self.log.info("[%s] %s", container, "\n".join(messages))
def _log_container_message(
self, container: str, messages: List[str], link_extractor: UrlFromLogsExtractor
):
message = "\n".join(messages)
if message:
link_extractor.extract_and_persist_urls(message)
self.log.info("[%s] %s", container, message)
messages.clear()

def _parse_log_line(self, line: bytes) -> Tuple[DateTime | None, str]:
"""
Expand Down
42 changes: 12 additions & 30 deletions third_party/airflow/armada/operators/armada.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,20 @@
# under the License.
from __future__ import annotations

import attrs
import dataclasses
import datetime
import os
import time
from typing import Any, Callable, Dict, Optional, Sequence, Tuple
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

import jinja2
import tenacity
import re

from airflow.configuration import conf
from airflow.exceptions import AirflowFailException
from airflow.models import BaseOperator, BaseOperatorLink, XCom
from airflow.models import BaseOperator
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.serialization.serde import deserialize
from airflow.utils.context import Context
from airflow.utils.log.logging_mixin import LoggingMixin
Expand All @@ -50,26 +48,7 @@
from ..policies.reattach import external_job_uri, policy
from ..triggers import ArmadaPollJobTrigger
from ..utils import log_exceptions, xcom_pull_for_ti, resolve_parameter_value


class LookoutLink(BaseOperatorLink):
name = "Lookout"

def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey):
task_state = XCom.get_value(ti_key=ti_key, key="job_context")
if not task_state:
return ""

return task_state.get("armada_lookout_url", "")


@attrs.define(init=True)
class DynamicLink(BaseOperatorLink, LoggingMixin):
name: str

def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey):
url = XCom.get_value(ti_key=ti_key, key=f"armada_{self.name.lower()}_url")
return url
from ..links import LookoutLink, DynamicLink, persist_link_value, UrlFromLogsExtractor


class ArmadaOperator(BaseOperator, LoggingMixin):
Expand Down Expand Up @@ -118,8 +97,9 @@ class ArmadaOperator(BaseOperator, LoggingMixin):
:param reattach_policy: Operator reattach policy to use (defaults to: never)
:type reattach_policy: Optional[str] | Callable[[JobState, str], bool]
:param kwargs: Additional keyword arguments to pass to the BaseOperator.
:param extra_links: Extra links to be shown in addition to Lookout URL.
:type extra_links: Optional[Dict[str, str]]
:param extra_links: Extra links to be shown in addition to Lookout URL. \
Regex patterns will be extracted from container logs (taking first match).
:type extra_links: Optional[Dict[str, Union[str, re.Pattern]]]
:param kwargs: Additional keyword arguments to pass to the BaseOperator.
"""

Expand Down Expand Up @@ -149,7 +129,7 @@ def __init__(
"armada_operator", "default_dry_run", fallback=False
),
reattach_policy: Optional[str] | Callable[[JobState, str], bool] = None,
extra_links: Optional[Dict[str, str]] = None,
extra_links: Optional[Dict[str, Union[str, re.Pattern]]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -326,9 +306,7 @@ def render_extra_links_urls(
continue
try:
rendered_url = jinja_env.from_string(url).render(context)
self._xcom_push(
context, key=f"armada_{name.lower()}_url", value=rendered_url
)
persist_link_value(context["ti"].key, name, rendered_url)
except jinja2.TemplateError as e:
self.log.error(f"Error rendering template for {name} ({url}): {e}")

Expand Down Expand Up @@ -485,12 +463,16 @@ def _check_job_status_and_fetch_logs(self, context) -> None:

if self._should_have_a_pod_in_k8s() and self.container_logs:
try:
link_extractor = UrlFromLogsExtractor.create(
self.extra_links, context["ti"].key
)
last_log_time = self.pod_manager.fetch_container_logs(
k8s_context=self.job_context.cluster,
namespace=self.job_request.namespace,
pod=f"armada-{self.job_context.job_id}-0",
container=self.container_logs,
since_time=self.job_context.last_log_time,
link_extractor=link_extractor,
)
if last_log_time:
self.job_context = dataclasses.replace(
Expand Down
3 changes: 2 additions & 1 deletion third_party/airflow/test/unit/operators/test_armada.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dataclasses
from datetime import timedelta
from typing import Optional
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, patch, ANY

import pytest
from airflow.exceptions import TaskDeferred
Expand Down Expand Up @@ -226,6 +226,7 @@ def test_polls_for_logs(context):
pod="armada-test_job-0",
container="alpine",
since_time=None,
link_extractor=ANY,
)


Expand Down

0 comments on commit f89883c

Please sign in to comment.