-
Notifications
You must be signed in to change notification settings - Fork 0
/
decision_tree.qmd
383 lines (265 loc) · 13.5 KB
/
decision_tree.qmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
---
title: "decision tree"
format: html
---
## decision tree
一个不错的数据集搜集的网站。
> https://archive-beta.ics.uci.edu/
> https://scikit-learn.org/stable/modules/tree.html
决策树,classical trees & conditional inference trees.
决策树的一般逻辑过程为,首先是特征选择,选择和分类结果相关性较高的特征,一般通过信息增益等;然后,选择好特征后从根节点出发,对节点计算所有的信息增益,选择信息增益最大的特征作为节点特征,根据该特征的不同取值建立子节点;对每个子节点使用相同的方式生成新的子节点,直到信息增益很小或者没有特征可以选择为止。
根节点如何选择呢?选择信息增益最大或Gini系数最小的分类特征作为根节点以及后续的子节点。还需要确定最优的分类变量和分类阈值等。信息增益是指使用某个特征来分割数据集后,能够使得分类结果的不确定性减少的程度;基尼指数是指使用某个特征来分割数据集后,随机抽取两个样本,它们被错误分类的概率。
在数据集中随机选择一个数据点,并随机分配给它一个数据集中存在的标签,分配错误的概率即为Gini impurity。
*3种典型的决策树算法为*,ID3、C4.5、CART, CART 算法使用了基尼系数取代了信息熵模型。常见的有三种不同的不纯度计算的方法,
包括Gini impurity(G), Entropy(H), classification error(E).
*决策树常用到的超参数:*
在决策树算法中,有一些重要的超参数需要进行调整以优化模型性能。以下是一些常用的决策树超参数:
最大深度(max_depth):决策树的最大深度,用于控制决策树的复杂度和过拟合风险。
最小样本数(min_samples_split):分裂一个内部节点需要的最小样本数,用于控制决策树的复杂度和过拟合风险。
叶节点最小样本数(min_samples_leaf):叶节点需要的最小样本数,用于控制决策树的复杂度和过拟合风险。
最大叶节点数(max_leaf_nodes):限制决策树的最大叶节点数,用于控制决策树的复杂度和过拟合风险。
分裂特征的最大数量(max_features):在每个节点上评估分裂的特征数量,用于控制决策树的复杂度和过拟合风险。
这些超参数的最佳取值通常需要通过交叉验证等方法进行调整,以找到最佳的模型性能。
```{r, include=FALSE}
library(rpart)
library(rpart.plot)
library(rattle)
library(partykit)
```
```{r}
# 9和自变量的小数据集
breast <- read.table('./datasets/breast-cancer-wisconsin.data', sep = ',',
header = FALSE,
na.strings = '?'
)
names(breast) <- c("ID", "clumpThickness", "sizeUniformity", "shapeUniformity", "maginalAdhesion",
"singleEpithelialCellSize", "bareNuclei",
"blandChromatin", "normalNucleoli","mitosis", "class")
df <- breast[-1]
df$class <- factor(df$class, levels=c(2,4),
labels=c("benign", "malignant"))
df <- na.omit(df)
```
```{r}
set.seed(1234)
index <- sample(nrow(df), 0.7*nrow(df))
train_df <- df[index, ]
test_df <- df[-index, ]
```
### 利用`rpart`包进行决策树的构建。
`rpart`包,是一个 R 语言中用于构建决策树模型的包,其名称来源于 Recursive Partitioning and Regression Trees(递归分区和回归树)。rpart 包提供了一种灵活、高效的方法来构建分类和回归树模型,能够处理多分类问题和连续型预测变量,同时还支持剪枝和交叉验证等技术。
```{r}
# 决策树分类问题
# 注意method的选择
# control 参数指定决策树的其他参数,例如最小叶节点大小和最大深度等
dtree <- rpart::rpart(formula = class ~ .,
data = train_df,
method = 'class',
parms=list(split="information"),
control = rpart.control(minsplit = 5, maxdepth = 3)
)
# 输出很多
# summary(dtree)
```
```{r}
# 使用交叉验证进行超参数调优
tree_model <- caret::train(Species ~ ., data = train_data,
method = "rpart",
trControl = trainControl(method = "cv", number = 5),
tuneLength = 10)
# cv.rpart() 函数用于交叉验证
```
CP(complexity parameter)复杂度参数)是决策树模型中的一个超参数,用于控制模型的复杂度。在决策树模型中,复杂度参数通常用来控制树的大小,即叶子节点的数量或深度。
在 rpart 包中,复杂度参数用来控制决策树的生长过程。当复杂度参数较小时,决策树会生长得更深,拟合训练数据的能力更强,但容易出现过拟合的问题;当复杂度参数较大时,决策树会生长得更浅,泛化能力更强,但拟合训练数据的能力会降低。因此,复杂度参数的选择需要在拟合能力和泛化能力之间进行权衡。
在 rpart 包中,复杂度参数的默认值为 0.01,可以通过调整 cp 参数来控制复杂度。当 cp 参数的值小于默认值时,决策树会生长得更深;当 cp 参数的值大于默认值时,决策树会生长得更浅。可以使用 plotcp() 函数绘制复杂度参数和交叉验证误差之间的关系曲线,以找到最佳的复杂度参数。
```{r}
dtree$cptable
# 对于
```
*interpret: *To choose a final tree size, examine the cptable component of the list returned by rpart(). It contains data about the prediction error for various tree sizes. The complexity parameter (cp) is used to penalize larger trees. Tree size is defined by the number of branch splits (nsplit). A tree with n splits has n + 1 terminal nodes. The rel error column contains the error rate for a tree of a given size in the training sample. The cross-validated error (xerror) is based on 10-fold cross-validation (also using the training sample). The xstd column contains the standard error of the cross-validation error.
CP:复杂度参数的值。
nsplit:对应的树的节点数(非叶子节点)。
rel error:相对误差,即与根节点相比的误差。
xerror:交叉验证误差,即使用交叉验证计算的误差。
xstd:交叉验证误差的标准差。
minerr:最小误差,即误差最小的节点的误差。
mince:最小分类误差,即误差最小的节点的分类误差。
yval:叶子节点的目标变量值,用于分类问题时为类别标签,用于回归问题时为预测值。
plotcp() 函数用于绘制交叉验证误差曲线
```{r}
plotcp(dtree)
```
*interpret: *选择多少个split作为, Y轴值越小越好,从图中可以考虑第三个点,即`cp = 0.01705`。
```{r}
plot(dtree)
text(dtree, use.n = TRUE, all = TRUE, cex = 0.8)
```
依据上面的分析,得到nsplit个数所对应的CP值,带入进行修剪。
剪枝是一种常用的模型正则化技术,其基本思想是通过删除某些叶子节点,使得决策树的复杂度降低,从而提高模型的泛化能力。在 rpart 包中,剪枝可以通过调整复杂度参数来实现。
prune() 函数会返回一个新的 rpart 对象,即剪枝后的决策树模型。可以使用 plot() 和 text() 函数可视化剪枝后的决策树模型,以评估剪枝的效果。
```{r}
dtree.prune <- prune(dtree, cp = 0.01705)
```
`dtree.prune`对象,可以理解为进一步优化后的模型。应该和网格搜索的结果做一番对比。
```{r}
printcp(dtree.prune)
```
```{r}
# 可视化剪枝后的决策树模型
plot(dtree.prune)
text(dtree.prune, use.n = TRUE, all = TRUE, cex = 0.8)
```
The color scheme ranges from dark red (most impure) to light green (most pure),with yellow representing intermediate levels of impurity.The gradient represents the accuracy of that node.
即是对`print(dtree.prune)`的可视化。
```{r}
rattle::fancyRpartPlot(dtree.prune, sub="Classification Tree")
```
*interpret the results: *整幅图为决策树模型,每一个node为分类的概率和某一变量branch,node上的数字为,correspond to the branch numbers in the textual representation of the trees as generated by the default print() method.
The "jumps" result from rpart() tuning the tree to remove some of the branches and those pruned branches do not appear in the final tree.
百分数为,the percentage of observations used at that node. 最后的叶节点的百分值加在一起约为100.
一个类似的图,
```{r}
plot(as.party(dtree.prune), gp = gpar(fontsize = 6))
```
*interpret the result: *最下面的柱状图代表每一个分类的概率。
*预测测试集的结果*
```{r}
dtree.pred <- predict(dtree.prune, test_df, type="class")
# 混淆矩阵
dtree.perf <- table(test_df$class, dtree.pred,
dnn=c("Actual", "Predicted"))
dtree.perf
# 计算模型的准确率
accuracy <- sum(dtree.pred == test_df$class) / nrow(test_df)
print(paste("Accuracy:", accuracy))
```
### 构建一个Conditional inference trees
此处先只做简单的介绍,超参数优化等步骤以后再增加。
*Conditional inference trees* are similar to traditional trees, but variables and splits are selected based on significance tests rather than purity/homogeneity measures. The significance tests are permutation tests.
Note that pruning isn’t required for conditional inference trees, and the process is somewhat more automated.
```{r}
# conditional inference tree
fit.ctree <- partykit::ctree(class~., data=train_df)
```
```{r}
plot(fit.ctree, main = 'Conditional Inference Tree',
gp = gpar(fontsize = 6)
)
```
ctree的混淆矩阵,和上面的结果对比,准确性还稍微高一些,
```{r}
ctree.pred <- predict(fit.ctree, test_df, type="response")
ctree.perf <- table(test_df$class, ctree.pred,
dnn=c("Actual", "Predicted"))
ctree.perf
```
### tidymodels中的决策树构建
具体参见`random_forest_tidymodels.Rmd`文件。
#### tidymodels做决策树分类问题
首先构建一个决策树模型,以帮助确定超参数的取值。
本基因表达数据集outcome为二分类结果。两种结局结果较为均衡,一个50例一个52例。
决策树模型的种类繁多,此处选择`rpart`包。
```{r}
class_tree_spec <- decision_tree() %>%
set_engine('rpart') %>%
set_mode("classification")
# 初步模型, 进行模型的一些探索性分析,为最终模型确定一些参数
class_tree_fit <- class_tree_spec %>%
fit(class ~ ., data = df_train)
# 在训练集数据上的一些模型指标
augment(class_tree_fit, new_data = df_train) %>%
accuracy(truth = class, estimate = .pred_class)
augment(class_tree_fit, new_data = df_train) %>%
conf_mat(truth = class, estimate = .pred_class)
```
`rpart.plot` 可以对rpart数进行可视化,此处做一个展示:
```{r}
class_tree_fit %>%
extract_fit_engine() %>%
rpart.plot::rpart.plot()
```
交叉验证和网格搜索确定最后的模型,
决策树模型需要确认nsplit等参数,一般会通过CP值分析得到,不过在网格搜索中得到的nsplit值。
在决策树模型中,超参数常见的设置为`cost_complexity` 和 `tree_depth`.
# 设置超参数范围
param_grid = {
'max_depth': [3, 5, 7],
'min_samples_split': [2, 5, 10],
'min_samples_leaf': [1, 2, 4],
'max_features': ['sqrt', 'log2']
}
```{r}
class_tree_wf <- workflow() %>%
add_model(class_tree_spec %>%
set_args(cost_complexity = tune(),
tree_depth = tune()
)) %>% # 选择网格搜索的参数
add_formula(class ~ .)
# k-fold cross-validation
set.seed(42)
df_fold <- vfold_cv(df_train)
# 利用dials包,进行超参数的设定
param_grid <- grid_regular(cost_complexity(range = c(-3, -1)),
tree_depth(),
levels = 10)
dials_para <- dials::grid_random(cost_complexity(),
tree_depth(),
size = 5
)
# 运行时间也太长了。。
tune_res <- tune_grid(
class_tree_wf,
resamples = df_fold,
grid = dials_para,
metrics = metric_set(accuracy)
)
autoplot(tune_res)
```
简单看下两个超参数在两个不同matrices的表现:
```{r}
tree_res %>%
collect_metrics()
tree_res %>%
collect_metrics() %>%
mutate(tree_depth = factor(tree_depth)) %>%
ggplot(aes(cost_complexity, mean, color = tree_depth)) +
geom_line(linewidth = 1.5, alpha = 0.6) +
geom_point(size = 2) +
facet_wrap(~ .metric, scales = "free", nrow = 2) +
scale_x_log10(labels = scales::label_number()) +
scale_color_viridis_d(option = "plasma", begin = .9, end = 0)
```
```{r}
best_complexity <- select_best(tune_res)
class_tree_final <- finalize_workflow(class_tree_wf, best_complexity)
class_tree_final_fit <- fit(class_tree_final, data = df_train)
class_tree_final_fit
```
```{r}
class_tree_final_fit %>%
extract_fit_engine() %>%
rpart.plot()
```
在调整完参数之后,图咋没啥改变呢。。
```{r}
class_tree_final_fit %>%
collect_predictions() %>%
roc_curve(class, .pred_PS) %>%
autoplot()
```
We can use the function last_fit() with our finalized model; this function fits the finalized model on the full training data set and evaluates the finalized model on the testing data.
```{r}
final_fit <-
final_wf %>%
last_fit(df_split)
final_fit %>%
collect_metrics()
```
```{r}
final_tree <- extract_workflow(final_fit)
# 应该和class_tree_final_fit是一样的,不过跑的太慢了,没有测试
final_tree
final_tree %>%
extract_fit_parsnip() %>%
vip::vip()
```