Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adfun hooks #214

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions R/basis_functions.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
make_gaussian = \(s)\(m)\(x)exp(-((x-m)^2)/(2*s^2))
make_locations = \(t)\(d)seq(from=0,to=t-1,length=d)

# g = (ss
# lapply(make_gaussian)
# mapply(FUN = \(f, x) lapply(x, f), mm, SIMPLIFY = FALSE)
# unlist(recursive = FALSE)
# lapply(do.call, args = list(tt))
# )

rbf_base = function(times, locations, scales) {
gaussians = lapply(locations, make_gaussian(scales))
gaussian_outputs = lapply(gaussians, do.call, list(times))
do.call(cbind, gaussian_outputs)
}


#' Radial Basis Functions
#'
#' Compute a set of radial basis functions (`dimension` of them).
Expand All @@ -11,9 +28,14 @@
#'
#' @export
rbf = function(time_steps, dimension, scale = time_steps / dimension) {
s = scale
make_gaussian = \(m)\(x)exp(-((x-m)^2)/(2*s^2))
locations = seq(from = 0, to = time_steps - 1, length = dimension)
gaussians = lapply(locations, make_gaussian)
do.call(cbind, lapply(gaussians, do.call, list(0:(time_steps - 1))))
locations = make_locations(time_steps)(dimension)
times = seq_len(time_steps) - 1L
rbf_base(times, locations, scale)
}


