Skip to content

Commit

Permalink
Merge pull request #217 from ven-com/GL-3551-parallel-execution
Browse files Browse the repository at this point in the history
- Multi-core plotting Pareto optimum models
- Replace doFuture by doParallel with increased performance
  • Loading branch information
gufengzhou authored Nov 25, 2021
2 parents eb7b086 + 2253025 commit 73b39b7
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 60 deletions.
2 changes: 0 additions & 2 deletions R/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@ Depends:
Imports:
data.table,
doParallel,
doFuture,
doRNG,
foreach,
future,
ggplot2,
ggridges,
glmnet,
Expand Down
6 changes: 1 addition & 5 deletions R/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,13 @@ export(robyn_save)
export(saturation_hill)
import(data.table)
import(ggplot2)
importFrom(doFuture,registerDoFuture)
importFrom(doParallel,registerDoParallel)
importFrom(doParallel,stopImplicitCluster)
importFrom(doRNG,"%dorng%")
importFrom(foreach,"%dopar%")
importFrom(foreach,foreach)
importFrom(foreach,getDoParWorkers)
importFrom(foreach,registerDoSEQ)
importFrom(future,availableCores)
importFrom(future,multicore)
importFrom(future,plan)
importFrom(future,sequential)
importFrom(ggridges,geom_density_ridges)
importFrom(glmnet,cv.glmnet)
importFrom(glmnet,glmnet)
Expand Down
6 changes: 2 additions & 4 deletions R/R/imports.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@
#' @author Antonio Prada (aprada@@fb.com)
#' @author Igor Skokan (igorskokan@@fb.com)
#' @import data.table
#' @importFrom doFuture registerDoFuture
#' @importFrom doRNG %dorng%
#' @importFrom doParallel registerDoParallel
#' @importFrom doParallel registerDoParallel stopImplicitCluster
#' @importFrom foreach foreach %dopar% getDoParWorkers registerDoSEQ
#' @importFrom future multicore plan sequential availableCores
#' @import ggplot2
#' @importFrom ggridges geom_density_ridges
#' @importFrom glmnet cv.glmnet glmnet
Expand Down Expand Up @@ -56,7 +54,7 @@ dt_vars <- c(
"optmResponseUnitTotalLift", "optmSpendUnit", "optmSpendUnitTotalDelta", "param",
"perc", "percentage", "pos", "predicted", "refreshStatus", "response", "rn", "robynPareto",
"roi", "roi_mean", "roi_total", "rsq_lm", "rsq_nls", "rsq_train", "s0", "scale_shape_halflife",
"season", "sequential", "shape", "solID", "spend", "spend_share", "spend_share_refresh",
"season", "shape", "solID", "spend", "spend_share", "spend_share_refresh",
"theta", "theta_halflife", "total_spend", "trend", "trial", "type", "value", "variable",
"weekday", "x", "xDecompAgg", "xDecompMeanNon0", "xDecompMeanNon0Perc",
"xDecompMeanNon0PercRF", "xDecompMeanNon0RF", "xDecompPerc", "xDecompPercRF", "y", "yhat",
Expand Down
138 changes: 89 additions & 49 deletions R/R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,24 @@ robyn_run <- function(InputCollect,

t0 <- Sys.time()

# enable parallelisation of main modelling loop for MacOS and Linux only
parallel_processing <- .Platform$OS.type == "unix"
if (parallel_processing) {
message(paste(
"Using", InputCollect$adstock, "adstocking with",
length(InputCollect$hyperparameters),
"hyperparameters & 10-fold ridge x-validation on",
InputCollect$cores, "cores"
))
} else {
message(paste(
"Using", InputCollect$adstock, "adstocking with",
length(InputCollect$hyperparameters),
"hyperparameters & 10-fold ridge x-validation on 1 core (Windows fallback)"
))
}

# ng_collect <- list()
message(paste(
"Using", InputCollect$adstock, "adstocking with",
length(InputCollect$hyperparameters),
"hyperparameters & 10-fold ridge x-validation on",
InputCollect$cores, "cores"
))
model_output_collect <- list()

message(paste(
Expand Down Expand Up @@ -240,11 +251,10 @@ robyn_run <- function(InputCollect,
#decompSpendDist <- decompSpendDist[xDecompAgg[rn %in% InputCollect$paid_media_vars, .(rn, xDecompAgg, solID)], on = c("rn", "solID")]

## get mean_response
registerDoFuture()
if (.Platform$OS.type == "unix") {
plan(multicore, workers = InputCollect$cores)
if (parallel_processing) {
registerDoParallel(InputCollect$cores)
} else {
plan(sequential)
registerDoSEQ()
}

# if (hyper_fixed == FALSE) {pb <- txtProgressBar(min=1, max = length(decompSpendDist$rn), style = 3)}
Expand All @@ -270,6 +280,7 @@ robyn_run <- function(InputCollect,
return(dt_resp)
}
#if (hyper_fixed == FALSE) close(pb)
stopImplicitCluster()
registerDoSEQ()
getDoParWorkers()

Expand Down Expand Up @@ -454,9 +465,22 @@ robyn_run <- function(InputCollect,
#####################################
#### Plot each pareto solution

# ggplot doesn't work with process forking on MacOS
# however it works fine on Linux and Windows
parallel_plotting <- Sys.info()["sysname"] != "Darwin"

if (plot_pareto) {
message(paste(">>> Plotting", num_pareto123, "Pareto optimum models..."))
pbplot <- txtProgressBar(max = num_pareto123, style = 3)
if (parallel_plotting) {
message(paste(">>> Plotting", num_pareto123, "Pareto optimum models on", InputCollect$cores, "cores..."))
} else {
message(paste(">>> Plotting", num_pareto123, "Pareto optimum models on 1 core (MacOS fallback)..."))
}
}

if (parallel_plotting) {
registerDoParallel(InputCollect$cores)
} else {
registerDoSEQ()
}

all_fronts <- unique(xDecompAgg$robynPareto)
Expand All @@ -466,6 +490,8 @@ robyn_run <- function(InputCollect,
}

cnt <- 0
pbplot <- txtProgressBar(max = num_pareto123, style = 3)

mediaVecCollect <- list()
xDecompVecCollect <- list()
meanResponseCollect <- list()
Expand All @@ -474,11 +500,10 @@ robyn_run <- function(InputCollect,
plotWaterfall <- xDecompAgg[robynPareto == pf]
uniqueSol <- plotMediaShare[, unique(solID)]

for (j in 1:length(uniqueSol)) {
cnt <- cnt + 1
parallelResult <- foreach(sid = uniqueSol) %dorng% {

## plot spend x effect share comparison
plotMediaShareLoop <- plotMediaShare[solID == uniqueSol[j]]
plotMediaShareLoop <- plotMediaShare[solID == sid]
rsq_train_plot <- plotMediaShareLoop[, round(unique(rsq_train), 4)]
nrmse_plot <- plotMediaShareLoop[, round(unique(nrmse), 4)]
decomp_rssd_plot <- plotMediaShareLoop[, round(unique(decomp.rssd), 4)]
Expand Down Expand Up @@ -517,7 +542,7 @@ robyn_run <- function(InputCollect,
)

## plot waterfall
plotWaterfallLoop <- plotWaterfall[solID == uniqueSol[j]][order(xDecompPerc)]
plotWaterfallLoop <- plotWaterfall[solID == sid][order(xDecompPerc)]
plotWaterfallLoop[, end := cumsum(xDecompPerc)]
plotWaterfallLoop[, end := 1 - end]
plotWaterfallLoop[, ":="(start = shift(end, fill = 1, type = "lag"),
Expand Down Expand Up @@ -549,7 +574,7 @@ robyn_run <- function(InputCollect,

## plot adstock rate

resultHypParamLoop <- resultHypParam[solID == uniqueSol[j]]
resultHypParamLoop <- resultHypParam[solID == sid]
hypParam <- unlist(resultHypParamLoop[, local_name, with = FALSE])

if (InputCollect$adstock == "geometric") {
Expand Down Expand Up @@ -738,7 +763,7 @@ robyn_run <- function(InputCollect,
dt_scurvePlotMean[channel == get_med, mean_response := get_response * coef]
dt_scurvePlotMean[channel == get_med, next_unit_response := get_response_marginal * coef - mean_response]
}
dt_scurvePlotMean[, solID := uniqueSol[j]]
dt_scurvePlotMean[, solID := sid]

p4 <- ggplot(data = dt_scurvePlot[channel %in% InputCollect$paid_media_vars], aes(x = spend, y = response, color = channel)) +
geom_line() +
Expand Down Expand Up @@ -766,7 +791,7 @@ robyn_run <- function(InputCollect,
col_order <- c("ds", "dep_var", InputCollect$all_ind_vars)
setcolorder(dt_transformDecomp, neworder = col_order)

xDecompVec <- dcast.data.table(xDecompAgg[solID == uniqueSol[j], .(rn, coef, solID)], solID ~ rn, value.var = "coef")
xDecompVec <- dcast.data.table(xDecompAgg[solID == sid, .(rn, coef, solID)], solID ~ rn, value.var = "coef")
if (!("(Intercept)" %in% names(xDecompVec))) {
xDecompVec[, "(Intercept)" := 0]
}
Expand All @@ -780,7 +805,7 @@ robyn_run <- function(InputCollect,
coefs = xDecompVec[, !c("solID", "(Intercept)")]
))
xDecompVec[, intercept := intercept]
xDecompVec[, ":="(depVarHat = rowSums(xDecompVec), solID = uniqueSol[j])]
xDecompVec[, ":="(depVarHat = rowSums(xDecompVec), solID = sid)]
xDecompVec <- cbind(dt_transformDecomp[, .(ds, dep_var)], xDecompVec)

xDecompVecPlot <- xDecompVec[, .(ds, dep_var, depVarHat)]
Expand Down Expand Up @@ -810,7 +835,7 @@ robyn_run <- function(InputCollect,

## save and aggregate one-pager plots

onepagerTitle <- paste0("Model one-pager, on pareto front ", pf, ", ID: ", uniqueSol[j])
onepagerTitle <- paste0("Model one-pager, on pareto front ", pf, ", ID: ", sid)

pg <- wrap_plots(p2, p5, p1, p4, p3, p6, ncol = 2) +
plot_annotation(title = onepagerTitle, theme = theme(plot.title = element_text(hjust = 0.5)))
Expand All @@ -819,12 +844,10 @@ robyn_run <- function(InputCollect,
# grid.draw(pg)
if (plot_pareto) {
ggsave(
filename = paste0(plot_folder, "/", plot_folder_sub, "/", uniqueSol[j], ".png"),
filename = paste0(plot_folder, "/", plot_folder_sub, "/", sid, ".png"),
plot = pg,
dpi = 600, width = 18, height = 18
)

setTxtProgressBar(pbplot, cnt)
}

## prepare output
Expand All @@ -834,21 +857,37 @@ robyn_run <- function(InputCollect,
dt_transformSaturationSpendReverse[, (InputCollect$organic_vars) := NA]
}

return(list(
mediaVecCollect = rbind(
dt_transformPlot[, ":="(type = "rawMedia", solID = sid)],
dt_transformSpend[, ":="(type = "rawSpend", solID = sid)],
dt_transformSpendMod[, ":="(type = "predictedExposure", solID = sid)],
dt_transformAdstock[, ":="(type = "adstockedMedia", solID = sid)],
dt_transformSaturation[, ":="(type = "saturatedMedia", solID = sid)],
dt_transformSaturationSpendReverse[, ":="(type = "saturatedSpendReversed", solID = sid)],
dt_transformSaturationDecomp[, ":="(type = "decompMedia", solID = sid)]
),
xDecompVecCollect = xDecompVec,
meanResponseCollect = dt_scurvePlotMean
))
} # end solution loop

mediaVecCollect[[cnt]] <- rbind(
dt_transformPlot[, ":="(type = "rawMedia", solID = uniqueSol[j])],
dt_transformSpend[, ":="(type = "rawSpend", solID = uniqueSol[j])],
dt_transformSpendMod[, ":="(type = "predictedExposure", solID = uniqueSol[j])],
dt_transformAdstock[, ":="(type = "adstockedMedia", solID = uniqueSol[j])],
dt_transformSaturation[, ":="(type = "saturatedMedia", solID = uniqueSol[j])],
dt_transformSaturationSpendReverse[, ":="(type = "saturatedSpendReversed", solID = uniqueSol[j])],
dt_transformSaturationDecomp[, ":="(type = "decompMedia", solID = uniqueSol[j])]
)
cnt <- cnt + length(uniqueSol)
setTxtProgressBar(pbplot, cnt)

xDecompVecCollect[[cnt]] <- xDecompVec
meanResponseCollect[[cnt]] <- dt_scurvePlotMean
} # end solution loop
# append parallel run results
mediaVecCollect <- append(mediaVecCollect, lapply(parallelResult, function (x) x$mediaVecCollect))
xDecompVecCollect <- append(xDecompVecCollect, lapply(parallelResult, function (x) x$xDecompVecCollect))
meanResponseCollect <- append(meanResponseCollect, lapply(parallelResult, function (x) x$meanResponseCollect))
} # end pareto front loop

close(pbplot)

if (parallel_plotting) {
# stop cluster to avoid memory leaks
stopImplicitCluster()
}

mediaVecCollect <- rbindlist(mediaVecCollect)
xDecompVecCollect <- rbindlist(xDecompVecCollect)
meanResponseCollect <- rbindlist(meanResponseCollect)
Expand Down Expand Up @@ -1076,6 +1115,17 @@ robyn_mmm <- function(hyper_collect,
}
# assign("InputCollect", InputCollect, envir = .GlobalEnv) # adding this to enable InputCollect reading during parallel
# opts <- list(progress = function(n) setTxtProgressBar(pb, n))

# enable parallelisation of main modelling loop for MacOS and Linux only
parallel_processing <- .Platform$OS.type == "unix"

# create cluster before big for-loop to minimize overhead for parallel backend registering
if (parallel_processing) {
registerDoParallel(InputCollect$cores)
} else {
registerDoSEQ()
}

sysTimeDopar <- system.time({
for (lng in 1:iterNG) { # lng = 1
nevergrad_hp <- list()
Expand Down Expand Up @@ -1120,17 +1170,7 @@ robyn_mmm <- function(hyper_collect,
nrmse.collect <- c()
decomp.rssd.collect <- c()
best_mape <- Inf
# registerDoParallel(cores) #registerDoParallel(cores=InputCollect$cores)

registerDoFuture()
if (.Platform$OS.type == "unix") {
plan(multicore, workers = cores)
} else {
plan(sequential)
}

# nbrOfWorkers()
getDoParWorkers()
doparCollect <- suppressPackageStartupMessages(
foreach(i = 1:iterPar) %dorng% { # i = 1
t1 <- Sys.time()
Expand Down Expand Up @@ -1406,13 +1446,10 @@ robyn_mmm <- function(hyper_collect,
}
) # end foreach parallel

# stopImplicitCluster()

nrmse.collect <- sapply(doparCollect, function(x) x$nrmse)
decomp.rssd.collect <- sapply(doparCollect, function(x) x$decomp.rssd)
mape.lift.collect <- sapply(doparCollect, function(x) x$mape.lift)


#####################################
#### Nevergrad tells objectives

Expand All @@ -1436,6 +1473,9 @@ robyn_mmm <- function(hyper_collect,

message("\n Finished in ", round(sysTimeDopar[3] / 60, 2), " mins")

# stop cluster to avoid memory leaks
stopImplicitCluster()

if (hyper_fixed == FALSE) close(pb)
registerDoSEQ()
getDoParWorkers()
Expand Down

0 comments on commit 73b39b7

Please sign in to comment.