-
Notifications
You must be signed in to change notification settings - Fork 0
/
day six random forests.R
109 lines (75 loc) · 2.67 KB
/
day six random forests.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
library(rsample)
library(tidymodels)
library(dplyr)
diamonds2<-diamonds %>%
sample_n(1000)
data_split <- diamonds2 %>%
initial_split(prop = .8)
training_data <- training(data_split)
validation_data <- testing(data_split)
rf_cls_spec <-
rand_forest(trees = 200, min_n = 5) %>%
# This model can be used for classification or regression, so set mode
set_mode("classification") %>%
set_engine("randomForest")
rf_cls_spec
set.seed(97331)
#we have no reason to believe that this will produce any meaningful results
rf_cls_fit <- rf_cls_spec %>% fit(color ~ clarity + cut, data = training_data)
rf_cls_fit
predicted<-bind_cols(
predict(rf_cls_fit, validation_data),
predict(rf_cls_fit, validation_data, type = "prob")
)
result<-data.frame(validation_data,predicted)
#and to avoid chaos
colnames(result)[3]<-"diamond_color"
library(ggplot2)
result %>%
mutate("correct" = if_else(as.character(.pred_class)==as.character(diamond_color), "yes", "no")) %>%
ggplot(aes(.pred_class, diamond_color, colour=correct))+geom_jitter()
diamonds %>%
ggplot(aes(carat, clarity, colour=color))+geom_jitter()
#lets try a regression then...
#using the same data
rf_reg_spec <-
rand_forest(trees = 200, min_n = 5) %>%
# This model can be used for classification or regression, so set mode
set_mode("regression") %>%
set_engine("randomForest")
rf_reg_spec
set.seed(97331 )
rf_reg_fit <- rf_reg_spec %>% fit(price ~ color + clarity + depth, data = training_data)
rf_reg_fit
result<-predict(rf_reg_fit, validation_data)
diamond_errors<-data.frame(validation_data, result)
diamond_errors %>%
ggplot(aes(price, .pred, colour=color, size=carat))+geom_point()
diamonds %>%
ggplot(aes(price, carat, color=color))+geom_jitter()+facet_grid(~clarity)+theme(axis.text.x = element_text(angle = 90, vjust = 0.5, hjust=1))
#lets get some analysis going here...
#a small cultural data set
TV3<-TV %>%
filter(Year > 1978 & Year <1989)
library(rsample)
data_split<-initial_split(TV3)
training_data <- training(data_split)
validation_data <- testing(data_split)
rf_reg_spec <-
rand_forest(trees = 200, min_n = 5) %>%
# This model can be used for classification or regression, so set mode
set_mode("regression") %>%
set_engine("randomForest")
rf_reg_spec
set.seed(97331)
rf_reg_fit <- rf_reg_spec %>% fit(Rating ~ as.factor(Network) + as.factor(Type), data = training_data)
rf_reg_fit
result<-predict(rf_reg_fit, validation_data)
ratings<-data.frame(result, validation_data)
ratingsb<-ratings %>%
mutate("difference" = Rating-.pred)
ratingsb %>%
ggplot(aes(Rating, .pred, colour=difference))+geom_point()
#lets explore s
#where was the model good? Where was it bad?
View(ratingsb)