-
Notifications
You must be signed in to change notification settings - Fork 0
/
Handwritten_Digits.ggplot2.Rmd
193 lines (171 loc) · 8.51 KB
/
Handwritten_Digits.ggplot2.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
---
title: "Handwritten Digits with ggplot2"
author: "James Hirschorn"
date: "May 26, 2017"
output: html_document
---
There was some interest on [Kaggle](https://www.kaggle.com) in our kernel for plotting handwritten digits using the highly popular `R` package [`ggplot2`](http://ggplot2.tidyverse.org). This was forked from Ben Hamner's [kernel](https://www.kaggle.com/benhamner/digit-recognizer/example-handwritten-digits/code).
The purpose of this blog post is to demonstrate plotting handwritten digits using `ggplot2`, in a much cleaner manner than our original kernel. `ggplot2` provides a powerful graphics grammar for creating plots. The handwritten digits come from the [MNIST Database](http://yann.lecun.com/exdb/mnist), one of the most widely used datasets in machine learning. We have found many examples of `R` code displaying this dataset,
e.g. [@Chivers2012], [@Lam2015], [@Liu2014] and others. However, none of these (besides the Kaggle kernels) makes use of `ggplot2`. Our contribution is this tutorial on how to do this. A follow up post is planned, where some basic machine learning is applied and `ggplot2` is used to visualize the miscategorizations.
We load all of the necessary libraries here.
```{r, message=FALSE}
library(data.table)
library(dtplyr)
library(dplyr)
library(foreach)
library(ggplot2)
library(grid)
```
The MNIST data is available from [Kaggle](https://www.kaggle.com/c/digit-recognizer/data) in CSV format, or the following code will automatically download it for you if you do not already have the training data. It contains 42,000 scanned images of handwritten digits. We load the training data from this dataset, and use the first 49 digits as our initial sample.
```{r}
if (!file.exists('train.csv')) {
download.file(file.path('https://github.com/quantitative-technologies',
'handwritten-digits-ggplot2/releases/download/v1.0.0/train.csv'),
'train.csv')
}
train <- fread('train.csv')
sample <- train[1:49,]
```
Each row of data has `28 * 28 = 784` values giving the darkness of each pixel in the image of a digit. This is massaged into a form usable by `ggplot2`, where each pixel is an observation of the form `(digit_id, label, x, y, gray_level)`, with the `x` and `y` values normalized to the interval $[0,1)$ (unnecessary but convenient). After converting `grey_level` to an actual `R` colour,
the digits in this format can be drawn with a `geom_raster` layer.
```{r}
transform_images <- function(data, inverted = FALSE) {
data.row <- select(data, -label)
x_values <- rep(rep(0:27), 28) / 28
y_values <- as.vector(sapply(27:0, function(x) rep(x, 28))) / 28
foreach(i=1:nrow(data), .combine = 'rbind') %do% {
gray_level <- as.numeric(ifelse(rep(inverted, 784), data.row[i], 255 - data.row[i])) / 255
dt <- data.table(digit_id = i, label = data[i, label],
x = x_values, y = y_values, gray_level = gray_level)
}
}
digits <- transform_images(sample)
```
The `facet_wrap` function is used to draw a sequence of panels that wraps around at the end of the horizontal space. It is specified below that one panel is drawn for each `digit_id` value. The layer with the digits is then added. Setting the `interpolate` parameter of `geom_raster` to `TRUE` smooths the image improving the visual appeal. We need to select the identity scale for the `fill` aesthetic so that `ggplot2` uses the actual unscaled colour values from the dataset. `coord_fixed()` preserves (or "fixes") the aspect ratio, i.e. one unit on the either the x or y-axis will have the same length on the output device. Thus each 28 x 28 pixel image will appear as a square on your screen, for example.
```{r}
digits_plot <- ggplot(digits) +
facet_wrap(~digit_id, ncol = 7) +
geom_raster(aes(x, y, fill = gray(gray_level)), interpolate = TRUE) +
scale_fill_identity() +
coord_fixed()
digits_plot
```
Next we add the labels with a `geom_text` layer. Note that only one observation should be made for each label. If `digits` is used as is without `select` and `unique` the labels will get plotted 784 times each. The choice of coordinates `(x = 0, y = 1)` places the label upper-left. After this, the title is added and a `theme` is used to format the title, add a border to each panel and remove all of the extra items from the plot above. This completes our plot:
```{r}
p <- digits_plot +
geom_text(aes(x = 0, y = 1, label = label),
unique(select(digits, digit_id, label)),
hjust = 'inward', vjust = 'inward',
size = 5, colour = 'darkgreen') +
ggtitle('Example Handwritten Digits') +
theme(strip.text = element_blank(),
axis.text = element_blank(),
axis.ticks = element_blank(),
axis.title = element_blank(),
plot.title = element_text(size = 18, hjust = '0.5'),
panel.border = element_rect(fill = NA, colour = 'blue'))
p
```
The above plotting commands are parameterized and encapsulated in the following function `plot_labelled_images`.
```{r}
plot_labelled_images <- function(images, title,
smoothed = TRUE, inverted = FALSE, ncol = 7) {
label_colour <- ifelse(inverted, 'green', 'darkgreen')
background_colour <- ifelse(inverted, 'black', 'white')
if(inverted) {
images <- invert_images(images)
}
p <- ggplot(images) +
facet_wrap(~digit_id, ncol = ncol) +
geom_raster(aes(x, y, fill = gray(gray_level)), interpolate = smoothed) +
scale_fill_identity() +
coord_fixed() +
geom_text(aes(x = 0, y = 1, label = label),
unique(select(images, digit_id, label)),
hjust = 'inward', vjust = 'inward',
size = 5, colour = label_colour) +
ggtitle(title) +
theme(strip.text = element_blank(),
axis.text = element_blank(),
axis.ticks = element_blank(),
axis.title = element_blank(),
plot.title = element_text(size = 18, hjust = '0.5'),
panel.grid = element_blank(),
panel.background = element_rect(fill = background_colour),
panel.border = element_rect(fill = NA, colour = 'blue'))
plot(p)
}
invert_images <- function(images) {
mutate(images, gray_level = 1 - gray_level)
}
```
Using the `plot_labelled_images` function, the previous plot can be created with:
```{r, eval=FALSE}
plot_labelled_images(digits, 'Example Handwritten Digits')
```
Our function can be used with a couple lines of code to present the average images over the whole data set, as was done in [@Chivers2012].
```{r}
averages <- train %>% group_by(label) %>% summarise_all(mean) %>% setkey(label)
average_digits <- transform_images(averages)
plot_labelled_images(average_digits, 'Averaged Handwritten Digits',
smoothed = FALSE, inverted = TRUE, ncol = 5)
```
The following function `examine_all_digits` can be run interactively to scan through all 42,000 digits, 49 at a time, by running `examine_all_digits(train)`.
```{r eval=FALSE}
examine_all_digits <- function(data) {
for(i in seq(1, nrow(data), 49)) {
last_digit_id = min(nrow(data), i + 48)
digits <- transform_images(data[i:last_digit_id])
plot_labelled_images(digits, paste('Handwritten Digits', i, 'to', last_digit_id))
c <- readline(prompt = "Press [Enter] to continue, or 'q' to quit.")
if (tolower(c) == 'q')
break
}
}
```
Finally, we use the `plot_labelled_images` function to show 49 randomly chosen digits.
```{r}
random_sample <- sample_n(train, 49)
random_digits <- transform_images(random_sample)
plot_labelled_images(random_digits, 'Random Selection', smoothed = TRUE, inverted = TRUE)
```
---
The source for this article is available on [Github](https://github.com/quantitative-technologies/handwritten-digits-ggplot2).
---
---
link-citations: true
references:
- id: Chivers2012
title: The essence of a handwritten digit
author:
- family: Chivers
given: Corey
URL: 'https://www.r-bloggers.com/the-essence-of-a-handwritten-digit'
type: post-weblog
issued:
year: 2012
month: 8
day: 13
- id: Lam2015
title: A little H2O deeplearning experiment on the MNIST data set
author:
- family: Lam
given: Longhow
URL: 'https://www.r-bloggers.com/a-little-h2o-deeplearning-experiment-on-the-mnist-data-set'
type: post-weblog
issued:
year: 2015
month: 11
day: 25
- id: Liu2014
title: 'R: Classifying Handwritten Digits (MNIST) using Random Forests'
author:
- family: Liu
given: Wayne
URL: 'http://beyondvalence.blogspot.ca/2014/01/r-classifying-handwritten-digits-mnist.html'
type: post-weblog
issued:
year: 2014
month: 1
day: 16
---