forked from iceberg-project/Seals
-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot_comparison.R
55 lines (46 loc) · 2 KB
/
plot_comparison.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# Script to plot a comparison plot across seal detection models
# load packages
library(ggplot2)
library(reshape2)
library(argparse)
# define arg-parser
parser = ArgumentParser(description="R script to get pooled data csv and plot model comparisons")
parser$add_argument("--input_file", type="character", help=".csv file with precision and recall")
parser$add_argument("--output_file", type="character", help='filename for the .png plot')
parser$add_argument("--x", type="character", help="name of the column which will be used as the X axis")
parser$add_argument("--y", type="character", help='name of the column which will be used as the Y axis')
parser$add_argument("--facet", default='NULL', type="character", help='name of the column which will be used to determine facets')
# unroll arguments
args = parser$parse_args()
inp_file = args$input_file
out_file = args$output_file
x_label = args$x
y_label = args$y
facet = args$facet
# read csv table from all
pooled_data = read.csv(inp_file, stringsAsFactors=FALSE)
# add columns for plot axes and facet wrap
pooled_data['x'] = pooled_data[, which(colnames(pooled_data) == x_label)]
pooled_data['y'] = pooled_data[, which(colnames(pooled_data) == y_label)]
pooled_data['facet'] =pooled_data[, which(colnames(pooled_data) == facet)]
# find x and y limits
x_lim = c(0, max(pooled_data['x']))
y_lim = c(0, max(pooled_data['y']))
# function for plotting precision/recall of a single label
label_plot = ggplot(data=pooled_data,
mapping=aes(x=x, y=y, color=model_name)) +
theme_minimal(base_size=15) +
geom_point(size=4, alpha=0.8) +
labs(x=x_label, y=y_label) +
xlim(x_lim) +
ylim(y_lim) +
theme(axis.title = element_text(face="bold", size=18),
strip.text.x = element_text(size=18, face="italic"))+
scale_colour_brewer(palette="Set1")
if(facet != 'NULL'){
label_plot = label_plot + facet_wrap(~facet, ncol=2, scales="free")
}
# save confusion_matrix as a png figure
png(out_file, width=1200, height=800)
print(label_plot)
dev.off()