-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathensemble_cifar_models.py
56 lines (44 loc) · 1.78 KB
/
ensemble_cifar_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
import utils as myutils
import sys
PATH_TO_CIFAR = "./cifar/"
sys.path.append(PATH_TO_CIFAR)
import train as cifar_train
import hyperparameters.vgg11_cifar10_baseline as vgg_hyperparams
import wasserstein_ensemble
import baseline
import parameters
import torch
ensemble_root_dir = "./cifar_models/"
# ensemble_experiment = "exp_2019-04-23_18-08-48/"
ensemble_experiment = "exp_2019-04-24_02-20-26"
ensemble_dir = ensemble_root_dir + ensemble_experiment
output_root_dir = "./cifar_models_ensembled/"
checkpoint_type = 'final' # which checkpoint to use for ensembling (either of 'best' or 'final)
def main():
# torch.cuda.empty_cache()
config = vgg_hyperparams.config
timestamp = myutils.get_timestamp_other()
model_list = os.listdir(ensemble_dir)
num_models = len(model_list)
train_loader, test_loader = cifar_train.get_dataset(config)
models = []
for idx in range(num_models):
print("Path is ", ensemble_dir)
print("loading model with idx {} and checkpoint_type is {}".format(idx, checkpoint_type))
models.append(
cifar_train.get_pretrained_model(
config, os.path.join(ensemble_dir, 'model_{}/{}.checkpoint'.format(idx, checkpoint_type)), parameters.gpu_id
)
)
print("Done loading all the models")
# run geometric aka wasserstein ensembling
print("------- Geometric Ensembling -------")
wasserstein_ensemble.geometric_ensembling_modularized(models, train_loader, test_loader)
# run baseline
print("------- Prediction based ensembling -------")
baseline.prediction_ensembling(models, train_loader, test_loader)
print("------- Naive ensembling of weights -------")
baseline.naive_ensembling(models, train_loader, test_loader)
if __name__ == '__main__':
main()