-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
113 lines (97 loc) · 4.54 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
This file contains the dataset class for the training process.
"""
import torch
from torch.utils.data import Dataset
class BilingualDataset(Dataset):
"""
Dataset class for the training process.
"""
def __init__(self, dataset, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
super().__init__()
self.seq_len = seq_len
self.dataset = dataset
self.tokenizer_src = tokenizer_src
self.tokenizer_tgt = tokenizer_tgt
self.src_lang = src_lang
self.tgt_lang = tgt_lang
# The special tokens to be used in the dataset
# SOS: Start of sentence
self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64)
# EOS: End of sentence
self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64)
# PAD: Padding, used to make all the sentences the same size
self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64)
# The length of the dataset is the number of sentences in the dataset
def __len__(self):
return len(self.dataset)
# This function is called when the dataset is indexed with dataset[idx],
# where idx is an integer. This function should return a dictionary containing
# the encoder input, decoder input, encoder mask, decoder mask, and label.
def __getitem__(self, idx):
src_target_pair = self.dataset[idx]
src_text = src_target_pair['translation'][self.src_lang]
tgt_text = src_target_pair['translation'][self.tgt_lang]
# Transform the text into tokens
enc_input_tokens = self.tokenizer_src.encode(src_text).ids
dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids
# Add sos, eos and padding to each sentence
enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2 # We will add <s> and </s>
# We will only add <s>, and </s> only on the label
dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1
# Make sure the number of padding tokens is not negative. If it is, the sentence is too long
if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
raise ValueError("Sentence is too long")
# Add <s> and </s> token
encoder_input = torch.cat(
[
self.sos_token,
torch.tensor(enc_input_tokens, dtype=torch.int64),
self.eos_token,
torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64),
],
dim=0,
)
# Add only <s> token
decoder_input = torch.cat(
[
self.sos_token,
torch.tensor(dec_input_tokens, dtype=torch.int64),
torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
],
dim=0,
)
# Add only </s> token
label = torch.cat(
[
torch.tensor(dec_input_tokens, dtype=torch.int64),
self.eos_token,
torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
],
dim=0,
)
# Double check the size of the tensors to make sure they are all seq_len long
assert encoder_input.size(0) == self.seq_len
assert decoder_input.size(0) == self.seq_len
assert label.size(0) == self.seq_len
return {
"encoder_input": encoder_input, # (seq_len)
"decoder_input": decoder_input, # (seq_len)
"encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
"decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len),
"label": label, # (seq_len)
"src_text": src_text,
"tgt_text": tgt_text,
}
def causal_mask(size):
"""
The causal mask is used to prevent the decoder from looking into the future.
This function returns a tensor of shape (1, size, size) where all the elements
are 0 except for the upper triangular part. The upper triangular part is filled
with -inf. This is because the softmax function will convert the -inf to 0, and
the softmax function is applied to the words before calculating the cross entropy.
:param size: the size of the mask
:return: the causal mask
"""
mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
return mask == 0