-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfold_train_test.py
56 lines (49 loc) · 1.51 KB
/
fold_train_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import sys,os
from ILPformat import parse_json as parse_inp
datadir = "data/"
folds = [x for x in os.listdir(datadir) if x[0]=='i']
test = [x for x in folds if str(sys.argv[1])+'.' in x][0]
folds.remove(test)
'''
def parse_inp(inp):
q=[]
a=[]
e=[]
with open(inp) as f:
f = f.readlines()
i=0
while i<len(f):
q.append(f[i])
i+=1
e.append(f[i])
i+=1
a.append(f[i])
i+=1
return (q,a,e)
'''
#q,a,e = parse_inp(sys.argv[2])
q,a = parse_inp(sys.argv[2])
train = []
for x in folds:
train.extend(open(datadir+x).readlines())
train = [int(i)-1 for i in train]
test = [int(i)-1 for i in open(datadir+test).readlines()]
if len([x for x in train if x in test])!=0:
print([x for x in train if x in test]);exit()
trainq = [x for i,x in enumerate(q) if i in train]
traina = [x for i,x in enumerate(a) if i in train]
traine = [str(i)+'\n' for i,x in enumerate(q) if i in train]
testq= [x for i,x in enumerate(q) if i in test]
testa = [x for i,x in enumerate(a) if i in test]
teste = [str(i)+'\n' for i,x in enumerate(q) if i in test]
print(len(testq))
with open('data/train'+sys.argv[1],'w') as f:
for i in range(len(trainq)):
f.write((trainq[i]+'\n').encode("utf8"))
f.write(traine[i])
f.write(str(traina[i])+'\n')
with open('data/test'+str(sys.argv[1]),'w') as f:
for i in range(len(testq)):
f.write((testq[i]+'\n').encode("utf8"))
f.write(teste[i])
f.write(str(testa[i])+'\n')