Skip to content

Commit a1d4522

Browse files
author
Chris Wewerka
committed
support for Future-based predict-method, see apache/predictionio#495
1 parent 5445b93 commit a1d4522

File tree

4 files changed

+167
-126
lines changed

4 files changed

+167
-126
lines changed

build.sbt

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ scalaVersion in ThisBuild := "2.11.11"
1414

1515
val mahoutVersion = "0.13.0"
1616

17-
val pioVersion = "0.12.0-incubating"
17+
val pioVersion = "0.14.0-SNAPSHOT"
1818

1919
val elasticsearchVersion = "5.5.2"
2020

src/main/scala/EsClient.scala

+18-12
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ import java.util
2222

2323
import grizzled.slf4j.Logger
2424
import org.apache.http.util.EntityUtils
25-
import org.apache.predictionio.data.storage.{ DataMap, Storage, StorageClientConfig }
25+
import org.apache.predictionio.data.storage.{DataMap, Storage, StorageClientConfig}
2626
import org.apache.predictionio.workflow.CleanupFunctions
2727
import org.apache.spark.SparkContext
2828
import org.apache.spark.rdd.RDD
2929
import org.elasticsearch.client.RestClient
3030
import org.apache.http.HttpHost
31-
import org.apache.http.auth.{ AuthScope, UsernamePasswordCredentials }
31+
import org.apache.http.auth.{AuthScope, UsernamePasswordCredentials}
3232
import org.apache.http.entity.ContentType
3333
import org.apache.http.entity.StringEntity
3434
import org.apache.http.impl.client.BasicCredentialsProvider
@@ -42,6 +42,9 @@ import org.elasticsearch.spark._
4242
import org.json4s.JValue
4343
import org.json4s.DefaultFormats
4444
import org.json4s.JsonAST.JString
45+
import ScalaRestClient.ExtendedScalaRestClient
46+
47+
import scala.concurrent.{ExecutionContext, Future}
4548
// import org.json4s.native.Serialization.writePretty
4649
import com.actionml.helpers.{ ItemID, ItemProps }
4750

