Skip to content

Commit

Permalink
More configs and a new metric
Browse files Browse the repository at this point in the history
  • Loading branch information
Old-Shatterhand committed Aug 13, 2024
1 parent a048970 commit 4189fbd
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 105 deletions.
51 changes: 27 additions & 24 deletions configs/downstream/all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,41 +22,44 @@ datasets:
task: multilabel
pre-transforms:
model:
# - name: rf
# n_estimators: 500
# n_jobs: -1
# random_state: 42
# - name: svm
# random_state: 42
# - name: xgb
# random_state: 42
# - name: mlp
# hidden_dim: 1024
# batch_size: 256
# num_layers: 3
# epochs: 100
# patience: 30
# learning_rate: 0
# optimizer: Adam
- name: rf
n_estimators: 500
n_jobs: -1
random_state: 42
- name: svm
random_state: 42
- name: xgb
random_state: 42
- name: mlp
hidden_dim: 1024
batch_size: 256
num_layers: 3
epochs: 100
patience: 30
learning_rate: 0
optimizer: Adam
- name: sweetnet
hidden_dim: 768
batch_size: 256
epochs: 50
patience: 30
learning_rate: 0.001
optimizer: Adam
# - name: gnngly
# hidden_dim: 14
# batch_size: 256
# num_layers: 5
# epochs: 200
# patience: 30
# learning_rate: 0
# optimizer: Adam
suffix: _768
- name: gnngly
hidden_dim: 14
batch_size: 256
num_layers: 5
epochs: 200
patience: 30
learning_rate: 0
optimizer: Adam
suffix: _5_14
- name: gifflar
hidden_dim: 768
batch_size: 32
num_layers: 6
epochs: 100
learning_rate: 0.001
optimizer: Adam
suffix: _768_6
16 changes: 0 additions & 16 deletions configs/downstream/check.yaml

This file was deleted.

32 changes: 1 addition & 31 deletions configs/downstream/dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,42 +24,12 @@ datasets:
task: multilabel
pre-transforms:
model:
# - name: rf
# n_estimators: 500
# n_jobs: -1
# random_state: 42
# - name: svm
# random_state: 42
# - name: xgb
# random_state: 42
# - name: mlp
# hidden_dim: 1024
# batch_size: 256
# num_layers: 4
# epochs: 1
# patience: 30
# learning_rate: 0
# optimizer: Adam
# - name: sweetnet
# hidden_dim: 1024
# batch_size: 256
# epochs: 100
# patience: 30
# learning_rate: 0.001
# optimizer: Adam
# - name: gnngly
# hidden_dim: 14
# batch_size: 256
# num_layers: 5
# epochs: 1
# patience: 30
# learning_rate: 0
# optimizer: Adam
- name: gifflar
hidden_dim: 1024
batch_size: 32
num_layers: 12
epochs: 100
learning_rate: 0.001
optimizer: Adam
suffix: _1024_12

51 changes: 29 additions & 22 deletions configs/downstream/cuda.yaml → configs/downstream/dl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,49 @@ seed: 42
data_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/
root_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/data
datasets:
- name: class-1
label: class
task: classification
- name: Taxonomy_Phylum
task: multilabel
pre-transforms:
model:
- name: rf
n_estimators: 500
n_jobs: -1
random_state: 42
- name: mlp
- name: sweetnet
hidden_dim: 1024
batch_size: 256
num_layers: 3
epochs: 100
epochs: 1
patience: 30
learning_rate: 0
learning_rate: 0.001
optimizer: Adam
- name: sweetnet
hidden_dim: 128
suffix: _1024
- name: gnngly
hidden_dim: 1024
batch_size: 256
epochs: 50
num_layers: 16
epochs: 1
patience: 30
learning_rate: 0.001
learning_rate: 0
optimizer: Adam
suffix: _1024_16
- name: gnngly
hidden_dim: 14
hidden_dim: 1024
batch_size: 256
num_layers: 5
epochs: 200
num_layers: 16
epochs: 1
patience: 30
learning_rate: 0
optimizer: Adam
suffix: _1024_16
- name: gifflar
hidden_dim: 1024
batch_size: 256
num_layers: 8
epochs: 1
learning_rate: 0.001
optimizer: Adam
suffix: _1024_8
- name: gifflar
hidden_dim: 128
batch_size: 32
num_layers: 3
epochs: 100
hidden_dim: 1024
batch_size: 256
num_layers: 8
epochs: 1
learning_rate: 0.001
optimizer: Adam
suffix: _1024_8
18 changes: 9 additions & 9 deletions configs/downstream/gnngly.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,20 @@ datasets:
# task: classification
# - name: Taxonomy_Domain
# task: multilabel
- name: Taxonomy_Kingdom
task: multilabel
# - name: Taxonomy_Kingdom
# task: multilabel
# - name: Taxonomy_Phylum
# task: multilabel
# - name: Taxonomy_Class
# task: multilabel
# - name: Taxonomy_Order
# task: multilabel
- name: Taxonomy_Family
task: multilabel
- name: Taxonomy_Genus
task: multilabel
- name: Taxonomy_Species
task: multilabel
# - name: Taxonomy_Family
# task: multilabel
# - name: Taxonomy_Genus
# task: multilabel
# - name: Taxonomy_Species
# task: multilabel
pre-transforms:
model:
- name: gnngly
Expand All @@ -32,4 +32,4 @@ model:
patience: 30
learning_rate: 0
optimizer: Adam

suffix: _5_14
20 changes: 20 additions & 0 deletions gifflar/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
import torchmetrics


class Sensitivity(torchmetrics.Metric):
def __init__(self, threshold=0.5, dist_sync_on_step=False, **kwargs):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.threshold = threshold
self.add_state("true_positives", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("false_negatives", default=torch.tensor(0), dist_reduce_fx="sum")

def update(self, preds: torch.Tensor, target: torch.Tensor):
preds = (preds > self.threshold).int()
target = target.int()

self.true_positives += torch.sum(preds * target)
self.false_negatives += torch.sum((1 - preds) * target)

def compute(self):
return self.true_positives.float() / (self.true_positives + self.false_negatives)
2 changes: 1 addition & 1 deletion gifflar/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def setup(**kwargs):
pre_transform=get_pretransforms(**(kwargs["pre-transforms"] or {})), **data_config,
)
data_config["num_classes"] = datamodule.train.dataset_args["num_classes"]
logger = CSVLogger("logs", name=kwargs["model"]["name"])
logger = CSVLogger("logs", name=kwargs["model"]["name"] + kwargs["model"].get("suffix", ""))
kwargs["dataset"]["filepath"] = str(data_config["filepath"])
logger.log_hyperparams(kwargs)
metrics = get_metrics(data_config["task"], data_config["num_classes"])
Expand Down
6 changes: 4 additions & 2 deletions gifflar/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ def get_metrics(
- the R2Score
For everything else (single- and multilabel classification) we monitor
- the Accuracy,
- the AUROC, and
- the MCC
- the AUROC,
- the MCC, and
- the Sensitivity, i.e. the TPR
Args:
task: The type of the prediction task
Expand All @@ -113,6 +114,7 @@ def get_metrics(
Accuracy(**metric_args),
AUROC(**metric_args),
MatthewsCorrCoef(**metric_args),
Sensitivity(**metric_args),
])
return {"train": m.clone(prefix="train/"), "val": m.clone(prefix="val/"), "test": m.clone(prefix="test/")}

Expand Down

0 comments on commit 4189fbd

Please sign in to comment.