Skip to content

Commit

Permalink
Merge pull request #295 from mrc-ide/mrc-4318
Browse files Browse the repository at this point in the history
Support for differentiation of parameters in DSL
  • Loading branch information
weshinsley authored Jun 29, 2023
2 parents 4f9f69e + 5c02f09 commit fdc2be1
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 17 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: odin
Title: ODE Generation and Integration
Version: 1.5.3
Version: 1.5.4
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "rich.fitzjohn@gmail.com"),
person("Thibaut", "Jombart", role = "ctb"),
Expand Down
74 changes: 58 additions & 16 deletions R/ir_parse.R
Original file line number Diff line number Diff line change
Expand Up @@ -479,38 +479,67 @@ ir_parse_packing_internal <- function(names, rank, len, variables,
## few different places. It might be worth trying to shift more of
## this classification into the initial equation parsing.
ir_parse_features <- function(eqs, debug, config, source) {
is_update <- vlapply(eqs, function(x) identical(x$lhs$special, "update"))
is_deriv <- vlapply(eqs, function(x) identical(x$lhs$special, "deriv"))
is_output <- vlapply(eqs, function(x) identical(x$lhs$special, "output"))
is_dim <- vlapply(eqs, function(x) identical(x$lhs$special, "dim"))
is_lhs_update <- vlapply(eqs, function(x) identical(x$lhs$special, "update"))
is_lhs_deriv <- vlapply(eqs, function(x) identical(x$lhs$special, "deriv"))
is_lhs_output <- vlapply(eqs, function(x) identical(x$lhs$special, "output"))
is_lhs_dim <- vlapply(eqs, function(x) identical(x$lhs$special, "dim"))
is_user <- vlapply(eqs, function(x) !is.null(x$user))
is_delay <- vlapply(eqs, function(x) !is.null(x$delay))
is_interpolate <- vlapply(eqs, function(x) !is.null(x$interpolate))
is_stochastic <- vlapply(eqs, function(x) isTRUE(x$stochastic))
is_data <- vlapply(eqs, function(x) !is.null(x$data))
is_compare <- vlapply(eqs, function(x) identical(x$lhs$special, "compare"))
is_lhs_compare <- vlapply(eqs,
function(x) identical(x$lhs$special, "compare"))
is_user_differentiate <- vlapply(eqs,
function(x) isTRUE(x$user$differentiate))

## We'll support other debugging bits later, I imagine.
is_debug_print <- vlapply(debug, function(x) x$type == "print")

if (!any(is_update | is_deriv)) {
if (!any(is_lhs_update | is_lhs_deriv)) {
ir_parse_error("Did not find a deriv() or an update() call",
NULL, NULL)
}

list(continuous = any(is_deriv),
discrete = any(is_update),
mixed = any(is_update) && any(is_deriv),
has_array = any(is_dim),
has_output = any(is_output),
continuous <- any(is_lhs_deriv)
has_compare <- any(is_lhs_compare)
has_array <- any(is_lhs_dim)
has_derivative <- any(is_user_differentiate)

## Most of these constraints go away later, might as well throw them
## early though; we could put it into a preliminary check for
## differentiability but in some ways thast just complicates things.
if (has_derivative) {
if (!has_compare) {
## (this one is fundamental; this just can't be done!
ir_parse_error("You need a compare expression to differentiate!",
ir_parse_error_lines(eqs[is_user_differentiate]), source)
}
if (continuous) {
ir_parse_error("Can't use differentiate with continuous time models",
ir_parse_error_lines(eqs[is_user_differentiate]), source)
}
if (has_array) {
ir_parse_error(
"Can't use differentiate with models that use arrays",
ir_parse_error_lines(eqs[is_user_differentiate | is_lhs_dim]), source)
}
}

list(continuous = continuous,
discrete = any(is_lhs_update),
mixed = any(is_lhs_update) && continuous,
has_array = has_array,
has_output = any(is_lhs_output),
has_user = any(is_user),
has_delay = any(is_delay),
has_interpolate = any(is_interpolate),
has_stochastic = any(is_stochastic),
has_data = any(is_data),
has_compare = any(is_compare),
has_compare = has_compare,
has_include = !is.null(config$include),
has_debug = any(is_debug_print),
has_derivative = has_derivative,
initial_time_dependent = NULL)
}

Expand Down Expand Up @@ -1040,7 +1069,9 @@ ir_parse_expr_rhs_user <- function(rhs, line, source) {
ir_parse_error("Only first argument to user() may be unnamed", line, source)
}

m <- match.call(function(default, integer, min, max, ...) NULL, rhs, FALSE)
m <- match.call(
function(default, integer, min, max, differentiate, ...) NULL,
rhs, FALSE)
extra <- m[["..."]]
if (!is.null(extra)) {
ir_parse_error(sprintf("Unknown %s to user(): %s",
Expand All @@ -1063,12 +1094,23 @@ ir_parse_expr_rhs_user <- function(rhs, line, source) {
if (length(deps$variables) > 0L) {
ir_parse_error("user() call must not reference variables", line, source)
}
## TODO: the 'dim' part here is not actually known yet!

integer <- m$integer %||% FALSE
differentiate <- m$differentiate %||% FALSE

if (differentiate && integer) {
ir_parse_error("Can't differentiate integer parameters",
line, source)
}

## NOTE: the 'dim' part here is not actually known yet!
user <- list(default = m$default,
dim = FALSE,
integer = m$integer %||% FALSE,
integer = integer,
min = m$min,
max = m$max)
max = m$max,
differentiate = differentiate)

list(user = user)
}

Expand Down
1 change: 1 addition & 0 deletions inst/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@
"has_stochastic": { "type": "boolean" },
"has_include": { "type": "boolean" },
"has_debug": { "type": "boolean" },
"has_derivative": { "type": "boolean" },
"initial_time_dependent": { "type": "boolean" }
},
"required": ["discrete", "has_array", "has_output", "has_user",
Expand Down
64 changes: 64 additions & 0 deletions tests/testthat/test-parse2-differentiate.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
test_that("Can parse with differentiable parameters", {
ir <- odin_parse({
initial(x) <- 1
update(x) <- rnorm(0, 0.1)
d <- data()
compare(d) ~ normal(0, scale)
scale <- user(differentiate = TRUE)
})

d <- ir_deserialise(ir)
expect_true(d$features$has_derivative)
})


test_that("can't differentiate integer parameters", {
expect_error(odin_parse({
initial(x) <- 1
update(x) <- rnorm(0, 0.1)
d <- data()
compare(d) ~ normal(x, scale)
scale <- user(differentiate = TRUE, integer = TRUE)
}),
"Can't differentiate integer parameters\\s+scale <-")
})


test_that("can't differentiate without compare", {
expect_error(
odin_parse({
initial(x) <- 1
update(x) <- rnorm(x, scale)
scale <- user(differentiate = TRUE)
}),
"You need a compare expression to differentiate!\\s+scale <-")
})


test_that("can't differentiate continuous time models", {
expect_error(
odin_parse({
initial(x) <- 1
deriv(x) <- 1
d <- data()
compare(d) ~ normal(x, scale)
scale <- user(differentiate = TRUE)
}),
"Can't use differentiate with continuous time models\\s+scale <-")
})


test_that("can't differentiate models with arrays", {
err <- expect_error(
odin_parse({
initial(x[]) <- 1
update(x[]) <- rnorm(x, 1)
dim(x) <- 5
d <- data()
compare(d) ~ normal(sum(x), scale)
scale <- user(differentiate = TRUE)
}),
"Can't use differentiate with models that use arrays")
expect_match(err$message, "dim(x) <-", fixed = TRUE)
expect_match(err$message, "scale <-", fixed = TRUE)
})

0 comments on commit fdc2be1

Please sign in to comment.