Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add update_options to STTs #1131

Draft
wants to merge 33 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
a581c4c
openai stt update options
jayeshp19 Nov 21, 2024
1089b9c
stt update options for fal
jayeshp19 Nov 23, 2024
0bc397f
stt update options for clova
jayeshp19 Nov 23, 2024
f22e42e
wip
jayeshp19 Nov 23, 2024
704a1d0
Merge branch 'main' of https://github.com/livekit/agents into stt-upd…
jayeshp19 Nov 25, 2024
b250d8d
Merge branch 'main' of https://github.com/livekit/agents into stt-upd…
jayeshp19 Nov 26, 2024
6333159
deepgram stt wip
jayeshp19 Nov 26, 2024
6474992
deepgram stt wip
jayeshp19 Nov 26, 2024
7ea51cb
deepgram stt wip
jayeshp19 Nov 26, 2024
a03d17f
deepgram stt wip
jayeshp19 Nov 26, 2024
72267c6
deepgram stt wip
jayeshp19 Nov 26, 2024
4deec8b
deepgram stt wip
jayeshp19 Nov 26, 2024
3ec13d3
deepgram stt wip
jayeshp19 Nov 26, 2024
fe3e1fd
deepgram stt wip
jayeshp19 Nov 27, 2024
b5c5cb8
merge main
jayeshp19 Nov 27, 2024
64df6e0
deepgram stt wip
jayeshp19 Nov 27, 2024
4862ecf
deepgram stt wip
jayeshp19 Nov 27, 2024
aacb8fc
deepgram stt wip
jayeshp19 Nov 27, 2024
6618d13
deepgram stt wip
jayeshp19 Nov 27, 2024
e2de5f5
deepgram stt wip
jayeshp19 Nov 27, 2024
12597ae
deepgram stt wip
jayeshp19 Nov 27, 2024
1f0922b
deepgram stt wip
jayeshp19 Nov 27, 2024
549e301
deepgram stt wip
jayeshp19 Nov 27, 2024
7e02271
deepgram stt wip
jayeshp19 Nov 27, 2024
bd9cf1f
wip
jayeshp19 Nov 28, 2024
9046b2c
wip
jayeshp19 Nov 28, 2024
2445f85
wip
jayeshp19 Nov 29, 2024
77acebd
wip
jayeshp19 Nov 29, 2024
1295db9
wip
jayeshp19 Nov 29, 2024
460a56b
deepgram update options
jayeshp19 Nov 29, 2024
cd54a77
assembly ai stt update options
jayeshp19 Nov 29, 2024
d857f5b
updates
jayeshp19 Nov 30, 2024
c962139
Merge branch 'main' of https://github.com/livekit/agents into stt-upd…
jayeshp19 Dec 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
end_utterance_silence_threshold=end_utterance_silence_threshold,
)
self._session = http_session
self._active_speech_stream: Optional[SpeechStream] = None

