Skip to content

Commit 852598f

Browse files
committed
fix
1 parent cbfce7f commit 852598f

File tree

6 files changed

+73
-22
lines changed

6 files changed

+73
-22
lines changed

R/expct.R

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,24 @@ expct <- function(
246246
x_synth <- cbind(synth_cnt, synth_cat)
247247
x_synth <- post_x(x_synth, params, round)
248248

249+
if (evidence_row_mode == "separate" & any(omega[, is.na(f_idx)])) {
250+
setDT(x_synth)
251+
indices_na <- cparams$forest[is.na(f_idx), c_idx]
252+
indices_sampled <- cparams$forest[!is.na(f_idx), unique(c_idx)]
253+
rows_na <- dcast(rbind(data.table(c_idx = 0, variable = params$meta[,variable]),
254+
cparams$evidence_prepped[c_idx == indices_na,],
255+
fill = T),
256+
c_idx ~ variable, value.var = "val")[c_idx != 0,]
257+
if (nomatch == "force") {
258+
rows_na_sampled <- expct(params, parallel = parallel, stepsize = stepsize)
259+
rows_na[is.na(rows_na)] <- rows_na_sampled[is.na(rows_na[,-1])]
260+
}
261+
x_synth[, c_idx := indices_sampled]
262+
x_synth <- rbind(x_synth, rows_na, fill = T)
263+
setorder(x_synth, c_idx)[, c_idx := NULL]
264+
x_synth <- post_x(x_synth, params, round)
265+
}
266+
249267
x_synth
250268
}
251269
if (isTRUE(parallel)) {

R/forge.R

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -275,24 +275,25 @@ forge <- function(
275275
NA_share <- rbind(NA_share_cnt, NA_share_cat)
276276
setorder(NA_share[,variable := factor(variable, levels = params$meta[,variable])], variable, idx)
277277
NA_share[,dat := rbinom(.N, 1, prob = NA_share)]
278-
x_synth[dcast(NA_share,formula = idx ~ variable, value.var = "dat")[,-"idx"] == 1] <- NA
278+
x_synth[dcast(NA_share, formula = idx ~ variable, value.var = "dat")[,-"idx"] == 1] <- NA
279279
x_synth <- post_x(x_synth, params, round)
280280
}
281-
282281
if (evidence_row_mode == "separate" & any(omega[, is.na(f_idx)])) {
283282
setDT(x_synth)
284283
indices_na <- cparams$forest[is.na(f_idx), c_idx]
285284
indices_sampled <- cparams$forest[!is.na(f_idx), unique(c_idx)]
286-
evidence_part_long <- dcast(rbind(data.table(c_idx = 0, variable = params$meta[,variable]),
287-
cparams$evidence_prepped,
285+
rows_na <- dcast(rbind(data.table(c_idx = 0, variable = params$meta[,variable]),
286+
cparams$evidence_prepped[c_idx == indices_na,],
288287
fill = T),
289-
c_idx ~ variable, value.var = "val")[c_idx != 0,-"c_idx"]
290-
rows_na <- evidence_part_long[indices_na, ]
291-
rows_na[, idx := indices_na]
288+
c_idx ~ variable, value.var = "val")[c_idx != 0,]
292289
rows_na <- rbindlist(replicate(n_synth, rows_na, simplify = FALSE))
293-
x_synth[, idx := rep(indices_sampled, each = n_synth)]
290+
if (nomatch == "force") {
291+
rows_na_sampled <- forge(params, n_synth = nrow(rows_na), sample_NAs = sample_NAs, parallel = parallel, stepsize = stepsize)
292+
rows_na[is.na(rows_na)] <- rows_na_sampled[is.na(rows_na[,-1])]
293+
}
294+
x_synth[, c_idx := rep(indices_sampled, each = n_synth)]
294295
x_synth <- rbind(x_synth, rows_na, fill = T)
295-
setorder(x_synth, idx)[, idx := NULL]
296+
setorder(x_synth, c_idx)[, c_idx := NULL]
296297
x_synth <- post_x(x_synth, params, round)
297298
}
298299
x_synth

R/utils.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -527,12 +527,12 @@ cforde <- function(params,
527527
}
528528

529529
# Add all leaves for all-NA conditions to forest
530-
if ((nomatch == "force" & length(conds_impossible) > 0) | (row_mode == "separate" & nconds != nconds_conditioned)) {
531-
conds_unconditioned <- c(conds_impossible, (1:nconds)[!(1:nconds) %in% conds_conditioned])
530+
if (row_mode == "separate" & nconds != nconds_conditioned) {
531+
conds_unconditioned <- (1:nconds)[!(1:nconds) %in% conds_conditioned]
532532
forest_new_unconditioned <- copy(forest)
533533
forest_new_unconditioned <- rbindlist(replicate(length(conds_unconditioned), forest, simplify = F))
534534
forest_new_unconditioned[, `:=` (c_idx = rep(conds_unconditioned,each = nrow(forest)), f_idx_uncond = f_idx, cvg_arf = cvg)]
535-
forest_new <- rbind(forest_new, forest_new_unconditioned)[!is.na(f_idx), ]
535+
forest_new <- rbind(forest_new, forest_new_unconditioned)
536536
}
537537

538538
setorder(setcolorder(forest_new,c("f_idx","c_idx","f_idx_uncond","tree","leaf","cvg_arf","cvg")), c_idx, f_idx, f_idx_uncond, tree, leaf)
@@ -543,7 +543,7 @@ cforde <- function(params,
543543

544544
#' Preprocess conditions
545545
#'
546-
#' This function prepares conditions for computing conditional circuit paramaters via cforde
546+
547547
#'
548548
#' @param params Circuit parameters learned via \code{\link{forde}}.
549549
#' @param evidence Optional set of conditioning events.

man/prep_cond.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-conditions.R

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,15 @@ test_that("if nomatch='force' and verbose=TRUE, run through with a warning", {
3232
nomatch = "force", verbose = TRUE,
3333
n_synth = 10, parallel = FALSE),
3434
"All leaves have zero likelihood for some entered evidence rows\\. This is probably because evidence contains an \\(almost\\) impossible combination\\. Sampling from all possible leaves with equal probability \\(can be changed with 'nomatch' argument\\)\\.")
35-
expect_true(all(!is.na(x_synth)))
35+
expect_true(all(!is.na(x_synth) & x_synth$Sepal.Length == 100))
3636

3737
# No matching leaf case (finite bounds)
3838
psi_global <- forde(arf, iris, finite_bounds = "global", parallel = FALSE)
3939
expect_warning(x_synth <- forge(psi_global, evidence = data.frame(Sepal.Length = 100),
4040
nomatch = "force", verbose = TRUE,
4141
n_synth = 10, parallel = FALSE),
4242
"For some entered evidence rows, no matching leaves could be found\\. This is probably because evidence lies outside of the distribution calculated by FORDE\\. For continuous data, consider setting epsilon>0 or finite_bounds='no' in forde\\(\\)\\. For categorical data, consider setting alpha>0 in forde\\(\\)\\. Sampling from all leaves with equal probability \\(can be changed with 'nomatch' argument\\)\\.")
43-
expect_true(all(!is.na(x_synth)))
43+
expect_true(all(!is.na(x_synth) & x_synth$Sepal.Length == 100))
4444
})
4545

4646
test_that("if nomatch='force' and verbose=FALSE, run through without a warning", {
@@ -49,14 +49,14 @@ test_that("if nomatch='force' and verbose=FALSE, run through without a warning",
4949
expect_silent(x_synth <- forge(psi_no, evidence = data.frame(Sepal.Length = 100),
5050
nomatch = "force", verbose = FALSE,
5151
n_synth = 10, parallel = FALSE))
52-
expect_true(all(!is.na(x_synth)))
52+
expect_true(all(!is.na(x_synth) & x_synth$Sepal.Length == 100))
5353

5454
# No matching leaf case (finite bounds)
5555
psi_global <- forde(arf, iris, finite_bounds = "global", parallel = FALSE)
5656
expect_silent(x_synth <- forge(psi_global, evidence = data.frame(Sepal.Length = 100),
5757
nomatch = "force", verbose = FALSE,
5858
n_synth = 10, parallel = FALSE))
59-
expect_true(all(!is.na(x_synth)))
59+
expect_true(all(!is.na(x_synth) & x_synth$Sepal.Length == 100))
6060
})
6161

6262
test_that("if nomatch='na' and verbose=TRUE, run through with a warning and return NA", {
@@ -66,15 +66,15 @@ test_that("if nomatch='na' and verbose=TRUE, run through with a warning and retu
6666
nomatch = "na", verbose = TRUE,
6767
n_synth = 10, parallel = FALSE),
6868
"All leaves have zero likelihood for some entered evidence rows\\. This is probably because evidence contains an \\(almost\\) impossible combination\\. Returning NA for those rows \\(can be changed with 'nomatch' argument\\)\\.")
69-
expect_true(all(is.na(x_synth[, -1])))
69+
expect_true(all(is.na(x_synth[, -1]) & x_synth$Sepal.Length == 100))
7070

7171
# No matching leaf case (finite bounds)
7272
psi_global <- forde(arf, iris, finite_bounds = "global", parallel = FALSE)
7373
expect_warning(x_synth <- forge(psi_global, evidence = data.frame(Sepal.Length = 100),
7474
nomatch = "na", verbose = TRUE,
7575
n_synth = 10, parallel = FALSE),
7676
"For some entered evidence rows, no matching leaves could be found\\. This is probably because evidence lies outside of the distribution calculated by FORDE\\. For continuous data, consider setting epsilon>0 or finite_bounds='no' in forde\\(\\)\\. For categorical data, consider setting alpha>0 in forde\\(\\)\\. Returning NA for those rows \\(can be changed with 'nomatch' argument\\)\\.")
77-
expect_true(all(is.na(x_synth[, -1])))
77+
expect_true(all(is.na(x_synth[, -1]) & x_synth$Sepal.Length == 100))
7878
})
7979

8080
test_that("if nomatch='na' and verbose=FALSE, run through without a warning and return NA", {
@@ -83,13 +83,13 @@ test_that("if nomatch='na' and verbose=FALSE, run through without a warning and
8383
expect_silent(x_synth <- forge(psi_no, evidence = data.frame(Sepal.Length = 100),
8484
nomatch = "na", verbose = FALSE,
8585
n_synth = 10, parallel = FALSE))
86-
expect_true(all(is.na(x_synth[, -1])))
86+
expect_true(all(is.na(x_synth[, -1]) & x_synth$Sepal.Length == 100))
8787

8888
# No matching leaf case (finite bounds)
8989
psi_global <- forde(arf, iris, finite_bounds = "global", parallel = FALSE)
9090
expect_silent(x_synth <- forge(psi_global, evidence = data.frame(Sepal.Length = 100),
9191
nomatch = "na", verbose = FALSE,
9292
n_synth = 10, parallel = FALSE))
93-
expect_true(all(is.na(x_synth[, -1])))
93+
expect_true(all(is.na(x_synth[, -1]) & x_synth$Sepal.Length == 100))
9494
})
9595

tests/testthat/test_expct.R

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,35 @@ test_that("expct works with partial sample", {
5151
expect_equal(ncol(res), ncol(iris) - ncol(evi))
5252
expect_equal(colnames(res), colnames(iris)[3:5])
5353
})
54+
55+
test_that("if nomatch='force', do not return NA", {
56+
# Zero likelihood case (no finite bounds)
57+
psi_no <- forde(arf, iris, finite_bounds = "no", parallel = FALSE)
58+
x_synth <- expct(psi_no, evidence = data.frame(Sepal.Length = 100),
59+
nomatch = "force", verbose = FALSE,
60+
parallel = FALSE)
61+
expect_true(all(!is.na(x_synth) & x_synth$Sepal.Length == 100))
62+
63+
# No matching leaf case (finite bounds)
64+
psi_global <- forde(arf, iris, finite_bounds = "global", parallel = FALSE)
65+
x_synth <- expct(psi_global, evidence = data.frame(Sepal.Length = 100),
66+
nomatch = "force", verbose = FALSE,
67+
parallel = FALSE)
68+
expect_true(all(!is.na(x_synth) & x_synth$Sepal.Length == 100))
69+
})
70+
71+
test_that("if nomatch='na', return NA", {
72+
# Zero likelihood case (no finite bounds)
73+
psi_no <- forde(arf, iris, finite_bounds = "no", parallel = FALSE)
74+
x_synth <- expct(psi_no, evidence = data.frame(Sepal.Length = 100),
75+
nomatch = "na", verbose = FALSE,
76+
parallel = FALSE)
77+
expect_true(all(is.na(x_synth[, -1]) & x_synth$Sepal.Length == 100))
78+
79+
# No matching leaf case (finite bounds)
80+
psi_global <- forde(arf, iris, finite_bounds = "global", parallel = FALSE)
81+
x_synth <- expct(psi_global, evidence = data.frame(Sepal.Length = 100),
82+
nomatch = "na", verbose = FALSE,
83+
parallel = FALSE)
84+
expect_true(all(is.na(x_synth[, -1]) & x_synth$Sepal.Length == 100))
85+
})

0 commit comments

Comments
 (0)