diff --git a/requirements-arm64.txt b/requirements-arm64.txt index c0c8b79..6935ec9 100644 --- a/requirements-arm64.txt +++ b/requirements-arm64.txt @@ -49,7 +49,7 @@ PyYAML>=4.2b1 rsa==4.7 scipy<1.11.0 scikit-learn<1.2.0 -setuptools>=41.0.0 +setuptools>=41.0.0,<65.0.0 six~=1.15.0 tensorflow-macos~=2.12.0 termcolor==1.1.0 diff --git a/requirements-stretch.txt b/requirements-stretch.txt index 728c565..8f97d64 100644 --- a/requirements-stretch.txt +++ b/requirements-stretch.txt @@ -1 +1,2 @@ -aeneas~=1.7.3.0 \ No newline at end of file +aeneas~=1.7.3.0 +dtw-python~=1.5.3 \ No newline at end of file diff --git a/subaligner/__main__.py b/subaligner/__main__.py index cf0a58d..804d84d 100755 --- a/subaligner/__main__.py +++ b/subaligner/__main__.py @@ -270,6 +270,7 @@ def main(): if FLAGS.stretch_on or FLAGS.mode == "script": try: import aeneas + import dtw except ModuleNotFoundError: print('ERROR: Alignment has been configured to use extra features. Please install "subaligner[stretch]" and run your command again.') sys.exit(21) diff --git a/subaligner/lib/language.py b/subaligner/lib/language.py index b2e2ca8..a828f0f 100644 --- a/subaligner/lib/language.py +++ b/subaligner/lib/language.py @@ -406,3 +406,63 @@ class Language(object): CODE_TO_HUMAN_LIST = sorted([u"%s\t%s" % (k, v) for k, v in CODE_TO_HUMAN.items()]) """ List of all language codes with their human-readable names """ + + LANGUAGE_TO_VOICE_CODE = { + AFR: "af", + ARG: "an", + BOS: "bs", + BUL: "bg", + CAT: "ca", + CES: "cs", + CMN: "zh", + CYM: "cy", + DAN: "da", + DEU: "de", + ELL: "el", + ENG: "en", + EPO: "eo", + EST: "et", + FAS: "fa", + FIN: "fi", + FRA: "fr", + GLE: "ga", + GRC: "grc", + HIN: "hi", + HRV: "hr", + HUN: "hu", + HYE: "hy", + IND: "id", + ISL: "is", + ITA: "it", + JBO: "jbo", + KAN: "kn", + KAT: "ka", + KUR: "ku", + LAT: "la", + LAV: "lv", + LFN: "lfn", + LIT: "lt", + MAL: "ml", + MKD: "mk", + MSA: "ms", + NEP: "ne", + NLD: "nl", + NOR: "no", + PAN: "pa", + POL: "pl", + POR: "pt", + RON: "ro", + RUS: "ru", + SLK: "sk", + SPA: "es", + SQI: "sq", + SRP: "sr", + SWA: "sw", + SWE: "sv", + TAM: "ta", + TUR: "tr", + UKR: "uk", + VIE: "vi", + YUE: "zh-yue", + ZHO: "zh", + } diff --git a/subaligner/predictor.py b/subaligner/predictor.py index 394fb6c..323780b 100644 --- a/subaligner/predictor.py +++ b/subaligner/predictor.py @@ -6,19 +6,23 @@ import gc import math import logging +import tempfile +import librosa import numpy as np import multiprocessing as mp +import soundfile as sf from typing import Tuple, List, Optional, Dict, Any, Iterable, Union +from copy import deepcopy from pysrt import SubRipTime, SubRipItem, SubRipFile from sklearn.metrics import log_loss -from copy import deepcopy from .network import Network from .embedder import FeatureEmbedder from .media_helper import MediaHelper from .subtitle import Subtitle from .hyperparameters import Hyperparameters -from .exception import TerminalException -from .exception import NoFrameRateException +from .lib.language import Language +from .utils import Utils +from .exception import TerminalException, NoFrameRateException from .logger import Logger @@ -445,7 +449,7 @@ def _predict_in_multithreads( gc.collect() if stretch: - subs_new = self.__adjust_durations(subs_new, audio_file_path, stretch_in_lang, lock) + subs_new = self.__compress_and_stretch(subs_new, audio_file_path, stretch_in_lang, lock) self.__LOGGER.info("[{}] Segment {} stretched".format(os.getpid(), segment_index)) return subs_new except Exception as e: @@ -715,6 +719,111 @@ def __adjust_durations(self, subs: List[SubRipItem], audio_file_path: str, stret if task.sync_map_file_path_absolute is not None and os.path.exists(task.sync_map_file_path_absolute): os.remove(task.sync_map_file_path_absolute) + def __compress_and_stretch(self, subs: List[SubRipItem], audio_file_path: str, stretch_in_lang: str, lock: threading.RLock) -> List[SubRipItem]: + from dtw import dtw + try: + with lock: + segment_path, _ = self.__media_helper.extract_audio_from_start_to_end( + audio_file_path, + str(subs[0].start), + str(subs[len(subs) - 1].end), + ) + + # Create a text file for DTW alignments + root, _ = os.path.splitext(segment_path) + text_file_path = "{}.txt".format(root) + + with open(text_file_path, "w", encoding="utf8") as text_file: + text_file.write("*****".join([sub_new.text for sub_new in subs])) + + sample_rate = self.__feature_embedder.frequency + hop_length = self.__feature_embedder.hop_len + n_mfcc = self.__feature_embedder.n_mfcc + + file_script_duration_mapping = [] + with tempfile.TemporaryDirectory() as temp_dir: + with open(text_file_path, "r") as f: + script_lines = f.read().split("*****") + wav_data = [] + for i, line in enumerate(script_lines): + normalised_line = line.replace('"', "'") + espeak_output_file = f"espeak_part_{i}.wav" + espeak_cmd = f"espeak -v {Language.LANGUAGE_TO_VOICE_CODE[stretch_in_lang]} --stdout -- \"{normalised_line}\" | ffmpeg -y -i - -af 'aresample={sample_rate}' {os.path.join(temp_dir, espeak_output_file)}" + os.system(espeak_cmd) + y, sr = librosa.load(os.path.join(temp_dir, espeak_output_file), sr=None) + wav_data.append(y) + duration = librosa.get_duration(y=y, sr=sr) + file_script_duration_mapping.append((os.path.join(temp_dir, espeak_output_file), line, duration)) + data = np.concatenate(wav_data) + sf.write(os.path.join(temp_dir, "espeak-all.wav"), data, sr) + + y_query, sr_query = librosa.load(os.path.join(temp_dir, "espeak-all.wav"), sr=None) + query_mfcc_features = librosa.feature.mfcc(y=y_query, sr=sr_query, n_mfcc=n_mfcc, hop_length=hop_length).T + y_reference, sr_reference = librosa.load(segment_path, sr=sample_rate) + reference_mfcc_features = librosa.feature.mfcc(y=y_reference, sr=sr_reference, n_mfcc=n_mfcc, hop_length=hop_length).T + + alignment = dtw(query_mfcc_features, reference_mfcc_features, keep_internals=False) + assert len(alignment.index1) == len(alignment.index2), "Mismatch in lengths of alignment indices" + assert sr_query == sr_reference + frame_duration = hop_length / sr_query + + mapped_times = [] + start_frame_index = 0 + for index, (wav_file, line_text, duration) in enumerate(file_script_duration_mapping): + num_frames_in_query = int(np.ceil(duration / frame_duration)) + + query_start_frame = start_frame_index + query_end_frame = start_frame_index + num_frames_in_query - 1 + reference_frame_indices = [r for q, r in zip(alignment.index1, alignment.index2) if + query_start_frame <= q <= query_end_frame] + reference_start_frame = min(reference_frame_indices) + reference_end_frame = max(reference_frame_indices) + + # TODO: Handle cases where mapped frames are not found in the reference audio + + new_reference_start_time = reference_start_frame * frame_duration + new_reference_end_time = (reference_end_frame + 1) * frame_duration + + mapped_times.append({ + "new_reference_start_time": new_reference_start_time, + "new_reference_end_time": new_reference_end_time + }) + + start_frame_index = query_end_frame + 1 + + with open(os.path.join(temp_dir, "synced_subtitles.srt"), "w") as f: + for index, entry in enumerate(mapped_times): + start_srt = Utils.format_timestamp(entry["new_reference_start_time"]) + end_srt = Utils.format_timestamp(entry["new_reference_end_time"]) + f.write(f"{index + 1}\n") + f.write(f"{start_srt} --> {end_srt}\n") + f.write(f"{script_lines[index]}\n") + f.write(f"\n") + f.flush() + + adjusted_subs = Subtitle._get_srt_subs( + subrip_file_path=os.path.join(temp_dir, "synced_subtitles.srt"), + encoding="utf-8" + ) + + for index, sub_new_loaded in enumerate(adjusted_subs): + sub_new_loaded.index = subs[index].index + + adjusted_subs.shift( + seconds=self.__media_helper.get_duration_in_seconds( + start=None, end=str(subs[0].start) + ) + ) + return adjusted_subs + except KeyboardInterrupt: + raise TerminalException("Subtitle compress and stretch interrupted by the user") + finally: + # Housekeep intermediate files + if text_file_path is not None and os.path.exists( + text_file_path + ): + os.remove(text_file_path) + def __predict( self, video_file_path: Optional[str], diff --git a/subaligner/subaligner_2pass/__main__.py b/subaligner/subaligner_2pass/__main__.py index d4271a3..1e4e7e7 100755 --- a/subaligner/subaligner_2pass/__main__.py +++ b/subaligner/subaligner_2pass/__main__.py @@ -171,6 +171,7 @@ def main(): if FLAGS.stretch_on: try: import aeneas + import dtw except ModuleNotFoundError: print('ERROR: Alignment has been configured to use extra features. Please install "subaligner[stretch]" and run your command again.') sys.exit(21)