Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
huizezhang-sherry committed May 24, 2024
1 parent abf1ea9 commit 5783022
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 36 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ Imports:
ggforce,
tidyr,
GpGp,
cli
cli,
progress
RoxygenNote: 7.3.1
Depends:
R (>= 2.10)
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ importFrom(GpGp,fit_model)
importFrom(cli,cli_abort)
importFrom(ggplot2,"%+replace%")
importFrom(magrittr,"%>%")
importFrom(progress,progress_bar)
importFrom(rlang,.data)
importFrom(rlang,`:=`)
importFrom(tidyr,unnest)
Expand Down
8 changes: 4 additions & 4 deletions R/calc-smoothness.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ calc_smoothness <- function(idx, data = sine1000, n_basis = 300, n = 6, d = 2,
best = matrix(c(0, 0, 0, 0, 1, 0,
0, 0, 0, 0, 0, 1), nrow = 6),
start_parms = c(0.001, 0.5, 2, 2),
other_gp_params = list(NULL)
other_gp_params = NULL
){

# sample basis
idx <- dplyr::sym(idx)
set.seed(123)
seed <- sample(1: 10000, size = n_basis)
basis_df <- tibble::tibble(basis = lapply(1:n_basis, function(i){
set.seed(seed[i]); tourr::basis_random(n = n, d = d)}))
set.seed(seed[i]); tourr::basis_random(n = n, d = d)})) |>
dplyr::rowwise() |>
dplyr::mutate(proj_dist = tourr::proj_dist(best, basis),
index_val = get(idx)()(as.matrix(data) %*% basis))
Expand All @@ -34,13 +34,13 @@ calc_smoothness <- function(idx, data = sine1000, n_basis = 300, n = 6, d = 2,
start_parms = start_parms, covfun_name = "matern_isotropic",
other_gp_params
)
fit <- do.call("GpGp::fit_model", gp_params)
fit <- do.call("fit_model", gp_params)
cov_params <- tibble::as_tibble_row(fit$covparms, .name_repair = "unique")
colnames(cov_params) <- c("variance", "range", "smoothness", "nugget", "convergence")
cov_params <- cov_params |> dplyr::mutate(convergence = fit$conv, idx = as.character(idx))

# return
list(basis_df = basis_df, gp_res = fit, measure = cov_params)
list(basis_df = basis_df, gp_res = list(fit), measure = cov_params)
}


