diff --git a/conda/post-link.sh b/conda/post-link.sh index aaf0608f9..24a9a62bb 100644 --- a/conda/post-link.sh +++ b/conda/post-link.sh @@ -5,4 +5,5 @@ $PREFIX/bin/pip install \ 'protobuf>=5.27.2,<6.0' \ 'influxdb3-python>=0.7,<1.0' \ 'pyiceberg[pyarrow,glue]>=0.7,<0.8' \ -'redis[hiredis]>=5.2.0,<6' +'redis[hiredis]>=5.2.0,<6' \ +'paho-mqtt>=2.1.0,<3' diff --git a/pyproject.toml b/pyproject.toml index 1f64ba568..cbcdcb782 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,8 @@ all = [ "psycopg2-binary>=2.9.9,<3", "boto3>=1.35.65,<2.0", "boto3-stubs>=1.35.65,<2.0", - "redis[hiredis]>=5.2.0,<6" + "redis[hiredis]>=5.2.0,<6", + "paho-mqtt>=2.1.0,<3" ] avro = ["fastavro>=1.8,<2.0"] @@ -50,6 +51,7 @@ pubsub = ["google-cloud-pubsub>=2.23.1,<3"] postgresql = ["psycopg2-binary>=2.9.9,<3"] kinesis = ["boto3>=1.35.65,<2.0", "boto3-stubs[kinesis]>=1.35.65,<2.0"] redis=["redis[hiredis]>=5.2.0,<6"] +mqtt=["paho-mqtt>=2.1.0,<3"] [tool.setuptools.packages.find] include = ["quixstreams*"] diff --git a/quixstreams/sinks/community/mqtt.py b/quixstreams/sinks/community/mqtt.py new file mode 100644 index 000000000..51a2284a0 --- /dev/null +++ b/quixstreams/sinks/community/mqtt.py @@ -0,0 +1,169 @@ +import json +from datetime import datetime +from typing import Any, List, Tuple + +from quixstreams.models.types import HeaderValue +from quixstreams.sinks.base.sink import BaseSink + +try: + import paho.mqtt.client as paho + from paho import mqtt +except ImportError as exc: + raise ImportError( + 'Package "paho-mqtt" is missing: ' "run pip install quixstreams[mqtt] to fix it" + ) from exc + + +class MQTTSink(BaseSink): + """ + A sink that publishes messages to an MQTT broker. + """ + + def __init__( + self, + mqtt_client_id: str, + mqtt_server: str, + mqtt_port: int, + mqtt_topic_root: str, + mqtt_username: str = None, + mqtt_password: str = None, + mqtt_version: str = "3.1.1", + tls_enabled: bool = True, + qos: int = 1, + ): + """ + Initialize the MQTTSink. + + :param mqtt_client_id: MQTT client identifier. + :param mqtt_server: MQTT broker server address. + :param mqtt_port: MQTT broker server port. + :param mqtt_topic_root: Root topic to publish messages to. + :param mqtt_username: Username for MQTT broker authentication. Defaults to None + :param mqtt_password: Password for MQTT broker authentication. Defaults to None + :param mqtt_version: MQTT protocol version ("3.1", "3.1.1", or "5"). Defaults to 3.1.1 + :param tls_enabled: Whether to use TLS encryption. Defaults to True + :param qos: Quality of Service level (0, 1, or 2). Defaults to 1 + """ + + super().__init__() + + self.mqtt_version = mqtt_version + self.mqtt_username = mqtt_username + self.mqtt_password = mqtt_password + self.mqtt_topic_root = mqtt_topic_root + self.tls_enabled = tls_enabled + self.qos = qos + + self.mqtt_client = paho.Client( + callback_api_version=paho.CallbackAPIVersion.VERSION2, + client_id=mqtt_client_id, + userdata=None, + protocol=self._mqtt_protocol_version(), + ) + + if self.tls_enabled: + self.mqtt_client.tls_set( + tls_version=mqtt.client.ssl.PROTOCOL_TLS + ) # we'll be using tls now + + self.mqtt_client.reconnect_delay_set(5, 60) + self._configure_authentication() + self.mqtt_client.on_connect = self._mqtt_on_connect_cb + self.mqtt_client.on_disconnect = self._mqtt_on_disconnect_cb + self.mqtt_client.connect(mqtt_server, int(mqtt_port)) + + # setting callbacks for different events to see if it works, print the message etc. + def _mqtt_on_connect_cb( + self, + client: paho.Client, + userdata: any, + connect_flags: paho.ConnectFlags, + reason_code: paho.ReasonCode, + properties: paho.Properties, + ): + if reason_code == 0: + print("CONNECTED!") # required for Quix to know this has connected + else: + print(f"ERROR ({reason_code.value}). {reason_code.getName()}") + + def _mqtt_on_disconnect_cb( + self, + client: paho.Client, + userdata: any, + disconnect_flags: paho.DisconnectFlags, + reason_code: paho.ReasonCode, + properties: paho.Properties, + ): + print( + f"DISCONNECTED! Reason code ({reason_code.value}) {reason_code.getName()}!" + ) + + def _mqtt_protocol_version(self): + if self.mqtt_version == "3.1": + return paho.MQTTv31 + elif self.mqtt_version == "3.1.1": + return paho.MQTTv311 + elif self.mqtt_version == "5": + return paho.MQTTv5 + else: + raise ValueError(f"Unsupported MQTT version: {self.mqtt_version}") + + def _configure_authentication(self): + if self.mqtt_username: + self.mqtt_client.username_pw_set(self.mqtt_username, self.mqtt_password) + + def _publish_to_mqtt( + self, + data: str, + key: bytes, + timestamp: datetime, + headers: List[Tuple[str, HeaderValue]], + ): + if isinstance(data, bytes): + data = data.decode("utf-8") # Decode bytes to string using utf-8 + + json_data = json.dumps(data) + message_key_string = key.decode( + "utf-8" + ) # Convert to string using utf-8 encoding + # publish to MQTT + self.mqtt_client.publish( + self.mqtt_topic_root + "/" + message_key_string, + payload=json_data, + qos=self.qos, + ) + + def add( + self, + topic: str, + partition: int, + offset: int, + key: bytes, + value: bytes, + timestamp: datetime, + headers: List[Tuple[str, HeaderValue]], + **kwargs: Any, + ): + self._publish_to_mqtt(value, key, timestamp, headers) + + def _construct_topic(self, key): + if key: + key_str = key.decode("utf-8") if isinstance(key, bytes) else str(key) + return f"{self.mqtt_topic_root}/{key_str}" + else: + return self.mqtt_topic_root + + def on_paused(self, topic: str, partition: int): + # not used + pass + + def flush(self, topic: str, partition: str): + # not used + pass + + def cleanup(self): + self.mqtt_client.loop_stop() + self.mqtt_client.disconnect() + + def __del__(self): + self.cleanup() diff --git a/tests/requirements.txt b/tests/requirements.txt index 8032d747b..ba046b08d 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -7,3 +7,4 @@ protobuf>=5.27.2 influxdb3-python>=0.7.0,<1.0 pyiceberg[pyarrow,glue]>=0.7,<0.8 redis[hiredis]>=5.2.0,<6 +paho-mqtt==2.1.0 diff --git a/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py b/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py new file mode 100644 index 000000000..05b6b332b --- /dev/null +++ b/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py @@ -0,0 +1,85 @@ +from datetime import datetime +from unittest.mock import patch + +import pytest + +from quixstreams.sinks.community.mqtt import MQTTSink + + +@pytest.fixture() +def mqtt_sink_factory(): + def factory( + mqtt_client_id: str = "test_client", + mqtt_server: str = "localhost", + mqtt_port: int = 1883, + mqtt_topic_root: str = "test/topic", + mqtt_username: str = None, + mqtt_password: str = None, + mqtt_version: str = "3.1.1", + tls_enabled: bool = True, + qos: int = 1, + ) -> MQTTSink: + with patch("paho.mqtt.client.Client") as MockClient: + mock_mqtt_client = MockClient.return_value + sink = MQTTSink( + mqtt_client_id=mqtt_client_id, + mqtt_server=mqtt_server, + mqtt_port=mqtt_port, + mqtt_topic_root=mqtt_topic_root, + mqtt_username=mqtt_username, + mqtt_password=mqtt_password, + mqtt_version=mqtt_version, + tls_enabled=tls_enabled, + qos=qos, + ) + sink.mqtt_client = mock_mqtt_client + return sink, mock_mqtt_client + + return factory + + +class TestMQTTSink: + def test_mqtt_connect(self, mqtt_sink_factory): + sink, mock_mqtt_client = mqtt_sink_factory() + mock_mqtt_client.connect.assert_called_once_with("localhost", 1883) + + def test_mqtt_tls_enabled(self, mqtt_sink_factory): + sink, mock_mqtt_client = mqtt_sink_factory(tls_enabled=True) + mock_mqtt_client.tls_set.assert_called_once() + + def test_mqtt_tls_disabled(self, mqtt_sink_factory): + sink, mock_mqtt_client = mqtt_sink_factory(tls_enabled=False) + mock_mqtt_client.tls_set.assert_not_called() + + def test_mqtt_publish(self, mqtt_sink_factory): + sink, mock_mqtt_client = mqtt_sink_factory() + data = "test_data" + key = b"test_key" + timestamp = datetime.now() + headers = [] + + sink.add( + topic="test-topic", + partition=0, + offset=1, + key=key, + value=data.encode("utf-8"), + timestamp=timestamp, + headers=headers, + ) + + mock_mqtt_client.publish.assert_called_once_with( + "test/topic/test_key", payload='"test_data"', qos=1 + ) + + def test_mqtt_authentication(self, mqtt_sink_factory): + sink, mock_mqtt_client = mqtt_sink_factory( + mqtt_username="user", mqtt_password="pass" + ) + mock_mqtt_client.username_pw_set.assert_called_once_with("user", "pass") + + def test_mqtt_disconnect_on_delete(self, mqtt_sink_factory): + sink, mock_mqtt_client = mqtt_sink_factory() + sink.cleanup() # Explicitly call cleanup + mock_mqtt_client.loop_stop.assert_called_once() + mock_mqtt_client.disconnect.assert_called_once()