-
Notifications
You must be signed in to change notification settings - Fork 3
/
knn.py
113 lines (85 loc) · 3.53 KB
/
knn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import common
import cv2
import numpy as np
import random
from functools import reduce
import ocr
import neighbour
class KnnOCR(ocr.OCR):
def __init__(self, train_set=None, dump=None, load=None, glyphs=ocr.GLYPHS, verbose=False):
self.verbose = verbose
self.glyphs = glyphs
self.knn = self.__train(train_set, dump, load)
def read(self, input, k=5):
"""Reads the number in the input image passed, k is the knn paramether"""
return self.__nearest(neighbour.clean2(input), self.knn, k, self.glyphs, verbose=self.verbose)
def __unpackage(self, train_set):
data, labels = list(), list()
for (l, d) in train_set:
data.append(d)
labels.append(l)
return np.array(data), np.array(labels)
@common.showtime
def __train(self, train_set=None, dump=None, load=None):
"""Trains a knn with the given train_set size of samples.
if dump is a file path (without estention) it'll save the trainset there
if load is a file path (without estention) it'll load the trainset from there
train_set will be ignored if these are defined
returns the knn object
"""
if load is not None:
with np.load("{}.npz".format(load)) as save:
data, labels = save["data"], save["labels"]
else:
t = ocr.OCR.get_train_set(train_set, verbose=self.verbose)
data, labels = self.__unpackage(t)
data = np.array([neighbour.clean2(d) for d in data])
size = reduce(lambda a, b: a*b, data[0].shape)
data = data.reshape(-1, size).astype(np.float32)
knn = cv2.ml.KNearest_create()
knn.train(data, cv2.ml.ROW_SAMPLE, labels)
if dump is not None:
np.savez("{}.npz".format(dump), data=data, labels=labels)
return knn
@common.showtime
def __nearest(self, input, knn, k, glyphs, verbose=False):
"""Given a knn object and an input mask will return the label of the mask for the curren knn training"""
samp = np.array(input).reshape(1, input.shape[0]*input.shape[1]).astype(np.float32)
ret, res, neigh, dist = knn.findNearest(samp, k=k)
if verbose:
print("r:{}, res:{}, neigh:{}, dist:{}".format(ret, res, neigh, dist))
lbl = ocr.OCR.delabelize(glyphs)[res[0][0]]
return lbl
def __enter__(self):
return self
def __exit__(self, type, value, tb):
pass
if __name__ == "__main__":
import generator
import extract
gen = True # if True will calculate a new trainset
size = 1 #size of the trainset
TOT = 30 #size of the testset
if gen:
d="data_set"
size = 1
s = ocr.OCR.get_train_set(size, verbose=True)
l = None
assert len(s)==len(ocr.GLYPHS)*size, "Must generate the correct number of element ({}, {})".format(len(s), len(ocr.GLYPHS)*size)
else:
d=None
s = None
l = "data_set"
print("Trained")
with KnnOCR(dump=d, load=l, train_set=size, verbose=True) as o:
assert o is not None, "A new object must be created"
res = 0
for i in range(TOT):
c = str(i%10)
t, _ = extract.get_optimal_mask(generator.get_all_tables(c)[c])
r = o.read(t)
assert r is not None, "ocr.read(...) must yield a result"
print("{} {}".format(r, c))
if r == c:
res+=1
print("Accuracy: {:.2f} ({}/{})".format(res/TOT, res, TOT))