Skip to content

Commit

Permalink
Singing Voice model
Browse files Browse the repository at this point in the history
  • Loading branch information
JuanFMontesinos committed Jan 3, 2023
1 parent dfd11c2 commit 330b849
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 5 deletions.
2 changes: 1 addition & 1 deletion inference_artificial_mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
vovit.utils.np_int2float(read('demo_samples/3r23tdRALns/00029.wav')[1][:16384 * 2 - 1])).to(device)
audio1 /= audio1.abs().max()
mixture = (audio1 + audio2).unsqueeze(0) / 2
model = vovit.End2EndVoViT(model_name='VoViT_speech', debug={}).to(device)
model = vovit.SpeechVoViT().to(device)
model.eval()
with torch.no_grad():
pred = model(mixture, tgt_face)
Expand Down
8 changes: 7 additions & 1 deletion inference_interview.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,17 @@

mixture = torch.from_numpy(read(f'{path}/audio.wav')[1]).to(device)

model = vovit.End2EndVoViT(model_name='VoViT_speech', extract_landmarks=compute_landmarks, debug={}).to(device)
print('Creating model instance...')
model = vovit.SpeechVoViT(extract_landmarks=compute_landmarks).to(device)
model.eval()
print('Done')

with torch.no_grad():
print('Forwarding speaker1...')
pred_s1 = model.forward_unlimited(mixture, speaker1_face)
print('Forwarding speaker2...')
pred_s2 = model.forward_unlimited(mixture, speaker2_face)

wav_s1 = pred_s1['ref_est_wav'].squeeze().cpu().numpy()
wav_s2 = pred_s2['ref_est_wav'].squeeze().cpu().numpy()
vd.plot_spectrogram(pred_s1['ref_est_sp'].squeeze(), 16384, 256, remove_labels=True)
Expand Down
37 changes: 37 additions & 0 deletions singing_voice_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
import torch

import vovit.display as vd
import vovit

from scipy.io.wavfile import write
import matplotlib.pyplot as plt

data_path = '/home/jfm/singing_voice_sep_demo/splits'
dst_path = 'demo_samples/singing_voice_sep_results'
device = 'cuda:0'


sampler = vd.singing_voice_demo.DemoDataLoader(25, 16384, data_path, vd.t_dict)
N = len(sampler)
model = vovit.SingingVoiceVoViT(debug={}).to(device)
model.eval()

for idx in range(N):
with torch.no_grad():
key, kwargs = next(sampler)
path = os.path.join(dst_path, key)
if not os.path.exists(path):
os.makedirs(path)
mixture = sampler.load_audio(key, **kwargs).to(device)
landmarks = sampler.load_landmarks(key, **kwargs).to(device)
outputs = model.forward_unlimited(mixture, landmarks)

# Dumping the results
wav = outputs['estimated_wav'].squeeze().cpu().numpy()
write(os.path.join(dst_path, f'{os.path.join(key, "estimated.wav")}'), 16384, wav)
estimated_sp = torch.view_as_complex(outputs['estimated_sp']).squeeze().cpu().numpy()
vd.plot_spectrogram(estimated_sp.squeeze(), 16384, 256, remove_labels=True)
plt.tight_layout(True)
plt.savefig(os.path.join(dst_path, f'{os.path.join(key, "estimated_sp.png")}'))
print(f'[{idx}/{N}], {key}')
71 changes: 68 additions & 3 deletions vovit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
from . import utils


