Skip to content

Commit

Permalink
Implement basic SDK for convai agents
Browse files Browse the repository at this point in the history
Early prototype, subject to change based on user feedback.

Takes care of the websocket session and message handling, exposing a
simplified audio interface to the client that can be hooked up to
the appropriate audio inputs / outputs based on the usecase.

Also implements a basic speaker/microphone interface, via optional
dependency on pyaudio.
  • Loading branch information
lacop11 committed Oct 25, 2024
1 parent 5a2c536 commit 1426346
Show file tree
Hide file tree
Showing 4 changed files with 335 additions and 1 deletion.
1 change: 1 addition & 0 deletions .fernignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Specify files that shouldn't be modified by Fern

src/elevenlabs/client.py
src/elevenlabs/conversation.py
src/elevenlabs/play.py
src/elevenlabs/realtime_tts.py

Expand Down
39 changes: 38 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,21 @@ requests = ">=2.20"
typing_extensions = ">= 4.0.0"
websockets = ">=11.0"

# Optional extras.
pyaudio = { version = ">=0.2.14", optional = true }

[tool.poetry.dev-dependencies]
mypy = "1.0.1"
pytest = "^7.4.0"
pytest-asyncio = "^0.23.5"
python-dateutil = "^2.9.0"
types-pyaudio = "^0.2.16.20240516"
types-python-dateutil = "^2.9.0.20240316"
ruff = "^0.5.6"

[tool.poetry.extras]
pyaudio = ["pyaudio"]

[tool.pytest.ini_options]
testpaths = [ "tests" ]
asyncio_mode = "auto"
Expand Down
289 changes: 289 additions & 0 deletions src/elevenlabs/conversation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
from abc import ABC, abstractmethod
import base64
import json
import queue
import threading
from typing import Callable, Optional

from websockets.sync.client import connect

from .base_client import BaseElevenLabs


class AudioInterface(ABC):
"""AudioInterface provides an abstraction for handling audio input and output."""

@abstractmethod
def start(self, input_callback: Callable[[bytes], None]):
"""Starts the audio interface.
Called one time before the conversation starts.
The `input_callback` should be called regularly with input audio chunks from
the user. The audio should be in 16-bit PCM mono format at 16kHz. Recommended
chunk size is 4000 samples (250 milliseconds).
"""
pass

@abstractmethod
def stop(self):
"""Stops the audio interface.
Called one time after the conversation ends. Should clean up any resources
used by the audio interface and stop any audio streams. Do not call the
`input_callback` from `start` after this method is called.
"""
pass

@abstractmethod
def output(self, audio: bytes):
"""Output audio to the user.
The `audio` input is in 16-bit PCM mono format at 16kHz. Implementations can
choose to do additional buffering. This method should return quickly and not
block the calling thread.
"""
pass

@abstractmethod
def interrupt(self):
"""Interruption signal to stop any audio output.
User has interrupted the agent and all previosly buffered audio output should
be stopped.
"""
pass


class Conversation:
client: BaseElevenLabs
agent_id: str
requires_auth: bool

audio_interface: AudioInterface
callback_agent_response: Optional[Callable[[str], None]]
callback_user_transcript: Optional[Callable[[str], None]]
callback_latency_measurement: Optional[Callable[[int], None]]

_thread: Optional[threading.Thread] = None
_should_stop: threading.Event = threading.Event()
_conversation_id: Optional[str] = None
_last_interrupt_id: int = 0

def __init__(
self,
client: BaseElevenLabs,
agent_id: str,
*,
requires_auth: bool,
audio_interface: AudioInterface,
callback_agent_response: Optional[Callable[[str], None]] = None,
callback_agent_response_correction: Optional[Callable[[str, str], None]] = None,
callback_user_transcript: Optional[Callable[[str], None]] = None,
callback_latency_measurement: Optional[Callable[[int], None]] = None,
):
"""Conversational AI session.
BETA: This API is subject to change without regard to backwards compatibility.
Args:
client: The ElevenLabs client to use for the conversation.
agent_id: The ID of the agent to converse with.
requires_auth: Whether the agent requires authentication.
audio_interface: The audio interface to use for input and output.
callback_agent_response: Callback for agent responses.
callback_agent_response_correction: Callback for agent response corrections.
First argument is the original response (previously given to
callback_agent_response), second argument is the corrected response.
callback_user_transcript: Callback for user transcripts.
callback_latency_measurement: Callback for latency measurements (in milliseconds).
"""

self.client = client
self.agent_id = agent_id
self.requires_auth = requires_auth

self.audio_interface = audio_interface
self.callback_agent_response = callback_agent_response
self.callback_agent_response_correction = callback_agent_response_correction
self.callback_user_transcript = callback_user_transcript
self.callback_latency_measurement = callback_latency_measurement

def start_session(self):
"""Starts the conversation session.
Will run in background thread until `end_session` is called.
"""
ws_url = self._get_signed_url() if self.requires_auth else self._get_wss_url()
self._thread = threading.Thread(target=self._run, args=(ws_url,))
self._thread.start()

def end_session(self):
"""Ends the conversation session."""
self.audio_interface.stop()
self._should_stop.set()

