Skip to content

Commit e9025a1

Browse files
committed
Add split-chain option to rank ecdf plots
Related to #333
1 parent 2ea6f04 commit e9025a1

File tree

4 files changed

+109
-5
lines changed

4 files changed

+109
-5
lines changed

Diff for: R/mcmc-traces.R

+29-3
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,9 @@ mcmc_rank_hist <- function(x,
483483
#' @param plot_diff For `mcmc_rank_ecdf()`, a boolean specifying if the
484484
#' difference between the observed rank ECDFs and the theoretical expectation
485485
#' should be drawn instead of the unmodified rank ECDF plots.
486+
#' @param split_chains Logical indicating whether to split each chain into two parts.
487+
#' If TRUE, each chain is split into first and second half with "_1" and "_2" suffixes.
488+
#' Defaults to `FALSE`.
486489
#' @export
487490
mcmc_rank_ecdf <-
488491
function(x,
@@ -494,7 +497,8 @@ mcmc_rank_ecdf <-
494497
facet_args = list(),
495498
prob = 0.99,
496499
plot_diff = FALSE,
497-
interpolate_adj = NULL) {
500+
interpolate_adj = NULL,
501+
split_chains = FALSE) {
498502
check_ignored_arguments(...,
499503
ok_args = c("K", "pit", "prob", "plot_diff", "interpolate_adj", "M")
500504
)
@@ -505,8 +509,28 @@ mcmc_rank_ecdf <-
505509
transformations = transformations,
506510
highlight = 1
507511
)
512+
513+
# Split chains if requested
514+
if (split_chains) {
515+
data$n_chains = data$n_chains/2
516+
data$n_iterations = data$n_iterations/2
517+
n_samples <- length(unique(data$iteration))
518+
midpoint <- n_samples/2
519+
520+
data <- data %>%
521+
group_by(.data$chain) %>%
522+
mutate(
523+
chain = ifelse(
524+
iteration <= midpoint,
525+
paste0(.data$chain, "_1"),
526+
paste0(.data$chain, "_2")
527+
)
528+
) %>%
529+
ungroup()
530+
}
531+
508532
n_iter <- unique(data$n_iterations)
509-
n_chain <- unique(data$n_chains)
533+
n_chain <- length(unique(data$chain))
510534
n_param <- unique(data$n_parameters)
511535

512536
x <- if (is.null(K)) {
@@ -559,7 +583,9 @@ mcmc_rank_ecdf <-
559583
group = .data$chain
560584
)
561585

562-
scale_color <- scale_color_manual("Chain", values = chain_colors(n_chain))
586+
# Update legend title based on split_chains
587+
legend_title <- if (split_chains) "Split Chains" else "Chain"
588+
scale_color <- scale_color_manual(legend_title, values = chain_colors(n_chain))
563589

564590
facet_call <- NULL
565591
if (n_param == 1) {

0 commit comments

Comments
 (0)