-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path4-CART.py
365 lines (269 loc) · 13.3 KB
/
4-CART.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
################################################
# Decision Tree Classification: CART
################################################
# CART (Classification and Regression Tree)
# 1. Exploratory Data Analysis
# 2. Data Preprocessing & Feature Engineering
# 3. Modeling using CART
# 4. Hyperparameter Optimization with GridSearchCV
# 5. Final Model
# 6. Feature Importance
# 7. Analyzing Model Complexity with Learning Curves (BONUS)
# 8. Visualizing the Decision Tree
# 9. Extracting Decision Rules
# 10. Extracting Python/SQL/Excel Codes of Decision Rules
# 11. Prediction using Python Codes
# 12. Saving and Loading Model
import warnings
import joblib
import pydotplus
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from sklearn.tree import DecisionTreeClassifier, export_graphviz, export_text
from sklearn.metrics import classification_report, roc_auc_score
from sklearn.model_selection import train_test_split, GridSearchCV, cross_validate, validation_curve
from skompiler import skompile
import graphviz
# bir karar ağacı yapısında en önemli değişken en üstteki değişkendir
# bağımsız değişkenleri bölen noktalara iç düğüm noktaları (bölüm noktaları) denir
# dallanma durumlarında 2 tane prooblemle karşılaşıyoruz:
# 1. ne kadar dallanacağız (1 seviye mi 3 seviye mi ne kadar derine ineceğiz)
# 2. diyelim ki bir dal oluştu diğer dal da oluşmuş olsun ama o dalda 3 tane değer kaldı
# bölmeye devam edeyim mi yoksa durayım mı?
# 1.si --> max_depth
# 2.si --> min_samples_split
# ağ modelleri başarılıdır ama overfit etmeye çok meyil'lidirler
# ağaç ağ modelleri iyi öğrenirler ama genellenebilirlik kabiliyeetlerrini kaybetmeleridir (overfit)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', 500)
warnings.simplefilter(action='ignore', category=Warning)
################################################
# 1. Exploratory Data Analysis
################################################
################################################
# 2. Data Preprocessing & Feature Engineering
################################################
################################################
# 3. Modeling using CART
################################################
df = pd.read_csv("datasets/diabetes.csv")
y = df["Outcome"]
X = df.drop(["Outcome"], axis=1)
cart_model = DecisionTreeClassifier(random_state=1).fit(X, y)
# Confusion matrix için y_pred:
y_pred = cart_model.predict(X)
# AUC için y_prob:
y_prob = cart_model.predict_proba(X)[:, 1]
# Confusion matrix
print(classification_report(y, y_pred))
# AUC
roc_auc_score(y, y_prob)
#####################
# Holdout Yöntemi ile Başarı Değerlendirme
#####################
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.30, random_state=45)
cart_model = DecisionTreeClassifier(random_state=17).fit(X_train, y_train)
# Train Hatası
y_pred = cart_model.predict(X_train)
y_prob = cart_model.predict_proba(X_train)[:, 1]
print(classification_report(y_train, y_pred))
roc_auc_score(y_train, y_prob)
# Test Hatası
y_pred = cart_model.predict(X_test)
y_prob = cart_model.predict_proba(X_test)[:, 1]
print(classification_report(y_test, y_pred))
roc_auc_score(y_test, y_prob)
#####################
# CV ile Başarı Değerlendirme
#####################
cart_model = DecisionTreeClassifier(random_state=17).fit(X, y)
cv_results = cross_validate(cart_model,
X, y,
cv=5,
scoring=["accuracy", "f1", "roc_auc"])
cv_results['test_accuracy'].mean()
# 0.7058568882098294
cv_results['test_f1'].mean()
# 0.5710621194523633
cv_results['test_roc_auc'].mean()
# 0.6719440950384347
################################################
# 4. Hyperparameter Optimization with GridSearchCV
################################################
cart_model.get_params()
cart_params = {'max_depth': range(1, 11),
"min_samples_split": range(2, 20)}
cart_best_grid = GridSearchCV(cart_model,
cart_params,
cv=5,
n_jobs=-1,
verbose=1).fit(X, y)
# GridSearchCV metodu bize en iyi parametreleri veriyor burada
cart_best_grid.best_params_
cart_best_grid.best_score_
random = X.sample(1, random_state=45)
cart_best_grid.predict(random)
################################################
# 5. Final Model
################################################
cart_final = DecisionTreeClassifier(**cart_best_grid.best_params_, random_state=17).fit(X, y)
cart_final.get_params()
cart_final = cart_model.set_params(**cart_best_grid.best_params_).fit(X, y)
cv_results = cross_validate(cart_final,
X, y,
cv=5,
scoring=["accuracy", "f1", "roc_auc"])
cv_results['test_accuracy'].mean()
cv_results['test_f1'].mean()
cv_results['test_roc_auc'].mean()
################################################
# 6. Feature Importance
################################################
cart_final.feature_importances_
def plot_importance(model, features, num=len(X), save=False):
feature_imp = pd.DataFrame({'Value': model.feature_importances_, 'Feature': features.columns})
plt.figure(figsize=(10, 10))
sns.set(font_scale=1)
sns.barplot(x="Value", y="Feature", data=feature_imp.sort_values(by="Value",
ascending=False)[0:num])
plt.title('Features')
plt.tight_layout()
plt.show()
if save:
plt.savefig('importances.png')
plot_importance(cart_final, X, num=5)
################################################
# 7. Analyzing Model Complexity with Learning Curves (BONUS)
################################################
train_score, test_score = validation_curve(cart_final, X, y,
param_name="max_depth",
param_range=range(1, 11),
scoring="roc_auc",
cv=10)
mean_train_score = np.mean(train_score, axis=1)
mean_test_score = np.mean(test_score, axis=1)
plt.plot(range(1, 11), mean_train_score,
label="Training Score", color='b')
plt.plot(range(1, 11), mean_test_score,
label="Validation Score", color='g')
plt.title("Validation Curve for CART")
plt.xlabel("Number of max_depth")
plt.ylabel("AUC")
plt.tight_layout()
plt.legend(loc='best')
plt.show()
def val_curve_params(model, X, y, param_name, param_range, scoring="roc_auc", cv=10):
train_score, test_score = validation_curve(
model, X=X, y=y, param_name=param_name, param_range=param_range, scoring=scoring, cv=cv)
mean_train_score = np.mean(train_score, axis=1)
mean_test_score = np.mean(test_score, axis=1)
plt.plot(param_range, mean_train_score,
label="Training Score", color='b')
plt.plot(param_range, mean_test_score,
label="Validation Score", color='g')
plt.title(f"Validation Curve for {type(model).__name__}")
plt.xlabel(f"Number of {param_name}")
plt.ylabel(f"{scoring}")
plt.tight_layout()
plt.legend(loc='best')
plt.show(block=True)
val_curve_params(cart_final, X, y, "max_depth", range(1, 11), scoring="f1")
cart_val_params = [["max_depth", range(1, 11)], ["min_samples_split", range(2, 20)]]
for i in range(len(cart_val_params)):
val_curve_params(cart_model, X, y, cart_val_params[i][0], cart_val_params[i][1])
################################################
# 8. Visualizing the Decision Tree
################################################
# conda install graphviz
# import graphviz
def tree_graph(model, col_names, file_name):
tree_str = export_graphviz(model, feature_names=col_names, filled=True, out_file=None)
graph = pydotplus.graph_from_dot_data(tree_str)
graph.write_png(file_name)
tree_graph(model=cart_final, col_names=X.columns, file_name="cart_final.png")
cart_final.get_params()
################################################
# 9. Extracting Decision Rules
################################################
tree_rules = export_text(cart_final, feature_names=list(X.columns))
print(tree_rules)
################################################
# 10. Extracting Python Codes of Decision Rules
################################################
# sklearn '0.23.1' versiyonu ile yapılabilir.
# pip install scikit-learn==0.23.1
print(skompile(cart_final.predict).to('python/code'))
print(skompile(cart_final.predict).to('sqlalchemy/sqlite'))
print(skompile(cart_final.predict).to('excel'))
################################################
# 11. Prediction using Python Codes
################################################
def predict_with_rules(x):
return ((((((0 if x[6] <= 0.671999990940094 else 1 if x[6] <= 0.6864999830722809 else
0) if x[0] <= 7.5 else 1) if x[5] <= 30.949999809265137 else ((1 if x[5
] <= 32.45000076293945 else 1 if x[3] <= 10.5 else 0) if x[2] <= 53.0 else
((0 if x[1] <= 111.5 else 0 if x[2] <= 72.0 else 1 if x[3] <= 31.0 else
0) if x[2] <= 82.5 else 1) if x[4] <= 36.5 else 0) if x[6] <=
0.5005000084638596 else (0 if x[1] <= 88.5 else (((0 if x[0] <= 1.0 else
1) if x[1] <= 98.5 else 1) if x[6] <= 0.9269999861717224 else 0) if x[1
] <= 116.0 else 0 if x[4] <= 166.0 else 1) if x[2] <= 69.0 else ((0 if
x[2] <= 79.0 else 0 if x[1] <= 104.5 else 1) if x[3] <= 5.5 else 0) if
x[6] <= 1.098000019788742 else 1) if x[5] <= 45.39999961853027 else 0 if
x[7] <= 22.5 else 1) if x[7] <= 28.5 else (1 if x[5] <=
9.649999618530273 else 0) if x[5] <= 26.350000381469727 else (1 if x[1] <=
28.5 else ((0 if x[0] <= 11.5 else 1 if x[5] <= 31.25 else 0) if x[1] <=
94.5 else (1 if x[5] <= 36.19999885559082 else 0) if x[1] <= 97.5 else
0) if x[6] <= 0.7960000038146973 else 0 if x[0] <= 3.0 else (1 if x[6] <=
0.9614999890327454 else 0) if x[3] <= 20.0 else 1) if x[1] <= 99.5 else
((1 if x[5] <= 27.649999618530273 else 0 if x[0] <= 5.5 else (((1 if x[
0] <= 7.0 else 0) if x[1] <= 103.5 else 0) if x[1] <= 118.5 else 1) if
x[0] <= 9.0 else 0) if x[6] <= 0.19999999552965164 else ((0 if x[5] <=
36.14999961853027 else 1) if x[1] <= 113.0 else 1) if x[0] <= 1.5 else
(1 if x[6] <= 0.3620000034570694 else 1 if x[5] <= 30.050000190734863 else
0) if x[2] <= 67.0 else (((0 if x[6] <= 0.2524999976158142 else 1) if x
[1] <= 120.0 else 1 if x[6] <= 0.23899999260902405 else 1 if x[7] <=
30.5 else 0) if x[2] <= 83.0 else 0) if x[5] <= 34.45000076293945 else
1 if x[1] <= 101.0 else 0 if x[5] <= 43.10000038146973 else 1) if x[6] <=
0.5609999895095825 else ((0 if x[7] <= 34.5 else 1 if x[5] <=
33.14999961853027 else 0) if x[4] <= 120.5 else (1 if x[3] <= 47.5 else
0) if x[4] <= 225.0 else 0) if x[0] <= 6.5 else 1) if x[1] <= 127.5 else
(((((1 if x[1] <= 129.5 else ((1 if x[6] <= 0.5444999933242798 else 0) if
x[2] <= 56.0 else 0) if x[2] <= 71.0 else 1) if x[2] <= 73.0 else 0) if
x[5] <= 28.149999618530273 else (1 if x[1] <= 135.0 else 0) if x[3] <=
21.0 else 1) if x[4] <= 132.5 else 0) if x[1] <= 145.5 else 0 if x[7] <=
25.5 else ((0 if x[1] <= 151.0 else 1) if x[5] <= 27.09999942779541 else
((1 if x[0] <= 6.5 else 0) if x[6] <= 0.3974999934434891 else 0) if x[2
] <= 82.0 else 0) if x[7] <= 61.0 else 0) if x[5] <= 29.949999809265137
else ((1 if x[2] <= 61.0 else (((((0 if x[6] <= 0.18299999833106995 else
1) if x[0] <= 0.5 else 1 if x[5] <= 32.45000076293945 else 0) if x[2] <=
73.0 else 0) if x[0] <= 4.5 else 1 if x[6] <= 0.6169999837875366 else 0
) if x[6] <= 1.1414999961853027 else 1) if x[5] <= 41.79999923706055 else
1 if x[6] <= 0.37299999594688416 else 1 if x[1] <= 142.5 else 0) if x[7
] <= 30.5 else (((1 if x[6] <= 0.13649999350309372 else 0 if x[5] <=
32.45000076293945 else 1 if x[5] <= 33.05000114440918 else (0 if x[6] <=
0.25599999725818634 else (0 if x[1] <= 130.5 else 1) if x[0] <= 8.5 else
0) if x[0] <= 13.5 else 1) if x[2] <= 92.0 else 1) if x[5] <=
45.54999923706055 else 1) if x[6] <= 0.4294999986886978 else (1 if x[5] <=
40.05000114440918 else 0 if x[5] <= 40.89999961853027 else 1) if x[4] <=
333.5 else 1 if x[2] <= 64.0 else 0) if x[1] <= 157.5 else ((((1 if x[7
] <= 25.5 else 0 if x[4] <= 87.5 else 1 if x[5] <= 45.60000038146973 else
0) if x[7] <= 37.5 else 1 if x[7] <= 56.5 else 0 if x[6] <=
0.22100000083446503 else 1) if x[6] <= 0.28849999606609344 else 0) if x
[6] <= 0.3004999905824661 else 1 if x[7] <= 44.0 else (0 if x[7] <=
51.0 else 1 if x[6] <= 1.1565000414848328 else 0) if x[0] <= 6.5 else 1
) if x[4] <= 629.5 else 1 if x[6] <= 0.4124999940395355 else 0)
X.columns
x = [12, 13, 20, 23, 4, 55, 12, 7]
predict_with_rules(x)
x = [6, 148, 70, 35, 0, 30, 0.62, 50]
predict_with_rules(x)
################################################
# 12. Saving and Loading Model
################################################
joblib.dump(cart_final, "cart_final.pkl")
cart_model_from_disc = joblib.load("cart_final.pkl")
x = [12, 13, 20, 23, 4, 55, 12, 7]
cart_model_from_disc.predict(pd.DataFrame(x).T)
# predict i kullanırken dataframe e çevirmek gerekiyor