-
Notifications
You must be signed in to change notification settings - Fork 20
/
istftnettorch.cpp
78 lines (50 loc) · 1.69 KB
/
istftnettorch.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include "istftnettorch.h"
#include <windows.h>
bool iSTFTNetTorch::Initialize(const std::string &VocoderPath)
{
torch::Device device(torch::kCPU);
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
std::string VCP = VocoderPath + ".pt";
Model = torch::jit::load(VCP,device);
}
catch (const c10::Error& e) {
QMessageBox::critical(nullptr,"r",e.what_without_backtrace());
return false;
}
try{
std::string PostPath = VocoderPath + "-post.pt";
Post = torch::jit::load(PostPath,device);
PostLoaded = true;
}
catch (const c10::Error& e){
PostLoaded = false;
}
return true;
}
TFTensor<float> iSTFTNetTorch::DoInference(const TFTensor<float> &InMel)
{
// without this memory consumption is 4x
torch::NoGradGuard no_grad;
torch::Device device(torch::kCPU);
auto TorchMel = torch::tensor(InMel.Data,device).reshape(InMel.Shape).transpose(1,2); // [1, frames, n_mels] -> [1, n_mels, frames]
try{
at::Tensor Output = Model({TorchMel}).toTensor().squeeze(); // (audio frames)
if (PostLoaded)
Output = Post({Output.unsqueeze(0).toType(at::ScalarType::Float)}).toTensor();
TFTensor<float> Tens = VoxUtil::CopyTensor<float>(Output);
return Tens;
}
catch (const std::exception& e) {
int msgboxID = MessageBox(
NULL,
(LPCWSTR)QString::fromStdString(e.what()).toStdWString().c_str(),
(LPCWSTR)L"Account Details",
MB_ICONWARNING | MB_CANCELTRYCONTINUE | MB_DEFBUTTON2
);
return TFTensor<float>();
}
}
iSTFTNetTorch::iSTFTNetTorch()
{
}