-
Notifications
You must be signed in to change notification settings - Fork 1
/
temp.py
49 lines (41 loc) · 1.56 KB
/
temp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torch.nn.functional as F
from data_processing import ukr_lang_chars_handle
from config import *
from model import Conformer as con
from data_processing import UkrVoiceDataset
from torch.utils.data import DataLoader
import pprint
device = "cpu"
PATH = os.path.join(DATA_DIR, "model_1.pt")
model = con(n_encoders=CONFIG["n_encoders"], n_decoders=CONFIG["n_decoders"], device=device)
model.load_state_dict(torch.load(PATH))
model.eval()
ds = UkrVoiceDataset(TRAIN_PATH, TRAIN_SPEC_PATH)
train_dataloader = DataLoader(ds, shuffle=True, batch_size=1)
with torch.no_grad():
X, tgt = next(iter(train_dataloader))
tgt = tgt["text"]
X = X.to(device)
print("Target:", tgt)
print("X shape:", X.shape)
#tgt = ("",)
tgt_one_hots = ukr_lang_chars_handle.sentences_to_one_hots(tgt, 152)
print("tgt to one_hots shape:", tgt_one_hots.shape)
print("tgt to one_hots:", ukr_lang_chars_handle.one_hots_to_sentences(tgt_one_hots))
emb, out_data = model(X, tgt_one_hots.to(device))
emb = F.log_softmax(emb, dim=-1)
emb = emb.cpu()
out_data = F.log_softmax(out_data, dim=-1)
out_data = out_data.cpu()
print("\n\nOutput data shape:", out_data.shape)
print("output:", out_data)
out_data = out_data.transpose(-1, -2).contiguous()
result = ukr_lang_chars_handle.one_hots_to_sentences(out_data)
pprint.pprint(len(result))
pprint.pprint(result)
pprint.pprint(out_data.shape)
print(emb.shape)
result = ukr_lang_chars_handle.one_hots_to_sentences(emb)
pprint.pprint(len(result))
pprint.pprint(result)