Skip to content

Commit

Permalink
fix script
Browse files Browse the repository at this point in the history
  • Loading branch information
yiweny committed May 13, 2024
1 parent d13073f commit 52311f8
Show file tree
Hide file tree
Showing 12 changed files with 25 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .pt_tmp/exp_version_manager.yml
Original file line number Diff line number Diff line change
@@ -1 +1 @@
classification: 1
classification: 4
49 changes: 22 additions & 27 deletions benchmark/pytorch_tabular_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
import argparse
import os.path as osp

import numpy as np
import pandas as pd
import torch
from pytorch_tabular import TabularModel
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from pytorch_tabular.models.tab_transformer import TabTransformerConfig
from sklearn.metrics import roc_auc_score

from torch_frame import TaskType, stype
from torch_frame.datasets import DataFrameBenchmark


def roc_auc(y_hat, y):
r"""Calculate the Area Under the ROC Curve (AUC)
for the given predictions and true labels.
Parameters:
y_hat (array-like): Predicted probabilities or scores.
y (array-like): True binary labels.
Returns:
float: AUC score.
"""
return roc_auc_score(y, y_hat)


parser = argparse.ArgumentParser()
parser.add_argument(
'--task_type', type=str, choices=[
Expand All @@ -38,25 +52,6 @@
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(args.seed)


def load_classification_data(df, target_col, test_size):
torch_data = np.array(df.drop(target_col, axis=1))
torch_labels = np.array(df[target_col])
data = np.hstack([torch_data, torch_labels.reshape(-1, 1)])
gen_names = [f"feature_{i}" for i in range(data.shape[-1])]
col_names = gen_names
col_names[-1] = "target"
data = pd.DataFrame(data, columns=col_names)
cat_col_names = [x for x in gen_names[:-1] if len(data[x].unique()) < 10]
num_col_names = [
x for x in gen_names[:-1] if x not in [target_col] + cat_col_names
]
test_idx = data.sample(int(test_size * len(data)), random_state=42).index
test = data[data.index.isin(test_idx)]
train = data[~data.index.isin(test_idx)]
return (train, test, ["target"], cat_col_names, num_col_names)


# Prepare datasets
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data')
dataset = DataFrameBenchmark(root=path, task_type=TaskType(args.task_type),
Expand All @@ -65,15 +60,15 @@ def load_classification_data(df, target_col, test_size):
dataset = dataset.shuffle()
train_dataset, val_dataset, test_dataset = dataset.split()

train, test, target_col, cat_col_names, num_col_names = (
train_dataset.df, test_dataset.df, dataset.target_col,
train_df, val_df, test_df, target_col, cat_col_names, num_col_names = (
train_dataset.df, val_dataset.df, test_dataset.df, dataset.target_col,
dataset.tensor_frame.col_names_dict[stype.categorical],
dataset.tensor_frame.col_names_dict[stype.numerical])

data_config = DataConfig(
target=[target_col],
continuous_cols=cat_col_names,
categorical_cols=num_col_names,
continuous_cols=num_col_names,
categorical_cols=cat_col_names,
)

trainer_config = TrainerConfig(
Expand Down Expand Up @@ -103,5 +98,5 @@ def load_classification_data(df, target_col, test_size):
optimizer_config=optimizer_config,
trainer_config=trainer_config,
)
tabular_model.fit(train=train)
tabular_model.evaluate(test)
tabular_model.fit(train=train_df, validation=val_df, metrics=roc_auc)
tabular_model.evaluate(test_df)
Binary file not shown.
Binary file not shown.
Binary file not shown.
1 change: 1 addition & 0 deletions lightning_logs/version_0/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
Binary file not shown.
Binary file not shown.
Binary file not shown.
1 change: 1 addition & 0 deletions lightning_logs/version_1/hparams.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
Binary file not shown.
Binary file not shown.

0 comments on commit 52311f8

Please sign in to comment.