Skip to content

Commit

Permalink
add impute() functions for imputation
Browse files Browse the repository at this point in the history
  • Loading branch information
mnwright committed Nov 7, 2024
1 parent 0d5ccc6 commit 2e1c363
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 3 deletions.
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: arf
Title: Adversarial Random Forests
Version: 0.2.2
Date: 2024-05-29
Version: 0.2.3
Date: 2024-11-07
Authors@R:
c(person(given = "Marvin N.",
family = "Wright",
Expand Down Expand Up @@ -45,7 +45,7 @@ Imports:
truncnorm
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
RoxygenNote: 7.3.2
Suggests:
ggplot2,
doParallel,
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export(earf)
export(expct)
export(forde)
export(forge)
export(impute)
export(lik)
export(rarf)
import(data.table)
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# arf 0.2.3
* Add impute() function for direct missing data imputation with ARF
* Add one-line functions darf(), earf(), rarf()

# arf 0.2.2
* Faster and vectorized conditional sampling
* Use min.bucket argument from ranger to avoid pruning if possible
Expand Down
101 changes: 101 additions & 0 deletions R/impute.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@

#' Missing value imputation with ARF
#'
#' Imputed a dataset with missing values using adversarial random forests (ARF).
#' Calls \code{adversarial_rf}, \code{forde} and \code{expct}/\code{forge}.
#'
#' @param x Input data.
#' @param m Number of multiple imputations.
#' @param expectation Return expected value instead of multiple imputations. By default, for single imputation (\code{m=1}), the expected value is returned.
#' @param num_trees Number of trees in ARF.
#' @param min_node_size Minimum node size in ARF.
#' @param round Round imputed values to their respective maximum precision in the original data set?
#' @param finite_bounds Impose finite bounds on all continuous variables? See \code{\link{forde}}.
#' @param epsilon Slack parameter on empirical bounds; see \code{\link{forde}}.
#' @param verbose Print progress for \code{adversarial_rf}?
#' @param ... Extra parameters to be passed to \code{adversarial_rf}, \code{forde}
#' and \code{expct}/\code{forge}.
#'
#' @return Imputed data. A single data table is returned for \code{m=1} and a list of data table for \code{m > 1}.
#' @export
#'
#' @examples
#' # Generate some missings
#' iris_na <- iris
#' for (j in 1:ncol(iris)) {
#' iris_na[sample(1:nrow(iris), 5), j] <- NA
#' }
#'
#' # Parallelization with doParallel
#' doParallel::registerDoParallel(cores = 2)
#'
#' # Single imputation
#' iris_imputed <- arf::impute(iris_na, m = 1)
#'
#' # Multiple imputation
#' iris_imputed <- arf::impute(iris_na, m = 20)
impute <- function(x,
m = 20,
expectation = ifelse(m==1, TRUE, FALSE),
num_trees = 100L,
min_node_size = 10L,
round = TRUE,
finite_bounds = "local",
epsilon = 1e-14,
verbose = FALSE,
...) {

# To avoid data.table check issues
idx <- . <- NULL

if (m > 1 & expectation) {
stop("Multiple imputation with expectation is not possible.")
}
if (!anyNA(x)) {
message("No missing values found. Returning input data.")
return(x)
}

# Separate ... arguments for each function
arg_names <- list(arf = names(as.list(args(adversarial_rf))),
forde = names(as.list(args(forde))),
forge = names(as.list(args(forge))),
expct = names(as.list(args(expct))))
dot_args <- list(...)
arf_args <- dot_args[names(dot_args) %in% arg_names$arf]
forde_args <- dot_args[names(dot_args) %in% arg_names$forde]
forge_args <- dot_args[names(dot_args) %in% arg_names$forge]
expct_args <- dot_args[names(dot_args) %in% arg_names$expct]

# ARF and FORDE
arf <- do.call(adversarial_rf, c(x = list(x),
verbose = list(verbose),
num_trees = list(num_trees),
min_node_size = list(min_node_size),
arf_args))
psi <- do.call(forde, c(arf = list(arf),
x = list(x),
finite_bounds = list(finite_bounds),
epsilon = list(epsilon),
forde_args))

if (expectation) {
# Expected value
x_imputed <- do.call(expct, c(params = list(psi),
evidence = list(x),
round = list(round),
expct_args))
} else {
# Multiple imputation
x_synth <- do.call(forge, c(params = list(psi),
n_synth = list(m),
evidence = list(x),
round = list(round),
forge_args))
x_synth <- as.data.table(x_synth)
x_synth[, idx := rep(1:m, nrow(x))]
x_imputed <- split(x_synth, by = "idx")
x_imputed <- lapply(x_imputed, function(x) x[, idx := NULL])
}
x_imputed
}
64 changes: 64 additions & 0 deletions man/impute.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 24 additions & 0 deletions tests/testthat/test-impute.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

# Generate some missings
iris_na <- iris
for (j in 1:ncol(iris)) {
iris_na[sample(1:nrow(iris), 5), j] <- NA
}

test_that("impute returns same data with message if no missings", {
expect_message(iris_imputed <- arf::impute(iris, parallel = FALSE), "No missing values found\\. Returning input data\\.")
expect_equal(iris, iris_imputed)
})

test_that("Imputation fills missing values", {
# Single imputation
iris_imputed <- arf::impute(iris_na, m = 1, parallel = FALSE)
expect_s3_class(iris_imputed, "data.frame")
expect_true(!anyNA(iris_imputed))

# Multiple imputation
iris_imputed <- arf::impute(iris_na, parallel = FALSE)
expect_type(iris_imputed, "list")
expect_length(iris_imputed, 20)
expect_true(all(sapply(iris_imputed, function(x) !anyNA(x))))
})

0 comments on commit 2e1c363

Please sign in to comment.