Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #159 #160

Merged
merged 4 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Imports:
withr,
zeallot
Suggests:
cli,
knitr,
modeldata,
patchwork,
Expand Down
13 changes: 13 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,36 @@ S3method(tabnet_pretrain,recipe)
S3method(update,tabnet)
export("%>%")
export(attention_width)
export(cat_emb_dim)
export(check_compliant_node)
export(checkpoint_epochs)
export(decision_width)
export(drop_last)
export(encoder_activation)
export(feature_reusage)
export(lr_scheduler)
export(mask_type)
export(mlp_activation)
export(mlp_hidden_multiplier)
export(momentum)
export(nn_prune_head.tabnet_fit)
export(nn_prune_head.tabnet_pretrain)
export(node_to_df)
export(num_independent)
export(num_independent_decoder)
export(num_shared)
export(num_shared_decoder)
export(num_steps)
export(optimizer)
export(penalty)
export(tabnet)
export(tabnet_config)
export(tabnet_explain)
export(tabnet_fit)
export(tabnet_nn)
export(tabnet_pretrain)
export(verbose)
export(virtual_batch_size)
importFrom(dplyr,filter)
importFrom(dplyr,last_col)
importFrom(dplyr,mutate)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Bugfixes

* improve function documentation consistency before translation
* fix ".... is not an exported object from 'namespace:dials'" error when using tune() on tabnet parameters. (#160 @cphaarmeyer)


# tabnet 0.6.0
Expand Down
108 changes: 86 additions & 22 deletions R/dials.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ check_dials <- function() {
stop("Package \"dials\" needed for this function to work. Please install it.", call. = FALSE)
}

check_cli <- function() {
if (!requireNamespace("cli", quietly = TRUE))
stop("Package \"cli\" needed for this function to work. Please install it.", call. = FALSE)
}



#' Parameters for the tabnet model
Expand All @@ -17,56 +22,70 @@ check_dials <- function() {
#' @rdname tabnet_params
#' @return A `dials` parameter to be used when tuning TabNet models.
#' @export
decision_width <- function(range = c(8L, 64L), trans = NULL) {
attention_width <- function(range = c(8L, 64L), trans = NULL) {
check_dials()
dials::new_quant_param(
type = "integer",
range = range,
inclusive = c(TRUE, TRUE),
trans = trans,
label = c(decision_width = "Width of the decision prediction layer"),
label = c(attention_width = "Width of the attention embedding for each mask"),
finalize = NULL
)
}

#' @rdname tabnet_params
#' @export
attention_width <- function(range = c(8L, 64L), trans = NULL) {
decision_width <- function(range = c(8L, 64L), trans = NULL) {
check_dials()
dials::new_quant_param(
type = "integer",
range = range,
inclusive = c(TRUE, TRUE),
trans = trans,
label = c(attention_width = "Width of the attention embedding for each mask"),
label = c(decision_width = "Width of the decision prediction layer"),
finalize = NULL
)
}


#' @rdname tabnet_params
#' @export
num_steps <- function(range = c(3L, 10L), trans = NULL) {
feature_reusage <- function(range = c(1, 2), trans = NULL) {
check_dials()
dials::new_quant_param(
type = "integer",
type = "double",
range = range,
inclusive = c(TRUE, TRUE),
trans = trans,
label = c(num_steps = "Number of steps in the architecture"),
label = c(feature_reusage = "Coefficient for feature reusage in the masks"),
finalize = NULL
)
}

#' @rdname tabnet_params
#' @export
feature_reusage <- function(range = c(1, 2), trans = NULL) {
momentum <- function(range = c(0.01, 0.4), trans = NULL) {
check_dials()
dials::new_quant_param(
type = "double",
range = range,
inclusive = c(TRUE, TRUE),
trans = trans,
label = c(feature_reusage = "Coefficient for feature reusage in the masks"),
label = c(momentum = "Momentum for batch normalization"),
finalize = NULL
)
}


#' @rdname tabnet_params
#' @export
mask_type <- function(values = c("sparsemax", "entmax")) {
check_dials()
dials::new_qual_param(
type = "character",
values = values,
label = c(mask_type = "Final layer of feature selector, either sparsemax or entmax"),
finalize = NULL
)
}
Expand Down Expand Up @@ -101,28 +120,73 @@ num_shared <- function(range = c(1L, 5L), trans = NULL) {

#' @rdname tabnet_params
#' @export
momentum <- function(range = c(0.01, 0.4), trans = NULL) {
num_steps <- function(range = c(3L, 10L), trans = NULL) {
check_dials()
dials::new_quant_param(
type = "double",
type = "integer",
range = range,
inclusive = c(TRUE, TRUE),
trans = trans,
label = c(momentum = "Momentum for batch normalization"),
label = c(num_steps = "Number of steps in the architecture"),
finalize = NULL
)
}


#' @rdname tabnet_params
#' Non-tunable parameters for the tabnet model
#'
#' @param range unused
#' @param trans unused
#' @rdname tabnet_non_tunable
#' @export
mask_type <- function(values = c("sparsemax", "entmax")) {
check_dials()
dials::new_qual_param(
type = "character",
values = values,
label = c(mask_type = "Final layer of feature selector, either sparsemax or entmax"),
finalize = NULL
)
cat_emb_dim <- function(range = NULL, trans = NULL) {
check_cli()
cli::cli_abort("{.var cat_emb_dim} cannot be used as a {.fun tune} parameter yet.")
}

#' @rdname tabnet_non_tunable
#' @export
checkpoint_epochs <- cat_emb_dim

#' @rdname tabnet_non_tunable
#' @export
drop_last <- cat_emb_dim

#' @rdname tabnet_non_tunable
#' @export
encoder_activation <- cat_emb_dim

#' @rdname tabnet_non_tunable
#' @export
lr_scheduler <- cat_emb_dim

#' @rdname tabnet_non_tunable
#' @export
mlp_activation <- cat_emb_dim

#' @rdname tabnet_non_tunable
#' @export
mlp_hidden_multiplier <- cat_emb_dim

#' @rdname tabnet_non_tunable
#' @export
num_independent_decoder <- cat_emb_dim

#' @rdname tabnet_non_tunable
#' @export
num_shared_decoder <- cat_emb_dim

#' @rdname tabnet_non_tunable
#' @export
optimizer <- cat_emb_dim

#' @rdname tabnet_non_tunable
#' @export
penalty <- cat_emb_dim

#' @rdname tabnet_non_tunable
#' @export
verbose <- cat_emb_dim

#' @rdname tabnet_non_tunable
#' @export
virtual_batch_size <- cat_emb_dim
Loading
Loading