From 116db18a55a27c297d3f6b33a405c4ee9c9ddd41 Mon Sep 17 00:00:00 2001 From: "Matthew L. Fidler" Date: Wed, 1 Nov 2023 17:43:36 -0500 Subject: [PATCH] Add tests and fixes for rxS() needed for nlmixr udf --- R/symengine.R | 20 +++++++++++++++++++- tests/testthat/test-udf.R | 9 +++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/R/symengine.R b/R/symengine.R index 13db5ab85..de69e48d1 100644 --- a/R/symengine.R +++ b/R/symengine.R @@ -497,6 +497,11 @@ rxD <- function(name, derivatives) { #' @export rxToSE <- function(x, envir = NULL, progress = FALSE, promoteLinSens = TRUE, parent = parent.frame()) { + if (!rxode2parse::.udfEnvLock(NULL)) { + rxode2parse::.udfEnvSet(parent) + rxode2parse::.udfEnvLock(TRUE) + on.exit(rxode2parse::.udfEnvLock(FALSE)) + } .rxToSE.envir$parent <- parent assignInMyNamespace(".promoteLinB", promoteLinSens) assignInMyNamespace(".rxIsLhs", FALSE) @@ -1352,6 +1357,7 @@ rxToSE <- function(x, envir = NULL, progress = FALSE, call. = FALSE ) } else if (length(.ret0) == length(.f)) { + assign(.fun, .rxFunction(.fun), envir = envir) .ret0 <- unlist(.ret0) .ret <- paste0(.fun, "(",paste(.ret0, collapse=", "), ")") } else { @@ -1399,6 +1405,11 @@ rxToSE <- function(x, envir = NULL, progress = FALSE, rxFromSE <- function(x, unknownDerivatives = c("forward", "central", "error"), parent=parent.frame()) { rxReq("symengine") + if (!rxode2parse::.udfEnvLock(NULL)) { + rxode2parse::.udfEnvSet(parent) + rxode2parse::.udfEnvLock(TRUE) + on.exit(rxode2parse::.udfEnvLock(FALSE)) + } .rxFromSE.envir$parent <- parent .unknown <- c("central" = 2L, "forward" = 1L, "error" = 0L) assignInMyNamespace(".rxFromNumDer", .unknown[match.arg(unknownDerivatives)]) @@ -2182,7 +2193,12 @@ rxFromSE <- function(x, unknownDerivatives = c("forward", "central", "error"), #' @return rxode2/symengine environment #' @author Matthew Fidler #' @export -rxS <- function(x, doConst = TRUE, promoteLinSens = FALSE) { +rxS <- function(x, doConst = TRUE, promoteLinSens = FALSE, envir=parent.frame()) { + if (!rxode2parse::.udfEnvLock(NULL)) { + rxode2parse::.udfEnvSet(envir) + rxode2parse::.udfEnvLock(TRUE) + on.exit(rxode2parse::.udfEnvLock(FALSE)) + } rxReq("symengine") .cnst <- names(.rxSEreserved) .env <- new.env(parent = loadNamespace("symengine")) @@ -2198,12 +2214,14 @@ rxS <- function(x, doConst = TRUE, promoteLinSens = FALSE) { .env$..doConst <- doConst .rxD <- rxode2parse::rxode2parseD() for (.f in c( + ls(rxode2parse::.symengineFs()), ls(.rxD), "linCmtA", "linCmtB", "rxEq", "rxNeq", "rxGeq", "rxLeq", "rxLt", "rxGt", "rxAnd", "rxOr", "rxNot", "rxTBS", "rxTBSd", "rxTBSd2", "lag", "lead", "rxTBSi" )) { assign(.f, .rxFunction(.f), envir = .env) } + for (.v in seq_along(.rxSEreserved)) { assign(names(.rxSEreserved)[.v], .rxSEreserved[[.v]], envir = .env) } diff --git a/tests/testthat/test-udf.R b/tests/testthat/test-udf.R index 7360f0f87..32dabdd7e 100644 --- a/tests/testthat/test-udf.R +++ b/tests/testthat/test-udf.R @@ -292,5 +292,14 @@ rxTest({ expect_error(suppressWarnings(rxSolve(f, qd)), NA) }) + test_that("symengine load", { + mod <- "tke=THETA[1];\nprop.sd=THETA[2];\neta.ke=ETA[1];\nke=gg(tke,exp(eta.ke));\nipre=gg(10,exp(-ke*t));\nlipre=log(ipre);\nrx_yj_~2;\nrx_lambda_~1;\nrx_low_~0;\nrx_hi_~1;\nrx_pred_f_~ipre;\nrx_pred_~rx_pred_f_;\nrx_r_~(rx_pred_f_*prop.sd)^2;\n" + gg <- function(x, y) { + PreciseSums::psProd(c(x, y)) + } + + rxS(mod, TRUE, TRUE) + + }) })