-
Notifications
You must be signed in to change notification settings - Fork 0
/
Run_Label_Prediction.jl
62 lines (36 loc) · 1.48 KB
/
Run_Label_Prediction.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
include("API.jl")
include("Config.jl")
graph = build_graph("
fox animal likes dog animal
fox animal neutral human mammal
fox animal dislikes cat animal
dog animal likes fox animal
dog animal likes human mammal
dog animal dislikes cat animal
human mammal likes dog animal
human mammal likes cat animal
human mammal likes fox animal
cat animal neutral human mammal
cat animal dislikes dog animal
cat animal dislikes fox animal
")
labels = Dict()
set_label(graph, node, label) = labels[get_node(graph,node)] = label
get_label(graph, node) = labels[get_node(graph,node)]
set_label(graph,"fox",tanh.(randn(1,2)))
set_label(graph,"dog",tanh.(randn(1,2)))
set_label(graph,"human",tanh.(randn(1,2)))
set_label(graph,"cat",tanh.(randn(1,2)))
graph.label_predictor = [FeedForward(message_size, size(collect(values(labels))[end])[end])]
@show test_for_label_prediction(graph, keys(labels), values(labels))
for i in 1:hm_epochs
println("train: $(train_for_label_prediction!(graph, learning_rate, keys(labels), values(labels)))")
i%test_per_epoch == 0 ? println("\ttest: $(test_for_label_prediction(graph, keys(labels), values(labels)))") : ()
end
println(" ")
# predict_label(graph, "human")
# embed_node(graph, "fox")
# display_similarities(graph)
# binary_cross_entropy(label, prop(predictor, update_node_wrt_depths(node), act2=sigm))
# cross_entropy(label, softmax(prop(predictor, update_node_wrt_depths(node), act2=nothing)))
# mse(label, prop(predictor, update_node_wrt_depths(node)))