From f17bd3132765d5041152c6696887cd6a9d9354b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Fri, 21 Jun 2024 16:57:44 +0900 Subject: [PATCH] fix: model loading check (#391) --- ChatTTS/core.py | 8 ++++---- examples/cmd/run.py | 7 +++++-- examples/web/webui.py | 11 +++++++++-- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index f49139d55..6553ef839 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -1,5 +1,5 @@ -import os +import os, sys import json import logging from functools import partial @@ -63,7 +63,7 @@ def load_models( download_all_assets(tmpdir=tmp) if not check_all_assets(update=False): logging.error("counld not satisfy all assets needed.") - exit(1) + return False elif source == 'huggingface': hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface")) try: @@ -79,7 +79,7 @@ def load_models( self.logger.log(logging.INFO, f'Load from local: {custom_path}') download_path = custom_path - self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()}, **kwargs) + return self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()}, **kwargs) def _load( self, @@ -148,7 +148,7 @@ def _load( self.pretrain_models['tokenizer'] = tokenizer self.logger.log(logging.INFO, 'tokenizer loaded.') - self.check_model() + return self.check_model() def _infer( self, diff --git a/examples/cmd/run.py b/examples/cmd/run.py index 8793ec51f..7af340f94 100644 --- a/examples/cmd/run.py +++ b/examples/cmd/run.py @@ -31,8 +31,11 @@ def main(): chat = ChatTTS.Chat() print("Initializing ChatTTS...") - chat.load_models() - print("Models loaded successfully.") + if chat.load_models(): + print("Models loaded successfully.") + else: + print("Models load failed.") + sys.exit(1) texts = [text_input] print("Text prepared for inference:", texts) diff --git a/examples/web/webui.py b/examples/web/webui.py index d6896abc6..3be454ad9 100644 --- a/examples/web/webui.py +++ b/examples/web/webui.py @@ -182,10 +182,17 @@ def main(): chat = ChatTTS.Chat() if args.custom_path == None: - chat.load_models() + ret = chat.load_models() else: print('local model path:', args.custom_path) - chat.load_models('custom', custom_path=args.custom_path) + ret = chat.load_models('custom', custom_path=args.custom_path) + + if ret: + print("Models loaded successfully.") + else: + print("Models load failed.") + sys.exit(1) + demo.launch(server_name=args.server_name, server_port=args.server_port, root_path=args.root_path, inbrowser=True)