-
Notifications
You must be signed in to change notification settings - Fork 6
/
predict_len.py
78 lines (63 loc) · 2.48 KB
/
predict_len.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import random
from random import shuffle
import argparse
from sklearn.linear_model import LinearRegression
from sampling import ExperimentSampler
random.seed(1234)
def predict_len(sampler, input_file, test_percent=0.2):
"""
Predict length of input sentences based on generated latent variables.
:param sampler: sampler instance
:param input_file: file of input sentences
:param test_percent: train/test split
:return:
"""
sentences_train, sentences_test = get_data_splits(input_file, test_percent)
x_train, y_train = get_x_y(sampler, sentences_train)
x_test, y_test = get_x_y(sampler, sentences_test)
model = LinearRegression()
model.fit(x_train, y_train)
score = model.score(x_test, y_test)
return score
def get_x_y(sampler, sentences):
"""
Get latent variables and construct instances for sklearn.
:param sampler: sampler instance
:param sentences: list of input sentences
:return: features, labels
"""
y = [len(s.split(' ')) for s in sentences]
x = sampler.get_latent_variables(sentences)
return x, y
def get_data_splits(input_file, test_percent=0.2, do_shuffle=True):
"""
Split input sentences in train and test.
:param input_file: path to file of input sentences
:param test_percent: train/test split
:param do_shuffle: shuffle input
:return: train sentences, test sentences
"""
sentences = list()
with open(input_file) as f:
for line in f:
sentences.append(line.rstrip())
if do_shuffle:
shuffle(sentences)
split_idx = int(len(sentences) * (1 - test_percent))
sentences_train = sentences[:split_idx]
sentences_test = sentences[split_idx:]
print(len(sentences_train))
print(len(sentences_test))
return sentences_train, sentences_test
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Predict the length of given sentences from the latent variable')
parser.add_argument('exp_dir', type=str,
help='Path to experiment directory. Should contain code, config, vocab and model subfolders.')
parser.add_argument('data_file', type=str, default=None,
help='File to read in.')
parser.add_argument('-batch_size', type=int, default=32,
help='Batch size.')
args = parser.parse_args()
sampler = ExperimentSampler(args.exp_dir, args.batch_size)
score = predict_len(sampler, args.data_file, test_percent=0.2)
print(score)