Skip to content

Commit

Permalink
added pruning to hyper param opt and parallelized the trials
Browse files Browse the repository at this point in the history
  • Loading branch information
johanos1 committed Oct 9, 2024
1 parent 96b9a22 commit b56584b
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 11 deletions.
12 changes: 8 additions & 4 deletions experiments/tune_hyperparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

import pandas as pd

import sys
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(root_dir)

from flib.train import centralized, federated, isolated, HyperparamTuner
import hyperparams

Expand All @@ -13,9 +17,9 @@ def main():
parser.add_argument('--clients', nargs='+', help='Types of clients to train.', default=['LogRegClient', 'DecisionTreeClient', 'RandomForestClient', 'GradientBoostingClient', 'SVMClient', 'KNNClient']) # LogRegClient, DecisionTreeClient, RandomForestClient, GradientBoostingClient, SVMClient, KNNClient
parser.add_argument('--settings', nargs='+', help='Types of settings to use. Can be "isolated", "centralized" or "federated".', default=['centralized', 'federated', 'isolated'])
parser.add_argument('--traindata_files', nargs='+', help='Paths to trainsets.', default=[
'/home/edvin/Desktop/flib/experiments/data/3_banks_homo_mid/preprocessed/a_nodes_train.csv',
'/home/edvin/Desktop/flib/experiments/data/3_banks_homo_mid/preprocessed/b_nodes_train.csv',
'/home/edvin/Desktop/flib/experiments/data/3_banks_homo_mid/preprocessed/c_nodes_train.csv'
'/home/johan/project/flib/data/3_banks_homo_mid/preprocessed/a_nodes_test.csv',
'/home/johan/project/flib/data/3_banks_homo_mid/preprocessed/b_nodes_test.csv',
'/home/johan/project/flib/data/3_banks_homo_mid/preprocessed/c_nodes_test.csv'
])
parser.add_argument('--valdata_files', nargs='+', help='Paths to valsets', default=[
None,
Expand All @@ -28,7 +32,7 @@ def main():
parser.add_argument('--device', type=str, help='Device for computations. Can be "cpu" or cuda device, e.g. "cuda:0".', default="cuda:0")
parser.add_argument('--seed', type=int, help='Seed.', default=42)
parser.add_argument('--n_trials', type=int, help='Number of trials.', default=10)
parser.add_argument('--results_dir', type=str, default='/home/edvin/Desktop/flib/experiments/results/3_banks_homo_mid/')
parser.add_argument('--results_dir', type=str, default='/home/johan/project/flib/data/3_banks_homo_mid')

args = parser.parse_args()

Expand Down
7 changes: 6 additions & 1 deletion flib/train/federated.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,12 @@ def objective(self, trial):
return avg_loss

def optimize(self, n_trials=10):
study = optuna.create_study(storage=self.storage, sampler=optuna.samplers.TPESampler(), study_name='study', directions=['minimize'], load_if_exists=True)
study = optuna.create_study(storage=self.storage,
sampler=optuna.samplers.TPESampler(),
study_name='study',
directions=['minimize'],
load_if_exists=True,
pruner=optuna.pruners.HyperbandPruner())
study.optimize(self.objective, n_trials=n_trials)
with open(self.results_file, 'a') as f:
f.write(f'\n\n{time.ctime()}\n')
Expand Down
2 changes: 1 addition & 1 deletion flib/train/isolated.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def isolated(train_dfs, val_dfs=[], test_dfs=[], seed=42, n_workers=3, client='L
set_random_seed(seed)

try:
mp.set_start_method('spawn')
mp.set_start_method('spawn', force=True)
except RuntimeError:
pass

Expand Down
33 changes: 28 additions & 5 deletions flib/train/tune_hyperparams.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import optuna
import inspect
from flib.train import Clients
import multiprocessing

class HyperparamTuner():
def __init__(self, study_name, obj_fn, train_dfs, val_dfs, seed=42, device='cpu', n_workers=1, storage=None, client=None, params=None):
Expand Down Expand Up @@ -58,9 +59,31 @@ def objective(self, trial: optuna.Trial):
avg_loss += results[client][round]['val']['loss'] / len(results)
return avg_loss

def optimize(self, n_trials=10):
study = optuna.create_study(storage=self.storage, sampler=optuna.samplers.TPESampler(), study_name=self.study_name, direction='minimize', load_if_exists=True)
study.optimize(self.objective, n_trials=n_trials)
return study.best_trials
# def optimize(self, n_trials=10):
# study = optuna.create_study(storage=self.storage, sampler=optuna.samplers.TPESampler(), study_name=self.study_name, direction='minimize', load_if_exists=True)
# study.optimize(self.objective, n_trials=n_trials, show_progress_bar=True)
# return study.best_trials


def optimize(self, n_trials=100, n_jobs=10):
# Create the study with RDB storage for parallel processing
study = optuna.create_study(storage=self.storage,
sampler=optuna.samplers.TPESampler(),
study_name=self.study_name,
direction='minimize',
load_if_exists=True,
pruner = optuna.pruners.HyperbandPruner())


def run_study():
study.optimize(self.objective, n_trials=n_trials // n_jobs)

processes = []
for _ in range(n_jobs):
p = multiprocessing.Process(target=run_study)
p.start()
processes.append(p)

for p in processes:
p.join()

return study.best_trials

0 comments on commit b56584b

Please sign in to comment.