-
Notifications
You must be signed in to change notification settings - Fork 0
/
nnIO.go
105 lines (98 loc) · 3.03 KB
/
nnIO.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
package main
import (
"fmt"
"io/ioutil"
"os"
"strconv"
"strings"
)
//Save Saves weights to hidden.weights & output.weights
func Save(net Network) {
network, err := os.Create("weights/default.network")
defer network.Close()
if err == nil {
network.WriteString(strconv.Itoa(net.inputs) + "\n")
network.WriteString(strconv.Itoa(net.hiddens) + "\n")
network.WriteString(strconv.Itoa(net.outputs) + "\n")
network.WriteString(fmt.Sprintf("%f", net.rate) + "\n")
}
h, err := os.Create("weights/hidden.weights")
defer h.Close()
if err == nil {
net.hiddenWeights.MarshalBinaryTo(h)
}
o, err := os.Create("weights/output.weights")
defer o.Close()
if err == nil {
net.outputWeights.MarshalBinaryTo(o)
}
}
//SaveAs saves weights to str
func SaveAs(net Network, str string) {
network, err := os.Create("weights/" + str + ".network")
defer network.Close()
if err == nil {
network.WriteString(strconv.Itoa(net.inputs) + "\n")
network.WriteString(strconv.Itoa(net.hiddens) + "\n")
network.WriteString(strconv.Itoa(net.outputs) + "\n")
network.WriteString(fmt.Sprintf("%f", net.rate) + "\n")
}
h, err := os.Create("weights/" + str + " - hidden.weights")
defer h.Close()
if err == nil {
net.hiddenWeights.MarshalBinaryTo(h)
}
o, err := os.Create("weights/" + str + " - output.weights")
defer o.Close()
if err == nil {
net.outputWeights.MarshalBinaryTo(o)
}
}
//Load load from hidden.weights and output.weights
func Load(net *Network) {
network, err := ioutil.ReadFile("weights/default.network")
if err == nil {
data := strings.Split(string(network), "\n")
net.inputs, _ = strconv.Atoi(data[0])
net.hiddens, _ = strconv.Atoi(data[1])
net.outputs, _ = strconv.Atoi(data[2])
net.rate, _ = strconv.ParseFloat(data[3], 64)
fmt.Printf("Loaded Default\nNeural Net: \nInputs : %d\nHiddens : %d\nOutputs : %d\nLearning Rate : %f\n", net.inputs, net.hiddens, net.outputs, net.rate)
}
h, err := os.Open("weights/hidden.weights")
defer h.Close()
if err == nil {
net.hiddenWeights.Reset()
net.hiddenWeights.UnmarshalBinaryFrom(h)
}
o, err := os.Open("weights/output.weights")
defer o.Close()
if err == nil {
net.outputWeights.Reset()
net.outputWeights.UnmarshalBinaryFrom(o)
}
}
//LoadAs load from given str
func LoadAs(net *Network, str string) {
network, err := ioutil.ReadFile("weights/" + str + ".network")
if err == nil {
data := strings.Split(string(network), "\n")
net.inputs, _ = strconv.Atoi(data[0])
net.hiddens, _ = strconv.Atoi(data[1])
net.outputs, _ = strconv.Atoi(data[2])
net.rate, _ = strconv.ParseFloat(data[3], 64)
fmt.Printf("Loaded %s\nNeural Net: \nInputs : %d\nHiddens : %d\nOutputs : %d\nLearning Rate : %f\n", str, net.inputs, net.hiddens, net.outputs, net.rate)
}
h, err := os.Open("weights/" + str + " - hidden.weights")
defer h.Close()
if err == nil {
net.hiddenWeights.Reset()
net.hiddenWeights.UnmarshalBinaryFrom(h)
}
o, err := os.Open("weights/" + str + " - output.weights")
defer o.Close()
if err == nil {
net.outputWeights.Reset()
net.outputWeights.UnmarshalBinaryFrom(o)
}
}