-
Notifications
You must be signed in to change notification settings - Fork 258
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement basic SDK for convai agents
Early prototype, subject to change based on user feedback. Takes care of the websocket session and message handling, exposing a simplified audio interface to the client that can be hooked up to the appropriate audio inputs / outputs based on the usecase. Also implements a basic speaker/microphone interface, via optional dependency on pyaudio.
- Loading branch information
Showing
4 changed files
with
335 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,289 @@ | ||
from abc import ABC, abstractmethod | ||
import base64 | ||
import json | ||
import queue | ||
import threading | ||
from typing import Callable, Optional | ||
|
||
from websockets.sync.client import connect | ||
|
||
from .base_client import BaseElevenLabs | ||
|
||
|
||
class AudioInterface(ABC): | ||
"""AudioInterface provides an abstraction for handling audio input and output.""" | ||
|
||
@abstractmethod | ||
def start(self, input_callback: Callable[[bytes], None]): | ||
"""Starts the audio interface. | ||
Called one time before the conversation starts. | ||
The `input_callback` should be called regularly with input audio chunks from | ||
the user. The audio should be in 16-bit PCM mono format at 16kHz. Recommended | ||
chunk size is 4000 samples (250 milliseconds). | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def stop(self): | ||
"""Stops the audio interface. | ||
Called one time after the conversation ends. Should clean up any resources | ||
used by the audio interface and stop any audio streams. Do not call the | ||
`input_callback` from `start` after this method is called. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def output(self, audio: bytes): | ||
"""Output audio to the user. | ||
The `audio` input is in 16-bit PCM mono format at 16kHz. Implementations can | ||
choose to do additional buffering. This method should return quickly and not | ||
block the calling thread. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def interrupt(self): | ||
"""Interruption signal to stop any audio output. | ||
User has interrupted the agent and all previosly buffered audio output should | ||
be stopped. | ||
""" | ||
pass | ||
|
||
|
||
class Conversation: | ||
client: BaseElevenLabs | ||
agent_id: str | ||
requires_auth: bool | ||
|
||
audio_interface: AudioInterface | ||
callback_agent_response: Optional[Callable[[str], None]] | ||
callback_user_transcript: Optional[Callable[[str], None]] | ||
callback_latency_measurement: Optional[Callable[[int], None]] | ||
|
||
_thread: Optional[threading.Thread] = None | ||
_should_stop: threading.Event = threading.Event() | ||
_conversation_id: Optional[str] = None | ||
_last_interrupt_id: int = 0 | ||
|
||
def __init__( | ||
self, | ||
client: BaseElevenLabs, | ||
agent_id: str, | ||
*, | ||
requires_auth: bool, | ||
audio_interface: AudioInterface, | ||
callback_agent_response: Optional[Callable[[str], None]] = None, | ||
callback_agent_response_correction: Optional[Callable[[str, str], None]] = None, | ||
callback_user_transcript: Optional[Callable[[str], None]] = None, | ||
callback_latency_measurement: Optional[Callable[[int], None]] = None, | ||
): | ||
"""Conversational AI session. | ||
BETA: This API is subject to change without regard to backwards compatibility. | ||
Args: | ||
client: The ElevenLabs client to use for the conversation. | ||
agent_id: The ID of the agent to converse with. | ||
requires_auth: Whether the agent requires authentication. | ||
audio_interface: The audio interface to use for input and output. | ||
callback_agent_response: Callback for agent responses. | ||
callback_agent_response_correction: Callback for agent response corrections. | ||
First argument is the original response (previously given to | ||
callback_agent_response), second argument is the corrected response. | ||
callback_user_transcript: Callback for user transcripts. | ||
callback_latency_measurement: Callback for latency measurements (in milliseconds). | ||
""" | ||
|
||
self.client = client | ||
self.agent_id = agent_id | ||
self.requires_auth = requires_auth | ||
|
||
self.audio_interface = audio_interface | ||
self.callback_agent_response = callback_agent_response | ||
self.callback_agent_response_correction = callback_agent_response_correction | ||
self.callback_user_transcript = callback_user_transcript | ||
self.callback_latency_measurement = callback_latency_measurement | ||
|
||
def start_session(self): | ||
"""Starts the conversation session. | ||
Will run in background thread until `end_session` is called. | ||
""" | ||
ws_url = self._get_signed_url() if self.requires_auth else self._get_wss_url() | ||
self._thread = threading.Thread(target=self._run, args=(ws_url,)) | ||
self._thread.start() | ||
|
||
def end_session(self): | ||
"""Ends the conversation session.""" | ||
self.audio_interface.stop() | ||
self._should_stop.set() | ||
|
||
def wait_for_session_end(self) -> Optional[str]: | ||
"""Waits for the conversation session to end. | ||
You must call `end_session` before calling this method, otherwise it will block. | ||
Returns the conversation ID, if available. | ||
""" | ||
if not self._thread: | ||
raise RuntimeError("Session not started.") | ||
self._thread.join() | ||
return self._conversation_id | ||
|
||
def _run(self, ws_url: str): | ||
with connect(ws_url) as ws: | ||
|
||
def input_callback(audio): | ||
ws.send( | ||
json.dumps( | ||
{ | ||
"user_audio_chunk": base64.b64encode(audio).decode(), | ||
} | ||
) | ||
) | ||
|
||
self.audio_interface.start(input_callback) | ||
while not self._should_stop.is_set(): | ||
try: | ||
message = json.loads(ws.recv(timeout=0.5)) | ||
self._handle_message(message, ws) | ||
except TimeoutError: | ||
pass | ||
|
||
def _handle_message(self, message, ws): | ||
if message["type"] == "conversation_initiation_metadata": | ||
event = message["conversation_initiation_metadata_event"] | ||
assert self._conversation_id is None | ||
self._conversation_id = event["conversation_id"] | ||
elif message["type"] == "audio": | ||
event = message["audio_event"] | ||
if int(event["event_id"]) <= self._last_interrupt_id: | ||
return | ||
audio = base64.b64decode(event["audio_base_64"]) | ||
self.audio_interface.output(audio) | ||
elif message["type"] == "agent_response": | ||
if self.callback_agent_response: | ||
event = message["agent_response_event"] | ||
self.callback_agent_response(event["agent_response"].strip()) | ||
elif message["type"] == "agent_response_correction": | ||
if self.callback_agent_response_correction: | ||
event = message["agent_response_correction_event"] | ||
self.callback_agent_response_correction(event["original_agent_response"].strip(), event["corrected_agent_response"].strip()) | ||
elif message["type"] == "user_transcript": | ||
if self.callback_user_transcript: | ||
event = message["user_transcription_event"] | ||
self.callback_user_transcript(event["user_transcript"].strip()) | ||
elif message["type"] == "interruption": | ||
event = message["interruption_event"] | ||
self.last_interrupt_id = int(event["event_id"]) | ||
self.audio_interface.interrupt() | ||
elif message["type"] == "ping": | ||
event = message["ping_event"] | ||
ws.send( | ||
json.dumps( | ||
{ | ||
"type": "pong", | ||
"event_id": event["event_id"], | ||
} | ||
) | ||
) | ||
if self.callback_latency_measurement and event["ping_ms"]: | ||
self.callback_latency_measurement(int(event["ping_ms"])) | ||
else: | ||
pass # Ignore all other message types. | ||
|
||
def _get_wss_url(self): | ||
base_url = self.client._client_wrapper._base_url | ||
# Replace http(s) with ws(s). | ||
base_ws_url = base_url.replace("http", "ws", 1) # First occurrence only. | ||
return f"{base_ws_url}/v1/convai/conversation?agent_id={self.agent_id}" | ||
|
||
def _get_signed_url(self): | ||
# TODO: Use generated SDK method once available. | ||
response = self.client._client_wrapper.httpx_client.request( | ||
f"v1/convai/conversation/get_signed_url?agent_id={self.agent_id}", | ||
method="GET", | ||
) | ||
return response.json()["signed_url"] | ||
|
||
|
||
class DefaultAudioInterface(AudioInterface): | ||
INPUT_FRAMES_PER_BUFFER = 4000 # 250ms @ 16kHz | ||
OUTPUT_FRAMES_PER_BUFFER = 1000 # 62.5ms @ 16kHz | ||
|
||
def __init__(self): | ||
try: | ||
import pyaudio | ||
except ImportError: | ||
raise ImportError("To use DefaultAudioInterface you must install pyaudio.") | ||
self.pyaudio = pyaudio | ||
|
||
def start(self, input_callback: Callable[[bytes], None]): | ||
# Audio input is using callbacks from pyaudio which we simply pass through. | ||
self.input_callback = input_callback | ||
|
||
# Audio output is buffered so we can handle interruptions. | ||
# Start a separate thread to handle writing to the output stream. | ||
self.output_queue: queue.Queue[bytes] = queue.Queue() | ||
self.should_stop = threading.Event() | ||
self.output_thread = threading.Thread(target=self._output_thread) | ||
|
||
self.p = self.pyaudio.PyAudio() | ||
self.in_stream = self.p.open( | ||
format=self.pyaudio.paInt16, | ||
channels=1, | ||
rate=16000, | ||
input=True, | ||
stream_callback=self._in_callback, | ||
frames_per_buffer=self.INPUT_FRAMES_PER_BUFFER, | ||
start=True, | ||
) | ||
self.out_stream = self.p.open( | ||
format=self.pyaudio.paInt16, | ||
channels=1, | ||
rate=16000, | ||
output=True, | ||
frames_per_buffer=self.OUTPUT_FRAMES_PER_BUFFER, | ||
start=True, | ||
) | ||
|
||
self.output_thread.start() | ||
|
||
def stop(self): | ||
self.should_stop.set() | ||
self.output_thread.join() | ||
self.in_stream.stop_stream() | ||
self.in_stream.close() | ||
self.out_stream.close() | ||
self.p.terminate() | ||
|
||
def output(self, audio: bytes): | ||
self.output_queue.put(audio) | ||
|
||
def interrupt(self): | ||
# Clear the output queue to stop any audio that is currently playing. | ||
# Note: We can't atomically clear the whole queue, but we are doing | ||
# it from the message handling thread so no new audio will be added | ||
# while we are clearing. | ||
try: | ||
while True: | ||
_ = self.output_queue.get(block=False) | ||
except queue.Empty: | ||
pass | ||
|
||
def _output_thread(self): | ||
while not self.should_stop.is_set(): | ||
try: | ||
audio = self.output_queue.get(timeout=0.25) | ||
self.out_stream.write(audio) | ||
except queue.Empty: | ||
pass | ||
|
||
def _in_callback(self, in_data, frame_count, time_info, status): | ||
if self.input_callback: | ||
self.input_callback(in_data) | ||
return (None, self.pyaudio.paContinue) |