Skip to content

Commit

Permalink
Handle label imbalance in binary classification tasks on text benchma…
Browse files Browse the repository at this point in the history
…rk (#376)

Labels in the text benchmarks are imbalanced and weighting the positive
labels improves performance.
Experiments done on `fake` dataset (5% positive labels) with
`text_embedded` and `RoBERTa` encodings:

- `ResNet` result changes 91.1% -> 93.4% 
- `FTTransformer` result remains unchanged
- `Trompt` result changes 95.2% -> 95.8%

The differences were even more stark with distilled roberta, but we
aren't reporting those anywhere so I didn't note them down.

More results are pending

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
vid-koci and pre-commit-ci[bot] authored Mar 11, 2024
1 parent 893678f commit 96bdf12
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion benchmark/data_frame_text_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,8 @@ def main_torch(

if dataset.task_type == TaskType.BINARY_CLASSIFICATION:
out_channels = 1
loss_fun = BCEWithLogitsLoss()
label_imbalance = sum(train_tensor_frame.y) / len(train_tensor_frame.y)
loss_fun = BCEWithLogitsLoss(pos_weight=1 / label_imbalance)
metric_computer = AUROC(task='binary').to(device)
higher_is_better = True
elif dataset.task_type == TaskType.MULTICLASS_CLASSIFICATION:
Expand Down

0 comments on commit 96bdf12

Please sign in to comment.