Skip to content

Commit 2ea6f04

Browse files
committed
Add split-chain option to rank overlay plots
Related to #333
1 parent 20910f5 commit 2ea6f04

File tree

4 files changed

+111
-4
lines changed

4 files changed

+111
-4
lines changed

Diff for: R/mcmc-traces.R

+30-4
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,9 @@ trace_style_np <- function(div_color = "red", div_size = 0.25, div_alpha = 1) {
277277
#' of rank-normalized MCMC samples. Defaults to `20`.
278278
#' @param ref_line For the rank plots, whether to draw a horizontal line at the
279279
#' average number of ranks per bin. Defaults to `FALSE`.
280+
#' @param split_chains Logical indicating whether to split each chain into two parts.
281+
#' If TRUE, each chain is split into first and second half with "_1" and "_2" suffixes.
282+
#' Defaults to `FALSE`.
280283
#' @export
281284
mcmc_rank_overlay <- function(x,
282285
pars = character(),
@@ -285,7 +288,8 @@ mcmc_rank_overlay <- function(x,
285288
facet_args = list(),
286289
...,
287290
n_bins = 20,
288-
ref_line = FALSE) {
291+
ref_line = FALSE,
292+
split_chains = FALSE) {
289293
check_ignored_arguments(...)
290294
data <- mcmc_trace_data(
291295
x,
@@ -294,7 +298,26 @@ mcmc_rank_overlay <- function(x,
294298
transformations = transformations
295299
)
296300

297-
n_chains <- unique(data$n_chains)
301+
# Split chains if requested
302+
if (split_chains) {
303+
# Calculate midpoint for each chain
304+
n_samples <- length(unique(data$iteration))
305+
midpoint <- n_samples/2
306+
307+
# Create new data frame with split chains
308+
data <- data %>%
309+
group_by(.data$chain) %>%
310+
mutate(
311+
chain = ifelse(
312+
iteration <= midpoint,
313+
paste0(.data$chain, "_1"),
314+
paste0(.data$chain, "_2")
315+
)
316+
) %>%
317+
ungroup()
318+
}
319+
320+
n_chains <- length(unique(data$chain))
298321
n_param <- unique(data$n_parameters)
299322

300323
# We have to bin and count the data ourselves because
@@ -319,6 +342,7 @@ mcmc_rank_overlay <- function(x,
319342
bin_start = unique(histobins$bin_start),
320343
stringsAsFactors = FALSE
321344
))
345+
322346
d_bin_counts <- all_combos %>%
323347
left_join(d_bin_counts, by = c("parameter", "chain", "bin_start")) %>%
324348
mutate(n = dplyr::if_else(is.na(n), 0L, n))
@@ -331,7 +355,9 @@ mcmc_rank_overlay <- function(x,
331355
mutate(bin_start = right_edge) %>%
332356
dplyr::bind_rows(d_bin_counts)
333357

334-
scale_color <- scale_color_manual("Chain", values = chain_colors(n_chains))
358+
# Update legend title based on split_chains
359+
legend_title <- if (split_chains) "Split Chains" else "Chain"
360+
scale_color <- scale_color_manual(legend_title, values = chain_colors(n_chains))
335361

336362
layer_ref_line <- if (ref_line) {
337363
geom_hline(
@@ -352,7 +378,7 @@ mcmc_rank_overlay <- function(x,
352378
}
353379

354380
ggplot(d_bin_counts) +
355-
aes(x = .data$bin_start, y = .data$n, color = .data$chain) +
381+
aes(x = .data$bin_start, y = .data$n, color = .data$chain) +
356382
geom_step() +
357383
layer_ref_line +
358384
facet_call +
Loading

Diff for: tests/testthat/data-for-mcmc-tests.R

+7
Original file line numberDiff line numberDiff line change
@@ -80,4 +80,11 @@ vdiff_dframe_rank_overlay_bins_test <- posterior::as_draws_df(
8080
)
8181
)
8282

83+
vdiff_dframe_rank_overlay_split_chain_test <- posterior::as_draws_df(
84+
list(
85+
list(theta = -2 + 0.003 * 1:1000 + stats::arima.sim(list(ar = 0.7), n = 1000, sd = 0.5)),
86+
list(theta = 1 + -0.003 * 1:1000 + stats::arima.sim(list(ar = 0.7), n = 1000, sd = 0.5))
87+
)
88+
)
89+
8390
set.seed(seed = NULL)

Diff for: tests/testthat/test-mcmc-traces.R

+7
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,10 @@ test_that("mcmc_rank_overlay renders correctly", {
157157
# https://github.com/stan-dev/bayesplot/issues/331
158158
p_not_all_bins_exist <- mcmc_rank_overlay(vdiff_dframe_rank_overlay_bins_test)
159159

160+
# https://github.com/stan-dev/bayesplot/issues/333
161+
p_split_chains <- mcmc_rank_overlay(vdiff_dframe_rank_overlay_split_chain_test,
162+
split_chains = TRUE)
163+
160164
vdiffr::expect_doppelganger("mcmc_rank_overlay (default)", p_base)
161165
vdiffr::expect_doppelganger(
162166
"mcmc_rank_overlay (reference line)",
@@ -170,6 +174,9 @@ test_that("mcmc_rank_overlay renders correctly", {
170174

171175
# https://github.com/stan-dev/bayesplot/issues/331
172176
vdiffr::expect_doppelganger("mcmc_rank_overlay (not all bins)", p_not_all_bins_exist)
177+
178+
# https://github.com/stan-dev/bayesplot/issues/333
179+
vdiffr::expect_doppelganger("mcmc_rank_overlay (split chains)", p_split_chains)
173180
})
174181

175182
test_that("mcmc_rank_hist renders correctly", {

0 commit comments

Comments
 (0)