-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
190 lines (157 loc) · 7.21 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
# -*- coding: utf-8 -*-
"""
This module is intended to join all the pipeline in separated tasks
to be executed individually or in a flow by using command-line options
Example:
Dataset embedding and processing:
$ python taskflows.py -e -pS
"""
from gensim.models.word2vec import Word2Vec
from argparse import ArgumentParser
import argparse
import shutil
import gc
import src.utils.functions.cpg as cpg
import src.utils.log as logger
import src.prepare as prepare
import src.process as process
import src.data as data
# Load config
import configs
PATHS = configs.Paths()
FILES = configs.Files()
DEVICE = FILES.get_device()
def select(dataset, project_name=None, length=1200):
"""
Filter `dataset` by `project_name` and `length` (of source code)
:param dataset: raw DataFrame (of `data/raw/`)
:param project_name: "FFmpeg" (default) or "qemu"
:param length: just filter data with soure code (`func`) < `length` (to speed up trainning proccess). It should be 1200 by default.
"""
# filter `project_name`
if project_name is None:
result = dataset
else:
result = dataset.loc[dataset["project"] == project_name]
# filter `length`
len_filter = result.func.str.len() < length
result = result.loc[len_filter]
return result.head(200)
def create_task():
"""
This task will convert raw source code datasets to CPG
"""
context = configs.Create()
##########################################
# Step 1: read + filter raw datasets #
##########################################
logger.log_info(logger_name="create_task", msg="Step 1: read + filter raw datasets.")
raw = data.read(PATHS.raw, FILES.raw)
filtered = data.apply_filter(raw, select)
filtered = data.clean(filtered)
data.drop(filtered, ["commit_id", "project"])
slices = data.slice_frame(filtered, context.slice_size)
slices = [(s, slice.apply(lambda x: x)) for s, slice in slices]
########################################
# Step 2: creaate CPG binary files #
########################################
logger.log_info(logger_name="create_task", msg="Step 2: creaate CPG binary files.")
cpg_files = []
for s, slice in slices:
data.to_files(slice, PATHS.joern)
cpg_file = prepare.joern_parse(context.joern_cli_dir, PATHS.joern, PATHS.cpg, f"{s}_{FILES.cpg}")
cpg_files.append(cpg_file)
logger.log_info(logger_name="create_task",msg=f"Dataset {s} to cpg.")
shutil.rmtree(PATHS.joern)
#################################################
# Step 3: Create CPG with graphs json files #
#################################################
logger.log_info(logger_name="create_task", msg="Step 3: Create CPG with graphs json files.")
json_files = prepare.joern_create(context.joern_cli_dir, PATHS.cpg, PATHS.cpg, cpg_files)
for (s, slice), json_file in zip(slices, json_files):
graphs = prepare.json_process(PATHS.cpg, json_file)
if graphs is None:
logger.log_info(logger_name="create_task",msg=f"Dataset chunk {s} not processed.")
continue
dataset = data.create_with_index(graphs, ["Index", "cpg"])
dataset = data.inner_join_by_index(slice, dataset)
file_chunk = f"{s}_{FILES.cpg}.pkl"
logger.log_warning(logger_name="create_task",msg=f"Saving CPG dataset chunk {s} -> {file_chunk}.")
data.write(dataset, PATHS.cpg, file_chunk)
del dataset
gc.collect()
def embed_task():
"""
This task...
"""
context = configs.Embed()
# Tokenize source code into tokens
dataset_files = data.get_directory_files(PATHS.cpg)
w2vmodel = Word2Vec(**context.w2v_args)
w2v_init = True
for pkl_file in dataset_files:
file_name = pkl_file.split(".")[0]
cpg_dataset = data.load(PATHS.cpg, pkl_file)
tokens_dataset = data.tokenize(cpg_dataset)
data.write(tokens_dataset, PATHS.tokens, f"{file_name}_{FILES.tokens}")
# word2vec used to learn the initial embedding of each token
w2vmodel.build_vocab(tokens_dataset.tokens, update=not w2v_init)
w2vmodel.train(tokens_dataset.tokens, total_examples=w2vmodel.corpus_count, epochs=1)
if w2v_init:
w2v_init = False
# Embed cpg to node representation and pass to graph data structure
cpg_dataset["nodes"] = cpg_dataset.apply(lambda row: cpg.parse_to_nodes(row.cpg, context.nodes_dim), axis=1)
# remove rows with no nodes
cpg_dataset = cpg_dataset.loc[cpg_dataset.nodes.map(len) > 0]
cpg_dataset["input"] = cpg_dataset.apply(lambda row: prepare.nodes_to_input(row.nodes, row.target, context.nodes_dim,
w2vmodel.wv, context.edge_type), axis=1)
data.drop(cpg_dataset, ["nodes"])
logger.log_warning(logger_name="embed_task", msg=f"Saving input dataset {file_name} with size {len(cpg_dataset)}.")
data.write(cpg_dataset[["input", "target"]], PATHS.input, f"{file_name}_{FILES.input}")
del cpg_dataset
gc.collect()
logger.log_warning(logger_name="embed_task", msg="Saving w2vmodel.")
w2vmodel.save(f"{PATHS.w2v}/{FILES.w2v}")
def process_task(stopping):
context = configs.Process()
devign = configs.Devign()
model_path = PATHS.model + FILES.model
model = process.Devign(path=model_path, device=DEVICE, model=devign.model, learning_rate=devign.learning_rate,
weight_decay=devign.weight_decay,
loss_lambda=devign.loss_lambda)
train = process.Train(model, context.epochs)
input_dataset = data.loads(PATHS.input)
# split the dataset and pass to DataLoader with batch size
train_loader, val_loader, test_loader = list(
map(lambda x: x.get_loader(context.batch_size, shuffle=context.shuffle),
data.train_val_test_split(input_dataset, shuffle=context.shuffle)))
train_loader_step = process.LoaderStep("Train", train_loader, DEVICE)
val_loader_step = process.LoaderStep("Validation", val_loader, DEVICE)
test_loader_step = process.LoaderStep("Test", test_loader, DEVICE)
if stopping:
early_stopping = process.EarlyStopping(model, patience=context.patience)
train(train_loader_step, val_loader_step, early_stopping)
model.load()
else:
train(train_loader_step, val_loader_step)
model.save()
process.predict(model, test_loader_step)
if __name__ == "__main__":
"""
main function that executes tasks based on command-line options
"""
parser: ArgumentParser = argparse.ArgumentParser()
# parser.add_argument('-p', '--prepare', help='Prepare task', required=False)
parser.add_argument('-c', '--create', action='store_true')
parser.add_argument('-e', '--embed', action='store_true')
parser.add_argument('-p', '--process', action='store_true')
parser.add_argument('-pS', '--process_stopping', action='store_true')
args = parser.parse_args()
if args.create:
create_task()
if args.embed:
embed_task()
if args.process:
process_task(False)
if args.process_stopping:
process_task(True)