Skip to content

Commit

Permalink
Some symengine fixes and expansions
Browse files Browse the repository at this point in the history
  • Loading branch information
mattfidler committed Nov 2, 2023
1 parent 0b4436f commit f1f90ab
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 24 deletions.
128 changes: 112 additions & 16 deletions R/symengine.R
Original file line number Diff line number Diff line change
Expand Up @@ -366,9 +366,21 @@ rxFun <- function(name, args, cCode) {
if (missing(args) && missing(cCode)) {
.funName <- as.character(substitute(name))
.lst <- rxFun2c(name, name=.funName)
.ret <- do.call(rxode2parse::rxFunParse, .lst)
message("converted R function '", .lst$name, "' to C with code:")
message(.lst$cCode)
.env <- new.env(parent=emptyenv())
.env$d <- list()
lapply(seq_along(.lst), function(i) {
.cur <- .lst[[i]]
do.call(rxode2parse::rxFunParse, .cur[1:3])
message("converted R function '", .cur$name, "' to C (will now use in rxode2)")
## message(.cur$cCode)
if (length(.cur) == 4L) {
.env$d <- c(.env$d, list(.cur[[4]]))
}
})
if (length(.env$d) > 0) {
message("Added derivative table for '", .lst[[1]]$name, "'")
rxD(.lst[[1]]$name, .env$d)
}
return(invisible())
}
rxode2parse::rxFunParse(name, args, cCode)
Expand Down Expand Up @@ -1346,6 +1358,9 @@ rxToSE <- function(x, envir = NULL, progress = FALSE,
}
} else {
.udf <- try(get(.fun, envir = .rxToSE.envir$parent, mode="function"), silent =TRUE)
if (inherits(.udf, "try-error")) {
.udf <- try(get(.fun, envir = rxode2parse::.udfEnvSet(NULL), mode="function"), silent =TRUE)
}
if (inherits(.udf, "try-error")) {
stop(sprintf(gettext("function '%s' or its derivatives are not supported in rxode2"), .fun),
call. = FALSE
Expand All @@ -1357,7 +1372,9 @@ rxToSE <- function(x, envir = NULL, progress = FALSE,
call. = FALSE
)
} else if (length(.ret0) == length(.f)) {
assign(.fun, .rxFunction(.fun), envir = envir)
if (is.environment(envir)){
assign(.fun, .rxFunction(.fun), envir = envir)
}
.ret0 <- unlist(.ret0)
.ret <- paste0(.fun, "(",paste(.ret0, collapse=", "), ")")
} else {
Expand Down Expand Up @@ -2142,6 +2159,10 @@ rxFromSE <- function(x, unknownDerivatives = c("forward", "central", "error"),
} else {
.fun <- paste(.ret0[[1]])
.g <- try(get(.fun, envir=.rxFromSE.envir$parent, mode="function"), silent=TRUE)
if (inherits(.g, "try-error")) {
.g <- try(get(.fun, envir=rxode2parse::.udfEnvSet(NULL),
mode="function"), silent=TRUE)
}
if (inherits(.g, "try-error")) {
stop(sprintf(gettext("'%s' not supported in symengine->rxode2"), .fun),
call. = FALSE
Expand Down Expand Up @@ -3153,6 +3174,10 @@ rxSupportedFuns <- function() {
}

.rxFunEq <- c(
"Rx_pow_di"=2,
"Rx_pow"=2,
"R_pow_di"=2,
"R_pow"=2,
"lgamma" = 1,
"abs" = 1,
"acos" = 1,
Expand Down Expand Up @@ -3305,10 +3330,10 @@ rxSupportedFuns <- function() {
}
}
.rxFun2cAssignOperators <- function(x, envir = envir) {
if (identical(x[[1]], quote(`~`))) {
stop("formulas or other expressions with '~` are not supported in translation",
call.=FALSE)
}
## if (!envir$isRx && identical(x[[1]], quote(`~`))) {
## stop("formulas or other expressions with '~` are not supported in translation",
## call.=FALSE)
## }
if (as.character(x[[2]]) %in% envir$args) {
stop("cannot assign argument '", as.character(x[[2]]),
"' in functions converted to C",
Expand All @@ -3320,8 +3345,14 @@ rxSupportedFuns <- function() {
}
envir$didAssign <- TRUE
.pre <- paste0(rep(" ", envir$n), collapse="")
return(paste0(.pre, "_lastValue = ", .lhs, " = ",
.rxFun2c(x[[3]], envir=envir), ";\n"))
if (envir$isRx) {
paste0(.lhs, " <- ",
.rxFun2c(x[[3]], envir=envir), "\n",
"rxLastValue <-", .lhs, "\n")
} else {
paste0(.pre, "_lastValue = ", .lhs, " = ",
.rxFun2c(x[[3]], envir=envir), ";\n")
}
}

.rxFun2cSquareBracket <- function(x, envir) {
Expand Down Expand Up @@ -3377,7 +3408,11 @@ rxSupportedFuns <- function() {
.cur <- .rxFun2c(.cur, envir=envir)
if(!envir$didAssign && !envir$isExpr) {
.pre <- paste0(rep(" ", envir$n), collapse="")
return(paste0(.pre, "_lastValue = ", .cur, ";\n"))
if (envir$isRx) {
return(paste0(.pre, "rxLastValue <- ", .cur, "\n"))
} else {
return(paste0(.pre, "_lastValue = ", .cur, ";\n"))
}
}
.cur
}, character(1), USE.NAMES = FALSE),
Expand Down Expand Up @@ -3419,6 +3454,7 @@ rxSupportedFuns <- function() {
} else {
# supported functions
if (identical(x[[1]], quote(`return`))) {
envir$hasReturn <- TRUE
.pre <- paste0(rep(" ", envir$n), collapse="")
envir$didAssign <- TRUE
return(paste0(.pre, "return (", .rxFun2c(x[[2]], envir=envir), ");\n"))
Expand Down Expand Up @@ -3453,7 +3489,7 @@ rxSupportedFuns <- function() {
}
}

rxFun2c <- function(fun, name) {
rxFun2c <- function(fun, name, onlyF=FALSE) {
.env <- new.env(parent=emptyenv())
.env$vars <- character(0)
if (!missing(name)) {
Expand All @@ -3465,6 +3501,8 @@ rxFun2c <- function(fun, name) {
.env$args <- names(.f)
.env$n <- 2
.env$isExpr <- FALSE
.env$isRx <- FALSE
.env$hasReturn <- FALSE
if (any(.env$args == "...")) {
stop("functions with ... in them are not supported",
call. =FALSE)
Expand All @@ -3480,7 +3518,11 @@ rxFun2c <- function(fun, name) {
.cur <- .rxFun2c(.cur, envir=.env)
if(!.env$didAssign && !.env$isExpr) {
.pre <- paste0(rep(" ", .env$n), collapse="")
return(paste0(.pre, "_lastValue = ", .cur, ";\n"))
if (.env$isRx) {
return(paste0(.pre, "rxLastValue <- ", .cur, "\n"))
} else {
return(paste0(.pre, "_lastValue = ", .cur, ";\n"))
}
}
.env$isExpr <- FALSE
.cur
Expand All @@ -3493,7 +3535,61 @@ rxFun2c <- function(fun, name) {
";\n"))
.stop <- " return _lastValue;\n}\n"
.cCode <- paste0(.start, .body, .stop)
list(name=.funName,
args=.env$args,
cCode=.cCode)
.ret <- list(name=.funName,
args=.env$args,
cCode=.cCode)

if (onlyF) {
return(.ret)
}
if (!.env$hasReturn) {
# Can calculate derivatives
# Firs create an rxode2 like model:
.env <- new.env(parent=emptyenv())
.env$isRx <- TRUE
.env$args <- names(.f)
.env$n <- 2
.env$isExpr <- FALSE
.env$hasReturn <- FALSE
.body <- as.list(body(fun))
.body <- paste(vapply(seq_along(.body)[-1], function(i) {
.extra <- .extra2 <- ""
.cur <- .body[[i]]
.env$didAssign <- FALSE
.cur <- .rxFun2c(.cur, envir=.env)
if(!.env$didAssign && !.env$isExpr) {
.pre <- paste0(rep(" ", .env$n), collapse="")
return(paste0(.pre, "rxLastValue = ", .cur, ";\n"))
}
.env$isExpr <- FALSE
.cur
},
character(1), USE.NAMES=FALSE), collapse="")
# take out if/else
.body <- rxPrune(.body)
.s <- rxS(.body)
.lastValue <- .s$rxLastValue
return(c(list(.ret),
lapply(.env$args, function(v){
.v <- symengine::D(.lastValue, symengine::S(v))
.v <- paste0("function(", paste(.env$args, collapse=", "), ") {\n", rxOptExpr(paste0("rxLastValue=", rxFromSE(.v)), msg=paste0("d(", .funName, ")/d(", v, ")")),
"\nrxLastValue}")
.v <- eval(str2lang(.v))
.dName <- paste0("rx_", .funName, "_d_", v)
.v <- rxFun2c(.v, .dName, onlyF=TRUE)
#' function(a, b, c) {
#' paste0("2*", a, "+", b)
#' },
#'
.v2 <- paste0("function(", paste(.env$args, collapse=", "), "){\n",
"paste0(\"", .dName, "(\", ",
paste(.env$args, collapse=", \", \", "), ", \")\")",
"}")
.v2 <- eval(str2lang(.v2))
c(.v, list(.v2))
})))
} else {
message("function contains return statement; derivatives not calculated")
}
return(list(.ret))
}
31 changes: 23 additions & 8 deletions tests/testthat/test-udf.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
rxTest({

e <- et(1:10) |> as.data.frame()

e$x <- 1:10
Expand Down Expand Up @@ -147,23 +148,23 @@ rxTest({
a + b
}

expect_true(grepl("R_pow_di[(]", rxFun2c(udf)$cCode))
expect_true(grepl("R_pow_di[(]", rxFun2c(udf)[[1]]$cCode))

udf <- function(x, y) {
a <- x + y
b <- a ^ x
a + b
}

expect_true(grepl("R_pow[(]", rxFun2c(udf)$cCode))
expect_true(grepl("R_pow[(]", rxFun2c(udf)[[1]]$cCode))

udf <- function(x, y) {
a <- x + y
b <- cos(a) + x
a + b
}

expect_true(grepl("cos[(]", rxFun2c(udf)$cCode))
expect_true(grepl("cos[(]", rxFun2c(udf)[[1]]$cCode))

udf <- function(x, y) {
if (a < b) {
Expand All @@ -172,7 +173,7 @@ rxTest({
a + b
}

expect_true(grepl("if [(]", rxFun2c(udf)$cCode))
expect_true(grepl("if [(]", rxFun2c(udf)[[1]]$cCode))


udf <- function(x, y) {
Expand All @@ -185,7 +186,7 @@ rxTest({
}
}

expect_true(grepl("else [{]", rxFun2c(udf)$cCode))
expect_true(grepl("else [{]", rxFun2c(udf)[[1]]$cCode))

udf <- function(x, y) {
a <- x
Expand All @@ -198,7 +199,7 @@ rxTest({
a ^ 2 + b ^ 2
}

expect_true(grepl("else if [(]", rxFun2c(udf)$cCode))
expect_true(grepl("else if [(]", rxFun2c(udf)[[1]]$cCode))

udf <- function(x, y) {
a <- x
Expand All @@ -215,15 +216,29 @@ rxTest({
a ^ 2 + b ^ 2
}

expect_true(grepl("else if [(]", rxFun2c(udf)$cCode))
expect_true(grepl("else if [(]", rxFun2c(udf)[[1]]$cCode))

udf <- function(x, y) {
a <- x + y
x <- a ^ 2
x
}

expect_error(rxFun2c(udf)$cCode)
expect_error(rxFun2c(udf)[[1]]$cCode)


udf <- function(x, y) {
a <- x
b <- x ^ 2 + a
if (a < b) {
b ^ 2
} else {
a + b
}
}

rxFun(udf)

})

test_that("udf with model functions", {
Expand Down

0 comments on commit f1f90ab

Please sign in to comment.