-
Notifications
You must be signed in to change notification settings - Fork 0
/
infer.py
107 lines (79 loc) · 3.17 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/usr/bin/env python
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import math
import sys
from torch.autograd import Variable
from dataset import KshDataset
from torch.utils.data import DataLoader
from net.model import voltexNet
import music_processer as mp
import torch.nn.functional as F
from torch.optim import lr_scheduler
from tqdm import tqdm
import os
def infer(model, device, batch, filename, savename) :
TH = 0.2
# Training End, infer #
input = KshDataset.music_load(filename)
#input = input.reshape(input.shape[0], 1, -1)
input = input.to(device, dtype=torch.float)
output = []
for i in range(0,input.shape[0], batch):
if i+batch < input.shape[0] :
pred = model(input[i:i+batch])
pred = pred.to(torch.device("cpu"))
if i == 0 :
softmax = nn.Softmax(dim=1)
pred = softmax(pred)
output = pred.tolist()
else :
pred = pred.tolist()
for i in pred:
output.append(i)
#print(output.shape)
#torch.save(model.state_dict(), "./model/model.pth")
index = 0
beforeIndex = 0
note_time_Stamp_output = []
fx_time_Stamp_output = []
for time in output :
#nt(time.index(max(time)), time[time.index(max(time))])
if time.index(max(time)) == 1 and time[time.index(max(time))] > 0.2 and index > beforeIndex + 1:
note_time_Stamp_output.append(index)
beforeIndex = index
if time.index(max(time)) == 2 and time[time.index(max(time))] > 0.2 and index > beforeIndex + 1:
fx_time_Stamp_output.append(index)
beforeIndex = index
if time.index(max(time)) == 3 and time[time.index(max(time))] > 0.2 and index > beforeIndex + 1:
note_time_Stamp_output.append(index)
fx_time_Stamp_output.append(index)
beforeIndex = index
index = index + 1
#print(note_time_Stamp_output)
#print(fx_time_Stamp_output)
#print(fx_time_Stamp_output)
song = mp.Audio(filename = ("./data_test/songs/badapple_nomico_alreco/nofx.ogg"), note_timestamp = note_time_Stamp_output, fx_timestamp = fx_time_Stamp_output)
song.synthesize(diff='ka')
song.save(filename = savename)
return note_time_Stamp_output, fx_time_Stamp_output
def main():
model = voltexNet()
model.load_state_dict(torch.load("./model/model_278_.pth"))
#print ("load model")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# move model to the right device
model.to(device)
#input = torch.rand(128,3,80,15)
batch = 256
song_index = 0
best_Acc = 0
epoch_loss = 0.0
#infer(model, device, batch, "./cache/albida.npy","./test_Output/infer.wav")
#infer(model, device, batch, "./test_ogg/nofx.npy","./test_Output/infer2.wav")
#infer(model, device, batch, "./Asset/KANA-BOON - Silhouette.ogg","./test_Output/infer3.wav")
infer(model, device, batch, "./cache/badapple_nomico_alreco.npy","./test_Output/infer3.wav")
if __name__ == "__main__":
main()