-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Mohamed Mejri
committed
Jan 23, 2023
1 parent
ad0e528
commit bc773bb
Showing
23 changed files
with
1,179 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
name: download_data | ||
conda_env: conda.yml | ||
|
||
entry_points: | ||
main: | ||
parameters: | ||
hydra_options: | ||
description: Hydra parameters to override | ||
type: str | ||
default: '' | ||
command: >- | ||
python main.py $(echo {hydra_options}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
name: exercise_9 | ||
conda_env: conda.yml | ||
|
||
entry_points: | ||
main: | ||
parameters: | ||
reference_artifact: | ||
description: Fully-qualitied name for the artifact to be used as reference dataset | ||
type: str | ||
sample_artifact: | ||
description: Fully-qualitied name for the artifact to be used as new data sample | ||
type: str | ||
ks_alpha: | ||
description: Threshold for the (pre-trial) p-value for the KS test | ||
type: float | ||
# NOTE: the -s flag is necessary, otherwise pytest will capture all the output and it | ||
# will not be uploaded to W&B. Hence, the log in W&B will be empty. | ||
command: >- | ||
pytest -s -vv . --reference_artifact {reference_artifact} \ | ||
--sample_artifact {sample_artifact} \ | ||
--ks_alpha {ks_alpha} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
name: download_data | ||
channels: | ||
- conda-forge | ||
- defaults | ||
dependencies: | ||
- pandas=1.2.3 | ||
- pip=20.3.3 | ||
- pytest=6.2.2 | ||
- scipy=1.6.1 | ||
- pip: | ||
- wandb==0.10.21 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import pytest | ||
import pandas as pd | ||
import wandb | ||
|
||
|
||
run = wandb.init(job_type="data_tests") | ||
|
||
|
||
def pytest_addoption(parser): | ||
parser.addoption("--reference_artifact", action="store") | ||
parser.addoption("--sample_artifact", action="store") | ||
parser.addoption("--ks_alpha", action="store") | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def data(request): | ||
|
||
reference_artifact = request.config.option.reference_artifact | ||
|
||
if reference_artifact is None: | ||
pytest.fail("--reference_artifact missing on command line") | ||
|
||
sample_artifact = request.config.option.sample_artifact | ||
|
||
if sample_artifact is None: | ||
pytest.fail("--sample_artifact missing on command line") | ||
|
||
local_path = run.use_artifact(reference_artifact).file() | ||
sample1 = pd.read_csv(local_path) | ||
|
||
local_path = run.use_artifact(sample_artifact).file() | ||
sample2 = pd.read_csv(local_path) | ||
|
||
return sample1, sample2 | ||
|
||
|
||
@pytest.fixture(scope='session') | ||
def ks_alpha(request): | ||
ks_alpha = request.config.option.ks_alpha | ||
|
||
if ks_alpha is None: | ||
pytest.fail("--ks_threshold missing on command line") | ||
|
||
return float(ks_alpha) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import scipy.stats | ||
import pandas as pd | ||
|
||
|
||
def test_column_presence_and_type(data): | ||
|
||
# Disregard the reference dataset | ||
_, data = data | ||
|
||
required_columns = { | ||
"time_signature": pd.api.types.is_integer_dtype, | ||
"key": pd.api.types.is_integer_dtype, | ||
"danceability": pd.api.types.is_float_dtype, | ||
"energy": pd.api.types.is_float_dtype, | ||
"loudness": pd.api.types.is_float_dtype, | ||
"speechiness": pd.api.types.is_float_dtype, | ||
"acousticness": pd.api.types.is_float_dtype, | ||
"instrumentalness": pd.api.types.is_float_dtype, | ||
"liveness": pd.api.types.is_float_dtype, | ||
"valence": pd.api.types.is_float_dtype, | ||
"tempo": pd.api.types.is_float_dtype, | ||
"duration_ms": pd.api.types.is_integer_dtype, # This is integer, not float as one might expect | ||
"text_feature": pd.api.types.is_string_dtype, | ||
"genre": pd.api.types.is_string_dtype | ||
} | ||
|
||
# Check column presence | ||
assert set(data.columns.values).issuperset(set(required_columns.keys())) | ||
|
||
for col_name, format_verification_funct in required_columns.items(): | ||
|
||
assert format_verification_funct(data[col_name]), f"Column {col_name} failed test {format_verification_funct}" | ||
|
||
|
||
def test_class_names(data): | ||
|
||
# Disregard the reference dataset | ||
_, data = data | ||
|
||
# Check that only the known classes are present | ||
known_classes = [ | ||
"Dark Trap", | ||
"Underground Rap", | ||
"Trap Metal", | ||
"Emo", | ||
"Rap", | ||
"RnB", | ||
"Pop", | ||
"Hiphop", | ||
"techhouse", | ||
"techno", | ||
"trance", | ||
"psytrance", | ||
"trap", | ||
"dnb", | ||
"hardstyle", | ||
] | ||
|
||
assert data["genre"].isin(known_classes).all() | ||
|
||
|
||
def test_column_ranges(data): | ||
|
||
# Disregard the reference dataset | ||
_, data = data | ||
|
||
ranges = { | ||
"time_signature": (1, 5), | ||
"key": (0, 11), | ||
"danceability": (0, 1), | ||
"energy": (0, 1), | ||
"loudness": (-35, 5), | ||
"speechiness": (0, 1), | ||
"acousticness": (0, 1), | ||
"instrumentalness": (0, 1), | ||
"liveness": (0, 1), | ||
"valence": (0, 1), | ||
"tempo": (50, 250), | ||
"duration_ms": (20000, 1000000), | ||
} | ||
|
||
for col_name, (minimum, maximum) in ranges.items(): | ||
|
||
assert data[col_name].dropna().between(minimum, maximum).all(), ( | ||
f"Column {col_name} failed the test. Should be between {minimum} and {maximum}, " | ||
f"instead min={data[col_name].min()} and max={data[col_name].max()}" | ||
) | ||
|
||
|
||
def test_kolmogorov_smirnov(data, ks_alpha): | ||
|
||
sample1, sample2 = data | ||
|
||
columns = [ | ||
"danceability", | ||
"energy", | ||
"loudness", | ||
"speechiness", | ||
"acousticness", | ||
"instrumentalness", | ||
"liveness", | ||
"valence", | ||
"tempo", | ||
"duration_ms" | ||
] | ||
|
||
# Bonferroni correction for multiple hypothesis testing | ||
# (see my blog post on this topic to see where this comes from: | ||
# https://towardsdatascience.com/precision-and-recall-trade-off-and-multiple-hypothesis-testing-family-wise-error-rate-vs-false-71a85057ca2b) | ||
alpha_prime = 1 - (1 - ks_alpha)**(1 / len(columns)) | ||
|
||
for col in columns: | ||
|
||
ts, p_value = scipy.stats.ks_2samp(sample1[col], sample2[col]) | ||
|
||
# NOTE: as always, the p-value should be interpreted as the probability of | ||
# obtaining a test statistic (TS) equal or more extreme that the one we got | ||
# by chance, when the null hypothesis is true. If this probability is not | ||
# large enough, this dataset should be looked at carefully, hence we fail | ||
assert p_value > alpha_prime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
name: download_data | ||
channels: | ||
- conda-forge | ||
- defaults | ||
dependencies: | ||
- requests=2.24.0 | ||
- pip=20.3.3 | ||
- mlflow=1.14.1 | ||
- hydra-core=1.0.6 | ||
- pip: | ||
- wandb==0.10.21 | ||
- hydra-joblib-launcher==1.1.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
main: | ||
project_name: exercise_14 | ||
experiment_name: dev | ||
execute_steps: | ||
- download | ||
- preprocess | ||
- check_data | ||
- segregate | ||
- random_forest | ||
- evaluate | ||
# This seed will be used to seed the random number generator | ||
# to ensure repeatibility of the data splits and other | ||
# pseudo-random operations | ||
random_seed: 42 | ||
data: | ||
file_url: "https://github.com/udacity/nd0821-c2-build-model-workflow-exercises/blob/master/lesson-2-data-exploration-and-preparation/exercises/exercise_4/starter/genres_mod.parquet?raw=true" | ||
reference_dataset: "exercise_14/preprocessed_data.csv:latest" | ||
# Threshold for Kolomorov-Smirnov test | ||
ks_alpha: 0.05 | ||
test_size: 0.3 | ||
val_size: 0.3 | ||
# Stratify according to the target when splitting the data | ||
# in train/test or in train/val | ||
stratify: genre | ||
random_forest_pipeline: | ||
random_forest: | ||
n_estimators: 100 | ||
criterion: 'gini' | ||
max_depth: 13 | ||
min_samples_split: 2 | ||
min_samples_leaf: 1 | ||
min_weight_fraction_leaf: 0.0 | ||
max_features: 'auto' | ||
max_leaf_nodes: null | ||
min_impurity_decrease: 0.0 | ||
min_impurity_split: null | ||
bootstrap: true | ||
oob_score: false | ||
n_jobs: null | ||
# This is a different random seed than main.random_seed, | ||
# because this is used only within the RandomForest | ||
random_state: 42 | ||
verbose: 0 | ||
warm_start: false | ||
class_weight: "balanced" | ||
ccp_alpha: 0.0 | ||
max_samples: null | ||
tfidf: | ||
max_features: 10 | ||
features: | ||
numerical: | ||
- "danceability" | ||
- "energy" | ||
- "loudness" | ||
- "speechiness" | ||
- "acousticness" | ||
- "instrumentalness" | ||
- "liveness" | ||
- "valence" | ||
- "tempo" | ||
- "duration_ms" | ||
categorical: | ||
- "time_signature" | ||
- "key" | ||
nlp: | ||
- "text_feature" | ||
export_artifact: "model_export" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
name: download_data | ||
conda_env: conda.yml | ||
|
||
entry_points: | ||
main: | ||
parameters: | ||
file_url: | ||
description: URL of the file to download | ||
type: uri | ||
artifact_name: | ||
description: Name for the W&B artifact that will be created | ||
type: str | ||
artifact_type: | ||
description: Type of the artifact to create | ||
type: str | ||
default: raw_data | ||
artifact_description: | ||
description: Description for the artifact | ||
type: str | ||
|
||
command: >- | ||
python download_data.py --file_url {file_url} \ | ||
--artifact_name {artifact_name} \ | ||
--artifact_type {artifact_type} \ | ||
--artifact_description {artifact_description} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
name: download_data | ||
channels: | ||
- conda-forge | ||
- defaults | ||
dependencies: | ||
- requests=2.24.0 | ||
- pip=20.3.3 | ||
- pip: | ||
- wandb==0.10.21 |
Oops, something went wrong.