Skip to content

Commit

Permalink
add option to generate continuation from input audio
Browse files Browse the repository at this point in the history
  • Loading branch information
zhvng committed Apr 18, 2023
1 parent 175be81 commit 4ee53d9
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 10 deletions.
84 changes: 74 additions & 10 deletions open_musiclm/open_musiclm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
eval_decorator, exists, float32_to_int16,
generate_mask_with_prob, get_embeds, gumbel_sample,
int16_to_float32, mask_out_after_eos_id,
round_down_nearest_multiple, top_k)
round_down_nearest_multiple, top_k, prepare_audio)


@dataclass
Expand Down Expand Up @@ -855,6 +855,7 @@ def forward(
*,
text: Optional[List[str]] = None,
prime_wave=None,
prime_wave_sample_hz=None,
output_seconds=8,
semantic_window_seconds=10,
coarse_window_seconds=4,
Expand All @@ -871,8 +872,53 @@ def forward(

clap_token_ids = get_or_compute_clap_token_ids(None, self.clap, conditioning_audio=None, conditioning_text=text)

# compute everything we need for audio continuation

all_audio_condition_coarse_token_ids = None
all_audio_condition_fine_token_ids = None
audio_condition_semantic_token_ids = None
audio_condition_coarse_token_ids = None
audio_condition_fine_token_ids = None
semantic_token_adjustment = 0 # used to crop generated semantic tokens so first sequence lines up with first coarse tokens sequence
coarse_token_adjustment = 0
fine_token_adjustment = 0
if exists(prime_wave):
assert exists(prime_wave_sample_hz)
prime_wave_wav2vec = prepare_audio(
prime_wave,
prime_wave_sample_hz,
self.wav2vec.target_sample_hz,
normalize=True,
target_length_seconds=semantic_window_seconds)
prime_wave_encodec = prepare_audio(
prime_wave,
prime_wave_sample_hz,
self.neural_codec.sample_rate,
normalize=False,
target_length_seconds=semantic_window_seconds)

condition_semantic_token_ids = get_or_compute_semantic_token_ids(None, prime_wave_wav2vec, self.wav2vec)
condition_coarse_token_ids, condition_fine_token_ids = get_or_compute_acoustic_token_ids(None, None, prime_wave_encodec, self.neural_codec, self.coarse.transformer_wrapper.token_sequences[2].num_quantizers)
condition_semantic_length = int(semantic_steps_per_second * semantic_window_seconds * (1 - semantic_sliding_window_step_percent))
condition_coarse_length = int(acoustic_steps_per_second * coarse_window_seconds * (1 - coarse_sliding_window_step_percent))
condition_fine_length = int(acoustic_steps_per_second * fine_window_seconds * (1 - fine_sliding_window_step_percent))

all_audio_condition_coarse_token_ids = condition_coarse_token_ids
all_audio_condition_fine_token_ids = condition_fine_token_ids

audio_condition_semantic_token_ids = condition_semantic_token_ids[:, -condition_semantic_length:] if condition_semantic_token_ids.shape[1] >= condition_semantic_length else condition_semantic_token_ids
audio_condition_coarse_token_ids = condition_coarse_token_ids[:, -condition_coarse_length:]
audio_condition_fine_token_ids = condition_fine_token_ids[:, -condition_fine_length:] if condition_fine_length > 0 else None

semantic_token_adjustment = condition_semantic_length - int(semantic_steps_per_second * coarse_window_seconds * (1 - coarse_sliding_window_step_percent))
coarse_token_adjustment = condition_coarse_length - int(acoustic_steps_per_second * fine_window_seconds * (1 - fine_sliding_window_step_percent))
fine_token_adjustment = condition_fine_length

# semantic stage

all_semantic_token_ids = self.semantic.generate(
clap_token_ids=clap_token_ids,
semantic_token_ids=audio_condition_semantic_token_ids,
max_time_steps=int(min(output_seconds, semantic_window_seconds) * semantic_steps_per_second),
include_eos_in_output=False,
append_eos_to_conditioning_tokens=True,
Expand All @@ -891,18 +937,23 @@ def forward(
pred_semantic_token_ids = pred_semantic_token_ids[:, condition_length:]
all_semantic_token_ids = torch.cat([all_semantic_token_ids, pred_semantic_token_ids], dim=1)

# sliding windows of coarse window size
# crop semantic tokens to line up with coarse tokens
all_semantic_token_ids = all_semantic_token_ids[:, semantic_token_adjustment:]

# coarse stage

window_size = int(coarse_window_seconds * semantic_steps_per_second - 1)
step_size = int(window_size * coarse_sliding_window_step_percent)
all_semantic_token_ids = all_semantic_token_ids.unfold(1, window_size, step_size)
all_semantic_token_ids = rearrange(all_semantic_token_ids, 'b n q w -> n b w q')

all_coarse_token_ids = None
for semantic_token_ids in all_semantic_token_ids:
# TODO: pad to coarse_window_seconds if needed

condition_length = int(coarse_window_seconds * acoustic_steps_per_second * (1 - coarse_sliding_window_step_percent))
condition_coarse_token_ids = all_coarse_token_ids[:, -condition_length:] if exists(all_coarse_token_ids) else None
if exists(all_coarse_token_ids):
condition_length = int(coarse_window_seconds * acoustic_steps_per_second * (1 - coarse_sliding_window_step_percent))
condition_coarse_token_ids = all_coarse_token_ids[:, -condition_length:]
else:
condition_coarse_token_ids = audio_condition_coarse_token_ids

pred_coarse_token_ids = self.coarse.generate(
clap_token_ids=clap_token_ids,
Expand All @@ -926,18 +977,24 @@ def forward(
wave = rearrange(wave, 'b 1 n -> b n')
return wave

# crop to fine window length and iterate
# crop coarse tokens to line up with fine tokens
all_coarse_token_ids = all_coarse_token_ids[:, coarse_token_adjustment:]

# fine stage

fine_window_size = int(fine_window_seconds * acoustic_steps_per_second)
fine_step_size = int(fine_window_size * fine_sliding_window_step_percent)
all_coarse_token_ids_unfolded = all_coarse_token_ids.unfold(1, fine_window_size, fine_step_size)
all_coarse_token_ids_unfolded = rearrange(all_coarse_token_ids_unfolded, 'b n q w -> n b w q')

all_fine_token_ids = None
for coarse_token_ids in all_coarse_token_ids_unfolded:
if exists(all_fine_token_ids):
condition_length = int(fine_window_size * (1 - fine_sliding_window_step_percent))
condition_fine_token_ids = all_fine_token_ids[:, -condition_length:] if condition_length > 0 else None
else:
condition_fine_token_ids = audio_condition_fine_token_ids

condition_length = int(fine_window_size * (1 - fine_sliding_window_step_percent))
condition_fine_token_ids = all_fine_token_ids[:, -condition_length:] if exists(
all_fine_token_ids) and condition_length > 0 else None
pred_fine_token_ids = self.fine.generate(
clap_token_ids=clap_token_ids,
coarse_token_ids=coarse_token_ids,
Expand All @@ -954,6 +1011,13 @@ def forward(
pred_fine_token_ids = pred_fine_token_ids[:, condition_length:]
all_fine_token_ids = torch.cat([all_fine_token_ids, pred_fine_token_ids], dim=1)

# crop fine tokens to remove conditioning audio
all_fine_token_ids = all_fine_token_ids[:, fine_token_adjustment:]

if exists(all_audio_condition_coarse_token_ids) and exists(all_audio_condition_fine_token_ids):
all_fine_token_ids = torch.cat([all_audio_condition_fine_token_ids, all_fine_token_ids], dim=1)
all_coarse_token_ids = torch.cat([all_audio_condition_coarse_token_ids, all_coarse_token_ids], dim=1)

all_acoustic_token_ids = torch.cat([all_coarse_token_ids, all_fine_token_ids], dim=-1)
wave = self.neural_codec.decode_from_codebook_indices(all_acoustic_token_ids)
wave = rearrange(wave, 'b 1 n -> b n')
Expand Down
12 changes: 12 additions & 0 deletions open_musiclm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
import shutil
import os
from torchaudio.functional import resample

from einops import rearrange, repeat, reduce

Expand Down Expand Up @@ -153,6 +154,17 @@ def float32_to_int16(x):
def zero_mean_unit_var_norm(x):
return (x - x.mean(dim=-1, keepdim=True)) / torch.sqrt(x.var(dim=-1, keepdim=True) + 1e-7)

def prepare_audio(data, sample_hz, target_sample_hz, normalize=True, target_length_seconds=None):
if data.shape[0] > 1:
data = torch.mean(data, dim=0).unsqueeze(0)
if normalize:
data = zero_mean_unit_var_norm(data)
if exists(target_length_seconds) and data.shape[1] > target_length_seconds * sample_hz:
data = data[: , :int(target_length_seconds * sample_hz)]
audio_for_wav2vec = resample(data, sample_hz, target_sample_hz)
audio_for_wav2vec = int16_to_float32(float32_to_int16(audio_for_wav2vec))
return audio_for_wav2vec

# helper for saving config

def copy_file_to_folder(file_path: str, folder_path: str):
Expand Down
9 changes: 9 additions & 0 deletions scripts/infer_top_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
parser.add_argument('prompt', help='prompts to generate audio for', type=str, nargs='+')
parser.add_argument('--num_samples', default=4, type=int)
parser.add_argument('--num_top_matches', default=1, type=int)
parser.add_argument('--input_audio', default=None, type=str, help='input audio to condition on and generate continuations from')
parser.add_argument('--model_config', default='./configs/model/musiclm_small.json', help='path to model config')
parser.add_argument('--semantic_path', required=True, help='path to semantic stage checkpoint')
parser.add_argument('--coarse_path', required=True, help='path to coarse stage checkpoint')
Expand All @@ -35,6 +36,7 @@
semantic_path = args.semantic_path
coarse_path = args.coarse_path
fine_path = args.fine_path
input_audio = args.input_audio
return_coarse_wave = args.return_coarse_wave
duration = args.duration
kmeans_path = args.kmeans_path
Expand All @@ -59,8 +61,15 @@

print(f'prompt: {args.prompt}')

prime_wave, prime_wave_sample_hz = None, None
if input_audio is not None:
prime_wave, prime_wave_sample_hz = torchaudio.load(input_audio)
prime_wave = prime_wave.to(device)

generated_wave, similarities = musiclm.generate_top_match(
text=args.prompt,
prime_wave=prime_wave,
prime_wave_sample_hz=prime_wave_sample_hz,
num_samples=args.num_samples,
num_top_matches=args.num_top_matches,
output_seconds=duration,
Expand Down

0 comments on commit 4ee53d9

Please sign in to comment.