## experimental
rbf_heterogeneous = function(time_steps, locations, scales) {
times = seq_len(time_steps) - 1L
rbf_base(times, locations, scales)
}
18 changes: 0 additions & 18 deletions R/index_matrices.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,3 @@ sparse_matrix_notation = function(M, zero_based = TRUE, tol = 1e-4) {
sparse_rbf_notation = function(time_steps, dimension, zero_based = TRUE, tol = 1e-2) {
rbf(time_steps, dimension) |> sparse_matrix_notation(zero_based, tol)
}
#
# bb = 80
# x = sparse_rbf_notation(100, bb, FALSE)
# b = rnorm(bb)
#
# times = x$M %*% b
# times_approx = tapply(
# x$values * b[x$col_index]
# , x$row_index
# , sum
# )
# plot(times, times_approx)
# length(x$values) / prod(dim(x$M))
# matplot(x$M, type = "l")
# matplot(x$Msparse, type = "l")
# plot(x$M %*% b, type = "n")
# lines(x$M %*% b)
# lines(x$Msparse %*% b, col = "red")
3 changes: 3 additions & 0 deletions R/mp_tmb_model_spec.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ TMBModelSpec = function(
, must_not_save = character()
, sim_exprs = character()
, state_update = c("euler", "rk4", "euler_multinomial", "hazard")
, sdreport = TRUE
) {
self = Base()
before = force_expr_list(before)
Expand All @@ -24,6 +25,7 @@ TMBModelSpec = function(
self$must_save = must_save
self$must_not_save = must_not_save
self$sim_exprs = sim_exprs
self$sdreport = sdreport

self$expr_list = function() {
ExprList(
Expand Down Expand Up @@ -146,6 +148,7 @@ TMBModelSpec = function(
int_vecs = do.call(IntVecs, self$all_integers())
)
, time_steps = Time(as.integer(time_steps))
, do_pred_sdreport = self$sdreport
)
}
self$simulator_fresh = function(
Expand Down
50 changes: 31 additions & 19 deletions R/tmb_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,10 @@ TMBModel = function(
random = self$random$vector()
)

## FIXME: need a dummy parameter if the model has not
## need a dummy parameter if the model has not
## yet been parameterized. is there a more TMB-ish
## way to do this?
if (length(p$params) == 0L) {
p$params = 0
}
if (length(p$params) == 0L) p$params = 0
p
}
self$random_arg = function() {
Expand All @@ -194,21 +192,29 @@ TMBModel = function(
, verbose = getOption("macpan2_verbose")
) {
params = self$param_arg()
if (getOption("macpan2_tmb_type") == "Fun") params$params = numeric()
list(
mp_args = list(
data = self$data_arg()
, parameters = params
, random = self$random_arg()
, DLL = tmb_cpp
, silent = !verbose
)
tmb_args = getOption("macpan2_tmb_adfun_args")

## FIXME: deprecate old less flexible argument options.
if (!"DLL" %in% names(tmb_args)) tmb_args$DLL = tmb_cpp
if (!"silent" %in% names(tmb_args)) tmb_args$silent = !verbose

## catch modifications to mp_args that must be made in response
## to certain choices in tmb_args. surely there will be more
## to add here.
if (identical(tmb_args$type, "Fun")) mp_args$parameters$params = numeric()

return(c(mp_args, tmb_args))
}
self$ad_fun = function(
tmb_cpp = getOption("macpan2_dll")
, verbose = getOption("macpan2_verbose")
, derivs = getOption("macpan2_tmb_derivs")
) {
do.call(TMB::MakeADFun, self$make_ad_fun_arg(tmb_cpp))
do.call(TMB::MakeADFun, self$make_ad_fun_arg(tmb_cpp, verbose))
}

self$simulator = function(
Expand Down Expand Up @@ -426,6 +432,7 @@ mp_trajectory_sd = function(model, conf.int = FALSE, conf.level = 0.95) {
UseMethod("mp_trajectory_sd")
}


#' @param n Number of samples used in `mp_trajectory_ensemble`.
#' @param probs What quantiles should be returned by `mp_trajectory_ensemble`.
#' @describeIn mp_trajectory Simulate a trajectory that includes uncertainty
Expand All @@ -442,14 +449,18 @@ mp_trajectory_ensemble = function(model, n, probs = c(0.025, 0.975)) {
#' @importFrom stats qnorm
#' @export
mp_trajectory_sd.TMBSimulator = function(model, conf.int = FALSE, conf.level = 0.95) {
alpha = (1 - conf.level) / 2
r = model$report_with_sd()
if (conf.int) {
r$conf.low = r$value + r$sd * qnorm(alpha)
r$conf.high = r$value + r$sd * qnorm(1 - alpha)
}
if (conf.int) r = normal_quantiles(r, conf.level)
r
}
}

normal_quantiles = function(report_with_sd, conf.level = 0.95) {
alpha = (1 - conf.level) / 2
r = report_with_sd
r$conf.low = r$value + r$sd * qnorm(alpha)
r$conf.high = r$value + r$sd * qnorm(1 - alpha)
return(r)
}

#' @export
mp_trajectory_sd.TMBCalibrator = function(model, conf.int = FALSE, conf.level = 0.95) {
Expand Down Expand Up @@ -630,9 +641,10 @@ TMBSimulationUtils = function() {
if (compute_sd) r$values = cbind(r$values, self$sdreport()$sd)
if (.values_only) return(r$values)
s = self$.simulation_formatter(r, .phases)
if (.sort) {
s = s[order(s$time), , drop = FALSE] ## TODO: move sorting by time to the c++ side
}

## TODO: move sorting by time to the c++ side
if (.sort) s = s[order(s$time), , drop = FALSE]

reset_rownames(s)
}
return_object(self, "TMBSimulationFormatter")
Expand Down
7 changes: 4 additions & 3 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
macpan2_dll = "macpan2"
, macpan2_verbose = TRUE
, macpan2_default_loss = c("clamped_poisson", "poisson", "sum_of_squares", "neg_bin")
, macpan2_tmb_type = "ADFun"
, macpan2_tmb_check = TRUE
, macpan2_tmb_adfun_args = list()

## FIXME: macpan2_vec_by is old and not relevant i think
, macpan2_vec_by = c("state", "flow_rates", "trans_rates") |> self_named_vector()
#, macpan2_memoise = TRUE

# functions that cannot be called unless their
# first argument has a saved simulation history
# (TODO: read this off the c++ file)
, macpan2_time_dep_funcs = c(
"convolution"
,"rbind_lag"
Expand All @@ -22,6 +22,7 @@
# functions that cannot be called repeatedly
# _within_ a single time-step (as would
# happen for example with RK4 state updates)
# (TODO: read this off the c++ file)
, macpan2_non_iterable_funcs = c(
"time_var"
, "rbinom"
Expand Down
4 changes: 2 additions & 2 deletions inst/starter_models/si/tmb.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
library(macpan2)

spec = mp_tmb_model_spec(
before = list(S ~ N - 1)
, during = list(mp_per_capita_flow("S", "I", infection ~ beta * I / N))
before = S ~ N - 1
, during = mp_per_capita_flow("S", "I", "beta * I / N", "infection")
, default = list(N = 100, beta = 0.2, I = 1)
)
specs = list(
Expand Down
3 changes: 2 additions & 1 deletion man/mp_tmb_model_spec.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 2 additions & 15 deletions tests/testthat/test-expr-parser.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,8 @@ test_that("parse_tables ...", {
form = ~ log(x) + exp(y)
environment(form) = eval_env

f = make_expr_parser("f", finalizer = finalizer_char)
g = make_expr_parser("g", finalizer = finalizer_index)

## work around testthat calling stuff from elsewhere.
## the issue is that make_expr_parser assumes that you are going to be
## calling from the same environment (or at least from an environment
## that can reach the environment in which the function was made) -- it
## is a recursive function. the idea of make_expr_parser is that it gets
## used at package loading time to create macpan2:::parse_expr, which gets
## used only by package functions. this ensures that it will be in the macpan2
## namespace, which can be easily accessed by other functions in the
## namespace. here we need to contrive an unrealistic case so that the
## test and coverage infrastructure works.
assign("f", f, envir = .GlobalEnv)
assign("g", g, envir = .GlobalEnv)
f = make_expr_parser(finalizer = finalizer_char)
g = make_expr_parser(finalizer = finalizer_index)

g_table = g(form)
f_table = f(form)
Expand Down
Loading