-
Notifications
You must be signed in to change notification settings - Fork 0
/
prep.py
134 lines (119 loc) · 4.02 KB
/
prep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#!/usr/bin/env python3
# coding: utf-8
"""
title: main.py
date: 2019-11-23
author: jskrable
description: Preprocessing
"""
import os
import re
import json
from pyspark import SparkContext
from pyspark.sql import SQLContext, Row
from pyspark.sql.types import ArrayType, StringType
from pyspark.sql.functions import lit, when, col, udf, concat
from pyspark.ml.feature import NGram, CountVectorizer, CountVectorizerModel, IndexToString, StringIndexer, VectorIndexer
def progress(count, total, suffix=''):
"""
Progress bar for cli
"""
bar_len = 60
filled_len = int(round(bar_len * count / float(total)))
percents = round(100.0 * count / float(total), 1)
bar = '#' * filled_len + '-' * (bar_len - filled_len)
sys.stdout.write('[%s] %s%s %s\r' % (bar, percents, '%', suffix))
sys.stdout.flush()
def parse_data(file):
"""
Function to read a psuedo-json file line by line and
return a generator object to save CPU and mem.
Wrap response in a list() to subscribe
"""
for l in open(file, 'r'):
yield json.loads(l)
def get_source_data(base_dir):
"""
Function to read a directory full of psuedo-json files
and return a list of objects. Objects are structured as follows:
article_link: http link to original article
headline: string headline, special characters intact
is_sarcastic: int, 1 for sarcasm, 0 for serious
"""
for d, _, f in os.walk(base_dir):
files = [os.path.join(d,file) for file in f]
data = [list(parse_data(f)) for f in files]
data = [item for sublist in data for item in sublist]
return data
def count_vectorizer(df, col, train=False):
"""
Function to take in a df of headlines and tranform to a
word count vector. Simple bag of words method.
Requires headline to be a list of words. Returns a df
with an additional vector column.
"""
if train:
cv = CountVectorizer(
inputCol=col,
outputCol='vector',
vocabSize=50000)
model = cv.fit(df)
print('Saving count vectorizer model to disk...')
model.save('./cv_model')
else:
model = CountVectorizerModel.load('./cv_model')
df = model.transform(df)
return df
def label_indexer(df, col):
"""
Function to take in a df, index the class label column,
and return the df w/ a new indexedLabel column.
"""
labelIndexer = StringIndexer(
inputCol=col,
outputCol="indexedLabel").fit(df)
df = labelIndexer.transform(df)
return df
def n_grams(df, col, n=2):
"""
Function to take in a df with a list of words and convert to
list of n-grams.
"""
ngram = NGram(
n=n,
inputCol=col,
outputCol="ngrams")
df = ngram.transform(df)
return df
def preprocessing(sql, data, train=True):
"""
Function to take in a list of dicts containing string
headlines and return a df containing indexed labels and
vectorized features.
"""
# convert input data to spark dataframe
# print('Creating dataframe...')
df = sql.createDataFrame(Row(**entry) for entry in data)
# print('Cleaning headlines...')
# allow only alphabetic characters
regex = re.compile('[^a-zA-Z]')
clean_headline = udf(lambda x:
regex.sub(' ', x).lower().split(), ArrayType(StringType()))
df = df.withColumn('cleanHeadline', clean_headline(df.headline))
df = n_grams(df, 'cleanHeadline')
concat = udf(lambda x,y : x + y, ArrayType(StringType()))
df = df.withColumn('gramList', concat(df.cleanHeadline,df.ngrams))
# print('Vectorizing headlines...')
# get a sparse vector of dictionary word counts
# choose to use n-grams or list here
# df = count_vectorizer(df, 'cleanHeadline')
# df = count_vectorizer(df, 'ngrams')
df = count_vectorizer(df, 'gramList', train)
if train:
# index label column
print('Indexing labels...')
df = label_indexer(df, 'is_sarcastic')
train, test = df.randomSplit([0.7,0.3])
return train, test
else:
return df