Skip to content

Commit

Permalink
Add tests and fixes for rxS() needed for nlmixr udf
Browse files Browse the repository at this point in the history
  • Loading branch information
mattfidler committed Nov 1, 2023
1 parent 5c7ff03 commit 116db18
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
20 changes: 19 additions & 1 deletion R/symengine.R
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,11 @@ rxD <- function(name, derivatives) {
#' @export
rxToSE <- function(x, envir = NULL, progress = FALSE,
promoteLinSens = TRUE, parent = parent.frame()) {
if (!rxode2parse::.udfEnvLock(NULL)) {
rxode2parse::.udfEnvSet(parent)
rxode2parse::.udfEnvLock(TRUE)
on.exit(rxode2parse::.udfEnvLock(FALSE))
}
.rxToSE.envir$parent <- parent
assignInMyNamespace(".promoteLinB", promoteLinSens)
assignInMyNamespace(".rxIsLhs", FALSE)
Expand Down Expand Up @@ -1352,6 +1357,7 @@ rxToSE <- function(x, envir = NULL, progress = FALSE,
call. = FALSE
)
} else if (length(.ret0) == length(.f)) {
assign(.fun, .rxFunction(.fun), envir = envir)
.ret0 <- unlist(.ret0)
.ret <- paste0(.fun, "(",paste(.ret0, collapse=", "), ")")
} else {
Expand Down Expand Up @@ -1399,6 +1405,11 @@ rxToSE <- function(x, envir = NULL, progress = FALSE,
rxFromSE <- function(x, unknownDerivatives = c("forward", "central", "error"),
parent=parent.frame()) {
rxReq("symengine")
if (!rxode2parse::.udfEnvLock(NULL)) {
rxode2parse::.udfEnvSet(parent)
rxode2parse::.udfEnvLock(TRUE)
on.exit(rxode2parse::.udfEnvLock(FALSE))
}
.rxFromSE.envir$parent <- parent
.unknown <- c("central" = 2L, "forward" = 1L, "error" = 0L)
assignInMyNamespace(".rxFromNumDer", .unknown[match.arg(unknownDerivatives)])
Expand Down Expand Up @@ -2182,7 +2193,12 @@ rxFromSE <- function(x, unknownDerivatives = c("forward", "central", "error"),
#' @return rxode2/symengine environment
#' @author Matthew Fidler
#' @export
rxS <- function(x, doConst = TRUE, promoteLinSens = FALSE) {
rxS <- function(x, doConst = TRUE, promoteLinSens = FALSE, envir=parent.frame()) {
if (!rxode2parse::.udfEnvLock(NULL)) {
rxode2parse::.udfEnvSet(envir)
rxode2parse::.udfEnvLock(TRUE)
on.exit(rxode2parse::.udfEnvLock(FALSE))
}
rxReq("symengine")
.cnst <- names(.rxSEreserved)
.env <- new.env(parent = loadNamespace("symengine"))
Expand All @@ -2198,12 +2214,14 @@ rxS <- function(x, doConst = TRUE, promoteLinSens = FALSE) {
.env$..doConst <- doConst
.rxD <- rxode2parse::rxode2parseD()
for (.f in c(
ls(rxode2parse::.symengineFs()),
ls(.rxD), "linCmtA", "linCmtB", "rxEq", "rxNeq", "rxGeq", "rxLeq", "rxLt",
"rxGt", "rxAnd", "rxOr", "rxNot", "rxTBS", "rxTBSd", "rxTBSd2", "lag", "lead",
"rxTBSi"
)) {
assign(.f, .rxFunction(.f), envir = .env)
}

for (.v in seq_along(.rxSEreserved)) {
assign(names(.rxSEreserved)[.v], .rxSEreserved[[.v]], envir = .env)
}
Expand Down
9 changes: 9 additions & 0 deletions tests/testthat/test-udf.R
Original file line number Diff line number Diff line change
Expand Up @@ -292,5 +292,14 @@ rxTest({
expect_error(suppressWarnings(rxSolve(f, qd)), NA)
})

test_that("symengine load", {

mod <- "tke=THETA[1];\nprop.sd=THETA[2];\neta.ke=ETA[1];\nke=gg(tke,exp(eta.ke));\nipre=gg(10,exp(-ke*t));\nlipre=log(ipre);\nrx_yj_~2;\nrx_lambda_~1;\nrx_low_~0;\nrx_hi_~1;\nrx_pred_f_~ipre;\nrx_pred_~rx_pred_f_;\nrx_r_~(rx_pred_f_*prop.sd)^2;\n"
gg <- function(x, y) {
PreciseSums::psProd(c(x, y))
}

rxS(mod, TRUE, TRUE)

})
})

0 comments on commit 116db18

Please sign in to comment.