-
Notifications
You must be signed in to change notification settings - Fork 1
/
extract_BERT_embedding_arrhythmia.py
53 lines (42 loc) · 1.62 KB
/
extract_BERT_embedding_arrhythmia.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
import numpy as np
import scipy.io
import torch
from sklearn.utils import shuffle
from transformers import BertTokenizer, BertModel
# Load the Arrhythmia dataset
data = scipy.io.loadmat("./data/arrhythmia.mat")
x_data = data['X'] # 518 samples
y_data = ((data['y']).astype(np.int32)).reshape(-1)
# Initialize the BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
# Convert numerical vectors to text
def vector_to_text(vector):
return ' '.join(map(str, vector))
# Convert the numerical vectors to text
x_data_text = [vector_to_text(sample) for sample in x_data]
# Shuffle the data
x_data_text, y_data = shuffle(x_data_text, y_data, random_state=42)
# Function to get BERT embeddings for a document
def get_bert_embedding(text):
inputs = tokenizer(text, return_tensors='pt', max_length=512, truncation=True, padding='max_length')
with torch.no_grad():
outputs = model(**inputs)
# Use the embeddings from the [CLS] token
cls_embedding = outputs.last_hidden_state[:, 0, :].numpy()
return cls_embedding.flatten()
# List to store embeddings
embeddings = []
# Iterate over the data and get embeddings
for sample_text in x_data_text:
embedding = get_bert_embedding(sample_text)
embeddings.append(embedding)
# Convert list of embeddings to numpy array
embeddings_array = np.vstack(embeddings)
# Save the dataset with attributes "X" (raw data), "y" (labels), and "embeddings" in .mat format
scipy.io.savemat('./arrhythmia_bert.mat', {
'X': embeddings_array,
'y': y_data
})
print("Dataset saved successfully.")