Skip to content

Commit 81a15bc

Browse files
Merge pull request #118 from tidymodels/partykit-classification-prediction
update `partykit_tree_info()` to handle classification outputs
2 parents dde75bd + ae6e973 commit 81a15bc

File tree

2 files changed

+78
-3
lines changed

2 files changed

+78
-3
lines changed

NAMESPACE

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ S3method(tidypredict_test,party)
3939
S3method(tidypredict_test,randomForest)
4040
S3method(tidypredict_test,ranger)
4141
S3method(tidypredict_test,xgb.Booster)
42+
export(.extract_partykit_classprob)
4243
export(.extract_xgb_trees)
4344
export(acceptable_formula)
4445
export(as_parsed_model)

R/model-partykit.R

+77-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
11
partykit_tree_info <- function(model) {
22
model_nodes <- map(seq_along(model), ~ model[[.x]])
33
is_split <- map_lgl(model_nodes, ~ class(.x$node[1]) == "partynode")
4-
# non-cat model
5-
mean_resp <- map_dbl(model_nodes, ~ mean(.x$fitted[, "(response)"]))
6-
prediction <- ifelse(!is_split, mean_resp, NA)
4+
if (is.numeric(model_nodes[[1]]$fitted[["(response)"]])) {
5+
mean_resp <- map_dbl(model_nodes, ~ mean(.x$fitted[, "(response)"]))
6+
prediction <- ifelse(!is_split, mean_resp, NA)
7+
} else {
8+
stat_mode <- function(x) {
9+
counts <- rev(sort(table(x)))
10+
if (counts[[1]] == counts[[2]]) {
11+
ties <- counts[counts[1] == counts]
12+
return(names(rev(ties))[1])
13+
}
14+
names(counts)[1]
15+
}
16+
mode_resp <- map_chr(model_nodes, ~ stat_mode(.x$fitted[, "(response)"]))
17+
prediction <- ifelse(!is_split, mode_resp, NA)
18+
}
19+
720
party_nodes <- map(seq_along(model), ~ partykit::nodeapply(model, .x))
821

922
kids <- map(party_nodes, ~ {
@@ -88,3 +101,64 @@ tidypredict_fit.party <- function(model) {
88101
parsedmodel <- parse_model(model)
89102
build_fit_formula_rf(parsedmodel)[[1]]
90103
}
104+
105+
# For {orbital}
106+
#' @keywords internal
107+
#' @export
108+
.extract_partykit_classprob <- function(model) {
109+
extract_classprob <- function(model) {
110+
mod <- model$fitted
111+
response <- mod[["(response)"]]
112+
weights <- mod[["(weights)"]]
113+
114+
lvls <- levels(response)
115+
weights_sum <- tapply(weights, response, sum)
116+
weights_sum[is.na(weights_sum)] <- 0
117+
res <- weights_sum / sum(weights)
118+
names(res) <- lvls
119+
res
120+
}
121+
122+
preds <- map(seq_along(model), ~extract_classprob(model[[.x]]))
123+
preds <- matrix(
124+
unlist(preds),
125+
nrow = length(preds),
126+
byrow = TRUE,
127+
dimnames = list(NULL, names(preds[[1]]))
128+
)
129+
130+
generate_one_tree <- function(tree_info) {
131+
paths <- tree_info$nodeID[tree_info[, "terminal"]]
132+
paths <- map(
133+
paths,
134+
~ {
135+
prediction <- tree_info$prediction[tree_info$nodeID == .x]
136+
if (is.null(prediction)) cli::cli_abort("Prediction column not found")
137+
if (is.factor(prediction)) prediction <- as.character(prediction)
138+
list(
139+
prediction = prediction,
140+
path = get_ra_path(.x, tree_info, FALSE)
141+
)
142+
}
143+
)
144+
145+
classes <- attr(model$terms, "dataClasses")
146+
pm <- list()
147+
pm$general$model <- "party"
148+
pm$general$type <- "tree"
149+
pm$general$version <- 2
150+
pm$trees <- list(paths)
151+
parsedmodel <- as_parsed_model(pm)
152+
153+
build_fit_formula_rf(parsedmodel)[[1]]
154+
}
155+
156+
tree_info <- partykit_tree_info(model)
157+
158+
res <- list()
159+
for (i in seq_len(ncol(preds))) {
160+
tree_info$prediction <- preds[, i]
161+
res[[i]] <- generate_one_tree(tree_info)
162+
}
163+
res
164+
}

0 commit comments

Comments
 (0)