-
Notifications
You must be signed in to change notification settings - Fork 0
/
example.py
59 lines (41 loc) · 1.34 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import fastImgClassifier as imgClassifier
import numpy as np
import pickle
from keras.datasets import mnist
from tqdm import tqdm
#create a model with an input shape of 28x28 and 10 outputs
model = imgClassifier.Classifier((28, 28), 10)
#load mnist data
(trainX, trainY),(testX, testY) = mnist.load_data()
#train a model with input and its label (index of an output neuron)
for img, label in tqdm(zip(trainX, trainY)):
model.train(img, label)
#test a model for its accuracy
corrects = 0
for img, label in zip(testX, testY):
answer = np.argmax(model.classify(img))
if label == answer:
corrects += 1
#print out accuracy
accuracy = corrects / len(testX)
print(f"accuracy: {accuracy * 100}%")
file_path = "digitClassifier"
#use pickle library to save the model
data = pickle.dumps(model)
with open(file_path, "wb") as f:
f.write(data)
#load model from a file
with open(file_path, "rb") as f:
model = pickle.dumps(f.read())
#doing the same thing with fastImgClassifier
import fastImgClassifier
model = fastImgClassifier.Classifier((28,28), 10)
model.trainAll(trainX, trainY)
accuracy = model.evaluate(testX, testY)
print(f"accuracy: {accuracy * 100}%")
file_path = "digitClassifier"
data = pickle.dumps(model)
with open(file_path, "wb") as f:
f.write(data)
with open(file_path, "rb") as f:
model = pickle.loads(f.read())