Skip to content

Commit

Permalink
no_fit does prevent fit now
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Sep 21, 2023
1 parent fa548fa commit 9e5e39c
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 75 deletions.
6 changes: 4 additions & 2 deletions R/orsf.R
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@ orsf <- function(data,
tree_type_R = 3,
tree_seeds = as.integer(tree_seeds),
loaded_forest = list(),
n_tree = n_tree,
n_tree = if(no_fit) 0 else n_tree,
mtry = mtry,
vi_type_R = switch(importance,
"none" = 0,
Expand Down Expand Up @@ -745,7 +745,9 @@ orsf <- function(data,
'user' = 2),
oobag_eval_every = oobag_eval_every,
n_thread = n_thread,
write_forest = !no_fit)
write_forest = TRUE)

# browser()

# if someone says no_fit and also says don't attach the data,
# give them a warning but also do the right thing for them.
Expand Down
78 changes: 6 additions & 72 deletions scratch.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,87 +2,21 @@ library(tidyverse)
library(riskRegression)
library(survival)

sink("orsf-output.txt")
# sink("orsf-output.txt")
fit <- orsf(pbc_orsf, Surv(time, status) ~ . - id,
n_tree = 100,
n_thread = 5,
# control = orsf_control_net(),
oobag_pred_type = 'risk',
oobag_pred_type = 'none',
split_rule = 'cstat',
split_min_stat = .5)
sink()
split_min_stat = .5,
no_fit = TRUE)
# sink()

fit$eval_oobag


.pbc_orsf <- pbc_orsf %>%
mutate(stage = factor(stage, ordered = F))

x <- model.matrix(~. -1, data = select(.pbc_orsf, -time, -status, -id))
y <- as.matrix(.pbc_orsf[, c('time', 'status')])

# .flchain <- flchain |>
# rename(time = futime, status = death) |>
# select(-chapter) |>
# tidyr::drop_na()
# x <- model.matrix(~. -1, data = select(.flchain, -time, -status))
# y <- as.matrix(.flchain[, c('time', 'status')])

w <- rep(1, nrow(x))

sorted <-
collapse::radixorder(y[, 1], # order this way for risk sets
-y[, 2]) # order this way for oob C-statistic.

y <- y[sorted, ]
x <- x[sorted, ]
w <- w[sorted]

f <- function(x, y, w){
matrix(runif(ncol(x)), ncol=1)
}

pred_horizon <- 200 # median(y[, 'time'])

# sink("orsf-output.txt")

orsf_tree = aorsf:::orsf_cpp(x,
y,
w,
tree_type_R = 3,
tree_seeds = 1:500,
loaded_forest = list(),
n_tree = 500,
mtry = 3,
vi_type_R = 2,
vi_max_pvalue = 0.01,
lincomb_R_function = f,
oobag_R_function = f,
leaf_min_events = 5,
leaf_min_obs = 5,
split_rule_R = 1,
split_min_events = 5,
split_min_obs = 10,
split_min_stat = 0,
split_max_cuts = 5,
split_max_retry = 3,
lincomb_type_R = 1,
lincomb_eps = 1e-9,
lincomb_iter_max = 1,
lincomb_scale = TRUE,
lincomb_alpha = 1,
lincomb_df_target = 1,
lincomb_ties_method = 1,
pred_type_R = 1,
pred_mode = FALSE,
pred_horizon = c(pred_horizon, 2*pred_horizon),
oobag = TRUE,
oobag_eval_every = 100,
n_thread = 6)

# sink()

orsf_tree$forest[-1] |>
fit$forest[-1] |>
as_tibble() |>
slice(1) |>
unnest(everything()) |>
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-orsf.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ add_noise <- function(x, eps = .Machine$double.eps){

}

change_scale <- function(x, mult_by = 10){
change_scale <- function(x, mult_by = 1/2){
x * mult_by
}

Expand Down

0 comments on commit 9e5e39c

Please sign in to comment.