Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add convert oml dataset to mlr3 #444

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions R/convertOMLDataSetToMlr3.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#' @title Convert an OpenML data set to mlr3 task.
#'
#' @description
#' Converts an \code{\link{OMLDataSet}} to a \code{\link[mlr3]{Task}}.
#'
#' @param obj [\code{\link{OMLDataSet}}]\cr
#' The object that should be converted.
#' @param mlr.task.id [\code{character(1)}]\cr
#' Id string for \code{\link[mlr3]{Task}} object.
#' The strings \code{<oml.data.name>}, \code{<oml.data.id>} and \code{<oml.data.version>}
#' will be replaced by their respective values contained in the \code{\link{OMLDataSet}} object.
#' Default is \code{<oml.data.name>}.
#' @param task.type [\code{character(1)}]\cr
#' As we only pass the data set, we need to define the task type manually.
#' Possible are: \dQuote{Supervised Classification}, \dQuote{Supervised Regression},
#' \dQuote{Survival Analysis}.
#' Default is \code{NULL} which means to guess it from the target column in the
#' data set. If that is a factor or a logical, we choose classification.
#' If it is numeric we choose regression. In all other cases an error is thrown.
#' @param target [\code{character}]\cr
#' The target for the classification/regression task.
#' Default is the \code{default.target.attribute} of the \code{\link{OMLDataSetDescription}}.
#' @param ignore.flagged.attributes [\code{logical(1)}]\cr
#' Should those features that are listed in the data set description slot \dQuote{ignore.attribute}
#' be removed?
#' Default is \code{TRUE}.
#' @param drop.levels [\code{logical(1)}]\cr
#' Should empty factor levels be dropped in the data?
#' Default is \code{TRUE}.
#' @param fix.colnames [\code{logical(1)}]\cr
#' Should colnames of the data be fixed using \code{\link[base]{make.names}}?
#' Default is \code{TRUE}.
#' @template arg_verbosity
#' @return [\code{\link[mlr3]{Task}}].
#' @family data set-related functions
#' @example /inst/examples/convertOMLDataSetToMlr3.R
#' @export
convertOMLDataSetToMlr3 = function(
obj,
mlr.task.id = "<oml.data.name>",
task.type = NULL,
target = obj$desc$default.target.attribute,
ignore.flagged.attributes = TRUE,
drop.levels = TRUE,
fix.colnames = TRUE,
verbosity = NULL) {

assertClass(obj, "OMLDataSet")
assertSubset(target, obj$colnames.new)
assertFlag(ignore.flagged.attributes)
assertFlag(drop.levels)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some assertions missing (mlr.task.id, fix.colnames, verbosity)


data = obj$data
desc = obj$desc

# no task type? guess it by looking at target
if (is.null(task.type))
task.type = guessTaskType(data[, target])
assertChoice(task.type, getValidTaskTypes())

# remove ignored attributes from data
if (any(!is.na(desc$ignore.attribute)) & ignore.flagged.attributes) {
keep.cols = obj$colnames.old %nin% desc$ignore.attribute
data = data[, keep.cols, drop = FALSE]
}

# drop levels
if (drop.levels)
data = droplevels(data)

# fix colnames using make.names
if (fix.colnames) {
colnames(data) = make.names(colnames(data), unique = TRUE)
target = make.names(target, unique = TRUE)
}

# get fixup verbose setting for mlr
if (is.null(verbosity))
verbosity = getOMLConfig()$verbosity
fixup = ifelse(verbosity == 0L, "quiet", "warn")

mlr.task = switch(task.type,
"Supervised Classification" = mlr3::TaskClassif$new(id = desc$name, backend = data, target = target),
"Supervised Regression" = mlr3::TaskRegr$new(id = desc$name, backend = data, target = target),
"Survival Analysis" = mlr3survival::TaskSurv$new(id = desc$name, backend = data, target = target),
stopf("Encountered currently unsupported task type: %s", task.type)
)

if (!is.null(mlr.task.id))
mlr.task$id = replaceOMLDataSetString(mlr.task.id, obj)

return(mlr.task)
}

replaceOMLDataSetString = function(string, data.set) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function already exists in convertOMLDataSetToMlr.R. unless we plan to remove mlr support this should not be duplicated. Same with guessTaskType and possibly others.

string = stri_replace_all_fixed(string, "<oml.data.id>", data.set$desc$id)
string = stri_replace_all_fixed(string, "<oml.data.name>", data.set$desc$name)
stri_replace_all_fixed(string, "<oml.data.version>", data.set$desc$version)
}

# @title Helper to guess task type from target column format.
#
# @param target [vector]
# Vector of target values.
# @return [character(1)]
guessTaskType = function(target) {
if (inherits(target, "data.frame")) {
assertDataFrame(target, types = "logical")
return("Multilabel")
} else {
if (is.factor(target) | is.logical(target))
return("Supervised Classification")
if (is.numeric(target))
return("Supervised Regression")
}

stopf("Cannot guess task.type from data!")
}

getValidTaskTypes = function() {
c("Supervised Classification", "Supervised Regression", "Survival Analysis", "Multilabel")
}

