Skip to content

Commit

Permalink
fix: model loading check (#391)
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama authored Jun 21, 2024
1 parent 0d6621d commit f17bd31
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
8 changes: 4 additions & 4 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

import os
import os, sys
import json
import logging
from functools import partial
Expand Down Expand Up @@ -63,7 +63,7 @@ def load_models(
download_all_assets(tmpdir=tmp)
if not check_all_assets(update=False):
logging.error("counld not satisfy all assets needed.")
exit(1)
return False
elif source == 'huggingface':
hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface"))
try:
Expand All @@ -79,7 +79,7 @@ def load_models(
self.logger.log(logging.INFO, f'Load from local: {custom_path}')
download_path = custom_path

self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()}, **kwargs)
return self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()}, **kwargs)

def _load(
self,
Expand Down Expand Up @@ -148,7 +148,7 @@ def _load(
self.pretrain_models['tokenizer'] = tokenizer
self.logger.log(logging.INFO, 'tokenizer loaded.')

self.check_model()
return self.check_model()

def _infer(
self,
Expand Down
7 changes: 5 additions & 2 deletions examples/cmd/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@ def main():

chat = ChatTTS.Chat()
print("Initializing ChatTTS...")
chat.load_models()
print("Models loaded successfully.")
if chat.load_models():
print("Models loaded successfully.")
else:
print("Models load failed.")
sys.exit(1)

texts = [text_input]
print("Text prepared for inference:", texts)
Expand Down
11 changes: 9 additions & 2 deletions examples/web/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,17 @@ def main():
chat = ChatTTS.Chat()

if args.custom_path == None:
chat.load_models()
ret = chat.load_models()
else:
print('local model path:', args.custom_path)
chat.load_models('custom', custom_path=args.custom_path)
ret = chat.load_models('custom', custom_path=args.custom_path)

if ret:
print("Models loaded successfully.")
else:
print("Models load failed.")
sys.exit(1)


demo.launch(server_name=args.server_name, server_port=args.server_port, root_path=args.root_path, inbrowser=True)

Expand Down

0 comments on commit f17bd31

Please sign in to comment.