def wait_for_session_end(self) -> Optional[str]:
"""Waits for the conversation session to end.
You must call `end_session` before calling this method, otherwise it will block.
Returns the conversation ID, if available.
"""
if not self._thread:
raise RuntimeError("Session not started.")
self._thread.join()
return self._conversation_id

def _run(self, ws_url: str):
with connect(ws_url) as ws:

def input_callback(audio):
ws.send(
json.dumps(
{
"user_audio_chunk": base64.b64encode(audio).decode(),
}
)
)

self.audio_interface.start(input_callback)
while not self._should_stop.is_set():
try:
message = json.loads(ws.recv(timeout=0.5))
self._handle_message(message, ws)
except TimeoutError:
pass

def _handle_message(self, message, ws):
if message["type"] == "conversation_initiation_metadata":
event = message["conversation_initiation_metadata_event"]
assert self._conversation_id is None
self._conversation_id = event["conversation_id"]
elif message["type"] == "audio":
event = message["audio_event"]
if int(event["event_id"]) <= self._last_interrupt_id:
return
audio = base64.b64decode(event["audio_base_64"])
self.audio_interface.output(audio)
elif message["type"] == "agent_response":
if self.callback_agent_response:
event = message["agent_response_event"]
self.callback_agent_response(event["agent_response"].strip())
elif message["type"] == "agent_response_correction":
if self.callback_agent_response_correction:
event = message["agent_response_correction_event"]
self.callback_agent_response_correction(event["original_agent_response"].strip(), event["corrected_agent_response"].strip())
elif message["type"] == "user_transcript":
if self.callback_user_transcript:
event = message["user_transcription_event"]
self.callback_user_transcript(event["user_transcript"].strip())
elif message["type"] == "interruption":
event = message["interruption_event"]
self.last_interrupt_id = int(event["event_id"])
self.audio_interface.interrupt()
elif message["type"] == "ping":
event = message["ping_event"]
ws.send(
json.dumps(
{
"type": "pong",
"event_id": event["event_id"],
}
)
)
if self.callback_latency_measurement and event["ping_ms"]:
self.callback_latency_measurement(int(event["ping_ms"]))
else:
pass # Ignore all other message types.

def _get_wss_url(self):
base_url = self.client._client_wrapper._base_url
# Replace http(s) with ws(s).
base_ws_url = base_url.replace("http", "ws", 1) # First occurrence only.
return f"{base_ws_url}/v1/convai/conversation?agent_id={self.agent_id}"

def _get_signed_url(self):
# TODO: Use generated SDK method once available.
response = self.client._client_wrapper.httpx_client.request(
f"v1/convai/conversation/get_signed_url?agent_id={self.agent_id}",
method="GET",
)
return response.json()["signed_url"]


class DefaultAudioInterface(AudioInterface):
INPUT_FRAMES_PER_BUFFER = 4000 # 250ms @ 16kHz
OUTPUT_FRAMES_PER_BUFFER = 1000 # 62.5ms @ 16kHz

def __init__(self):
try:
import pyaudio
except ImportError:
raise ImportError("To use DefaultAudioInterface you must install pyaudio.")
self.pyaudio = pyaudio

def start(self, input_callback: Callable[[bytes], None]):
# Audio input is using callbacks from pyaudio which we simply pass through.
self.input_callback = input_callback

# Audio output is buffered so we can handle interruptions.
# Start a separate thread to handle writing to the output stream.
self.output_queue: queue.Queue[bytes] = queue.Queue()
self.should_stop = threading.Event()
self.output_thread = threading.Thread(target=self._output_thread)

self.p = self.pyaudio.PyAudio()
self.in_stream = self.p.open(
format=self.pyaudio.paInt16,
channels=1,
rate=16000,
input=True,
stream_callback=self._in_callback,
frames_per_buffer=self.INPUT_FRAMES_PER_BUFFER,
start=True,
)
self.out_stream = self.p.open(
format=self.pyaudio.paInt16,
channels=1,
rate=16000,
output=True,
frames_per_buffer=self.OUTPUT_FRAMES_PER_BUFFER,
start=True,
)

self.output_thread.start()

def stop(self):
self.should_stop.set()
self.output_thread.join()
self.in_stream.stop_stream()
self.in_stream.close()
self.out_stream.close()
self.p.terminate()

def output(self, audio: bytes):
self.output_queue.put(audio)

def interrupt(self):
# Clear the output queue to stop any audio that is currently playing.
# Note: We can't atomically clear the whole queue, but we are doing
# it from the message handling thread so no new audio will be added
# while we are clearing.
try:
while True:
_ = self.output_queue.get(block=False)
except queue.Empty:
pass

def _output_thread(self):
while not self.should_stop.is_set():
try:
audio = self.output_queue.get(timeout=0.25)
self.out_stream.write(audio)
except queue.Empty:
pass

def _in_callback(self, in_data, frame_count, time_info, status):
if self.input_callback:
self.input_callback(in_data)
return (None, self.pyaudio.paContinue)

0 comments on commit 1426346

Please sign in to comment.