Skip to content

Commit

Permalink
Merge pull request #777 from nlmixr2/777-seed-same-solve
Browse files Browse the repository at this point in the history
Use same random number management when the solver is not `liblsoda`
  • Loading branch information
mattfidler authored Aug 31, 2024
2 parents b1bf8b0 + cda0e0f commit 8efcfa4
Show file tree
Hide file tree
Showing 10 changed files with 183 additions and 37 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:

- uses: r-lib/actions/setup-r-dependencies@v2
with:
cache-version: 8
cache-version: 9
pak-version: stable
extra-packages: |
any::rcmdcheck
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pkgdown.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:

- uses: r-lib/actions/setup-r-dependencies@v2
with:
cache-version: 7
cache-version: 8
extra-packages: |
any::pkgdown
nlmixr2/dparser-R
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:

- uses: r-lib/actions/setup-r-dependencies@v2
with:
cache-version: 7
cache-version: 8
extra-packages: |
any::covr
nlmixr2/dparser-R
Expand Down
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Package: rxode2
Version: 2.1.3.9000
Version: 3.0.0
Title: Facilities for Simulating from ODE-Based Models
Authors@R: c(
person("Matthew L.","Fidler", role = c("aut", "cre"), email = "matthew.fidler@gmail.com", comment=c(ORCID="0000-0001-8538-6691")),
Expand Down Expand Up @@ -82,7 +82,7 @@ Imports:
checkmate,
ggplot2 (>= 3.4.0),
inline,
lotri (>= 0.5.0),
lotri (>= 1.0.0),
magrittr,
memoise,
methods,
Expand Down Expand Up @@ -116,7 +116,7 @@ RoxygenNote: 7.3.2
Biarch: true
LinkingTo:
sitmo,
lotri (>= 0.5.0),
lotri (>= 1.0.0),
PreciseSums (>= 0.7),
Rcpp,
RcppArmadillo (>= 0.9.300.2.0),
Expand Down
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
will also change. This is a more conservative protection mechanism
than was applied previously.

- Random numbers from `rxode2` are different when using `dop853`,
`lsoda` or `indLin` methods. These now seed the random numbers in
the same way as `liblsoda`, so the random number provided will be
the same with different solving methods.

## Possible breaking changes (though unlikely)

- `iCov` is no longer merged to the event dataset. This makes solving
Expand Down
100 changes: 92 additions & 8 deletions R/piping.R
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,7 @@
})
} else if (inherits(.cur, "matrix")) {
.cur2 <- .cur
if (!inherits(.cur, "lotriFix")) {
class(.cur2) <- c("lotriFix", class(.cur))
}
.unlistedBrackets <- as.list(as.expression(.cur2)[[-1]])[-1]
.unlistedBrackets <- as.list(lotri::lotriAsExpression(.cur2, plusNames=TRUE)[[-1]])[-1]
} else if (inherits(.cur, "character") && !is.null(names(.cur))) {
.unlistedBrackets <- lapply(paste(names(.cur),"=", setNames(.cur, NULL)),
str2lang)
Expand Down Expand Up @@ -238,8 +235,96 @@
.expandedForm
}

