-
Notifications
You must be signed in to change notification settings - Fork 379
/
multi_training.py
60 lines (46 loc) · 1.97 KB
/
multi_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os, random
from midi_to_statematrix import *
from data import *
import cPickle as pickle
import signal
batch_width = 10 # number of sequences in a batch
batch_len = 16*8 # length of each sequence
division_len = 16 # interval between possible start locations
def loadPieces(dirpath):
pieces = {}
for fname in os.listdir(dirpath):
if fname[-4:] not in ('.mid','.MID'):
continue
name = fname[:-4]
outMatrix = midiToNoteStateMatrix(os.path.join(dirpath, fname))
if len(outMatrix) < batch_len:
continue
pieces[name] = outMatrix
print "Loaded {}".format(name)
return pieces
def getPieceSegment(pieces):
piece_output = random.choice(pieces.values())
start = random.randrange(0,len(piece_output)-batch_len,division_len)
# print "Range is {} {} {} -> {}".format(0,len(piece_output)-batch_len,division_len, start)
seg_out = piece_output[start:start+batch_len]
seg_in = noteStateMatrixToInputForm(seg_out)
return seg_in, seg_out
def getPieceBatch(pieces):
i,o = zip(*[getPieceSegment(pieces) for _ in range(batch_width)])
return numpy.array(i), numpy.array(o)
def trainPiece(model,pieces,epochs,start=0):
stopflag = [False]
def signal_handler(signame, sf):
stopflag[0] = True
old_handler = signal.signal(signal.SIGINT, signal_handler)
for i in range(start,start+epochs):
if stopflag[0]:
break
error = model.update_fun(*getPieceBatch(pieces))
if i % 100 == 0:
print "epoch {}, error={}".format(i,error)
if i % 500 == 0 or (i % 100 == 0 and i < 1000):
xIpt, xOpt = map(numpy.array, getPieceSegment(pieces))
noteStateMatrixToMidi(numpy.concatenate((numpy.expand_dims(xOpt[0], 0), model.predict_fun(batch_len, 1, xIpt[0])), axis=0),'output/sample{}'.format(i))
pickle.dump(model.learned_config,open('output/params{}.p'.format(i), 'wb'))
signal.signal(signal.SIGINT, old_handler)