@property
def session(self) -> aiohttp.ClientSession:
Expand All @@ -116,13 +117,18 @@ def stream(
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
) -> "SpeechStream":
config = dataclasses.replace(self._opts)
return SpeechStream(
self._active_speech_stream = SpeechStream(
stt=self,
conn_options=conn_options,
opts=config,
api_key=self._api_key,
http_session=self.session,
)
return self._active_speech_stream

def update_options(self, language: str | None = None):
if self._active_speech_stream is not None and language is not None:
self._active_speech_stream.update_options(language)


class SpeechStream(stt.SpeechStream):
Expand All @@ -149,12 +155,131 @@ def __init__(

# keep a list of final transcripts to combine them inside the END_OF_SPEECH event
self._final_events: List[stt.SpeechEvent] = []
self._reconnect_event = asyncio.Event()

def update_options(self, language: str | None = None):
self._opts.language = language or self._opts.language
self._reconnect_event.set()

async def _run(self) -> None:
"""
Run a single websocket connection to AssemblyAI and make sure to reconnect
when something went wrong.
"""

closing_ws = False

async def send_task(ws: aiohttp.ClientWebSocketResponse):
try:
nonlocal closing_ws

if self._opts.end_utterance_silence_threshold:
await ws.send_str(
json.dumps(
{
"end_utterance_silence_threshold": self._opts.end_utterance_silence_threshold
}
)
)

samples_per_buffer = self._opts.sample_rate // round(
1 / self._opts.buffer_size_seconds
)
audio_bstream = utils.audio.AudioByteStream(
sample_rate=self._opts.sample_rate,
num_channels=1,
samples_per_channel=samples_per_buffer,
)

# forward inputs to AssemblyAI
# if we receive a close message, signal it to AssemblyAI and break.
# the recv task will then make sure to process the remaining audio and stop
async for data in self._input_ch:
if isinstance(data, self._FlushSentinel):
frames = audio_bstream.flush()
else:
frames = audio_bstream.write(data.data.tobytes())

for frame in frames:
self._speech_duration += frame.duration
await ws.send_bytes(frame.data.tobytes())

closing_ws = True
await ws.send_str(SpeechStream._CLOSE_MSG)
except asyncio.CancelledError:
pass
except Exception:
logger.error("failed to send audio to AssemblyAI")
return

async def recv_task(ws: aiohttp.ClientWebSocketResponse):
try:
nonlocal closing_ws
while True:
try:
msg = await asyncio.wait_for(ws.receive(), timeout=5)
except asyncio.TimeoutError:
if closing_ws:
break
continue

if msg.type in (
aiohttp.WSMsgType.CLOSED,
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSING,
):
if closing_ws: # close is expected, see SpeechStream.aclose
return

raise Exception(
"AssemblyAI connection closed unexpectedly",
) # this will trigger a reconnection, see the _run loop

if msg.type != aiohttp.WSMsgType.TEXT:
logger.error("unexpected AssemblyAI message type %s", msg.type)
continue

try:
# received a message from AssemblyAI
data = json.loads(msg.data)
self._process_stream_event(data, closing_ws)
except Exception:
logger.exception("failed to process AssemblyAI message")
except asyncio.CancelledError:
pass
except Exception:
logger.error("failed to receive messages from AssemblyAI")
return

async def _wait_for_reconnect():
await self._reconnect_event.wait()

while True:
try:
ws = await self._connect_ws()
tasks = [
asyncio.create_task(send_task(ws)),
asyncio.create_task(recv_task(ws)),
]
reconnect_task = asyncio.create_task(_wait_for_reconnect())

try:
done, _ = await asyncio.wait(
[asyncio.gather(*tasks), reconnect_task],
return_when=asyncio.FIRST_COMPLETED,
)
if reconnect_task in done and self._reconnect_event.is_set():
self._reconnect_event.clear()
continue
else:
break
finally:
await utils.aio.gracefully_cancel(*tasks, reconnect_task)
finally:
if ws is not None:
await ws.close()

async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
live_config = {
"sample_rate": self._opts.sample_rate,
"word_boost": self._opts.word_boost,
Expand All @@ -172,88 +297,7 @@ async def _run(self) -> None:
filtered_config = {k: v for k, v in live_config.items() if v is not None}
url = f"{ws_url}?{urlencode(filtered_config).lower()}"
ws = await self._session.ws_connect(url, headers=headers)

closing_ws = False

async def send_task():
nonlocal closing_ws

if self._opts.end_utterance_silence_threshold:
await ws.send_str(
json.dumps(
{
"end_utterance_silence_threshold": self._opts.end_utterance_silence_threshold
}
)
)

samples_per_buffer = self._opts.sample_rate // round(
1 / self._opts.buffer_size_seconds
)
audio_bstream = utils.audio.AudioByteStream(
sample_rate=self._opts.sample_rate,
num_channels=1,
samples_per_channel=samples_per_buffer,
)

# forward inputs to AssemblyAI
# if we receive a close message, signal it to AssemblyAI and break.
# the recv task will then make sure to process the remaining audio and stop
async for data in self._input_ch:
if isinstance(data, self._FlushSentinel):
frames = audio_bstream.flush()
else:
frames = audio_bstream.write(data.data.tobytes())

for frame in frames:
self._speech_duration += frame.duration
await ws.send_bytes(frame.data.tobytes())

closing_ws = True
await ws.send_str(SpeechStream._CLOSE_MSG)

async def recv_task():
nonlocal closing_ws
while True:
try:
msg = await asyncio.wait_for(ws.receive(), timeout=5)
except asyncio.TimeoutError:
if closing_ws:
break
continue

if msg.type in (
aiohttp.WSMsgType.CLOSED,
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSING,
):
if closing_ws: # close is expected, see SpeechStream.aclose
return

raise Exception(
"AssemblyAI connection closed unexpectedly",
) # this will trigger a reconnection, see the _run loop

if msg.type != aiohttp.WSMsgType.TEXT:
logger.error("unexpected AssemblyAI message type %s", msg.type)
continue

try:
# received a message from AssemblyAI
data = json.loads(msg.data)
self._process_stream_event(data, closing_ws)
except Exception:
logger.exception("failed to process AssemblyAI message")

tasks = [
asyncio.create_task(send_task()),
asyncio.create_task(recv_task()),
]

try:
await asyncio.gather(*tasks)
finally:
await utils.aio.gracefully_cancel(*tasks)
return ws

def _process_stream_event(self, data: dict, closing_ws: bool) -> None:
# see this page:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ def __init__(
)
self.threshold = threshold

def update_options(self, *, language: str | None) -> None:
self._language = (
clova_languages_mapping.get(language, language) or self._language
)

def _ensure_session(self) -> aiohttp.ClientSession:
if not self._session:
self._session = utils.http_context.http_session()
Expand Down
Loading
Loading