forked from besarthoxhaj/lora
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_onion.py
40 lines (30 loc) · 1.21 KB
/
data_onion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from torch.utils.data import Dataset
import transformers as t
import datasets as d
import pandas as pd
TEMPLATE = "Below is a title for an article. Write an article that appropriately suits the title: \n\n### Title:\n{title}\n\n### Article:\n"
class TrainDataset(Dataset):
def __init__(self):
self.tokenizer = t.AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
self.tokenizer.pad_token_id = 0
self.tokenizer.padding_side = "left"
self.data_df = pd.read_csv("onion.csv")
self.data_df = self.data_df.drop(columns=["Published Time"])
self.ds_prompts = self.data_df.apply(self.prompt, axis=1)
self.ds = self.ds_prompts.apply(self.tokenize)
def __len__(self):
return len(self.ds)
def __getitem__(self, idx):
return self.ds.iloc[idx]
def prompt(self, elm):
prompt = TEMPLATE.format(title=elm["Title"])
prompt = prompt + str(elm["Content"])
return {"prompt": prompt}
def tokenize(self, elm):
res = self.tokenizer(elm["prompt"])
res["input_ids"].append(self.tokenizer.eos_token_id)
res["attention_mask"].append(1)
res["labels"] = res["input_ids"].copy()
return res
def max_seq_len(self):
return max([len(elm["input_ids"]) for elm in self.ds])