-
Notifications
You must be signed in to change notification settings - Fork 1
/
inference.py
113 lines (94 loc) · 4.19 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import argparse
import os
from path_config import SAVEPATH, CACHEDIR, model_path, map_config
from run import parse_path
def run(args):
if args.interactive:
base_cmd = f"python -B -m src.test_interact"
elif args.stream:
base_cmd = f"python -B -m src.test_stream"
else:
base_cmd = f"python -B -m src.test_case"
# Config file
model = args.model
base_cmd = f"{base_cmd} +{args.dataset}={map_config(model)} model.model_name_or_path={model_path(model)}"
base_cmd = f"{base_cmd} training.do_train=false wandb.log=false"
# Compression type
attn_type, n_tok = parse_path(args.eval_path)
base_cmd = f"{base_cmd} training.comp.comp_type=online"
base_cmd = f"{base_cmd} training.comp.attn_type={attn_type}"
base_cmd = f"{base_cmd} training.comp.num_comp_tokens={n_tok}"
if "-sink" in args.eval_path:
base_cmd = f"{base_cmd} training.comp.sink=true"
# Evaluation path
wandb_group = args.dataset
eval_path = f"{SAVEPATH}/{wandb_group}/{args.eval_path}"
base_cmd = f"{base_cmd} training.eval_path={eval_path}"
load_path = None
if args.load_path != '':
load_path = f"{SAVEPATH}/{wandb_group}/{args.load_path}"
base_cmd = f"{base_cmd} training.load_path={load_path}"
if args.interactive:
base_cmd = f"{base_cmd} training.interactive=true"
if args.stream:
base_cmd = f"{base_cmd} model.stream=true"
base_cmd = f"{base_cmd} model.cache_dir={CACHEDIR}"
cmd = base_cmd.split()
print(cmd)
if load_path is not None:
print(f"\nFinetuned base model adapter path : {load_path}")
print(f"Compression adapter path : {eval_path}\n")
os.execvp(cmd[0], cmd)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model",
"-m",
type=str,
default="llama-7b",
choices=["llama-7b", "llama-2-7b-chat", "mistral-7b-inst"])
parser.add_argument("--dataset",
"-d",
type=str,
default="unified",
choices=["unified", "pretrain", "metaicl", "dialog"])
parser.add_argument("--eval_path", "-e", type=str, help="Compression adapter path")
parser.add_argument(
"--eval_name",
type=str,
default='',
help="Shortcut for eval_path. Set --eval_path and --dataset to test other models.")
parser.add_argument("--interactive",
"-i",
action="store_true",
help="Interactive mode. Turning off will run some sanity check test codes.")
parser.add_argument("--stream", "-s", action="store_true", help="Evaluate streaming setting.")
args = parser.parse_args()
args.load_path = ''
if args.model == "llama-7b":
args.load_path = "llama-7b-no"
# Shortcut commands
if args.eval_name != '':
if args.model == "llama-7b":
if args.eval_name == "merge_recur":
args.eval_path = "finetune/llama-7b-no-online-merge_recur-ntok8"
if args.eval_name == "concat_recur":
args.eval_path = "finetune/llama-7b-no-online-concat_recur-ntok2"
elif args.model == "llama-2-7b-chat":
args.dataset = "pretrain"
print("Dataset is set to 'pretrain' for llama-2-7b-chat (check inference.py L93)")
if args.eval_name == "merge_recur":
args.eval_path = "llama-2-7b-chat-online-merge_recur-ntok8"
if args.eval_name == "concat_recur":
args.eval_path = "llama-2-7b-chat-online-concat_recur-ntok2"
if args.dataset == "pretrain":
args.eval_path += "-sink"
if "concat_recur" in args.eval_path:
args.eval_path += "_pajama"
# This version does not use Lmsys-Chat for compression training,
# which makes models refuse to reponse with personal information or memories.
# Streaming model
if args.stream:
args.model = "llama-7b"
args.dataset = "pretrain"
args.eval_path = "finetune/llama-7b-no-online-concat_recur-ntok2-sink"
run(args)