Skip to content

Commit

Permalink
various changes in squintability
Browse files Browse the repository at this point in the history
  • Loading branch information
huizezhang-sherry committed May 29, 2024
1 parent d34048b commit 617fd50
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 16 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ export(explore_space_start)
export(explore_space_tour)
export(explore_trace_interp)
export(explore_trace_search)
export(fit_ks)
export(fit_nls)
export(flip_sign)
export(format_label)
export(get_anchor)
Expand Down
58 changes: 46 additions & 12 deletions R/calc-squintability.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#' @param proj_dist_threshold only for squintability, the threshold for projection
#' distance for the random basis to be considered in sampling
#' @param return_early logical, whether to return early of all the bases
#' before fitting the kernel or non-linear least square. This can be useful if
#' the index value evaluation is time-consuming and the user wants to save a copy
#' before fitting the kernel or non-linear least square.
#' @param method only for squintability, the method to calculate squintability,
#' either through kernel smoothing ("ks") or non-linear least square ("nls")
#' @param step only for squintability, the step size for interpolation,
Expand All @@ -8,16 +12,20 @@
#' number of observations for
#' applying binning before fitting the kernel
#' @param bin_width only for squintability, the bin size for binning the data
#' before fitting the kernel
#' before fitting the kernel, recommend to set as the same as step parameter
#' @param sampling_seed the seed used for sampling the random basis
#' @param basis_df a basis data frame, returned from \code{calc_squintability
#' (..., return_early = TRUE)}
#' @param nls_params additional parameter for fitting the nls model, see
#' \code{stats::nls()}
#' @rdname optim
#' @export
calc_squintability <- function(idx, data = sine1000,
method = c("ks", "nls"), n_basis = 100, n = 6, d = 2,
proj_dist_threshold = 1.5, step = 0.02,
calc_squintability <- function(idx, data = sine1000, return_early = FALSE,
method = c("ks", "nls"), n_basis = 50, n = 6, d = 2,
proj_dist_threshold = 1.5, step = 0.005,
best = matrix(c(0, 0, 0, 0, 1, 0,
0, 0, 0, 0, 0, 1), nrow = 6),
bin_nobs_threshold = 5000, bin_width = 0.02,
bin_nobs_threshold = 5000, bin_width = 0.005,
sampling_seed = 123
){
cli::cli_inform("sample random bases ...")
Expand All @@ -26,6 +34,10 @@ calc_squintability <- function(idx, data = sine1000,
seed <- sample(1: 10000, size = 1000)
bb_lst <- list()
i <- 1
if (!all(dim(best) == c(n, d))){
cli::cli_abort("sampled bases and the best basis must have the same dimension,
check the parameter {.field n}, {.field d}, and {.field best}.")
}
while (length(bb_lst) < n_basis){
set.seed(seed[i])
bb <- tourr::basis_random(n = n, d = d)
Expand All @@ -50,13 +62,20 @@ calc_squintability <- function(idx, data = sine1000,
data |> dplyr::mutate(!!rlang::sym(idx) := get(idx)()(org_data %*% basis[[1]]))
}

idx_sym <- rlang::sym(idx)
cli::cli_inform("calculate index values for interpolated bases ...")
pb <- progress::progress_bar$new(total = nrow(basis_df))
basis_df <- basis_df |>
dplyr::group_split(aa = dplyr::row_number()) |>
purrr::map_dfr(~df_add_idx_val(.x, idx, data)) |>
dplyr::mutate(!!idx_sym := if (idx %in% c("TIC")) {
(!!idx_sym - min(!!idx_sym)) / (max(!!idx_sym) - min(!!idx_sym))
} else {
!!idx_sym
}) |>
dplyr::select(-aa)

if (return_early) return(basis_df)

cli::cli_inform("fit kernel smoothing or non-linear least square ...")
res <- switch(method,
Expand Down Expand Up @@ -89,7 +108,9 @@ interp_bb_best <- function(bb, best, step = 0.02){
tibble::tibble(basis = tt_mat, dist = dist)
}

fit_ks <- function(basis_df, idx, bin_nobs_threshold, bin_width){
#' @export
#' @rdname optim
fit_ks <- function(basis_df, idx, bin_nobs_threshold = 5000, bin_width = 0.005){
if (nrow(basis_df) > bin_nobs_threshold){
cli::cli_inform("apply binning: bin_width = {bin_width}")
dist_bin <- ceiling(basis_df$dist / bin_width) * bin_width
Expand All @@ -110,23 +131,36 @@ fit_ks <- function(basis_df, idx, bin_nobs_threshold, bin_width){

}

fit_nls <- function(basis_df, idx, bin_nobs_threshold, bin_width){

#' @export
#' @rdname optim
fit_nls <- function(basis_df, idx, bin_nobs_threshold = 5000, bin_width = 0.005,
nls_params = list(start = list(theta1 = 1, theta2 = 1, theta3 = 50, theta4 = 0))){
if (nrow(basis_df) > bin_nobs_threshold){
cli::cli_inform("apply binning: bin_width = {bin_width}")
dist_bin <- ceiling(basis_df$dist / bin_width) * bin_width
basis_df <- basis_df |>
dplyr::bind_cols(dist_bin = dist_bin / pi * 180) |>
dplyr::bind_cols(dist_bin = dist_bin ) |>
dplyr::group_by(dist_bin) |>
dplyr::summarise(idx := mean(!!rlang::sym(idx)))
} else{
dist_bin <- ceiling(basis_df$dist / bin_width) * bin_width
basis_df <- basis_df |> dplyr::bind_cols(dist_bin = dist_bin / pi * 180) |>
basis_df <- basis_df |> dplyr::bind_cols(dist_bin = dist_bin) |>
dplyr::rename(idx = !!dplyr::sym(idx))
}
ff <- function(x, theta2, theta3){
1 / (1 + exp(theta3 * (x - theta2)))
}
ff_ratio <- function(x, theta2, theta3){
(ff(x, theta2, theta3) - ff(max(x), theta2, theta3))/
(ff(0, theta2, theta3) - ff(max(x), theta2, theta3))
}

nls_prms <- list(
formula = idx ~ (theta1 - theta4) * ff_ratio(dist_bin, theta2, theta3) + theta4,
data = basis_df) |> append(nls_params)

model <- do.call("nls", nls_prms)

model = stats::nls(idx ~ theta1/(1 + exp(-theta2 + theta3 * dist_bin)),
data = basis_df, start = list(theta1 = 1, theta2 = 5, theta3 = 0.1))
theta_params <- stats::coef(model)
tibble::tibble(idx = idx) |> dplyr::bind_cols(tibble::as_tibble_row(theta_params))
}
Expand Down
32 changes: 28 additions & 4 deletions man/optim.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 617fd50

Please sign in to comment.