Skip to content

Commit

Permalink
Merge pull request #20 from bips-hb/conditional_sampling_jk
Browse files Browse the repository at this point in the history
Faster conditional sampling for multiple rows
  • Loading branch information
mnwright authored May 24, 2024
2 parents 678a1ad + fb5822d commit 9ff0f5c
Show file tree
Hide file tree
Showing 21 changed files with 898 additions and 452 deletions.
7 changes: 4 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: arf
Title: Adversarial Random Forests
Version: 0.2.0
Date: 2024-01-24
Version: 0.2.2
Date: 2024-05-24
Authors@R:
c(person(given = "Marvin N.",
family = "Wright",
Expand Down Expand Up @@ -41,10 +41,11 @@ Imports:
data.table,
ranger,
foreach,
stringr,
truncnorm
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.0
RoxygenNote: 7.3.1
Suggests:
ggplot2,
doParallel,
Expand Down
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,15 @@ export(forge)
export(lik)
import(data.table)
import(ranger)
import(stringr)
importFrom(foreach,"%do%")
importFrom(foreach,"%dopar%")
importFrom(foreach,foreach)
importFrom(foreach,getDoParWorkers)
importFrom(stats,dunif)
importFrom(stats,predict)
importFrom(stats,punif)
importFrom(stats,rbinom)
importFrom(stats,runif)
importFrom(truncnorm,dtruncnorm)
importFrom(truncnorm,etruncnorm)
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# arf 0.2.2
* Faster and vectorized conditional sampling
* Use min.bucket argument from ranger to avoid pruning if possible
* Option to sample NAs in generated data if original data contains NAs
* Stepsize in forge() to reduce memory usage
* Option for local and global finite bounds

# arf 0.2.0
* Vectorized adversarial resampling
Expand Down
41 changes: 23 additions & 18 deletions R/adversarial_rf.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,18 @@ adversarial_rf <- function(
dat <- rbind(data.frame(y = 1L, x_real),
data.frame(y = 0L, x_synth))
if (isTRUE(parallel)) {
rf0 <- ranger(y ~ ., dat, keep.inbag = TRUE, classification = TRUE,
num.trees = num_trees, min.node.size = 2L * min_node_size,
respect.unordered.factors = TRUE, ...)
num.threads <- NULL
} else {
rf0 <- ranger(y ~ ., dat, keep.inbag = TRUE, classification = TRUE,
num.trees = num_trees, min.node.size = 2L * min_node_size,
respect.unordered.factors = TRUE, num.threads = 1L, ...)
num.threads <- 1L
}
if (utils::packageVersion("ranger") >= "0.16.1") {
min.bucket <- c(min_node_size, 0)
} else {
min.bucket <- min_node_size
}
rf0 <- ranger(y ~ ., dat, keep.inbag = TRUE, classification = TRUE,
num.trees = num_trees, min.bucket = min.bucket,
respect.unordered.factors = TRUE, num.threads = num.threads, ...)

# Recurse
iters <- 0L
Expand Down Expand Up @@ -159,15 +163,9 @@ adversarial_rf <- function(
dat <- rbind(data.frame(y = 1L, x_real),
data.frame(y = 0L, x_synth))
# Train discriminator
if (isTRUE(parallel)) {
rf1 <- ranger(y ~ ., dat, keep.inbag = TRUE, classification = TRUE,
num.trees = num_trees, min.node.size = 2 * min_node_size,
respect.unordered.factors = TRUE, ...)
} else {
rf1 <- ranger(y ~ ., dat, keep.inbag = TRUE, classification = TRUE,
num.trees = num_trees, min.node.size = 2 * min_node_size,
respect.unordered.factors = TRUE, num.threads = 1, ...)
}
rf1 <- ranger(y ~ ., dat, keep.inbag = TRUE, classification = TRUE,
num.trees = num_trees, min.bucket = min.bucket,
respect.unordered.factors = TRUE, num.threads = num.threads, ...)
# Evaluate
acc0 <- 1 - rf1$prediction.error
acc <- c(acc, acc0)
Expand All @@ -191,22 +189,29 @@ adversarial_rf <- function(
if (isTRUE(prune)) {
pred <- stats::predict(rf0, x_real, type = 'terminalNodes')$predictions + 1L
prune <- function(tree) {
# Nodes to prune are leaves which contain fewer than min_node_size real samples
out <- rf0$forest$child.nodeIDs[[tree]]
leaves <- which(out[[1]] == 0L)
to_prune <- leaves[!(leaves %in% which(tabulate(pred[, tree]) >= min_node_size))]
while(length(to_prune) > 0) {
if (1 %in% to_prune) {
# Never prune the root
break
}
for (tp in to_prune) {
# Find parents
# Find parent
parent <- which((out[[1]] + 1L) == tp)
if (length(parent) > 0) {
# Left child
# If node to prune (tp) is the left child of parent, replace left child with right child
out[[1]][parent] <- out[[2]][parent]
} else {
# Right child
# If node to prune (tp) is the right child of parent, replace right child with left child
parent <- which((out[[2]] + 1L) == tp)
out[[2]][parent] <- out[[1]][parent]
}
}
# If both children of a parent are to be pruned, prune the parent in the next round
# This happens if both children have been pruned
to_prune <- which((out[[1]] + 1L) %in% to_prune)
}
return(out)
Expand Down
19 changes: 9 additions & 10 deletions R/expct.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
#' arf <- adversarial_rf(iris)
#' psi <- forde(arf, iris)
#'
#' # What is the expected value Sepal.Length?
#' # What is the expected value of Sepal.Length?
#' expct(psi, query = "Sepal.Length")
#'
#' # What if we condition on Species = "setosa"?
Expand All @@ -61,21 +61,20 @@ expct <- function(

# To avoid data.table check issues
variable <- tree <- f_idx <- cvg <- wt <- V1 <- value <- val <- family <-
mu <- sigma <- obs <- prob <- . <- NULL
mu <- sigma <- obs <- prob <- f_idx_uncond <- . <- NULL

# Prep evidence
conj <- FALSE
if (!is.null(evidence)) {
evidence <- prep_evi(params, evidence)
if (!all(c('f_idx', 'wt') %in% colnames(evidence))) {
conj <- TRUE
}
}
if (!is.null(evidence) && !(ncol(evidence) == 2 && all(c("f_idx", "wt") %in% colnames(evidence)))) {
evidence_variable <- prep_cond(evidence, params, "or")$variable
conj <- TRUE
}


# Check query
if (is.null(query)) {
if (isTRUE(conj)) {
query <- setdiff(params$meta$variable, evidence$variable)
query <- setdiff(params$meta$variable, evidence_variable)
} else {
query <- params$meta$variable
if (!is.null(evidence)) {
Expand All @@ -97,7 +96,7 @@ expct <- function(
omega[, wt := cvg / num_trees]
omega[, cvg := NULL]
} else if (conj) {
omega <- leaf_posterior(params, evidence)
omega <- cforde(params, evidence, "or")$forest[, .(f_idx = f_idx_uncond, wt = cvg)]
} else {
omega <- evidence
}
Expand Down
Loading

0 comments on commit 9ff0f5c

Please sign in to comment.