Skip to content

Commit

Permalink
pass minimal tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dajmcdon committed Jan 21, 2025
1 parent a236706 commit 967f16a
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 10 deletions.
17 changes: 7 additions & 10 deletions R/growth_rate.R
Original file line number Diff line number Diff line change
Expand Up @@ -254,39 +254,36 @@ growth_rate <- function(y, x = seq_along(y), x0 = x,
cv <- params$cv
single_lambda <- is.numeric(params$lambda) & length(params$lambda) == 1L
lambda_seq <- is.numeric(params$lambda) & length(params$lambda) > 1L
if (is.null(params$df) && params$lambda) {
if (is.null(params$df) && lambda_seq) {
cv <- TRUE
params$df <- "min"
} else if (is.numeric(params$df) && single_lambda) {
cli_abort("Only one of {.val lambda} or {.val df} may be specified.")
}

# Estimate growth rate and return
if (cv) {
if (is.numeric(params$df)) params$df <- "min"
if (!is.character(params$df)) params$df <- "min"
if (length(params$lambda) == 1L) params$lambda <- NULL
which_lambda <- rlang::arg_match0(df, c("min", "1se"))
lam <- paste0("lambda_", which_lambda)
lam <- rlang::arg_match0(params$df, c("min", "1se"))
which_lambda <- paste0("lambda_", lam)
obj <- trendfilter::cv_trendfilter(
y, x, k = params$k, error_measure = params$error_measure,
nfolds = params$nfolds, family = params$family, lambda = params$lambda,
nlambda = params$nlambda, lambda_max = params$lambda_max,
lambda_min = params$lambda_min, lambda_min_ratio = params$lambda_min_ratio
)
f <- stats::predict(obj, newx = x0, which_lambda = which_lambda)
} else {
if (!single_lambda || !is.numeric(params$df)) {
cli_abort("If a sequence of `lambda` is used, `df` must be specified.")
}
obj <- trendfilter::trendfilter(
y, x,
k = params$k, family = params$family, lambda = params$lambda,
nlambda = params$nlambda, lambda_max = params$lambda_max,
lambda_min = params$lambda_min, lambda_min_ratio = params$lambda_min_ratio
)
lam <- if (single_lambda) obj$lambda else obj$lambda[which.min(abs(df - obj$dof))]
lam <- ifelse(single_lambda, obj$lambda, obj$lambda[which.min(abs(params$df - obj$dof))])
f <- stats::predict(obj, newx = x0, lambda = lam)
}

f <- stats::predict(obj, newx = x0, which_lambda = lam)
d <- diff(f) / diff(x0)
# Extend by one element
d <- c(d, d[length(d)])
Expand Down
9 changes: 9 additions & 0 deletions tests/testthat/_snaps/growth_rate.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,12 @@
Error in `growth_rate()`:
! "smooth_spline" requires 1 `lambda` but more were used.

# trendfilter growth_rate implementation

Code
growth_rate(y = 1:20, method = "trend_filter", params = growth_rate_global_params(
lambda = 1, df = 4))
Condition
Error in `growth_rate()`:
! Only one of "lambda" or "df" may be specified.

52 changes: 52 additions & 0 deletions tests/testthat/test-growth_rate.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,57 @@ test_that("new setup args and warnings are as expected", {

test_that("trendfilter growth_rate implementation", {
skip_if_not_installed("trendfilter", "0.0.2")
# tf with multiple lambdas, no df
expect_length(
growth_rate(y = 1:20, method = "trend_filter",
params = growth_rate_global_params(lambda = 20:1)),
20L
)
# specifying lambda seq and df (numeric) is ok
expect_length(
growth_rate(y = 1:20, method = "trend_filter",
params = growth_rate_global_params(lambda = 20:1, df = 4)),
20L
)
# single lambda and fixed df is bad
expect_snapshot(
error = TRUE,
growth_rate(y = 1:20, method = "trend_filter",
params = growth_rate_global_params(lambda = 1, df = 4))
)


# other tf args give output (correctness not checked)
z <- rnorm(30)
expect_length(growth_rate(z, method = "trend_filter"), 30L)
expect_length(growth_rate(
z, method = "trend_filter", params = growth_rate_global_params(lambda = 10)
), 30L)
expect_length(growth_rate(
z, method = "trend_filter", params = growth_rate_global_params(df = 14)
), 30L)
expect_length(growth_rate(
z, method = "trend_filter", params = growth_rate_global_params(cv = TRUE)
), 30L)
expect_length(growth_rate(
z, method = "trend_filter", params = growth_rate_global_params(k = 3)
), 30L)
expect_length(growth_rate(
z, method = "trend_filter", params = growth_rate_global_params(nlambda = 10)
), 30L)
expect_length(growth_rate(
z, method = "trend_filter", params = growth_rate_global_params(lambda_max = 10)
), 30L)
expect_length(growth_rate(
z, method = "trend_filter", params = growth_rate_global_params(lambda_min = 10)
), 30L)
expect_length(growth_rate(
z, method = "trend_filter", params = growth_rate_global_params(lambda_min_ratio = .1)
), 30L)
expect_length(growth_rate(
z, method = "trend_filter", params = growth_rate_global_params(error_measure = "mse")
), 30L)
expect_length(growth_rate(
z, method = "trend_filter", params = growth_rate_global_params(nfolds = 3)
), 30L)
})

0 comments on commit 967f16a

Please sign in to comment.