@@ -277,6 +277,9 @@ trace_style_np <- function(div_color = "red", div_size = 0.25, div_alpha = 1) {
277
277
# ' of rank-normalized MCMC samples. Defaults to `20`.
278
278
# ' @param ref_line For the rank plots, whether to draw a horizontal line at the
279
279
# ' average number of ranks per bin. Defaults to `FALSE`.
280
+ # ' @param split_chains Logical indicating whether to split each chain into two parts.
281
+ # ' If TRUE, each chain is split into first and second half with "_1" and "_2" suffixes.
282
+ # ' Defaults to `FALSE`.
280
283
# ' @export
281
284
mcmc_rank_overlay <- function (x ,
282
285
pars = character (),
@@ -285,7 +288,8 @@ mcmc_rank_overlay <- function(x,
285
288
facet_args = list (),
286
289
... ,
287
290
n_bins = 20 ,
288
- ref_line = FALSE ) {
291
+ ref_line = FALSE ,
292
+ split_chains = FALSE ) {
289
293
check_ignored_arguments(... )
290
294
data <- mcmc_trace_data(
291
295
x ,
@@ -294,7 +298,26 @@ mcmc_rank_overlay <- function(x,
294
298
transformations = transformations
295
299
)
296
300
297
- n_chains <- unique(data $ n_chains )
301
+ # Split chains if requested
302
+ if (split_chains ) {
303
+ # Calculate midpoint for each chain
304
+ n_samples <- length(unique(data $ iteration ))
305
+ midpoint <- n_samples / 2
306
+
307
+ # Create new data frame with split chains
308
+ data <- data %> %
309
+ group_by(.data $ chain ) %> %
310
+ mutate(
311
+ chain = ifelse(
312
+ iteration < = midpoint ,
313
+ paste0(.data $ chain , " _1" ),
314
+ paste0(.data $ chain , " _2" )
315
+ )
316
+ ) %> %
317
+ ungroup()
318
+ }
319
+
320
+ n_chains <- length(unique(data $ chain ))
298
321
n_param <- unique(data $ n_parameters )
299
322
300
323
# We have to bin and count the data ourselves because
@@ -319,6 +342,7 @@ mcmc_rank_overlay <- function(x,
319
342
bin_start = unique(histobins $ bin_start ),
320
343
stringsAsFactors = FALSE
321
344
))
345
+
322
346
d_bin_counts <- all_combos %> %
323
347
left_join(d_bin_counts , by = c(" parameter" , " chain" , " bin_start" )) %> %
324
348
mutate(n = dplyr :: if_else(is.na(n ), 0L , n ))
@@ -331,7 +355,9 @@ mcmc_rank_overlay <- function(x,
331
355
mutate(bin_start = right_edge ) %> %
332
356
dplyr :: bind_rows(d_bin_counts )
333
357
334
- scale_color <- scale_color_manual(" Chain" , values = chain_colors(n_chains ))
358
+ # Update legend title based on split_chains
359
+ legend_title <- if (split_chains ) " Split Chains" else " Chain"
360
+ scale_color <- scale_color_manual(legend_title , values = chain_colors(n_chains ))
335
361
336
362
layer_ref_line <- if (ref_line ) {
337
363
geom_hline(
@@ -352,7 +378,7 @@ mcmc_rank_overlay <- function(x,
352
378
}
353
379
354
380
ggplot(d_bin_counts ) +
355
- aes(x = .data $ bin_start , y = .data $ n , color = .data $ chain ) +
381
+ aes(x = .data $ bin_start , y = .data $ n , color = .data $ chain ) +
356
382
geom_step() +
357
383
layer_ref_line +
358
384
facet_call +
0 commit comments