diff --git a/DESCRIPTION b/DESCRIPTION index def22610..06316c5f 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: monty Title: Monte Carlo Models -Version: 0.3.19 +Version: 0.3.20 Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"), email = "rich.fitzjohn@gmail.com"), person("Wes", "Hinsley", role = "aut"), diff --git a/R/progress.R b/R/progress.R index 3ebc567f..d5b017da 100644 --- a/R/progress.R +++ b/R/progress.R @@ -52,13 +52,13 @@ progress_bar_simple <- function(n_steps, every_s = 1, min_updates = 20) { progress_bar_fancy <- function(n_chains, n_steps, show_overall, single_chain = FALSE) { ## We're expecting to take a while, so we show immediately, if enabled: - oo <- options(cli.progress_show_after = 0) - on.exit(options(oo)) + oo <- options(cli.progress_show_after = 0, + cli.spinner = monty_spinner()) e <- new.env() e$n <- rep(0, n_chains) overall <- progress_overall(n_chains, n_steps, show_overall, single_chain) - fmt <- paste("Sampling {overall(e$n)} {cli::pb_bar} |", + fmt <- paste("{cli::pb_spin} Sampling {overall(e$n)} {cli::pb_bar} |", "{cli::pb_percent} ETA: {cli::pb_eta}") fmt_done <- paste( "{cli::col_green(cli::symbol$tick)} Sampled {cli::pb_total} steps", @@ -78,15 +78,16 @@ progress_bar_fancy <- function(n_chains, n_steps, show_overall, update <- function(chain_id, at) { ## Avoid writing into a closed progress bar, it will cause an ## error. We do this by checking to see if progress has changed - ## from last time we tried updating. - changed <- any(e$n[chain_id] != at, na.rm = TRUE) - if (changed) { + ## from last time we tried updating, or if we're simply + ## incomplete. + if (any(at < n_steps | at > e$n)) { e$n[chain_id] <- at cli::cli_progress_update(id = id, set = sum(e$n)) } } fail <- function() { + options(oo) cli::cli_progress_done(id, result = "failed") } @@ -155,3 +156,10 @@ with_progress_fail_on_error <- function(progress, code) { error = function(e) progress$fail(), interrupt = function(e) progress$fail()) } + + +monty_spinner <- function(date = Sys.Date()) { + getOption( + "cli.spinner", + if (format(date, "%m") == "12") "christmas" else "dots12") +} diff --git a/tests/testthat/test-progress.R b/tests/testthat/test-progress.R index b22ca716..b2901b84 100644 --- a/tests/testthat/test-progress.R +++ b/tests/testthat/test-progress.R @@ -150,3 +150,18 @@ test_that("can fail progress bar nicely", { expect_s3_class(res$result, "simpleError") expect_equal(conditionMessage(res$result), "some error") }) + + +test_that("can get default spinner", { + withr::with_options(list(cli.spinner = NULL), { + expect_equal(monty_spinner(), monty_spinner(Sys.Date())) + expect_equal(monty_spinner(as.Date("2025-03-01")), "dots12") + expect_equal(monty_spinner(as.Date("2025-12-05")), "christmas") + }) + + withr::with_options(list(cli.spinner = "dots"), { + expect_equal(monty_spinner(), "dots") + expect_equal(monty_spinner(as.Date("2025-03-01")), "dots") + expect_equal(monty_spinner(as.Date("2025-12-05")), "dots") + }) +})