Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(examples) Update whisper finetuning example #4158

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/whisper-federated-finetuning/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
processed_partitions/
264 changes: 140 additions & 124 deletions examples/whisper-federated-finetuning/README.md

Large diffs are not rendered by default.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
import argparse
import random

import numpy as np
import torch
from datasets import concatenate_datasets, load_dataset
from torch.utils.data import DataLoader, WeightedRandomSampler
from transformers import WhisperForConditionalGeneration, WhisperProcessor

from utils import (
from torch.utils.data import DataLoader
from transformers import WhisperProcessor
from whisper_example.dataset import get_encoding_fn, prepare_silences_dataset
from whisper_example.model import (
construct_balanced_sampler,
eval_model,
get_encoding_fn,
get_model,
prepare_silences_dataset,
remove_cols,
train_one_epoch,
)

from datasets import concatenate_datasets, load_dataset

random.seed(1989)
torch.set_float32_matmul_precision(
"high"
) # If “high” or “medium” are set then the TensorFloat32 is used
NUM_CLASSES = 12
REMOVE_COLS = ["file", "audio", "label", "is_unknown", "speaker_id", "utterance_id"]
parser = argparse.ArgumentParser(description="Whisper centralised")

parser.add_argument("--checkpoint", type=str, help="path to classifier`s checkpoint")
Expand Down Expand Up @@ -56,10 +55,10 @@ def main():
torch.set_num_threads(
1
) # not clear to me why we need this in order to be able to use `num_proc > 1 for .map`
train_encoded = sc.map(prepare_dataset_fn, num_proc=4, remove_columns=remove_cols)
val_encoded = sc_val.map(prepare_dataset_fn, num_proc=4, remove_columns=remove_cols)
train_encoded = sc.map(prepare_dataset_fn, num_proc=4, remove_columns=REMOVE_COLS)
val_encoded = sc_val.map(prepare_dataset_fn, num_proc=4, remove_columns=REMOVE_COLS)
test_encoded = sc_test.map(
prepare_dataset_fn, num_proc=4, remove_columns=remove_cols
prepare_dataset_fn, num_proc=4, remove_columns=REMOVE_COLS
)

# create and pre-process the dataset of silences
Expand All @@ -68,26 +67,19 @@ def main():
# ! needed each time you run the code. Alternatively, this silence generation could be
# ! implemented as part of a `collate_fn` in the standard PyTorch dataloader...
encoded_silences = silences_dataset.map(
prepare_dataset_fn, num_proc=4, remove_columns=remove_cols
prepare_dataset_fn, num_proc=4, remove_columns=REMOVE_COLS
)
full_train_dataset = concatenate_datasets([train_encoded, encoded_silences])

torch.set_num_threads(og_threads)

lbls = set(full_train_dataset["targets"])
print(f"{lbls = }")
hist = np.histogram(full_train_dataset["targets"], bins=12)
print(f"{[int(count) for count in hist[0]]}")

# make balanced batches with a WeightedRandomSampler
w_per_class = (
len(full_train_dataset) / hist[0]
) # doesn't have to add up to 1 (relative is what matters)
print(f"{w_per_class = }")
w_ss = [w_per_class[t] for t in full_train_dataset["targets"]]
sampler = WeightedRandomSampler(w_ss, len(w_ss))

# prepare dataloaders
# Construct a balanced sampler so batches roughly contain the same number
# of samples from each class
sampler = construct_balanced_sampler(full_train_dataset)

# Prepare dataloaders
train_dataset = full_train_dataset.with_format("torch", columns=["data", "targets"])
train_loader = DataLoader(
train_dataset, batch_size=64, shuffle=False, num_workers=4, sampler=sampler
Expand All @@ -97,7 +89,7 @@ def main():
test_dataset = test_encoded.with_format("torch", columns=["data", "targets"])
test_loader = DataLoader(test_dataset, batch_size=64, num_workers=4)

# model to cuda, set criterion, classification layer to train and optimiser
# Model to cuda, set criterion, classification layer to train and optimiser
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
encoder, classifier = get_model(device, num_classes=12)
criterion = torch.nn.CrossEntropyLoss()
Expand All @@ -113,7 +105,7 @@ def main():
classifier_head_params = sum(p.numel() for p in classifier.parameters())
print(f"{classifier_head_params = }")

# eval initial model
# Eval initial model
loss, accuracy = eval_model(encoder, classifier, criterion, val_loader, device)
print(f"Initial (loss, acc): {loss = }, {accuracy = }")
best = [-float("inf"), None]
Expand Down
185 changes: 0 additions & 185 deletions examples/whisper-federated-finetuning/client.py

This file was deleted.

60 changes: 60 additions & 0 deletions examples/whisper-federated-finetuning/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import argparse
from multiprocessing import Pool
from time import time

import tomli
from whisper_example.dataset import load_data

from datasets import load_dataset

parser = argparse.ArgumentParser(description="Whisper preprocessing")

parser.add_argument(
"--partition-id", type=int, help="The partition to create and save."
)

args = parser.parse_args()


# Open and read the pyproject.toml
with open("pyproject.toml", "rb") as file:
flwr_config = tomli.load(file)["tool"]["flwr"]

# Display
print(flwr_config)
remove_cols = flwr_config["app"]["config"]["remove-cols"]
num_supernodes = flwr_config["federations"]["local-sim"]["options"]["num-supernodes"]

# If specified one partition, only that one will be processed and saved to the current directory
if args.partition_id:
print(f"Pre-processing partition {args.partition_id} only.")
else:
print(f"Pre-processing dataset into {num_supernodes} partitions.")


def process_one_partition(partition_id: int, save: bool = False):
pp = load_data(partition_id, remove_cols)
if save:
file_name = f"partition_{partition_id}"
pp.save_to_disk(file_name)
print(f"Saved partition to disk: {file_name}")


if __name__ == "__main__":

# Download train set
_ = load_dataset("speech_commands", "v0.02", split="train", token=False)

# Parallelize the processing of each partition in the dataset
t_start = time()
num_proc = None # set it if you want to limit the number of processes

if args.partition_id:
process_one_partition(args.partition_id, True)

else:
with Pool(num_proc) as pool:
pool.map(process_one_partition, range(num_supernodes))
print(
f"Pre-processing {num_supernodes} partitions took: {time() - t_start:.2f} s"
)
Loading