Skip to content

Commit

Permalink
Overhauled counterfactuals design and code
Browse files Browse the repository at this point in the history
  • Loading branch information
rvlenth committed Dec 29, 2024
1 parent 15f129c commit 254488e
Show file tree
Hide file tree
Showing 8 changed files with 279 additions and 79 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Package: emmeans
Type: Package
Title: Estimated Marginal Means, aka Least-Squares Means
Version: 1.10.6-090001
Date: 2024-12-18
Version: 1.10.6-090002
Date: 2024-12-28
Authors@R: c(person("Russell V.", "Lenth", role = c("aut", "cre", "cph"),
email = "russell-lenth@uiowa.edu"),
person("Balazs", "Banfai", role = "ctb"),
Expand Down
10 changes: 10 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@ title: "NEWS for the emmeans package"

## emmeans 1.10-6-090xxx
* Spelling changes in several vignettes
* We have completely revamped the design of reference grids involving
counterfactuals. Now, if we specify counterfactuals `A` and `B`, the
reference grid comprises combinations of `A`, `B`, `actual_A`, and `actual_B`
the latter two used to track the original settings of `A` and `B` in the dataset.
We always average over combinations of these factors. The previous code was
a memory hog, and we have made it much more efficient for large datasets.
* `emmeans()` has also been revised to do special handling of counterfactual
reference grids. Whenever we average over a counterfactual `B`, we only
use the cases where `B == actual_B`, thus obtaining the same results as
would be obtained when `B` is not regarded as a counterfactual.


## emmeans 1.10.6
Expand Down
5 changes: 5 additions & 0 deletions R/emmGrid-methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ str.emmGrid <- function(object, ...) {
cat(paste(tmp, collapse = ", "))
cat("\n")
}
if (length(cf <- object@roles$counterfactuals) > 0) {
cat("Counterfactuals\n ")
cat(paste(cf, collapse = ", "))
cat("\n")
}
if(!is.null(object@model.info$nesting)) {
cat("Nesting structure: ")
cat(.fmt.nest(object@model.info$nesting))
Expand Down
46 changes: 44 additions & 2 deletions R/emmeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,20 @@ emmeans.list = function(object, specs, ...) {
#' \code{"cells"}, except nonempty cells are weighted equally and empty cells
#' are ignored.
#'
#' @section Counterfactuals:
#' Counterfactual reference grids (see the documentation for \code{\link{ref_grid}})
#' contain pairs of imputed and actual factor levels, and are handled in a special way.
#' For starters, the \code{weights} argument is ignored and we always use
#' \code{"cells"} weights.
#' Our understanding is that if factors \code{A, B} are specified as counterfactuals,
#' the marginal means for \code{A} should still be the same as if \code{A} were the only
#' counterfactual. Accordingly, in computing these marginal means, we
#' exclude all cases where \code{B != actual_B}, because if \code{A} were the only
#' counterfactual, \code{B} will stay at its actual level.
#' We also take special pains to "remember" information about actual and
#' imputed levels of counterfactuals so that appropriate results are obtained when
#' \code{emmeans} is applied to a previous \code{emmeans} result.
#'
#'
#' @section Offsets:
#' Unlike in \code{ref_grid}, an offset need not be scalar. If not enough values
Expand Down Expand Up @@ -315,6 +329,32 @@ emmeans = function(object, specs, by = NULL,
warning("emmeans() results may be corrupted by removal of a nesting structure")
}

## Handle counterfactuals...
# cf.grid is either NULL, logical, or a "parent" grid that should replace RG
if(!is.null(cf.grid <- RG@misc$cf.grid)) {
weights = "cells" # always use cells weights with counterfactuals
if (inherits(cf.grid, "emmGrid"))
RG = cf.grid
# which counterfactuals do we average over?
cf.ao = intersect(setdiff(names(RG@levels), facs), RG@roles$counterfactuals)
# zero-out weights for any cases where cf.ao levels differ from actual levels
for(f in cf.ao) {
fc = paste0("actual_", f)
excl = (RG@grid[ , f] != RG@grid[, fc])
RG@grid[excl, ".wgt."] = 0
}
# Fix up the returned grid to play along
cf.surv = intersect(facs, RG@roles$counterfactuals)
if(length(cf.surv) == 0) # no cfs remain
RG@misc$cf.grid = RG@roles$counterfactuals = NULL
else {
if(length(intersect(facs, paste0("actual_", cf.surv))) == length(cf.surv))
RG@misc$cf.grid = TRUE
else
RG@misc$cf.grid = RG # save this as "parent" grid
}
}

# Ensure object is in standard order
ord = .std.order(RG@grid, RG@levels) ###do.call(order, unname(RG@grid[rev(names(RG@levels))]))
if(any(ord != seq_along(ord)))
Expand Down Expand Up @@ -445,8 +485,10 @@ emmeans = function(object, specs, by = NULL,
RG@misc$by.vars = by
RG@misc$avgd.over = union(RG@misc$avgd.over, avgd.over)
RG@misc$methDesc = "emmeans"
RG@roles$predictors = names(levs)
### Pass up 'new' as we're not changing its class result = new("emmGrid", RG, linfct = linfct, levels = levs, grid = combs)
RG@roles$predictors = setdiff(names(levs), RG@roles$multresp)
if ((length(RG@roles$multresp) > 0) && !(RG@roles$multresp %in% names(levs)))
RG@roles$multresp = character(0)

result = as.emmGrid(RG)
result@linfct = linfct
result@levels = levs
Expand Down
215 changes: 163 additions & 52 deletions R/ref-grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,13 @@
#' \code{sigma(object)}, if available, and \code{NULL} otherwise.
#' Note: This applies only when the family is \code{"gaussian"}; for other families,
#' \code{sigma} is set to \code{NA} and cannot be overridden.
#' @param counterfactuals,wt.counter,avg.counter \code{counterfactuals} specifies character
#' @param counterfactuals \code{counterfactuals} specifies character
#' names of counterfactual factors. If this is non-missing, a reference grid
#' is created consisting of combinations of counterfactual levels and a constructed
#' factor \code{.obs.no.} having a level for each observation in the dataset.
#' By default, this grid is re-gridded with the response transformation
#' and averaged over \code{.obs.no.} (by default, with equal weights, but
#' a vector of weights may be specified in \code{wt.counter}; it must be
#' of length equal to the number of observations in the dataset).
#' If \code{avg.counter} is set to \code{FALSE}, this averaging is disabled.
#' See the section below on counterfactuals.
#' is created consisting of combinations of counterfactual levels
#' and the actual levels of those same factors.
#' This grid is always converted to the response transformation scale
#' and averaged over the actual factor levels. See the section below
#' on counterfactuals.
#' @param nuisance,non.nuisance,wt.nuis If \code{nuisance} is a vector of predictor names,
#' those predictors are omitted from the reference grid. Instead, the result
#' will be as if we had averaged over the levels of those factors, with either
Expand Down Expand Up @@ -318,23 +315,29 @@
#'
#' @section Counterfactuals:
#' If \code{counterfactuals} is specified, the rows of the entire dataset
#' become a factor in the reference grid, and the other reference levels are
#' become part of the reference grid, and the other reference levels are
#' confined to those named in \code{counterfactuals}. In this type of analysis
#' (called G-computation), we substitute each combination of counterfactual
#' (called G-computation), we substitute (or impute) each combination of counterfactual
#' levels into the entire dataset. Thus, predictions from this grid are those
#' of each observation under each of the counterfactual levels. For this to
#' make sense, we require an assumption of exchangeability of these levels.
#'
#' By default, this grid is converted to the response scale (unless otherwise
#' specified in \code{regrid}) and averaged over the observations in the dataset.
#' Averaging can be disabled by setting \code{avg.counter = FALSE}, but
#' be warned that the resulting reference grid is potentially huge -- the
#' number of observations in the dataset times the number of counterfactual
#' combinations, times the number of multivariate levels.
#' This grid is always converted to the response scale, as G-computation on
#' the linear-predictor scale produces the same results as ordinary weighted EMMs.
#' If we have counterfactual factors \code{A, B}, the reference grid also includes
#' factors \code{actual_A, actual_B} which are used to track which observations
#' originally had the \code{A, B} levels before they were changed by the
#' counterfactuals code. We average the response-scale predictions for each
#' combination of actual levels and imputed levels (and multivariate levels,
#' if any). See additional discussion of how \code{\link{emmeans}} handles
#' counterfactuals under that documentation.
#'
#' The counterfactuals code is still fairly rudimentary and we can't guarantee
#' it will always work, such as in cases of nested models. Sometimes, an error
#' can be averted by specifying \code{avg.counter = FALSE}.
#' Currently, counterfactuals are not supported when the reference grid
#' requires post-processing (e.g., ordinal models with \code{mode = "prob"}).
#' Cases where we have nested factor levels can be complicated if mixed-in with counterfactuals,
#' and we make no guarantees.
#' Note that past implementations included arguments \code{wt.counter} and \code{avg.counter},
#' which are now deprecated and are just ignored if specified.
#'
#' @section Optional side effect: If the \code{save.ref_grid} option is set to
#' \code{TRUE} (see \code{\link{emm_options}}),
Expand Down Expand Up @@ -440,10 +443,16 @@ ref_grid <- function(object, at, cov.reduce = mean, cov.keep = get_emm_option("c
mult.names, mult.levs,
options = get_emm_option("ref_grid"), data, df, type,
regrid, nesting, offset, sigma,
counterfactuals, wt.counter, avg.counter = TRUE,
counterfactuals, ## wt.counter, avg.counter = TRUE,
nuisance = character(0), non.nuisance, wt.nuis = "equal",
rg.limit = get_emm_option("rg.limit"), ...)
{
if(!missing(counterfactuals)) { # route this to a different routine
cl = match.call()
cl[[1]] = as.name(".cf.refgrid") # internal function for counterfactuals
return(eval(cl))
}

# hack to ignore 'tran' in dots arguments and interpret 'transform' as `regrid` :
.foo = function(t,tr,tra,tran, transform = NULL, ...) transform
.bar = .foo(...)
Expand Down Expand Up @@ -598,17 +607,17 @@ ref_grid <- function(object, at, cov.reduce = mean, cov.keep = get_emm_option("c

# Now create the reference grid
if(no.nuis <- (length(nuisance) == 0)) {
if (!missing(counterfactuals)) {
cfac = intersect(counterfactuals, names(ref.levels))
ref.levels = ref.levels[cfac]
ref.levels$.obs.no. = seq_len(nrow(data))
.check.grid(ref.levels, rg.limit)
grid = .setup.cf(ref.levels, data)
}
else {
.check.grid(ref.levels, rg.limit)
grid = do.call(expand.grid, ref.levels)
}
# if (!missing(counterfactuals)) {
# cfac = intersect(counterfactuals, names(ref.levels))
# ref.levels = ref.levels[cfac]
# ref.levels$.obs.no. = seq_len(nrow(data))
# .check.grid(ref.levels, rg.limit)
# grid = .setup.cf(ref.levels, data)
# }
# ## else {
.check.grid(ref.levels, rg.limit)
grid = do.call(expand.grid, ref.levels)
##}
}
else {
nuis.info = .setup.nuis(nuisance, ref.levels, trms, rg.limit)
Expand Down Expand Up @@ -671,19 +680,20 @@ ref_grid <- function(object, at, cov.reduce = mean, cov.keep = get_emm_option("c

# we've added args `misc` and `options` so emm_basis methods can access and use these if they want
basis = emm_basis(object, trms, xl, grid, misc = attr(data, "misc"), options = options, ...)

environment(basis$dffun) = baseenv() # releases unnecessary storage
if(length(basis$bhat) != ncol(basis$X))
stop("Something went wrong:\n",
" Non-conformable elements in reference grid.",
call. = TRUE)

collapse = NULL
if (!missing(counterfactuals)) {
grid = do.call(expand.grid, ref.levels)
if (missing(regrid))
regrid = "response"
if (avg.counter) collapse = ".obs.no."
}
# if (!missing(counterfactuals)) {
# grid = do.call(expand.grid, ref.levels)
# if (missing(regrid))
# regrid = "response"
# if (avg.counter) collapse = ".obs.no."
# }

if(!no.nuis) {
basis = .basis.nuis(basis, nuis.info, wt.nuis, ref.levels, data, grid, ref.levels)
Expand Down Expand Up @@ -904,7 +914,7 @@ ref_grid <- function(object, at, cov.reduce = mean, cov.keep = get_emm_option("c
post.beta = matrix(NA)

predictors = intersect(attr(data, "predictors"), names(grid))
if(!missing(counterfactuals)) predictors = c(predictors, ".obs.no.")
# if(!missing(counterfactuals)) predictors = c(predictors, ".obs.no.")

simp.tbl = environment(trms)$.simplify.names.
if (! is.null(simp.tbl)) {
Expand Down Expand Up @@ -962,7 +972,7 @@ ref_grid <- function(object, at, cov.reduce = mean, cov.keep = get_emm_option("c
result = hook(result, ...)
}
if(!missing(regrid)) {
if(missing(wt.counter)) wt.counter = 1
# if(missing(wt.counter)) wt.counter = 1
result = regrid(result, transform = regrid, sigma = sigma,
.collapse = collapse, wt.counter = wt.counter, ...)
if(!is.null(collapse))
Expand Down Expand Up @@ -1145,18 +1155,119 @@ ref_grid <- function(object, at, cov.reduce = mean, cov.keep = get_emm_option("c
basis
}

# Internal function to do reference grid for counterfactuals
.cf.refgrid = function(object, counterfactuals, data, options = list(), ...) {
if(missing(data))
data = recover_data(object, ...)
# Start with just the ordinary reference grid
rg = ref_grid(object, data = data, ...)
cfac = intersect(counterfactuals, names(rg@levels))
clevs = rg@levels[cfac]
cgrid = do.call(expand.grid, clevs)

# Get the stuff we need for each main dataset step
link = .get.link(rg@misc)
if(is.null(link))
link = make.link("identity")
trms = rg@model.info$terms
xlev = rg@model.info$xlev
misc = list()
offset = .get.offset(object, grid = data)
k = ifelse(length(mr <- rg@roles$multresp) == 0, 1, length(rg@levels[[mr]])) # grid expansion factor

# Index sets for combinations of factors
cidx = apply(cgrid, 1, function(x) {
flag = data[[cfac[1]]] == x[1]
if(length(x) > 1)
for (col in 2:length(x))
flag = flag & data[[cfac[col]]] == x[col]
which(flag)
}, simplify = FALSE)
# account for any NAs in bhat
nonNA = !is.na(rg@bhat)
# ensure we include all levels of cfacs with data
all.active = sort(unlist(cidx))
deadrows = sapply(cidx, function(x) x[1])
offset = c(offset, rep(mean(offset), length(deadrows)))
data = rbind(data, data[deadrows, ])
n = nrow(data)
mymean = function(x) ifelse(is.null(x), NA, mean(x))

## Compile the averaged results for delta method
DL = matrix(nrow = 0, ncol = sum(nonNA))
bh = numeric(0)
for (i in seq_len(nrow(cgrid))) {
g = data
for(j in cfac)
g[all.active, j] = cgrid[i, j]
bas = emm_basis(object, trms, xlev, g, ...)
if(!is.null(bas$misc$postGridHook))
stop("Sorry, we do not support counterfactuals for this situation.")
X = bas$X[, nonNA, drop = FALSE]
eta = X %*% bas$bhat[nonNA]
yhat = link$linkinv(eta + offset)
d = link$mu.eta(eta)
X = sweep(X, 1, d, "*")

pos = 0
XX = matrix(nrow = 0, ncol = ncol(X))
for(I in 1:k) {
XX = sapply(cidx, \(i) apply(X[pos + i, , drop = FALSE], 2, mymean))
DL = rbind(DL, t(XX))
yy = sapply(cidx, \(i) ifelse(length(i) == 0, NA, mean(yhat[i + pos])))
bh = c(bh, yy)
pos = pos + n
}
}
RG = rg
RG@bhat = bh
nonNA = !is.na(bh)
RG@linfct = diag(nrow(DL))
RG@V = DL %*% rg@V %*% t(DL)
levs = rg@levels
levs[cfac] = NULL
alevs = clevs
names(alevs) = paste0("actual_", cfac)
levs = c(alevs, clevs)
if (k > 1)
levs = c(levs, rg@levels[length(rg@levels)])
RG@levels = levs
wts = sapply(cidx, length)
RG@grid = do.call("expand.grid", levs)
RG@grid$.wgt. = rep(wts, length(bh)/length(wts))
misc = rg@misc
if(!is.null(misc$inv.lbl))
misc$estName = misc$inv.lbl
misc[c("tran", "inv.lbl", "sigma")] = NULL
RG@misc = c(misc, famSize = length(bh), cf.grid = TRUE)
RG@model.info$model.matrix = NULL
RG@roles$predictors = c(names(alevs), names(clevs))
RG@roles$counterfactuals = cfac
if (all(nonNA))
RG@nbasis = estimability::all.estble
else {
RG@nbasis = matrix(0, ncol = sum(!nonNA), nrow = length(bh))
idx = which(!nonNA)
for (j in seq_len(ncol(RG@nbasis)))
RG@nbasis[idx[j], j] = 1
RG@V = RG@V[nonNA, nonNA]
}
RG
}


## OLD CODE...
# Set up grid for counterfactuals - i.e., copies of whole dataset with
# counterfactual levels substituted, with obs index varying the slowest
.setup.cf = function(levs, data) {
lv = arg = levs[-length(levs)] # omit .obs.no.
arg$stringsAsFactors = FALSE
g = do.call(expand.grid, arg)
idx = rep(seq_len(nrow(data)), each = nrow(g))
xdata = data[idx, , drop = FALSE]
idx = matrix(seq_len(nrow(xdata)), nrow = nrow(g))
for (i in seq_along(g[, 1]))
for (j in names(lv))
xdata[idx[i, ], j] = g[i, j]
xdata
}
# .setup.cf = function(levs, data) {
# lv = arg = levs[-length(levs)] # omit .obs.no.
# arg$stringsAsFactors = FALSE
# g = do.call(expand.grid, arg)
# idx = rep(seq_len(nrow(data)), each = nrow(g))
# xdata = data[idx, , drop = FALSE]
# idx = matrix(seq_len(nrow(xdata)), nrow = nrow(g))
# for (i in seq_along(g[, 1]))
# for (j in names(lv))
# xdata[idx[i, ], j] = g[i, j]
# xdata
# }
Loading

0 comments on commit 254488e

Please sign in to comment.