Skip to content
This repository was archived by the owner on May 12, 2021. It is now read-only.

Commit b0b8a6f

Browse files
author
Chris Wewerka
committed
[PIO-193] Async support for predict and storage access, blocking code
wrapped in blocking construct
1 parent 998938b commit b0b8a6f

File tree

18 files changed

+597
-544
lines changed

18 files changed

+597
-544
lines changed

core/src/main/scala/org/apache/predictionio/controller/LAlgorithm.scala

+8-3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ import org.apache.predictionio.workflow.PersistentModelManifest
2424
import org.apache.spark.SparkContext
2525
import org.apache.spark.rdd.RDD
2626

27+
import scala.concurrent.duration._
28+
import scala.concurrent.{Await, ExecutionContext, Future, blocking}
29+
import scala.language.postfixOps
2730
import scala.reflect._
2831

2932
/** Base class of a local algorithm.
@@ -72,11 +75,13 @@ abstract class LAlgorithm[PD, M : ClassTag, Q, P]
7275
val glomQs: RDD[Array[(Long, Q)]] = qs.glom()
7376
val cartesian: RDD[(M, Array[(Long, Q)])] = mRDD.cartesian(glomQs)
7477
cartesian.flatMap { case (m, qArray) =>
75-
qArray.map { case (qx, q) => (qx, predict(m, q)) }
78+
qArray.map {
79+
case (qx, q) => (qx, Await.result(predict(m, q)(scala.concurrent.ExecutionContext.global), 60 minutes) )
80+
}
7681
}
7782
}
7883

79-
def predictBase(localBaseModel: Any, q: Q): P = {
84+
def predictBase(localBaseModel: Any, q: Q)(implicit ec: ExecutionContext): Future[P] = {
8085
predict(localBaseModel.asInstanceOf[M], q)
8186
}
8287

@@ -87,7 +92,7 @@ abstract class LAlgorithm[PD, M : ClassTag, Q, P]
8792
* @param q An input query.
8893
* @return A prediction.
8994
*/
90-
def predict(m: M, q: Q): P
95+
def predict(m: M, q: Q)(implicit ec: ExecutionContext): Future[P]
9196

9297
/** :: DeveloperApi ::
9398
* Engine developers should not use this directly (read on to see how local

core/src/main/scala/org/apache/predictionio/controller/P2LAlgorithm.scala

+6-3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ import org.apache.spark.SparkContext
2525
import org.apache.spark.SparkContext._
2626
import org.apache.spark.rdd.RDD
2727

28+
import scala.concurrent.duration._
29+
import scala.concurrent.{Await, ExecutionContext, Future, blocking}
30+
import scala.language.postfixOps
2831
import scala.reflect._
2932

3033
/** Base class of a parallel-to-local algorithm.
@@ -67,10 +70,10 @@ abstract class P2LAlgorithm[PD, M: ClassTag, Q: ClassTag, P]
6770
* @return Batch of predicted results
6871
*/
6972
def batchPredict(m: M, qs: RDD[(Long, Q)]): RDD[(Long, P)] = {
70-
qs.mapValues { q => predict(m, q) }
73+
qs.mapValues { q => Await.result(predict(m, q)(scala.concurrent.ExecutionContext.global), 60 minutes) }
7174
}
7275

73-
def predictBase(bm: Any, q: Q): P = predict(bm.asInstanceOf[M], q)
76+
def predictBase(bm: Any, q: Q)(implicit ec: ExecutionContext): Future[P] = predict(bm.asInstanceOf[M], q)
7477

7578
/** Implement this method to produce a prediction from a query and trained
7679
* model.
@@ -79,7 +82,7 @@ abstract class P2LAlgorithm[PD, M: ClassTag, Q: ClassTag, P]
7982
* @param query An input query.
8083
* @return A prediction.
8184
*/
82-
def predict(model: M, query: Q): P
85+
def predict(model: M, query: Q)(implicit ec: ExecutionContext): Future[P]
8386

