-
Notifications
You must be signed in to change notification settings - Fork 0
/
llama3_base.py
58 lines (47 loc) · 2.03 KB
/
llama3_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from datasets import load_dataset, Dataset, DatasetDict, load_from_disk
from transformers.generation.utils import GenerationConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import pandas as pd
import os
import argparse
def gpu_list(string):
gpu_ids = []
for gpu_id in string.split(','):
try:
gpu_ids.append(int(gpu_id))
except ValueError:
raise argparse.ArgumentTypeError(f"Invalid GPU id: {gpu_id}")
return gpu_ids
parser = argparse.ArgumentParser(description='控制使用哪些GPU进行计算')
parser.add_argument('--gpu', type=gpu_list, default=[],
help='要使用的GPU id列表,用逗号分隔')
args = parser.parse_args()
print(','.join(map(str, args.gpu)))
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, args.gpu))
PATH = '/model/Meta-Llama-3-8B'
tokenizer = AutoTokenizer.from_pretrained(PATH)
model = AutoModelForCausalLM.from_pretrained(PATH, device_map="auto").eval()
def inference(sample):
inputs = tokenizer([instr+examples+sample['inputs']],
return_tensors="pt").to('cuda')
outputs = model.generate(**inputs, max_new_tokens=512)
return {"pred": tokenizer.decode(outputs[0].tolist())}
if __name__ == "__main__":
with open('../common_instr.txt', 'r') as file:
instr = file.read()
examples = ""
data_path = '/logiNumBench/datas/'
data_names = ["D1", "D2", "D3", "D4", "D5", "D6",
"LD1", "LD2", "LD3", "LD4", "LD5", "LD6"]
for datan in data_names:
with open("../shot-2/" + datan + ".txt", 'r') as file:
examples = file.read()
datap = data_path+datan+'/disk'
print('---------------------'+datan+'---------------------')
test_datasets = load_from_disk(datap)['test'].select(range(200))
with torch.no_grad():
test_datasets = test_datasets.map(inference)
df = pd.DataFrame(test_datasets)
df.to_excel('./few/llama3-'+datan+'.xlsx',
index=False, engine='xlsxwriter')