.nsEnv <- new.env(parent=emptyenv())
#' This function collapses the lotri line form to the plus form
#'
#' @param expressionList Expression list that is input to change into
#' matrix expression form the new line expressions to the classic
#' plus expressions.
#' @return expression list where lotri line for covariance matrices
#' are translated to classic plus form.
#' @author Matthew L. Fidler
#' @noRd
#' @examples
#'
#' tmp <- list(str2lang("d ~ 1"),
#' str2lang("e ~ c(0.5, 3)"),
#' str2lang("cp ~ add(add.sd)"),
#' str2lang("cp ~ add(add.sd) + prop(prop.sd)"),
#' str2lang("cp ~ + add(add.sd)"))
#'
#' .collapseLotriLineFormToPlusForm(tmp)
#'
.collapseLotriLineFormToPlusForm <- function(expressionList) {
.env <- new.env(parent=emptyenv())
.env$ret <- expressionList
.env$lst <- list()
.env$last <- NA_integer_

.f <- function() {
if (!is.na(.env$last)) {
.val <- as.call(c(list(quote(`{`)), .env$lst))
.val <- as.call(c(str2lang("lotri::lotri"), .val))
.val <- suppressMessages(try(eval(.val), silent=TRUE))
if (inherits(.val, "try-error")) {
for (.j in seq_along(.env$lst)) {
.env$ret[[.env$last + .j - 1L]] <- .env$lst[[.j]]
}
} else {
.val <- lotri::lotriAsExpression(.val, plusNames=TRUE)
.val <- lapply(seq_along(.val)[-1],
function(i){
.val[[i]]
})[[1]]
.val <- lapply(seq_along(.val)[-1],
function(i){
.val[[i]]
})
for (.j in seq_along(.val)) {
.env$ret[[.env$last + .j - 1L]] <- .val[[.j]]
}
}
.env$lst <- list()
.env$last <- NA_integer_
}
}
for (.i in seq_along(.env$ret)) {
.cur <- .env$ret[[.i]]
if (is.call(.cur) && identical(.cur[[1]], quote(`~`)) &&
length(.cur) == 3L &&
length(.cur[[2]]) == 1L # excludes ll(cp) ~ 1
) {
.isLotri <- TRUE
# Check to see if this is an error call
if (is.call(.cur[[3]])) {
.call <- deparse1(.cur[[3]][[1]])
if (.call == "+" &&
length(.cur[[3]]) >= 2 &&
is.call(.cur[[3]][[2]])) {
.call <- deparse1(.cur[[3]][[2]][[1]])
}
if (.call %in% names(.errDist)) {
.isLotri <- FALSE
}
}
if (.isLotri) {
if (is.na(.env$last)) {
.env$last <- .i
}
.env$ret[[.i]] <- NA
.env$lst <- c(.env$lst, .cur)
}
} else {
.f()
}
}
.f()
.w <- which(vapply(seq_along(.env$ret), function(i) {
!(length(.env$ret[[i]]) == 1L && is.na(.env$ret[[i]]))
}, logical(1), USE.NAMES=FALSE))
lapply(.w, function(i) { .env$ret[[i]]})
}

.nsEnv <- new.env(parent=emptyenv())

.nsEnv$.quoteCallInfoLinesAppend <- NULL
#' Returns quoted call information
Expand Down Expand Up @@ -329,8 +414,7 @@
}
.ret[[i]]
})

