-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
153 lines (115 loc) · 6.42 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import argparse
import numpy as np
from src.data import load_data
from src.methods.dummy_methods import DummyClassifier
from src.methods.logistic_regression import LogisticRegression
from src.methods.linear_regression import LinearRegression
from src.methods.knn import KNN
from src.utils import normalize_fn, append_bias_term, accuracy_fn, macrof1_fn, mse_fn
import os
np.random.seed(100)
def main(args):
"""
The main function of the script. Do not hesitate to play with it
and add your own code, visualization, prints, etc!
Arguments:
args (Namespace): arguments that were parsed from the command line (see at the end
of this file). Their value can be accessed as "args.argument".
"""
## 1. First, we load our data and flatten the images into vectors
##EXTRACTED FEATURES DATASET
if args.data_type == "features":
feature_data = np.load('features.npz',allow_pickle=True)
xtrain, xtest, ytrain, ytest, ctrain, ctest =feature_data['xtrain'],feature_data['xtest'],\
feature_data['ytrain'],feature_data['ytest'],feature_data['ctrain'],feature_data['ctest']
##ORIGINAL IMAGE DATASET (MS2)
elif args.data_type == "original":
data_dir = os.path.join(args.data_path,'dog-small-64')
xtrain, xtest, ytrain, ytest, ctrain, ctest = load_data(data_dir)
##TODO: ctrain and ctest are for regression task. (To be used for Linear Regression and KNN)
##TODO: xtrain, xtest, ytrain, ytest are for classification task. (To be used for Logistic Regression and KNN)
## 2. Then we must prepare it. This is were you can create a validation set,
# normalize, add bias, etc.
# Make a validation set (it can overwrite xtest, ytest)
if not args.test:
### WRITE YOUR CODE HERE
pass
### WRITE YOUR CODE HERE to do any other data processing
# Add bias term
xtrain = append_bias_term(xtrain)
xtest = append_bias_term(xtest)
# Compute mean and std on the training data
mean = np.mean(xtrain)
std = np.std(xtrain)
# Normalize the data
xtrain = normalize_fn(xtrain, mean, std)
xtest = normalize_fn(xtest, mean, std)
## 3. Initialize the method you want to use.
# Use NN (FOR MS2!)
if args.method == "nn":
raise NotImplementedError("This will be useful for MS2.")
# Follow the "DummyClassifier" example for your methods
if args.method == "dummy_classifier":
method_obj = DummyClassifier(arg1=1, arg2=2)
elif args.method == "knn":
task_kind = "classification" if args.task == "breed_identifying" else "regression"
method_obj = KNN(k=args.K, task_kind=task_kind)
elif args.method == "linear_regression":
if args.task == "breed_identifying":
raise Exception("Linear regression is not suitable for classification task.")
method_obj = LinearRegression(lmda=args.lmda)
elif args.method == "logistic_regression":
if args.task == "center_locating":
raise Exception("Logistic regression is not suitable for regression task.")
method_obj = LogisticRegression(lr=args.lr, max_iters=args.max_iters)
else:
raise Exception("Invalid choice of method! Only support dummy_classifier / knn / linear_regression/ logistic_regression / nn (MS2)")
## 4. Train and evaluate the method
if args.task == "center_locating":
# Fit parameters on training data
preds_train = method_obj.fit(xtrain, ctrain)
# Perform inference for training and test data
train_pred = method_obj.predict(xtrain)
preds = method_obj.predict(xtest)
## Report results: performance on train and valid/test sets
train_loss = mse_fn(train_pred, ctrain)
loss = mse_fn(preds, ctest)
print(f"\nTrain loss = {train_loss:.3f}% - Test loss = {loss:.3f}")
return train_loss, loss
elif args.task == "breed_identifying":
# Fit (:=train) the method on the training data for classification task
preds_train = method_obj.fit(xtrain, ytrain)
# Predict on unseen data
preds = method_obj.predict(xtest)
## Report results: performance on train and valid/test sets
acc = accuracy_fn(preds_train, ytrain)
macrof1 = macrof1_fn(preds_train, ytrain)
print(f"\nTrain set: accuracy = {acc:.3f}% - F1-score = {macrof1:.6f}")
acc = accuracy_fn(preds, ytest)
macrof1 = macrof1_fn(preds, ytest)
print(f"Test set: accuracy = {acc:.3f}% - F1-score = {macrof1:.6f}")
return acc, macrof1
else:
raise Exception("Invalid choice of task! Only support center_locating and breed_identifying!")
### WRITE YOUR CODE HERE if you want to add other outputs, visualization, etc.
if __name__ == '__main__':
# Definition of the arguments that can be given through the command line (terminal).
# If an argument is not given, it will take its default value as defined below.
parser = argparse.ArgumentParser()
parser.add_argument('--task', default="center_locating", type=str, help="center_locating / breed_identifying")
parser.add_argument('--method', default="dummy_classifier", type=str, help="dummy_classifier / knn / linear_regression/ logistic_regression / nn (MS2)")
parser.add_argument('--data_path', default="data", type=str, help="path to your dataset")
parser.add_argument('--data_type', default="features", type=str, help="features/original(MS2)")
parser.add_argument('--lmda', type=float, default=10, help="lambda of linear/ridge regression")
parser.add_argument('--K', type=int, default=1, help="number of neighboring datapoints used for knn")
parser.add_argument('--lr', type=float, default=1e-5, help="learning rate for methods with learning rate")
parser.add_argument('--max_iters', type=int, default=100, help="max iters for methods which are iterative")
parser.add_argument('--test', action="store_true", help="train on whole training data and evaluate on the test data, otherwise use a validation set")
# Feel free to add more arguments here if you need!
# MS2 arguments
parser.add_argument('--nn_type', default="cnn", help="which network to use, can be 'Transformer' or 'cnn'")
parser.add_argument('--nn_batch_size', type=int, default=64, help="batch size for NN training")
# "args" will keep in memory the arguments and their values,
# which can be accessed as "args.data", for example.
args = parser.parse_args()
main(args)