Skip to content

Commit

Permalink
add docs linting
Browse files Browse the repository at this point in the history
  • Loading branch information
baxtree committed May 10, 2024
1 parent 658cfdd commit f7f55ba
Show file tree
Hide file tree
Showing 13 changed files with 111 additions and 25 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ci-pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ jobs:
- name: Linting
run: |
pycodestyle subaligner tests examples misc bin/subaligner bin/subaligner_1pass bin/subaligner_2pass bin/subaligner_batch bin/subaligner_convert bin/subaligner_train bin/subaligner_tune setup.py --ignore=E203,E501,W503 --exclude="subaligner/lib"
- name: Linting docstring
run: |
darglint -v 2 subaligner
- name: Unit tests and coverage
run: |
coverage run -m unittest discover
Expand Down
3 changes: 2 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ types-setuptools==57.4.9
typing-extensions==4.5.0
parameterized==0.8.1
pylint~=2.17.2
pygments==2.7.4
pygments==2.7.4
darglint~=1.8.1
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Markdown==2.6.11
mccabe==0.6.1
networkx>=2.5.1
numba>=0.50.0
numpy<1.24.0
numpy<1.27.0
oauthlib==3.1.0
pbr==4.0.2
pkgconfig~=1.5.5
Expand All @@ -49,7 +49,7 @@ pystack-debugger==0.8.0
pytz==2018.4
PyYAML>=4.2b1
rsa==4.7
scipy<1.11.0
scipy<1.12.0
scikit-learn<1.2.0
six~=1.15.0
tblib==1.3.2
Expand Down
3 changes: 3 additions & 0 deletions subaligner/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ def extract_data_and_label_from_audio(
Returns:
tuple -- The training data and the training lables.
Raises:
TerminalException: Thrown when the subtitles are missing.
"""

len_mfcc = self.get_len_mfcc()
Expand Down
6 changes: 5 additions & 1 deletion subaligner/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ def __init__(self) -> None:
self.__network_type = "lstm"

def __eq__(self, other: Any) -> bool:
"""Comparator for Hyperparameters objects"""
"""Comparator for Hyperparameters objects
Returns:
bool -- If True, the compared hyperparameter object is the same
"""

if isinstance(other, Hyperparameters):
return all([
Expand Down
17 changes: 17 additions & 0 deletions subaligner/media_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,17 @@ def extract_audio(self, video_file_path, decompress: bool = False, freq: int = 1
Arguments:
video_file_path {string} -- The input video file path.
Keyword Arguments:
decompress {bool} -- Extract WAV if True otherwise extract AAC (default: {False}).
freq {int} -- The audio sample frequency (default: {16000}).
Returns:
string -- The file path of the extracted audio.
Raises:
TerminalException: If audio extraction is interrupted by user hitting the interrupt key or timed out.
Exception: Thrown when any other exceptions occur.
"""

basename = os.path.basename(video_file_path)
Expand Down Expand Up @@ -173,6 +179,10 @@ def extract_audio_from_start_to_end(self, audio_file_path: str, start: str, end:
Returns:
tuple -- The file path to the extracted audio and its duration.
Raises:
TerminalException: If audio extraction is interrupted by user hitting the interrupt key or timed out.
Exception: Thrown when any other exceptions occur.
"""
segment_duration = self.get_duration_in_seconds(start, end)
basename = os.path.basename(audio_file_path)
Expand Down Expand Up @@ -311,8 +321,14 @@ def get_frame_rate(self, file_path: str) -> float:
Arguments:
file_path {string} -- The input audiovisual file path.
Returns:
float -- The frame rate
Raises:
TerminalException: If frame rate extraction is interrupted by user hitting the interrupt key or timed out.
NoFrameRateException: If no frame rate is detected on the input audiovisual file.
Exception: Thrown when any other exceptions occur.
"""

discarded = "NUL:" if os.name == "nt" else "/dev/null"
Expand Down Expand Up @@ -378,6 +394,7 @@ def refragment_with_min_duration(self, subs: List[SubRipItem], minimum_segment_d
Arguments:
subs {list} -- A list of SupRip cues.
minimum_segment_duration {float} -- The minimum duration in seconds for each output subtitle cue.
Returns:
list -- A list of new SupRip cues after fragmentation.
"""
Expand Down
17 changes: 15 additions & 2 deletions subaligner/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
model_path: Optional[str] = None,
backend: str = "tensorflow"
) -> None:
""" Network object initialiser used by factory methods.
"""Network object initialiser used by factory methods.
Arguments:
secret {object} -- A hash only known by factory methods.
Expand All @@ -65,8 +65,9 @@ def __init__(
model_path {string} -- The path to the model file.
backend {string} -- The tensor manipulation backend (default: {tensorflow}). Only tensorflow is supported
by TF 2 and this parameter is here only for a historical reason.
Raises:
NotImplementedError -- Thrown when any network attributes are modified.
ValueError: Thrown when the network type is not supported.
"""
assert (
secret == Network.__secret
Expand Down Expand Up @@ -124,6 +125,9 @@ def get_from_model(cls, model_path: str, hyperparameters: Hyperparameters) -> "N
Arguments:
model_path {string} -- The path to the model file.
hyperparameters {Hyperparameters} -- A configuration for hyperparameters used for training.
Returns:
Network: The model network.
"""

hp = hyperparameters.clone()
Expand Down Expand Up @@ -236,8 +240,12 @@ def fit_and_get_history(
logs_dir {string} -- The TensorBoard log file directory.
training_log {string} -- The path to the log file of epoch results.
resume {bool} -- True to continue with previous training result or False to start a new one (default: {False}).
Returns:
tuple -- A tuple contains validation losses and validation accuracies.
Raises:
TerminalException: If the predication is interrupted by user hitting the interrupt key
"""

csv_logger = (
Expand Down Expand Up @@ -324,8 +332,12 @@ def fit_with_generator(
logs_dir {string} -- The TensorBoard log file directory.
training_log {string} -- The path to the log file of epoch results.
resume {bool} -- True to continue with previous training result or False to start a new one (default: {False}).
Returns:
tuple -- A tuple contains validation losses and validation accuracies.
Raises:
TerminalException: If the training is interrupted by user hitting the interrupt key
"""

initial_epoch = 0
Expand Down Expand Up @@ -455,6 +467,7 @@ def simple_fit_with_generator(
train_data_raw {list} -- The HDF5 raw training data.
labels_raw {list} -- The HDF5 raw training labels.
hyperparameters {Hyperparameters} -- A configuration for hyperparameters used for training.
Returns:
tuple -- A tuple contains validation losses and validation accuracies.
"""
Expand Down
47 changes: 34 additions & 13 deletions subaligner/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ def predict_dual_pass(
"""Predict time to shift with single pass
Arguments:
video_file_path {string} -- The input video file path.
subtitle_file_path {string} -- The path to the subtitle file.
weights_dir {string} -- The the model weights directory.
stretch {bool} -- True to stretch the subtitle segments (default: {False})
stretch_in_lang {str} -- The language used for stretching subtitles (default: {"eng"}).
exit_segfail {bool} -- True to exit on any segment alignment failures (default: {False})
video_file_path {string} -- The input video file path.
subtitle_file_path {string} -- The path to the subtitle file.
weights_dir {string} -- The the model weights directory.
stretch {bool} -- True to stretch the subtitle segments (default: {False})
stretch_in_lang {str} -- The language used for stretching subtitles (default: {"eng"}).
exit_segfail {bool} -- True to exit on any segment alignment failures (default: {False})
Returns:
tuple -- The shifted subtitles, the globally shifted subtitles and the voice probabilities of the original audio.
tuple -- The shifted subtitles, the globally shifted subtitles and the voice probabilities of the original audio.
"""

weights_file_path = self.__get_weights_path(weights_dir)
Expand Down Expand Up @@ -143,12 +143,15 @@ def predict_plain_text(self, video_file_path: str, subtitle_file_path: str, stre
"""Predict time to shift with plain texts
Arguments:
video_file_path {string} -- The input video file path.
subtitle_file_path {string} -- The path to the subtitle file.
stretch_in_lang {str} -- The language used for stretching subtitles (default: {"eng"}).
video_file_path {string} -- The input video file path.
subtitle_file_path {string} -- The path to the subtitle file.
stretch_in_lang {str} -- The language used for stretching subtitles (default: {"eng"}).
Returns:
tuple -- The shifted subtitles, the audio file path (None) and the voice probabilities of the original audio (None).
tuple -- The shifted subtitles, the audio file path (None) and the voice probabilities of the original audio (None).
Raises:
TerminalException: If the predication is interrupted by user hitting the interrupt key.
"""
from aeneas.executetask import ExecuteTask
from aeneas.task import Task
Expand Down Expand Up @@ -223,8 +226,11 @@ def get_log_loss(self, voice_probabilities: np.ndarray, subs: List[SubRipItem])
voice_probabilities {list} -- A list of probabilities of audio chunks being speech.
subs {list} -- A list of subtitle segments.
Returns:
float -- The loss value.
Returns:
float -- The loss value.
Raises:
TerminalException: If the subtitle mask is empty.
"""

subtitle_mask = Predictor.__get_subtitle_mask(self, subs)
Expand Down Expand Up @@ -258,8 +264,12 @@ def get_min_log_loss_and_index(self, voice_probabilities: np.ndarray, subs: SubR
Arguments:
voice_probabilities {list} -- A list of probabilities of audio chunks being speech.
subs {list} -- A list of subtitle segments.
Returns:
tuple -- The minimum loss value and its position.
Raises:
TerminalException: If subtitle is empty or suspicious audio/subtitle duration is detected.
"""

local_subs = deepcopy(subs)
Expand Down Expand Up @@ -526,6 +536,13 @@ def __predict_2nd_pass(self, audio_file_path: str, subs: List[SubRipItem], weigh
stretch {bool} -- True to stretch the subtitle segments.
stretch_in_lang {str} -- The language used for stretching subtitles.
exit_segfail {bool} -- True to exit on any segment alignment failures.
Returns:
list -- A list of aligned SubRip files
Raises:
TerminalException: If the alignment is interrupted by user hitting the interrupt key or times out
Exception: Thrown when any other exceptions occur.
"""

segment_starts, segment_ends, subs = self.__media_helper.get_audio_segment_starts_and_ends(subs)
Expand Down Expand Up @@ -726,6 +743,10 @@ def __predict(
Returns:
tuple -- The shifted subtitles, the audio file path and the voice probabilities of the original audio.
Raises:
TerminalException: If the prediction failed on invalid input or on other exceptions.
ValueError: Thrown when no subtitle is passed in.
"""
if network is None:
network = self.__initialise_network(os.path.dirname(weights_file_path), self.__LOGGER)
Expand Down
5 changes: 4 additions & 1 deletion subaligner/subtitle.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, secret: object, subtitle_file_path: str, subtitle_format: str
format {string} -- Supported subtitle formats: subrip and ttml.
Raises:
NotImplementedError -- Thrown when any subtitle attributes are modified.
UnsupportedFormatException: Thrown when the input subtitle format is not supported or no subtitle content is found.
"""

assert (
Expand Down Expand Up @@ -337,6 +337,9 @@ def shift_subtitle(
Returns:
string -- The path to the shifted subtitle file.
Raises:
UnsupportedFormatException: Thrown when the input subtitle format is not supported.
"""
_, file_extension = os.path.splitext(subtitle_file_path)
if shifted_subtitle_file_path is None:
Expand Down
9 changes: 9 additions & 0 deletions subaligner/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ def pre_train(
hyperparameters {Hyperparameters} -- A configuration for hyperparameters used for training.
sound_effect_start_marker: {string} -- A string indicating the start of the ignored sound effect (default: {"("}).
sound_effect_end_marker: {string} -- A string indicating the end of the ignored sound effect (default: {")"}).
Returns:
tuple -- The valuation loss and accuracy.
"""

training_dump = os.path.join(os.path.abspath(training_dump_dir), "training_dump.hdf5")
Expand Down Expand Up @@ -234,6 +237,9 @@ def get_done_epochs(training_log: str) -> int:
Arguments:
training_log {string} -- The path to the training log file.
Returns:
int -- The number of finished epochs.
"""
if not os.path.isfile(training_log):
return 0
Expand All @@ -260,6 +266,9 @@ def __extract_data_and_label_from_avs(
Returns:
tuple -- The training data and labels.
Raises:
TerminalException: If the extraction is interrupted by user hitting the interrupt key.
"""

train_data, labels = (
Expand Down
11 changes: 8 additions & 3 deletions subaligner/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ def __init__(self, recipe: str = TranscriptionRecipe.WHISPER.value, flavour: str
Arguments:
recipe {string} -- the LLM recipe used for transcribing video files (default: "whisper").
flavour {string} -- the flavour variation for a specific LLM recipe (default: "small").
Raises:
NotImplementedError -- Thrown when the LLM recipe is unknown.
NotImplementedError: Thrown when the LLM recipe is unknown.
"""
if recipe not in [r.value for r in TranscriptionRecipe]:
raise NotImplementedError(f"Unknown recipe: {recipe}")
Expand All @@ -43,9 +44,13 @@ def transcribe(self, video_file_path: str, language_code: str) -> Tuple[Subtitle
Arguments:
video_file_path {string} -- The input video file path.
language_code {string} -- An alpha 3 language code derived from ISO 639-3.
Returns:
{tuple} -- Generated subtitle after transcription and the detected frame rate
Raises:
TranscriptionException -- Thrown when transcription is failed.
NotImplementedError -- Thrown when the LLM recipe is not supported.
TranscriptionException: Thrown when transcription is failed.
NotImplementedError: Thrown when the LLM recipe is not supported.
"""
if self.__recipe == "whisper":
lang = Utils.get_iso_639_alpha_2(language_code)
Expand Down
8 changes: 6 additions & 2 deletions subaligner/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self,
flavour {string} -- the flavour variation for a specific LLM recipe (default: None).
Raises:
NotImplementedError -- Thrown when the model of the specified language pair is not found.
NotImplementedError: Thrown when the model of the specified language pair is not found.
"""

self.__LOGGER = Logger().get_logger(__name__)
Expand All @@ -89,7 +89,11 @@ def translate(self,
language_pair {Tuple[str, str]} -- Used for overriding the default language pair (default: None).
Returns:
{list} -- A list of new SubRipItems holding the translation results.
list -- A list of new SubRipItems holding the translation results.
Raises:
NotImplementedError: Thrown when the input language pair is not supported.
TranslationException: Thrown when the source or the target language is not supported.
"""

if self.__recipe == TranslationRecipe.HELSINKI_NLP.value:
Expand Down
3 changes: 3 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ envlist =
skipsdist=True
skip_missing_interpreters = True

[darglint]
ignore=DAR101

[testenv:py36]
basepython = python3.6
whitelist_externals = /bin/bash
Expand Down

0 comments on commit f7f55ba

Please sign in to comment.