-
Notifications
You must be signed in to change notification settings - Fork 3
/
train.py
119 lines (103 loc) · 3.55 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import numpy as np
np.random.seed(1)
from matplotlib import pyplot as plt
import skimage.data
from skimage.color import rgb2gray
from skimage.filters import threshold_mean
from skimage.transform import resize
import network
import os
# Helper Functions
def get_corrupted_input(input, corruption_level):
corrupted = np.copy(input)
inv = np.random.binomial(n=1, p=corruption_level, size=len(input))
for i, v in enumerate(input):
if inv[i]:
corrupted[i] = -1 * v
return corrupted
def reshape(data):
dim = int(np.sqrt(len(data)))
data = np.reshape(data, (dim, dim))
return data
def split(l, n):
for i in range(0,len(l), n):
yield l[i:i+n]
def plot(data, test, predicted, figsize=(5, 6)):
data = [reshape(d) for d in data]
test = [reshape(d) for d in test]
predicted = [reshape(d) for d in predicted]
if not os.path.exists('results'):
os.mkdir('results')
count=0
file_count=0
for d in split(data, 4):
if not len(d)is 1:
fig, axarr = plt.subplots(len(d), 3)
for i in range(len(d)):
if i==0:
axarr[i, 0].set_title('Train data')
axarr[i, 1].set_title("Input data")
axarr[i, 2].set_title('Output data')
axarr[i, 0].imshow(data[count])
axarr[i, 0].axis('off')
axarr[i, 1].imshow(test[count])
axarr[i, 1].axis('off')
axarr[i, 2].imshow(predicted[count])
axarr[i, 2].axis('off')
count = count+1
plt.tight_layout()
plt.savefig("results/result_"+str(file_count)+".png")
file_count=file_count+1
plt.show()
else:
fig, axarr = plt.subplots(1, 3)
axarr[0].set_title('Train data')
axarr[1].set_title("Input data")
axarr[2].set_title('Output data')
axarr[0].imshow(data[count])
axarr[0].axis('off')
axarr[1].imshow(test[count])
axarr[1].axis('off')
axarr[2].imshow(predicted[count])
axarr[2].axis('off')
count = count+1
plt.tight_layout()
plt.savefig("results/result_"+str(file_count)+".png")
file_count=file_count+1
plt.show()
def preprocessing(img, w=128, h=128):
# Resize image
img = resize(img, (w,h), mode='reflect')
# Thresholding
thresh = threshold_mean(img)
binary = img > thresh
shift = 2*(binary*1)-1 # Boolian to int
# Reshape
flatten = np.reshape(shift, (w*h))
return flatten
def main():
# Load data
import cv2
import glob
img_dir = "train/" # Enter Directory of all images
data_path = os.path.join(img_dir,'*g')
files = glob.glob(data_path)
data = []
for f1 in files:
img = rgb2gray(cv2.imread(f1))
data.append(img)
# Preprocessing
print("Start to data preprocessing...")
data = [preprocessing(d) for d in data]
# Create Hopfield Network Model
model = network.HopfieldNetwork()
model.train_weights(data)
# Generate testset
test = [get_corrupted_input(d, 0.3) for d in data]
predicted = model.predict(test, threshold=0, asyn=False)
print("Show prediction results...")
plot(data, test, predicted)
#print("Show network weights matrix...")
#model.plot_weights("results/")
if __name__ == '__main__':
main()