diff --git a/.pt_tmp/exp_version_manager.yml b/.pt_tmp/exp_version_manager.yml index edafd308..cde6119e 100644 --- a/.pt_tmp/exp_version_manager.yml +++ b/.pt_tmp/exp_version_manager.yml @@ -1 +1 @@ -classification: 1 +classification: 4 diff --git a/benchmark/pytorch_tabular_benchmark.py b/benchmark/pytorch_tabular_benchmark.py index 2a049c7d..9adc7d22 100644 --- a/benchmark/pytorch_tabular_benchmark.py +++ b/benchmark/pytorch_tabular_benchmark.py @@ -1,17 +1,31 @@ import argparse import os.path as osp -import numpy as np -import pandas as pd import torch from pytorch_tabular import TabularModel from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig from pytorch_tabular.models.common.heads import LinearHeadConfig from pytorch_tabular.models.tab_transformer import TabTransformerConfig +from sklearn.metrics import roc_auc_score from torch_frame import TaskType, stype from torch_frame.datasets import DataFrameBenchmark + +def roc_auc(y_hat, y): + r"""Calculate the Area Under the ROC Curve (AUC) + for the given predictions and true labels. + + Parameters: + y_hat (array-like): Predicted probabilities or scores. + y (array-like): True binary labels. + + Returns: + float: AUC score. + """ + return roc_auc_score(y, y_hat) + + parser = argparse.ArgumentParser() parser.add_argument( '--task_type', type=str, choices=[ @@ -38,25 +52,6 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') torch.manual_seed(args.seed) - -def load_classification_data(df, target_col, test_size): - torch_data = np.array(df.drop(target_col, axis=1)) - torch_labels = np.array(df[target_col]) - data = np.hstack([torch_data, torch_labels.reshape(-1, 1)]) - gen_names = [f"feature_{i}" for i in range(data.shape[-1])] - col_names = gen_names - col_names[-1] = "target" - data = pd.DataFrame(data, columns=col_names) - cat_col_names = [x for x in gen_names[:-1] if len(data[x].unique()) < 10] - num_col_names = [ - x for x in gen_names[:-1] if x not in [target_col] + cat_col_names - ] - test_idx = data.sample(int(test_size * len(data)), random_state=42).index - test = data[data.index.isin(test_idx)] - train = data[~data.index.isin(test_idx)] - return (train, test, ["target"], cat_col_names, num_col_names) - - # Prepare datasets path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data') dataset = DataFrameBenchmark(root=path, task_type=TaskType(args.task_type), @@ -65,15 +60,15 @@ def load_classification_data(df, target_col, test_size): dataset = dataset.shuffle() train_dataset, val_dataset, test_dataset = dataset.split() -train, test, target_col, cat_col_names, num_col_names = ( - train_dataset.df, test_dataset.df, dataset.target_col, +train_df, val_df, test_df, target_col, cat_col_names, num_col_names = ( + train_dataset.df, val_dataset.df, test_dataset.df, dataset.target_col, dataset.tensor_frame.col_names_dict[stype.categorical], dataset.tensor_frame.col_names_dict[stype.numerical]) data_config = DataConfig( target=[target_col], - continuous_cols=cat_col_names, - categorical_cols=num_col_names, + continuous_cols=num_col_names, + categorical_cols=cat_col_names, ) trainer_config = TrainerConfig( @@ -103,5 +98,5 @@ def load_classification_data(df, target_col, test_size): optimizer_config=optimizer_config, trainer_config=trainer_config, ) -tabular_model.fit(train=train) -tabular_model.evaluate(test) +tabular_model.fit(train=train_df, validation=val_df, metrics=roc_auc) +tabular_model.evaluate(test_df) diff --git a/lightning_logs/version_0/events.out.tfevents.1715586909.ip-10-0-153-205.1676053.0 b/lightning_logs/version_0/events.out.tfevents.1715586909.ip-10-0-153-205.1676053.0 new file mode 100644 index 00000000..5851387c Binary files /dev/null and b/lightning_logs/version_0/events.out.tfevents.1715586909.ip-10-0-153-205.1676053.0 differ diff --git a/lightning_logs/version_0/events.out.tfevents.1715586916.ip-10-0-153-205.1676053.1 b/lightning_logs/version_0/events.out.tfevents.1715586916.ip-10-0-153-205.1676053.1 new file mode 100644 index 00000000..511af502 Binary files /dev/null and b/lightning_logs/version_0/events.out.tfevents.1715586916.ip-10-0-153-205.1676053.1 differ diff --git a/lightning_logs/version_0/events.out.tfevents.1715586967.ip-10-0-153-205.1676053.2 b/lightning_logs/version_0/events.out.tfevents.1715586967.ip-10-0-153-205.1676053.2 new file mode 100644 index 00000000..6a1ae393 Binary files /dev/null and b/lightning_logs/version_0/events.out.tfevents.1715586967.ip-10-0-153-205.1676053.2 differ diff --git a/lightning_logs/version_0/hparams.yaml b/lightning_logs/version_0/hparams.yaml new file mode 100644 index 00000000..0967ef42 --- /dev/null +++ b/lightning_logs/version_0/hparams.yaml @@ -0,0 +1 @@ +{} diff --git a/lightning_logs/version_1/events.out.tfevents.1715587174.ip-10-0-153-205.1680263.0 b/lightning_logs/version_1/events.out.tfevents.1715587174.ip-10-0-153-205.1680263.0 new file mode 100644 index 00000000..d512e9a6 Binary files /dev/null and b/lightning_logs/version_1/events.out.tfevents.1715587174.ip-10-0-153-205.1680263.0 differ diff --git a/lightning_logs/version_1/events.out.tfevents.1715587181.ip-10-0-153-205.1680263.1 b/lightning_logs/version_1/events.out.tfevents.1715587181.ip-10-0-153-205.1680263.1 new file mode 100644 index 00000000..aedb15d1 Binary files /dev/null and b/lightning_logs/version_1/events.out.tfevents.1715587181.ip-10-0-153-205.1680263.1 differ diff --git a/lightning_logs/version_1/events.out.tfevents.1715587243.ip-10-0-153-205.1680263.2 b/lightning_logs/version_1/events.out.tfevents.1715587243.ip-10-0-153-205.1680263.2 new file mode 100644 index 00000000..3bdf2002 Binary files /dev/null and b/lightning_logs/version_1/events.out.tfevents.1715587243.ip-10-0-153-205.1680263.2 differ diff --git a/lightning_logs/version_1/hparams.yaml b/lightning_logs/version_1/hparams.yaml new file mode 100644 index 00000000..0967ef42 --- /dev/null +++ b/lightning_logs/version_1/hparams.yaml @@ -0,0 +1 @@ +{} diff --git a/saved_models/classification-3_epoch=9-valid_loss=0.36.ckpt b/saved_models/classification-3_epoch=9-valid_loss=0.36.ckpt new file mode 100644 index 00000000..cca5849d Binary files /dev/null and b/saved_models/classification-3_epoch=9-valid_loss=0.36.ckpt differ diff --git a/saved_models/classification-4_epoch=9-valid_loss=0.33.ckpt b/saved_models/classification-4_epoch=9-valid_loss=0.33.ckpt new file mode 100644 index 00000000..2ef8dd23 Binary files /dev/null and b/saved_models/classification-4_epoch=9-valid_loss=0.33.ckpt differ