-
Notifications
You must be signed in to change notification settings - Fork 1
/
13-intro-class.qmd
145 lines (130 loc) · 7.94 KB
/
13-intro-class.qmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Introduction to supervised classification
<!--
- overview of supervised classification, model structure
- refer to the ISLR language
- basics of training/testing for model building
- model components needed from all
- common tasks: checking clusters relative to class labels, looking at boundaries
-->
Methods for supervised classification originated in the field of Statistics in the early nineteenth century, under the moniker *discriminant analysis* (see, for example, @Fi36). An increase in the collection of data, and storage in databases, in the late twentieth century has inspired a growing desire to extract knowledge from data, particularly to be able accurately predict the class labels. This has contributed to an explosion of research on new methods, especially on algorithms that focus on accurate prediction of new data based on training samples.
\index{classification!supervised}
In contrast to unsupervised classification, the class label (categorical response variable) is known, in the training sample. The training sample is used to build the prediction model, and also to estimate the accuracy, or inversely error, of the model for future data. It is also important to understand the model and to interpret it, so that we can know how predictions are made. High-dimensional visualisation can help with this, and helps to tackle questions like:
- Are the classes well separated in the data space, so that they
correspond to distinct clusters? If so, what are the shapes of the clusters? Is each cluster sufficiently ellipsoidal so that we can assume that the data arises from a mixture of multivariate normal distributions? Do the clusters exhibit characteristics that suggest one algorithm in preference to others?
- Where does the boundary between classes fall? Are the classes
linearly separable, or does the difference between classes suggest
a non-linear boundary? How do changes in the input parameters affect these boundaries? How do the boundaries generated by different methods vary?
- What cases are misclassified, or have more uncertain predictions? Are there places in the data space where predictions are especially good or bad?
- Which predictors most contribute to the model predictions? Is it possible to reduce the set of explanatory variables?
Addressing these types of queries also motivate the emerging field called explainable artificial intelligence (XAI), which goes beyond predictive accuracy to more completely satisfy the *desire to extract knowledge from data*.
Although we focus on categorical response, some of the techniques here can be modified or adapted for problems with a numeric, or continuous, response variable. With a categorical response, and numerical predictors, we map colour to the response variable and use the tour to examine the relationship between predictors, and the different classes.
```{r}
#| label: fig-sup-example
#| fig-cap: "Examples of supervised classification patterns: (a) linearly separable, (b) linear but not completely separable, (c) non-linearly separable, (d) non-linear, but not completely separable."
#| echo: false
#| message: false
library(ggplot2)
library(dplyr)
library(colorspace)
library(patchwork)
set.seed(524)
x1 <- runif(176) + 0.5
x1[1:61] <- x1[1:61] - 1.2
x2 <- 1 + 2*x1 + rnorm(176)
x2[1:61] <- 2 - 3*x1[1:61] + rnorm(61)
x3 <- runif(176) + 0.5
x3[1:61] <- x3[1:61] - 0.5
x4 <- 0.25 - x3 + rnorm(176)
x4[1:61] <- -0.25 + 3*x3[1:61] + rnorm(61)
cl <- factor(c(rep("A", 61), rep("B", 176-61)))
df <- data.frame(x1, x2, x3, x4, cl)
class1 <- ggplot(df, aes(x=x1, y=x2, colour = cl)) +
geom_point(alpha=0.7) +
scale_colour_discrete_divergingx(
palette = "Zissou 1", nmax = 2, rev = TRUE) +
annotate("text", -0.65, 6.6, label="a") +
theme(aspect.ratio=1,
legend.position = "none",
axis.text = element_blank(),
axis.title = element_blank(),
axis.ticks = element_blank(),
panel.background = element_rect("white"),
panel.border = element_rect("black", fill=NA,
linewidth = 0.5))
class2 <- ggplot(df, aes(x=x3, y=x4, colour = cl)) +
geom_point(alpha=0.7) +
scale_colour_discrete_divergingx(
palette = "Zissou 1", nmax = 2, rev = TRUE) +
annotate("text", 0.05, 4.1, label="b") +
theme(aspect.ratio=1,
legend.position = "none",
axis.text = element_blank(),
axis.title = element_blank(),
axis.ticks = element_blank(),
panel.background = element_rect("white"),
panel.border = element_rect("black", fill=NA,
linewidth = 0.5))
set.seed(826)
x5 <- 2*(runif(176) - 0.5)
x6 <- case_when(x5 < -0.4 ~ -1.2 - 3 * x5,
x5 > 0.2 ~ 2.4 - 3 * x5,
.default = 1.2 + 3 * x5)
x5 <- 2*x5
x6 <- x6 + rnorm(176) * 0.25
x6[1:83] <- x6[1:83] - 1.5
x7 <- 2*(runif(176) - 0.5)
x8 <- case_when(x7 < -0.4 ~ -1.2 - 3 * x7,
x7 > 0.2 ~ 2.4 - 3 * x7,
.default = 1.2 + 3 * x7)
x7 <- 2*x7
x8[x7 < -0.1] <- x8[x7 < -0.1] + rnorm(length(x8[x7 < -0.1])) * 0.25
x8[x7 >= -0.1] <- x8[x7 >= -0.1] + rnorm(length(x8[x7 >= -0.1])) * 0.5
x8[1:83] <- x8[1:83] - 1.5
cl2 <- factor(c(rep("A", 83), rep("B", 176-83)))
df2 <- data.frame(x5, x6, x7, x8, cl2)
class3 <- ggplot(df2, aes(x=x5, y=x6, colour = cl2)) +
geom_point(alpha=0.7) +
scale_colour_discrete_divergingx(
palette = "Zissou 1", nmax = 2, rev = TRUE) +
annotate("text", -1.95, 2.15, label="c") +
theme(aspect.ratio=1,
legend.position = "none",
axis.text = element_blank(),
axis.title = element_blank(),
axis.ticks = element_blank(),
panel.background = element_rect("white"),
panel.border = element_rect("black", fill=NA,
linewidth = 0.5))
class4 <- ggplot(df2, aes(x=x7, y=x8, colour = cl2)) +
geom_point(alpha=0.7) +
scale_colour_discrete_divergingx(
palette = "Zissou 1", nmax = 2, rev = TRUE) +
annotate("text", 1.95, 1.9, label="d") +
theme(aspect.ratio=1,
legend.position = "none",
axis.text = element_blank(),
axis.title = element_blank(),
axis.ticks = element_blank(),
panel.background = element_rect("white"),
panel.border = element_rect("black", fill=NA,
linewidth = 0.5))
print(class1 + class2 + class3 + class4 + plot_layout(ncol=2))
```
@fig-sup-example shows some 2D examples where the two classes are (a) linearly separable, (b) not completely separable but linearly different, (c) non-linearly separable and (d) not completely separable but with a non-linear difference. We can also see that in (a) only the horizontal variable would be important for the model because the two classes are completely separable in this direction. Although the pattern in (c) is separable classes, most models would have difficulty capturing the separation. It is for this reason that it is important to understand the boundary between classes produced by a fitted model. In each of b, c, d it is likely that some observations would be misclassified. Identifying these cases, and inspecting where they are in the data space is important for understanding the model's future performance.
## Exercises {-}
1. For the penguins data, use the tour to decide if the species are separable, and if the boundaries between species is linear or non-linear.
2. Using just the variables `se`, `maxt`, `mint`, `log_dist_road`, and "accident" or "lightning" causes, use the tour to decide whether the two classes are separable, and whether the boundary might be linear or non-linear.
```{r eval=FALSE}
#| echo: false
b_sub <- bushfires |>
select(se, maxt, mint, log_dist_road, cause) |>
filter(cause %in% c("accident", "lightning")) |>
rename(ldr = log_dist_road) |>
mutate(cause = factor(cause))
animate_xy(b_sub[,-5], col=b_sub$cause, rescale=TRUE)
animate_xy(b_sub[,-5], guided_tour(lda_pp(b_sub$cause)), col=b_sub$cause, rescale=TRUE)
```
::: {.content-hidden}
Q1 answer: Not separable, but boundary could be linear.
Q2 answer: Gentoo and others are separable. Chinstrap and Adelie are not separable. All bounaries are linear.
:::