-
Notifications
You must be signed in to change notification settings - Fork 18
/
just_chat.py
57 lines (49 loc) · 1.92 KB
/
just_chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import time
import bitsandbytes #not strictly necessary, used for 8bit inference
import sys
from drugs.nice_imports import efficiency_stuff
import torch
from transformers import AutoTokenizer, TextStreamer, AutoModelForCausalLM
from drugs.dgenerate import DRUGS
model_id = "NousResearch/Llama-2-7b-chat-hf"
#model_id = "cognitivecomputations/dolphin-2.2.1-mistral-7b"
#model_id = "TheBloke/30B-Epsilon-AWQ"
#efficiency_stuff['load_in_8bit'] = False #when testing AWQ models
sober_model = AutoModelForCausalLM.from_pretrained(
model_id,
attn_implementation="flash_attention_2",
**efficiency_stuff)
sober_model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_id)
streamer = TextStreamer(tokenizer)
injection_depth = 0.4 #how deep to shove the needle in
spread = 5/32 #how many layers to dose on either side of the injection site
drugs = DRUGS()
drugs.set_K_dose_theta(0.5)
drugs.set_K_dose_shape([
{'depth': (injection_depth-(spread*1.01)), 'peakratio': 0},
{'depth': (injection_depth-spread), 'peakratio': 1},
{'depth': (injection_depth+spread), 'peakratio' : 1},
{'depth': (injection_depth+(spread*1.01)), 'peakratio' : 0}],
'ceil')
model = drugs.inject(sober_model)
initial_input = 'Hello'#str(input("\bAsk Something:"))
tokenized_start = tokenizer.apply_chat_template([
{'role': 'system',
'content': 'Respond as Alan Watts would.'},
{'role': 'user',
'content': initial_input}
], return_tensors='pt')
with torch.no_grad():
while True:
generated_tokens = model.Dgenerate(
input_ids = tokenized_start,
min_new_tokens = 5,
streamer = streamer,
)
print("\n\nAsk Something:", end="")
model.cold_shower(True)
await_input = str(input(" "))
tokenized_start = tokenizer.apply_chat_template([{
'role': 'user',
'content': await_input}], return_tensors="pt")