8487
/** :: DeveloperApi ::
8588
* Engine developers should not use this directly (read on to see how

core/src/main/scala/org/apache/predictionio/controller/PAlgorithm.scala

+4-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ import org.apache.predictionio.workflow.PersistentModelManifest
2424
import org.apache.spark.SparkContext
2525
import org.apache.spark.rdd.RDD
2626

27+
import scala.concurrent.{ExecutionContext, Future}
28+
2729
/** Base class of a parallel algorithm.
2830
*
2931
* A parallel algorithm can be run in parallel on a cluster and produces a
@@ -72,7 +74,7 @@ abstract class PAlgorithm[PD, M, Q, P]
7274
def batchPredict(m: M, qs: RDD[(Long, Q)]): RDD[(Long, P)] =
7375
throw new NotImplementedError("batchPredict not implemented")
7476

75-
def predictBase(baseModel: Any, query: Q): P = {
77+
def predictBase(baseModel: Any, query: Q)(implicit ec: ExecutionContext): Future[P] = {
7678
predict(baseModel.asInstanceOf[M], query)
7779
}
7880

@@ -83,7 +85,7 @@ abstract class PAlgorithm[PD, M, Q, P]
8385
* @param query An input query.
8486
* @return A prediction.
8587
*/
86-
def predict(model: M, query: Q): P
88+
def predict(model: M, query: Q)(implicit ec: ExecutionContext): Future[P]
8789

8890
/** :: DeveloperApi ::
8991
* Engine developers should not use this directly (read on to see how parallel

core/src/main/scala/org/apache/predictionio/core/BaseAlgorithm.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ import net.jodah.typetools.TypeResolver
2626
import org.apache.spark.SparkContext
2727
import org.apache.spark.rdd.RDD
2828

29+
import scala.concurrent.{ExecutionContext, Future}
30+
2931
/** :: DeveloperApi ::
3032
* Base trait with default custom query serializer, exposed to engine developer
3133
* via [[org.apache.predictionio.controller.CustomQuerySerializer]]
@@ -90,7 +92,7 @@ abstract class BaseAlgorithm[PD, M, Q, P]
9092
* @return Predicted result
9193
*/
9294
@DeveloperApi
93-
def predictBase(bm: Any, q: Q): P
95+
def predictBase(bm: Any, q: Q)(implicit ec: ExecutionContext): Future[P]
9496

9597
/** :: DeveloperApi ::
9698
* Engine developers should not use this directly. Prepare a model for

core/src/main/scala/org/apache/predictionio/workflow/BatchPredict.scala

+22-14
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@ import org.apache.predictionio.workflow.CleanupFunctions
3232
import org.apache.spark.rdd.RDD
3333
import org.json4s._
3434
import org.json4s.native.JsonMethods._
35+
import scala.concurrent.duration._
36+
import scala.language.postfixOps
37+
import scala.concurrent.blocking
38+
import scala.concurrent.{Await, Future}
3539
import scala.language.existentials
40+
import scala.concurrent.ExecutionContext.Implicits.global
3641

3742
case class BatchPredictConfig(
3843
inputFilePath: String = "batchpredict-input.json",
@@ -207,23 +212,26 @@ object BatchPredict extends Logging {
207212
// Deploy logic. First call Serving.supplement, then Algo.predict,
208213
// finally Serving.serve.
209214
val supplementedQuery = serving.supplementBase(query)
210-
// TODO: Parallelize the following.
211-
val predictions = algorithms.zip(models).map { case (a, m) =>
215+
val predictionsFuture = Future.sequence(algorithms.zip(models).map { case (a, m) =>
212216
a.predictBase(m, supplementedQuery)
213-
}
217+
})
214218
// Notice that it is by design to call Serving.serve with the
215219
// *original* query.
216-
val prediction = serving.serveBase(query, predictions)
217-
// Combine query with prediction, so the batch results are
218-
// self-descriptive.
219-
val predictionJValue = JsonExtractor.toJValue(
220-
jsonExtractorOption,
221-
Map("query" -> query,
222-
"prediction" -> prediction),
223-
algorithms.head.querySerializer,
224-
algorithms.head.gsonTypeAdapterFactories)
225-
// Return JSON string
226-
compact(render(predictionJValue))
220+
val predFutureRdds = predictionsFuture.map {
221+
predictions =>
222+
val prediction = serving.serveBase(query, predictions)
223+
// Combine query with prediction, so the batch results are
224+
// self-descriptive.
225+
val predictionJValue = JsonExtractor.toJValue(
226+
jsonExtractorOption,
227+
Map("query" -> query,
228+
"prediction" -> prediction),
229+
algorithms.head.querySerializer,
230+
algorithms.head.gsonTypeAdapterFactories)
231+
// Return JSON string
232+
compact(render(predictionJValue))
233+
}
234+
Await.result(predFutureRdds, 60 minutes)
227235
}
228236

229237
predictionsRDD.saveAsTextFile(config.outputFilePath)

0 commit comments

Comments
 (0)