24 changes: 24 additions & 0 deletions R/convertOMLSplitsToMlr3.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
convertOMLSplitsToMlr3 = function(estim.proc, mlr.task, predict = "both") {
type = estim.proc$type
n.repeats = estim.proc$parameters[["number_repeats"]]
n.folds = estim.proc$parameters[["number_folds"]]
percentage = as.numeric(estim.proc$parameters[["percentage"]])
data.splits = estim.proc$data.splits
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The data splits need to be stored

stratified = estim.proc$parameters[["stratified_sampling"]]
stratified = ifelse(is.null(stratified), FALSE, stratified == "true")

if (type == "crossvalidation") {
if (n.repeats == 1L)
mlr.rdesc = mlr3::rsmp("cv", folds = n.folds, stratify = stratified)
else
mlr.rdesc = mlr3::rsmp("repeated_cv", reps = n.repeats, folds = n.folds, stratify = stratified)
mlr.rin = mlr.rdesc$instantiate(mlr.task)
} else if (type == "holdout") {
mlr.rdesc = mlr3::rsmp("holdout")
mlr.rin = mlr.rdesc$instantiate(task = mlr.task)
n.folds = 1
} else {
stopf("Unsupported estimation procedure type: %s", type)
}
return(mlr.rin)
}
5 changes: 5 additions & 0 deletions inst/examples/convertOMLDataSetToMlr3.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# \dontrun{
# library("mlr3")
# autosOML = getOMLDataSet(data.id = 9)
# autosMlr3 = convertOMLDataSetToMlr3(autosOML)
# }
45 changes: 45 additions & 0 deletions tests/testthat/test_local_convertOMLDataSetToMlr3.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
context("convertOMLDataSetToMlr3")

test_that("convertOMLDataSetToMlr3", {
with_test_cache({
ds = getOMLDataSet(10)

expect_is_mlr_task = function(mlr.task, ds) {
expect_equal(mlr.task$task_type, "classif")
expect_equal(mlr.task$nrow, nrow(ds$data))
expect_equal(ds$desc$default.target.attribute, mlr.task$target_names)
}

# now create the task
mlr.task = convertOMLDataSetToMlr3(ds)
expect_equal(mlr.task$task_type, "classif")

# now modify dataset by hand (no more server calls) to check
# ignore attributes stuff:
# Define the first two attributes as ignored attributes
ds$desc$ignore.attribute = colnames(ds$data[, 1:2])

mlr.task = convertOMLDataSetToMlr3(ds, ignore.flagged.attributes = TRUE)
expect_is_mlr_task(mlr.task, ds)
# we removed two attributes (and the target column is not considered here)
#expect_equal(sum(mlr.task$task.desc$n.feat), ncol(ds$data) - 3L)
expect_equal(mlr.task$ncol, ncol(ds$data) - 2L)

# pass faulty parameters
expect_error(convertOMLDataSetToMlr3(ds, task.type = "Nonexistent task type"), "element of")

# check setting mlr task id
expect_equal(convertOMLDataSetToMlr3(ds)$id, ds$desc$name)
expect_equal(convertOMLDataSetToMlr3(ds, mlr.task.id = "<oml.data.name>.<oml.data.id>")$id,
sprintf("%s.%s", ds$desc$name, ds$desc$id))
expect_equal(convertOMLDataSetToMlr3(ds, mlr.task.id = "test")$id, "test")
expect_equal(convertOMLDataSetToMlr3(ds, mlr.task.id = "<oml.data.id>")$id, as.character(ds$desc$id))
expect_equal(convertOMLDataSetToMlr3(ds, mlr.task.id = "<oml.data.name>")$id, as.character(ds$desc$name))
expect_equal(convertOMLDataSetToMlr3(ds, mlr.task.id = "<oml.data.version>")$id, as.character(ds$desc$version))
expect_equal(convertOMLDataSetToMlr3(ds, mlr.task.id = "<oml.task.id>")$id, "<oml.task.id>")

# check if conversion to regression task works
ds$desc$target.features = ds$desc$default.target.attribute = "no_of_nodes_in"
expect_equal(convertOMLDataSetToMlr3(ds)$task_type, "regr")
})
})
25 changes: 25 additions & 0 deletions tests/testthat/test_local_convertOMLSplitsToMlr3.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
context("convertOMLSplitsToMlr3")

test_that("convertOMLSplitsToMlr3", {
with_test_cache({
task = getOMLTask(59)
mlr.task = convertOMLTaskToMlr3(task)$mlr.task

oml.types = c("crossvalidation", "holdout")
mlr.types = c("cv", "holdout")

for (i in seq_along(oml.types)) {
task$input$estimation.procedure$type = oml.types[i]
if (oml.types[i] == "holdout") {
task$input$estimation.procedure$parameters$percentage = "50"
}
splits = convertOMLSplitsToMlr3(task$input$estimation.procedure, mlr.task)
expect_is(splits, "Resampling")
expect_equal(splits$id, mlr.types[i])
}

# pass invalid estim.proc
task$input$estimation.procedure$type = "blabla"
expect_error(convertOMLSplitsToMlr3(task$input$estimation.procedure, mlr.task), "Unsupported estimation procedure type: blabla")
})
})