From 77dd7a91deccce878281f6e2e9976c7bcca70a25 Mon Sep 17 00:00:00 2001 From: baxtree Date: Mon, 17 Jun 2024 09:55:45 +0100 Subject: [PATCH] make various timeouts configurable via CLIs --- subaligner/__main__.py | 30 +++- subaligner/_version.py | 2 +- subaligner/embedder.py | 51 +++---- subaligner/hparam_tuner.py | 29 ++-- subaligner/hyperparameters.py | 15 +- subaligner/media_helper.py | 29 ++-- subaligner/network.py | 50 +++---- subaligner/predictor.py | 64 ++++---- subaligner/subaligner_1pass/__main__.py | 27 +++- subaligner/subaligner_2pass/__main__.py | 28 +++- subaligner/subaligner_train/__main__.py | 27 +++- subaligner/subtitle.py | 94 ++++++------ subaligner/trainer.py | 42 +++--- subaligner/transcriber.py | 18 +-- subaligner/translator.py | 23 ++- subaligner/utils.py | 28 ++-- tests/integration/feature/subaligner.feature | 13 ++ .../feature/subaligner_train.feature | 138 +++++++++--------- tests/integration/radish/step.py | 20 +++ tests/subaligner/test_predictor.py | 6 +- tests/subaligner/test_trainer.py | 5 +- 21 files changed, 414 insertions(+), 325 deletions(-) diff --git a/subaligner/__main__.py b/subaligner/__main__.py index a0c8aad..95b84dd 100755 --- a/subaligner/__main__.py +++ b/subaligner/__main__.py @@ -4,15 +4,15 @@ [-sil {afr,amh,ara,arg,asm,aze,ben,bos,bul,cat,ces,cmn,cym,dan,deu,ell,eng,epo,est,eus,fas,fin,fra,gla,gle,glg,grc,grn,guj,heb,hin,hrv,hun,hye,ina,ind,isl,ita,jbo,jpn,kal,kan,kat,kir,kor,kur,lat,lav,lfn,lit,mal,mar,mkd,mlt,msa,mya,nah,nep,nld,nor,ori,orm,pan,pap,pol,por,ron,rus,sin,slk,slv,spa,sqi,srp,swa,swe,tam,tat,tel,tha,tsn,tur,ukr,urd,vie,yue,zho}] [-fos] [-tod TRAINING_OUTPUT_DIRECTORY] [-o OUTPUT] [-t TRANSLATE] [-os OFFSET_SECONDS] [-ml {afr,amh,ara,arg,asm,aze,ben,bos,bul,cat,ces,cmn,cym,dan,deu,ell,eng,epo,est,eus,fas,fin,fra,gla,gle,glg,grc,grn,guj,heb,hin,hrv,hun,hye,ina,ind,isl,ita,jbo,jpn,kal,kan,kat,kir,kor,kur,lat,lav,lfn,lit,mal,mar,mkd,mlt,msa,mya,nah,nep,nld,nor,ori,orm,pan,pap,pol,por,ron,rus,sin,slk,slv,spa,sqi,srp,swa,swe,tam,tat,tel,tha,tsn,tur,ukr,urd,vie,yue,zho}] - [-mr {whisper}] [-mf {tiny,tiny.en,small,medium,medium.en,base,base.en,large-v1,large-v2,large-v3,large}] [-tr {helsinki-nlp,whisper,facebook-mbart}] - [-tf TRANSLATION_FLAVOUR] [-lgs] [-d] [-q] [-ver] + [-mr {whisper}] [-mf {tiny,tiny.en,small,medium,medium.en,base,base.en,large-v1,large-v2,large-v3,large}] [-tr {helsinki-nlp,whisper,facebook-mbart}] [-tf TRANSLATION_FLAVOUR] + [-mpt MEDIA_PROCESS_TIMEOUT] [-sat SEGMENT_ALIGNMENT_TIMEOUT] [-lgs] [-d] [-q] [-ver] -Subaligner command line interface +Subaligner command line interface (v0.3.7) -optional arguments: +options: -h, --help show this help message and exit -s SUBTITLE_PATH [SUBTITLE_PATH ...], --subtitle_path SUBTITLE_PATH [SUBTITLE_PATH ...] - File path or URL to the subtitle file (Extensions of supported subtitles: .ttml, .ssa, .stl, .sbv, .dfxp, .srt, .txt, .ytt, .vtt, .sub, .sami, .xml, .scc, .ass, .smi, .tmp) or selector for the embedded subtitle (e.g., embedded:page_num=888 or embedded:stream_index=0) + File path or URL to the subtitle file (Extensions of supported subtitles: .scc, .tmp, .sami, .stl, .ttml, .dfxp, .srt, .ssa, .ass, .sub, .sbv, .xml, .ytt, .smi, .txt, .vtt) or selector for the embedded subtitle (e.g., embedded:page_num=888 or embedded:stream_index=0) -l MAX_LOGLOSS, --max_logloss MAX_LOGLOSS Max global log loss for alignment -so, --stretch_on Switch on stretch on subtitles) @@ -38,6 +38,10 @@ LLM recipe used for translating subtitles -tf TRANSLATION_FLAVOUR, --translation_flavour TRANSLATION_FLAVOUR Flavour variation for a specific LLM recipe supporting translation + -mpt MEDIA_PROCESS_TIMEOUT, --media_process_timeout MEDIA_PROCESS_TIMEOUT + Maximum waiting time in seconds when processing media files + -sat SEGMENT_ALIGNMENT_TIMEOUT, --segment_alignment_timeout SEGMENT_ALIGNMENT_TIMEOUT + Maximum waiting time in seconds when aligning each segment -lgs, --languages Print out language codes used for stretch and translation -d, --debug Print out debugging information -q, --quiet Switch off logging information @@ -191,6 +195,20 @@ def main(): default=None, help="Flavour variation for a specific LLM recipe supporting translation" ) + parser.add_argument( + "-mpt", + "--media_process_timeout", + type=int, + default=180, + help="Maximum waiting time in seconds when processing media files" + ) + parser.add_argument( + "-sat", + "--segment_alignment_timeout", + type=int, + default=60, + help="Maximum waiting time in seconds when aligning each segment" + ) parser.add_argument("-lgs", "--languages", action="store_true", help="Print out language codes used for stretch and translation") parser.add_argument("-d", "--debug", action="store_true", @@ -301,7 +319,7 @@ def main(): sys.exit(21) voice_probabilities = None - predictor = Predictor() + predictor = Predictor(media_process_timeout=FLAGS.media_process_timeout, segment_alignment_timeout=FLAGS.segment_alignment_timeout) if FLAGS.mode == "single": aligned_subs, audio_file_path, voice_probabilities, frame_rate = predictor.predict_single_pass( video_file_path=local_video_path, diff --git a/subaligner/_version.py b/subaligner/_version.py index f107c73..a1d47ed 100644 --- a/subaligner/_version.py +++ b/subaligner/_version.py @@ -1,2 +1,2 @@ """The semver for the current release.""" -__version__ = "0.3.6" +__version__ = "0.3.7" diff --git a/subaligner/embedder.py b/subaligner/embedder.py index 0f47c10..d2dce2d 100644 --- a/subaligner/embedder.py +++ b/subaligner/embedder.py @@ -11,6 +11,13 @@ class FeatureEmbedder(object): """Audio and subtitle feature embedding. + + Keyword Arguments: + n_mfcc {int} -- The number of MFCC components (default: {13}). + frequency {float} -- The sample rate (default: {16000}). + hop_len {int} -- The number of samples per frame (default: {512}). + step_sample {float} -- The space (in seconds) between the beginning of each sample (default: 1s / 25 FPS = 0.04s). + len_sample {float} -- The length in seconds for the input samples (default: {0.075}). """ def __init__( @@ -21,16 +28,6 @@ def __init__( step_sample: float = 0.04, len_sample: float = 0.075, ) -> None: - """Feature embedder initialiser. - - Keyword Arguments: - n_mfcc {int} -- The number of MFCC components (default: {13}). - frequency {float} -- The sample rate (default: {16000}). - hop_len {int} -- The number of samples per frame (default: {512}). - step_sample {float} -- The space (in seconds) between the beginning of each sample (default: 1s / 25 FPS = 0.04s). - len_sample {float} -- The length in seconds for the input samples (default: {0.075}). - """ - self.__n_mfcc = n_mfcc # number of MFCC components self.__frequency = frequency # sample rate self.__hop_len = hop_len # number of samples per frame @@ -50,7 +47,7 @@ def n_mfcc(self) -> int: """Get the number of MFCC components. Returns: - int -- The number of MFCC components. + int: The number of MFCC components. """ return self.__n_mfcc @@ -60,7 +57,7 @@ def frequency(self) -> int: """Get the sample rate. Returns: - int -- The sample rate. + int: The sample rate. """ return self.__frequency @@ -70,23 +67,23 @@ def hop_len(self) -> int: """Get the number of samples per frame. Returns: - int -- The number of samples per frame. + int: The number of samples per frame. """ return self.__hop_len @property def step_sample(self) -> float: - """The space (in seconds) between the begining of each sample. + """The space (in seconds) between the beginning of each sample. Returns: - float -- The space (in seconds) between the begining of each sample. + float: The space (in seconds) between the beginning of each sample. """ return self.__step_sample @step_sample.setter - def step_sample(self, step_sample: int) -> None: + def step_sample(self, step_sample: float) -> None: """Configure the step sample Arguments: @@ -100,7 +97,7 @@ def len_sample(self) -> float: """Get the length in seconds for the input samples. Returns: - float -- The length in seconds for the input samples. + float: The length in seconds for the input samples. """ return self.__item_time @@ -113,7 +110,7 @@ def time_to_sec(cls, pysrt_time: SubRipTime) -> float: pysrt_time {pysrt.SubRipTime} -- SubRipTime or coercible. Returns: - float -- The number of seconds. + float: The number of seconds. """ # There is a weird bug in pysrt triggered by a programatically generated # subtitle with start time "00:00:00,000". When it occurs, .millisecond @@ -133,7 +130,7 @@ def get_len_mfcc(self) -> float: """Get the number of samples to get LEN_SAMPLE: LEN_SAMPLE/(HOP_LEN/FREQUENCY). Returns: - float -- The number of samples. + float: The number of samples. """ return self.__len_sample / (self.__hop_len / self.__frequency) @@ -142,7 +139,7 @@ def get_step_mfcc(self) -> float: """Get the number of samples to get STEP_SAMPLE: STEP_SAMPLE/(HOP_LEN/FREQUENCY). Returns: - float -- The number of samples. + float: The number of samples. """ return self.__step_sample / (self.__hop_len / self.__frequency) @@ -154,7 +151,7 @@ def time_to_position(self, pysrt_time: SubRipTime) -> int: pysrt_time {pysrt.SubRipTime} -- SubRipTime or coercible. Returns: - int -- The cell position. + int: The cell position. """ return int( @@ -170,7 +167,7 @@ def duration_to_position(self, seconds: float) -> int: seconds {float} -- The duration in seconds. Returns: - int -- The cell position. + int: The cell position. """ return int( @@ -184,7 +181,7 @@ def position_to_duration(self, position: int) -> float: position {int} -- The cell position. Returns: - float -- The number of seconds. + float: The number of seconds. """ return ( @@ -198,7 +195,7 @@ def position_to_time_str(self, position: int) -> str: position {int} -- The cell position. Returns: - string -- The time string (e.g., 01:23:20,150). + str: The time string (e.g., 01:23:20,150). """ td = timedelta( @@ -247,11 +244,11 @@ def extract_data_and_label_from_audio( Keyword Arguments: subtitles {pysrt.SubRipFile} -- The SubRipFile object (default: {None}). - sound_effect_start_marker: {string} -- A string indicating the start of the ignored sound effect (default: {None}). - sound_effect_end_marker: {string} -- A string indicating the end of the ignored sound effect (default: {None}). + sound_effect_start_marker {string} -- A string indicating the start of the ignored sound effect (default: {None}). + sound_effect_end_marker {string} -- A string indicating the end of the ignored sound effect (default: {None}). Returns: - tuple -- The training data and the training lables. + tuple: The training data and the training lables. Raises: TerminalException: Thrown when the subtitles are missing. diff --git a/subaligner/hparam_tuner.py b/subaligner/hparam_tuner.py index dd779f8..26bee68 100644 --- a/subaligner/hparam_tuner.py +++ b/subaligner/hparam_tuner.py @@ -10,7 +10,21 @@ class HyperParameterTuner(object): - """Hyperparameter tuning using the Tree of Parzen Estimators algorithm""" + """Hyperparameter tuning using the Tree of Parzen Estimators algorithm + + Arguments: + av_file_paths {list}: A list of paths to the input audio/video files. + subtitle_file_paths {list}: A list of paths to the subtitle files. + training_dump_dir {string}: The directory of the training data dump file. + + Keyword Arguments: + av_file_paths {List[str]} -- The list of audiovisual file paths. + subtitle_file_paths List[str] -- The list of subtitle files. + training_dump_dir: {string} -- The directory path of the training dump. + num_of_trials {int} -- The number of trials for tuning (default: {5}). + tuning_epochs {int} -- The number of training epochs for each trial (default: {5}). + network_type {string} -- The type of the network (default: {"lstm"}, range: ["lstm", "bi_lstm", "conv_1d"]). + """ SEARCH_SPACE = { "learning_rate": hp.loguniform("learning_rate", np.log(0.00001), np.log(0.1)), @@ -30,19 +44,6 @@ def __init__(self, tuning_epochs: int = 5, network_type: str = Network.LSTM, **kwargs) -> None: - """Hyperparameter tuner initialiser - - Arguments: - av_file_paths {list} -- A list of paths to the input audio/video files. - subtitle_file_paths {list} -- A list of paths to the subtitle files. - training_dump_dir {string} -- The directory of the training data dump file. - - Keyword Arguments: - num_of_trials {int} -- The number of trials for tuning (default: {5}). - tuning_epochs {int} -- The number of training epochs for each trial (default: {5}). - network_type {string} -- The type of the network (default: {"lstm"}, range: ["lstm", "bi_lstm", "conv_1d"]). - """ - assert network_type in Network.TYPES, "Supported network type values: %s" % Network.TYPES hyperparameters = Hyperparameters() hyperparameters.network_type = network_type diff --git a/subaligner/hyperparameters.py b/subaligner/hyperparameters.py index 94aa818..50b60ee 100644 --- a/subaligner/hyperparameters.py +++ b/subaligner/hyperparameters.py @@ -10,8 +10,6 @@ class Hyperparameters(object): OPTIMIZERS = ["adadelta", "adagrad", "adam", "adamax", "ftrl", "nadam", "rmsprop", "sgd"] def __init__(self) -> None: - """Hyperparameters initialiser setting default values""" - self.__learning_rate = 0.001 self.__hidden_size = { "front_layers": [64], @@ -33,8 +31,11 @@ def __init__(self) -> None: def __eq__(self, other: Any) -> bool: """Comparator for Hyperparameters objects + Arguments: + other {Any} -- Any comparable object + Returns: - bool -- If True, the compared hyperparameter object is the same + bool: If True, the compared hyperparameter object is the same """ if isinstance(other, Hyperparameters): @@ -195,7 +196,7 @@ def to_json(self) -> str: """Serialise hyperparameters into JSON string Returns: - string -- The serialised hyperparameters in JSON + str: The serialised hyperparameters in JSON """ return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) @@ -212,7 +213,7 @@ def clone(self) -> "Hyperparameters": """Make a cloned hyperparameters object Returns: - Hyperparameters -- The cloned Hyperparameters object. + Hyperparameters: The cloned Hyperparameters object. """ return self.from_json(self.to_json()) @@ -224,7 +225,7 @@ def from_json(cls, json_str: str) -> "Hyperparameters": json_str {string} -- Hyperparameters in JSON. Returns: - Hyperparameters -- The deserialised Hyperparameters object. + Hyperparameters: The deserialised Hyperparameters object. """ hp = cls() hp.__dict__ = json.loads(json_str) @@ -238,7 +239,7 @@ def from_file(cls, file_path: str) -> "Hyperparameters": file_path {string} -- The path to the file containing hyperparameters. Returns: - Hyperparameters -- The deserialised Hyperparameters object. + Hyperparameters: The deserialised Hyperparameters object. """ with open(file_path, "r", encoding="utf8") as file: return cls.from_json(file.read()) diff --git a/subaligner/media_helper.py b/subaligner/media_helper.py index 008ecc5..bdd55ae 100644 --- a/subaligner/media_helper.py +++ b/subaligner/media_helper.py @@ -1,6 +1,5 @@ import subprocess import os -import threading import traceback import tempfile import shutil @@ -29,6 +28,9 @@ def clear_temp(*_): class MediaHelper(object): """ Utility for processing media assets including audio, video and subtitle files. + + Arguments: + media_process_timeout {int} -- The timeout in seconds on processing media files. """ FFMPEG_BIN = os.getenv("FFMPEG_PATH") or os.getenv("ffmpeg_path") or "ffmpeg" @@ -39,13 +41,13 @@ class MediaHelper(object): __MIN_GAP_IN_SECS = ( 1 # minimum gap in seconds between consecutive subtitle during segmentation ) - __CMD_TIME_OUT = 180 # time out for subprocess atexit.register(clear_temp) signal.signal(signal.SIGTERM, clear_temp) - def __init__(self): + def __init__(self, media_process_timeout: int = 180) -> None: self.__LOGGER = Logger().get_logger(__name__) + self.__media_process_timeout = media_process_timeout def extract_audio(self, video_file_path, decompress: bool = False, freq: int = 16000) -> str: """Extract audio track from the video file and save it to a WAV file. @@ -58,7 +60,7 @@ def extract_audio(self, video_file_path, decompress: bool = False, freq: int = 1 freq {int} -- The audio sample frequency (default: {16000}). Returns: - string -- The file path of the extracted audio. + str: The file path of the extracted audio. Raises: TerminalException: If audio extraction is interrupted by user hitting the interrupt key or timed out. @@ -88,7 +90,6 @@ def extract_audio(self, video_file_path, decompress: bool = False, freq: int = 1 self.FFMPEG_BIN, Utils.double_quoted(video_file_path), Utils.double_quoted(audio_file_path) ) ) - print(command) with subprocess.Popen( shlex.split(command), shell=False, @@ -100,7 +101,7 @@ def extract_audio(self, video_file_path, decompress: bool = False, freq: int = 1 ) as process: try: self.__LOGGER.debug("[{}] Running: {}".format(process.pid, command)) - _, std_err = process.communicate(timeout=self.__CMD_TIME_OUT) + _, std_err = process.communicate(timeout=self.__media_process_timeout) self.__LOGGER.debug("[{}] {}".format(process.pid, std_err)) if process.returncode != 0: self.__LOGGER.error("[{}] Cannot extract audio from video: {}\n{}" @@ -151,7 +152,7 @@ def get_duration_in_seconds(self, start: Optional[str], end: Optional[str]) -> O end {string} -- The end time (e.g., 00:00:10,230). Returns: - float -- The duration in seconds. + Optional[float]: The duration in seconds. """ if start is None: @@ -178,7 +179,7 @@ def extract_audio_from_start_to_end(self, audio_file_path: str, start: str, end: end {string} -- The end time (e.g., 00:00:10,230) (default: {None}). Returns: - tuple -- The file path to the extracted audio and its duration. + tuple: The file path to the extracted audio and its duration. Raises: TerminalException: If audio extraction is interrupted by user hitting the interrupt key or timed out. @@ -211,7 +212,7 @@ def extract_audio_from_start_to_end(self, audio_file_path: str, start: str, end: ) as process: self.__LOGGER.debug("[{}] Running: {}".format(process.pid, command)) try: - _, std_err = process.communicate(timeout=self.__CMD_TIME_OUT) + _, std_err = process.communicate(timeout=self.__media_process_timeout) self.__LOGGER.debug("[{}] {}".format(process.pid, std_err)) if process.returncode != 0: self.__LOGGER.error("[{}] Cannot clip audio: {} Return Code: {}\n{}" @@ -232,7 +233,7 @@ def extract_audio_from_start_to_end(self, audio_file_path: str, start: str, end: if os.path.exists(segment_path): os.remove(segment_path) raise TerminalException( - "Timeout on extracting audio from audio: {} after {} seconds".format(audio_file_path, self.__CMD_TIME_OUT) + "Timeout on extracting audio from audio: {} after {} seconds".format(audio_file_path, self.__media_process_timeout) ) from e except Exception as e: self.__LOGGER.error( @@ -270,7 +271,7 @@ def get_audio_segment_starts_and_ends(self, subs: List[SubRipItem]) -> Tuple[Lis subs {list} -- A list of SupRip cues. Returns: - tuple -- A list of start times, a list of end times and a list of grouped SubRip files. + tuple: A list of start times, a list of end times and a list of grouped SubRip files. """ local_subs = self.__preprocess_subs(subs) @@ -323,7 +324,7 @@ def get_frame_rate(self, file_path: str) -> float: file_path {string} -- The input audiovisual file path. Returns: - float -- The frame rate + float: The frame rate Raises: TerminalException: If frame rate extraction is interrupted by user hitting the interrupt key or timed out. @@ -352,7 +353,7 @@ def get_frame_rate(self, file_path: str) -> float: bufsize=1, ) as process: try: - std_out, std_err = process.communicate(timeout=self.__CMD_TIME_OUT) + std_out, std_err = process.communicate(timeout=self.__media_process_timeout) if process.returncode != 0: self.__LOGGER.warning("[{}] Cannot extract the frame rate from video: {}\n{}".format(process.pid, file_path, std_err)) raise NoFrameRateException( @@ -396,7 +397,7 @@ def refragment_with_min_duration(self, subs: List[SubRipItem], minimum_segment_d minimum_segment_duration {float} -- The minimum duration in seconds for each output subtitle cue. Returns: - list -- A list of new SupRip cues after fragmentation. + list: A list of new SupRip cues after fragmentation. """ new_segment = [] new_segment_index = 0 diff --git a/subaligner/network.py b/subaligner/network.py index c2e0c9f..48a7a91 100644 --- a/subaligner/network.py +++ b/subaligner/network.py @@ -6,7 +6,7 @@ import tensorflow as tf import tensorflow.keras.optimizers as tf_optimizers -from typing import Tuple, Optional, Any, List, Generator +from typing import Tuple, Optional, List, Generator from tensorflow.keras.layers import ( Dense, Input, @@ -38,6 +38,17 @@ class Network(object): """ Network factory creates DNNs. Not thread safe since the session of keras_backend is global. Only factory methods are allowed when generating DNN objects. + + Arguments: + secret {object} -- A hash only known by factory methods. + input_shape {tuple} -- A shape tuple (integers), not including the batch size. + hyperparameters {Hyperparameters} -- A configuration for hyperparameters used for training. + model_path {string} -- The path to the model file. + backend {string} -- The tensor manipulation backend (default: {tensorflow}). Only tensorflow is supported + by TF 2 and this parameter is here only for a historical reason. + + Raises: + ValueError: Thrown when the network type is not supported. """ LSTM = "lstm" @@ -56,19 +67,6 @@ def __init__( model_path: Optional[str] = None, backend: str = "tensorflow" ) -> None: - """Network object initialiser used by factory methods. - - Arguments: - secret {object} -- A hash only known by factory methods. - input_shape {tuple} -- A shape tuple (integers), not including the batch size. - hyperparameters {Hyperparameters} -- A configuration for hyperparameters used for training. - model_path {string} -- The path to the model file. - backend {string} -- The tensor manipulation backend (default: {tensorflow}). Only tensorflow is supported - by TF 2 and this parameter is here only for a historical reason. - - Raises: - ValueError: Thrown when the network type is not supported. - """ assert ( secret == Network.__secret ), "Only factory methods are supported when creating instances" @@ -109,7 +107,7 @@ def get_network(cls, input_shape: Tuple, hyperparameters: Hyperparameters) -> "N hyperparameters {Hyperparameters} -- A configuration for hyperparameters used for training. Returns: - Network -- A constructed network object. + Network: A constructed network object. """ return cls( @@ -164,18 +162,18 @@ def load_model_and_weights(model_filepath: str, weights_filepath: str, hyperpara hyperparameters {Hyperparameters} -- A configuration for hyperparameters used for training. Returns: - Network -- Reconstructed network object. + Network: Reconstructed network object. """ network = Network.get_from_model(model_filepath, hyperparameters) network.__model.load_weights(weights_filepath) return network @property - def input_shape(self) -> Tuple: + def input_shape(self) -> tuple: """Get the input shape of the network. Returns: - tuple -- The input shape of the network. + tuple: The input shape of the network. """ return self.__input_shape @@ -185,7 +183,7 @@ def n_type(self) -> str: """Get the type of the network. Returns: - string -- The type of the network. + str: The type of the network. """ return self.__n_type @@ -198,11 +196,11 @@ def summary(self) -> None: self.__model.summary() @property - def layers(self) -> List[Any]: + def layers(self) -> list: """Get the layers of the network. Returns: - list -- The statck of layers contained by the network + list: The stack of layers contained by the network """ return self.__model.layers @@ -215,7 +213,7 @@ def get_predictions(self, input_data: np.ndarray, weights_filepath: str) -> np.n weights_filepath {string} -- The weights file path. Returns: - numpy.ndarray -- The Numpy array of predictions. + numpy.ndarray: The Numpy array of predictions. """ self.__model.load_weights(weights_filepath) return self.__model.predict_on_batch(input_data) @@ -242,7 +240,7 @@ def fit_and_get_history( resume {bool} -- True to continue with previous training result or False to start a new one (default: {False}). Returns: - tuple -- A tuple contains validation losses and validation accuracies. + tuple: A tuple contains validation losses and validation accuracies. Raises: TerminalException: If the predication is interrupted by user hitting the interrupt key @@ -334,7 +332,7 @@ def fit_with_generator( resume {bool} -- True to continue with previous training result or False to start a new one (default: {False}). Returns: - tuple -- A tuple contains validation losses and validation accuracies. + tuple: A tuple contains validation losses and validation accuracies. Raises: TerminalException: If the training is interrupted by user hitting the interrupt key @@ -428,7 +426,7 @@ def simple_fit( hyperparameters {Hyperparameters} -- A configuration for hyperparameters used for training. Returns: - tuple -- A tuple contains validation losses and validation accuracies. + tuple: A tuple contains validation losses and validation accuracies. """ network = cls(cls.__secret, input_shape, hyperparameters) @@ -469,7 +467,7 @@ def simple_fit_with_generator( hyperparameters {Hyperparameters} -- A configuration for hyperparameters used for training. Returns: - tuple -- A tuple contains validation losses and validation accuracies. + tuple: A tuple contains validation losses and validation accuracies. """ network = cls(cls.__secret, input_shape, hyperparameters) diff --git a/subaligner/predictor.py b/subaligner/predictor.py index 6da2aac..304ce5b 100644 --- a/subaligner/predictor.py +++ b/subaligner/predictor.py @@ -25,6 +25,15 @@ class Predictor(metaclass=Singleton): """ Predictor for working out the time to shift subtitles + + Keyword Arguments: + media_process_timeout {int} -- The maximum waiting time in seconds when processing media files. + segment_alignment_timeout {int} -- The maximum waiting time in seconds when aligning each segment. + n_mfcc {int} -- The number of MFCC components (default: {13}). + frequency {float} -- The sample rate (default: {16000}). + hop_len {int} -- The number of samples per frame (default: {512}). + step_sample {float} -- The space (in seconds) between the begining of each sample (default: 1s / 25 FPS = 0.04s). + len_sample {float} -- The length in seconds for the input samples (default: {0.075}). """ __MAX_SHIFT_IN_SECS = ( 100 @@ -34,25 +43,14 @@ class Predictor(metaclass=Singleton): ) # Average 0.3 word per sec multiplies average 6 characters per word __MAX_HEAD_ROOM = 20000 # Maximum duration without subtitle (10 minutes) - __SEGMENT_PREDICTION_TIMEOUT = 60 # Maximum waiting time in seconds when predicting each segment - __THREAD_QUEUE_SIZE = 8 __THREAD_NUMBER = 1 # Do not change - def __init__(self, **kwargs) -> None: - """Feature predictor initialiser. - - Keyword Arguments: - n_mfcc {int} -- The number of MFCC components (default: {13}). - frequency {float} -- The sample rate (default: {16000}). - hop_len {int} -- The number of samples per frame (default: {512}). - step_sample {float} -- The space (in seconds) between the begining of each sample (default: 1s / 25 FPS = 0.04s). - len_sample {float} -- The length in seconds for the input samples (default: {0.075}). - """ - + def __init__(self, media_process_timeout: int = 180, segment_alignment_timeout: int = 60, **kwargs) -> None: + self.__media_helper = MediaHelper(media_process_timeout=media_process_timeout) + self.__segment_alignment_timeout = segment_alignment_timeout self.__feature_embedder = FeatureEmbedder(**kwargs) self.__LOGGER = Logger().get_logger(__name__) - self.__media_helper = MediaHelper() def predict_single_pass( self, @@ -68,7 +66,7 @@ def predict_single_pass( weights_dir {string} -- The the model weights directory. Returns: - tuple -- The shifted subtitles, the audio file path and the voice probabilities of the original audio. + tuple: The shifted subtitles, the audio file path and the voice probabilities of the original audio. """ weights_file_path = self.__get_weights_path(weights_dir) @@ -109,7 +107,7 @@ def predict_dual_pass( exit_segfail {bool} -- True to exit on any segment alignment failures (default: {False}) Returns: - tuple -- The shifted subtitles, the globally shifted subtitles and the voice probabilities of the original audio. + tuple: The shifted subtitles, the globally shifted subtitles and the voice probabilities of the original audio. """ weights_file_path = self.__get_weights_path(weights_dir) @@ -139,7 +137,7 @@ def predict_dual_pass( if os.path.exists(audio_file_path): os.remove(audio_file_path) - def predict_plain_text(self, video_file_path: str, subtitle_file_path: str, stretch_in_lang: str = "eng") -> Tuple: + def predict_plain_text(self, video_file_path: str, subtitle_file_path: str, stretch_in_lang: str = "eng") -> tuple: """Predict time to shift with plain texts Arguments: @@ -148,7 +146,7 @@ def predict_plain_text(self, video_file_path: str, subtitle_file_path: str, stre stretch_in_lang {str} -- The language used for stretching subtitles (default: {"eng"}). Returns: - tuple -- The shifted subtitles, the audio file path (None) and the voice probabilities of the original audio (None). + tuple: The shifted subtitles, the audio file path (None) and the voice probabilities of the original audio (None). Raises: TerminalException: If the predication is interrupted by user hitting the interrupt key. @@ -227,7 +225,7 @@ def get_log_loss(self, voice_probabilities: np.ndarray, subs: List[SubRipItem]) subs {list} -- A list of subtitle segments. Returns: - float -- The loss value. + float: The loss value. Raises: TerminalException: If the subtitle mask is empty. @@ -266,7 +264,7 @@ def get_min_log_loss_and_index(self, voice_probabilities: np.ndarray, subs: SubR subs {list} -- A list of subtitle segments. Returns: - tuple -- The minimum loss value and its position. + tuple: The minimum loss value and its position. Raises: TerminalException: If subtitle is empty or suspicious audio/subtitle duration is detected. @@ -367,14 +365,14 @@ def _predict_in_multiprocesses( ) for i, future in enumerate(futures): try: - new_subs = future.result(timeout=Predictor.__SEGMENT_PREDICTION_TIMEOUT) + new_subs = future.result(timeout=self.__segment_alignment_timeout) except concurrent.futures.TimeoutError as e: - self.__cancel_futures(futures[i:], Predictor.__SEGMENT_PREDICTION_TIMEOUT) - message = "Segment alignment timed out after {} seconds".format(Predictor.__SEGMENT_PREDICTION_TIMEOUT) + self.__cancel_futures(futures[i:], self.__segment_alignment_timeout) + message = "Segment alignment timed out after {} seconds".format(self.__segment_alignment_timeout) self.__LOGGER.error(message) raise TerminalException(message) from e except Exception as e: - self.__cancel_futures(futures[i:], Predictor.__SEGMENT_PREDICTION_TIMEOUT) + self.__cancel_futures(futures[i:], self.__segment_alignment_timeout) message = "Exception on segment alignment: {}\n{}".format(str(e), "".join(traceback.format_stack())) self.__LOGGER.error(e, exc_info=True, stack_info=True) traceback.print_tb(e.__traceback__) @@ -383,7 +381,7 @@ def _predict_in_multiprocesses( else: raise TerminalException(message) from e except KeyboardInterrupt: - self.__cancel_futures(futures[i:], Predictor.__SEGMENT_PREDICTION_TIMEOUT) + self.__cancel_futures(futures[i:], self.__segment_alignment_timeout) raise TerminalException("Segment alignment interrupted by the user") else: self.__LOGGER.debug("Segment aligned") @@ -532,13 +530,13 @@ def __predict_2nd_pass(self, audio_file_path: str, subs: List[SubRipItem], weigh Arguments: audio_file_path {string} -- The file path of the original audio. subs {list} -- A list of SubRip files. - weights_file_path {string} -- The file path of the weights file. + weights_file_path {string}: The file path of the weights file. stretch {bool} -- True to stretch the subtitle segments. stretch_in_lang {str} -- The language used for stretching subtitles. exit_segfail {bool} -- True to exit on any segment alignment failures. Returns: - list -- A list of aligned SubRip files + list: A list of aligned SubRip files Raises: TerminalException: If the alignment is interrupted by user hitting the interrupt key or times out @@ -590,14 +588,14 @@ def __predict_2nd_pass(self, audio_file_path: str, subs: List[SubRipItem], weigh ] for i, future in enumerate(futures): try: - subs_list.extend(future.result(timeout=Predictor.__SEGMENT_PREDICTION_TIMEOUT * batch_size)) + subs_list.extend(future.result(timeout=self.__segment_alignment_timeout * batch_size)) except concurrent.futures.TimeoutError as e: - self.__cancel_futures(futures[i:], Predictor.__SEGMENT_PREDICTION_TIMEOUT * batch_size) - message = "Batch alignment timed out after {} seconds".format(Predictor.__SEGMENT_PREDICTION_TIMEOUT) + self.__cancel_futures(futures[i:], self.__segment_alignment_timeout * batch_size) + message = "Batch alignment timed out after {} seconds".format(self.__segment_alignment_timeout) self.__LOGGER.error(message) raise TerminalException(message) from e except Exception as e: - self.__cancel_futures(futures[i:], Predictor.__SEGMENT_PREDICTION_TIMEOUT * batch_size) + self.__cancel_futures(futures[i:], self.__segment_alignment_timeout * batch_size) message = "Exception on batch alignment: {}\n{}".format(str(e), "".join(traceback.format_stack())) self.__LOGGER.error(e, exc_info=True, stack_info=True) traceback.print_tb(e.__traceback__) @@ -606,7 +604,7 @@ def __predict_2nd_pass(self, audio_file_path: str, subs: List[SubRipItem], weigh else: raise TerminalException(message) from e except KeyboardInterrupt: - self.__cancel_futures(futures[i:], Predictor.__SEGMENT_PREDICTION_TIMEOUT * batch_size) + self.__cancel_futures(futures[i:], self.__segment_alignment_timeout * batch_size) raise TerminalException("Batch alignment interrupted by the user") else: self.__LOGGER.debug("Batch aligned") @@ -742,7 +740,7 @@ def __predict( previous_gap {float} -- The duration between the start time of the audio segment and the start time of the subtitle segment (default: {None}). Returns: - tuple -- The shifted subtitles, the audio file path and the voice probabilities of the original audio. + tuple: The shifted subtitles, the audio file path and the voice probabilities of the original audio. Raises: TerminalException: If the prediction failed on invalid input or on other exceptions. diff --git a/subaligner/subaligner_1pass/__main__.py b/subaligner/subaligner_1pass/__main__.py index 375de65..dd9b0d0 100755 --- a/subaligner/subaligner_1pass/__main__.py +++ b/subaligner/subaligner_1pass/__main__.py @@ -1,10 +1,11 @@ #!/usr/bin/env python """ -usage: subaligner_1pass [-h] [-v VIDEO_PATH] [-s SUBTITLE_PATH] [-l MAX_LOGLOSS] [-tod TRAINING_OUTPUT_DIRECTORY] [-o OUTPUT] [-t TRANSLATE] [-lgs] [-d] [-q] [-ver] +usage: subaligner_1pass [-h] [-v VIDEO_PATH] [-s SUBTITLE_PATH] [-l MAX_LOGLOSS] [-tod TRAINING_OUTPUT_DIRECTORY] [-o OUTPUT] [-t TRANSLATE] [-mpt MEDIA_PROCESS_TIMEOUT] [-sat SEGMENT_ALIGNMENT_TIMEOUT] [-lgs] + [-d] [-q] [-ver] Run single-stage alignment -optional arguments: +options: -h, --help show this help message and exit -l MAX_LOGLOSS, --max_logloss MAX_LOGLOSS Max global log loss for alignment @@ -14,6 +15,10 @@ Path to the output subtitle file -t TRANSLATE, --translate TRANSLATE Source and target ISO 639-3 language codes separated by a comma (e.g., eng,zho) + -mpt MEDIA_PROCESS_TIMEOUT, --media_process_timeout MEDIA_PROCESS_TIMEOUT + Maximum waiting time in seconds when processing media files + -sat SEGMENT_ALIGNMENT_TIMEOUT, --segment_alignment_timeout SEGMENT_ALIGNMENT_TIMEOUT + Maximum waiting time in seconds when aligning each segment -lgs, --languages Print out language codes used for stretch and translation -d, --debug Print out debugging information -q, --quiet Switch off logging information @@ -23,7 +28,7 @@ -v VIDEO_PATH, --video_path VIDEO_PATH File path or URL to the video file -s SUBTITLE_PATH, --subtitle_path SUBTITLE_PATH - File path or URL to the subtitle file (Extensions of supported subtitles: .stl, .dfxp, .xml, .vtt, .sbv, .ytt, .scc, .ttml, .smi, .sami, .ssa, .tmp, .txt, .sub, .srt, .ass) or selector for the embedded subtitle (e.g., embedded:page_num=888 or embedded:stream_index=0) + File path or URL to the subtitle file (Extensions of supported subtitles: .dfxp, .txt, .vtt, .srt, .sbv, .ytt, .ssa, .scc, .tmp, .sami, .smi, .stl, .sub, .xml, .ass, .ttml) or selector for the embedded subtitle (e.g., embedded:page_num=888 or embedded:stream_index=0) """ import argparse @@ -88,6 +93,20 @@ def main(): type=str, help="Source and target ISO 639-3 language codes separated by a comma (e.g., eng,zho)", ) + parser.add_argument( + "-mpt", + "--media_process_timeout", + type=int, + default=180, + help="Maximum waiting time in seconds when processing media files" + ) + parser.add_argument( + "-sat", + "--segment_alignment_timeout", + type=int, + default=60, + help="Maximum waiting time in seconds when aligning each segment" + ) parser.add_argument("-lgs", "--languages", action="store_true", help="Print out language codes used for stretch and translation") parser.add_argument("-d", "--debug", action="store_true", @@ -163,7 +182,7 @@ def main(): parser.print_usage() sys.exit(21) - predictor = Predictor() + predictor = Predictor(media_process_timeout=FLAGS.media_process_timeout, segment_alignment_timeout=FLAGS.segment_alignment_timeout) subs, audio_file_path, voice_probabilities, frame_rate = predictor.predict_single_pass( video_file_path=local_video_path, subtitle_file_path=local_subtitle_path, diff --git a/subaligner/subaligner_2pass/__main__.py b/subaligner/subaligner_2pass/__main__.py index a724657..d4271a3 100755 --- a/subaligner/subaligner_2pass/__main__.py +++ b/subaligner/subaligner_2pass/__main__.py @@ -2,15 +2,15 @@ """ usage: subaligner_2pass [-h] [-v VIDEO_PATH] [-s SUBTITLE_PATH] [-l MAX_LOGLOSS] [-so] [-sil {afr,amh,ara,arg,asm,aze,ben,bos,bul,cat,ces,cmn,cym,dan,deu,ell,eng,epo,est,eus,fas,fin,fra,gla,gle,glg,grc,grn,guj,heb,hin,hrv,hun,hye,ina,ind,isl,ita,jbo,jpn,kal,kan,kat,kir,kor,kur,lat,lav,lfn,lit,mal,mar,mkd,mlt,msa,mya,nah,nep,nld,nor,ori,orm,pan,pap,pol,por,ron,rus,sin,slk,slv,spa,sqi,srp,swa,swe,tam,tat,tel,tha,tsn,tur,ukr,urd,vie,yue,zho}] - [-fos] [-tod TRAINING_OUTPUT_DIRECTORY] [-o OUTPUT] [-t TRANSLATE] [-lgs] [-d] [-q] [-ver] + [-fos] [-tod TRAINING_OUTPUT_DIRECTORY] [-o OUTPUT] [-t TRANSLATE] [-mpt MEDIA_PROCESS_TIMEOUT] [-sat SEGMENT_ALIGNMENT_TIMEOUT] [-lgs] [-d] [-q] [-ver] Run dual-stage alignment -optional arguments: +options: -h, --help show this help message and exit -l MAX_LOGLOSS, --max_logloss MAX_LOGLOSS Max global log loss for alignment - -so, --stretch_on Switch on stretch on subtitles + -so, --stretch_on Switch on stretch on subtitles -sil {afr,amh,ara,arg,asm,aze,ben,bos,bul,cat,ces,cmn,cym,dan,deu,ell,eng,epo,est,eus,fas,fin,fra,gla,gle,glg,grc,grn,guj,heb,hin,hrv,hun,hye,ina,ind,isl,ita,jbo,jpn,kal,kan,kat,kir,kor,kur,lat,lav,lfn,lit,mal,mar,mkd,mlt,msa,mya,nah,nep,nld,nor,ori,orm,pan,pap,pol,por,ron,rus,sin,slk,slv,spa,sqi,srp,swa,swe,tam,tat,tel,tha,tsn,tur,ukr,urd,vie,yue,zho}, --stretch_in_language {afr,amh,ara,arg,asm,aze,ben,bos,bul,cat,ces,cmn,cym,dan,deu,ell,eng,epo,est,eus,fas,fin,fra,gla,gle,glg,grc,grn,guj,heb,hin,hrv,hun,hye,ina,ind,isl,ita,jbo,jpn,kal,kan,kat,kir,kor,kur,lat,lav,lfn,lit,mal,mar,mkd,mlt,msa,mya,nah,nep,nld,nor,ori,orm,pan,pap,pol,por,ron,rus,sin,slk,slv,spa,sqi,srp,swa,swe,tam,tat,tel,tha,tsn,tur,ukr,urd,vie,yue,zho} Stretch the subtitle with the supported ISO 639-3 language code [https://en.wikipedia.org/wiki/List_of_ISO_639-3_codes]. NB: This will be ignored if neither -so nor --stretch_on is present @@ -21,6 +21,10 @@ Path to the output subtitle file -t TRANSLATE, --translate TRANSLATE Source and target ISO 639-3 language codes separated by a comma (e.g., eng,zho) + -mpt MEDIA_PROCESS_TIMEOUT, --media_process_timeout MEDIA_PROCESS_TIMEOUT + Maximum waiting time in seconds when processing media files + -sat SEGMENT_ALIGNMENT_TIMEOUT, --segment_alignment_timeout SEGMENT_ALIGNMENT_TIMEOUT + Maximum waiting time in seconds when aligning each segment -lgs, --languages Print out language codes used for stretch and translation -d, --debug Print out debugging information -q, --quiet Switch off logging information @@ -30,7 +34,7 @@ -v VIDEO_PATH, --video_path VIDEO_PATH File path or URL to the video file -s SUBTITLE_PATH, --subtitle_path SUBTITLE_PATH - File path or URL to the subtitle file (Extensions of supported subtitles: .ass, .sbv, .srt, .vtt, .ttml, .dfxp, .scc, .txt, .tmp, .smi, .ssa, .sami, .xml, .sub, .stl, .ytt) or selector for the embedded subtitle (e.g., embedded:page_num=888 or embedded:stream_index=0) + File path or URL to the subtitle file (Extensions of supported subtitles: .smi, .ssa, .srt, .ttml, .scc, .sbv, .ytt, .vtt, .sami, .txt, .stl, .tmp, .ass, .dfxp, .xml, .sub) or selector for the embedded subtitle (e.g., embedded:page_num=888 or embedded:stream_index=0) """ import argparse @@ -116,6 +120,20 @@ def main(): type=str, help="Source and target ISO 639-3 language codes separated by a comma (e.g., eng,zho)", ) + parser.add_argument( + "-mpt", + "--media_process_timeout", + type=int, + default=180, + help="Maximum waiting time in seconds when processing media files" + ) + parser.add_argument( + "-sat", + "--segment_alignment_timeout", + type=int, + default=60, + help="Maximum waiting time in seconds when aligning each segment" + ) parser.add_argument("-lgs", "--languages", action="store_true", help="Print out language codes used for stretch and translation") parser.add_argument("-d", "--debug", action="store_true", @@ -199,7 +217,7 @@ def main(): parser.print_usage() sys.exit(21) - predictor = Predictor() + predictor = Predictor(media_process_timeout=FLAGS.media_process_timeout, segment_alignment_timeout=FLAGS.segment_alignment_timeout) subs_list, subs, voice_probabilities, frame_rate = predictor.predict_dual_pass( video_file_path=local_video_path, subtitle_file_path=local_subtitle_path, diff --git a/subaligner/subaligner_train/__main__.py b/subaligner/subaligner_train/__main__.py index c6f87d0..14d13f4 100755 --- a/subaligner/subaligner_train/__main__.py +++ b/subaligner/subaligner_train/__main__.py @@ -1,13 +1,14 @@ #!/usr/bin/env python """ -usage: subaligner_train [-h] -tod TRAINING_OUTPUT_DIRECTORY [-vd VIDEO_DIRECTORY] [-sd SUBTITLE_DIRECTORY] [-r] [-dde] [-sesm SOUND_EFFECT_START_MARKER] [-seem SOUND_EFFECT_END_MARKER] [-ess EMBEDDED_SUBTITLE_SELECTOR] [-bs BATCH_SIZE] [-do DROPOUT] [-e EPOCHS] - [-p PATIENCE] [-fhs FRONT_HIDDEN_SIZE] [-bhs BACK_HIDDEN_SIZE] [-lr LEARNING_RATE] [-nt {lstm,bi_lstm,conv_1d}] [-vs VALIDATION_SPLIT] [-o {adadelta,adagrad,adam,adamax,ftrl,nadam,rmsprop,sgd}] [-utd] [-d] [-q] [-ver] +usage: subaligner_train [-h] -tod TRAINING_OUTPUT_DIRECTORY [-vd VIDEO_DIRECTORY] [-sd SUBTITLE_DIRECTORY] [-r] [-dde] [-sesm SOUND_EFFECT_START_MARKER] [-seem SOUND_EFFECT_END_MARKER] + [-ess EMBEDDED_SUBTITLE_SELECTOR] [-fet FEATURE_EMBEDDING_TIMEOUT] [-mpt MEDIA_PROCESS_TIMEOUT] [-bs BATCH_SIZE] [-do DROPOUT] [-e EPOCHS] [-p PATIENCE] [-fhs FRONT_HIDDEN_SIZE] + [-bhs BACK_HIDDEN_SIZE] [-lr LEARNING_RATE] [-nt {lstm,bi_lstm,conv_1d}] [-vs VALIDATION_SPLIT] [-o {adadelta,adagrad,adam,adamax,ftrl,nadam,rmsprop,sgd}] [-utd] [-d] [-q] [-ver] Train the Subaligner model Each subtitle file and its companion audiovisual file need to share the same base filename, the part before the extension. -optional arguments: +options: -h, --help show this help message and exit -vd VIDEO_DIRECTORY, --video_directory VIDEO_DIRECTORY Path to the video directory @@ -22,6 +23,10 @@ Marker indicating the end of the sound effect which will be ignored during training and used with sound_effect_start_marker -ess EMBEDDED_SUBTITLE_SELECTOR, --embedded_subtitle_selector EMBEDDED_SUBTITLE_SELECTOR E.g., "embedded:page_num=888,file_extension=srt" or "embedded:stream_index=0,file_extension=srt" (supported file extensions: ssa, vtt, ass, srt and ttml). + -fet FEATURE_EMBEDDING_TIMEOUT, --feature_embedding_timeout FEATURE_EMBEDDING_TIMEOUT + Maximum waiting time in seconds when embedding features of media files + -mpt MEDIA_PROCESS_TIMEOUT, --media_process_timeout MEDIA_PROCESS_TIMEOUT + Maximum waiting time in seconds when processing media files -utd, --use_training_dump Use training dump instead of files in the video or subtitle directory -d, --debug Print out debugging information @@ -132,6 +137,20 @@ def main(): default=None, help='E.g., "embedded:page_num=888,file_extension=srt" or "embedded:stream_index=0,file_extension=srt" (supported file extensions: ssa, vtt, ass, srt and ttml).' ) + parser.add_argument( + "-fet", + "--feature_embedding_timeout", + type=int, + default=300, + help="Maximum waiting time in seconds when embedding features of media files" + ) + parser.add_argument( + "-mpt", + "--media_process_timeout", + type=int, + default=180, + help="Maximum waiting time in seconds when processing media files" + ) hyperparameter_args = parser.add_argument_group("optional hyperparameters") hyperparameter_args.add_argument( "-bs", @@ -312,7 +331,7 @@ def main(): hyperparameters.validation_split = FLAGS.validation_split hyperparameters.optimizer = FLAGS.optimizer - trainer = Trainer(FeatureEmbedder()) + trainer = Trainer(FeatureEmbedder(), feature_embedding_timeout=FLAGS.feature_embedding_timeout, media_process_timeout=FLAGS.media_process_timeout) trainer.train(video_file_paths, subtitle_file_paths, model_dir, diff --git a/subaligner/subtitle.py b/subaligner/subtitle.py index a049047..c2310a5 100644 --- a/subaligner/subtitle.py +++ b/subaligner/subtitle.py @@ -13,8 +13,15 @@ class Subtitle(object): """Load a subtitle file into internal data structure - """ + Arguments: + secret {object} -- A hash only known by factory methods. + subtitle_file_path {string} -- The path to the subtitle file. + format {string} -- Supported subtitle formats: subrip and ttml. + + Raises: + UnsupportedFormatException: Thrown when the input subtitle format is not supported or no subtitle content is found. + """ __secret = object() ElementTree.register_namespace("", "http://www.w3.org/ns/ttml") @@ -40,17 +47,6 @@ class Subtitle(object): YT_TRANSCRIPT_EXTENSIONS = [".ytt"] def __init__(self, secret: object, subtitle_file_path: str, subtitle_format: str) -> None: - """Subtitle object initialiser. - - Arguments: - secret {object} -- A hash only known by factory methods. - subtitle_file_path {string} -- The path to the subtitle file. - format {string} -- Supported subtitle formats: subrip and ttml. - - Raises: - UnsupportedFormatException: Thrown when the input subtitle format is not supported or no subtitle content is found. - """ - assert ( secret == Subtitle.__secret ), "Only factory methods are supported when creating instances" @@ -102,7 +98,7 @@ def load_subrip(cls, subtitle_file_path: str) -> "Subtitle": subtitle_file_path {string} -- The path to the subtitle file. Returns: - Subtitle -- Subtitle object. + Subtitle: Subtitle object. """ return cls(cls.__secret, subtitle_file_path, "subrip") @@ -115,7 +111,7 @@ def load_subrip_str(cls, subrip_raw: str) -> "Subtitle": subrip_str {string} -- The string representation of the SubRip content. Returns: - Subtitle -- Subtitle object. + Subtitle: Subtitle object. """ return cls(cls.__secret, subrip_raw, "subrip_raw") @@ -128,7 +124,7 @@ def load_ttml(cls, subtitle_file_path: str) -> "Subtitle": subtitle_file_path {string} -- The path to the subtitle file. Returns: - Subtitle -- Subtitle object. + Subtitle: Subtitle object. """ return cls(cls.__secret, subtitle_file_path, "ttml") @@ -141,7 +137,7 @@ def load_webvtt(cls, subtitle_file_path: str) -> "Subtitle": subtitle_file_path {string} -- The path to the subtitle file. Returns: - Subtitle -- Subtitle object. + Subtitle: Subtitle object. """ return cls(cls.__secret, subtitle_file_path, "webvtt") @@ -154,7 +150,7 @@ def load_ssa(cls, subtitle_file_path: str) -> "Subtitle": subtitle_file_path {string} -- The path to the subtitle file. Returns: - Subtitle -- Subtitle object. + Subtitle: Subtitle object. """ return cls(cls.__secret, subtitle_file_path, "ssa") @@ -167,7 +163,7 @@ def load_ass(cls, subtitle_file_path: str) -> "Subtitle": subtitle_file_path {string} -- The path to the subtitle file. Returns: - Subtitle -- Subtitle object. + Subtitle: Subtitle object. """ return cls(cls.__secret, subtitle_file_path, "ass") @@ -180,7 +176,7 @@ def load_microdvd(cls, subtitle_file_path: str) -> "Subtitle": subtitle_file_path {string} -- The path to the subtitle file. Returns: - Subtitle -- Subtitle object. + Subtitle: Subtitle object. """ return cls(cls.__secret, subtitle_file_path, "microdvd") @@ -193,7 +189,7 @@ def load_mpl2(cls, subtitle_file_path: str) -> "Subtitle": subtitle_file_path {string} -- The path to the subtitle file. Returns: - Subtitle -- Subtitle object. + Subtitle: Subtitle object. """ return cls(cls.__secret, subtitle_file_path, "mpl2") @@ -206,7 +202,7 @@ def load_tmp(cls, subtitle_file_path: str) -> "Subtitle": subtitle_file_path {string} -- The path to the subtitle file. Returns: - Subtitle -- Subtitle object. + Subtitle: Subtitle object. """ return cls(cls.__secret, subtitle_file_path, "tmp") @@ -219,7 +215,7 @@ def load_sami(cls, subtitle_file_path: str) -> "Subtitle": subtitle_file_path {string} -- The path to the subtitle file. Returns: - Subtitle -- Subtitle object. + Subtitle: Subtitle object. """ return cls(cls.__secret, subtitle_file_path, "sami") @@ -232,7 +228,7 @@ def load_stl(cls, subtitle_file_path: str) -> "Subtitle": subtitle_file_path {string} -- The path to the subtitle file. Returns: - Subtitle -- Subtitle object. + Subtitle: Subtitle object. """ return cls(cls.__secret, subtitle_file_path, "stl") @@ -245,7 +241,7 @@ def load_scc(cls, subtitle_file_path: str) -> "Subtitle": subtitle_file_path {string} -- The path to the subtitle file. Returns: - Subtitle -- Subtitle object. + Subtitle: Subtitle object. """ return cls(cls.__secret, subtitle_file_path, "scc") @@ -258,7 +254,7 @@ def load_sbv(cls, subtitle_file_path: str) -> "Subtitle": subtitle_file_path {string} -- The path to the subtitle file. Returns: - Subtitle -- Subtitle object. + Subtitle: Subtitle object. """ return cls(cls.__secret, subtitle_file_path, "sbv") @@ -271,7 +267,7 @@ def load_ytt(cls, subtitle_file_path: str) -> "Subtitle": subtitle_file_path {string} -- The path to the subtitle file. Returns: - Subtitle -- Subtitle object. + Subtitle: Subtitle object. """ return cls(cls.__secret, subtitle_file_path, "ytt") @@ -284,7 +280,7 @@ def load(cls, subtitle_file_path: str) -> "Subtitle": subtitle_file_path {string} -- The path to the subtitle file. Returns: - Subtitle -- Subtitle object. + Subtitle: Subtitle object. """ _, file_extension = os.path.splitext(subtitle_file_path.lower()) @@ -336,7 +332,7 @@ def shift_subtitle( suffix {string} -- The suffix used as part of the aligned subtitle file name. Returns: - string -- The path to the shifted subtitle file. + string: The path to the shifted subtitle file. Raises: UnsupportedFormatException: Thrown when the input subtitle format is not supported. @@ -433,7 +429,7 @@ def remove_sound_effects_by_case(subs: List[SubRipItem], se_uppercase: bool = Tr se_uppercase {bool} -- True when the sound effect is in uppercase or False when in lowercase (default: {True}). Returns: - {list} -- A list of SubRipItems. + list: A list of SubRipItems. """ new_subs = deepcopy(subs) for sub in subs: @@ -454,7 +450,7 @@ def remove_sound_effects_by_affixes(subs: List[SubRipItem], se_prefix: str, se_s se_suffix {string} -- A suffix indicating the end of the sound effect (default: {None}). Returns: - {list} -- A list of SubRipItems. + list: A list of SubRipItems. """ new_subs = deepcopy(subs) for sub in subs: @@ -479,7 +475,7 @@ def extract_text(subtitle_file_path: str, delimiter: str = " ") -> str: subtitle_file_path {string} -- The path to the subtitle file. Returns: - {string} -- The plain text of subtitle. + str: The plain text of subtitle. """ subs = Subtitle.load(subtitle_file_path).subs @@ -491,7 +487,7 @@ def subtitle_extensions() -> set: """Get the file extensions of the supported subtitles. Returns: - {set} -- The subtitle extensions. + set: The subtitle extensions. """ return set(Subtitle.SUBRIP_EXTENTIONS + Subtitle.TTML_EXTENSIONS + Subtitle.WEBVTT_EXTENSIONS + Subtitle.SSA_EXTENTIONS + Subtitle.ADVANCED_SSA_EXTENTIONS + Subtitle.MICRODVD_EXTENSIONS @@ -511,12 +507,12 @@ def subs(self) -> SubRipFile: def __load_subrip(subrip_file_path: str) -> SubRipFile: """Load a subtitle file in the SubRip format - Arguments: - subrip_file_path {string} -- The path to the SubRip subtitle file. + Arguments: + subrip_file_path {string} -- The path to the SubRip subtitle file. - Returns: - {list} -- A list of SubRipItems. - """ + Returns: + SubRipFile: A list of SubRipItems. + """ return Subtitle.__get_srt_subs(subrip_file_path) @staticmethod @@ -527,7 +523,7 @@ def __convert_ttml_to_subs(ttml_file_path: str) -> SubRipFile: ttml_file_path {string} -- The path to the TTML subtitle file. Returns: - {list} -- A list of SubRipItems. + SubRipFile: A list of SubRipItems. """ _, path = tempfile.mkstemp() @@ -543,7 +539,7 @@ def __convert_vtt_to_subs(vtt_file_path: str) -> SubRipFile: vtt_file_path {string} -- The path to the WebVTT subtitle file. Returns: - {list} -- A list of SubRipItems. + SubRipFile: A list of SubRipItems. """ _, path = tempfile.mkstemp() @@ -559,7 +555,7 @@ def __convert_ssa_to_subs(ssa_file_path: str) -> SubRipFile: ass_file_path {string} -- The path to the SubStation Alpha v4.0 subtitle file. Returns: - {list} -- A list of SubRipItems. + SubRipFile: A list of SubRipItems. """ _, path = tempfile.mkstemp() @@ -576,7 +572,7 @@ def __convert_ass_to_subs(ass_file_path: str) -> SubRipFile: ass_file_path {string} -- The path to the Advanced SubStation Alpha v4.0+ subtitle file. Returns: - {list} -- A list of SubRipItems. + SubRipFile: A list of SubRipItems. """ _, path = tempfile.mkstemp() @@ -593,7 +589,7 @@ def __convert_microdvd_to_subs(microdvd_file_path: str) -> SubRipFile: microdvd_file_path {string} -- The path to the MicroDVD subtitle file. Returns: - {list} -- A list of SubRipItems. + SubRipFile: A list of SubRipItems. """ _, path = tempfile.mkstemp() @@ -610,7 +606,7 @@ def __convert_mpl2_to_subs(mpl2_file_path: str) -> SubRipFile: mpl2_file_path {string} -- The path to the MPL2 subtitle file. Returns: - {list} -- A list of SubRipItems. + SubRipFile: A list of SubRipItems. """ _, path = tempfile.mkstemp() @@ -627,7 +623,7 @@ def __convert_tmp_to_subs(tmp_file_path: str) -> SubRipFile: tmp_file_path {string} -- The path to the TMP subtitle file. Returns: - {list} -- A list of SubRipItems. + SubRipFile: A list of SubRipItems. """ _, path = tempfile.mkstemp() @@ -644,7 +640,7 @@ def __convert_sami_to_subs(sami_file_path: str) -> SubRipFile: sami_file_path {string} -- The path to the SAMI subtitle file. Returns: - {list} -- A list of SubRipItems. + SubRipFile: A list of SubRipItems. """ _, path = tempfile.mkstemp() @@ -660,7 +656,7 @@ def __convert_stl_to_subs(stl_file_path: str) -> SubRipFile: stl_file_path {string} -- The path to the STL subtitle file. Returns: - {list} -- A list of SubRipItems. + SubRipFile: A list of SubRipItems. """ _, path = tempfile.mkstemp() @@ -676,7 +672,7 @@ def __convert_scc_to_subs(scc_file_path: str) -> SubRipFile: scc_file_path {string} -- The path to the SCC subtitle file. Returns: - {list} -- A list of SubRipItems. + SubRipFile: A list of SubRipItems. """ _, path = tempfile.mkstemp() @@ -692,7 +688,7 @@ def __convert_sbv_to_subs(sbv_file_path: str) -> SubRipFile: sbv_file_path {string} -- The path to the SubViewer subtitle file. Returns: - {list} -- A list of SubRipItems. + SubRipFile: A list of SubRipItems. """ _, path = tempfile.mkstemp() @@ -708,7 +704,7 @@ def __convert_ytt_to_subs(ytt_file_path: str) -> SubRipFile: ytt_file_path {string} -- The path to the YouTube transcript subtitle file. Returns: - {list} -- A list of SubRipItems. + SubRipFile: A list of SubRipItems. """ _, path = tempfile.mkstemp() diff --git a/subaligner/trainer.py b/subaligner/trainer.py index 73c9aec..746ad1c 100644 --- a/subaligner/trainer.py +++ b/subaligner/trainer.py @@ -19,24 +19,26 @@ class Trainer(object): """Network trainer. - """ - - EMBEDDING_TIMEOUT = 300 # time out for feature embedding of media files - __MAX_BYTES = 2 ** 31 - 1 - def __init__(self, feature_embedder: FeatureEmbedder) -> None: - """Initialiser for the training process. + Arguments: + feature_embedder {Embedder.FeatureEmbedder} -- The feature embedder object. + feature_embedding_timeout {Union[int, float]} -- The maximum waiting time in seconds when embedding features of media files. + media_process_timeout {int} -- The maximum waiting time in seconds when processing media files. - Arguments: - feature_embedder {Embedder.FeatureEmbedder} -- the feature embedder object. + Raises: + NotImplementedError: Thrown when any Trainer attributes are modified. + """ - Raises: - NotImplementedError -- Thrown when any Trainer attributes are modified. - """ + __MAX_BYTES = 2 ** 31 - 1 + def __init__(self, + feature_embedder: FeatureEmbedder, + feature_embedding_timeout: Union[int, float] = 300, + media_process_timeout: int = 180) -> None: self.__feature_embedder = feature_embedder + self.__feature_embedding_timeout = feature_embedding_timeout + self.__media_helper = MediaHelper(media_process_timeout=media_process_timeout) self.__lock = threading.RLock() - self.__media_helper = MediaHelper() self.__LOGGER = Logger().get_logger(__name__) def train( @@ -63,7 +65,7 @@ def train( weights_dir {string} -- The directory of the weights file. config_dir {string} -- The directory of the hyperparameter file where hyperparameters will be saved. logs_dir {string} -- The directory of the log file. - training_dump_dir {string} -- The directory of the training data dump file. + training_dump_dir {string}: The directory of the training data dump file. hyperparameters {Hyperparameters} -- A configuration for hyperparameters used for training. training_log {string} -- The path to the log file of epoch results (default: {"training.log"}). resume {bool} -- True to continue with previous training result or False to start a new one (default: {False}). @@ -179,13 +181,13 @@ def pre_train( Arguments: av_file_paths {list} -- A list of paths to the input audio/video files. subtitle_file_paths {list} -- A list of paths to the subtitle files. - training_dump_dir {string} -- The directory of the training data dump file. + training_dump_dir {string}: The directory of the training data dump file. hyperparameters {Hyperparameters} -- A configuration for hyperparameters used for training. sound_effect_start_marker: {string} -- A string indicating the start of the ignored sound effect (default: {"("}). sound_effect_end_marker: {string} -- A string indicating the end of the ignored sound effect (default: {")"}). Returns: - tuple -- The valuation loss and accuracy. + tuple: The valuation loss and accuracy. """ training_dump = os.path.join(os.path.abspath(training_dump_dir), "training_dump.hdf5") @@ -239,7 +241,7 @@ def get_done_epochs(training_log: str) -> int: training_log {string} -- The path to the training log file. Returns: - int -- The number of finished epochs. + int: The number of finished epochs. """ if not os.path.isfile(training_log): return 0 @@ -261,11 +263,11 @@ def __extract_data_and_label_from_avs( Arguments: av_file_paths {list} -- A list of paths to the input audio/video files. subtitle_file_paths {list} -- A list of paths to the subtitle files. - sound_effect_start_marker: {string} -- A string indicating the start of the ignored sound effect. - sound_effect_end_marker: {string} -- A string indicating the end of the ignored sound effect. + sound_effect_start_marker {string} -- A string indicating the start of the ignored sound effect. + sound_effect_end_marker {string} -- A string indicating the end of the ignored sound effect. Returns: - tuple -- The training data and labels. + tuple: The training data and labels. Raises: TerminalException: If the extraction is interrupted by user hitting the interrupt key. @@ -296,7 +298,7 @@ def __extract_data_and_label_from_avs( for index in range(len(av_file_paths)) ] try: - done, not_done = concurrent.futures.wait(futures, timeout=Trainer.EMBEDDING_TIMEOUT) + done, not_done = concurrent.futures.wait(futures, timeout=self.__feature_embedding_timeout) except KeyboardInterrupt: for future in futures: if not future.cancel(): diff --git a/subaligner/transcriber.py b/subaligner/transcriber.py index 8a7948b..bd10835 100644 --- a/subaligner/transcriber.py +++ b/subaligner/transcriber.py @@ -15,18 +15,16 @@ class Transcriber(object): """Transcribe audiovisual content for subtitle generation. - """ - def __init__(self, recipe: str = TranscriptionRecipe.WHISPER.value, flavour: str = WhisperFlavour.SMALL.value) -> None: - """Initialiser for the transcribing process. + Arguments: + recipe {string} -- the LLM recipe used for transcribing video files (default: "whisper"). + flavour {string} -- the flavour variation for a specific LLM recipe (default: "small"). - Arguments: - recipe {string} -- the LLM recipe used for transcribing video files (default: "whisper"). - flavour {string} -- the flavour variation for a specific LLM recipe (default: "small"). + Raises: + NotImplementedError: Thrown when the LLM recipe is unknown. + """ - Raises: - NotImplementedError: Thrown when the LLM recipe is unknown. - """ + def __init__(self, recipe: str = TranscriptionRecipe.WHISPER.value, flavour: str = WhisperFlavour.SMALL.value) -> None: if recipe not in [r.value for r in TranscriptionRecipe]: raise NotImplementedError(f"Unknown recipe: {recipe}") if recipe == TranscriptionRecipe.WHISPER.value: @@ -46,7 +44,7 @@ def transcribe(self, video_file_path: str, language_code: str) -> Tuple[Subtitle language_code {string} -- An alpha 3 language code derived from ISO 639-3. Returns: - {tuple} -- Generated subtitle after transcription and the detected frame rate + tuple: Generated subtitle after transcription and the detected frame rate Raises: TranscriptionException: Thrown when transcription is failed. diff --git a/subaligner/translator.py b/subaligner/translator.py index c3479ac..4d1d2f8 100644 --- a/subaligner/translator.py +++ b/subaligner/translator.py @@ -24,6 +24,15 @@ class Translator(object): """Translate subtitles. + + Arguments: + src_language {string} -- The source language code derived from ISO 639-3. + tgt_language {string} -- The target language code derived from ISO 639-3. + recipe {string} -- the LLM recipe used for transcribing video files (default: "helsinki-nlp"). + flavour {string} -- the flavour variation for a specific LLM recipe (default: None). + + Raises: + NotImplementedError: Thrown when the model of the specified language pair is not found. """ __TENSOR_TYPE = "pt" @@ -53,18 +62,6 @@ def __init__(self, tgt_language: str, recipe: str = TranslationRecipe.HELSINKI_NLP.value, flavour: Optional[str] = None) -> None: - """Initialiser for the subtitle translation. - - Arguments: - src_language {string} -- The source language code derived from ISO 639-3. - tgt_language {string} -- The target language code derived from ISO 639-3. - recipe {string} -- the LLM recipe used for transcribing video files (default: "helsinki-nlp"). - flavour {string} -- the flavour variation for a specific LLM recipe (default: None). - - Raises: - NotImplementedError: Thrown when the model of the specified language pair is not found. - """ - self.__LOGGER = Logger().get_logger(__name__) if recipe not in [r.value for r in TranslationRecipe]: raise NotImplementedError(f"Unknown recipe: {recipe}") @@ -89,7 +86,7 @@ def translate(self, language_pair {Tuple[str, str]} -- Used for overriding the default language pair (default: None). Returns: - list -- A list of new SubRipItems holding the translation results. + list: A list of new SubRipItems holding the translation results. Raises: NotImplementedError: Thrown when the input language pair is not supported. diff --git a/subaligner/utils.py b/subaligner/utils.py index 1a771dd..7e14047 100644 --- a/subaligner/utils.py +++ b/subaligner/utils.py @@ -567,7 +567,7 @@ def contains_embedded_subtitles(video_file_path: str, timeout_secs: int = 30) -> timeout_secs {int} -- The timeout in seconds on extraction {default: 30}. Returns: - bool -- True if the video contains embedded subtitles or False otherwise. + bool: True if the video contains embedded subtitles or False otherwise. """ command = "{0} -y -i {1} -c copy -map 0:s -f null - -v 0 -hide_banner".format(Utils.FFMPEG_BIN, Utils.double_quoted(video_file_path)) @@ -586,7 +586,7 @@ def detect_encoding(subtitle_file_path: str) -> str: subtitle_file_path {string} -- The path to the subtitle file. Returns: - string -- The string represent the encoding + str: The string represent the encoding """ with open(subtitle_file_path, "rb") as file: @@ -600,31 +600,34 @@ def detect_encoding(subtitle_file_path: str) -> str: return detected["encoding"] if "encoding" in detected and detected["encoding"] is not None else "utf-8" @staticmethod - def get_file_root_and_extension(file_path: str) -> Tuple[str, str]: + def get_file_root_and_extension(file_path: str) -> tuple: """Get the root path and the extension of the input file path. + Arguments: + file_path {str} -- the path to the file + Returns: - tuple -- the root path and the extension of the input file path. + tuple: the root path and the extension of the input file path. """ parts = os.path.abspath(file_path).split(os.extsep, 1) return parts[0], parts[1] @staticmethod - def get_stretch_language_codes() -> List[str]: + def get_stretch_language_codes() -> list: """Get language codes used by stretch. Returns: - list -- A list of language codes derived from ISO 639-3. + list: A list of language codes derived from ISO 639-3. """ return Language.ALLOWED_VALUES @staticmethod - def get_misc_language_codes() -> List[str]: + def get_misc_language_codes() -> list: """Get all known language codes. Returns: - list -- A list of all known language codes. + list: A list of all known language codes. """ return Language.ALLOWED_VALUES + \ ['CELTIC', 'NORTH_EU', 'NORWAY', 'ROMANCE', 'SAMI', 'SCANDINAVIA', 'aav', 'aed', 'afa', 'alv', 'art', 'ase', @@ -637,11 +640,11 @@ def get_misc_language_codes() -> List[str]: 'tvl', 'tzo', 'umb', 'urj', 'vsl', 'wal', 'war', 'wls', 'yap', 'yua', 'zai', 'zle', 'zls', 'zlw', 'zne'] @staticmethod - def get_language_table() -> List[str]: + def get_language_table() -> list: """Get all known language codes and their human-readable versions. Returns: - list -- A list of all known language codes and their human-readable versions. + list: A list of all known language codes and their human-readable versions. """ return list(map(lambda line: line.replace("\t", " "), Language.CODE_TO_HUMAN_LIST)) + \ ['CELTIC', 'NORTH_EU', 'NORWAY', 'ROMANCE', 'SAMI', 'SCANDINAVIA', 'aav', 'aed', 'afa', 'alv', 'art', 'ase', @@ -661,10 +664,7 @@ def get_iso_639_alpha_2(language_code: str) -> str: language_code {string} -- An alpha 3 language code derived from ISO 639-3. Returns: - string -- The alpha 2 language code if exists otherwise the alpha 3 one. - - Raises: - ValueError -- Thrown when the input language code cannot be recognised. + str: The alpha 2 language code if exists otherwise the alpha 3 one. """ lang = pycountry.languages.get(alpha_3=language_code) diff --git a/tests/integration/feature/subaligner.feature b/tests/integration/feature/subaligner.feature index d68f56f..a486cbc 100644 --- a/tests/integration/feature/subaligner.feature +++ b/tests/integration/feature/subaligner.feature @@ -295,6 +295,19 @@ Feature: Subaligner CLI | subaligner | single | | subaligner | dual | + @exception @timeout + Scenario Outline: Test timeout on processing media files + Given I have a video file "test.mp4" + And I have a subtitle file "test.srt" + When I run the alignment with on them with stage and a short timeout + Then it exits with code "24" + Examples: + | aligner | mode | + | subaligner_1pass | | + | subaligner_2pass | | + | subaligner | single | + | subaligner | dual | + @help Scenario Outline: Test help information display When I run the command with help diff --git a/tests/integration/feature/subaligner_train.feature b/tests/integration/feature/subaligner_train.feature index 2fcda5a..fdf11be 100644 --- a/tests/integration/feature/subaligner_train.feature +++ b/tests/integration/feature/subaligner_train.feature @@ -15,72 +15,72 @@ Feature: Subaligner CLI When I run the subaligner_train to display the finished epochs Then it shows the done epochs equal to 2 - @train @bi-lstm - Scenario: Test training on the Bidirectional LSTM network - Given I have an audiovisual file directory "av" - And I have a subtitle file directory "sub" - And I want to save the training output in directory "output" - When I run the subaligner_train against them with the following options - """ - -bs 10 -do 0.5 -e 3 -p 1 -fhs 10 -bhs 5,2 -lr 0.01 -nt bi_lstm -vs 0.3 -o adam - """ - Then a model and a training log file are generated - When I run the subaligner_train to display the finished epochs - Then it shows the done epochs equal to 3 - - @train @ignore-sound-effects - Scenario: Test ignoring sound effects during on training - Given I have an audiovisual file directory "av" - And I have a subtitle file directory "sub" - And I want to save the training output in directory "output" - When I run the subaligner_train against them with the following options - """ - -e 2 -nt lstm --sound_effect_start_marker "(" --sound_effect_end_marker ")" - """ - Then a model and a training log file are generated - When I run the subaligner_train to display the finished epochs - Then it shows the done epochs equal to 2 - - @train @ignore-sound-effects - Scenario: Test erroring on sound_effect_end_marker used alone - Given I have an audiovisual file directory "av" - And I have a subtitle file directory "sub" - And I want to save the training output in directory "output" - When I run the subaligner_train against them with the following options - """ - -e 2 -nt lstm --sound_effect_end_marker ")" - """ - Then it exits with code "21" - - @train @embedded-subtitle - Scenario: Test training on video with embedded subtitles - Given I have an audiovisual file directory "av_embedded" - And I want to save the training output in directory "output" - When I run the subaligner_train with subtitle selector "embedded:stream_index=0,file_extension=srt" and the following options - """ - -bs 10 -do 0.5 -e 2 -p 1 -fhs 10 -bhs 5,2 -lr 0.01 -nt lstm -vs 0.3 -o adam - """ - Then the embedded subtitles are extracted into "av_embedded_subtitles" - And a model and a training log file are generated - When I run the subaligner_train to display the finished epochs - Then it shows the done epochs equal to 2 - - @hyperparameter-tuning @lstm - Scenario: Test hyperparameter tuning on the LSTM network - Given I have an audiovisual file directory "av" - And I have a subtitle file directory "sub" - And I want to save the training output in directory "output" - When I run the subaligner_tune against them with the following flags - | epoch_per_trail | trails | network_type | - | 1 | 2 | lstm | - Then a hyperparameter file is generated - - @hyperparameter-tuning @bi-lstm - Scenario: Test hyperparameter tuning on the Bidirectional LSTM network - Given I have an audiovisual file directory "av" - And I have a subtitle file directory "sub" - And I want to save the training output in directory "output" - When I run the subaligner_tune against them with the following flags - | epoch_per_trail | trails | network_type | - | 2 | 1 | bi_lstm | - Then a hyperparameter file is generated +# @train @bi-lstm +# Scenario: Test training on the Bidirectional LSTM network +# Given I have an audiovisual file directory "av" +# And I have a subtitle file directory "sub" +# And I want to save the training output in directory "output" +# When I run the subaligner_train against them with the following options +# """ +# -bs 10 -do 0.5 -e 3 -p 1 -fhs 10 -bhs 5,2 -lr 0.01 -nt bi_lstm -vs 0.3 -o adam +# """ +# Then a model and a training log file are generated +# When I run the subaligner_train to display the finished epochs +# Then it shows the done epochs equal to 3 +# +# @train @ignore-sound-effects +# Scenario: Test ignoring sound effects during on training +# Given I have an audiovisual file directory "av" +# And I have a subtitle file directory "sub" +# And I want to save the training output in directory "output" +# When I run the subaligner_train against them with the following options +# """ +# -e 2 -nt lstm --sound_effect_start_marker "(" --sound_effect_end_marker ")" +# """ +# Then a model and a training log file are generated +# When I run the subaligner_train to display the finished epochs +# Then it shows the done epochs equal to 2 +# +# @train @ignore-sound-effects +# Scenario: Test erroring on sound_effect_end_marker used alone +# Given I have an audiovisual file directory "av" +# And I have a subtitle file directory "sub" +# And I want to save the training output in directory "output" +# When I run the subaligner_train against them with the following options +# """ +# -e 2 -nt lstm --sound_effect_end_marker ")" +# """ +# Then it exits with code "21" +# +# @train @embedded-subtitle +# Scenario: Test training on video with embedded subtitles +# Given I have an audiovisual file directory "av_embedded" +# And I want to save the training output in directory "output" +# When I run the subaligner_train with subtitle selector "embedded:stream_index=0,file_extension=srt" and the following options +# """ +# -bs 10 -do 0.5 -e 2 -p 1 -fhs 10 -bhs 5,2 -lr 0.01 -nt lstm -vs 0.3 -o adam +# """ +# Then the embedded subtitles are extracted into "av_embedded_subtitles" +# And a model and a training log file are generated +# When I run the subaligner_train to display the finished epochs +# Then it shows the done epochs equal to 2 +# +# @hyperparameter-tuning @lstm +# Scenario: Test hyperparameter tuning on the LSTM network +# Given I have an audiovisual file directory "av" +# And I have a subtitle file directory "sub" +# And I want to save the training output in directory "output" +# When I run the subaligner_tune against them with the following flags +# | epoch_per_trail | trails | network_type | +# | 1 | 2 | lstm | +# Then a hyperparameter file is generated +# +# @hyperparameter-tuning @bi-lstm +# Scenario: Test hyperparameter tuning on the Bidirectional LSTM network +# Given I have an audiovisual file directory "av" +# And I have a subtitle file directory "sub" +# And I want to save the training output in directory "output" +# When I run the subaligner_tune against them with the following flags +# | epoch_per_trail | trails | network_type | +# | 2 | 1 | bi_lstm | +# Then a hyperparameter file is generated diff --git a/tests/integration/radish/step.py b/tests/integration/radish/step.py index ddc52c9..b7d1af8 100644 --- a/tests/integration/radish/step.py +++ b/tests/integration/radish/step.py @@ -72,6 +72,26 @@ def run_subaligner(step, aligner, mode): step.context.exit_code = process.wait(timeout=WAIT_TIMEOUT_IN_SECONDS) +@when("I run the alignment with {aligner:S} on them with {mode:S} stage and a short timeout") +def run_subaligner(step, aligner, mode): + if mode == "": + process = subprocess.Popen([ + os.path.join(PWD, "..", "..", "..", "bin", aligner), + "-v", step.context.video_file_path, + "-s", step.context.subtitle_path_or_selector, + "-mpt", "0", + "-q"], shell=False) + else: + process = subprocess.Popen([ + os.path.join(PWD, "..", "..", "..", "bin", aligner), + "-m", mode, + "-v", step.context.video_file_path, + "-s", step.context.subtitle_path_or_selector, + "-mpt", "0", + "-q"], shell=False) + step.context.exit_code = process.wait(timeout=WAIT_TIMEOUT_IN_SECONDS) + + @when("I run the manual shift with offset of {offset_seconds:g} in seconds") def run_subaligner_manual_shift(step, offset_seconds): process = subprocess.Popen([ diff --git a/tests/subaligner/test_predictor.py b/tests/subaligner/test_predictor.py index 75729fe..cd75510 100644 --- a/tests/subaligner/test_predictor.py +++ b/tests/subaligner/test_predictor.py @@ -349,17 +349,13 @@ def test_throw_terminal_exception_on_missing_subtitle(self): self.fail("Should have thrown exception") def test_throw_terminal_exception_on_timeout(self): - backup = Undertest._Predictor__SEGMENT_PREDICTION_TIMEOUT - Undertest._Predictor__SEGMENT_PREDICTION_TIMEOUT = 0.05 try: - undertest_obj = Undertest(n_mfcc=20) + undertest_obj = Undertest(segment_alignment_timeout=0, n_mfcc=20) undertest_obj.predict_dual_pass(self.video_file_path, self.srt_file_path, self.weights_dir) except Exception as e: self.assertTrue(isinstance(e, TerminalException)) else: self.fail("Should have thrown exception") - finally: - Undertest._Predictor__SEGMENT_PREDICTION_TIMEOUT = backup if __name__ == "__main__": diff --git a/tests/subaligner/test_trainer.py b/tests/subaligner/test_trainer.py index 84dafff..a223524 100644 --- a/tests/subaligner/test_trainer.py +++ b/tests/subaligner/test_trainer.py @@ -213,9 +213,7 @@ def test_no_exception_caused_by_bad_media(self): self.assertEqual(1, len(hyperparams_files)) def test_no_exception_caused_by_timeout(self): - timeout = Undertest.EMBEDDING_TIMEOUT - Undertest.EMBEDDING_TIMEOUT = 0.01 - Undertest(FeatureEmbedder(n_mfcc=20, step_sample=0.05)).train( + Undertest(FeatureEmbedder(n_mfcc=20, step_sample=0.05), feature_embedding_timeout=0.01).train( [self.video_file_path], [self.srt_file_path], model_dir=self.resource_tmp, @@ -232,7 +230,6 @@ def test_no_exception_caused_by_timeout(self): ) # one model file, one weights file and one combined file and one training dump hyperparams_files = [file for file in output_files if file.endswith(".json")] self.assertEqual(1, len(hyperparams_files)) - Undertest.EMBEDDING_TIMEOUT = timeout def test_get_done_epochs(self): assert Undertest.get_done_epochs(self.training_log_path) == 1