-
Notifications
You must be signed in to change notification settings - Fork 0
/
exploreflux.jl
100 lines (82 loc) · 2.34 KB
/
exploreflux.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
using Images
using Flux
using Flux: throttle
using FileIO
using DataFrames
using Statistics
using Serialization
function normalize(filenm::AbstractString)
println("normalizing $filenm")
img = FileIO.load(filenm)
imgg = Float64.(Gray.(img))
sigma = std(imgg)
imgg = imfilter(imgg, Kernel.gaussian(sigma))
img = imresize(imgg,100,100)
return img
end
# set up training and test sets
# dataframe with:
# dir path
# image name as UUID
# width, height, color/gray
# size
# keep=0|1
# load image
# scale it down to 100x100
# flux it!
basedir = splitdir(@__DIR__)[1]
tstd = joinpath(basedir,"ml","tst")
trd = joinpath(basedir,"ml","tr")
# mkdir(joinpath(basedir,"tr"))
# mkdir(joinpath(basedir,"tst"))
trl = Int8[]
trimgs = []
imgfile = joinpath(@__DIR__, "imgs.bin")
labelfile = joinpath(@__DIR__, "labels.bin")
function process_img_dir()
for (root, dir, files) in walkdir(trd)
# println("looking at $dir")
for d in dir
fcnt = 0
fls = readdir(joinpath(trd,d))
for f in fls
splitext(f)[2] == ".png" || continue
push!(trimgs, normalize(joinpath(trd,d,f)))
fcnt += 1
end
if d == "y"
push!(trl, ones(Int8,fcnt))
else
push!(trl, zeros(Int8,fcnt))
end
end
end
Serialization.serialize(imgfile, trimgs)
Serialization.serialize(labelfile, trl)
end
if isfile(imgfile)
trimgs = Serialization.deserialize(imgfile)
trl = Serialization.deserialize(labelfile)
else
process_img_dir()
end
println("trimgs has size $(size(trimgs))")
println("trl has size $(size(trl))")
m = Chain(
Conv((3,3), 3 => 64, relu, pad=(1, 1), stride=(1, 1)),
BatchNorm(64),
Conv((3,3), 64 => 64, relu, pad=(1, 1), stride=(1, 1)),
BatchNorm(64),
softmax
)
loss(x, y) = crossentropy(m(x),y)
accuracy(x, y) = mean(m(x) .== y)
evalcb = () -> @show(loss(trimgs, trl))
opt = ADAM()
Flux.train!(loss, params(m), zip(trimgs, trl), opt, cb = throttle(evalcb, 10))
atr = accuracy(trimgs, trl)
println("training accuracy is $atr")
# trimgs = []
push!(a, normalize("/Users/doug/dev/automaton/colonies12p7-1row/colonies12p7-150dpi-rainbow--1-1366939139754.png"))
# push!(a, normalize("/Users/doug/dev/automaton/colonies15/colonies15-158302-1366076820082.png"))
# trl=[1,0]