-
Notifications
You must be signed in to change notification settings - Fork 0
/
Kellogg_BERTopic.py
129 lines (111 loc) · 4.39 KB
/
Kellogg_BERTopic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import pandas as pd
from bertopic import BERTopic
from datetime import datetime
from sentence_transformers import SentenceTransformer
import os
import plotly.io as pio
from umap import UMAP
# Get the current time in UTC
time_start = datetime.now()
folder_path = r"/Users/hjr7324/Desktop/Kellogg_Dissertations"
if not os.path.exists(folder_path + '/results'): # create a results folder
os.mkdir(folder_path + '/results')
# Load the CSV file
df = pd.read_csv(os.path.join(folder_path, 'matrix_full.csv'))
df['Department'] = df['Department'].str.strip()
df.set_index('GOID', inplace=True)
year = df['Year']
df.drop(['Year'], axis=1, inplace=True)
department = df['Department']
unique_classes = department.unique()
bertopic_models = {}
# Save the topics to a text file
from plotly.subplots import make_subplots
visualizations = []
with open(os.path.join(folder_path, f'results/bertopic_topics.txt'), 'w') as file:
for class_label in unique_classes:
print(f'Processing {class_label}')
dep_df = df[df['Department'] == class_label]
# document_ids = dep_df['ID'].tolist()
documents = []
for _, row in df.iloc[:, 1:].iterrows():
doc = ' '.join([f"{word} " * freq for word, freq in row.items() if freq > 0])
documents.append(doc.strip())
# Create a dictionary to map document IDs to their texts
# document_texts = dict(zip(document_ids, documents))
# vectorizer_model = CountVectorizer(ngram_range=(1, 2), stop_words="english")
# Load pre-trained Sentence Transformer model
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# Generate embeddings for the documents
# embeddings = model.encode(class_sub.tolist(), show_progress_bar=True)
embeddings = model.encode(documents, show_progress_bar=True)
# Initialize the BERTopic model
topic_model = BERTopic(
# vectorizer_model=vectorizer_model,
language='english', calculate_probabilities=True,
verbose=True)
# Fit the model to your text data
topics, probs = topic_model.fit_transform(documents, embeddings)
# Get the topics
topic_info = topic_model.get_topic_info()
# print(f"Subreddit: {cls}")
file.write(f'Department: {class_label}:\n')
for index, row in topic_info.iterrows():
file.write(f"Topic {row['Topic']} - Count {row['Count']} : {row['Name']}\n")
topic_words = topic_model.get_topic(row['Topic'])
# topic_model.visualize_barchart()
for word, _ in topic_words:
file.write(f" {word}: {_}\n")
file.write("\n---------------\n")
file.write("\n")
plotly_fig = topic_model.visualize_barchart()
# plotly_fig = topic_model.visualize_documents(documents, embeddings=embeddings)
# reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings)
# plotly_fig = topic_model.visualize_documents(class_sub, reduced_embeddings=reduced_embeddings)
# plotly_fig.update_layout({class_label})
# fig.update_layout(title_text=f"Topic Model for Submissions - {class_label}")
html_str = pio.to_html(plotly_fig, full_html=False)
visualizations.append([class_label, html_str])
# fig.show()
# fig.write_html(f"{token_path}/bertopic_sub_{class_label}.html")
# fig.write_html(f"{token_path}/bertopic_submission.html")
html_content = """
<!DOCTYPE html>
<html>
<head>
<style>
.container {{
display: flex;
flex-wrap: wrap;
gap: 20px;
justify-content: center;
}}
.item {{
flex: 1;
min-width: 300px;
max-width: 500px;
}}
</style>
</head>
<body>
<div class="container">
"""
# Embed each visualization
for cls, html_string in visualizations:
html_content += f"""
<h2>BERTopic Visualization for {cls}</h2>
{html_string}
"""
html_content += """
</div>
</body>
</html>
"""
# Save the combined HTML file
with open(f"{folder_path}/results/BERTopic_vis_bar.html", "w") as f:
f.write(html_content)
print(f"BERTopic topics saved to bertopic_topics.txt")
# Get the current time in UTC
time_now = datetime.now()
print(f"Execution end: {time_now}")
print(f"Time taken: {time_now - time_start}")