Skip to content

Commit

Permalink
修复插件导入时的pytorch加载问题
Browse files Browse the repository at this point in the history
  • Loading branch information
binary-husky committed Nov 12, 2023
1 parent 7e56ace commit b9b7bf3
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 24 deletions.
46 changes: 33 additions & 13 deletions crazy_functional.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from toolbox import HotReload # HotReload 的意思是热更新,修改函数插件后,不需要重启程序,代码直接生效
from toolbox import trimmed_format_exc


def get_crazy_functions():
Expand Down Expand Up @@ -292,6 +293,7 @@ def get_crazy_functions():
}
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')

try:
Expand All @@ -316,6 +318,7 @@ def get_crazy_functions():
}
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')

try:
Expand All @@ -331,6 +334,7 @@ def get_crazy_functions():
},
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')

try:
Expand All @@ -346,23 +350,24 @@ def get_crazy_functions():
},
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')

try:
from crazy_functions.图片生成 import 图片生成, 图片生成_DALLE3
from crazy_functions.图片生成 import 图片生成_DALLE2, 图片生成_DALLE3
function_plugins.update({
"图片生成(先切换模型到openai或api2d)": {
"图片生成_DALLE2 (先切换模型到openai或api2d)": {
"Group": "对话",
"Color": "stop",
"AsButton": False,
"AdvancedArgs": True, # 调用时,唤起高级参数输入区(默认False)
"ArgsReminder": "在这里输入分辨率, 如1024x1024(默认),支持 256x256, 512x512, 1024x1024", # 高级参数输入区的显示提示
"Info": "使用DALLE2生成图片 | 输入参数字符串,提供图像的内容",
"Function": HotReload(图片生成)
"Function": HotReload(图片生成_DALLE2)
},
})
function_plugins.update({
"图片生成_DALLE3(先切换模型到openai或api2d)": {
"图片生成_DALLE3 (先切换模型到openai或api2d)": {
"Group": "对话",
"Color": "stop",
"AsButton": False,
Expand All @@ -373,6 +378,7 @@ def get_crazy_functions():
},
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')

try:
Expand All @@ -389,6 +395,7 @@ def get_crazy_functions():
}
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')

try:
Expand All @@ -403,6 +410,7 @@ def get_crazy_functions():
}
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')

try:
Expand All @@ -418,6 +426,7 @@ def get_crazy_functions():
}
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')

try:
Expand All @@ -433,6 +442,7 @@ def get_crazy_functions():
}
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')

try:
Expand All @@ -448,6 +458,7 @@ def get_crazy_functions():
}
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')

try:
Expand All @@ -461,6 +472,7 @@ def get_crazy_functions():
}
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')

try:
Expand Down Expand Up @@ -505,6 +517,7 @@ def get_crazy_functions():
}
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')

try:
Expand All @@ -522,6 +535,7 @@ def get_crazy_functions():
}
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')

try:
Expand All @@ -535,6 +549,7 @@ def get_crazy_functions():
}
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')

try:
Expand All @@ -548,17 +563,22 @@ def get_crazy_functions():
}
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')

from crazy_functions.多智能体 import 多智能体终端
function_plugins.update({
"AutoGen多智能体终端(仅供测试)": {
"Group": "智能体",
"Color": "stop",
"AsButton": False,
"Function": HotReload(多智能体终端)
}
})
try:
from crazy_functions.多智能体 import 多智能体终端
function_plugins.update({
"AutoGen多智能体终端(仅供测试)": {
"Group": "智能体",
"Color": "stop",
"AsButton": False,
"Function": HotReload(多智能体终端)
}
})
except:
print(trimmed_format_exc())
print('Load function plugin failed')

# try:
# from crazy_functions.chatglm微调工具 import 微调数据集生成
Expand Down
9 changes: 3 additions & 6 deletions crazy_functions/批量翻译PDF文档_多线程.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from toolbox import CatchException, report_exception, get_log_folder, gen_time_str
from toolbox import CatchException, report_exception, get_log_folder, gen_time_str, check_packages
from toolbox import update_ui, promote_file_to_downloadzone, update_ui_lastest_msg, disable_auto_promotion
from toolbox import write_history_to_file, promote_file_to_downloadzone
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
from .crazy_utils import read_and_clean_pdf_text
from .pdf_fns.parse_pdf import parse_pdf, get_avail_grobid_url, translate_pdf
from colorful import *
import copy
import os
import math


@CatchException
def 批量翻译PDF文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
Expand All @@ -22,9 +21,7 @@ def 批量翻译PDF文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst

# 尝试导入依赖,如果缺少依赖,则给出安装建议
try:
import fitz
import tiktoken
import scipdf
check_packages(["fitz", "tiktoken", "scipdf"])
except:
report_exception(chatbot, history,
a=f"解析项目: {txt}",
Expand Down
2 changes: 1 addition & 1 deletion request_llms/bridge_chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
cmd_to_install = "`pip install -r request_llms/requirements_chatglm.txt`"


from transformers import AutoModel, AutoTokenizer
from toolbox import get_conf, ProxyNetworkActivate
from .local_llm_class import LocalLLMHandle, get_local_llm_predict_fns

Expand All @@ -23,6 +22,7 @@ def load_model_and_tokenizer(self):
import os, glob
import os
import platform
from transformers import AutoModel, AutoTokenizer
LOCAL_MODEL_QUANT, device = get_conf('LOCAL_MODEL_QUANT', 'LOCAL_MODEL_DEVICE')

if LOCAL_MODEL_QUANT == "INT4": # INT4
Expand Down
2 changes: 1 addition & 1 deletion request_llms/bridge_chatglm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
cmd_to_install = "`pip install -r request_llms/requirements_chatglm.txt`"


from transformers import AutoModel, AutoTokenizer
from toolbox import get_conf, ProxyNetworkActivate
from .local_llm_class import LocalLLMHandle, get_local_llm_predict_fns

Expand All @@ -20,6 +19,7 @@ def load_model_info(self):

def load_model_and_tokenizer(self):
# 🏃‍♂️🏃‍♂️🏃‍♂️ 子进程执行
from transformers import AutoModel, AutoTokenizer
import os, glob
import os
import platform
Expand Down
2 changes: 0 additions & 2 deletions request_llms/bridge_moss.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@

from transformers import AutoModel, AutoTokenizer
import time
import threading
import importlib
from toolbox import update_ui, get_conf
from multiprocessing import Process, Pipe

Expand Down
8 changes: 7 additions & 1 deletion toolbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,4 +1145,10 @@ def get_chat_default_kwargs():

def get_max_token(llm_kwargs):
from request_llms.bridge_all import model_info
return model_info[llm_kwargs['llm_model']]['max_token']
return model_info[llm_kwargs['llm_model']]['max_token']

def check_packages(packages=[]):
import importlib.util
for p in packages:
spam_spec = importlib.util.find_spec(p)
if spam_spec is None: raise ModuleNotFoundError

0 comments on commit b9b7bf3

Please sign in to comment.