From ab832fc3736ee0c8d30b0adb0e10d88fb8bf0088 Mon Sep 17 00:00:00 2001 From: Hendrik Schreiber Date: Thu, 17 Oct 2024 17:04:14 +0200 Subject: [PATCH] Added type hints. --- .github/workflows/python-package.yml | 5 ++- CHANGES.rst | 1 + tempocnn/classifier.py | 49 +++++++++++++++----------- tempocnn/commands.py | 51 +++++++++++++++------------- tempocnn/feature.py | 23 ++++++++++--- 5 files changed, 81 insertions(+), 48 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 07bbb10..cab6241 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -34,12 +34,15 @@ jobs: sudo apt-get install libsndfile1 sudo apt-get install ffmpeg python -m pip install --upgrade pip setuptools wheel - pip install ruff pytest + pip install ruff pytest mypy pip install .[testing] - name: Lint with ruff run: | ruff check tempocnn test ruff format --check tempocnn test + - name: Type check with mypy + run: | + mypy --ignore-missing-imports --check-untyped-defs tempocnn test - name: Test with pytest run: | coverage run --source ./tempocnn -m pytest --verbose --junitxml=pytest_report${{ matrix.python-version }}.xml diff --git a/CHANGES.rst b/CHANGES.rst index 63ca0f2..b8bff02 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -6,6 +6,7 @@ Changes - Moved to TensorFlow 2.17.0 and Python 3.9/3.10/3.11. - Made local cache version dependent. - Migrated code to Pathlib. + - Added type hints. 0.0.7: - Added DOIs to bibtex entries. diff --git a/tempocnn/classifier.py b/tempocnn/classifier.py index 4e63fc0..0133147 100644 --- a/tempocnn/classifier.py +++ b/tempocnn/classifier.py @@ -3,6 +3,7 @@ import sys import urllib.request from pathlib import Path +from typing import Optional from urllib.error import HTTPError import numpy as np @@ -12,7 +13,7 @@ logger = logging.getLogger("tempocnn.classifier") -def std_normalizer(data): +def std_normalizer(data: np.ndarray) -> np.ndarray: """ Normalizes data to zero mean and unit variance. Used by Mazurka models. @@ -29,7 +30,7 @@ def std_normalizer(data): return data.astype(np.float16) -def max_normalizer(data): +def max_normalizer(data: np.ndarray) -> np.ndarray: """ Divides by max. Used as normalization by older models. @@ -47,7 +48,7 @@ class TempoClassifier: Classifier that can estimate musical tempo in different formats. """ - def __init__(self, model_name="fcn"): + def __init__(self, model_name: str = "fcn"): """ Initializes this classifier with a Keras model. @@ -94,7 +95,7 @@ def __init__(self, model_name="fcn"): logger.debug(f"Loading model {model_name} from {file}") self.model = load_model(file, compile=False) - def estimate(self, data): + def estimate(self, data: np.ndarray) -> np.ndarray: """ Estimate a tempo distribution. Probabilities are indexed, starting with 30 BPM and ending with 286 BPM. @@ -118,7 +119,9 @@ def estimate(self, data): return self.model.predict(norm_data, norm_data.shape[0]) @staticmethod - def quad_interpol_argmax(y, x=None): + def quad_interpol_argmax( + y: np.ndarray, x: Optional[int] = None + ) -> tuple[float, float]: """ Find argmax for quadratic interpolation around argmax of y. @@ -127,16 +130,16 @@ def quad_interpol_argmax(y, x=None): :return: float (index) of interpolated max, strength """ if x is None: - x = np.argmax(y) + x = int(np.argmax(y)) if x == 0 or x == y.shape[0] - 1: - return x, y[x] + return float(x), float(y[x]) z = np.polyfit([x - 1, x, x + 1], [y[x - 1], y[x], y[x + 1]], 2) # find (float) x value for max argmax = -z[1] / (2.0 * z[0]) height = z[2] - (z[1] ** 2.0) / (4.0 * z[0]) - return argmax, height + return float(argmax), float(height) - def estimate_tempo(self, data, interpolate=False): + def estimate_tempo(self, data: np.ndarray, interpolate: bool = False) -> float: """ Estimates the pre-dominant global tempo. @@ -150,10 +153,12 @@ def estimate_tempo(self, data, interpolate=False): if interpolate: index, _ = self.quad_interpol_argmax(averaged_prediction) else: - index = np.argmax(averaged_prediction) + index = int(np.argmax(averaged_prediction)) return self.to_bpm(index) - def estimate_mirex(self, data, interpolate=False): + def estimate_mirex( + self, data: np.ndarray, interpolate: bool = False + ) -> tuple[float, float, float]: """ Estimates the two dominant tempi along with a salience value. @@ -165,8 +170,8 @@ def estimate_mirex(self, data, interpolate=False): prediction = self.estimate(data) - def find_index_peaks(distribution): - p = [] + def find_index_peaks(distribution: np.ndarray) -> list[tuple[float, float]]: + p: list[tuple[float, float]] = [] last_index = 0 for index in range(256): height = distribution[index] @@ -181,7 +186,7 @@ def find_index_peaks(distribution): ) = self.quad_interpol_argmax(distribution, x=index) p.append((interpolated_index, interpolated_height)) else: - p.append((index, height)) + p.append((float(index), float(height))) last_index = index # sort peaks by height, descending return sorted(p, key=lambda element: element[1], reverse=True) @@ -189,6 +194,10 @@ def find_index_peaks(distribution): averaged_prediction = np.average(prediction, axis=0) peaks = find_index_peaks(averaged_prediction) + s1: float + t1: float + t2: float + if len(peaks) == 0: s1 = 1.0 t1 = 0.0 @@ -226,7 +235,7 @@ class MeterClassifier: Classifier that can estimate musical meter """ - def __init__(self, model_name="fcn"): + def __init__(self, model_name: str = "fcn"): """ Initializes this classifier with a Keras model. @@ -254,7 +263,7 @@ def __init__(self, model_name="fcn"): raise e self.model = load_model(file) - def estimate(self, data): + def estimate(self, data: np.ndarray) -> np.ndarray: """ Estimate a meter distribution. Probabilities are indexed, starting with 2. Only the meter numerator is given (e.g. 2 for 2/4). @@ -277,7 +286,7 @@ def estimate(self, data): norm_data = self.normalize(data) return self.model.predict(norm_data, norm_data.shape[0]) - def estimate_meter(self, data): + def estimate_meter(self, data: np.ndarray) -> int: """ Estimates the pre-dominant global meter. @@ -290,7 +299,7 @@ def estimate_meter(self, data): return self._to_meter(index) -def _to_model_resource(model_name): +def _to_model_resource(model_name: str) -> str: file = model_name if not model_name.endswith(".h5"): file = file + ".h5" @@ -299,7 +308,7 @@ def _to_model_resource(model_name): return file -def _extract_from_package(resource): +def _extract_from_package(resource: str) -> str: # check local cache cache_path = Path(Path.home(), ".tempocnn", package_version, resource) if cache_path.exists(): @@ -326,7 +335,7 @@ def _extract_from_package(resource): return str(cache_path) -def _load_model_from_github(resource): +def _load_model_from_github(resource: str): url = f"https://raw.githubusercontent.com/hendriks73/tempo-cnn/main/tempocnn/{resource}" logger.info(f"Attempting to download model file from main branch {url}") try: diff --git a/tempocnn/commands.py b/tempocnn/commands.py index 723eccf..a7f8a48 100644 --- a/tempocnn/commands.py +++ b/tempocnn/commands.py @@ -1,6 +1,7 @@ import argparse import sys from pathlib import Path +from typing import Union, Optional import jams import librosa @@ -98,10 +99,10 @@ def tempo(): output_format = parser.add_mutually_exclusive_group() output_format.add_argument( - "--mirex", help="use MIREX format for output", action="store_true" + "--mirex", help="use MIREX format for output", action="store_true", type=bool ) output_format.add_argument( - "--jams", help="use JAMS format for output", action="store_true" + "--jams", help="use JAMS format for output", action="store_true", type=bool ) parser.add_argument( @@ -132,6 +133,7 @@ def tempo(): create_jam = args.jams create_mirex = args.mirex + result: Union[str, jams.JAMS] if create_mirex or create_jam: t1, t2, s1 = classifier.estimate_mirex( features, interpolate=args.interpolate @@ -169,14 +171,14 @@ def tempo(): def _write_tempo_result( - result, - input_file=None, - output_dir=None, - output_list=None, - index=0, - append_extension=None, - replace_extension=None, - create_jam=False, + result: Union[str, jams.JAMS], + input_file: str, + output_dir: Optional[str] = None, + output_list: Optional[list[str]] = None, + index: int = 0, + append_extension: Optional[str] = None, + replace_extension: Optional[str] = None, + create_jam: bool = False, ): """ Write the tempo analysis results to a file. @@ -207,7 +209,7 @@ def _write_tempo_result( output_file = Path(output_list[index]) # actually writing the output - if create_jam: + if create_jam and isinstance(result, jams.JAMS): result.save(str(output_file)) elif output_file is None: print("\n" + result) @@ -216,7 +218,9 @@ def _write_tempo_result( file_name.write(result + "\n") -def _create_tempo_jam(input_file, model, s1, t1, t2): +def _create_tempo_jam( + input_file: Union[str, Path], model: str, s1: float, t1: float, t2: float +) -> jams.JAMS: result = jams.JAMS() y, sr = librosa.load(input_file) track_duration = librosa.get_duration(y=y, sr=sr) @@ -378,7 +382,8 @@ def tempogram(): frame_length = (fft_hop_length / sr) * hop_length fig = plt.figure() - fig.canvas.manager.set_window_title("tempogram: " + file) + if fig.canvas.manager is not None: + fig.canvas.manager.set_window_title("tempogram: " + file) if args.png: fig.set_size_inches(5, 2) @@ -422,7 +427,7 @@ def tempogram(): print("\nDone") -def _norm_tempogram_frames(predictions=None, norm_frame=None): +def _norm_tempogram_frames(predictions: np.ndarray, norm_frame: str) -> np.ndarray: norm_order = np.inf if "max" == norm_frame.lower(): norm_order = np.inf @@ -439,14 +444,14 @@ def _norm_tempogram_frames(predictions=None, norm_frame=None): def _write_tempogram_as_csv( - predictions=None, - classifier=None, - file=None, - frame_length=None, - log_scale=False, - min_bpm=30, - max_bpm=256, - sharpen=False, + predictions: np.ndarray, + classifier: TempoClassifier, + file: str, + frame_length: int, + log_scale: bool = False, + min_bpm: int = 30, + max_bpm: int = 256, + sharpen: bool = False, ): csv_file_name = file + ".csv" if sharpen: @@ -479,7 +484,7 @@ def _write_tempogram_as_csv( ) -def _get_tempogram_limits(log_scale): +def _get_tempogram_limits(log_scale: bool) -> tuple[int, int, int]: if log_scale: min_bpm = 50 max_bpm = 500 diff --git a/tempocnn/feature.py b/tempocnn/feature.py index b721e15..26564f8 100644 --- a/tempocnn/feature.py +++ b/tempocnn/feature.py @@ -5,11 +5,24 @@ 20 to 5000 Hz. """ +import os +from pathlib import Path +from typing import Union, Any, BinaryIO + +import audioread import librosa as librosa import numpy as np +import soundfile as sf -def read_features(file, frames=256, hop_length=128, zero_pad=False): +def read_features( + file: Union[ + str, Path, os.PathLike[Any], sf.SoundFile, audioread.AudioFile, BinaryIO + ], + frames: int = 256, + hop_length: int = 128, + zero_pad: bool = False, +) -> np.ndarray: """ Resample file to 11025 Hz, then transform using STFT with length 1024 and hop size 512. Convert resulting linear spectrum to mel spectrum @@ -56,13 +69,13 @@ def read_features(file, frames=256, hop_length=128, zero_pad=False): return _to_sliding_window(data, frames, hop_length) -def _ensure_length(data, length): +def _ensure_length(data: np.ndarray, length: int) -> np.ndarray: padded_data = np.zeros((1, data.shape[1], length, 1), dtype=data.dtype) padded_data[0, :, 0 : data.shape[2], 0] = data[0, :, :, 0] return padded_data -def _add_zeros(data, zeros): +def _add_zeros(data: np.ndarray, zeros: int) -> np.ndarray: padded_data = np.zeros( (1, data.shape[1], data.shape[2] + zeros, 1), dtype=data.dtype ) @@ -70,7 +83,9 @@ def _add_zeros(data, zeros): return padded_data -def _to_sliding_window(data, window_length, hop_length): +def _to_sliding_window( + data: np.ndarray, window_length: int, hop_length: int +) -> np.ndarray: total_frames = data.shape[2] windowed_data = [] for offset in range(