-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathpreprocess_data.py
53 lines (43 loc) · 2.02 KB
/
preprocess_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import uuid
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.functions import udf, lit, instr, col
### Preprocess transaction data ###
# This preprocess job gives each transaction an ID,
# and creates two files for later graph construction.
#
# - transactions.parquet
# transaction data with UUID trancastion ID
# - edges.parquet
# relationship between transactions and customers
###
spark = SparkSession.builder.getOrCreate()
# input data file
dataFile = "/data/fraud/transactions.csv"
# preprocessed data file
transactionFile = "/data/fraud/preprocessed/transactions.parquet"
edgeFile = "/data/fraud/preprocessed/edges.parquet"
### UDFs ###
# UUID generator for each transaction
gen_uuid = udf(lambda: str(uuid.uuid1()), StringType())
# CSV schema
schema = StructType().add("step", IntegerType()) \
.add("type", StringType()) \
.add("amount", DecimalType(10,2)) \
.add("nameOrig", StringType()) \
.add("oldbalanceOrg", DecimalType(10,2)) \
.add("newbalanceOrig", DecimalType(10,2)) \
.add("nameDest", StringType()) \
.add("oldbalanceDest", DecimalType(10,2)) \
.add("newbalanceDest", DecimalType(10,2)) \
.add("isFraud", IntegerType()) \
.add("isFlaggedFraud", IntegerType())
rawTransactions = spark.read.csv(dataFile, header=True, schema=schema) \
.withColumn("tranId", gen_uuid())
rawTransactions.write.format("parquet").save(transactionFile)
# reload saved transactions to get the same transaction id
rawTransactions = spark.read.parquet(transactionFile)
rawEdges = rawTransactions.rdd.flatMap(lambda r: [(r["nameOrig"], r["tranId"], "from", r["type"]),
(r["tranId"], r["nameDest"], "to", r["type"])]) \
.toDF(["src", "dst", "~label", "type"])
rawEdges.write.format("parquet").save(edgeFile)