Skip to content

Commit

Permalink
Merge pull request #133 from mrc-ide/mrc-6133
Browse files Browse the repository at this point in the history
Support for differentiation through arrays
  • Loading branch information
weshinsley authored Jan 10, 2025
2 parents 7dc0478 + 7e68bb6 commit 47b9656
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 1 deletion.
138 changes: 137 additions & 1 deletion R/dsl-differentiate-expr.R
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,28 @@ derivative <- list(
`(` = function(expr, name) {
differentiate(expr[[2]], name)
},
`[` = function(expr, name) {
target <- as.character(expr[[2]])
if (target != name) {
return(0)
}
index <- as.list(expr[-(1:2)])

## Assume odin's indexes for now
idx <- c("i", "j", "k", "l", "i5", "i6", "i7", "i8")[seq_along(index)]
i <- Map(maths$is_same, index, lapply(idx, as.name))

if (any(vlapply(i, isFALSE))) {
return(0)
}
j <- vlapply(i, isTRUE)
if (all(j)) {
return(1)
}

## Have to resort to some actual calculation here, sadly:
call("if", maths$fold("&&", i[!j]), 1, 0)
},
exp = function(expr, name) {
a <- maths$rewrite(expr[[2]])
maths$times(differentiate(a, name), call("exp", a))
Expand Down Expand Up @@ -244,6 +266,49 @@ derivative <- list(
b <- differentiate(call("lfactorial", k), name)
c <- differentiate(call("lfactorial", maths$minus(n, k)), name)
maths$minus(maths$minus(a, b), c)
},
sum = function(expr, name) {
target <- expr[[2]]
if (is.symbol(target)) {
return(if (as.character(target) == name) 1 else 0)
}
stopifnot(rlang::is_call(target, "["))

index <- as.list(target[-(1:2)])
target <- target[[2]]
if (as.character(target) != name) {
return(0)
}

## Assuming that this is all "reasonable" following odin:
index_full <- vlapply(index, rlang::is_missing)
index_range <- vlapply(index, function(i) ":" %in% all.names(i))
index_at <- !(index_full | index_range)

idx <- c("i", "j", "k", "l", "i5", "i6", "i7", "i8")[seq_along(index)]

i <- Map(maths$is_same, index[index_at], lapply(idx[index_at], as.name))
if (any(vlapply(i, isFALSE))) {
return(0)
}

if (any(index_range)) {
i <- c(
i,
Map(function(a, b) call(">=", a, b),
lapply(idx[index_range], as.name),
lapply(index[index_range], function(e) e[[2]])),
Map(function(a, b) call("<=", a, b),
lapply(idx[index_range], as.name),
lapply(index[index_range], function(e) e[[3]])))
}

j <- vlapply(i, isTRUE)
if (all(j)) {
return(1)
}

call("if", maths$fold("&&", i[!j]), 1, 0)
}
)

Expand All @@ -255,9 +320,10 @@ maths <- local({
return(.parentheses_except(x[[2]], except))
}
pass <- grepl("^[A-Za-z]", fn) ||
fn == "[" ||
(length(except) > 0 && fn %in% except) ||
"unary_minus" %in% except && .is_unary_minus(x)
if (pass) {
if (fn != "if" && pass) {
return(x)
}
call("(", x)
Expand Down Expand Up @@ -422,6 +488,76 @@ maths <- local({
ret
}
}
fold <- function(fn, x) {
stopifnot(length(x) > 0)
if (length(x) == 1) {
x[[1]]
} else {
ret <- x[[1]]
for (el in x[-1]) {
ret <- call(fn, ret, el)
}
ret
}
}
as_sum_of_parts <- function(expr) {
if (rlang::is_call(expr, c("-", "+"), 2)) {
if (rlang::is_call(expr, "-")) {
parts <- lapply(expr[-1], as_sum_of_parts)
uminus <- monty::monty_differentiation()$uminus
parts[[2]] <- lapply(parts[[2]], uminus)
unlist(parts, FALSE)
} else {
unlist(lapply(expr[-1], as_sum_of_parts), FALSE)
}
} else {
list(expr)
}
}
factorise_parts <- function(parts) {
f <- function(el) {
if (is.numeric(el)) {
list(el, 1, "")
} else if (rlang::is_call(el, "-", 1)) {
ret <- f(el[[2]])
ret[[1]] <- -1 * ret[[1]]
ret
} else {
list(1, el, rlang::hash(el))
}
}
parts <- lapply(parts, f)
id <- vcapply(parts, "[[", 3)
ret <- lapply(unname(split(parts, id)), function(el) {
n <- sum(vnapply(el, "[[", 1))
times(n, el[[1]][[2]])
})
plus_fold(ret)
}
factorise <- function(x) {
factorise_parts(as_sum_of_parts(x))
}
is_same <- function(a, b) {
if (is.numeric(a) && is.numeric(b)) {
return(a == b)
}
if (identical(a, b)) {
return(TRUE)
}
if (!is.recursive(a) && !is.recursive(b)) {
return(call("==", a, b))
}

a_parts <- as_sum_of_parts(a)
b_parts <- lapply(as_sum_of_parts(b), uminus)
ab <- factorise_parts(c(a_parts, b_parts))

if (is.numeric(ab)) {
return(ab == 0)
}

call("==", ab, 0)
}
rewrite <- function(expr) {
if (is.recursive(expr)) {
fn <- as.character(expr[[1]])
Expand Down
71 changes: 71 additions & 0 deletions tests/testthat/test-dsl-differentiation.R
Original file line number Diff line number Diff line change
Expand Up @@ -451,3 +451,74 @@ test_that("can diferentiate basic trig functions", {
differentiate(quote(tan(x)), "x"),
quote(1 / cos(x)^2))
})


test_that("differentiate expressions with arrays", {
expect_identical(differentiate(quote(x[i] + y[i]), "x"), 1)
expect_identical(differentiate(quote(x[i] + y[i]), "z"), 0)
expect_identical(differentiate(quote(x[i]^2), "x"), quote(2 * x[i]))

expect_identical(differentiate(quote((x[i] - x[i + 1])^2), "x"),
quote(2 * (x[i] - x[1 + i])))
expect_identical(differentiate(quote(x[i] - x[i + 1]), "x"), 1)
expect_identical(differentiate(quote(3 * (x[i] - x[2])), "x"),
quote(3 * (1 - (if (2 == i) 1 else 0))))

expect_identical(differentiate(quote(x[i] + x[j]), "x"),
quote(1 + if (j == i) 1 else 0))
})


test_that("differentiate complete sums with arrays", {
expect_identical(differentiate(quote(sum(x)), "x"), 1)
expect_identical(differentiate(quote(sum(x)), "y"), 0)
})


test_that("differentiate partial sums with arrays", {
expect_equal(differentiate(quote(sum(x[i, ])), "x"), 1)
expect_equal(differentiate(quote(sum(x[, i])), "x"),
quote(if (i == j) 1 else 0))
expect_equal(differentiate(quote(sum(x[i, , 3])), "x"),
quote(if (3 == k) 1 else 0))
expect_equal(differentiate(quote(sum(x[i, , a:b])), "x"),
quote(if (k >= a && k <= b) 1 else 0))
expect_equal(differentiate(quote(sum(x[, i, a:b])), "x"),
quote(if (i == j && k >= a && k <= b) 1 else 0))
})


test_that("test sameness", {
expect_true(maths$is_same(1, 1))
expect_false(maths$is_same(1, 0))
expect_true(maths$is_same(quote(i), quote(i)))
expect_true(maths$is_same(quote(j), quote(j)))
expect_equal(maths$is_same(quote(i), quote(j)), quote(i == j))
expect_equal(maths$is_same(quote(i), quote(2)), quote(i == 2))
expect_false(maths$is_same(quote(i), quote(i + 1)))
})


test_that("decompose an expression into sum of parts", {
expect_equal(maths$as_sum_of_parts(1), list(1))
expect_equal(maths$as_sum_of_parts(quote(x)), list(quote(x)))
expect_equal(maths$as_sum_of_parts(quote(x + y)), list(quote(x), quote(y)))
expect_equal(maths$as_sum_of_parts(quote(x + y + z)),
list(quote(x), quote(y), quote(z)))
expect_equal(maths$as_sum_of_parts(quote(x + 2 * y + z)),
list(quote(x), quote(2 * y), quote(z)))
expect_equal(maths$as_sum_of_parts(quote(x - y)), list(quote(x), quote(-y)))
expect_equal(maths$as_sum_of_parts(quote(x - y - z)),
list(quote(x), quote(-y), quote(-z)))
})


## Lots that this does not cover yet, it's limited to support what
## tends to happen in odin index calculations which are necessarily
## simple
test_that("factorise an expression", {
expect_equal(maths$factorise(quote(1)), quote(1))
expect_equal(maths$factorise(quote(a)), quote(a))
expect_equal(maths$factorise(quote(a + a)), quote(2 * a))
expect_equal(maths$factorise(quote(1 + 2)), quote(3))
})

0 comments on commit 47b9656

Please sign in to comment.