-
Notifications
You must be signed in to change notification settings - Fork 4
/
minhash.py
337 lines (279 loc) · 13.2 KB
/
minhash.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
import json
import multiprocessing as mp
import re
from collections import defaultdict
from functools import partial
from typing import Dict, List, Optional, Set, Tuple, Type
from datasets import Dataset
from datasketch import MinHash, MinHashLSH
from tqdm import tqdm
from functools import partial
import sys
import queue
import threading
from typing import TypeVar, Iterator, List, Optional, Tuple
T = TypeVar('T')
class ThreadedIterator(Iterator[T]):
"""An iterator object that computes its elements in a single parallel thread to be ready to be consumed.
The iterator should *not* return `None`. Elements of the original iterable will be shuffled arbitrarily."""
def __init__(self, original_iterator: Iterator[T], max_queue_size: int = 2, enabled: bool = True):
self.__is_enabled = enabled
if enabled:
self.__queue = queue.Queue(maxsize=max_queue_size) # type: queue.Queue[Optional[T]]
self.__thread = threading.Thread(target=lambda: self.__worker(self.__queue, original_iterator), daemon=True)
self.__thread.start()
else:
self.__original_iterator = original_iterator
@staticmethod
def __worker(queue: queue.Queue, original_iterator: Iterator[T])-> None:
try:
for element in original_iterator:
assert element is not None, 'By convention, Iterables wrapped in ThreadedIterator may not contain None.'
queue.put(element, block=True)
queue.put(None, block=True)
except Exception as e:
_, __, tb = sys.exc_info()
queue.put((e, tb), block=True)
def __next__(self) -> T:
next_element = self.__queue.get(block=True)
if next_element is None:
self.__thread.join()
self.__queue.put(None) # Make sure that we remember that we are done if we are called once more...
raise StopIteration
if isinstance(next_element, tuple) and isinstance(next_element[0], Exception):
raise next_element[0].with_traceback(next_element[1])
return next_element
def __iter__(self):
if self.__is_enabled:
return self
else:
return iter(self.__original_iterator)
NON_ALPHA = re.compile("[^A-Za-z_0-9]")
# parameters used in DuplicationIndex
MIN_NUM_TOKENS = 10
NUM_PERM = 256
def get_min_hash(tokens: List[str]) -> Optional[MinHash]:
"""Compute the MinHash of a code snippet."""
if len(tokens) < MIN_NUM_TOKENS:
return None
min_hash = MinHash(num_perm=NUM_PERM)
for token in set(tokens):
min_hash.update(token.encode())
return min_hash
def get_tokens(code: str) -> Set[str]:
"""Tokenize a code snippet."""
return {t for t in NON_ALPHA.split(code) if len(t.strip()) > 0}
class DuplicationIndex:
def __init__(
self,
*,
duplication_jaccard_threshold: float = 0.85,
):
self._duplication_jaccard_threshold = duplication_jaccard_threshold
self._num_perm = NUM_PERM
self._index = MinHashLSH(threshold=self._duplication_jaccard_threshold, num_perm=self._num_perm)
self._duplicate_clusters = defaultdict(set)
def add(self, code_key: Tuple, min_hash: MinHash) -> None:
"""Add a key to _index (MinHashLSH)
the min_hash is used to query closest matches based on the jaccard_threshold.
The new key is either added to a existing cluster of one close match,
or a new cluster is created. The clusters created in this way, depend on the order of add.
Args:
code_key (Tuple of (index, repo_name, path)):
Theoritically any hasbale key. Here we use a tuple to retrieve the information later.
min_hash: MinHash of the code_key.
"""
close_duplicates = self._index.query(min_hash)
if code_key in self._index.keys:
print(f"Duplicate key {code_key}")
return
self._index.insert(code_key, min_hash)
if len(close_duplicates) > 0:
for base_duplicate in close_duplicates:
if base_duplicate in self._duplicate_clusters:
self._duplicate_clusters[base_duplicate].add(code_key)
break
else:
self._duplicate_clusters[close_duplicates[0]].add(code_key)
def get_duplicate_clusters(self) -> List[List[Dict]]:
"""Export the duplicate clusters.
For each cluster, the first element is the base element of the cluster.
The base element has an estimation jaccard similarity higher than the threshold with all the other elements.
Returns:
duplicate_clusters (List[List[Dict]]):
List of duplicate clusters.
"""
duplicate_clusters = []
for base, duplicates in self._duplicate_clusters.items():
cluster = [base] + list(duplicates)
# reformat the cluster to be a list of dict
cluster = [{"base_index": el[0], "repo_name": el[1], "path": el[2]} for el in cluster]
duplicate_clusters.append(cluster)
return duplicate_clusters
def save(self, filepath) -> None:
duplicate_clusters = self.get_duplicate_clusters()
with open(filepath, "w") as f:
json.dump(duplicate_clusters, f)
def _compute_min_hash(element):
index, data = element
min_hash = get_min_hash([t for t in NON_ALPHA.split(data["code"]) if len(t.strip()) > 0])
if min_hash is not None:
#can supply later and bookkept by get_duplicate_clusters
return (index, "norepo", "nopath"), min_hash
def minhash_iter(dataset_iterator: Type[Dataset]):
with mp.Pool() as pool:
for data in pool.imap_unordered(
_compute_min_hash,
ThreadedIterator(dataset_iterator, max_queue_size=10000),
chunksize=100,
):
if data is not None:
yield data
def make_duplicate_clusters(dataset_iterator: Type[Dataset], jaccard_threshold: float):
"""Find duplicate clusters in the dataset in two steps:
1. Compute MinHash for each code snippet. MinHash is a tool for fast jaccard similarity estimation.
This step is computed using an asynchronous multiprocessing pool, minhash_iter
2. Find duplicate clusters. The computed MinHash is added sequentially to the DuplicationIndex.
This step cannot be parallelized. So using asynchronous thread in the previous step helps to speed up the process.
"""
di = DuplicationIndex(duplication_jaccard_threshold=jaccard_threshold)
for filename, min_hash in tqdm(ThreadedIterator(minhash_iter(enumerate(dataset_iterator)), max_queue_size=100)):
di.add(filename, min_hash)
# Returns a List[Cluster] where Cluster is List[str] with the filenames.
return di.get_duplicate_clusters()
def jaccard_similarity(code1: str, code2: str) -> float:
"""Compute the Jaccard similarity of two code snippets."""
tokens1 = get_tokens(code1)
tokens2 = get_tokens(code2)
return len(tokens1 & tokens2) / len(tokens1 | tokens2)
_shared_dataset = None
def _find_cluster_extremes_shared(cluster, jaccard_threshold):
"""Find a reduced cluster such that each code in the origin cluster is similar to at least one code in the reduced cluster.
Two codes are similar if their Jaccard similarity is above the threshold.
Args:
cluster (List[dict]):
cluster is a list of dict, each dict contains the following keys:
- base_index
- repo_name
- path
This is a typical output of DuplicationIndex.get_duplicate_clusters()
jaccard_threshold (float):
threshold for Jaccard similarity.
Two codes are similar if their Jaccard similarity is above the threshold.
Returns:
extremes (List[dict]):
A reduced representation of the cluster. The field copies is added to each dict.
The copies field indicates the number of similar codes in the cluster for a extreme.
"""
extremes = []
for element1 in cluster:
code1 = _shared_dataset[element1["base_index"]]["code"]
for element2 in extremes:
code2 = _shared_dataset[element2["base_index"]]["code"]
if jaccard_similarity(code1, code2) >= jaccard_threshold:
element2["copies"] += 1
break
else:
element1["copies"] = 1
extremes.append(element1)
return extremes
def find_extremes(cluster_list, dataset, jaccard_threshold):
"""Call the _find_cluster_extremes_shared function in a parallel fashion.
Args:
cluster_list (List[List[Dict]]):
each cluster is a list of dicts with the key base_index,
referring to the index of the base code in the dataset.
dataset (Type[Dataset]):
dataset is used to access the content of the code snippets,
using the base_index from the cluster_list.
dataset is shared between all the processes using a glabal variable (any other way to share the dataset?),
otherwise the multi processing is not speeded up.
jaccard_threshold (float):
the threshold for the jaccard similarity. The default value is 0.85
Returns:
extremes_list (List[Dict]):
Each cluster is reduced to extremes.
See _find_cluster_extremes_shared for the definition of extremes.
"""
global _shared_dataset
_shared_dataset = dataset
extremes_list = []
f = partial(_find_cluster_extremes_shared, jaccard_threshold=jaccard_threshold)
with mp.Pool() as pool:
for extremes in tqdm(
pool.imap_unordered(
f,
cluster_list,
),
total=len(cluster_list),
):
extremes_list.append(extremes)
return extremes_list
def deduplicate_dataset(
dataset: Type[Dataset], jaccard_threshold: float = 0.85
) -> Tuple[Type[Dataset], List[List[Dict]]]:
"""Deduplicate the dataset using minhash and jaccard similarity.
This function first generate duplicate clusters, then each cluster
is reduced to the extremes that are similar to the other elements in the cluster.
Codes are called similar if their Jaccard similarity is greater than jaccard_threshold (0.85 default).
Args:
dataset (Type[Dataset]):
The dataset to deduplicate.
jaccard_threshold (float, default=0.85):
jaccard threshold to determine if two codes are similar
Returns:
ds_dedup (Type[Dataset]):
The deduplicated dataset.
duplicate_clusters (List[List[Dict]]):
The list of duplicate clusters.
Each cluster is a list of dicts with the following keys:
- base_index : int
The index of the code in the original dataset.
- repo_name : str
- path : str
- copies : int
The number of copies of the code in the cluster. (find_cluster_extremes)
- is_extreme : bool
Whether the code is an extreme in the cluster.
All the codes in the cluster are removed from the dataset except the extremes.
Example:
>>> from datasets import load_dataset
>>> from minhash_deduplication import deduplicate_dataset
>>> ds = load_dataset("lvwerra/codeparrot-clean", split="train")
>>> ds_dedup, duplicate_clusters = deduplicate_dataset(ds, jaccard_threshold=0.85)
"""
duplicate_clusters = make_duplicate_clusters(dataset, jaccard_threshold)
duplicate_indices = {x["base_index"] for cluster in duplicate_clusters for x in cluster}
extreme_dict = {}
extremes_clusters = find_extremes(duplicate_clusters, dataset, jaccard_threshold)
for extremes in extremes_clusters:
for element in extremes:
extreme_dict[element["base_index"]] = element
remove_indices = duplicate_indices - set(extreme_dict.keys())
ds_filter = dataset.filter(lambda x, idx: idx not in remove_indices, with_indices=True)
# update duplicate_clusters
for cluster in duplicate_clusters:
for element in cluster:
element["is_extreme"] = element["base_index"] in extreme_dict
if element["is_extreme"]:
element["copies"] = extreme_dict[element["base_index"]]["copies"]
print(f"Original dataset size: {len(dataset)}")
print(f"Number of duplicate clusters: {len(duplicate_clusters)}")
print(f"Files in duplicate cluster: {len(duplicate_indices)}")
print(f"Unique files in duplicate cluster: {len(extreme_dict)}")
print(f"Filtered dataset size: {len(ds_filter)}")
return ds_filter, duplicate_clusters
if __name__ == "__main__":
import argparse
from datasets import load_dataset
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="GaTech-EIC/MG-Verilog")
args = parser.parse_args()
ds = load_dataset(args.dataset)
#deduplicate based on the "code" key
ds_test = ds["train"]
ds_dedup, duplicate_clusters = deduplicate_dataset(ds_test, jaccard_threshold=0.85)
# print(duplicate_clusters)
print("Duplicate clusters: ", len(duplicate_clusters))
print("Original dataset size: ", len(ds_test))
print("Deduplicated dataset size: ", len(ds_dedup))