Skip to content

Commit

Permalink
Make DataFrameTextBenchmark script pos_weight optional (#379)
Browse files Browse the repository at this point in the history
Make `DataFrameTextBenchmark` script `pos_weight` optional as some
datasets may not need this (such as `kick`)
  • Loading branch information
zechengz authored Mar 15, 2024
1 parent 0cb4d7f commit 5345eea
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions benchmark/data_frame_text_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
"--idx",
type=int,
default=0,
help="The index of the dataset within DataFrameBenchmark",
help="The index of the dataset within DataFrameTextBenchmark",
)
parser.add_argument(
"--model_type",
Expand All @@ -87,6 +87,12 @@
],
)
parser.add_argument("--finetune", action="store_true")
parser.add_argument(
"--pos_weight",
action="store_true",
help=("Whether to set `pos_weight` in `BCEWithLogitsLoss` "
"for the binary classification task."),
)
parser.add_argument('--result_path', type=str, default='')
parser.add_argument("--api_key", type=str, default=None)
args = parser.parse_args()
Expand Down Expand Up @@ -464,8 +470,12 @@ def main_torch(

if dataset.task_type == TaskType.BINARY_CLASSIFICATION:
out_channels = 1
label_imbalance = sum(train_tensor_frame.y) / len(train_tensor_frame.y)
loss_fun = BCEWithLogitsLoss(pos_weight=1 / label_imbalance)
if args.pos_weight:
label_imbalance = sum(train_tensor_frame.y) / len(
train_tensor_frame.y)
loss_fun = BCEWithLogitsLoss(pos_weight=1 / label_imbalance)
else:
loss_fun = BCEWithLogitsLoss()
metric_computer = AUROC(task='binary').to(device)
higher_is_better = True
elif dataset.task_type == TaskType.MULTICLASS_CLASSIFICATION:
Expand Down

0 comments on commit 5345eea

Please sign in to comment.