.ret[vapply(seq_along(.ret), function(i) {
.collapseLotriLineFormToPlusForm(.ret[vapply(seq_along(.ret), function(i) {
!is.null(.ret[[i]])
}, logical(1), USE.NAMES=FALSE)]
}, logical(1), USE.NAMES=FALSE)])
}
4 changes: 3 additions & 1 deletion R/ui.R
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@
#' @export
ini <- function(x, ..., envir = parent.frame(), append = NULL) {
if (is(substitute(x), "{")) {
.ini <- eval(bquote(lotri(.(substitute(x)))), envir=envir)
.ini <- eval(bquote(lotri::lotri(.(substitute(x)),
cov=TRUE, rcm=TRUE)),
envir=envir)
assignInMyNamespace(".lastIni", .ini)
assignInMyNamespace(".lastIniQ", bquote(.(substitute(x))))
return(invisible(.ini))
Expand Down
30 changes: 19 additions & 11 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,25 +33,32 @@ reference:
- rxCbindStudyIndividual
- title: Functions for working with nlmixr2/rxode2 functions
contents:
- as.ini
- as.model
- as.rxUi
- assertCompartmentExists
- assertCompartmentName
- assertCompartmentNew
- assertRxUi
- assertVariableExists
- assertVariableNew
- ini
- ini<-
- model
- model<-
- modelExtract
- ini
- zeroRe
- assertRxUi
- rxAppendModel
- rxFixPop
- rxRename
- update.rxUi
- as.rxUi
- ini<-
- model<-
- rxode2<-
- rxSetCovariateNamesForPiping
- rxSetPipingAuto
- rxUiDecompress
- rxUiCompress
- as.ini
- as.model
- rxUiDecompress
- rxode2<-
- testRxLinCmt
- testRxUnbounded
- update.rxUi
- zeroRe
- title: ggplot2/plot support functions
contents:
- stat_cens
Expand Down Expand Up @@ -164,6 +171,7 @@ reference:
- .rxLinCmtGen
- .rxWithOptions
- .rxWithWd
- .rxode2ptrs
- .toClassicEvid
- .vecDf
- invWR1d
Expand Down
36 changes: 25 additions & 11 deletions src/par_solve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
#define isSameTimeOp(xout, xp) (op->stiff == 0 ? isSameTimeDop(xout, xp) : isSameTime(xout, xp))
// dop853 is same time

extern "C" uint32_t getRxSeed1(int ncores);
extern "C" void setSeedEng1(uint32_t seed);
extern "C" void setRxSeedFinal(uint32_t seed);

extern "C" {
#include "dop853.h"
#include "common.h"
Expand Down Expand Up @@ -2147,14 +2151,17 @@ extern "C" void par_indLin(rx_solve *rx){
// It was buggy due to Rprint. Use REprint instead since Rprint calls the interrupt every so often....
int abort = 0;
// FIXME parallel
uint32_t seed0 = getRxSeed1(1);
for (int solveid = 0; solveid < nsim*nsub; solveid++){
if (abort == 0){
setSeedEng1(seed0 + solveid - 1 );
ind_indLin(rx, solveid, update_inis, ME, IndF);
if (displayProgress){ // Can only abort if it is long enough to display progress.
curTick = par_progress(solveid, nsim*nsub, curTick, 1, t0, 0);
}
}
}
setRxSeedFinal(seed0 + nsim*nsub);
if (abort == 1){
op->abort = 1;
/* yp0 = NULL; */
Expand Down Expand Up @@ -2290,10 +2297,6 @@ extern "C" void ind_liblsoda0(rx_solve *rx, rx_solving_options *op, struct lsoda
ind->solveTime += ((double)(clock() - t0))/CLOCKS_PER_SEC;
}

extern "C" uint32_t getRxSeed1(int ncores);
extern "C" void setSeedEng1(uint32_t seed);
extern "C" void setRxSeedFinal(uint32_t seed);

extern "C" void ind_liblsoda(rx_solve *rx, int solveid,
t_dydt_liblsoda dydt, t_update_inis u_inis){
rx_solving_options *op = &op_global;
Expand Down Expand Up @@ -2422,11 +2425,13 @@ extern "C" void par_liblsoda(rx_solve *rx){
// http://permalink.gmane.org/gmane.comp.lang.r.devel/27627
// It was buggy due to Rprint. Use REprint instead since Rprint calls the interrupt every so often....
int abort = 0;
uint32_t seed0 = getRxSeed1(cores);
#ifdef _OPENMP
#pragma omp parallel for num_threads(op->cores)
#endif
for (int solveid = 0; solveid < nsim*nsub; solveid++){
if (abort == 0){
setSeedEng1(seed0 + rx->ordId[solveid] - 1);
ind_liblsoda0(rx, op, opt, solveid, dydt_liblsoda, update_inis);
if (displayProgress){
#pragma omp critical
Expand All @@ -2443,6 +2448,7 @@ extern "C" void par_liblsoda(rx_solve *rx){
}
}
}
setRxSeedFinal(seed0 + nsim*nsub);
if (abort == 1){
op->abort = 1;
/* yp0 = NULL; */
Expand Down Expand Up @@ -2705,17 +2711,22 @@ extern "C" void par_lsoda(rx_solve *rx){

int curTick = 0;
int abort = 0;
uint32_t seed0 = getRxSeed1(1);
for (int solveid = 0; solveid < nsim*nsub; solveid++){
ind_lsoda0(rx, &op_global, solveid, neq, rwork, lrw, iwork, liw, jt,
dydt_lsoda_dum, update_inis, jdum_lsoda);
if (displayProgress){ // Can only abort if it is long enough to display progress.
curTick = par_progress(solveid, nsim*nsub, curTick, 1, t0, 0);
if (checkInterrupt()){
abort =1;
break;
if (abort == 0){
setSeedEng1(seed0 + solveid - 1 );
ind_lsoda0(rx, &op_global, solveid, neq, rwork, lrw, iwork, liw, jt,
dydt_lsoda_dum, update_inis, jdum_lsoda);
if (displayProgress){ // Can only abort if it is long enough to display progress.
curTick = par_progress(solveid, nsim*nsub, curTick, 1, t0, 0);
if (checkInterrupt()){
abort =1;
break;
}
}
}
}
setRxSeedFinal(seed0 + nsim*nsub);
if (abort == 1){
op_global.abort = 1;
} else {
Expand Down Expand Up @@ -2932,15 +2943,18 @@ void par_dop(rx_solve *rx){

int curTick = 0;
int abort = 0;
uint32_t seed0 = getRxSeed1(1);
for (int solveid = 0; solveid < nsim*nsub; solveid++){
if (abort == 0){
setSeedEng1(seed0 + solveid - 1 );
ind_dop0(rx, &op_global, solveid, neq, dydt, update_inis);
if (displayProgress && abort == 0){
if (checkInterrupt()) abort =1;
}
if (displayProgress) curTick = par_progress(solveid, nsim*nsub, curTick, 1, t0, 0);
}
}
setRxSeedFinal(seed0 + nsim*nsub);
if (abort == 1){
op->abort = 1;
} else {
Expand Down
Loading

0 comments on commit 8efcfa4

Please sign in to comment.