Skip to content

Commit f4363e9

Browse files
committed
Add documentation and examples to GoldiLocksFirstTry
#4
1 parent e6cd103 commit f4363e9

File tree

1 file changed

+141
-67
lines changed

1 file changed

+141
-67
lines changed

src/main/scala/com/high-performance-spark-examples/GoldiLocks/GoldiLocksFirstTry.scala

+141-67
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
package com.highperformancespark.examples.goldilocks
22

3-
import org.apache.spark.Partition
43
import org.apache.spark.rdd.RDD
54
import org.apache.spark.sql.DataFrame
65
import org.apache.spark.storage.StorageLevel
76

87
import scala.collection.mutable
98
import scala.collection.mutable.ArrayBuffer
10-
9+
import scala.collection.Map;
10+
import scala.collection.mutable.MutableList;
1111

1212
object GoldiLocksGroupByKey {
1313
//tag::groupByKey[]
1414
def findRankStatistics(
1515
pairRDD: RDD[(Int, Double)],
16-
ranks: List[Long]): scala.collection.Map[Int, List[Double]] = {
16+
ranks: List[Long]): Map[Int, List[Double]] = {
1717
pairRDD.groupByKey().mapValues(iter => {
1818
val ar = iter.toArray.sorted
1919
ranks.map(n => ar(n.toInt))
@@ -25,42 +25,93 @@ object GoldiLocksGroupByKey {
2525
//tag::firstTry[]
2626
object GoldiLocksFirstTry {
2727

28-
def findQuantiles( dataFrame: DataFrame, targetRanks: List[Long] ) = {
29-
val n = dataFrame.schema.length
30-
val valPairs: RDD[(Double, Int)] = getPairs(dataFrame)
31-
val sorted = valPairs.sortByKey()
32-
sorted.persist(StorageLevel.MEMORY_AND_DISK)
33-
val parts : Array[Partition] = sorted.partitions
34-
val map1 = getTotalsForeachPart(sorted, parts.length, n )
35-
val map2 = getLocationsOfRanksWithinEachPart(targetRanks, map1, n)
36-
val result = findElementsIteratively(sorted, map2)
28+
/**
29+
* Find nth target rank for every column.
30+
*
31+
* For example:
32+
*
33+
* dataframe:
34+
* (0.0, 4.5, 7.7, 5.0)
35+
* (1.0, 5.5, 6.7, 6.0)
36+
* (2.0, 5.5, 1.5, 7.0)
37+
* (3.0, 5.5, 0.5, 7.0)
38+
* (4.0, 5.5, 0.5, 8.0)
39+
*
40+
* targetRanks:
41+
* 1, 3
42+
*
43+
* The output will be:
44+
* 0 -> (0.0, 2.0)
45+
* 1 -> (4.5, 5.5)
46+
* 2 -> (7.7, 1.5)
47+
* 3 -> (5.0, 7.0)
48+
*
49+
* @param dataFrame dataframe of doubles
50+
* @param targetRanks the required ranks for every column
51+
*
52+
* @return map of (column index, list of target ranks)
53+
*/
54+
def findQuantiles(dataFrame: DataFrame, targetRanks: List[Long]):
55+
Map[Int, Iterable[Double]] = {
56+
57+
val valueColumnPairs: RDD[(Double, Int)] = getValueColumnIndexPairs(dataFrame)
58+
val sortedValueColumnPairs = valueColumnPairs.sortByKey()
59+
sortedValueColumnPairs.persist(StorageLevel.MEMORY_AND_DISK)
60+
61+
val numOfColumns = dataFrame.schema.length
62+
val partitionColumnsFreq = getColumnFreqPerPartition(sortedValueColumnPairs, numOfColumns)
63+
val ranksLocations = getLocationsOfRanksWithinEachPart(targetRanks, partitionColumnsFreq, numOfColumns)
64+
val result = findElementsIteratively(sortedValueColumnPairs, ranksLocations)
3765
result.groupByKey().collectAsMap()
3866
}
3967

4068
/**
41-
* Step 1. Map the rows to pairs of (value, colIndex)
42-
* @param dataFrame of double columns to compute the rank satistics for
43-
* @return
69+
* Step 1. Map the rows to pairs of (value, column Index).
70+
*
71+
* For example:
72+
*
73+
* dataFrame:
74+
* 1.5, 1.25, 2.0
75+
* 5.25, 2.5, 1.5
76+
*
77+
* The output RDD will be:
78+
* (1.5, 0) (1.25, 1) (2.0, 2) (5.25, 0) (2.5, 1) (1.5, 2)
79+
*
80+
* @param dataFrame dateframe of doubles
81+
*
82+
* @return RDD of pairs (value, column Index)
4483
*/
45-
private def getPairs(dataFrame : DataFrame ): RDD[(Double, Int )] ={
46-
dataFrame.flatMap( row => row.toSeq.zipWithIndex.map{ case (v, index ) =>
47-
(v.toString.toDouble, index )})
84+
private def getValueColumnIndexPairs(dataFrame : DataFrame): RDD[(Double, Int)] = {
85+
dataFrame.flatMap(row => row.toSeq.zipWithIndex.map{ case (v, index) =>
86+
(v.toString.toDouble, index)})
4887
}
4988

5089
/**
51-
* Step 2. Find the number of elements for each column in each partition
52-
* @param sorted - the sorted (value, colIndex) pairs
53-
* @param numPartitions
54-
* @param n the number of columns
55-
* @return an RDD the length of the number of partitions, where each row contains
56-
* - the partition index
57-
* - an array, totalsPerPart where totalsPerPart(i) = the number of elements in column
58-
* i on this partition
90+
* Step 2. Find the number of elements for each column in each partition.
91+
*
92+
* For Example:
93+
*
94+
* sortedValueColumnPairs:
95+
* Partition 1: (1.5, 0) (1.25, 1) (2.0, 2) (5.25, 0)
96+
* Partition 2: (7.5, 1) (9.5, 2)
97+
*
98+
* numOfColumns: 3
99+
*
100+
* The output will be:
101+
* [(0, [2, 1, 1]), (1, [0, 1, 1])]
102+
*
103+
* @param sortedValueColumnPairs - sorted RDD of (value, column Index) pairs
104+
* @param numOfColumns the number of columns
105+
*
106+
* @return Array that contains (partition index, number of elements from every column on this partition)
59107
*/
60-
private def getTotalsForeachPart(sorted: RDD[(Double, Int)], numPartitions: Int, n : Int ) = {
61-
val zero = Array.fill[Long](n)(0)
62-
sorted.mapPartitionsWithIndex((partitionIndex : Int, it : Iterator[(Double, Int)]) => {
63-
val totalsPerPart : Array[Long] = it.aggregate(zero)(
108+
private def getColumnFreqPerPartition(sortedValueColumnPairs: RDD[(Double, Int)], numOfColumns : Int):
109+
Array[(Int, Array[Long])] = {
110+
111+
val zero = Array.fill[Long](numOfColumns)(0)
112+
113+
def aggregateColumnFrequencies (partitionIndex : Int, valueColumnPairs : Iterator[(Double, Int)]) = {
114+
val totalsPerPart : Array[Long] = valueColumnPairs.aggregate(zero)(
64115
(a : Array[Long], v : (Double ,Int)) => {
65116
val (value, colIndex) = v
66117
a(colIndex) = a(colIndex) + 1L
@@ -70,62 +121,85 @@ object GoldiLocksFirstTry {
70121
require(a.length == b.length)
71122
a.zip(b).map{ case(aVal, bVal) => aVal + bVal}
72123
})
124+
73125
Iterator((partitionIndex, totalsPerPart))
74-
}).collect()
126+
}
127+
128+
sortedValueColumnPairs.mapPartitionsWithIndex(aggregateColumnFrequencies).collect()
75129
}
130+
76131
/**
77132
* Step 3: For each Partition determine the index of the elements that are desired rank statistics
78-
* @param partitionMap- the result of the previous method
79-
* @return an Array, the length of the number of partitions where each row contains
80-
* - the partition index
81-
* - a list, relevantIndexList where relevantIndexList(i) = the index of an element on this
82-
* partition that matches one of the target ranks
133+
*
134+
* For Example:
135+
* targetRanks: 5
136+
* partitionColumnsFreq: [(0, [2, 3]), (1, [4, 1]), (2, [5, 2])]
137+
* numOfColumns: 2
138+
*
139+
* The output will be:
140+
* [(0, []), (1, [(0, 3)]), (2, [(1, 1)])]
141+
*
142+
* @param partitionColumnsFreq Array of (partition index, columns frequencies per this partition)
143+
*
144+
* @return Array that contains (partition index, relevantIndexList where relevantIndexList(i) = the index
145+
* of an element on this partition that matches one of the target ranks)
83146
*/
84147
private def getLocationsOfRanksWithinEachPart(targetRanks : List[Long],
85-
partitionMap : Array[(Int, Array[Long])], n : Int ) : Array[(Int, List[(Int, Long)])] = {
86-
val runningTotal = Array.fill[Long](n)(0)
87-
partitionMap.sortBy(_._1).map { case (partitionIndex, totals)=>
88-
val relevantIndexList = new scala.collection.mutable.MutableList[(Int, Long)]()
89-
totals.zipWithIndex.foreach{ case (colCount, colIndex) => {
148+
partitionColumnsFreq : Array[(Int, Array[Long])], numOfColumns : Int) : Array[(Int, List[(Int, Long)])] = {
149+
150+
val runningTotal = Array.fill[Long](numOfColumns)(0)
151+
152+
partitionColumnsFreq.sortBy(_._1).map { case (partitionIndex, columnsFreq) =>
153+
val relevantIndexList = new MutableList[(Int, Long)]()
154+
155+
columnsFreq.zipWithIndex.foreach{ case (colCount, colIndex) => {
90156
val runningTotalCol = runningTotal(colIndex)
157+
val ranksHere: List[Long] = targetRanks.filter(rank => (runningTotalCol < rank && runningTotalCol + colCount >= rank))
158+
159+
// for each of the rank statistics present add this column index and the index it will be at
160+
// on this partition (the rank - the running total)
161+
relevantIndexList ++= ranksHere.map(rank => (colIndex, rank - runningTotalCol))
162+
91163
runningTotal(colIndex) += colCount
92-
val ranksHere = targetRanks.filter(rank =>
93-
runningTotalCol <= rank && runningTotalCol + colCount >= rank
94-
)
95-
//for each of the rank statistics present add this column index and the index it will be
96-
//at on this partition (the rank - the running total)
97-
ranksHere.foreach(rank => {
98-
relevantIndexList += ((colIndex, rank-runningTotalCol))
99-
})
100164
}}
165+
101166
(partitionIndex, relevantIndexList.toList)
102167
}
103168
}
104169

105170
/**
106-
* Step4: Using the results of the previous method, scan the data and return the elements
107-
* which correspond to the rank statistics we are looking for in each column
108-
*/
109-
private def findElementsIteratively(sorted : RDD[(Double, Int)], locations : Array[(Int, List[(Int, Long)])]) = {
110-
sorted.mapPartitionsWithIndex((index : Int, it : Iterator[(Double, Int)]) => {
111-
val targetsInThisPart = locations(index)._2
112-
val len = targetsInThisPart.length
113-
if (len > 0) {
114-
val partMap = targetsInThisPart.groupBy(_._1).mapValues(_.map(_._2))
115-
val keysInThisPart = targetsInThisPart.map(_._1).distinct
171+
* Finds rank statistics elements using ranksLocations.
172+
*
173+
* @param sortedValueColumnPairs - sorted RDD of (value, colIndex) pairs
174+
* @param ranksLocations Array of (partition Index, list of (column index, rank index of this column at this partition))
175+
*
176+
* @return
177+
*/
178+
private def findElementsIteratively(sortedValueColumnPairs : RDD[(Double, Int)],
179+
ranksLocations : Array[(Int, List[(Int, Long)])]): RDD[(Int, Double)] = {
180+
181+
sortedValueColumnPairs.mapPartitionsWithIndex((partitionIndex : Int, valueColumnPairs : Iterator[(Double, Int)]) => {
182+
val targetsInThisPart: List[(Int, Long)] = ranksLocations(partitionIndex)._2
183+
if (!targetsInThisPart.isEmpty) {
184+
val columnsRelativeIndex: Map[Int, List[Long]] = targetsInThisPart.groupBy(_._1).mapValues(_.map(_._2))
185+
val columnsInThisPart = targetsInThisPart.map(_._1).distinct
186+
116187
val runningTotals : mutable.HashMap[Int, Long]= new mutable.HashMap()
117-
keysInThisPart.foreach(key => runningTotals+=((key, 0L)))
118-
val newIt : ArrayBuffer[(Int, Double)] = new scala.collection.mutable.ArrayBuffer()
119-
it.foreach{ case( value, colIndex) => {
120-
if(runningTotals isDefinedAt colIndex){
188+
runningTotals ++= columnsInThisPart.map(columnIndex => (columnIndex, 0L)).toMap
189+
190+
val result : ArrayBuffer[(Int, Double)] = new scala.collection.mutable.ArrayBuffer()
191+
192+
valueColumnPairs.foreach{ case(value, colIndex) => {
193+
if (runningTotals isDefinedAt colIndex) {
121194
val total = runningTotals(colIndex) + 1L
122195
runningTotals.update(colIndex, total)
123-
if(partMap(colIndex).contains(total)){
124-
newIt += ((colIndex,value ))
125-
}
196+
197+
if (columnsRelativeIndex(colIndex).contains(total))
198+
result += ((colIndex, value))
126199
}
127200
}}
128-
newIt.toIterator
201+
202+
result.toIterator
129203
}
130204
else {
131205
Iterator.empty

0 commit comments

Comments
 (0)