diff --git a/recipes/voice-llm/python/main.py b/recipes/voice-llm/python/main.py new file mode 100644 index 0000000..8bdc08f --- /dev/null +++ b/recipes/voice-llm/python/main.py @@ -0,0 +1,375 @@ +import signal +import time +from argparse import ArgumentParser +from enum import Enum +from multiprocessing import ( + Pipe, + Process, +) +from typing import ( + Optional, + Sequence, +) + +import picollm +import pvcheetah +import pvorca +import pvporcupine +from pvrecorder import PvRecorder + + +class Logger: + class Levels(Enum): + DEBUG = 'DEBUG' + INFO = 'INFO' + + def __init__(self, level: 'Logger.Levels' = Levels.INFO) -> None: + self._level = level + + def debug(self, message: str, end: str = '\n') -> None: + if self._level is self.Levels.DEBUG: + print(message, end=end, flush=True) + + # noinspection PyMethodMayBeStatic + def info(self, message: str, end: str = '\n') -> None: + print(message, end=end, flush=True) + + +class RTFProfiler: + def __init__(self, sample_rate: int) -> None: + self._sample_rate = sample_rate + self._compute_sec = 0. + self._audio_sec = 0. + self._tick_sec = 0. + + def tick(self) -> None: + self._tick_sec = time.time() + + def tock(self, audio: Optional[Sequence[int]] = None) -> None: + self._compute_sec += time.time() - self._tick_sec + self._audio_sec += (len(audio) / self._sample_rate) if audio is not None else 0. + + def rtf(self) -> float: + rtf = self._compute_sec / self._audio_sec + self._compute_sec = 0. + self._audio_sec = 0. + return rtf + + +class TPSProfiler(object): + def __init__(self) -> None: + self._num_tokens = 0 + self._start_sec = 0. + + def tock(self) -> None: + if self._start_sec == 0.: + self._start_sec = time.time() + else: + self._num_tokens += 1 + + def tps(self) -> float: + tps = self._num_tokens / (time.time() - self._start_sec) + self._num_tokens = 0 + self._start_sec = 0. + return tps + + +def orca_worker(access_key: str, connection, warmup_sec: float, stream_frame_sec: int = 0.03) -> None: + # noinspection PyUnresolvedReferences + import numpy as np + from sounddevice import OutputStream + + orca = pvorca.create(access_key=access_key) + orca_stream = orca.stream_open() + + texts = list() + pcm_buffer = list() + warmup = [False] + synthesize = False + flush = False + close = False + utterance_end_sec = 0. + delay_sec = [-1.] + + def callback(data, _, __, ___) -> None: + if warmup[0]: + if len(pcm_buffer) < int(warmup_sec * orca.sample_rate): + data[:, 0] = 0 + return + else: + warmup[0] = False + + if len(pcm_buffer) < data.shape[0]: + pcm_buffer.extend([0] * (data.shape[0] - len(pcm_buffer))) + + data[:, 0] = pcm_buffer[:data.shape[0]] + del pcm_buffer[:data.shape[0]] + + stream = OutputStream( + samplerate=orca.sample_rate, + blocksize=int(stream_frame_sec * orca.sample_rate), + channels=1, + dtype='int16', + callback=callback) + + connection.send({'version': orca.version}) + + orca_profiler = RTFProfiler(orca.sample_rate) + + def buffer_pcm(x: Optional[Sequence[int]]) -> None: + if x is not None: + pcm_buffer.extend(x) + if delay_sec[0] == -1: + delay_sec[0] = time.time() - utterance_end_sec + + while True: + if synthesize and len(texts) > 0: + orca_profiler.tick() + pcm = orca_stream.synthesize(texts.pop(0)) + orca_profiler.tock(pcm) + buffer_pcm(pcm) + elif flush: + while len(texts) > 0: + orca_profiler.tick() + pcm = orca_stream.synthesize(texts.pop(0)) + orca_profiler.tock(pcm) + buffer_pcm(pcm) + orca_profiler.tick() + pcm = orca_stream.flush() + orca_profiler.tock(pcm) + buffer_pcm(pcm) + connection.send({'rtf': orca_profiler.rtf(), 'delay': delay_sec[0]}) + flush = False + while len(pcm_buffer) > 0: + time.sleep(stream_frame_sec) + stream.stop() + delay_sec[0] = -1 + connection.send({'done': True}) + elif close: + break + else: + time.sleep(stream_frame_sec) + + while connection.poll(): + message = connection.recv() + if message['command'] == 'synthesize': + texts.append(message['text']) + if not stream.active: + stream.start() + warmup[0] = True + utterance_end_sec = message['utterance_end_sec'] + synthesize = True + elif message['command'] == 'flush': + synthesize = False + flush = True + elif message['command'] == 'close': + close = True + + stream.close() + orca_stream.close() + orca.delete() + + +def main() -> None: + parser = ArgumentParser() + parser.add_argument( + '--access_key', + required=True, + help='`AccessKey` obtained from `Picovoice Console` (https://console.picovoice.ai/).') + parser.add_argument( + '--picollm_model_path', + required=True, + help='Absolute path to the file containing LLM parameters.') + parser.add_argument( + '--keyword-model_path', + help='Absolute path to the keyword model file. If not set, `Picovoice` will be used as the wake phrase') + parser.add_argument( + '--cheetah_endpoint_duration_sec', + type=float, + default=1., + help="Duration of silence (pause) after the user's utterance to consider it the end of the utterance.") + parser.add_argument( + '--picollm_device', + help="String representation of the device (e.g., CPU or GPU) to use for inference. If set to `best`, picoLLM " + "picks the most suitable device. If set to `gpu`, the engine uses the first available GPU device. To " + "select a specific GPU device, set this argument to `gpu:${GPU_INDEX}`, where `${GPU_INDEX}` is the index " + "of the target GPU. If set to `cpu`, the engine will run on the CPU with the default number of threads. " + "To specify the number of threads, set this argument to `cpu:${NUM_THREADS}`, where `${NUM_THREADS}` is " + "the desired number of threads.") + parser.add_argument( + '--picollm_completion_token_limit', + type=int, + help="Maximum number of tokens in the completion. Set to `None` to impose no limit.") + parser.add_argument( + '--picollm_presence_penalty', + type=float, + default=0., + help="It penalizes logits already appearing in the partial completion if set to a positive value. If set to " + "`0.0`, it has no effect.") + parser.add_argument( + '--picollm_frequency_penalty', + type=float, + default=0., + help="If set to a positive floating-point value, it penalizes logits proportional to the frequency of their " + "appearance in the partial completion. If set to `0.0`, it has no effect.") + parser.add_argument( + '--picollm_temperature', + type=float, + default=0., + help="Sampling temperature. Temperature is a non-negative floating-point value that controls the randomness of " + "the sampler. A higher temperature smoothens the samplers' output, increasing the randomness. In " + "contrast, a lower temperature creates a narrower distribution and reduces variability. Setting it to " + "`0` selects the maximum logit during sampling.") + parser.add_argument( + '--picollm_top_p', + type=float, + default=1., + help="A positive floating-point number within (0, 1]. It restricts the sampler's choices to high-probability " + "logits that form the `top_p` portion of the probability mass. Hence, it avoids randomly selecting " + "unlikely logits. A value of `1.` enables the sampler to pick any token with non-zero probability, " + "turning off the feature.") + parser.add_argument( + '--orca_warmup_sec', + type=float, + default=0., + help="Duration of the synthesized audio to buffer before streaming it out. A higher value helps slower " + "(e.g., Raspberry Pi) to keep up with real-time at the cost of increasing the initial delay.") + parser.add_argument( + '--log_level', + choices=[x.value for x in Logger.Levels], + default=Logger.Levels.INFO.value, + help='Log level verbosity.') + args = parser.parse_args() + + access_key = args.access_key + picollm_model_path = args.picollm_model_path + keyword_model_path = args.keyword_model_path + cheetah_endpoint_duration_sec = args.cheetah_endpoint_duration_sec + picollm_device = args.picollm_device + picollm_completion_token_limit = args.picollm_completion_token_limit + picollm_presence_penalty = args.picollm_presence_penalty + picollm_frequency_penalty = args.picollm_frequency_penalty + picollm_temperature = args.picollm_temperature + picollm_top_p = args.picollm_top_p + orca_warmup_sec = args.orca_warmup_sec + log_level = Logger.Levels(args.log_level) + + log = Logger(log_level) + + if keyword_model_path is None: + porcupine = pvporcupine.create(access_key=access_key, keywords=['picovoice']) + else: + porcupine = pvporcupine.create(access_key=access_key, keyword_paths=[keyword_model_path]) + log.info(f"→ Porcupine V{porcupine.version}") + + cheetah = pvcheetah.create(access_key=access_key, endpoint_duration_sec=cheetah_endpoint_duration_sec) + log.info(f"→ Cheetah V{cheetah.version}") + + pllm = picollm.create(access_key=access_key, model_path=picollm_model_path, device=picollm_device) + dialog = pllm.get_dialog() + log.info(f"→ picoLLM V{pllm.version} {pllm.model}") + + main_connection, orca_process_connection = Pipe() + orca_process = Process(target=orca_worker, args=(access_key, orca_process_connection, orca_warmup_sec)) + orca_process.start() + while not main_connection.poll(): + time.sleep(0.01) + log.info(f"→ Orca V{main_connection.recv()['version']}") + + mic = PvRecorder(frame_length=porcupine.frame_length) + mic.start() + + log.info("\n$ Say `Picovoice` ...") + + stop = [False] + + def handler(_, __) -> None: + stop[0] = True + + signal.signal(signal.SIGINT, handler) + + wake_word_detected = False + human_request = '' + endpoint_reached = False + utterance_end_sec = 0 + + porcupine_profiler = RTFProfiler(porcupine.sample_rate) + cheetah_profiler = RTFProfiler(cheetah.sample_rate) + + try: + while True: + if stop[0]: + break + elif not wake_word_detected: + pcm = mic.read() + porcupine_profiler.tick() + wake_word_detected = porcupine.process(pcm) == 0 + porcupine_profiler.tock(pcm) + if wake_word_detected: + log.debug(f"[Porcupine RTF: {porcupine_profiler.rtf():.3f}]") + log.info("$ Wake word detected, utter your request or question ...\n") + log.info("human > ", end='') + elif not endpoint_reached: + pcm = mic.read() + cheetah_profiler.tick() + partial_transcript, endpoint_reached = cheetah.process(pcm) + cheetah_profiler.tock(pcm) + log.info(partial_transcript, end='') + human_request += partial_transcript + if endpoint_reached: + utterance_end_sec = time.time() + cheetah_profiler.tick() + remaining_transcript = cheetah.flush() + cheetah_profiler.tock() + human_request += remaining_transcript + log.info(remaining_transcript, end='\n\n') + log.debug(f"[Cheetah RTF: {cheetah_profiler.rtf():.3f}]") + else: + dialog.add_human_request(human_request) + + picollm_profiler = TPSProfiler() + + def llm_callback(text: str) -> None: + picollm_profiler.tock() + main_connection.send( + {'command': 'synthesize', 'text': text, 'utterance_end_sec': utterance_end_sec}) + log.info(text, end='') + + log.info("\nllm > ", end='') + res = pllm.generate( + prompt=dialog.prompt(), + completion_token_limit=picollm_completion_token_limit, + presence_penalty=picollm_presence_penalty, + frequency_penalty=picollm_frequency_penalty, + temperature=picollm_temperature, + top_p=picollm_top_p, + stream_callback=llm_callback) + main_connection.send({'command': 'flush'}) + log.info('\n') + dialog.add_llm_response(res.completion) + log.debug(f"[picoLLM TPS: {picollm_profiler.tps():.2f}]") + + while not main_connection.poll(): + time.sleep(0.01) + message = main_connection.recv() + log.debug(f"[Orca RTF: {message['rtf']:.2f}]") + log.debug(f"[Delay: {message['delay']:.2f} sec]") + while not main_connection.poll(): + time.sleep(0.01) + assert main_connection.recv()['done'] + + wake_word_detected = False + human_request = '' + endpoint_reached = False + log.info("\n$ Say `Picovoice` ...") + finally: + main_connection.send({'command': 'close'}) + mic.delete() + pllm.release() + cheetah.delete() + porcupine.delete() + orca_process.join() + + +if __name__ == '__main__': + main() diff --git a/recipes/voice-llm/python/requirements.txt b/recipes/voice-llm/python/requirements.txt new file mode 100644 index 0000000..4657101 --- /dev/null +++ b/recipes/voice-llm/python/requirements.txt @@ -0,0 +1,6 @@ +numpy +pvcheetah==2.0.1 +pvorca==0.2.1 +pvporcupine==3.0.2 +pvrecorder==1.2.2 +sounddevice \ No newline at end of file