forked from t41372/Open-LLM-VTuber
-
Notifications
You must be signed in to change notification settings - Fork 0
/
live2d_model.py
200 lines (158 loc) · 7.28 KB
/
live2d_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import json
import chardet
from loguru import logger
# This class will only prepare the payload for the live2d model
# the process of sending the payload should be done by the caller
# This class is **Not responsible** for sending the payload to the server
class Live2dModel:
"""
A class to represent a Live2D model. This class only prepares and stores the information of the Live2D model. It does not send anything to the frontend or server or anything.
Attributes:
model_dict_path (str): The path to the model dictionary file.
live2d_model_name (str): The name of the Live2D model.
model_info (dict): The information of the Live2D model.
emo_map (dict): The emotion map of the Live2D model.
emo_str (str): The string representation of the emotion map of the Live2D model.
"""
model_dict_path: str
live2d_model_name: str
model_info: dict
emo_map: dict
emo_str: str
def __init__(
self, live2d_model_name: str, model_dict_path: str = "model_dict.json"
):
self.model_dict_path: str = model_dict_path
self.live2d_model_name: str = live2d_model_name
self.set_model(live2d_model_name)
def set_model(self, model_name: str) -> None:
"""
Set the model with its name and load the model information. This method will initialize the `self.model_info`, `self.emo_map`, and `self.emo_str` attributes.
This method is called in the constructor.
Parameters:
model_name (str): The name of the live2d model.
Returns:
None
"""
self.model_info: dict = self._lookup_model_info(model_name)
self.emo_map: dict = self.model_info["emotionMap"]
self.emo_str: str = " ".join([f"[{key}]," for key in self.emo_map.keys()])
# emo_str is a string of the keys in the emoMap dictionary. The keys are enclosed in square brackets.
# example: `"[fear], [anger], [disgust], [sadness], [joy], [neutral], [surprise]"`
def _load_file_content(self, file_path: str) -> str:
"""Load the content of a file with robust encoding handling."""
# Try common encodings first
encodings = ["utf-8", "utf-8-sig", "gbk", "gb2312", "ascii"]
for encoding in encodings:
try:
with open(file_path, "r", encoding=encoding) as file:
return file.read()
except UnicodeDecodeError:
continue
# If all common encodings fail, try to detect encoding
try:
with open(file_path, "rb") as file:
raw_data = file.read()
detected = chardet.detect(raw_data)
detected_encoding = detected["encoding"]
if detected_encoding:
try:
return raw_data.decode(detected_encoding)
except UnicodeDecodeError:
pass
except Exception as e:
logger.error(f"Error detecting encoding for {file_path}: {e}")
raise UnicodeError(f"Failed to decode {file_path} with any encoding")
def _lookup_model_info(self, model_name: str) -> dict:
"""
Find the model information from the model dictionary and return the information about the matched model.
Parameters:
model_name (str): The name of the live2d model.
Returns:
dict: The dictionary with the information of the matched model.
Raises:
FileNotFoundError if the model dictionary file is not found.
json.JSONDecodeError if the model dictionary file is not a valid JSON file.
KeyError if the model name is not found in the model dictionary.
"""
self.live2d_model_name = model_name
try:
file_content = self._load_file_content(self.model_dict_path)
model_dict = json.loads(file_content)
except FileNotFoundError as file_e:
print(f"Model dictionary file not found at {self.model_dict_path}.")
raise file_e
except json.JSONDecodeError as json_e:
print(
f"Error decoding JSON from model dictionary file at {self.model_dict_path}."
)
raise json_e
except UnicodeError as uni_e:
print(f"Error reading model dictionary file at {self.model_dict_path}.")
raise uni_e
except Exception as e:
print(
f"Error occurred while reading model dictionary file at {self.model_dict_path}."
)
raise e
# Find the model in the model_dict
matched_model = next(
(model for model in model_dict if model["name"] == model_name), None
)
if matched_model is None:
print(f"Unable to find {model_name} in {self.model_dict_path}.")
raise KeyError(
f"{model_name} not found in model dictionary {self.model_dict_path}."
)
# The feature: "translate model url to full url if it starts with '/' " is no longer implemented here
print("Model Information Loaded.")
return matched_model
def extract_emotion(self, str_to_check: str) -> list:
"""
Check the input string for any emotion keywords and return a list of values (the expression index) of the emotions found in the string.
Parameters:
str_to_check (str): The string to check for emotions.
Returns:
list: A list of values of the emotions found in the string. An empty list is returned if no emotions are found.
"""
expression_list = []
str_to_check = str_to_check.lower()
i = 0
while i < len(str_to_check):
if str_to_check[i] != "[":
i += 1
continue
for key in self.emo_map.keys():
emo_tag = f"[{key}]"
if str_to_check[i : i + len(emo_tag)] == emo_tag:
expression_list.append(self.emo_map[key])
i += len(emo_tag) - 1
break
i += 1
return expression_list
def remove_emotion_keywords(self, target_str: str) -> str:
"""
Remove the emotion keywords from the input string and return the cleaned string.
Parameters:
str_to_check (str): The string to check for emotions.
Returns:
str: The cleaned string with the emotion keywords removed.
"""
lower_str = target_str.lower()
for key in self.emo_map.keys():
lower_key = f"[{key}]".lower()
while lower_key in lower_str:
start_index = lower_str.find(lower_key)
end_index = start_index + len(lower_key)
target_str = target_str[:start_index] + target_str[end_index:]
lower_str = lower_str[:start_index] + lower_str[end_index:]
return target_str
if __name__ == "__main__":
live2d_model = Live2dModel("shizuku-local")
print(live2d_model.model_info)
print(live2d_model.emo_map)
print(live2d_model.emo_str)
test_str = "[*joins hands and smi]les* * [SmIrK]: HEHE, YOU THINK YOU CAN HANDLE THE TRUTH?[anger][anger] [anger] [smirk] [anger]["
print(test_str)
print(live2d_model.extract_emotion(test_str))
print(live2d_model.remove_emotion_keywords(test_str))