From a581c4c510935f5a30d6760de0f4e9cb1afd81d5 Mon Sep 17 00:00:00 2001 From: jayesh Date: Thu, 21 Nov 2024 15:18:58 +0530 Subject: [PATCH 01/31] openai stt update options --- .../livekit-plugins-openai/livekit/plugins/openai/stt.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/stt.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/stt.py index 4b79ba038..0ccf0d585 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/stt.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/stt.py @@ -86,6 +86,12 @@ def __init__( ), ) + def update_options( + self, *, model: WhisperModels | GroqAudioModels | None, language: str | None + ) -> None: + self._opts.model = model or self._opts.model + self._opts.language = language or self._opts.language + @staticmethod def with_groq( *, From 1089b9c4e5c500240fdeba640c34b1aadb56098f Mon Sep 17 00:00:00 2001 From: jayesh Date: Sun, 24 Nov 2024 00:32:51 +0530 Subject: [PATCH 02/31] stt update options for fal --- livekit-plugins/livekit-plugins-fal/livekit/plugins/fal/stt.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/livekit-plugins/livekit-plugins-fal/livekit/plugins/fal/stt.py b/livekit-plugins/livekit-plugins-fal/livekit/plugins/fal/stt.py index fc275ed21..89354fd27 100644 --- a/livekit-plugins/livekit-plugins-fal/livekit/plugins/fal/stt.py +++ b/livekit-plugins/livekit-plugins-fal/livekit/plugins/fal/stt.py @@ -47,6 +47,9 @@ def __init__( "FAL AI API key is required. It should be set with env FAL_KEY" ) + def update_options(self, *, language: str | None) -> None: + self._opts.language = language or self._opts.language + def _sanitize_options( self, *, From 0bc397f4a661464bd1d452ba355958824f2f8baf Mon Sep 17 00:00:00 2001 From: jayesh Date: Sun, 24 Nov 2024 00:52:41 +0530 Subject: [PATCH 03/31] stt update options for clova --- .../livekit-plugins-clova/livekit/plugins/clova/stt.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/livekit-plugins/livekit-plugins-clova/livekit/plugins/clova/stt.py b/livekit-plugins/livekit-plugins-clova/livekit/plugins/clova/stt.py index ef7367c77..a09eaee2b 100644 --- a/livekit-plugins/livekit-plugins-clova/livekit/plugins/clova/stt.py +++ b/livekit-plugins/livekit-plugins-clova/livekit/plugins/clova/stt.py @@ -68,6 +68,11 @@ def __init__( ) self.threshold = threshold + def update_options(self, *, language: str | None) -> None: + self._language = ( + clova_languages_mapping.get(language, language) or self._language + ) + def _ensure_session(self) -> aiohttp.ClientSession: if not self._session: self._session = utils.http_context.http_session() From f22e42e53414767b5ea0f5fe738ad892daf6d0c2 Mon Sep 17 00:00:00 2001 From: jayesh Date: Sun, 24 Nov 2024 01:12:33 +0530 Subject: [PATCH 04/31] wip --- .../livekit-plugins-fal/livekit/plugins/fal/stt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/livekit-plugins/livekit-plugins-fal/livekit/plugins/fal/stt.py b/livekit-plugins/livekit-plugins-fal/livekit/plugins/fal/stt.py index 89354fd27..70f912834 100644 --- a/livekit-plugins/livekit-plugins-fal/livekit/plugins/fal/stt.py +++ b/livekit-plugins/livekit-plugins-fal/livekit/plugins/fal/stt.py @@ -1,7 +1,7 @@ import dataclasses import os from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union import fal_client from livekit.agents import ( @@ -47,7 +47,7 @@ def __init__( "FAL AI API key is required. It should be set with env FAL_KEY" ) - def update_options(self, *, language: str | None) -> None: + def update_options(self, *, language: Union[str, None]) -> None: self._opts.language = language or self._opts.language def _sanitize_options( From 6333159906aca52a75ede4567f1604b752196170 Mon Sep 17 00:00:00 2001 From: jayesh Date: Tue, 26 Nov 2024 17:59:06 +0530 Subject: [PATCH 05/31] deepgram stt wip --- .../livekit/plugins/deepgram/stt.py | 253 +++++++++++++----- 1 file changed, 180 insertions(+), 73 deletions(-) diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index 4c4d46cc5..ca3dac004 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -180,6 +180,7 @@ def __init__( energy_filter=energy_filter, ) self._session = http_session + self._active_streams = set() def _ensure_session(self) -> aiohttp.ClientSession: if not self._session: @@ -245,13 +246,28 @@ def stream( conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, ) -> "SpeechStream": config = self._sanitize_options(language=language) - return SpeechStream( + + stream = SpeechStream( stt=self, conn_options=conn_options, opts=config, api_key=self._api_key, http_session=self._ensure_session(), ) + self._active_streams.add(stream) + return stream + + def remove_stream(self, stream: "SpeechStream") -> None: + """Remove a SpeechStream from the active streams set.""" + self._active_streams.discard(stream) + + def update_options(self, language: DeepgramLanguages | str | None) -> None: + """Update the STT options and propagate changes to active streams.""" + # Update the options stored inside the class + self._opts.language = language or self._opts.language + # Propagate updated options to active streams + for stream in self._active_streams: + asyncio.create_task(stream.update_options(language=language)) def _sanitize_options(self, *, language: str | None = None) -> STTOptions: config = dataclasses.replace(self._opts) @@ -303,33 +319,117 @@ def __init__( self._pushed_audio_duration = 0.0 self._request_id = "" - async def _run(self) -> None: - closing_ws = False + self._reconnect_event = asyncio.Event() + self._closed = False + self._ws: Optional[aiohttp.ClientWebSocketResponse] = None + self._ws_lock = asyncio.Lock() + self._stt = stt + self._stt._active_streams.add(self) + + async def update_options(self, language: DeepgramLanguages | str | None) -> None: + """Update the options and trigger reconnection.""" + async with self._ws_lock: + self._opts.language = language or self._opts.language + self._reconnect_event.set() + logger.info("options updated, reconnection requested.") - async def keepalive_task(ws: aiohttp.ClientWebSocketResponse): - # if we want to keep the connection alive even if no audio is sent, - # Deepgram expects a keepalive message. - # https://developers.deepgram.com/reference/listen-live#stream-keepalive + async def _run(self) -> None: + while not self._closed: try: - while True: - await ws.send_str(SpeechStream._KEEPALIVE_MSG) - await asyncio.sleep(5) + await self._connect_ws() + send_task = asyncio.create_task(self._send_task()) + recv_task = asyncio.create_task(self._recv_task()) + keepalive_task = asyncio.create_task(self._keepalive_task()) + reconnect_wait_task = asyncio.create_task(self._reconnect_event.wait()) + + tasks = [send_task, recv_task, keepalive_task] + done, pending = await asyncio.wait( + [reconnect_wait_task] + tasks, + return_when=asyncio.FIRST_COMPLETED, + ) + + if reconnect_wait_task.done(): + self._reconnect_event.clear() + logger.info("reconnecting with updated options...") + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + + if self._ws and not self._ws.closed: + await self._ws.close() + self._ws = None + + continue + except Exception: - return + logger.exception("Error in SpeechStream _run method") + # Decide whether to retry or break based on the exception type + break # For now, we break the loop on exceptions - async def send_task(ws: aiohttp.ClientWebSocketResponse): - nonlocal closing_ws + await self._cleanup() - # forward audio to deepgram in chunks of 50ms - samples_50ms = self._opts.sample_rate // 20 - audio_bstream = utils.audio.AudioByteStream( - sample_rate=self._opts.sample_rate, - num_channels=self._opts.num_channels, - samples_per_channel=samples_50ms, - ) + async def _connect_ws(self): + """Establish the websocket connection using the current options.""" + async with self._ws_lock: + if self._ws and not self._ws.closed: + await self._ws.close() + + live_config = { + "model": self._opts.model, + "punctuate": self._opts.punctuate, + "smart_format": self._opts.smart_format, + "no_delay": self._opts.no_delay, + "interim_results": self._opts.interim_results, + "encoding": "linear16", + "vad_events": True, + "sample_rate": self._opts.sample_rate, + "channels": self._opts.num_channels, + "endpointing": False + if self._opts.endpointing_ms == 0 + else self._opts.endpointing_ms, + "filler_words": self._opts.filler_words, + "keywords": self._opts.keywords, + "profanity_filter": self._opts.profanity_filter, + } + + if self._opts.language: + live_config["language"] = self._opts.language + + try: + self._ws = await asyncio.wait_for( + self._session.ws_connect( + _to_deepgram_url(live_config, websocket=True), + headers={"Authorization": f"Token {self._api_key}"}, + ), + self._conn_options.timeout, + ) + logger.info("WebSocket connection established.") + except Exception as e: + logger.exception("Failed to establish WebSocket connection.") + raise APIConnectionError() from e + + async def _send_task(self): + """Task for sending audio data to the websocket.""" + # Ensure the websocket is connected + if not self._ws or self._ws.closed: + logger.error("WebSocket is not connected in send_task.") + return + + ws = self._ws + closing_ws = False + + # Forward audio to deepgram in chunks of 50ms + samples_50ms = self._opts.sample_rate // 20 + audio_bstream = utils.audio.AudioByteStream( + sample_rate=self._opts.sample_rate, + num_channels=self._opts.num_channels, + samples_per_channel=samples_50ms, + ) + + has_ended = False + last_frame: Optional[rtc.AudioFrame] = None - has_ended = False - last_frame: Optional[rtc.AudioFrame] = None + try: async for data in self._input_ch: frames: list[rtc.AudioFrame] = [] if isinstance(data, rtc.AudioFrame): @@ -369,8 +469,24 @@ async def send_task(ws: aiohttp.ClientWebSocketResponse): closing_ws = True await ws.send_str(SpeechStream._CLOSE_MSG) - async def recv_task(ws: aiohttp.ClientWebSocketResponse): - nonlocal closing_ws + except asyncio.CancelledError: + # Task was cancelled due to reconnection or closure + pass + except Exception: + logger.exception("Error in send_task") + finally: + if closing_ws: + await ws.close() + + async def _recv_task(self): + """Task for receiving data from the websocket.""" + # Ensure the websocket is connected + if not self._ws or self._ws.closed: + logger.error("WebSocket is not connected in recv_task.") + return + + ws = self._ws + try: while True: msg = await ws.receive() if msg.type in ( @@ -378,11 +494,8 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse): aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, ): - if closing_ws: # close is expected, see SpeechStream.aclose - return - - # this will trigger a reconnection, see the _run loop - raise Exception("deepgram connection closed unexpectedly") + logger.info("WebSocket connection closed.") + break if msg.type != aiohttp.WSMsgType.TEXT: logger.warning("unexpected deepgram message type %s", msg.type) @@ -391,53 +504,33 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse): try: self._process_stream_event(json.loads(msg.data)) except Exception: - logger.exception("failed to process deepgram message") + logger.exception("Failed to process message from Deepgram") - ws: aiohttp.ClientWebSocketResponse | None = None - - try: - live_config = { - "model": self._opts.model, - "punctuate": self._opts.punctuate, - "smart_format": self._opts.smart_format, - "no_delay": self._opts.no_delay, - "interim_results": self._opts.interim_results, - "encoding": "linear16", - "vad_events": True, - "sample_rate": self._opts.sample_rate, - "channels": self._opts.num_channels, - "endpointing": False - if self._opts.endpointing_ms == 0 - else self._opts.endpointing_ms, - "filler_words": self._opts.filler_words, - "keywords": self._opts.keywords, - "profanity_filter": self._opts.profanity_filter, - } - - if self._opts.language: - live_config["language"] = self._opts.language - - ws = await asyncio.wait_for( - self._session.ws_connect( - _to_deepgram_url(live_config, websocket=True), - headers={"Authorization": f"Token {self._api_key}"}, - ), - self._conn_options.timeout, - ) + except asyncio.CancelledError: + # Task was cancelled due to reconnection or closure + pass + except Exception: + logger.exception("Error in recv_task") + finally: + await ws.close() - tasks = [ - asyncio.create_task(send_task(ws)), - asyncio.create_task(recv_task(ws)), - asyncio.create_task(keepalive_task(ws)), - ] + async def _keepalive_task(self): + """Task for sending keepalive messages.""" + # Ensure the websocket is connected + if not self._ws or self._ws.closed: + logger.error("WebSocket is not connected in keepalive_task.") + return - try: - await asyncio.gather(*tasks) - finally: - await utils.aio.gracefully_cancel(*tasks) - finally: - if ws is not None: - await ws.close() + ws = self._ws + try: + while True: + await ws.send_str(SpeechStream._KEEPALIVE_MSG) + await asyncio.sleep(5) + except asyncio.CancelledError: + # Task was cancelled due to reconnection or closure + pass + except Exception: + logger.exception("Error in keepalive_task") def _check_energy_state(self, frame: rtc.AudioFrame) -> AudioEnergyFilter.State: if self._audio_energy_filter: @@ -519,6 +612,20 @@ def _process_stream_event(self, data: dict) -> None: else: logger.warning("received unexpected message from deepgram %s", data) + async def aclose(self) -> None: + """Close the stream and clean up resources.""" + self._closed = True + self._reconnect_event.set() # Trigger any waiting loops to exit + self._stt.remove_stream(self) + await super().aclose() + await self._cleanup() + + async def _cleanup(self): + """Cleanup resources.""" + if self._ws and not self._ws.closed: + await self._ws.close() + self._ws = None + def live_transcription_to_speech_data( language: str, data: dict From 64749924065030c47f3ce827830d7ac1029646c0 Mon Sep 17 00:00:00 2001 From: jayesh Date: Tue, 26 Nov 2024 19:00:36 +0530 Subject: [PATCH 06/31] deepgram stt wip --- .../livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index ca3dac004..008cbaff2 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -323,7 +323,7 @@ def __init__( self._closed = False self._ws: Optional[aiohttp.ClientWebSocketResponse] = None self._ws_lock = asyncio.Lock() - self._stt = stt + self._stt: STT = stt self._stt._active_streams.add(self) async def update_options(self, language: DeepgramLanguages | str | None) -> None: From 7ea51cbadc107b313fd2dc15391771c8273b8ad4 Mon Sep 17 00:00:00 2001 From: jayesh Date: Tue, 26 Nov 2024 19:06:27 +0530 Subject: [PATCH 07/31] deepgram stt wip --- .../livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index 008cbaff2..64184a097 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -180,7 +180,7 @@ def __init__( energy_filter=energy_filter, ) self._session = http_session - self._active_streams = set() + self._active_streams: set(SpeechStream) = set() def _ensure_session(self) -> aiohttp.ClientSession: if not self._session: From a03d17f633bb8553a185dc86517039855386a633 Mon Sep 17 00:00:00 2001 From: jayesh Date: Tue, 26 Nov 2024 19:12:37 +0530 Subject: [PATCH 08/31] deepgram stt wip --- .../livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index 64184a097..119fb5fbf 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -180,7 +180,7 @@ def __init__( energy_filter=energy_filter, ) self._session = http_session - self._active_streams: set(SpeechStream) = set() + self._active_streams: set[SpeechStream] = set() def _ensure_session(self) -> aiohttp.ClientSession: if not self._session: @@ -263,9 +263,7 @@ def remove_stream(self, stream: "SpeechStream") -> None: def update_options(self, language: DeepgramLanguages | str | None) -> None: """Update the STT options and propagate changes to active streams.""" - # Update the options stored inside the class self._opts.language = language or self._opts.language - # Propagate updated options to active streams for stream in self._active_streams: asyncio.create_task(stream.update_options(language=language)) @@ -327,7 +325,6 @@ def __init__( self._stt._active_streams.add(self) async def update_options(self, language: DeepgramLanguages | str | None) -> None: - """Update the options and trigger reconnection.""" async with self._ws_lock: self._opts.language = language or self._opts.language self._reconnect_event.set() From 72267c679f3bbcf4fd430398110184df0c2ea083 Mon Sep 17 00:00:00 2001 From: jayesh Date: Tue, 26 Nov 2024 19:40:03 +0530 Subject: [PATCH 09/31] deepgram stt wip --- .../livekit/plugins/deepgram/stt.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index 119fb5fbf..73fe54afb 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -357,7 +357,11 @@ async def _run(self) -> None: self._ws = None continue - + if any(task in done for task in [send_task, recv_task]): + for task in tasks: + if task.done() and task.exception(): + raise task.exception() + break except Exception: logger.exception("Error in SpeechStream _run method") # Decide whether to retry or break based on the exception type From 4deec8bf66570a23d1a902f3fce0b274214c2a69 Mon Sep 17 00:00:00 2001 From: jayesh Date: Tue, 26 Nov 2024 20:18:38 +0530 Subject: [PATCH 10/31] deepgram stt wip --- .../livekit/plugins/deepgram/stt.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index 73fe54afb..d324813e8 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -350,6 +350,7 @@ async def _run(self) -> None: logger.info("reconnecting with updated options...") for task in tasks: task.cancel() + reconnect_wait_task.cancel() await asyncio.gather(*tasks, return_exceptions=True) if self._ws and not self._ws.closed: @@ -357,15 +358,17 @@ async def _run(self) -> None: self._ws = None continue - if any(task in done for task in [send_task, recv_task]): + if recv_task in done: for task in tasks: if task.done() and task.exception(): raise task.exception() - break + break # Exit the loop if the recv_task has completed + + reconnect_wait_task.cancel() + await asyncio.gather(reconnect_wait_task, return_exceptions=True) except Exception: logger.exception("Error in SpeechStream _run method") - # Decide whether to retry or break based on the exception type - break # For now, we break the loop on exceptions + break await self._cleanup() From 3ec13d37f9d9d67ca29c9e565066ad7b2fef605d Mon Sep 17 00:00:00 2001 From: jayesh Date: Tue, 26 Nov 2024 20:18:49 +0530 Subject: [PATCH 11/31] deepgram stt wip --- .../livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index d324813e8..b8e69bfb0 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -362,7 +362,7 @@ async def _run(self) -> None: for task in tasks: if task.done() and task.exception(): raise task.exception() - break # Exit the loop if the recv_task has completed + break reconnect_wait_task.cancel() await asyncio.gather(reconnect_wait_task, return_exceptions=True) From fe3e1fdbae9115135092d57f97bd992986f8fd29 Mon Sep 17 00:00:00 2001 From: jayesh Date: Wed, 27 Nov 2024 11:25:17 +0530 Subject: [PATCH 12/31] deepgram stt wip --- .../livekit/plugins/deepgram/stt.py | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index b8e69bfb0..c163aeb3a 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -362,12 +362,24 @@ async def _run(self) -> None: for task in tasks: if task.done() and task.exception(): raise task.exception() + if not task.done(): + task.cancel() + reconnect_wait_task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) break - + if send_task in done: + if send_task.exception(): + raise send_task.exception() + continue reconnect_wait_task.cancel() await asyncio.gather(reconnect_wait_task, return_exceptions=True) except Exception: logger.exception("Error in SpeechStream _run method") + for task in tasks + [reconnect_wait_task]: + task.cancel() + await asyncio.gather( + *tasks + [reconnect_wait_task], return_exceptions=True + ) break await self._cleanup() @@ -516,7 +528,8 @@ async def _recv_task(self): except Exception: logger.exception("Error in recv_task") finally: - await ws.close() + if not ws.closed: + await ws.close() async def _keepalive_task(self): """Task for sending keepalive messages.""" @@ -530,7 +543,11 @@ async def _keepalive_task(self): while True: await ws.send_str(SpeechStream._KEEPALIVE_MSG) await asyncio.sleep(5) - except asyncio.CancelledError: + except ( + asyncio.CancelledError, + aiohttp.ClientConnectionError, + aiohttp.ClientConnectionError, + ): # Task was cancelled due to reconnection or closure pass except Exception: @@ -619,9 +636,10 @@ def _process_stream_event(self, data: dict) -> None: async def aclose(self) -> None: """Close the stream and clean up resources.""" self._closed = True - self._reconnect_event.set() # Trigger any waiting loops to exit + self._reconnect_event.set() self._stt.remove_stream(self) await super().aclose() + await self._task await self._cleanup() async def _cleanup(self): From 64df6e018f0e44bc1383e18d705b2d1599c5e441 Mon Sep 17 00:00:00 2001 From: jayesh Date: Wed, 27 Nov 2024 11:51:35 +0530 Subject: [PATCH 13/31] deepgram stt wip --- .../livekit/plugins/deepgram/stt.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index ccb35b5cf..7a5b0f941 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -335,6 +335,8 @@ async def update_options(self, language: DeepgramLanguages | str | None) -> None logger.info("options updated, reconnection requested.") async def _run(self) -> None: + tasks = [] + reconnect_wait_task = None while not self._closed: try: await self._connect_ws() @@ -379,11 +381,13 @@ async def _run(self) -> None: await asyncio.gather(reconnect_wait_task, return_exceptions=True) except Exception: logger.exception("Error in SpeechStream _run method") - for task in tasks + [reconnect_wait_task]: - task.cancel() - await asyncio.gather( - *tasks + [reconnect_wait_task], return_exceptions=True - ) + if tasks: + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + if reconnect_wait_task: + reconnect_wait_task.cancel() + await asyncio.gather(reconnect_wait_task, return_exceptions=True) break await self._cleanup() @@ -418,7 +422,7 @@ async def _connect_ws(self): try: self._ws = await asyncio.wait_for( self._session.ws_connect( - _to_deepgram_url(live_config, websocket=True), + _to_deepgram_url(live_config, self._base_url, websocket=True), headers={"Authorization": f"Token {self._api_key}"}, ), self._conn_options.timeout, @@ -489,7 +493,10 @@ async def _send_task(self): closing_ws = True await ws.send_str(SpeechStream._CLOSE_MSG) - except asyncio.CancelledError: + except ( + asyncio.CancelledError, + aiohttp.ClientConnectionError, + ): # Task was cancelled due to reconnection or closure pass except Exception: @@ -550,7 +557,6 @@ async def _keepalive_task(self): except ( asyncio.CancelledError, aiohttp.ClientConnectionError, - aiohttp.ClientConnectionError, ): # Task was cancelled due to reconnection or closure pass From 4862ecfea4b07f0e6f97c7d92091100a685f668a Mon Sep 17 00:00:00 2001 From: jayesh Date: Wed, 27 Nov 2024 12:06:06 +0530 Subject: [PATCH 14/31] deepgram stt wip --- .../livekit/plugins/deepgram/stt.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index 7a5b0f941..a49a8c06f 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -366,16 +366,25 @@ async def _run(self) -> None: continue if recv_task in done: for task in tasks: - if task.done() and task.exception(): - raise task.exception() + if task.done(): + exc = task.exception() + if exc: + if isinstance(exc, Exception): + raise exc + else: + raise Exception(f"Task raised exception: {exc}") if not task.done(): task.cancel() reconnect_wait_task.cancel() await asyncio.gather(*tasks, return_exceptions=True) break if send_task in done: - if send_task.exception(): - raise send_task.exception() + exc = task.exception() + if exc: + if isinstance(exc, Exception): + raise exc + else: + raise Exception(f"Task raised exception: {exc}") continue reconnect_wait_task.cancel() await asyncio.gather(reconnect_wait_task, return_exceptions=True) From aacb8fc80077b177ce9691469df968fffeb1c64d Mon Sep 17 00:00:00 2001 From: jayesh Date: Wed, 27 Nov 2024 12:29:32 +0530 Subject: [PATCH 15/31] deepgram stt wip --- .../livekit/plugins/deepgram/stt.py | 99 +++++++++---------- 1 file changed, 47 insertions(+), 52 deletions(-) diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index a49a8c06f..29ec4efe0 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -181,7 +181,7 @@ def __init__( energy_filter=energy_filter, ) self._session = http_session - self._active_streams: set[SpeechStream] = set() + self._active_stream: SpeechStream = None def _ensure_session(self) -> aiohttp.ClientSession: if not self._session: @@ -256,18 +256,16 @@ def stream( http_session=self._ensure_session(), base_url=self._base_url, ) - self._active_streams.add(stream) - return stream + self._active_stream = stream + return self._active_stream def remove_stream(self, stream: "SpeechStream") -> None: - """Remove a SpeechStream from the active streams set.""" - self._active_streams.discard(stream) + self._active_stream = None def update_options(self, language: DeepgramLanguages | str | None) -> None: """Update the STT options and propagate changes to active streams.""" self._opts.language = language or self._opts.language - for stream in self._active_streams: - asyncio.create_task(stream.update_options(language=language)) + asyncio.create_task(self._active_stream.update_options(language=language)) def _sanitize_options(self, *, language: str | None = None) -> STTOptions: config = dataclasses.replace(self._opts) @@ -324,15 +322,12 @@ def __init__( self._reconnect_event = asyncio.Event() self._closed = False self._ws: Optional[aiohttp.ClientWebSocketResponse] = None - self._ws_lock = asyncio.Lock() self._stt: STT = stt - self._stt._active_streams.add(self) async def update_options(self, language: DeepgramLanguages | str | None) -> None: - async with self._ws_lock: - self._opts.language = language or self._opts.language - self._reconnect_event.set() - logger.info("options updated, reconnection requested.") + self._opts.language = language or self._opts.language + self._reconnect_event.set() + logger.info("options updated, reconnection requested.") async def _run(self) -> None: tasks = [] @@ -362,8 +357,8 @@ async def _run(self) -> None: if self._ws and not self._ws.closed: await self._ws.close() self._ws = None - continue + if recv_task in done: for task in tasks: if task.done(): @@ -378,8 +373,9 @@ async def _run(self) -> None: reconnect_wait_task.cancel() await asyncio.gather(*tasks, return_exceptions=True) break + if send_task in done: - exc = task.exception() + exc = send_task.exception() if exc: if isinstance(exc, Exception): raise exc @@ -388,6 +384,7 @@ async def _run(self) -> None: continue reconnect_wait_task.cancel() await asyncio.gather(reconnect_wait_task, return_exceptions=True) + except Exception: logger.exception("Error in SpeechStream _run method") if tasks: @@ -403,43 +400,42 @@ async def _run(self) -> None: async def _connect_ws(self): """Establish the websocket connection using the current options.""" - async with self._ws_lock: - if self._ws and not self._ws.closed: - await self._ws.close() - - live_config = { - "model": self._opts.model, - "punctuate": self._opts.punctuate, - "smart_format": self._opts.smart_format, - "no_delay": self._opts.no_delay, - "interim_results": self._opts.interim_results, - "encoding": "linear16", - "vad_events": True, - "sample_rate": self._opts.sample_rate, - "channels": self._opts.num_channels, - "endpointing": False - if self._opts.endpointing_ms == 0 - else self._opts.endpointing_ms, - "filler_words": self._opts.filler_words, - "keywords": self._opts.keywords, - "profanity_filter": self._opts.profanity_filter, - } - - if self._opts.language: - live_config["language"] = self._opts.language + if self._ws and not self._ws.closed: + await self._ws.close() - try: - self._ws = await asyncio.wait_for( - self._session.ws_connect( - _to_deepgram_url(live_config, self._base_url, websocket=True), - headers={"Authorization": f"Token {self._api_key}"}, - ), - self._conn_options.timeout, - ) - logger.info("WebSocket connection established.") - except Exception as e: - logger.exception("Failed to establish WebSocket connection.") - raise APIConnectionError() from e + live_config = { + "model": self._opts.model, + "punctuate": self._opts.punctuate, + "smart_format": self._opts.smart_format, + "no_delay": self._opts.no_delay, + "interim_results": self._opts.interim_results, + "encoding": "linear16", + "vad_events": True, + "sample_rate": self._opts.sample_rate, + "channels": self._opts.num_channels, + "endpointing": False + if self._opts.endpointing_ms == 0 + else self._opts.endpointing_ms, + "filler_words": self._opts.filler_words, + "keywords": self._opts.keywords, + "profanity_filter": self._opts.profanity_filter, + } + + if self._opts.language: + live_config["language"] = self._opts.language + + try: + self._ws = await asyncio.wait_for( + self._session.ws_connect( + _to_deepgram_url(live_config, self._base_url, websocket=True), + headers={"Authorization": f"Token {self._api_key}"}, + ), + self._conn_options.timeout, + ) + logger.info("WebSocket connection established.") + except Exception as e: + logger.exception("Failed to establish WebSocket connection.") + raise APIConnectionError() from e async def _send_task(self): """Task for sending audio data to the websocket.""" @@ -658,7 +654,6 @@ async def aclose(self) -> None: self._reconnect_event.set() self._stt.remove_stream(self) await super().aclose() - await self._task await self._cleanup() async def _cleanup(self): From 6618d13665ac952d9752ae8825957b7f5bfedc59 Mon Sep 17 00:00:00 2001 From: jayesh Date: Wed, 27 Nov 2024 13:28:18 +0530 Subject: [PATCH 16/31] deepgram stt wip --- .../livekit/plugins/deepgram/stt.py | 376 +++++++----------- 1 file changed, 139 insertions(+), 237 deletions(-) diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index 29ec4efe0..702fbae77 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -181,7 +181,7 @@ def __init__( energy_filter=energy_filter, ) self._session = http_session - self._active_stream: SpeechStream = None + self._active_speech_stream: Optional[SpeechStream] = None def _ensure_session(self) -> aiohttp.ClientSession: if not self._session: @@ -247,8 +247,7 @@ def stream( conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, ) -> "SpeechStream": config = self._sanitize_options(language=language) - - stream = SpeechStream( + self._active_speech_stream = SpeechStream( stt=self, conn_options=conn_options, opts=config, @@ -256,16 +255,16 @@ def stream( http_session=self._ensure_session(), base_url=self._base_url, ) - self._active_stream = stream - return self._active_stream - - def remove_stream(self, stream: "SpeechStream") -> None: - self._active_stream = None + return self._active_speech_stream - def update_options(self, language: DeepgramLanguages | str | None) -> None: - """Update the STT options and propagate changes to active streams.""" - self._opts.language = language or self._opts.language - asyncio.create_task(self._active_stream.update_options(language=language)) + async def update_options(self, **kwargs): + for key, value in kwargs.items(): + if hasattr(self._opts, key): + setattr(self._opts, key, value) + else: + raise AttributeError(f"Invalid option: {key}") + if self._active_speech_stream is not None: + await self._active_speech_stream.update_options(self._opts) def _sanitize_options(self, *, language: str | None = None) -> STTOptions: config = dataclasses.replace(self._opts) @@ -319,90 +318,142 @@ def __init__( self._pushed_audio_duration = 0.0 self._request_id = "" - self._reconnect_event = asyncio.Event() + self._reconnect_needed = False + self._options_update_event = asyncio.Event() self._closed = False - self._ws: Optional[aiohttp.ClientWebSocketResponse] = None - self._stt: STT = stt - async def update_options(self, language: DeepgramLanguages | str | None) -> None: - self._opts.language = language or self._opts.language - self._reconnect_event.set() - logger.info("options updated, reconnection requested.") + async def update_options(self, opts: STTOptions): + self._opts = opts + self._reconnect_needed = True + self._options_update_event.set() async def _run(self) -> None: - tasks = [] - reconnect_wait_task = None while not self._closed: + closing_ws = False + + async def keepalive_task(ws: aiohttp.ClientWebSocketResponse): + # if we want to keep the connection alive even if no audio is sent, + # Deepgram expects a keepalive message. + # https://developers.deepgram.com/reference/listen-live#stream-keepalive + try: + while True: + await ws.send_str(SpeechStream._KEEPALIVE_MSG) + await asyncio.sleep(5) + except Exception: + return + + async def send_task(ws: aiohttp.ClientWebSocketResponse): + nonlocal closing_ws + + # forward audio to deepgram in chunks of 50ms + samples_50ms = self._opts.sample_rate // 20 + audio_bstream = utils.audio.AudioByteStream( + sample_rate=self._opts.sample_rate, + num_channels=self._opts.num_channels, + samples_per_channel=samples_50ms, + ) + + has_ended = False + last_frame: Optional[rtc.AudioFrame] = None + async for data in self._input_ch: + frames: list[rtc.AudioFrame] = [] + if isinstance(data, rtc.AudioFrame): + state = self._check_energy_state(data) + if state in ( + AudioEnergyFilter.State.START, + AudioEnergyFilter.State.SPEAKING, + ): + if last_frame: + frames.extend( + audio_bstream.write(last_frame.data.tobytes()) + ) + last_frame = None + frames.extend(audio_bstream.write(data.data.tobytes())) + elif state == AudioEnergyFilter.State.END: + # no need to buffer as we have cooldown period + frames = audio_bstream.flush() + has_ended = True + elif state == AudioEnergyFilter.State.SILENCE: + # buffer the last silence frame, since it could contain beginning of speech + # TODO: improve accuracy by using a ring buffer with longer window + last_frame = data + elif isinstance(data, self._FlushSentinel): + frames = audio_bstream.flush() + has_ended = True + + for frame in frames: + self._audio_duration_collector.push(frame.duration) + await ws.send_bytes(frame.data.tobytes()) + + if has_ended: + self._audio_duration_collector.flush() + await ws.send_str(SpeechStream._FINALIZE_MSG) + has_ended = False + + # tell deepgram we are done sending audio/inputs + closing_ws = True + await ws.send_str(SpeechStream._CLOSE_MSG) + + async def recv_task(ws: aiohttp.ClientWebSocketResponse): + nonlocal closing_ws + while True: + msg = await ws.receive() + if msg.type in ( + aiohttp.WSMsgType.CLOSED, + aiohttp.WSMsgType.CLOSE, + aiohttp.WSMsgType.CLOSING, + ): + if closing_ws: # close is expected, see SpeechStream.aclose + return + + # this will trigger a reconnection, see the _run loop + raise Exception("deepgram connection closed unexpectedly") + + if msg.type != aiohttp.WSMsgType.TEXT: + logger.warning("unexpected deepgram message type %s", msg.type) + continue + + try: + self._process_stream_event(json.loads(msg.data)) + except Exception: + logger.exception("failed to process deepgram message") + + async def wait_for_reconnect(): + await self._options_update_event.wait() + + ws: aiohttp.ClientWebSocketResponse | None = None + try: - await self._connect_ws() - send_task = asyncio.create_task(self._send_task()) - recv_task = asyncio.create_task(self._recv_task()) - keepalive_task = asyncio.create_task(self._keepalive_task()) - reconnect_wait_task = asyncio.create_task(self._reconnect_event.wait()) + ws = await self._connect_ws() + + tasks = [ + asyncio.create_task(send_task(ws)), + asyncio.create_task(recv_task(ws)), + asyncio.create_task(keepalive_task(ws)), + asyncio.create_task(wait_for_reconnect()), + ] - tasks = [send_task, recv_task, keepalive_task] done, pending = await asyncio.wait( - [reconnect_wait_task] + tasks, + tasks, return_when=asyncio.FIRST_COMPLETED, ) - if reconnect_wait_task.done(): - self._reconnect_event.clear() - logger.info("reconnecting with updated options...") - for task in tasks: - task.cancel() - reconnect_wait_task.cancel() - await asyncio.gather(*tasks, return_exceptions=True) - - if self._ws and not self._ws.closed: - await self._ws.close() - self._ws = None + if self._reconnect_needed: + self._reconnect_needed = False + self._options_update_event.clear() + await utils.aio.gracefully_cancel(*pending) + if ws is not None: + await ws.close() continue - - if recv_task in done: - for task in tasks: - if task.done(): - exc = task.exception() - if exc: - if isinstance(exc, Exception): - raise exc - else: - raise Exception(f"Task raised exception: {exc}") - if not task.done(): - task.cancel() - reconnect_wait_task.cancel() - await asyncio.gather(*tasks, return_exceptions=True) + else: + self._closed = True break - if send_task in done: - exc = send_task.exception() - if exc: - if isinstance(exc, Exception): - raise exc - else: - raise Exception(f"Task raised exception: {exc}") - continue - reconnect_wait_task.cancel() - await asyncio.gather(reconnect_wait_task, return_exceptions=True) - - except Exception: - logger.exception("Error in SpeechStream _run method") - if tasks: - for task in tasks: - task.cancel() - await asyncio.gather(*tasks, return_exceptions=True) - if reconnect_wait_task: - reconnect_wait_task.cancel() - await asyncio.gather(reconnect_wait_task, return_exceptions=True) - break - - await self._cleanup() - - async def _connect_ws(self): - """Establish the websocket connection using the current options.""" - if self._ws and not self._ws.closed: - await self._ws.close() + finally: + if ws is not None: + await ws.close() + async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse: live_config = { "model": self._opts.model, "punctuate": self._opts.punctuate, @@ -424,149 +475,14 @@ async def _connect_ws(self): if self._opts.language: live_config["language"] = self._opts.language - try: - self._ws = await asyncio.wait_for( - self._session.ws_connect( - _to_deepgram_url(live_config, self._base_url, websocket=True), - headers={"Authorization": f"Token {self._api_key}"}, - ), - self._conn_options.timeout, - ) - logger.info("WebSocket connection established.") - except Exception as e: - logger.exception("Failed to establish WebSocket connection.") - raise APIConnectionError() from e - - async def _send_task(self): - """Task for sending audio data to the websocket.""" - # Ensure the websocket is connected - if not self._ws or self._ws.closed: - logger.error("WebSocket is not connected in send_task.") - return - - ws = self._ws - closing_ws = False - - # Forward audio to deepgram in chunks of 50ms - samples_50ms = self._opts.sample_rate // 20 - audio_bstream = utils.audio.AudioByteStream( - sample_rate=self._opts.sample_rate, - num_channels=self._opts.num_channels, - samples_per_channel=samples_50ms, + ws = await asyncio.wait_for( + self._session.ws_connect( + _to_deepgram_url(live_config, base_url=self._base_url, websocket=True), + headers={"Authorization": f"Token {self._api_key}"}, + ), + self._conn_options.timeout, ) - - has_ended = False - last_frame: Optional[rtc.AudioFrame] = None - - try: - async for data in self._input_ch: - frames: list[rtc.AudioFrame] = [] - if isinstance(data, rtc.AudioFrame): - state = self._check_energy_state(data) - if state in ( - AudioEnergyFilter.State.START, - AudioEnergyFilter.State.SPEAKING, - ): - if last_frame: - frames.extend( - audio_bstream.write(last_frame.data.tobytes()) - ) - last_frame = None - frames.extend(audio_bstream.write(data.data.tobytes())) - elif state == AudioEnergyFilter.State.END: - # no need to buffer as we have cooldown period - frames = audio_bstream.flush() - has_ended = True - elif state == AudioEnergyFilter.State.SILENCE: - # buffer the last silence frame, since it could contain beginning of speech - # TODO: improve accuracy by using a ring buffer with longer window - last_frame = data - elif isinstance(data, self._FlushSentinel): - frames = audio_bstream.flush() - has_ended = True - - for frame in frames: - self._audio_duration_collector.push(frame.duration) - await ws.send_bytes(frame.data.tobytes()) - - if has_ended: - self._audio_duration_collector.flush() - await ws.send_str(SpeechStream._FINALIZE_MSG) - has_ended = False - - # tell deepgram we are done sending audio/inputs - closing_ws = True - await ws.send_str(SpeechStream._CLOSE_MSG) - - except ( - asyncio.CancelledError, - aiohttp.ClientConnectionError, - ): - # Task was cancelled due to reconnection or closure - pass - except Exception: - logger.exception("Error in send_task") - finally: - if closing_ws: - await ws.close() - - async def _recv_task(self): - """Task for receiving data from the websocket.""" - # Ensure the websocket is connected - if not self._ws or self._ws.closed: - logger.error("WebSocket is not connected in recv_task.") - return - - ws = self._ws - try: - while True: - msg = await ws.receive() - if msg.type in ( - aiohttp.WSMsgType.CLOSED, - aiohttp.WSMsgType.CLOSE, - aiohttp.WSMsgType.CLOSING, - ): - logger.info("WebSocket connection closed.") - break - - if msg.type != aiohttp.WSMsgType.TEXT: - logger.warning("unexpected deepgram message type %s", msg.type) - continue - - try: - self._process_stream_event(json.loads(msg.data)) - except Exception: - logger.exception("Failed to process message from Deepgram") - - except asyncio.CancelledError: - # Task was cancelled due to reconnection or closure - pass - except Exception: - logger.exception("Error in recv_task") - finally: - if not ws.closed: - await ws.close() - - async def _keepalive_task(self): - """Task for sending keepalive messages.""" - # Ensure the websocket is connected - if not self._ws or self._ws.closed: - logger.error("WebSocket is not connected in keepalive_task.") - return - - ws = self._ws - try: - while True: - await ws.send_str(SpeechStream._KEEPALIVE_MSG) - await asyncio.sleep(5) - except ( - asyncio.CancelledError, - aiohttp.ClientConnectionError, - ): - # Task was cancelled due to reconnection or closure - pass - except Exception: - logger.exception("Error in keepalive_task") + return ws def _check_energy_state(self, frame: rtc.AudioFrame) -> AudioEnergyFilter.State: if self._audio_energy_filter: @@ -648,20 +564,6 @@ def _process_stream_event(self, data: dict) -> None: else: logger.warning("received unexpected message from deepgram %s", data) - async def aclose(self) -> None: - """Close the stream and clean up resources.""" - self._closed = True - self._reconnect_event.set() - self._stt.remove_stream(self) - await super().aclose() - await self._cleanup() - - async def _cleanup(self): - """Cleanup resources.""" - if self._ws and not self._ws.closed: - await self._ws.close() - self._ws = None - def live_transcription_to_speech_data( language: str, data: dict From e2de5f5c5926a08bbf78f4ded0242fb89d35425d Mon Sep 17 00:00:00 2001 From: jayesh Date: Wed, 27 Nov 2024 14:27:44 +0530 Subject: [PATCH 17/31] deepgram stt wip --- .../livekit/plugins/deepgram/stt.py | 186 +++++++++--------- 1 file changed, 90 insertions(+), 96 deletions(-) diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index 702fbae77..1fc58748c 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -329,108 +329,17 @@ async def update_options(self, opts: STTOptions): async def _run(self) -> None: while not self._closed: - closing_ws = False - - async def keepalive_task(ws: aiohttp.ClientWebSocketResponse): - # if we want to keep the connection alive even if no audio is sent, - # Deepgram expects a keepalive message. - # https://developers.deepgram.com/reference/listen-live#stream-keepalive - try: - while True: - await ws.send_str(SpeechStream._KEEPALIVE_MSG) - await asyncio.sleep(5) - except Exception: - return - - async def send_task(ws: aiohttp.ClientWebSocketResponse): - nonlocal closing_ws - - # forward audio to deepgram in chunks of 50ms - samples_50ms = self._opts.sample_rate // 20 - audio_bstream = utils.audio.AudioByteStream( - sample_rate=self._opts.sample_rate, - num_channels=self._opts.num_channels, - samples_per_channel=samples_50ms, - ) - - has_ended = False - last_frame: Optional[rtc.AudioFrame] = None - async for data in self._input_ch: - frames: list[rtc.AudioFrame] = [] - if isinstance(data, rtc.AudioFrame): - state = self._check_energy_state(data) - if state in ( - AudioEnergyFilter.State.START, - AudioEnergyFilter.State.SPEAKING, - ): - if last_frame: - frames.extend( - audio_bstream.write(last_frame.data.tobytes()) - ) - last_frame = None - frames.extend(audio_bstream.write(data.data.tobytes())) - elif state == AudioEnergyFilter.State.END: - # no need to buffer as we have cooldown period - frames = audio_bstream.flush() - has_ended = True - elif state == AudioEnergyFilter.State.SILENCE: - # buffer the last silence frame, since it could contain beginning of speech - # TODO: improve accuracy by using a ring buffer with longer window - last_frame = data - elif isinstance(data, self._FlushSentinel): - frames = audio_bstream.flush() - has_ended = True - - for frame in frames: - self._audio_duration_collector.push(frame.duration) - await ws.send_bytes(frame.data.tobytes()) - - if has_ended: - self._audio_duration_collector.flush() - await ws.send_str(SpeechStream._FINALIZE_MSG) - has_ended = False - - # tell deepgram we are done sending audio/inputs - closing_ws = True - await ws.send_str(SpeechStream._CLOSE_MSG) - - async def recv_task(ws: aiohttp.ClientWebSocketResponse): - nonlocal closing_ws - while True: - msg = await ws.receive() - if msg.type in ( - aiohttp.WSMsgType.CLOSED, - aiohttp.WSMsgType.CLOSE, - aiohttp.WSMsgType.CLOSING, - ): - if closing_ws: # close is expected, see SpeechStream.aclose - return - - # this will trigger a reconnection, see the _run loop - raise Exception("deepgram connection closed unexpectedly") - - if msg.type != aiohttp.WSMsgType.TEXT: - logger.warning("unexpected deepgram message type %s", msg.type) - continue - - try: - self._process_stream_event(json.loads(msg.data)) - except Exception: - logger.exception("failed to process deepgram message") - - async def wait_for_reconnect(): - await self._options_update_event.wait() - + self._closing_ws = False ws: aiohttp.ClientWebSocketResponse | None = None try: ws = await self._connect_ws() tasks = [ - asyncio.create_task(send_task(ws)), - asyncio.create_task(recv_task(ws)), - asyncio.create_task(keepalive_task(ws)), - asyncio.create_task(wait_for_reconnect()), + asyncio.create_task(self._send_task(ws)), + asyncio.create_task(self._recv_task(ws)), + asyncio.create_task(self._keepalive_task(ws)), + asyncio.create_task(self._wait_for_reconnect()), ] done, pending = await asyncio.wait( @@ -453,6 +362,91 @@ async def wait_for_reconnect(): if ws is not None: await ws.close() + async def _keepalive_task(self, ws: aiohttp.ClientWebSocketResponse): + # if we want to keep the connection alive even if no audio is sent, + # Deepgram expects a keepalive message. + # https://developers.deepgram.com/reference/listen-live#stream-keepalive + try: + while True: + await ws.send_str(SpeechStream._KEEPALIVE_MSG) + await asyncio.sleep(5) + except Exception: + return + + async def _send_task(self, ws: aiohttp.ClientWebSocketResponse): + # forward audio to deepgram in chunks of 50ms + samples_50ms = self._opts.sample_rate // 20 + audio_bstream = utils.audio.AudioByteStream( + sample_rate=self._opts.sample_rate, + num_channels=self._opts.num_channels, + samples_per_channel=samples_50ms, + ) + + has_ended = False + last_frame: Optional[rtc.AudioFrame] = None + async for data in self._input_ch: + frames: list[rtc.AudioFrame] = [] + if isinstance(data, rtc.AudioFrame): + state = self._check_energy_state(data) + if state in ( + AudioEnergyFilter.State.START, + AudioEnergyFilter.State.SPEAKING, + ): + if last_frame: + frames.extend(audio_bstream.write(last_frame.data.tobytes())) + last_frame = None + frames.extend(audio_bstream.write(data.data.tobytes())) + elif state == AudioEnergyFilter.State.END: + # no need to buffer as we have cooldown period + frames = audio_bstream.flush() + has_ended = True + elif state == AudioEnergyFilter.State.SILENCE: + # buffer the last silence frame, since it could contain beginning of speech + # TODO: improve accuracy by using a ring buffer with longer window + last_frame = data + elif isinstance(data, self._FlushSentinel): + frames = audio_bstream.flush() + has_ended = True + + for frame in frames: + self._audio_duration_collector.push(frame.duration) + await ws.send_bytes(frame.data.tobytes()) + + if has_ended: + self._audio_duration_collector.flush() + await ws.send_str(SpeechStream._FINALIZE_MSG) + has_ended = False + + # tell deepgram we are done sending audio/inputs + self._self._closing_ws = True + await ws.send_str(SpeechStream._CLOSE_MSG) + + async def _recv_task(self, ws: aiohttp.ClientWebSocketResponse): + while True: + msg = await ws.receive() + if msg.type in ( + aiohttp.WSMsgType.CLOSED, + aiohttp.WSMsgType.CLOSE, + aiohttp.WSMsgType.CLOSING, + ): + if self._closing_ws: # close is expected, see SpeechStream.aclose + return + + # this will trigger a reconnection, see the _run loop + raise Exception("deepgram connection closed unexpectedly") + + if msg.type != aiohttp.WSMsgType.TEXT: + logger.warning("unexpected deepgram message type %s", msg.type) + continue + + try: + self._process_stream_event(json.loads(msg.data)) + except Exception: + logger.exception("failed to process deepgram message") + + async def _wait_for_reconnect(self): + await self._options_update_event.wait() + async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse: live_config = { "model": self._opts.model, From 12597aefe48be61e22e213872072ac8eb1864aab Mon Sep 17 00:00:00 2001 From: jayesh Date: Wed, 27 Nov 2024 14:32:17 +0530 Subject: [PATCH 18/31] deepgram stt wip --- .../livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index 1fc58748c..ee016dfbb 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -418,7 +418,7 @@ async def _send_task(self, ws: aiohttp.ClientWebSocketResponse): has_ended = False # tell deepgram we are done sending audio/inputs - self._self._closing_ws = True + self._closing_ws = True await ws.send_str(SpeechStream._CLOSE_MSG) async def _recv_task(self, ws: aiohttp.ClientWebSocketResponse): From 1f0922b380292cac6957d1a885186afecbe70ccc Mon Sep 17 00:00:00 2001 From: jayesh Date: Wed, 27 Nov 2024 16:14:44 +0530 Subject: [PATCH 19/31] deepgram stt wip --- .../livekit/plugins/deepgram/stt.py | 56 +++++++++---------- 1 file changed, 27 insertions(+), 29 deletions(-) diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index ee016dfbb..bb33fed96 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -276,6 +276,10 @@ def _sanitize_options(self, *, language: str | None = None) -> STTOptions: return config +class ReconnectRequired(Exception): + pass + + class SpeechStream(stt.SpeechStream): _KEEPALIVE_MSG: str = json.dumps({"type": "KeepAlive"}) _CLOSE_MSG: str = json.dumps({"type": "CloseStream"}) @@ -328,39 +332,30 @@ async def update_options(self, opts: STTOptions): self._options_update_event.set() async def _run(self) -> None: - while not self._closed: - self._closing_ws = False - ws: aiohttp.ClientWebSocketResponse | None = None + self._closing_ws = False + ws: aiohttp.ClientWebSocketResponse | None = None - try: - ws = await self._connect_ws() - - tasks = [ - asyncio.create_task(self._send_task(ws)), - asyncio.create_task(self._recv_task(ws)), - asyncio.create_task(self._keepalive_task(ws)), - asyncio.create_task(self._wait_for_reconnect()), - ] - - done, pending = await asyncio.wait( - tasks, - return_when=asyncio.FIRST_COMPLETED, - ) + try: + ws = await self._connect_ws() - if self._reconnect_needed: - self._reconnect_needed = False - self._options_update_event.clear() - await utils.aio.gracefully_cancel(*pending) - if ws is not None: - await ws.close() - continue - else: - self._closed = True - break + tasks = [ + asyncio.create_task(self._send_task(ws)), + asyncio.create_task(self._recv_task(ws)), + asyncio.create_task(self._keepalive_task(ws)), + ] + try: + await asyncio.gather(*tasks) + # except ReconnectRequired: + # logger.info("Reconnection requested.") + # await utils.aio.gracefully_cancel(*tasks) + # await ws.close() finally: - if ws is not None: - await ws.close() + await utils.aio.gracefully_cancel(*tasks) + + finally: + if ws is not None: + await ws.close() async def _keepalive_task(self, ws: aiohttp.ClientWebSocketResponse): # if we want to keep the connection alive even if no audio is sent, @@ -446,6 +441,9 @@ async def _recv_task(self, ws: aiohttp.ClientWebSocketResponse): async def _wait_for_reconnect(self): await self._options_update_event.wait() + if not self._closed: + self._options_update_event.clear() + raise ReconnectRequired() async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse: live_config = { From 549e301d44e2892458a04a919497428657f45cc3 Mon Sep 17 00:00:00 2001 From: jayesh Date: Wed, 27 Nov 2024 16:43:10 +0530 Subject: [PATCH 20/31] deepgram stt wip --- .../livekit/plugins/deepgram/stt.py | 57 +++++++------------ 1 file changed, 20 insertions(+), 37 deletions(-) diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index bb33fed96..9359ba64e 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -276,10 +276,6 @@ def _sanitize_options(self, *, language: str | None = None) -> STTOptions: return config -class ReconnectRequired(Exception): - pass - - class SpeechStream(stt.SpeechStream): _KEEPALIVE_MSG: str = json.dumps({"type": "KeepAlive"}) _CLOSE_MSG: str = json.dumps({"type": "CloseStream"}) @@ -322,53 +318,44 @@ def __init__( self._pushed_audio_duration = 0.0 self._request_id = "" - self._reconnect_needed = False - self._options_update_event = asyncio.Event() - self._closed = False + self._ws: aiohttp.ClientWebSocketResponse | None = None async def update_options(self, opts: STTOptions): self._opts = opts - self._reconnect_needed = True - self._options_update_event.set() + await self._connect_ws() async def _run(self) -> None: self._closing_ws = False - ws: aiohttp.ClientWebSocketResponse | None = None - try: - ws = await self._connect_ws() + await self._connect_ws() tasks = [ - asyncio.create_task(self._send_task(ws)), - asyncio.create_task(self._recv_task(ws)), - asyncio.create_task(self._keepalive_task(ws)), + asyncio.create_task(self._send_task()), + asyncio.create_task(self._recv_task()), + asyncio.create_task(self._keepalive_task()), ] try: await asyncio.gather(*tasks) - # except ReconnectRequired: - # logger.info("Reconnection requested.") - # await utils.aio.gracefully_cancel(*tasks) - # await ws.close() finally: await utils.aio.gracefully_cancel(*tasks) finally: - if ws is not None: - await ws.close() + if self._ws is not None: + await self._ws.close() - async def _keepalive_task(self, ws: aiohttp.ClientWebSocketResponse): + async def _keepalive_task(self): # if we want to keep the connection alive even if no audio is sent, # Deepgram expects a keepalive message. # https://developers.deepgram.com/reference/listen-live#stream-keepalive try: while True: - await ws.send_str(SpeechStream._KEEPALIVE_MSG) + await self._ws.send_str(SpeechStream._KEEPALIVE_MSG) await asyncio.sleep(5) except Exception: return - async def _send_task(self, ws: aiohttp.ClientWebSocketResponse): + async def _send_task(self): # forward audio to deepgram in chunks of 50ms samples_50ms = self._opts.sample_rate // 20 audio_bstream = utils.audio.AudioByteStream( @@ -405,20 +392,20 @@ async def _send_task(self, ws: aiohttp.ClientWebSocketResponse): for frame in frames: self._audio_duration_collector.push(frame.duration) - await ws.send_bytes(frame.data.tobytes()) + await self._ws.send_bytes(frame.data.tobytes()) if has_ended: self._audio_duration_collector.flush() - await ws.send_str(SpeechStream._FINALIZE_MSG) + await self._ws.send_str(SpeechStream._FINALIZE_MSG) has_ended = False # tell deepgram we are done sending audio/inputs self._closing_ws = True - await ws.send_str(SpeechStream._CLOSE_MSG) + await self._ws.send_str(SpeechStream._CLOSE_MSG) - async def _recv_task(self, ws: aiohttp.ClientWebSocketResponse): + async def _recv_task(self): while True: - msg = await ws.receive() + msg = await self._ws.receive() if msg.type in ( aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, @@ -439,13 +426,10 @@ async def _recv_task(self, ws: aiohttp.ClientWebSocketResponse): except Exception: logger.exception("failed to process deepgram message") - async def _wait_for_reconnect(self): - await self._options_update_event.wait() - if not self._closed: - self._options_update_event.clear() - raise ReconnectRequired() - async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse: + if self._ws is not None: + await self._ws.close() + self._ws = None live_config = { "model": self._opts.model, "punctuate": self._opts.punctuate, @@ -467,14 +451,13 @@ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse: if self._opts.language: live_config["language"] = self._opts.language - ws = await asyncio.wait_for( + self._ws = await asyncio.wait_for( self._session.ws_connect( _to_deepgram_url(live_config, base_url=self._base_url, websocket=True), headers={"Authorization": f"Token {self._api_key}"}, ), self._conn_options.timeout, ) - return ws def _check_energy_state(self, frame: rtc.AudioFrame) -> AudioEnergyFilter.State: if self._audio_energy_filter: From 7e02271c89846eac31dd2f32789678ad1ad03a84 Mon Sep 17 00:00:00 2001 From: jayesh Date: Wed, 27 Nov 2024 17:18:40 +0530 Subject: [PATCH 21/31] deepgram stt wip --- .../livekit/plugins/deepgram/stt.py | 20 ++++++------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index 9359ba64e..df6b289e2 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -257,14 +257,9 @@ def stream( ) return self._active_speech_stream - async def update_options(self, **kwargs): - for key, value in kwargs.items(): - if hasattr(self._opts, key): - setattr(self._opts, key, value) - else: - raise AttributeError(f"Invalid option: {key}") - if self._active_speech_stream is not None: - await self._active_speech_stream.update_options(self._opts) + async def update_options(self, language: str | None = None): + if self._active_speech_stream is not None and language is not None: + await self._active_speech_stream.update_options(language) def _sanitize_options(self, *, language: str | None = None) -> STTOptions: config = dataclasses.replace(self._opts) @@ -320,8 +315,8 @@ def __init__( self._ws: aiohttp.ClientWebSocketResponse | None = None - async def update_options(self, opts: STTOptions): - self._opts = opts + async def update_options(self, language: str | None = None): + self._opts.language = language or self._opts.language await self._connect_ws() async def _run(self) -> None: @@ -426,10 +421,7 @@ async def _recv_task(self): except Exception: logger.exception("failed to process deepgram message") - async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse: - if self._ws is not None: - await self._ws.close() - self._ws = None + async def _connect_ws(self): live_config = { "model": self._opts.model, "punctuate": self._opts.punctuate, From bd9cf1f05bf1961947767d0e48485a7a22604731 Mon Sep 17 00:00:00 2001 From: jayesh Date: Thu, 28 Nov 2024 17:25:33 +0530 Subject: [PATCH 22/31] wip --- .../livekit/plugins/deepgram/stt.py | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index df6b289e2..54f7a1f42 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -312,12 +312,12 @@ def __init__( self._pushed_audio_duration = 0.0 self._request_id = "" - + self._reconnect_event = asyncio.Event() self._ws: aiohttp.ClientWebSocketResponse | None = None async def update_options(self, language: str | None = None): self._opts.language = language or self._opts.language - await self._connect_ws() + self._reconnect_event.set() async def _run(self) -> None: self._closing_ws = False @@ -325,13 +325,27 @@ async def _run(self) -> None: await self._connect_ws() tasks = [ - asyncio.create_task(self._send_task()), - asyncio.create_task(self._recv_task()), - asyncio.create_task(self._keepalive_task()), + asyncio.create_task(self._send_task(), name="send_task"), + asyncio.create_task(self._recv_task(), name="recv_task"), + asyncio.create_task(self._keepalive_task(), name="keepalive_task"), ] try: - await asyncio.gather(*tasks) + while True: + done, pending = await asyncio.wait( + [ + asyncio.gather(*tasks), + asyncio.create_task( + self._reconnect_event.wait(), name="reconnect" + ), + ], + return_when=asyncio.FIRST_COMPLETED, + ) + if self._reconnect_event.is_set(): + self._reconnect_event.clear() + await self._connect_ws() + else: + break finally: await utils.aio.gracefully_cancel(*tasks) From 9046b2c5ba97f52e09d2ec31d2e959405b00d13d Mon Sep 17 00:00:00 2001 From: jayesh Date: Thu, 28 Nov 2024 17:59:25 +0530 Subject: [PATCH 23/31] wip --- .../livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index 54f7a1f42..3c794e086 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -332,7 +332,7 @@ async def _run(self) -> None: try: while True: - done, pending = await asyncio.wait( + await asyncio.wait( [ asyncio.gather(*tasks), asyncio.create_task( From 2445f8596b955a0bc83c09742a71bf16bfea7142 Mon Sep 17 00:00:00 2001 From: jayesh Date: Fri, 29 Nov 2024 12:56:47 +0530 Subject: [PATCH 24/31] wip --- .../livekit/plugins/deepgram/stt.py | 267 +++++++++--------- 1 file changed, 136 insertions(+), 131 deletions(-) diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index 3c794e086..7183722f5 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -257,9 +257,9 @@ def stream( ) return self._active_speech_stream - async def update_options(self, language: str | None = None): + def update_options(self, language: str | None = None): if self._active_speech_stream is not None and language is not None: - await self._active_speech_stream.update_options(language) + self._active_speech_stream.update_options(language) def _sanitize_options(self, *, language: str | None = None) -> STTOptions: config = dataclasses.replace(self._opts) @@ -315,155 +315,160 @@ def __init__( self._reconnect_event = asyncio.Event() self._ws: aiohttp.ClientWebSocketResponse | None = None - async def update_options(self, language: str | None = None): + def update_options(self, language: str | None = None): self._opts.language = language or self._opts.language self._reconnect_event.set() async def _run(self) -> None: - self._closing_ws = False - try: - await self._connect_ws() - - tasks = [ - asyncio.create_task(self._send_task(), name="send_task"), - asyncio.create_task(self._recv_task(), name="recv_task"), - asyncio.create_task(self._keepalive_task(), name="keepalive_task"), - ] + closing_ws = False + async def keepalive_task(ws: aiohttp.ClientWebSocketResponse): + # if we want to keep the connection alive even if no audio is sent, + # Deepgram expects a keepalive message. + # https://developers.deepgram.com/reference/listen-live#stream-keepalive try: while True: + await ws.send_str(SpeechStream._KEEPALIVE_MSG) + await asyncio.sleep(5) + except Exception: + return + + async def send_task(ws: aiohttp.ClientWebSocketResponse): + nonlocal closing_ws + + # forward audio to deepgram in chunks of 50ms + samples_50ms = self._opts.sample_rate // 20 + audio_bstream = utils.audio.AudioByteStream( + sample_rate=self._opts.sample_rate, + num_channels=self._opts.num_channels, + samples_per_channel=samples_50ms, + ) + + has_ended = False + last_frame: Optional[rtc.AudioFrame] = None + async for data in self._input_ch: + frames: list[rtc.AudioFrame] = [] + if isinstance(data, rtc.AudioFrame): + state = self._check_energy_state(data) + if state in ( + AudioEnergyFilter.State.START, + AudioEnergyFilter.State.SPEAKING, + ): + if last_frame: + frames.extend( + audio_bstream.write(last_frame.data.tobytes()) + ) + last_frame = None + frames.extend(audio_bstream.write(data.data.tobytes())) + elif state == AudioEnergyFilter.State.END: + # no need to buffer as we have cooldown period + frames = audio_bstream.flush() + has_ended = True + elif state == AudioEnergyFilter.State.SILENCE: + # buffer the last silence frame, since it could contain beginning of speech + # TODO: improve accuracy by using a ring buffer with longer window + last_frame = data + elif isinstance(data, self._FlushSentinel): + frames = audio_bstream.flush() + has_ended = True + + for frame in frames: + self._audio_duration_collector.push(frame.duration) + await ws.send_bytes(frame.data.tobytes()) + + if has_ended: + self._audio_duration_collector.flush() + await ws.send_str(SpeechStream._FINALIZE_MSG) + has_ended = False + + # tell deepgram we are done sending audio/inputs + closing_ws = True + await ws.send_str(SpeechStream._CLOSE_MSG) + + async def recv_task(ws: aiohttp.ClientWebSocketResponse): + nonlocal closing_ws + while True: + msg = await ws.receive() + if msg.type in ( + aiohttp.WSMsgType.CLOSED, + aiohttp.WSMsgType.CLOSE, + aiohttp.WSMsgType.CLOSING, + ): + if closing_ws: # close is expected, see SpeechStream.aclose + return + + # this will trigger a reconnection, see the _run loop + raise Exception("deepgram connection closed unexpectedly") + + if msg.type != aiohttp.WSMsgType.TEXT: + logger.warning("unexpected deepgram message type %s", msg.type) + continue + + try: + self._process_stream_event(json.loads(msg.data)) + except Exception: + logger.exception("failed to process deepgram message") + + ws: aiohttp.ClientWebSocketResponse | None = None + + try: + while True: + live_config = { + "model": self._opts.model, + "punctuate": self._opts.punctuate, + "smart_format": self._opts.smart_format, + "no_delay": self._opts.no_delay, + "interim_results": self._opts.interim_results, + "encoding": "linear16", + "vad_events": True, + "sample_rate": self._opts.sample_rate, + "channels": self._opts.num_channels, + "endpointing": False + if self._opts.endpointing_ms == 0 + else self._opts.endpointing_ms, + "filler_words": self._opts.filler_words, + "keywords": self._opts.keywords, + "profanity_filter": self._opts.profanity_filter, + } + + if self._opts.language: + live_config["language"] = self._opts.language + + ws = await asyncio.wait_for( + self._session.ws_connect( + _to_deepgram_url( + live_config, base_url=self._base_url, websocket=True + ), + headers={"Authorization": f"Token {self._api_key}"}, + ), + self._conn_options.timeout, + ) + + tasks = [ + asyncio.create_task(send_task(ws)), + asyncio.create_task(recv_task(ws)), + asyncio.create_task(keepalive_task(ws)), + ] + try: await asyncio.wait( [ asyncio.gather(*tasks), asyncio.create_task( - self._reconnect_event.wait(), name="reconnect" + self._reconnect_event.wait(), name="wait-for-reconnect" ), ], return_when=asyncio.FIRST_COMPLETED, ) if self._reconnect_event.is_set(): self._reconnect_event.clear() - await self._connect_ws() + continue else: break - finally: - await utils.aio.gracefully_cancel(*tasks) - + finally: + await utils.aio.gracefully_cancel(*tasks) finally: - if self._ws is not None: - await self._ws.close() - - async def _keepalive_task(self): - # if we want to keep the connection alive even if no audio is sent, - # Deepgram expects a keepalive message. - # https://developers.deepgram.com/reference/listen-live#stream-keepalive - try: - while True: - await self._ws.send_str(SpeechStream._KEEPALIVE_MSG) - await asyncio.sleep(5) - except Exception: - return - - async def _send_task(self): - # forward audio to deepgram in chunks of 50ms - samples_50ms = self._opts.sample_rate // 20 - audio_bstream = utils.audio.AudioByteStream( - sample_rate=self._opts.sample_rate, - num_channels=self._opts.num_channels, - samples_per_channel=samples_50ms, - ) - - has_ended = False - last_frame: Optional[rtc.AudioFrame] = None - async for data in self._input_ch: - frames: list[rtc.AudioFrame] = [] - if isinstance(data, rtc.AudioFrame): - state = self._check_energy_state(data) - if state in ( - AudioEnergyFilter.State.START, - AudioEnergyFilter.State.SPEAKING, - ): - if last_frame: - frames.extend(audio_bstream.write(last_frame.data.tobytes())) - last_frame = None - frames.extend(audio_bstream.write(data.data.tobytes())) - elif state == AudioEnergyFilter.State.END: - # no need to buffer as we have cooldown period - frames = audio_bstream.flush() - has_ended = True - elif state == AudioEnergyFilter.State.SILENCE: - # buffer the last silence frame, since it could contain beginning of speech - # TODO: improve accuracy by using a ring buffer with longer window - last_frame = data - elif isinstance(data, self._FlushSentinel): - frames = audio_bstream.flush() - has_ended = True - - for frame in frames: - self._audio_duration_collector.push(frame.duration) - await self._ws.send_bytes(frame.data.tobytes()) - - if has_ended: - self._audio_duration_collector.flush() - await self._ws.send_str(SpeechStream._FINALIZE_MSG) - has_ended = False - - # tell deepgram we are done sending audio/inputs - self._closing_ws = True - await self._ws.send_str(SpeechStream._CLOSE_MSG) - - async def _recv_task(self): - while True: - msg = await self._ws.receive() - if msg.type in ( - aiohttp.WSMsgType.CLOSED, - aiohttp.WSMsgType.CLOSE, - aiohttp.WSMsgType.CLOSING, - ): - if self._closing_ws: # close is expected, see SpeechStream.aclose - return - - # this will trigger a reconnection, see the _run loop - raise Exception("deepgram connection closed unexpectedly") - - if msg.type != aiohttp.WSMsgType.TEXT: - logger.warning("unexpected deepgram message type %s", msg.type) - continue - - try: - self._process_stream_event(json.loads(msg.data)) - except Exception: - logger.exception("failed to process deepgram message") - - async def _connect_ws(self): - live_config = { - "model": self._opts.model, - "punctuate": self._opts.punctuate, - "smart_format": self._opts.smart_format, - "no_delay": self._opts.no_delay, - "interim_results": self._opts.interim_results, - "encoding": "linear16", - "vad_events": True, - "sample_rate": self._opts.sample_rate, - "channels": self._opts.num_channels, - "endpointing": False - if self._opts.endpointing_ms == 0 - else self._opts.endpointing_ms, - "filler_words": self._opts.filler_words, - "keywords": self._opts.keywords, - "profanity_filter": self._opts.profanity_filter, - } - - if self._opts.language: - live_config["language"] = self._opts.language - - self._ws = await asyncio.wait_for( - self._session.ws_connect( - _to_deepgram_url(live_config, base_url=self._base_url, websocket=True), - headers={"Authorization": f"Token {self._api_key}"}, - ), - self._conn_options.timeout, - ) + if ws is not None: + await ws.close() def _check_energy_state(self, frame: rtc.AudioFrame) -> AudioEnergyFilter.State: if self._audio_energy_filter: From 77acebd821423bbfd7e98596a5b3b291cb863ea3 Mon Sep 17 00:00:00 2001 From: jayesh Date: Fri, 29 Nov 2024 14:00:23 +0530 Subject: [PATCH 25/31] wip --- .../livekit/plugins/deepgram/stt.py | 83 ++++++++++--------- 1 file changed, 42 insertions(+), 41 deletions(-) diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index 7183722f5..cf906dc74 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -412,44 +412,14 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse): ws: aiohttp.ClientWebSocketResponse | None = None try: - while True: - live_config = { - "model": self._opts.model, - "punctuate": self._opts.punctuate, - "smart_format": self._opts.smart_format, - "no_delay": self._opts.no_delay, - "interim_results": self._opts.interim_results, - "encoding": "linear16", - "vad_events": True, - "sample_rate": self._opts.sample_rate, - "channels": self._opts.num_channels, - "endpointing": False - if self._opts.endpointing_ms == 0 - else self._opts.endpointing_ms, - "filler_words": self._opts.filler_words, - "keywords": self._opts.keywords, - "profanity_filter": self._opts.profanity_filter, - } - - if self._opts.language: - live_config["language"] = self._opts.language - - ws = await asyncio.wait_for( - self._session.ws_connect( - _to_deepgram_url( - live_config, base_url=self._base_url, websocket=True - ), - headers={"Authorization": f"Token {self._api_key}"}, - ), - self._conn_options.timeout, - ) - - tasks = [ - asyncio.create_task(send_task(ws)), - asyncio.create_task(recv_task(ws)), - asyncio.create_task(keepalive_task(ws)), - ] - try: + ws = await self._connect_ws() + tasks = [ + asyncio.create_task(send_task(ws)), + asyncio.create_task(recv_task(ws)), + asyncio.create_task(keepalive_task(ws)), + ] + try: + while True: await asyncio.wait( [ asyncio.gather(*tasks), @@ -461,15 +431,46 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse): ) if self._reconnect_event.is_set(): self._reconnect_event.clear() - continue + ws = await self._connect_ws() else: break - finally: - await utils.aio.gracefully_cancel(*tasks) + finally: + await utils.aio.gracefully_cancel(*tasks) finally: if ws is not None: await ws.close() + async def _connect_ws(self): + live_config = { + "model": self._opts.model, + "punctuate": self._opts.punctuate, + "smart_format": self._opts.smart_format, + "no_delay": self._opts.no_delay, + "interim_results": self._opts.interim_results, + "encoding": "linear16", + "vad_events": True, + "sample_rate": self._opts.sample_rate, + "channels": self._opts.num_channels, + "endpointing": False + if self._opts.endpointing_ms == 0 + else self._opts.endpointing_ms, + "filler_words": self._opts.filler_words, + "keywords": self._opts.keywords, + "profanity_filter": self._opts.profanity_filter, + } + + if self._opts.language: + live_config["language"] = self._opts.language + + ws = await asyncio.wait_for( + self._session.ws_connect( + _to_deepgram_url(live_config, base_url=self._base_url, websocket=True), + headers={"Authorization": f"Token {self._api_key}"}, + ), + self._conn_options.timeout, + ) + return ws + def _check_energy_state(self, frame: rtc.AudioFrame) -> AudioEnergyFilter.State: if self._audio_energy_filter: return self._audio_energy_filter.update(frame) From 1295db9316ffdfd8ab22b9800e7276cde017ce8f Mon Sep 17 00:00:00 2001 From: jayesh Date: Fri, 29 Nov 2024 14:19:46 +0530 Subject: [PATCH 26/31] wip --- .../livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index cf906dc74..5abbd9791 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -440,7 +440,7 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse): if ws is not None: await ws.close() - async def _connect_ws(self): + async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse: live_config = { "model": self._opts.model, "punctuate": self._opts.punctuate, From 460a56b77130942a8c17adae19597ac22bfe2ec9 Mon Sep 17 00:00:00 2001 From: jayesh Date: Fri, 29 Nov 2024 18:46:45 +0530 Subject: [PATCH 27/31] deepgram update options --- .../livekit/plugins/deepgram/stt.py | 187 ++++++++++-------- 1 file changed, 102 insertions(+), 85 deletions(-) diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index 5abbd9791..beb53cde4 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -330,112 +330,129 @@ async def keepalive_task(ws: aiohttp.ClientWebSocketResponse): while True: await ws.send_str(SpeechStream._KEEPALIVE_MSG) await asyncio.sleep(5) + except asyncio.CancelledError: + pass except Exception: return async def send_task(ws: aiohttp.ClientWebSocketResponse): - nonlocal closing_ws - - # forward audio to deepgram in chunks of 50ms - samples_50ms = self._opts.sample_rate // 20 - audio_bstream = utils.audio.AudioByteStream( - sample_rate=self._opts.sample_rate, - num_channels=self._opts.num_channels, - samples_per_channel=samples_50ms, - ) + try: + nonlocal closing_ws + + # forward audio to deepgram in chunks of 50ms + samples_50ms = self._opts.sample_rate // 20 + audio_bstream = utils.audio.AudioByteStream( + sample_rate=self._opts.sample_rate, + num_channels=self._opts.num_channels, + samples_per_channel=samples_50ms, + ) - has_ended = False - last_frame: Optional[rtc.AudioFrame] = None - async for data in self._input_ch: - frames: list[rtc.AudioFrame] = [] - if isinstance(data, rtc.AudioFrame): - state = self._check_energy_state(data) - if state in ( - AudioEnergyFilter.State.START, - AudioEnergyFilter.State.SPEAKING, - ): - if last_frame: - frames.extend( - audio_bstream.write(last_frame.data.tobytes()) - ) - last_frame = None - frames.extend(audio_bstream.write(data.data.tobytes())) - elif state == AudioEnergyFilter.State.END: - # no need to buffer as we have cooldown period + has_ended = False + last_frame: Optional[rtc.AudioFrame] = None + async for data in self._input_ch: + frames: list[rtc.AudioFrame] = [] + if isinstance(data, rtc.AudioFrame): + state = self._check_energy_state(data) + if state in ( + AudioEnergyFilter.State.START, + AudioEnergyFilter.State.SPEAKING, + ): + if last_frame: + frames.extend( + audio_bstream.write(last_frame.data.tobytes()) + ) + last_frame = None + frames.extend(audio_bstream.write(data.data.tobytes())) + elif state == AudioEnergyFilter.State.END: + # no need to buffer as we have cooldown period + frames = audio_bstream.flush() + has_ended = True + elif state == AudioEnergyFilter.State.SILENCE: + # buffer the last silence frame, since it could contain beginning of speech + # TODO: improve accuracy by using a ring buffer with longer window + last_frame = data + elif isinstance(data, self._FlushSentinel): frames = audio_bstream.flush() has_ended = True - elif state == AudioEnergyFilter.State.SILENCE: - # buffer the last silence frame, since it could contain beginning of speech - # TODO: improve accuracy by using a ring buffer with longer window - last_frame = data - elif isinstance(data, self._FlushSentinel): - frames = audio_bstream.flush() - has_ended = True - - for frame in frames: - self._audio_duration_collector.push(frame.duration) - await ws.send_bytes(frame.data.tobytes()) - - if has_ended: - self._audio_duration_collector.flush() - await ws.send_str(SpeechStream._FINALIZE_MSG) - has_ended = False - - # tell deepgram we are done sending audio/inputs - closing_ws = True - await ws.send_str(SpeechStream._CLOSE_MSG) + + for frame in frames: + self._audio_duration_collector.push(frame.duration) + await ws.send_bytes(frame.data.tobytes()) + + if has_ended: + self._audio_duration_collector.flush() + await ws.send_str(SpeechStream._FINALIZE_MSG) + has_ended = False + + # tell deepgram we are done sending audio/inputs + closing_ws = True + await ws.send_str(SpeechStream._CLOSE_MSG) + except asyncio.CancelledError: + pass + except Exception as e: + logger.error("unexpected error in send_task %s", e) + return async def recv_task(ws: aiohttp.ClientWebSocketResponse): - nonlocal closing_ws - while True: - msg = await ws.receive() - if msg.type in ( - aiohttp.WSMsgType.CLOSED, - aiohttp.WSMsgType.CLOSE, - aiohttp.WSMsgType.CLOSING, - ): - if closing_ws: # close is expected, see SpeechStream.aclose - return - - # this will trigger a reconnection, see the _run loop - raise Exception("deepgram connection closed unexpectedly") - - if msg.type != aiohttp.WSMsgType.TEXT: - logger.warning("unexpected deepgram message type %s", msg.type) - continue + try: + nonlocal closing_ws + while True: + msg = await ws.receive() + if msg.type in ( + aiohttp.WSMsgType.CLOSED, + aiohttp.WSMsgType.CLOSE, + aiohttp.WSMsgType.CLOSING, + ): + if closing_ws: # close is expected, see SpeechStream.aclose + return + + # this will trigger a reconnection, see the _run loop + raise Exception("deepgram connection closed unexpectedly") + + if msg.type != aiohttp.WSMsgType.TEXT: + logger.warning("unexpected deepgram message type %s", msg.type) + continue + + try: + self._process_stream_event(json.loads(msg.data)) + except Exception: + logger.exception("failed to process deepgram message") + except asyncio.CancelledError: + pass + except Exception as e: + logger.error("unexpected error in recv_task %s", e) + return - try: - self._process_stream_event(json.loads(msg.data)) - except Exception: - logger.exception("failed to process deepgram message") + async def _wait_for_reconnect(): + await self._reconnect_event.wait() ws: aiohttp.ClientWebSocketResponse | None = None try: - ws = await self._connect_ws() - tasks = [ - asyncio.create_task(send_task(ws)), - asyncio.create_task(recv_task(ws)), - asyncio.create_task(keepalive_task(ws)), - ] - try: - while True: + while True: + ws = await self._connect_ws() + tasks = [ + asyncio.create_task(send_task(ws)), + asyncio.create_task(recv_task(ws)), + asyncio.create_task(keepalive_task(ws)), + ] + reconnect_task = asyncio.create_task(_wait_for_reconnect()) + try: await asyncio.wait( - [ - asyncio.gather(*tasks), - asyncio.create_task( - self._reconnect_event.wait(), name="wait-for-reconnect" - ), - ], + [asyncio.gather(*tasks), reconnect_task], return_when=asyncio.FIRST_COMPLETED, ) if self._reconnect_event.is_set(): self._reconnect_event.clear() - ws = await self._connect_ws() + await utils.aio.gracefully_cancel(*tasks) + await ws.close() + continue else: + await utils.aio.gracefully_cancel(*tasks) + await utils.aio.gracefully_cancel(reconnect_task) break - finally: - await utils.aio.gracefully_cancel(*tasks) + finally: + await utils.aio.gracefully_cancel(*tasks) finally: if ws is not None: await ws.close() From cd54a7799ed554de739f9d7ade928bad9b64beae Mon Sep 17 00:00:00 2001 From: jayesh Date: Sat, 30 Nov 2024 00:28:37 +0530 Subject: [PATCH 28/31] assembly ai stt update options --- .../livekit/plugins/assemblyai/stt.py | 214 +++++++++++------- .../livekit/plugins/deepgram/stt.py | 1 - 2 files changed, 131 insertions(+), 84 deletions(-) diff --git a/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py b/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py index 6e94efade..d946a3682 100644 --- a/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py +++ b/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py @@ -93,6 +93,7 @@ def __init__( end_utterance_silence_threshold=end_utterance_silence_threshold, ) self._session = http_session + self._active_speech_stream: Optional[SpeechStream] = None @property def session(self) -> aiohttp.ClientSession: @@ -116,13 +117,18 @@ def stream( conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, ) -> "SpeechStream": config = dataclasses.replace(self._opts) - return SpeechStream( + self._active_speech_stream = SpeechStream( stt=self, conn_options=conn_options, opts=config, api_key=self._api_key, http_session=self.session, ) + return self._active_speech_stream + + def update_options(self, language: str | None = None): + if self._active_speech_stream is not None and language is not None: + self._active_speech_stream.update_options(language) class SpeechStream(stt.SpeechStream): @@ -149,12 +155,135 @@ def __init__( # keep a list of final transcripts to combine them inside the END_OF_SPEECH event self._final_events: List[stt.SpeechEvent] = [] + self._reconnect_event = asyncio.Event() + + def update_options(self, language: str | None = None): + self._opts.language = language or self._opts.language + self._reconnect_event.set() async def _run(self) -> None: """ Run a single websocket connection to AssemblyAI and make sure to reconnect when something went wrong. """ + + closing_ws = False + + async def send_task(ws: aiohttp.ClientWebSocketResponse): + try: + nonlocal closing_ws + + if self._opts.end_utterance_silence_threshold: + await ws.send_str( + json.dumps( + { + "end_utterance_silence_threshold": self._opts.end_utterance_silence_threshold + } + ) + ) + + samples_per_buffer = self._opts.sample_rate // round( + 1 / self._opts.buffer_size_seconds + ) + audio_bstream = utils.audio.AudioByteStream( + sample_rate=self._opts.sample_rate, + num_channels=1, + samples_per_channel=samples_per_buffer, + ) + + # forward inputs to AssemblyAI + # if we receive a close message, signal it to AssemblyAI and break. + # the recv task will then make sure to process the remaining audio and stop + async for data in self._input_ch: + if isinstance(data, self._FlushSentinel): + frames = audio_bstream.flush() + else: + frames = audio_bstream.write(data.data.tobytes()) + + for frame in frames: + self._speech_duration += frame.duration + await ws.send_bytes(frame.data.tobytes()) + + closing_ws = True + await ws.send_str(SpeechStream._CLOSE_MSG) + except asyncio.CancelledError: + pass + except Exception: + logger.error("failed to send audio to AssemblyAI") + return + + async def recv_task(ws: aiohttp.ClientWebSocketResponse): + try: + nonlocal closing_ws + while True: + try: + msg = await asyncio.wait_for(ws.receive(), timeout=5) + except asyncio.TimeoutError: + if closing_ws: + break + continue + + if msg.type in ( + aiohttp.WSMsgType.CLOSED, + aiohttp.WSMsgType.CLOSE, + aiohttp.WSMsgType.CLOSING, + ): + if closing_ws: # close is expected, see SpeechStream.aclose + return + + raise Exception( + "AssemblyAI connection closed unexpectedly", + ) # this will trigger a reconnection, see the _run loop + + if msg.type != aiohttp.WSMsgType.TEXT: + logger.error("unexpected AssemblyAI message type %s", msg.type) + continue + + try: + # received a message from AssemblyAI + data = json.loads(msg.data) + self._process_stream_event(data, closing_ws) + except Exception: + logger.exception("failed to process AssemblyAI message") + except asyncio.CancelledError: + pass + except Exception: + logger.error("failed to receive messages from AssemblyAI") + return + + async def _wait_for_reconnect(): + await self._reconnect_event.wait() + + try: + while True: + ws = await self._connect_ws() + tasks = [ + asyncio.create_task(send_task(ws)), + asyncio.create_task(recv_task(ws)), + ] + reconnect_task = asyncio.create_task(_wait_for_reconnect()) + + try: + await asyncio.wait( + [asyncio.gather(*tasks), reconnect_task], + return_when=asyncio.FIRST_COMPLETED, + ) + if self._reconnect_event.is_set(): + self._reconnect_event.clear() + await utils.aio.gracefully_cancel(*tasks) + await ws.close() + continue + else: + await utils.aio.gracefully_cancel(*tasks) + await utils.aio.gracefully_cancel(reconnect_task) + break + finally: + await utils.aio.gracefully_cancel(*tasks) + finally: + if ws is not None: + await ws.close() + + async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse: live_config = { "sample_rate": self._opts.sample_rate, "word_boost": self._opts.word_boost, @@ -172,88 +301,7 @@ async def _run(self) -> None: filtered_config = {k: v for k, v in live_config.items() if v is not None} url = f"{ws_url}?{urlencode(filtered_config).lower()}" ws = await self._session.ws_connect(url, headers=headers) - - closing_ws = False - - async def send_task(): - nonlocal closing_ws - - if self._opts.end_utterance_silence_threshold: - await ws.send_str( - json.dumps( - { - "end_utterance_silence_threshold": self._opts.end_utterance_silence_threshold - } - ) - ) - - samples_per_buffer = self._opts.sample_rate // round( - 1 / self._opts.buffer_size_seconds - ) - audio_bstream = utils.audio.AudioByteStream( - sample_rate=self._opts.sample_rate, - num_channels=1, - samples_per_channel=samples_per_buffer, - ) - - # forward inputs to AssemblyAI - # if we receive a close message, signal it to AssemblyAI and break. - # the recv task will then make sure to process the remaining audio and stop - async for data in self._input_ch: - if isinstance(data, self._FlushSentinel): - frames = audio_bstream.flush() - else: - frames = audio_bstream.write(data.data.tobytes()) - - for frame in frames: - self._speech_duration += frame.duration - await ws.send_bytes(frame.data.tobytes()) - - closing_ws = True - await ws.send_str(SpeechStream._CLOSE_MSG) - - async def recv_task(): - nonlocal closing_ws - while True: - try: - msg = await asyncio.wait_for(ws.receive(), timeout=5) - except asyncio.TimeoutError: - if closing_ws: - break - continue - - if msg.type in ( - aiohttp.WSMsgType.CLOSED, - aiohttp.WSMsgType.CLOSE, - aiohttp.WSMsgType.CLOSING, - ): - if closing_ws: # close is expected, see SpeechStream.aclose - return - - raise Exception( - "AssemblyAI connection closed unexpectedly", - ) # this will trigger a reconnection, see the _run loop - - if msg.type != aiohttp.WSMsgType.TEXT: - logger.error("unexpected AssemblyAI message type %s", msg.type) - continue - - try: - # received a message from AssemblyAI - data = json.loads(msg.data) - self._process_stream_event(data, closing_ws) - except Exception: - logger.exception("failed to process AssemblyAI message") - - tasks = [ - asyncio.create_task(send_task()), - asyncio.create_task(recv_task()), - ] - - try: - await asyncio.gather(*tasks) - finally: - await utils.aio.gracefully_cancel(*tasks) + return ws def _process_stream_event(self, data: dict, closing_ws: bool) -> None: # see this page: diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index beb53cde4..f5c891a29 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -313,7 +313,6 @@ def __init__( self._pushed_audio_duration = 0.0 self._request_id = "" self._reconnect_event = asyncio.Event() - self._ws: aiohttp.ClientWebSocketResponse | None = None def update_options(self, language: str | None = None): self._opts.language = language or self._opts.language From d857f5b56072da854993c415d0ed6379197b36a3 Mon Sep 17 00:00:00 2001 From: jayesh Date: Sat, 30 Nov 2024 11:27:00 +0530 Subject: [PATCH 29/31] updates --- .../livekit/plugins/assemblyai/stt.py | 20 ++++++++----------- .../livekit/plugins/deepgram/stt.py | 20 ++++++++----------- 2 files changed, 16 insertions(+), 24 deletions(-) diff --git a/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py b/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py index d946a3682..f7efe3754 100644 --- a/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py +++ b/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py @@ -254,8 +254,8 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse): async def _wait_for_reconnect(): await self._reconnect_event.wait() - try: - while True: + while True: + try: ws = await self._connect_ws() tasks = [ asyncio.create_task(send_task(ws)), @@ -264,24 +264,20 @@ async def _wait_for_reconnect(): reconnect_task = asyncio.create_task(_wait_for_reconnect()) try: - await asyncio.wait( + done, _ = await asyncio.wait( [asyncio.gather(*tasks), reconnect_task], return_when=asyncio.FIRST_COMPLETED, ) - if self._reconnect_event.is_set(): + if reconnect_task in done and self._reconnect_event.is_set(): self._reconnect_event.clear() - await utils.aio.gracefully_cancel(*tasks) - await ws.close() continue else: - await utils.aio.gracefully_cancel(*tasks) - await utils.aio.gracefully_cancel(reconnect_task) break finally: - await utils.aio.gracefully_cancel(*tasks) - finally: - if ws is not None: - await ws.close() + await utils.aio.gracefully_cancel(*tasks, reconnect_task) + finally: + if ws is not None: + await ws.close() async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse: live_config = { diff --git a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py index f5c891a29..458e6018d 100644 --- a/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py +++ b/livekit-plugins/livekit-plugins-deepgram/livekit/plugins/deepgram/stt.py @@ -427,8 +427,8 @@ async def _wait_for_reconnect(): ws: aiohttp.ClientWebSocketResponse | None = None - try: - while True: + while True: + try: ws = await self._connect_ws() tasks = [ asyncio.create_task(send_task(ws)), @@ -437,24 +437,20 @@ async def _wait_for_reconnect(): ] reconnect_task = asyncio.create_task(_wait_for_reconnect()) try: - await asyncio.wait( + done, _ = await asyncio.wait( [asyncio.gather(*tasks), reconnect_task], return_when=asyncio.FIRST_COMPLETED, ) - if self._reconnect_event.is_set(): + if reconnect_task in done and self._reconnect_event.is_set(): self._reconnect_event.clear() - await utils.aio.gracefully_cancel(*tasks) - await ws.close() continue else: - await utils.aio.gracefully_cancel(*tasks) - await utils.aio.gracefully_cancel(reconnect_task) break finally: - await utils.aio.gracefully_cancel(*tasks) - finally: - if ws is not None: - await ws.close() + await utils.aio.gracefully_cancel(*tasks, reconnect_task) + finally: + if ws is not None: + await ws.close() async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse: live_config = { From 3fa6004a8a3c4b7a1d02ad5fede52bc2b8503781 Mon Sep 17 00:00:00 2001 From: jayesh Date: Mon, 2 Dec 2024 14:57:33 +0530 Subject: [PATCH 30/31] azure stt update options --- .../livekit/plugins/azure/stt.py | 91 ++++++++++++------- 1 file changed, 56 insertions(+), 35 deletions(-) diff --git a/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/stt.py b/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/stt.py index 864fe50ec..b2c4d38ba 100644 --- a/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/stt.py +++ b/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/stt.py @@ -16,6 +16,7 @@ import contextlib import os from dataclasses import dataclass +from typing import Optional from livekit import rtc from livekit.agents import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions, stt, utils @@ -94,6 +95,7 @@ def __init__( segmentation_max_time_ms=segmentation_max_time_ms, segmentation_strategy=segmentation_strategy, ) + self._active_speech_stream: Optional[SpeechStream] = None async def _recognize_impl( self, @@ -110,7 +112,14 @@ def stream( language: str | None = None, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS, ) -> "SpeechStream": - return SpeechStream(stt=self, opts=self._config, conn_options=conn_options) + self._active_speech_stream = SpeechStream( + stt=self, opts=self._config, conn_options=conn_options + ) + return self._active_speech_stream + + def update_options(self, language: str | None = None): + if self._active_speech_stream is not None and language is not None: + self._active_speech_stream.update_options(language) class SpeechStream(stt.SpeechStream): @@ -127,44 +136,56 @@ def __init__( self._session_started_event = asyncio.Event() self._loop = asyncio.get_running_loop() + self._reconnect_event = asyncio.Event() + + def update_options(self, language: str | None = None): + self._opts.languages = [language] + self._reconnect_event.set() async def _run(self) -> None: - self._stream = speechsdk.audio.PushAudioInputStream( - stream_format=speechsdk.audio.AudioStreamFormat( - samples_per_second=self._opts.sample_rate, - bits_per_sample=16, - channels=self._opts.num_channels, + while True: + self._stream = speechsdk.audio.PushAudioInputStream( + stream_format=speechsdk.audio.AudioStreamFormat( + samples_per_second=self._opts.sample_rate, + bits_per_sample=16, + channels=self._opts.num_channels, + ) ) - ) - self._recognizer = _create_speech_recognizer( - config=self._opts, stream=self._stream - ) - self._recognizer.recognizing.connect(self._on_recognizing) - self._recognizer.recognized.connect(self._on_recognized) - self._recognizer.speech_start_detected.connect(self._on_speech_start) - self._recognizer.speech_end_detected.connect(self._on_speech_end) - self._recognizer.session_started.connect(self._on_session_started) - self._recognizer.session_stopped.connect(self._on_session_stopped) - self._recognizer.start_continuous_recognition() - - try: - await asyncio.wait_for( - self._session_started_event.wait(), self._conn_options.timeout + self._recognizer = _create_speech_recognizer( + config=self._opts, stream=self._stream ) - - async for input in self._input_ch: - if isinstance(input, rtc.AudioFrame): - self._stream.write(input.data.tobytes()) - - self._stream.close() - await self._session_stopped_event.wait() - finally: - - def _cleanup(): - self._recognizer.stop_continuous_recognition() - del self._recognizer - - await asyncio.to_thread(_cleanup) + self._recognizer.recognizing.connect(self._on_recognizing) + self._recognizer.recognized.connect(self._on_recognized) + self._recognizer.speech_start_detected.connect(self._on_speech_start) + self._recognizer.speech_end_detected.connect(self._on_speech_end) + self._recognizer.session_started.connect(self._on_session_started) + self._recognizer.session_stopped.connect(self._on_session_stopped) + self._recognizer.start_continuous_recognition() + + try: + await asyncio.wait_for( + self._session_started_event.wait(), self._conn_options.timeout + ) + + async for input in self._input_ch: + if self._reconnect_event.is_set(): + break + if isinstance(input, rtc.AudioFrame): + self._stream.write(input.data.tobytes()) + + self._stream.close() + await self._session_stopped_event.wait() + finally: + + def _cleanup(): + self._recognizer.stop_continuous_recognition() + del self._recognizer + + await asyncio.to_thread(_cleanup) + if self._reconnect_event.is_set(): + self._reconnect_event.clear() + else: + break def _on_recognized(self, evt: speechsdk.SpeechRecognitionEventArgs): detected_lg = speechsdk.AutoDetectSourceLanguageResult(evt.result).language From e5a58a73d2fa639717c79913f1f4d729bed3cead Mon Sep 17 00:00:00 2001 From: jayesh Date: Mon, 2 Dec 2024 15:03:16 +0530 Subject: [PATCH 31/31] azure stt update options --- .../livekit-plugins-azure/livekit/plugins/azure/stt.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/stt.py b/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/stt.py index b2c4d38ba..9a7a729a0 100644 --- a/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/stt.py +++ b/livekit-plugins/livekit-plugins-azure/livekit/plugins/azure/stt.py @@ -139,8 +139,9 @@ def __init__( self._reconnect_event = asyncio.Event() def update_options(self, language: str | None = None): - self._opts.languages = [language] - self._reconnect_event.set() + if language: + self._opts.languages = [language] + self._reconnect_event.set() async def _run(self) -> None: while True: