-
Notifications
You must be signed in to change notification settings - Fork 0
/
Chapter_8_Trees_and_Random_Forests.Rmd
210 lines (140 loc) · 7.36 KB
/
Chapter_8_Trees_and_Random_Forests.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
---
title: "Chapter_8_labs_trees"
author: "Russ Conte"
date: "8/7/2021"
output: html_document
---
## 8.3 Lab: Decision Trees
## 8.3.1 Fitting Classification Trees
We first use classification trees to analyze the Carseats data set. In these data, Sales is a continuous variable, and so we begin by recoding it as a binary variable. We use the ifelse() function to create a variable, called High, which takes on a value of Yes if the Sales variable exceeds 8, and takes on a value of No otherwise.
```{r}
library(tree)
library(ISLR)
attach(Carseats)
High = as.factor(ifelse(Sales <= 8, "No", "Yes"))
Carseats = data.frame(Carseats, High)
str(Carseats)
View(Carseats)
```
We now use the tree() function to fit a classification tree in order to predict High using all variables but Sales. The syntax of the tree() function is quite similar to that of the lm() function.
```{r}
tree.carseats = tree(High ~. - Sales, data = Carseats)
summary(tree.carseats) # 91% accuracy out of the box!
```
One of the most attractive properties of trees is that they can be graphically displayed. We use the plot() function to display the tree struc- ture, and the text() function to display the node labels.
```{r}
par(mfrow = c(1,1))
plot(tree.carseats)
text(tree.carseats, pretty = 0)
tree.carseats # this is how we can see the splits made by the tree function
```
In order to properly evaluate the performance of a classification tree on these data, we must estimate the test error rather than simply computing the training error. We split the observations into a training set and a test set, build the tree using the training set, and evaluate its performance on the test data. The predict() function can be used for this purpose. In the case of a classification tree, the argument type="class" instructs R to return the actual class prediction. This approach leads to correct predictions for around 71.5 % of the locations in the test data set.
```{r}
set.seed((2))
train = sample(1:nrow(Carseats), 200)
Carseats.test = Carseats[-train,]
High.test = High[-train]
tree.carseats = tree(High~.-Sales, Carseats, subset = train)
tree.pred = predict(tree.carseats, Carseats.test, type = "class")
table(tree.pred, High.test)
154/200 # 77% accuracy
```
Next, we consider whether pruning the tree might lead to improved results. The function cv.tree() performs cross-validation in order to determine the optimal level of tree complexity; cost complexity pruning is used in order to select a sequence of trees for consideration. We use the argument FUN=prune.misclass in order to indicate that we want the classification error rate to guide the cross-validation and pruning process, rather than the default for the cv.tree() function, which is deviance.
```{r}
set.seed(3)
cv.carseats = cv.tree(tree.carseats, FUN = prune.misclass)
names(cv.carseats)
cv.carseats
```
Note that, despite the name, dev corresponds to the cross-validation error rate in this instance. The tree with 9 terminal nodes results in the lowest cross-validation error rate, with 50 cross-validation errors. We plot the error rate as a function of both size and k.
```{r}
par(mfrow = c(1, 2))
plot(cv.carseats$size, cv.carseats$dev, type = "b")
plot(cv.carseats$k, cv.carseats$dev, type = "b")
```
How well does this pruned tree perform on the test data set? Once again, we apply the predict() function.
```{r}
prune.carseats = prune.misclass(tree.carseats, best = 9)
plot(prune.carseats)
text(prune.carseats, pretty = 0)
tree.pred = predict(prune.carseats, Carseats.test, type = "class")
table(tree.pred, High.test)
(97 + 58) / 200 # 77.5% of the predictions are accurately classified
```
## 8.3.2 Fitting Regression Trees
Here we fit a regression tree to the Boston data set. First, we create a training set, and fit the tree to the training data.
```{r}
library(MASS)
set.seed(1)
train = sample(1:nrow(Boston), nrow(Boston)/2)
tree.boston = tree(medv ~., Boston, subset = train)
summary(tree.boston)
plot(tree.boston)
text(tree.boston, pretty = 0)
```
Now we use the cv.tree() function to see whether pruning the tree will improve performance.
```{r}
cv.boston = cv.tree(tree.boston)
plot(cv.boston$size, cv.boston$dev, type = "b")
```
In keeping with the cross-validation results, we use the unpruned tree to make predictions on the test set.
```{r}
yhat = predict(tree.boston, newdata = Boston[-train,])
boston.test = Boston[-train, "medv"]
plot(yhat, boston.test)
abline(0, 1)
mean((yhat - boston.test)^2) # 35.28688
sqrt(mean((yhat - boston.test)^2)) # 5.940276, this model gives predictions that are +/- 5,940 of the true median home value
```
## 8.3.3 Batting and Random Forests
Here we apply bagging and random forests to the Boston data, using the randomForest package in R.
```{r}
library(randomForest)
set.seed(1)
bag.boston = randomForest(medv~., data = Boston, subset = train, mtry = 13, importance = TRUE)
bag.boston
```
The argument mtry=13 indicates that all 13 predictors should be considered for each split of the tree—in other words, that bagging should be done. How well does this bagged model perform on the test set?
```{r}
yhat.bag = predict(bag.boston, newdata = Boston[-train,])
plot(yhat.bag, boston.test)
abline(0,1)
mean((yhat.bag - boston.test)^2) # 23.59273, a significant improvement
```
We could change the number of trees grown by randomForest() using the ntree argument:
```{r}
bag.boston = randomForest(medv~., data = Boston, subset = train, mtry = 13, ntree = 25)
yhat.bag = predict(bag.boston, newdata = Boston[-train,])
mean((yhat.bag - boston.test)^2) # 23.66716
```
Growing a random forest proceeds in exactly the same way, except that we use a smaller value of the mtry argument. By default, randomForest() uses p/3 variables when building a random forest of regression trees, and √p variables when building a random forest of classification trees. Here we use mtry = 6.
```{r}
set.seed(1)
rf.boston = randomForest(medv~., data = Boston, subset = train, mtry = 6, improtance = TRUE)
yhat.rf = predict(rf.boston, newdata = Boston[-train,])
mean((yhat.rf - boston.test)^2) # 19.47538
```
Using the importance() function, we can view the importance of each variable.
```{r}
importance(rf.boston)
```
## 8.3.4 Boosting
Here we use the gbm package, and within it the gbm() function, to fit boosted regression trees to the Boston data set. We run gbm() with the option distribution="gaussian" since this is a regression problem; if it were a bi- nary classification problem, we would use distribution="bernoulli". The argument n.trees=5000 indicates that we want 5000 trees, and the option interaction.depth=4 limits the depth of each tree.
```{r}
library(gbm)
set.seed(1)
boost.boston = gbm(medv~., data = Boston[train,], distribution = "gaussian", n.trees = 5000, interaction.depth = 4)
# The summary() function produces a relative influence plot and also outputs the relative influence statistics.
summary(boost.boston)
```
We see that lstat and rm are by far the most important variables. We can also produce partial dependence plots for these two variables. These plots illustrate the marginal effect of the selected variables on the response after integrating out the other variables.
```{r}
par(mfrow = c(1,2))
plot(boost.boston, i = "rm")
plot(boost.boston, i = "lstat")
```
We now use the boosted model to predict medv on the test set:
```{r}
yhat.boost = predict(boost.boston, newdata = Boston[-train,], n.trees = 5000)
mean((yhat.boost - boston.test) ^2) # 18.84709
```