Skip to content

Commit

Permalink
Add plotLIFT and plotCGain
Browse files Browse the repository at this point in the history
  • Loading branch information
agosiewska committed Mar 22, 2018
1 parent d6d71a3 commit ca79d16
Show file tree
Hide file tree
Showing 7 changed files with 250 additions and 2 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Imports: car,
ggplot2,
hnp,
plotROC,
ROCR,
tseries
RoxygenNote: 6.0.1
Suggests: aods3,
Expand Down
6 changes: 5 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Generated by roxygen2: do not edit by hand

S3method(plot,modelAudit)
export(audit)
export(plot.modelAudit)
export(plotACF)
export(plotAutocorrelation)
export(plotCGains)
export(plotCook)
export(plotHalfNormal)
export(plotLIFT)
export(plotREC)
export(plotROC)
export(plotRROC)
Expand All @@ -17,6 +19,8 @@ export(scoreDW)
export(scoreGQ)
export(scoreHalfNormal)
export(scoreRuns)
import(ROCR)
import(dplyr)
import(ggplot2)
import(plotROC)
importFrom(car,durbinWatsonTest)
Expand Down
71 changes: 71 additions & 0 deletions R/plotCGains.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#' @title Cumulative Gains Chart
#'
#' @description Cumulative Gains Chartis a plot of the rate of positive prediction against true positive rate for the different thresholds.
#' It is useful for measuring and comparing the accuracy of the classificators.
#' @param object An object of class ModelAudit
#' @param newdata optionally, a data frame in which to look for variables with which to plot CGains curve. If omitted, the data used to build model will be used.
#' @param newy optionally, required if newdata used. Response vector for new data.
#' @param ... other modelAudit objects to be plotted together
#'
#' @return ggplot object
#'
#' @seealso \code{\link{plot.modelAudit}}
#'
#' @import ggplot2
#' @import ROCR
#'
#' @examples
#' library(auditor)
#' library(mlbench)
#' library(randomForest)
#' data("PimaIndiansDiabetes")
#'
#' model_rf <- randomForest(diabetes~., data=PimaIndiansDiabetes)
#' au_rf <- audit(model_rf, label="rf")
#' plotCGains(au_rf)
#'
#' model_glm <- glm(diabetes~., family=binomial, data=PimaIndiansDiabetes)
#' au_glm <- audit(model_glm)
#' plotCGains(au_rf, au_glm)
#'
#' @export


plotCGains <- function(object, ..., newdata = NULL, newy){
if(class(object)!="modelAudit") stop("plotCGains requires object class modelAudit.")
rpp <- tpr <- label <- NULL
df <- getCGainsDF(object, newdata, newy)

dfl <- list(...)
if (length(dfl) > 0) {
for (resp in dfl) {
if(class(resp)=="modelAudit"){
df <- rbind( df, getCGainsDF(resp, newdata, newy) )
}
}
}

ggplot(df, aes(x = rpp, y = tpr, color = label)) +
geom_line() +
xlab("Rate of Positive Prediction") +
ylab("True Positive Rate") +
theme_light()
}

getCGainsDF <- function(object, newdata, newy){
if (is.null(newdata)) {
predictions <- object$fitted.values
y <- object$y
} else {
if(is.null(newy)) stop("newy must be provided.")
predictions <- object$predict.function(object$model, newdata)
y <- newy
}

pred <- prediction(predictions, y)
gain <- performance(pred, "tpr", "rpp")

res <- data.frame(rpp = gain@x.values[[1]], tpr = gain@y.values[[1]], alpha = gain@alpha.values[[1]],
label = object$label)
return(res)
}
84 changes: 84 additions & 0 deletions R/plotLift.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#' @title Lift Chart
#'
#' @description Lift Chart shows the ratio of a model to a random guess.
#'
#' @param object An object of class ModelAudit
#' @param newdata optionally, a data frame in which to look for variables with which to plot CGains curve. If omitted, the data used to build model will be used.
#' @param newy optionally, required if newdata used. Response vector for new data.
#' @param groups number of groups
#' @param cumulative boolean. If TRUE cumulative lift curve will be plotted.
#' @param ... other modelAudit objects to be plotted together
#'
#' @return ggplot object
#'
#' @seealso \code{\link{plot.modelAudit}}
#'
#' @import ggplot2
#' @import dplyr
#'
#' @examples
#' library(auditor)
#' library(mlbench)
#' library(randomForest)
#' data("PimaIndiansDiabetes")
#'
#' model_rf <- randomForest(diabetes~., data=PimaIndiansDiabetes)
#' au_rf <- audit(model_rf, label="rf")
#' plotLIFT(au_rf)
#'
#' model_glm <- glm(diabetes~., family=binomial, data=PimaIndiansDiabetes)
#' au_glm <- audit(model_glm)
#' plotLIFT(au_rf, au_glm)
#'
#' @export


plotLIFT <- function(object, ..., newdata = NULL, newy, groups = 10, cumulative = TRUE){
if(class(object)!="modelAudit") stop("plotCGains requires object class modelAudit.")
depth <- lift <- label <- NULL
df <- getLIFTDF(object, newdata, newy, groups, cumulative)

dfl <- list(...)
if (length(dfl) > 0) {
for (resp in dfl) {
if(class(resp)=="modelAudit"){
df <- rbind( df, getLIFTDF(resp, newdata, newy, groups, cumulative) )
}
}
}

ggplot(df, aes(x = depth, y = lift, color = label)) +
geom_line() +
xlab("Percentage of observations") +
ylab("Lift") +
theme_light()
}

getLIFTDF <- function(object, newdata, newy, n.groups, cumulative = TRUE){
pred <- NULL
if (is.null(newdata)) {
predictions <- object$fitted.values
y <- as.numeric(as.character(object$y))
} else {
if(is.null(newy)) stop("newy must be provided.")
predictions <- object$predict.function(object$model, newdata)
y <- as.numeric(as.character(newy))
}
df <- data.frame(pred=predictions, y=y)
df <- arrange(df, desc(pred))

group <- ceiling(seq_along(df[,2])/floor(nrow(df)/n.groups))

cap <- floor(nrow(df)/n.groups) * n.groups
df <- stats::aggregate(df[1:cap,2], by=list(group[1:cap]), mean)

if (cumulative==TRUE) {
df[,2] <- cumsum(df[,2])/seq_along(df[,2])
}
colnames(df) <- c("depth", "lift")
df$lift <- df$lift/mean(y)
df$depth <- 100* df$depth / n.groups
df$label <- object$label
return(df)
}

2 changes: 1 addition & 1 deletion man/plot.modelAudit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

42 changes: 42 additions & 0 deletions man/plotCGains.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

46 changes: 46 additions & 0 deletions man/plotLIFT.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit ca79d16

Please sign in to comment.