class End2EndVoViT(torch.nn.Module):
def __init__(self, *, model_name: str, debug: dict, pretrained: bool = True,
class SpeechVoViT(torch.nn.Module):
def __init__(self, debug: dict = {}, pretrained: bool = True,
extract_landmarks: bool = False, detect_faces: bool = False):
super().__init__()
self.extract_landmarks = extract_landmarks
self.detect_faces = detect_faces
self.vovit = VoViT(model_name=model_name, debug=debug, pretrained=pretrained)
self.vovit = VoViT(model_name='VoViT_speech', debug=debug, pretrained=pretrained)

if self.extract_landmarks:
from .core.landmark_estimator.TDDFA_GPU import TDDFA
Expand Down Expand Up @@ -96,3 +96,68 @@ def forward_unlimited(self, mixture, visuals):
if v.ndim == 2: # Waveforms
pred_unraveled[k] = v.flatten()
return pred_unraveled


class SingingVoiceVoViT(torch.nn.Module):
def __init__(self, *, debug: dict, pretrained: bool = True,
extract_landmarks: bool = False, detect_faces: bool = False):
super().__init__()
self.extract_landmarks = extract_landmarks
self.detect_faces = detect_faces
self.vovit = VoViT(model_name='vovit_singing_voice', debug=debug, pretrained=pretrained)

if self.extract_landmarks:
raise NotImplementedError('Landmark extraction is not implemented for singing voice')

def forward(self, mixture, visuals, extract_landmarks=False):
"""
:param mixture: torch.Tensor of shape (B,N)
:param visuals: torch.Tensor of shape (B,C,H,W) BGR format required
:return:
"""
if self.detect_faces:
raise NotImplementedError

ld = visuals

ld = rearrange(ld, 'b t j c ->b c t j').unsqueeze(-1).float()
mixture = cast_dtype(mixture, raise_error=True) # Cast integers to float
mixture /= mixture.abs().max()

return self.vovit(mixture, ld)

def forward_unlimited(self, mixture, visuals):
"""
Allows to run inference in an unlimited duration samples (up to gpu memory constrains)
The results will be trimmed to multiples of 2 seconds (e.g. if your audio is 8.5 seconds long,
the result will be trimmed to 8 seconds)
Args:
visuals: raw video if self.extract_landmarks is True, precomputed_landmarks otherwise.
lanmarks are uint16 tensors of shape (T,3,68)
raw video are uint8 RGB tensors of shape (T,H,W,3) (values between 0-255)
mixture: tensor of shape (N)
"""
fps = VIDEO_FRAMERATE
length = self.vovit.avse.ap._audio_length
n_chunks = visuals.shape[0] // (fps * 4)
visuals = visuals[:n_chunks * fps * 4].view(n_chunks, fps * 4, 68, 2)
mixture = mixture[:n_chunks * length].view(n_chunks, -1)
pred = self.forward(mixture, visuals)
pred_unraveled = {}
for k, v in pred.items():
if v is None:
continue
if v.is_complex(): # Complex spectrogram
pred_unraveled[k] = rearrange(v, 'b f t -> f (b t)')
if v.ndim == 4: # Two-channels mask
idx = v.shape[1:].index(2) + 1
if idx == 1:
string = 'b c f t -> c f (b t)'
elif idx == 3:
string = 'b f t c-> f (b t) c'
else:
raise ValueError('Unknown shape')
pred_unraveled[k] = rearrange(v, string)
if v.ndim == 2: # Waveforms
pred_unraveled[k] = v.flatten()
return pred_unraveled
1 change: 1 addition & 0 deletions vovit/display/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import librosa
import torch

from .singing_voice_demo import DemoDataLoader,t_dict

def plot_spectrogram(spectrogram, sr: int, hop_length,
title=None, remove_labels=False, remove_axis=False,
Expand Down
107 changes: 107 additions & 0 deletions vovit/display/singing_voice_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import os
import numpy as np
import torch

from scipy.io.wavfile import read


t_dict = {'pRh9rKd2j64_0_15_to_0_55': {'initial_time': 15, 'sample_name': 'pRh9rKd2j64_0_15_to_0_55', 'n': 3}, # n=3
'sEnTMgzw8ow_1_29_to_1_47': {'initial_time': 0, 'sample_name': 'lead_vocals', 'n': 1},
'sEnTMgzw8ow_1_5_to_2_07': {'initial_time': 0, 'sample_name': 'lead_vocals', 'n': 1},
'sEnTMgzw8ow_2_11_to_2_33': {'initial_time': 0, 'sample_name': 'lead_vocals', 'n': 1},
'sEnTMgzw8ow_2_38_to_2_53': {'initial_time': 0, 'sample_name': 'lead_vocals', 'n': 1},
'sEnTMgzw8ow_0_34_to_0_39': {'initial_time': 0, 'sample_name': 'lead_vocals', 'n': 1},
'Dyo7jzaCUhk_0_02_to_5_2': {'initial_time': 42, 'sample_name': '2_2', 'n': 3},
# 'cttFanV0o7c_0_07_to_2_44': {'initial_time': 46, 'sample_name': 'top_right','n':1}, # 44-onwards llcp fails 8s
'cttFanV0o7c_0_07_to_2_44': {'initial_time': 32, 'sample_name': 'bottom_left', 'n': 1},

# Separates good from acmt + better than audio
'vyu3HU3XWi4_0_3_to_0_4': {'initial_time': 0, 'sample_name': 'vyu3HU3XWi4_0_3_to_0_4', 'n': 1},
'vyu3HU3XWi4_2_04_to_2_14': {'initial_time': 0, 'sample_name': 'vyu3HU3XWi4_2_04_to_2_14', 'n': 2}, # n=2
'vyu3HU3XWi4_1_46_to_1_51': {'initial_time': 0, 'sample_name': 'vyu3HU3XWi4_1_46_to_1_51', 'n': 1}, # n=1
'it6Ud6PDPes_2_22_to_2_27': {'initial_time': 0, 'sample_name': 'lead_vocals', 'n': 1}, # n=1
'SNgnylGkerE_0_15_to_0_2': {'initial_time': 0, 'sample_name': 'male_voice', 'n': 1},
# n=1 # Unison appearance matters
'WikcPREx0DM_2_12_to_2_17': {'initial_time': 0, 'sample_name': 'beatbox', 'n': 1}, # n=1
'the_circle_of_life': {'initial_time': 2 * 60 + 55, 'sample_name': 'rafiki', 'n': 3}, # n=3
'q9vqt-lwy3I_0_29_to_0_36': {'initial_time': 2, 'sample_name': 'lead_vocals', 'n': 1}, # n=1
'q9vqt-lwy3I_1_33_to_1_43': {'initial_time': 2, 'sample_name': 'lead_vocals', 'n': 2}, # n=2
'Gayh_GrCKgU_5_11_to_5_35': {'initial_time': 2, 'sample_name': 'lead_vocals', 'n': 1},
# n=1, LLCP doesn't separate at all vocals not used
'BtuwsjeN7Pk_4_22_to_4_28': {'initial_time': 2, 'sample_name': 'lead_vocals', 'n': 1}, # n=1
'kce_zDH-OVA_0_43_to_0_5': {'initial_time': 3, 'sample_name': 'lead_vocals', 'n': 1}, # n=1
'kce_zDH-OVA_1_42_to_1_5': {'initial_time': 0, 'sample_name': 'kce_zDH-OVA_1_42_to_1_5', 'n': 2}, # n=2
'kce_zDH-OVA_2_09_to_2_3': {'initial_time': 8, 'sample_name': 'kce_zDH-OVA_2_09_to_2_3', 'n': 3}, # n=3
'hWCkCSO8h9I_0_4_to_0_45': {'initial_time': 0, 'sample_name': 'lead_vocals', 'n': 1}, # n=1
'6Ws1WKA4z2k_0_35_to_0_48': {'initial_time': 0, 'sample_name': 'lead_vocals', 'n': 3}, # n=3
}

class DemoDataLoader:
def __init__(self, framerate: int, audiorate: int, data_path: str, dictionary={}):
self.fps = framerate
self.arate = audiorate
self.data_path = data_path
assert os.path.exists(data_path), f'The directorty {data_path} does not exist'
self.core = dictionary
self.generator = self._generator()

def av_faces(self, video_id):
path = os.path.join(self.data_path, 'frames', video_id)
return os.listdir(path)

def load(self, video_id, sample_name, initial_time, elements, n):
output = {}
for el in elements:
loader = getattr(self, f'load_{el}')
key = el if el != 'audio' else 'mixture'
output[key] = loader(video_id, sample_name, initial_time, n)
return output

def load_frames(self, video_id, sample_name, initial_time, n):
video_path = os.path.join(self.data_path, 'frames', video_id, sample_name) + '.npy'
video = np.load(video_path)
video = video[initial_time * self.fps:initial_time * self.fps + self.fps * 4 * n]
return video

def load_audio(self, video_id, sample_name, initial_time, n, reshape=False):
audio_path = os.path.join(self.data_path, 'audio', video_id) + '.wav'
audio = read(audio_path)[1][initial_time * self.arate:initial_time * self.arate + (self.arate * 4 - 1) * n]
audio = torch.from_numpy(audio)
audio = audio / audio.abs().max()
if reshape:
return audio.view(n, -1)
return audio

def load_landmarks(self, video_id, sample_name, initial_time, n, reshape=False):
landmarks_path = os.path.join(self.data_path, 'landmarks', video_id, sample_name) + '.npy'
landmarks = np.load(landmarks_path)[initial_time * self.fps:initial_time * self.fps + self.fps * 4 * n]

landmarks = torch.from_numpy(landmarks)

if not reshape:
return landmarks
landmarks = landmarks.reshape(n, -1, *landmarks.shape[1:])
landmarks = landmarks.permute(0, 3, 1, 2)
return landmarks.unsqueeze(-1).float()

def get_sample(self, *args):
key, kwargs = next(self)
inputs = self.load(video_id=key, **kwargs, elements=args)
return inputs, (key, kwargs)

def _generator(self):
for key, items in self.core.items():
yield key, items

def __len__(self):
return len(self.core)

def __iter__(self):
return self

def __next__(self):
try:
return next(self.generator)
except StopIteration:
self.generator = self._generator()
raise StopIteration

0 comments on commit 330b849

Please sign in to comment.