Skip to content

Commit

Permalink
change default stepsize
Browse files Browse the repository at this point in the history
  • Loading branch information
jkapar committed Mar 4, 2024
1 parent 0da0b5a commit d7b8296
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 91 deletions.
39 changes: 31 additions & 8 deletions R/forge.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#' @param condition_row_mode Interpretation of rows in multi-row conditions.
#' @param sample_NAs Sample NAs respecting the probability for missing values in the original data.
#' @param stepsize Stepsize defining number of condition rows handled in one for each step.
#' Defaults to nrow(condition)/num_registered_workers for parallel == TRUE.
#' @param parallel Compute in parallel? Must register backend beforehand, e.g.
#' via \code{doParallel}.
#' @param n_synth Number of synthetic samples to generate.
Expand Down Expand Up @@ -76,7 +77,7 @@
#'
#' @export
#' @import data.table
#' @importFrom foreach foreach %dopar%
#' @importFrom foreach foreach %dopar% getDoPar getDoParWorkers
#' @importFrom truncnorm rtruncnorm
#' @importFrom stats rbinom
#'
Expand All @@ -87,11 +88,17 @@ forge <- function(
condition = NULL,
condition_row_mode = c("separate", "or"),
sample_NAs = F,
stepsize = 200,
stepsize = 0,
parallel = TRUE) {

condition_row_mode <- match.arg(condition_row_mode)

doParRegistered <- getDoParRegistered()
num_workers <- getDoParWorkers()
if(!parallel & doParRegistered & (num_workers > 1)) {
registerDoSEQ()
}

# To avoid data.table check issues
tree <- cvg <- leaf <- idx <- family <- mu <- sigma <- prob <- dat <-
variable <- relation <- wt <- j <- f_idx <- val <- . <- step_ <- c_idx <-
Expand All @@ -102,10 +109,19 @@ forge <- function(
if(!is.null(condition)) {
condition <- as.data.table(condition)
if(parallel & condition_row_mode == "separate") {
stepsize_foreach <- stepsize
step_no <- ceiling(nrow(condition)/stepsize_foreach)
if(stepsize == 0) {
stepsize <- ceiling(nrow(condition)/getDoParWorkers())
}
stepsize_cforde <- 0
parallel_cforde = F
step_no <- ceiling(nrow(condition)/stepsize)
} else {
stepsize_foreach <- nrow(condition)
if(stepsize == 0) {
stepsize = nrow(condition)
}
stepsize_cforde <- stepsize
parallel_cforde <- parallel
stepsize <- nrow(condition)
step_no <- 1
}
} else {
Expand All @@ -116,10 +132,10 @@ forge <- function(

# Prepare the event space
if (!is.null(condition)) {
index_start <- (step_-1)*stepsize_foreach + 1
index_end <- min(step_*stepsize_foreach, nrow(condition))
index_start <- (step_-1)*stepsize + 1
index_end <- min(step_*stepsize, nrow(condition))
condition_part <- condition[index_start:index_end,]
cparams <- cforde(params, condition_part, condition_row_mode, stepsize)
cparams <- cforde(params, condition_part, condition_row_mode, stepsize_cforde, parallel_cforde)
if(is.null(cparams)) {
n_synth <- n_synth * nrow(condition_part)
}
Expand Down Expand Up @@ -198,6 +214,9 @@ forge <- function(

# Combine, optionally impose constraint(s)
x_synth <- cbind(synth_cnt, synth_cat)
if(length(x_synth) == 0) {
x_synth <- condition_part[F,]
}

# Clean up, export
x_synth <- post_x(x_synth, params)
Expand Down Expand Up @@ -230,6 +249,10 @@ forge <- function(
x_synth
}

if(!parallel & doParRegistered & (num_workers > 1)) {
registerDoParallel(num_workers)
}

return(x_synth_)
}

Expand Down
183 changes: 100 additions & 83 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -342,10 +342,16 @@ post_x <- function(x, params) {
#' @importFrom stats dunif punif
#'

cforde <- function(params, condition, row_mode = c("separate", "or"), stepsize = 200) {
cforde <- function(params, condition, row_mode = c("separate", "or"), stepsize = 0, parallel = TRUE) {

row_mode <- match.arg(row_mode)

doParRegistered <- getDoParRegistered()
num_workers <- getDoParWorkers()
if(!parallel & doParRegistered & (num_workers > 1)) {
registerDoSEQ()
}

# To avoid data.table check issues
. <- c_idx <- cvg <- cvg_arf <- cvg_factor <- f_idx <- f_idx_uncond <- i.max <-
i.min <- leaf <- max.x <- max.y <- min.x <- min.y <- mu <- prob <- sigma <-
Expand All @@ -361,6 +367,7 @@ cforde <- function(params, condition, row_mode = c("separate", "or"), stepsize =

# format c, calculate DNF and output disjoint hyperrectangles
condition_long <- prep_cond(condition,params, row_mode)
setkey(condition_long, c_idx)

if(nrow(condition_long) == 0){
return(NULL)
Expand All @@ -377,56 +384,58 @@ cforde <- function(params, condition, row_mode = c("separate", "or"), stepsize =

conds_conditioned <- condition_long[,unique(c_idx)]

# store conditions for cat and cnt separately
cat_conds <- condition_long[variable %in% cat_cols,c("c_idx","variable","val")][,variable := factor(variable)]
cnt_conds <- condition_long[variable %in% cnt_cols,c("c_idx","variable","min", "max","val")][,`:=` (variable = factor(variable),
val = as.numeric(val))]

#
if (nrow(cat_conds) != 0) {
cat_relevant <- cat[,.(.(f_idx)), by=.(variable,val)]
setkey(cat_relevant, variable, val)
setkey(cat_conds, variable, val)
step_no <- ceiling(nconds_conditioned/stepsize)
cat_relevant <- cat_conds[cat_relevant, on = .(variable, val),nomatch = NULL]
setkey(cat_relevant,c_idx)

relevant_leaves_changed_cat <- foreach(step_ = 1:step_no, .combine = "rbind") %do% {
index_start <- conds_conditioned[(step_ - 1)*stepsize + 1]
index_end <- conds_conditioned[min(step_ * stepsize, nconds_conditioned)]
cat_relevant[.(index_start:index_end), Reduce(intersect,V1),by = c_idx][,.(c_idx, f_idx = V1)]
if(stepsize == 0) {
if(parallel) {
stepsize <- ceiling(nconds_conditioned/getDoParWorkers())
} else {
stepsize <- nconds_conditioned
}

setorder(relevant_leaves_changed_cat)
conditions_unchanged_cat <- (1:nconds)[!(1:nconds %in% cat_conds[,c_idx])]
relevant_leaves_unchanged_cat <- data.table(c_idx = rep(conditions_unchanged_cat, each = nrow(forest) ), f_idx = rep(forest[,f_idx],length(conditions_unchanged_cat)))
relevant_leaves_cat <- rbind(relevant_leaves_changed_cat, relevant_leaves_unchanged_cat)
relevant_leaves_cat_list <- relevant_leaves_cat[,.(f_idx = .(f_idx)),by=c_idx]
} else {
relevant_leaves_cat <- data.table(c_idx = integer(), f_idx = integer())
}


if (nrow(cnt_conds) != 0) {
cnt_conds_compact <- copy(cnt_conds)
cnt_conds_compact[!is.na(val), `:=`(min = val, max = val)][,val := NULL]
step_no <- ceiling(nconds_conditioned/stepsize)
updates_relevant_leaves <- foreach(step_ = 1:step_no, .combine = "rbind") %dopar% {

cnt_relevant <- cnt[,.(min = .(min), max = .(max)),by = variable]
cnt_relevant <- cnt_conds_compact[cnt_relevant,on = .(variable), nomatch = NULL]
index_start <- conds_conditioned[(step_ - 1)*stepsize + 1]
index_end <- conds_conditioned[min(step_ * stepsize, nconds_conditioned)]
condition_long_step <- condition_long[.(index_start:index_end),nomatch = NULL]

# store conditions for cat and cnt separately
cat_conds <- condition_long_step[variable %in% cat_cols,c("c_idx","variable","val")][,variable := factor(variable)]
cnt_conds <- condition_long_step[variable %in% cnt_cols,c("c_idx","variable","min", "max","val")][,`:=` (variable = factor(variable),
val = as.numeric(val))]

if (nrow(cat_conds) != 0) {
cnt_relevant <- cnt_relevant[relevant_leaves_cat_list, on = .(c_idx)]
cat_relevant <- cat[,.(.(f_idx)), by=.(variable,val)]
setkey(cat_relevant, variable, val)
setkey(cat_conds, variable, val)
cat_relevant <- cat_conds[cat_relevant, on = .(variable, val), nomatch = NULL]
setkey(cat_relevant, c_idx)
relevant_leaves_changed_cat <- cat_relevant[, Reduce(intersect,V1),by = c_idx][,.(c_idx, f_idx = V1)]
setorder(relevant_leaves_changed_cat)

conditions_unchanged_cat <- setdiff(condition_long_step[, c_idx], cat_conds[, c_idx])
relevant_leaves_unchanged_cat <- data.table(c_idx = rep(conditions_unchanged_cat, each = nrow(forest) ), f_idx = rep(forest[,f_idx],length(conditions_unchanged_cat)))
relevant_leaves_cat <- rbind(relevant_leaves_changed_cat, relevant_leaves_unchanged_cat, fill = T)
relevant_leaves_cat_list <- relevant_leaves_cat[,.(f_idx = .(f_idx)),by=c_idx]
} else {
cnt_relevant[,f_idx := NA]
relevant_leaves_cat <- data.table(c_idx = integer(), f_idx = integer())
}

setkey(cnt_relevant,c_idx)
step_no <- ceiling(nconds_conditioned/stepsize)

relevant_leaves_changed_cnt <- foreach(step_ = 1:step_no, .combine = "rbind") %do% {
index_start <- conds_conditioned[(step_ - 1)*stepsize + 1]
index_end <- conds_conditioned[min(step_ * stepsize, nconds_conditioned)]
cnt_relevant[.(index_start:index_end), .(
if (nrow(cnt_conds) != 0) {
cnt_conds_compact <- copy(cnt_conds)
cnt_conds_compact[!is.na(val), `:=`(min = val, max = val)][,val := NULL]

cnt_relevant <- cnt[,.(min = .(min), max = .(max)),by = variable]
cnt_relevant <- cnt_conds_compact[cnt_relevant, on = .(variable), nomatch = NULL]
setkey(cnt_relevant,c_idx)

if (nrow(cat_conds) != 0) {
cnt_relevant <- cnt_relevant[relevant_leaves_cat_list, on = .(c_idx)]
} else {
cnt_relevant[, f_idx := NA]
}

relevant_leaves_changed_cnt <- cnt_relevant[, .(
c_idx,
variable,
f_idx = Map(\(f_idx, min, max, i.min, i.max) {
Expand All @@ -443,40 +452,51 @@ cforde <- function(params, condition, row_mode = c("separate", "or"), stepsize =
}
intersect(rel_min,rel_max)
}, f_idx = f_idx, min = min, max = max, i.min = i.min, i.max = i.max))
][, Reduce(intersect,f_idx),by = c_idx][,.(c_idx, f_idx = V1)]
}

conditions_unchanged_cnt <- (1:nconds)[!(1:nconds %in% cnt_conds[,c_idx])]
relevant_leaves_unchanged_cnt <- data.table(c_idx = rep(conditions_unchanged_cnt, each = nrow(forest) ), f_idx = rep(forest[,f_idx],length(conditions_unchanged_cnt)))
relevant_leaves_cnt <- rbind(relevant_leaves_changed_cnt, relevant_leaves_unchanged_cnt)

cnt_new <- merge(merge(relevant_leaves_cnt, cnt_conds, by = "c_idx", allow.cartesian = T, sort = F), cnt, by = c("f_idx", "variable"), all.x = T, allow.cartesian = T, sort = F)
cnt_new[!is.na(val),`:=` (min = min.y,
max = max.y)]
cnt_new[is.na(val),`:=` (min = pmax(min.x, min.y, na.rm = T),
max = pmin(max.x, max.y, na.rm = T))]
cnt_new[,cvg_factor := NA_real_]
if (family == "truncnorm") {
cnt_new[!is.na(val), cvg_factor := dtruncnorm(val, a=min.y, b=max.y, mean=mu, sd=sigma)*(val != min.y)]
cnt_new[is.na(val) & (min == min.y) & (max == max.y), cvg_factor := 1]
cnt_new[is.na(val) & is.na(cvg_factor), cvg_factor := ptruncnorm(max, a=min.y, b=max.y, mean=mu, sd=sigma) - ptruncnorm(min, a=min.y, max.y, mean=mu,sd=sigma)]
} else if (family == "unif") {
cnt_new[!is.na(val), cvg_factor := dunif(val, min=min.y, max=max.y)*(val != min.y)]
cnt_new[is.na(val) & (min == min.y) & (max == max.y), cvg_factor := 1]
cnt_new[is.na(val) & is.na(cvg_factor), cvg_factor := punif(max, min=min.y, max=max.y) - punif(min, min=min.y, max.y)]
}
cnt_new[,c("min.x","max.x","min.y","max.y") := NULL]
if (nrow(cat_conds) > 0) {
relevant_leaves <- merge(relevant_leaves_cnt, relevant_leaves_cat, by = c("c_idx", "f_idx"))[,.(c_idx, f_idx = .I, f_idx_uncond =f_idx)]
][, Reduce(intersect,f_idx),by = c_idx][,.(c_idx, f_idx = V1)]

conditions_unchanged_cnt <- setdiff(condition_long_step[, c_idx], cnt_conds[, c_idx])
relevant_leaves_unchanged_cnt <- data.table(c_idx = rep(conditions_unchanged_cnt, each = nrow(forest)), f_idx = rep(forest[,f_idx],length(conditions_unchanged_cnt)))
relevant_leaves_cnt <- rbind(relevant_leaves_changed_cnt, relevant_leaves_unchanged_cnt)

cnt_new <- merge(merge(relevant_leaves_cnt, cnt_conds, by = "c_idx", allow.cartesian = T, sort = F), cnt, by = c("f_idx", "variable"), all.x = T, allow.cartesian = T, sort = F)
cnt_new[!is.na(val),`:=` (min = min.y,
max = max.y)]
cnt_new[is.na(val),`:=` (min = pmax(min.x, min.y, na.rm = T),
max = pmin(max.x, max.y, na.rm = T))]
cnt_new[,cvg_factor := NA_real_]
if (family == "truncnorm") {
cnt_new[!is.na(val), cvg_factor := dtruncnorm(val, a=min.y, b=max.y, mean=mu, sd=sigma)*(val != min.y)]
cnt_new[is.na(val) & (min == min.y) & (max == max.y), cvg_factor := 1]
cnt_new[is.na(val) & is.na(cvg_factor), cvg_factor := ptruncnorm(max, a=min.y, b=max.y, mean=mu, sd=sigma) - ptruncnorm(min, a=min.y, max.y, mean=mu,sd=sigma)]
} else if (family == "unif") {
cnt_new[!is.na(val), cvg_factor := dunif(val, min=min.y, max=max.y)*(val != min.y)]
cnt_new[is.na(val) & (min == min.y) & (max == max.y), cvg_factor := 1]
cnt_new[is.na(val) & is.na(cvg_factor), cvg_factor := punif(max, min=min.y, max=max.y) - punif(min, min=min.y, max.y)]
}
cnt_new[,c("min.x","max.x","min.y","max.y") := NULL]
if (nrow(cat_conds) > 0) {
relevant_leaves <- merge(relevant_leaves_cnt, relevant_leaves_cat, by = c("c_idx", "f_idx"))[,.(c_idx, f_idx)]
} else {
relevant_leaves <- relevant_leaves_cnt[,.(c_idx, f_idx)]
}
} else {
relevant_leaves <- relevant_leaves_cnt[,.(c_idx, f_idx = .I, f_idx_uncond =f_idx)]
relevant_leaves <- relevant_leaves_cat[,.(c_idx, f_idx)]
cnt_new <- cbind(cnt[F,], data.table(cvg_factor = numeric(), c_idx = integer(), val = numeric()))
}
cnt_new <- setcolorder(merge(relevant_leaves, cnt_new, by.x = c("c_idx", "f_idx_uncond"), by.y = c("c_idx", "f_idx"), sort = F),c("f_idx","c_idx","variable","min","max","val","cvg_factor"))[]
} else {
relevant_leaves <- relevant_leaves_cat[,.(c_idx, f_idx = .I, f_idx_uncond =f_idx)]
cnt_new <- cbind(cnt[F,], data.table(f_idx_uncond = integer(), cvg_factor = numeric(), c_idx = integer(), val = numeric()))
cat_new <- merge(merge(relevant_leaves, cat_conds, by = "c_idx", allow.cartesian = T), cat, by = c("f_idx","variable", "val"))
cat_new[,`:=` (cvg_factor = prob, prob = 1)]

list(cnt_new = cnt_new, cat_new = cat_new, relevant_leaves = relevant_leaves)
}

if(is.matrix(updates_relevant_leaves)) {
updates_relevant_leaves <- lapply(as.data.table(updates_relevant_leaves), rbindlist)
}

relevant_leaves <- updates_relevant_leaves$relevant_leaves[,`:=` (f_idx = .I, f_idx_uncond = f_idx)][]
cnt_new <- setcolorder(merge(relevant_leaves, updates_relevant_leaves$cnt_new, by.x = c("c_idx", "f_idx_uncond"), by.y = c("c_idx", "f_idx"), sort = F), c("f_idx","c_idx","variable","min","max","val","cvg_factor"))[]
cat_new <- setcolorder(merge(relevant_leaves, updates_relevant_leaves$cat_new, by = c("c_idx", "f_idx"), sort = F), c("f_idx","c_idx","variable","val","prob","cvg_factor"))[]

if(relevant_leaves[,uniqueN(c_idx)] < nconds_conditioned) {
if(relevant_leaves[,uniqueN(c_idx)] == 0 & row_mode == "or") {
stop("For all entered evidence rows, no matching leaves could be found. This is probably because evidence lies outside of the distribution calculated by FORDE. For continuous data, consider setting epsilon>0 or finite_bounds=FALSE in forde(). For categorical data, consider setting alpha>0 in forde()")
Expand All @@ -487,11 +507,6 @@ cforde <- function(params, condition, row_mode = c("separate", "or"), stepsize =
}
}

cat_new <- merge(merge(relevant_leaves, cat_conds, by = "c_idx", allow.cartesian = T), cat, by.x = c("f_idx_uncond","variable", "val"), by.y = c("f_idx","variable", "val"))
cat_new[,`:=` (cvg_factor = prob, prob = 1)]
setcolorder(cat_new,c("f_idx","c_idx","f_idx_uncond","variable","val","prob","cvg_factor"))


forest_new <- merge(relevant_leaves,forest, by.x = "f_idx_uncond", by.y = "f_idx", all.x = T, sort = F)
setnames(forest_new,"cvg","cvg_arf")

Expand All @@ -516,10 +531,9 @@ cforde <- function(params, condition, row_mode = c("separate", "or"), stepsize =
cvg_new <- cvg_new[cvg > 0,][,cvg := cvg / sum(cvg)]
}
} else {
if(any(cvg_new[,all(cvg == -Inf), by = c_idx][,V1])) {
cvg_new[,all(cvg == -Inf), by = c_idx][V1 ==T]
if(any(cvg_new[, leaf_zero_lik := all(cvg == -Inf), by = c_idx][, leaf_zero_lik])) {
warning("All leaves have zero likelihood for some entered evidence rows. This is probably because evidence contains an (almost) impossible combination.")
cvg_new[, cvg := 1/.N, by = c_idx]
cvg_new[leaf_zero_lik == T, cvg := 1/.N, by = c_idx][, leaf_zero_lik := NULL]
} else {
cvg_new[, cvg := exp(cvg - max(cvg)), by = c_idx]
cvg_new <- cvg_new[cvg > 0,][,cvg := cvg / sum(cvg), by = c_idx]
Expand All @@ -540,15 +554,18 @@ cforde <- function(params, condition, row_mode = c("separate", "or"), stepsize =
}
if(row_mode == "separate" & (nconds != nconds_conditioned)) {
conds_unconditioned <- (1:nconds)[!(1:nconds) %in% conds_conditioned]
forest_new <- forest_new[!(c_idx %in% conds_unconditioned),]
forest_new_unconditioned <- copy(forest)
forest_new_unconditioned <- rbindlist(replicate(length(conds_unconditioned), forest, simplify = F))
forest_new_unconditioned[, `:=` (c_idx = rep(conds_unconditioned,each = nrow(forest)), f_idx_uncond = f_idx, cvg_arf = cvg)]
forest_new <- rbind(forest_new, forest_new_unconditioned)
}

setorder(setcolorder(forest_new,c("f_idx","c_idx","f_idx_uncond","tree","leaf","cvg_arf","cvg")), c_idx, f_idx, f_idx_uncond, tree, leaf)

if(!parallel & doParRegistered & (num_workers > 1)) {
registerDoParallel(num_workers)
}

list(condition_input = condition, condition_prepped = condition_long, cnt = cnt_new, cat = cat_new, forest = forest_new)
}

Expand Down Expand Up @@ -627,7 +644,7 @@ prep_cond <- function(condition, params, row_mode) {
condition_long <- condition_long[!(min == -Inf & max == Inf & is.na(val))]
}
setorder(condition_long, c_idx)
condition_long
condition_long[]
}


Expand Down

0 comments on commit d7b8296

Please sign in to comment.