-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathazw_nodes.py
146 lines (124 loc) · 4.48 KB
/
azw_nodes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# -*- coding: utf-8 -*-
#
#from .zw_def import *
from .zw_tool import *
#from .zw_srt import *
#
#from icecream import ic as cpp
#cpp=print
#
import torch
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
from folder_paths import models_dir
from comfy import model_management, model_patcher
#
#------------------
#
deep_model_folder_path = Path(models_dir) / 'deepseek'
deep_model_folder_path.mkdir(parents=True, exist_ok=True)
#
#-------------
#
class DeepModel:
def __init__(self, model, patcher, tokenizer=None, processor=None):
self.model = model
self.tokenizer = tokenizer
self.processor = processor
self.patcher = patcher
# hook modelClass.device setter
def set_value(self, new_value):
pass
model.__class__.device = property(fget=model.__class__.device.fget, fset=set_value)
class DeepLoader:
@classmethod
def INPUT_TYPES(cls):
model_lst = []
for folder in deep_model_folder_path.iterdir():
if folder.is_dir():
config_file = folder / 'config.json'
if config_file.is_file():
relative_path = str(folder.relative_to(deep_model_folder_path))
model_lst.append(relative_path)
return {
"required": {
"model_name": (model_lst, {}),
}
}
RETURN_TYPES = ("DEEP_MODEL", )
RETURN_NAMES = ("model", )
FUNCTION = "load_model"
CATEGORY = "ComfyUI-DeepSeek-R1"
def load_model(self, model_name):
offload_device = torch.device('cpu')
load_device = model_management.get_torch_device()
model = AutoModelForCausalLM.from_pretrained(
deep_model_folder_path / model_name,
device_map=offload_device,
torch_dtype="auto",
)
tokenizer = AutoTokenizer.from_pretrained(deep_model_folder_path / model_name)
patcher = model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
return (DeepModel(model, patcher, tokenizer=tokenizer), )
class DeepGen:
@classmethod
def INPUT_TYPES(cls):
DEFAULT_INSTRUCT = ''
return {
"required": {
"deep_model": ("DEEP_MODEL",),
#"prompt": ("STRING", {"defaultInput": True,}),
#
"system_instruction": ("STRING", {"default": DEFAULT_INSTRUCT, "multiline": True}),
"user_prompt": ("STRING", {"default": "", "multiline": True}),
"seed": ("INT", {"default": 888, "min": 0, "max": 0xffffffffffffffff}),
"max_tokens": ("INT", {"default": 512, "min": 0, "max": 0xffffffffffffffff}),
"temperature": ("FLOAT", {"default": 1, "min": 0, "max": 2}),
"top_k": ("INT", {"default": 50, "min": 0, "max": 101}),
"top_p": ("FLOAT", {"default": 1, "min": 0, "max": 1}),
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("STRING",)
FUNCTION = "deep_xgen"
CATEGORY = "ComfyUI-DeepSeek-R1"
def deep_xgen(self, deep_model, system_instruction, user_prompt, seed=0, temperature=1.0, max_tokens=512, top_k=50, top_p=1.0, **kwargs):
set_seed(seed % 9999999)
messages = [
{"role": "system", "content": system_instruction},
{"role": "user", "content": user_prompt.format(**kwargs)},
#{"role": "user", "content": user_prompt},
]
tokenizer = deep_model.tokenizer
model = deep_model.model
patcher = deep_model.patcher
model_management.load_model_gpu(patcher)
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return (response, )
#
#----------------
#
NODE_CLASS_MAPPINGS = {
#
'deep_load': DeepLoader,
'deep_gen': DeepGen,
#
}
#----------
#