Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Function for intra-leaf sampling #43

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export(forge)
export(impute)
export(lik)
export(rarf)
export(sample_from_leaves)
import(data.table)
import(ranger)
import(stringr)
Expand Down
32 changes: 1 addition & 31 deletions R/adversarial_rf.R
Original file line number Diff line number Diff line change
Expand Up @@ -153,37 +153,7 @@ adversarial_rf <- function(
converged <- FALSE
while (!isTRUE(converged)) { # Adversarial loop begins...
# Create synthetic data by sampling from intra-leaf marginals
nodeIDs <- stats::predict(rf0, x_real, type = 'terminalNodes')$predictions
tmp <- data.table('tree' = rep(seq_len(num_trees), each = n),
'leaf' = as.integer(nodeIDs))
tmp2 <- tmp[sample(.N, n, replace = TRUE)]
tmp2 <- unique(tmp2[, cnt := .N, by = .(tree, leaf)])
draw_from <- rbindlist(lapply(seq_len(num_trees), function(b) {
x_real_b <- cbind(x_real, tmp[tree == b])
x_real_b[, factor_cols] <- lapply(x_real_b[, factor_cols, drop = FALSE], as.numeric)
merge(tmp2, x_real_b, by = c('tree', 'leaf'),
sort = FALSE)[, N := .N, by = .(tree, leaf)]
}))
rm(nodeIDs, tmp, tmp2)
draw_params_within <- unique(draw_from, by = c('tree','leaf'))[, .(cnt, N)]
adj_absolut_col <- rep(c(0, draw_params_within[-.N, cumsum(N)]),
times = draw_params_within$cnt)
adj_absolut <- rep(adj_absolut_col, d) + rep(seq(0, d - 1) * nrow(draw_from), each = n)
idx_drawn_within <- ceiling(runif(n * d, 0, rep(draw_params_within$N, draw_params_within$cnt)))
idx_drawn <- idx_drawn_within + adj_absolut
draw_from_stacked <- unlist(draw_from[, -c('tree', 'leaf', 'cnt', 'N')],
use.names = FALSE)
values_drawn_stacked <- data.table('col_id' = rep(seq_len(d), each = n),
'values' = draw_from_stacked[idx_drawn])
x_synth <- as.data.table(split(values_drawn_stacked, by = 'col_id', keep.by = FALSE))
setnames(x_synth, names(x_real))
if (any(factor_cols)) {
x_synth[, names(which(factor_cols))] <- lapply(names(which(factor_cols)), function(j) {
lvls[[j]][x_synth[[j]]]
})
}
rm(draw_from, draw_params_within, adj_absolut_col,
adj_absolut, idx_drawn_within, idx_drawn, draw_from_stacked)
x_synth <- sample_from_leaves(rf0, x_real, factor_cols, lvls, prep = FALSE)
# Concatenate real and synthetic data
dat <- rbind(data.frame(y = 1L, x_real),
data.frame(y = 0L, x_synth))
Expand Down
63 changes: 63 additions & 0 deletions R/sample_from_leaves.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@

#' Generate synthetic data by sampling from leaves of a random forest
#'
#' @param arf Random forest to sample from
#' @param x_real Data to sample from
#' @param factor_cols Logical vector indicating which columns are factors (optional)
#' @param lvls List of factor levels (optional)
#' @param prep Whether to prepare the data before sampling
#'
#' @returns A data.table of synthetic data
#' @export
#'
#' @examples
#' arf <- adversarial_rf(iris)
#' sample_from_leaves(arf, iris)
sample_from_leaves <- function(arf, x_real, factor_cols = NULL, lvls = NULL, prep = TRUE) {
# To avoid data.table check issues
i <- b <- cnt <- obs <- tree <- leaf <- N <- . <- NULL
# Prep data
if (prep) {
x_real <- prep_x(x_real, verbose = FALSE)
}
n <- nrow(x_real)
d <- ncol(x_real)
if (is.null(factor_cols)) {
factor_cols <- sapply(x_real, is.factor)
}
if (is.null(lvls)) {
lvls <- lapply(x_real[factor_cols], levels)
}
# Sample leaves and get values from other observations in the same leaf
nodeIDs <- stats::predict(arf, x_real, type = 'terminalNodes')$predictions
tmp <- data.table('tree' = rep(seq_len(arf$num.trees), each = n),
'leaf' = as.integer(nodeIDs))
tmp2 <- tmp[sample(.N, n, replace = TRUE)]
tmp2 <- unique(tmp2[, cnt := .N, by = .(tree, leaf)])
draw_from <- rbindlist(lapply(seq_len(arf$num.trees), function(b) {
x_real_b <- cbind(x_real, tmp[tree == b])
x_real_b[, factor_cols] <- lapply(x_real_b[, factor_cols, drop = FALSE], as.numeric)
merge(tmp2, x_real_b, by = c('tree', 'leaf'),
sort = FALSE)[, N := .N, by = .(tree, leaf)]
}))
# Draw new observations by sampling marginally from those leaves
draw_params_within <- unique(draw_from, by = c('tree','leaf'))[, .(cnt, N)]
adj_absolut_col <- rep(c(0, draw_params_within[-.N, cumsum(N)]),
times = draw_params_within$cnt)
adj_absolut <- rep(adj_absolut_col, d) + rep(seq(0, d - 1) * nrow(draw_from), each = n)
idx_drawn_within <- ceiling(runif(n * d, 0, rep(draw_params_within$N, draw_params_within$cnt)))
idx_drawn <- idx_drawn_within + adj_absolut
draw_from_stacked <- unlist(draw_from[, -c('tree', 'leaf', 'cnt', 'N')],
use.names = FALSE)
values_drawn_stacked <- data.table('col_id' = rep(seq_len(d), each = n),
'values' = draw_from_stacked[idx_drawn])
# Return synthetic data
x_synth <- as.data.table(split(values_drawn_stacked, by = 'col_id', keep.by = FALSE))
setnames(x_synth, names(x_real))
if (any(factor_cols)) {
x_synth[, names(which(factor_cols))] <- lapply(names(which(factor_cols)), function(j) {
lvls[[j]][x_synth[[j]]]
})
}
x_synth
}
29 changes: 29 additions & 0 deletions man/sample_from_leaves.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading