-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
118 lines (92 loc) · 2.94 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import numpy as np
import pandas as pd
import copy
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm, Normalize
def read_text(path):
"""Read article.
Parameters
----------
path : str
Path to file.
Returns
-------
text : str
The text of the article.
"""
#print(path)
text = np.nan
try:
with open(path, 'r') as f:
lines = f.readlines()
lines = [line.strip() for line in lines]
lines = list(filter(lambda x: x != "", lines))
text = " ".join(lines)
except Exception as e:
print(f"skipping {path} due to: {e}")
finally:
return text
class DataReader():
def __init__(self, data_path, spark):
"""Init.
Parameters
----------
data_path : str
Path to the data folder.
spark : pyspark.sql.session.SparkSession
Spark session.
"""
self.data_path = data_path
self.spark = spark
def __call__(self, topic_names):
"""Read data.
Parameters
----------
topic_names : list
List of topics among articles.
Returns
-------
df_data_all : spark.DataFrame
DataFrame of data with columns: path (path to file), topic (topic of article e.g.: business),
text (the text of the article), id (unique id of the article).
"""
data = {}
paths = []
topics = []
texts = []
for topic_name in topic_names:
dir_path = os.path.join(self.data_path, topic_name)
for dirpath, dirnames, filenames in os.walk(dir_path):
paths_ = [os.path.join(dirpath, filename) for filename in filenames]
topics_ = [topic_name for i in range(len(paths_))]
texts_ = [read_text(path_) for path_ in paths_]
paths += paths_
topics += topics_
texts += texts_
data["path"] = paths
data["topic"] = topics
data["text"] = texts
data["id"] = [i for i in range(len(texts))]
pd_df = pd.DataFrame(data=data)
pd_df.dropna(inplace=True)
df_data_all = self.spark.createDataFrame(pd_df, list(pd_df.columns.values))
return df_data_all
def make_heatmap(rdd_collected, title, epoch_time):
my_cmap = copy.copy(plt.cm.get_cmap('plasma'))
my_cmap.set_bad((0, 0, 0))
# rdd_collected = rdd.collect()
n = len(rdd_collected)
heatmap = np.zeros((n, n))
for idx_l, l in enumerate(rdd_collected):
for idx_t, t in enumerate(l):
idx, s = t
assert idx_t == idx
heatmap[idx_l, idx_t] = s
plt.figure(figsize=(10, 10))
sns.heatmap(heatmap, square=True, norm=LogNorm(), cmap=my_cmap)
plt.title(title)
fname = title + "_" + str(epoch_time) + ".png"
plt.savefig(f"assets/{fname}")
plt.plot()