-
Notifications
You must be signed in to change notification settings - Fork 2
/
predict.py
90 lines (70 loc) · 2.34 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import sys
import numpy as np
import tensorflow as tf
from pathlib import Path
from model import Network, Config
from data import Vocab
from itertools import chain
if __name__ == '__main__':
config = Config(
batch_size = 1,
char_embed_size = 25,
conv_kernel = 3,
depth = 1,
dropout = 0.5,
h_size = 256,
learning_rate = 0.01,
pool_size = 53
)
vocab = Vocab(f'etc/samnorsk.300.skipgram.bin', 'etc/gazetteer.txt')
output_types = (tf.float32, tf.float32, tf.float32, tf.int32, tf.int32, tf.int32, tf.int32)
output_shapes = (
[None, vocab.n_words], # Word embeddings for each word in the sentence
[None, vocab.n_pos], # one_hot encoded PoS for each word
[None, vocab.n_categories], # NE category memberships
[None, None], # The characters for each word
[None], # The number of characters pr word
[None], # The labels for each word
[] # the number of words in sentence
)
examples = tf.data.Dataset.from_generator(
vocab.examples(sys.argv[2]),
output_types = output_types,
output_shapes = output_shapes
).padded_batch(
batch_size = config.batch_size,
padded_shapes = output_shapes
)
iterator = tf.data.Iterator.from_structure(
output_types = output_types,
output_shapes = examples.output_shapes
)
example_init = iterator.make_initializer(
examples
)
batch = iterator.get_next(name = "batch")
net = Network(
name = 'samnorsk',
batch = batch,
config = config,
v_shape = vocab.shape
)
meta_graph = Path(sys.argv[1])
model = meta_graph.with_name(meta_graph.stem)
saver = tf.train.Saver()
with tf.Session() as session:
saver.restore(session, str(model))
session.run(example_init)
try:
while True:
predicted = session.run(
net.predict,
feed_dict = {
net.dropout: 0.0
}
)
for label in chain(*predicted):
print(vocab.idx_tag[label])
print()
except tf.errors.OutOfRangeError:
pass