@@ -367,20 +370,23 @@ object EsClient {
367370
* @param indexName the index to search
368371
* @return a [PredictedResults] collection
369372
*/
370-
def search(query: String, indexName: String): Option[JValue] = {
373+
def search(query: String, indexName: String)(implicit ec: ExecutionContext): Future[Option[JValue]] = {
371374
logger.info(s"Query:\n${query}")
372-
val response = client.performRequest(
375+
val responseFuture = client.performRequestFuture(
373376
"POST",
374377
s"/$indexName/_search",
375-
Map.empty[String, String].asJava,
378+
Map.empty[String, String],
376379
new StringEntity(query, ContentType.APPLICATION_JSON))
377-
response.getStatusLine.getStatusCode match {
378-
case 200 =>
379-
logger.info(s"Got source from query: ${query}")
380-
Some(parse(EntityUtils.toString(response.getEntity)))
381-
case _ =>
382-
logger.info(s"Query: ${query}\nproduced status code: ${response.getStatusLine.getStatusCode}")
383-
None
380+
responseFuture.map {
381+
response =>
382+
response.getStatusLine.getStatusCode match {
383+
case 200 =>
384+
logger.info(s"Got source from query: ${query}")
385+
Some(parse(EntityUtils.toString(response.getEntity)))
386+
case _ =>
387+
logger.info(s"Query: ${query}\nproduced status code: ${response.getStatusLine.getStatusCode}")
388+
None
389+
}
384390
}
385391
}
386392

src/main/scala/ScalaRestClient.scala

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package com.actionml
2+
3+
import org.apache.http.{Header, HttpEntity}
4+
import org.elasticsearch.client.{Response, ResponseListener, RestClient}
5+
import scala.collection.JavaConverters._
6+
import scala.concurrent.{Future, Promise}
7+
8+
object ScalaRestClient {
9+
10+
implicit class ExtendedScalaRestClient(restClient: RestClient) {
11+
12+
def performRequestFuture(method: String, endpoint: String, params: Map[String, String],
13+
entity: HttpEntity, headers: Header*): Future[Response] = {
14+
val promise: Promise[Response] = Promise()
15+
val responseListener = new ResponseListener {
16+
override def onSuccess(response: Response): Unit = promise.success(response)
17+
override def onFailure(exception: Exception): Unit = promise.failure(exception)
18+
}
19+
restClient.performRequestAsync(method, endpoint, params.asJava, entity, responseListener, headers: _*)
20+
promise.future
21+
}
22+
}
23+
}

src/main/scala/URAlgorithm.scala

+125-113
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ package com.actionml
2020
import java.util
2121

2222
import grizzled.slf4j.Logger
23-
import org.apache.predictionio.controller.{ P2LAlgorithm, Params }
24-
import org.apache.predictionio.data.storage.{ DataMap, Event, NullModel, PropertyMap }
23+
import org.apache.predictionio.controller.{P2LAlgorithm, Params}
24+
import org.apache.predictionio.data.storage.{DataMap, Event, NullModel, PropertyMap}
2525
import org.apache.predictionio.data.store.LEventStore
26-
import org.apache.mahout.math.cf.{ DownsamplableCrossOccurrenceDataset, SimilarityAnalysis }
26+
import org.apache.mahout.math.cf.{DownsamplableCrossOccurrenceDataset, SimilarityAnalysis}
2727
import org.apache.mahout.sparkbindings.indexeddataset.IndexedDatasetSpark
2828
import org.apache.spark.SparkContext
2929
import org.apache.spark.rdd.RDD
@@ -34,10 +34,12 @@ import org.json4s.JsonAST._
3434
import org.json4s.JsonDSL._
3535
import org.json4s.jackson.JsonMethods._
3636
import com.actionml.helpers._
37-
37+
import scala.concurrent.ExecutionContext
3838
import scala.collection.JavaConverters._
39+
import scala.concurrent.Future
3940
import scala.concurrent.duration.Duration
40-
import scala.language.{ implicitConversions, postfixOps }
41+
import scala.language.{implicitConversions, postfixOps}
42+
import ScalaRestClient.ExtendedScalaRestClient
4143

4244
/** Available value for algorithm param "RecsModel" */
4345
object RecsModels { // todo: replace this with rankings
@@ -481,51 +483,58 @@ class URAlgorithm(val ap: URAlgorithmParams)
481483
* @todo Need to prune that query to minimum required for data include, for instance no need for the popularity
482484
* ranking if no PopModel is being used, same for "must" clause and dates.
483485
*/
484-
def predict(model: NullModel, query: Query): PredictedResult = {
486+
def predict(model: NullModel, query: Query)(implicit ec: ExecutionContext): Future[PredictedResult] = {
485487

486488
queryEventNames = query.eventNames.getOrElse(modelEventNames) // eventNames in query take precedence
487489

488-
val (queryStr, blacklist) = buildQuery(ap, query, rankingFieldNames)
489-
// old es1 query
490-
// val searchHitsOpt = EsClient.search(queryStr, esIndex, queryEventNames)
491-
val searchHitsOpt = EsClient.search(queryStr, esIndex)
492-
493-
val withRanks = query.withRanks.getOrElse(false)
494-
val predictedResults = searchHitsOpt match {
495-
case Some(searchHits) =>
496-
val hits = (searchHits \ "hits" \ "hits").extract[Seq[JValue]]
497-
val recs = hits.map { hit =>
498-
if (withRanks) {
499-
val source = hit \ "source"
500-
val ranks: Map[String, Double] = rankingsParams map { backfillParams =>
501-
val backfillType = backfillParams.`type`.getOrElse(DefaultURAlgoParams.BackfillType)
502-
val backfillFieldName = backfillParams.name.getOrElse(PopModel.nameByType(backfillType))
503-
backfillFieldName -> (source \ backfillFieldName).extract[Double]
504-
} toMap
505-
506-
ItemScore((hit \ "_id").extract[String], (hit \ "_score").extract[Double],
507-
ranks = if (ranks.nonEmpty) Some(ranks) else None)
508-
} else {
509-
ItemScore((hit \ "_id").extract[String], (hit \ "_score").extract[Double])
510-
}
511-
}.toArray
512-
logger.info(s"Results: ${hits.length} retrieved of a possible ${(searchHits \ "hits" \ "total").extract[Long]}")
513-
PredictedResult(recs)
514-
515-
case _ =>
516-
logger.info(s"No results for query ${parse(queryStr)}")
517-
PredictedResult(Array.empty[ItemScore])
490+
val queryStrBlacklistFuture = buildQuery(ap, query, rankingFieldNames)
491+
492+
queryStrBlacklistFuture.flatMap {
493+
case (queryStr, blacklist) =>
494+
// old es1 query
495+
// val searchHitsOpt = EsClient.search(queryStr, esIndex, queryEventNames)
496+
val searchHitsOptFuture = EsClient.search(queryStr, esIndex)
497+
498+
val withRanks = query.withRanks.getOrElse(false)
499+
searchHitsOptFuture.map {
500+
searchHitsOpt =>
501+
val predictedResults = searchHitsOpt match {
502+
case Some(searchHits) =>
503+
val hits = (searchHits \ "hits" \ "hits").extract[Seq[JValue]]
504+
val recs = hits.map { hit =>
505+
if (withRanks) {
506+
val source = hit \ "source"
507+
val ranks: Map[String, Double] = rankingsParams map { backfillParams =>
508+
val backfillType = backfillParams.`type`.getOrElse(DefaultURAlgoParams.BackfillType)
509+
val backfillFieldName = backfillParams.name.getOrElse(PopModel.nameByType(backfillType))
510+
backfillFieldName -> (source \ backfillFieldName).extract[Double]
511+
} toMap
512+
513+
ItemScore((hit \ "_id").extract[String], (hit \ "_score").extract[Double],
514+
ranks = if (ranks.nonEmpty) Some(ranks) else None)
515+
} else {
516+
ItemScore((hit \ "_id").extract[String], (hit \ "_score").extract[Double])
517+
}
518+
}.toArray
519+
logger.info(s"Results: ${hits.length} retrieved of a possible ${(searchHits \ "hits" \ "total").extract[Long]}")
520+
PredictedResult(recs)
521+
522+
case _ =>
523+
logger.info(s"No results for query ${parse(queryStr)}")
524+
PredictedResult(Array.empty[ItemScore])
525+
}
526+
527+
// todo: is this needed to remove ranked items from recs?
528+
//if (recsModel == RecsModels.CF) {
529+
// PredictedResult(predictedResults.filter(_.score != 0.0))
530+
//} else PredictedResult(predictedResults)
531+
532+
// should have all blacklisted items excluded
533+
// todo: need to add dithering, mean, sigma, seed required, make a seed that only changes on some fixed time
534+
// period so the recs ordering stays fixed for that time period.
535+
predictedResults
536+
}
518537
}
519-
520-
// todo: is this needed to remove ranked items from recs?
521-
//if (recsModel == RecsModels.CF) {
522-
// PredictedResult(predictedResults.filter(_.score != 0.0))
523-
//} else PredictedResult(predictedResults)
524-
525-
// should have all blacklisted items excluded
526-
// todo: need to add dithering, mean, sigma, seed required, make a seed that only changes on some fixed time
527-
// period so the recs ordering stays fixed for that time period.
528-
predictedResults
529538
}
530539

531540
/** Calculate all fields and items needed for ranking.
@@ -563,56 +572,60 @@ class URAlgorithm(val ap: URAlgorithmParams)
563572
def buildQuery(
564573
ap: URAlgorithmParams,
565574
query: Query,
566-
backfillFieldNames: Seq[String] = Seq.empty): (String, Seq[Event]) = {
575+
backfillFieldNames: Seq[String] = Seq.empty)(implicit ec: ExecutionContext): Future[(String, Seq[Event])] = {
567576

568577
logger.info(s"Got query: \n${query}")
569578

570579
val startPos = query.from.getOrElse(0)
571580
logger.info(s"from: ${startPos}")
572581

573-
try {
574-
// create a list of all query correlators that can have a bias (boost or filter) attached
575-
val (boostable, events) = getBiasedRecentUserActions(query)
576-
logger.info(s"getBiasedRecentUserActions returned boostable: ${boostable} and events: ${events}")
577-
578-
// since users have action history and items have correlators and both correspond to the same "actions" like
579-
// purchase or view, we'll pass both to the query if the user history or items correlators are empty
580-
// then metadata or backfill must be relied on to return results.
581-
val numRecs = if (query.num.isDefined) query.num.get else limit // num in query orerrides num in config
582-
logger.info(s"UR query num = ${query.num}")
583-
logger.info(s"query.num.getOrElse returned numRecs: ${numRecs}")
584-
585-
val should = buildQueryShould(query, boostable)
586-
logger.info(s"buildQueryShould returned should: ${should}")
587-
val must = buildQueryMust(query, boostable)
588-
logger.info(s"buildQueryMust returned must: ${must}")
589-
val mustNot = buildQueryMustNot(query, events)
590-
logger.info(s"buildQueryMustNot returned mustNot: ${mustNot}")
591-
val sort = buildQuerySort()
592-
logger.info(s"buildQuerySort returned sort: ${sort}")
593-
594-
val json =
595-
("from" -> startPos) ~
596-
("size" -> numRecs) ~
597-
("query" ->
598-
("bool" ->
599-
("should" -> should) ~
600-
("must" -> must) ~
601-
("must_not" -> mustNot) ~
602-
("minimum_should_match" -> 1))) ~
603-
("sort" -> sort)
604-
605-
logger.info(s"json is: ${json}")
606-
val compactJson = compact(render(json))
607-
logger.info(s"compact json is: ${compactJson}")
608-
609-
//logger.info(s"Query:\n$compactJson")
610-
(compactJson, events)
611-
} catch {
612-
case e: IllegalArgumentException => {
613-
logger.warn("whoops, IllegalArgumentException for something in buildQuery.")
614-
("", Seq.empty[Event])
615-
}
582+
// create a list of all query correlators that can have a bias (boost or filter) attached
583+
val biasedRecentUserActionsFuture = getBiasedRecentUserActions(query)
584+
585+
biasedRecentUserActionsFuture.map {
586+
case (boostable, events) =>
587+
try {
588+
logger.info(s"getBiasedRecentUserActions returned boostable: ${boostable} and events: ${events}")
589+
590+
// since users have action history and items have correlators and both correspond to the same "actions" like
591+
// purchase or view, we'll pass both to the query if the user history or items correlators are empty
592+
// then metadata or backfill must be relied on to return results.
593+
val numRecs = if (query.num.isDefined) query.num.get else limit // num in query orerrides num in config
594+
logger.info(s"UR query num = ${query.num}")
595+
logger.info(s"query.num.getOrElse returned numRecs: ${numRecs}")
596+
597+
val should = buildQueryShould(query, boostable)
598+
logger.info(s"buildQueryShould returned should: ${should}")
599+
val must = buildQueryMust(query, boostable)
600+
logger.info(s"buildQueryMust returned must: ${must}")
601+
val mustNot = buildQueryMustNot(query, events)
602+
logger.info(s"buildQueryMustNot returned mustNot: ${mustNot}")
603+
val sort = buildQuerySort()
604+
logger.info(s"buildQuerySort returned sort: ${sort}")
605+
606+
val json =
607+
("from" -> startPos) ~
608+
("size" -> numRecs) ~
609+
("query" ->
610+
("bool" ->
611+
("should" -> should) ~
612+
("must" -> must) ~
613+
("must_not" -> mustNot) ~
614+
("minimum_should_match" -> 1))) ~
615+
("sort" -> sort)
616+
617+
logger.info(s"json is: ${json}")
618+
val compactJson = compact(render(json))
619+
logger.info(s"compact json is: ${compactJson}")
620+
621+
//logger.info(s"Query:\n$compactJson")
622+
(compactJson, events)
623+
} catch {
624+
case e: IllegalArgumentException => {
625+
logger.warn("whoops, IllegalArgumentException for something in buildQuery.")
626+
("", Seq.empty[Event])
627+
}
628+
}
616629
}
617630
}
618631

@@ -792,10 +805,10 @@ class URAlgorithm(val ap: URAlgorithmParams)
792805
}
793806

794807
/** Get recent events of the user on items to create the recommendations query from */
795-
def getBiasedRecentUserActions(query: Query): (Seq[BoostableCorrelators], Seq[Event]) = {
808+
def getBiasedRecentUserActions(query: Query)(implicit ec: ExecutionContext): Future[(Seq[BoostableCorrelators], Seq[Event])] = {
796809

797-
val recentEvents = try {
798-
LEventStore.findByEntity(
810+
val recentEventsFuture =
811+
LEventStore.findByEntityAsync(
799812
appName = appName,
800813
// entityType and entityId is specified for fast lookup
801814
entityType = "user",
@@ -806,13 +819,9 @@ class URAlgorithm(val ap: URAlgorithmParams)
806819
// targetEntityType = None,
807820
// limit = Some(maxQueryEvents), // this will get all history then each action can be limited before using in
808821
// the query
809-
latest = true,
810-
// set time limit to avoid super long DB access
811-
timeout = Duration(200, "millis")).toSeq
812-
} catch {
813-
case e: scala.concurrent.TimeoutException =>
814-
logger.error(s"Timeout when reading recent events. Empty list is used. $e")
815-
Seq.empty[Event]
822+
latest = true).map(_.toSeq)
823+
824+
val recoveredRecentEventsFuture = recentEventsFuture.recover {
816825
case e: NoSuchElementException =>
817826
logger.info("No user id for recs, returning item-based recs if an item is specified in the query.")
818827
Seq.empty[Event]
@@ -821,21 +830,24 @@ class URAlgorithm(val ap: URAlgorithmParams)
821830
Seq.empty[Event]
822831
}
823832

824-
val userEventBias = query.userBias.getOrElse(userBias)
825-
val userEventsBoost = if (userEventBias > 0 && userEventBias != 1) Some(userEventBias) else None
826-
val rActions = queryEventNames.map { action =>
827-
var items = Seq.empty[String]
828-
829-
for (event <- recentEvents) { // todo: use indidatorParams for each indicator type
830-
if (event.event == action && items.size < indicatorParams(action).maxItemsPerUser) {
831-
items = event.targetEntityId.get +: items
832-
// todo: may throw exception and we should ignore the event instead of crashing
833+
recoveredRecentEventsFuture.map {
834+
recentEvents =>
835+
val userEventBias = query.userBias.getOrElse(userBias)
836+
val userEventsBoost = if (userEventBias > 0 && userEventBias != 1) Some(userEventBias) else None
837+
val rActions = queryEventNames.map { action =>
838+
var items = Seq.empty[String]
839+
840+
for (event <- recentEvents) { // todo: use indidatorParams for each indicator type
841+
if (event.event == action && items.size < indicatorParams(action).maxItemsPerUser) {
842+
items = event.targetEntityId.get +: items
843+
// todo: may throw exception and we should ignore the event instead of crashing
844+
}
845+
// userBias may be None, which will cause no JSON output for this
846+
}
847+
BoostableCorrelators(action, items.distinct, userEventsBoost)
833848
}
834-
// userBias may be None, which will cause no JSON output for this
835-
}
836-
BoostableCorrelators(action, items.distinct, userEventsBoost)
849+
(rActions, recentEvents)
837850
}
838-
(rActions, recentEvents)
839851
}
840852

841853
/** get all metadata fields that potentially have boosts (not filters) */

0 commit comments

Comments
 (0)