Expand Down
75 changes: 49 additions & 26 deletions R/calc-squintability.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,26 @@
#' @param bin_nobs_threshold numeric, only for squintability, the threshold
#' number of observations for
#' applying binning before fitting the kernel
#' @param bin_size only for squintability, the bin size for binning the data
#' @param bin_width only for squintability, the bin size for binning the data
#' before fitting the kernel
#' @param sampling_seed the seed used for sampling the random basis
#' @rdname optim
#' @export
calc_squintability <- function(idx, data = sine1000,
method = c("ks", "nls"), n_basis, n = 6, d = 2,
method = c("ks", "nls"), n_basis = 100, n = 6, d = 2,
proj_dist_threshold = 1.5, step = 0.02,
best = matrix(c(0, 0, 0, 0, 1, 0,
0, 0, 0, 0, 0, 1), nrow = 6),
bin_nobs_threshold = 5000, bin_size = 0.02
bin_nobs_threshold = 5000, bin_width = 0.02,
sampling_seed = 123
){

cli::cli_inform("sample random bases ...")
## sample basis
set.seed(123)
set.seed(sampling_seed)
seed <- sample(1: 10000, size = 1000)
basis_lst <- list()
bb_lst <- list()
i <- 1
while (length(basis_lst) < n_basis){
while (length(bb_lst) < n_basis){
set.seed(seed[i])
bb <- tourr::basis_random(n = n, d = d)
if (tourr::proj_dist(best, bb) > proj_dist_threshold){
Expand All @@ -35,21 +37,37 @@ calc_squintability <- function(idx, data = sine1000,

## interpolate between the best and the random basis
## TODO: progress bar here
cli::cli_inform("interpolate between the best and the random bases ...")
basis_df <- tibble::tibble(id = 1:n_basis) |>
dplyr::mutate(res = lapply(bb_lst, function(bb){
interp_bb_best(bb = bb, best = best, step = step)
})) |>
tidyr::unnest(res) |>
dplyr::rowwise() |>
dplyr::mutate(!!dplyr::sym(idx) := get(idx)()(data %*% basis)) |>
dplyr::ungroup()
unnest(dist)

df_add_idx_val <- function(data, idx, org_data){
pb$tick()
data |> dplyr::mutate(!!rlang::sym(idx) := get(idx)()(org_data %*% basis[[1]]))
}

cli::cli_inform("calculate index values for interpolated bases ...")
pb <- progress::progress_bar$new(total = nrow(basis_df))
basis_df <- basis_df |>
dplyr::group_split(aa = dplyr::row_number()) |>
purrr::map_dfr(~df_add_idx_val(.x, idx, data)) |>
dplyr::select(-aa)


cli::cli_inform("fit kernel smoothing or non-linear least square ...")
res <- switch(method,
ks = fit_ks(basis_df, idx = idx, bin_nobs_threshold, bin_size),
nls = fit_nls(basis_df, idx = idx, bin_nobs_threshold, bin_size)
ks = fit_ks(basis_df, idx = idx, bin_width = bin_width,
bin_nobs_threshold = bin_nobs_threshold),
nls = fit_nls(basis_df, idx = idx, bin_width = bin_width,
bin_nobs_threshold = bin_nobs_threshold)
)

list(basis_df = basis_df, measure = res)
tibble::tibble(basis_df = list(basis_df), measure = res) |>
unnest(measure)

}

Expand All @@ -71,12 +89,14 @@ interp_bb_best <- function(bb, best, step = 0.02){
tibble::tibble(basis = tt_mat, dist = dist)
}

fit_ks <- function(basis_df, idx, bin_nobs_threshold, bin_size){

fit_ks <- function(basis_df, idx, bin_nobs_threshold, bin_width){
if (nrow(basis_df) > bin_nobs_threshold){
cli::cli_abort("apply binning before fitting the kernel smoother with bin_size = {bin_size}")
cli::cli_inform("apply binning: bin_width = {bin_width}")
dist_bin <- ceiling(basis_df$dist / bin_width) * bin_width
basis_df <- basis_df |>
dplyr::mutate(dist_bin = ceiling(dist / bin_size) * bin_size)
dplyr::bind_cols(dist_bin = dist_bin) |>
dplyr::group_by(dist_bin) |>
dplyr::summarise(!!rlang::sym(idx) := mean(!!rlang::sym(idx)))
} else{
basis_df <- basis_df |> dplyr::mutate(dist_bin = dist)
}
Expand All @@ -90,22 +110,25 @@ fit_ks <- function(basis_df, idx, bin_nobs_threshold, bin_size){

}

fit_nls <- function(basis_df, idx, bin_nobs_threshold, bin_size){
fit_nls <- function(basis_df, idx, bin_nobs_threshold, bin_width){

if (nrow(basis_df) > bin_nobs_threshold){
cli::cli_abort("apply binning before fitting the kernel smoother with bin_size = {bin_size}")
cli::cli_inform("apply binning: bin_width = {bin_width}")
dist_bin <- ceiling(basis_df$dist / bin_width) * bin_width
basis_df <- basis_df |>
dplyr::mutate(dist_bin = ceiling(dist / bin_size) * bin_size,
dist_bin = dist_bin / pi * 180)
dplyr::bind_cols(dist_bin = dist_bin / pi * 180) |>
dplyr::group_by(dist_bin) |>
dplyr::summarise(idx := mean(!!rlang::sym(idx)))
} else{
basis_df <- basis_df |> dplyr::mutate(dist_bin = dist / pi * 180)
dist_bin <- ceiling(basis_df$dist / bin_width) * bin_width
basis_df <- basis_df |> dplyr::bind_cols(dist_bin = dist_bin / pi * 180) |>
dplyr::rename(idx = !!dplyr::sym(idx))
}

model = stats::nls(idx ~ theta1/(1 + exp(-theta2 + theta3 * dist_bin)),
start = list(theta1 = 1, theta2 = 5, theta3 = 0.1))
data = basis_df, start = list(theta1 = 1, theta2 = 5, theta3 = 0.1))
theta_params <- stats::coef(model)
colnames(theta_params) <- paste0("theta", 1:length(theta_params))
tibble::tibble(idx = idx) |> dplyr::bind_cols(theta_params)
tibble::tibble(idx = idx) |> dplyr::bind_cols(tibble::as_tibble_row(theta_params))
}

globalVariables(c("dist", "dist_bin", "dist", "y", "x", "dy", "max_dev", "max_x"))
globalVariables(c("dist", "dist_bin", "dist", "y", "x", "dy", "max_dev", "max_x", "aa", "measure"))
1 change: 1 addition & 0 deletions R/ferrn-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
#' @importFrom cli cli_abort
#' @importFrom tidyr unnest
#' @importFrom rlang `:=`
#' @importFrom progress progress_bar
"_PACKAGE"
2 changes: 1 addition & 1 deletion man/data.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 7 additions & 4 deletions man/optim.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 5783022

Please sign in to comment.