Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster conditional sampling for multiple rows #20

Merged
merged 67 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
db852b7
new conditional sampling
jkapar Feb 10, 2024
d8e8f6b
fix typo
jkapar Feb 10, 2024
a44fb20
doSeq
jkapar Feb 10, 2024
ba6c1b9
some cleanup
jkapar Feb 11, 2024
3183da7
change parallelization
jkapar Feb 11, 2024
7dc6aa8
bugfix parallelization
jkapar Feb 12, 2024
89f59de
avoid par for or conditions
jkapar Feb 12, 2024
c30b60e
sample from unique leaf
jkapar Feb 13, 2024
f47f203
adapt warning for zero lik leaves
jkapar Feb 13, 2024
fd72b83
fix empty cnt conds
jkapar Feb 13, 2024
d7316ea
forde for 2 obs
jkapar Feb 13, 2024
32cc9d0
apply alpha to non-observed categories
jkapar Feb 14, 2024
af2e691
fix 'wt=1 fix'
jkapar Feb 14, 2024
f123bad
merge main
mnwright Feb 15, 2024
b14e097
add unique version number
mnwright Feb 15, 2024
ff122ce
only consider colnames in condition
mnwright Feb 15, 2024
9af0a8f
speedup single var cond
jkapar Feb 21, 2024
bf5d4ca
change finite.bounds to local in forde
jkapar Feb 27, 2024
ad35463
fix all-NA condition rows bug
jkapar Mar 1, 2024
0da0b5a
fix all NA condition rows bug
jkapar Mar 1, 2024
d7b8296
change default stepsize
jkapar Mar 4, 2024
8b89082
speedup partial conditions
jkapar Mar 6, 2024
fd2bec5
speed up partial conditions
jkapar Mar 6, 2024
752f24b
fix by.x/by.y merge bug
jkapar Mar 6, 2024
6498c4a
fix wts
jkapar Apr 29, 2024
b7dd59b
fix wts
jkapar Apr 29, 2024
4fc71d2
Merge branch 'conditional_sampling_jk' into local_finite_bounds
jkapar Apr 29, 2024
32b83d7
Merge pull request #22 from bips-hb/local_finite_bounds
jkapar Apr 29, 2024
21387c6
rename function argument condition to evidence
jkapar Apr 29, 2024
f682920
some commenting
jkapar Apr 29, 2024
ae4084b
update forde help and Rd files
mnwright May 3, 2024
77f3b25
fix checks etc.
mnwright May 3, 2024
26303aa
Merge branch 'main' into conditional_sampling_jk
mnwright May 3, 2024
0831ee1
code style etc in forde
mnwright May 3, 2024
f2b8284
docs and style in forge
mnwright May 3, 2024
ef43edb
typo
mnwright May 3, 2024
e47393b
cforde docs and style
mnwright May 3, 2024
76b0cad
add smaller/larger conditions
mnwright May 6, 2024
2ed6f25
add logical not conditions and some tests and docs
mnwright May 6, 2024
02f4d4e
data.table check issues
mnwright May 6, 2024
e95c43a
add evidence with leaves and weights
mnwright May 6, 2024
73d932d
docs etc
mnwright May 6, 2024
6696c2d
remove unoverlap_hyperrectangles (move to branch)
mnwright May 7, 2024
d9cd184
fix evidence colnames check
mnwright May 7, 2024
e49dfdc
adapt renaming in row_mode or for partial evi
jkapar May 7, 2024
464bab9
adapt expct.R to cforde
jkapar May 7, 2024
a1a40bf
bugfix expct
jkapar May 7, 2024
72ec434
small fix for or
jkapar May 7, 2024
97e2716
adapt lik.R to cforde
jkapar May 7, 2024
0141966
delete leaf_posterior and prep_evi
jkapar May 7, 2024
8ce4c15
fix checks, examples, docs
mnwright May 8, 2024
1d5ceee
use resample function to avoid problems with sampling from length 1 v…
mnwright May 8, 2024
e25368f
round before converting numeric to integer
mnwright May 8, 2024
1f4a4db
merge instead of cbind
mnwright May 8, 2024
ddc8136
handle case where no leaf is found
mnwright May 8, 2024
6101365
adjust vignette
mnwright May 8, 2024
2ecf6a0
don't try to prune the root
mnwright May 16, 2024
3da7eda
add comments to pruning
mnwright May 16, 2024
537ce53
within-row OR for mode separate
jkapar May 16, 2024
a744fe8
fix scalar and zero leafs conditions
jkapar May 16, 2024
de9cc61
respect parallel=FALSE
mnwright May 17, 2024
a57be4f
fix data.table check issues (again)
mnwright May 17, 2024
24505a6
fix check with NAs
mnwright May 17, 2024
00b77ee
fix random leaf assignment
mnwright May 17, 2024
73873a9
use class-wise min.bucket when available to avoid pruning
mnwright May 21, 2024
3c5a63f
Merge pull request #23 from bips-hb/min_bucket
mnwright May 22, 2024
fb5822d
update version and news
mnwright May 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: arf
Title: Adversarial Random Forests
Version: 0.2.0
Date: 2024-01-24
Version: 0.2.2
Date: 2024-05-24
Authors@R:
c(person(given = "Marvin N.",
family = "Wright",
Expand Down Expand Up @@ -41,10 +41,11 @@ Imports:
data.table,
ranger,
foreach,
stringr,
truncnorm
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.0
RoxygenNote: 7.3.1
Suggests:
ggplot2,
doParallel,
Expand Down
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,15 @@ export(forge)
export(lik)
import(data.table)
import(ranger)
import(stringr)
importFrom(foreach,"%do%")
importFrom(foreach,"%dopar%")
importFrom(foreach,foreach)
importFrom(foreach,getDoParWorkers)
importFrom(stats,dunif)
importFrom(stats,predict)
importFrom(stats,punif)
importFrom(stats,rbinom)
importFrom(stats,runif)
importFrom(truncnorm,dtruncnorm)
importFrom(truncnorm,etruncnorm)
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# arf 0.2.2
* Faster and vectorized conditional sampling
* Use min.bucket argument from ranger to avoid pruning if possible
* Option to sample NAs in generated data if original data contains NAs
* Stepsize in forge() to reduce memory usage
* Option for local and global finite bounds

# arf 0.2.0
* Vectorized adversarial resampling
Expand Down
41 changes: 23 additions & 18 deletions R/adversarial_rf.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,18 @@ adversarial_rf <- function(
dat <- rbind(data.frame(y = 1L, x_real),
data.frame(y = 0L, x_synth))
if (isTRUE(parallel)) {
rf0 <- ranger(y ~ ., dat, keep.inbag = TRUE, classification = TRUE,
num.trees = num_trees, min.node.size = 2L * min_node_size,
respect.unordered.factors = TRUE, ...)
num.threads <- NULL
} else {
rf0 <- ranger(y ~ ., dat, keep.inbag = TRUE, classification = TRUE,
num.trees = num_trees, min.node.size = 2L * min_node_size,
respect.unordered.factors = TRUE, num.threads = 1L, ...)
num.threads <- 1L
}
if (utils::packageVersion("ranger") >= "0.16.1") {
min.bucket <- c(min_node_size, 0)
} else {
min.bucket <- min_node_size
}
rf0 <- ranger(y ~ ., dat, keep.inbag = TRUE, classification = TRUE,
num.trees = num_trees, min.bucket = min.bucket,
respect.unordered.factors = TRUE, num.threads = num.threads, ...)

# Recurse
iters <- 0L
Expand Down Expand Up @@ -159,15 +163,9 @@ adversarial_rf <- function(
dat <- rbind(data.frame(y = 1L, x_real),
data.frame(y = 0L, x_synth))
# Train discriminator
if (isTRUE(parallel)) {
rf1 <- ranger(y ~ ., dat, keep.inbag = TRUE, classification = TRUE,
num.trees = num_trees, min.node.size = 2 * min_node_size,
respect.unordered.factors = TRUE, ...)
} else {
rf1 <- ranger(y ~ ., dat, keep.inbag = TRUE, classification = TRUE,
num.trees = num_trees, min.node.size = 2 * min_node_size,
respect.unordered.factors = TRUE, num.threads = 1, ...)
}
rf1 <- ranger(y ~ ., dat, keep.inbag = TRUE, classification = TRUE,
num.trees = num_trees, min.bucket = min.bucket,
respect.unordered.factors = TRUE, num.threads = num.threads, ...)
# Evaluate
acc0 <- 1 - rf1$prediction.error
acc <- c(acc, acc0)
Expand All @@ -191,22 +189,29 @@ adversarial_rf <- function(
if (isTRUE(prune)) {
pred <- stats::predict(rf0, x_real, type = 'terminalNodes')$predictions + 1L
prune <- function(tree) {
# Nodes to prune are leaves which contain fewer than min_node_size real samples
out <- rf0$forest$child.nodeIDs[[tree]]
leaves <- which(out[[1]] == 0L)
to_prune <- leaves[!(leaves %in% which(tabulate(pred[, tree]) >= min_node_size))]
while(length(to_prune) > 0) {
if (1 %in% to_prune) {
# Never prune the root
break
}
for (tp in to_prune) {
# Find parents
# Find parent
parent <- which((out[[1]] + 1L) == tp)
if (length(parent) > 0) {
# Left child
# If node to prune (tp) is the left child of parent, replace left child with right child
out[[1]][parent] <- out[[2]][parent]
} else {
# Right child
# If node to prune (tp) is the right child of parent, replace right child with left child
parent <- which((out[[2]] + 1L) == tp)
out[[2]][parent] <- out[[1]][parent]
}
}
# If both children of a parent are to be pruned, prune the parent in the next round
# This happens if both children have been pruned
to_prune <- which((out[[1]] + 1L) %in% to_prune)
}
return(out)
Expand Down
19 changes: 9 additions & 10 deletions R/expct.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
#' arf <- adversarial_rf(iris)
#' psi <- forde(arf, iris)
#'
#' # What is the expected value Sepal.Length?
#' # What is the expected value of Sepal.Length?
#' expct(psi, query = "Sepal.Length")
#'
#' # What if we condition on Species = "setosa"?
Expand All @@ -61,21 +61,20 @@ expct <- function(

# To avoid data.table check issues
variable <- tree <- f_idx <- cvg <- wt <- V1 <- value <- val <- family <-
mu <- sigma <- obs <- prob <- . <- NULL
mu <- sigma <- obs <- prob <- f_idx_uncond <- . <- NULL

# Prep evidence
conj <- FALSE
if (!is.null(evidence)) {
evidence <- prep_evi(params, evidence)
if (!all(c('f_idx', 'wt') %in% colnames(evidence))) {
conj <- TRUE
}
}
if (!is.null(evidence) && !(ncol(evidence) == 2 && all(c("f_idx", "wt") %in% colnames(evidence)))) {
evidence_variable <- prep_cond(evidence, params, "or")$variable
conj <- TRUE
}


# Check query
if (is.null(query)) {
if (isTRUE(conj)) {
query <- setdiff(params$meta$variable, evidence$variable)
query <- setdiff(params$meta$variable, evidence_variable)
} else {
query <- params$meta$variable
if (!is.null(evidence)) {
Expand All @@ -97,7 +96,7 @@ expct <- function(
omega[, wt := cvg / num_trees]
omega[, cvg := NULL]
} else if (conj) {
omega <- leaf_posterior(params, evidence)
omega <- cforde(params, evidence, "or")$forest[, .(f_idx = f_idx_uncond, wt = cvg)]
} else {
omega <- evidence
}
Expand Down
Loading
Loading