|
| 1 | +# When running this example make sure to include the built Scala jar : |
| 2 | +# $SPARK_HOME/bin/pyspark --jars ./target/examples-0.0.1.jar --driver-class-path ./target/examples-0.0.1.jar |
| 3 | +# This example illustrates how to interface Scala and Python code, but caution |
| 4 | +# should be taken as it depends on many private members that may change in |
| 5 | +# future releases of Spark. |
| 6 | + |
| 7 | +from pyspark.sql.types import * |
| 8 | +from pyspark.sql import DataFrame |
| 9 | +import timeit |
| 10 | +import time |
| 11 | + |
| 12 | +def generate_scale_data(sqlCtx, rows, numCols): |
| 13 | + """ |
| 14 | + Generate scale data for the performance test. |
| 15 | +
|
| 16 | + This also illustrates calling custom Scala code from the driver. |
| 17 | +
|
| 18 | + .. Note: This depends on many internal methods and may break between versions. |
| 19 | + """ |
| 20 | + sc = sqlCtx._sc |
| 21 | + # Get the SQL Context, 2.0 and pre-2.0 syntax |
| 22 | + try: |
| 23 | + javaSqlCtx = sqlCtx._jsqlContext |
| 24 | + except: |
| 25 | + javaSqlCtx = sqlCtx._ssql_ctx |
| 26 | + jsc = sc._jsc |
| 27 | + scalasc = jsc.sc() |
| 28 | + gateway = sc._gateway |
| 29 | + # Call a java method that gives us back an RDD of JVM Rows (Int, Double) |
| 30 | + # While Python RDDs are wrapped Java RDDs (even of Rows) the contents are different, so we |
| 31 | + # can't directly wrap this. |
| 32 | + # This returns a Java RDD of Rows - normally it would better to |
| 33 | + # return a DataFrame directly, but for illustration we will work with an RDD |
| 34 | + # of Rows. |
| 35 | + java_rdd = gateway.jvm.com.highperformancespark.examples.tools.GenerateScalingData. \ |
| 36 | + generateMiniScaleRows(scalasc, rows, numCols) |
| 37 | + # Schemas are serialized to JSON and sent back and forth |
| 38 | + # Construct a Python Schema and turn it into a Java Schema |
| 39 | + schema = StructType([StructField("zip", IntegerType()), StructField("fuzzyness", DoubleType())]) |
| 40 | + jschema = javaSqlCtx.parseDataType(schema.json()) |
| 41 | + # Convert the Java RDD to Java DataFrame |
| 42 | + java_dataframe = javaSqlCtx.createDataFrame(java_rdd, jschema) |
| 43 | + # Wrap the Java DataFrame into a Python DataFrame |
| 44 | + python_dataframe = DataFrame(java_dataframe, sqlCtx) |
| 45 | + # Convert the Python DataFrame into an RDD |
| 46 | + pairRDD = python_dataframe.rdd.map(lambda row: (row[0], row[1])) |
| 47 | + return (python_dataframe, pairRDD) |
| 48 | + |
| 49 | +def runOnDF(df): |
| 50 | + result = df.groupBy("zip").avg("fuzzyness").count() |
| 51 | + return result |
| 52 | + |
| 53 | +def runOnRDD(rdd): |
| 54 | + result = rdd.map(lambda (x, y): (x, (y, 1))). \ |
| 55 | + reduceByKey(lambda x, y: (x[0] + y [0], x[1] + y[1])). \ |
| 56 | + count() |
| 57 | + return result |
| 58 | + |
| 59 | +def groupOnRDD(rdd): |
| 60 | + return rdd.groupByKey().mapValues(lambda v: sum(v) / float(len(v))).count() |
| 61 | + |
| 62 | +def run(sc, sqlCtx, scalingFactor, size): |
| 63 | + (input_df, input_rdd) = generate_scale_data(sqlCtx, scalingFactor, size) |
| 64 | + input_rdd.cache().count() |
| 65 | + rddTimeings = timeit.repeat(stmt=lambda: runOnRDD(input_rdd), repeat=10, number=1, timer=time.time, setup='gc.enable()') |
| 66 | + groupTimeings = timeit.repeat(stmt=lambda: groupOnRDD(input_rdd), repeat=10, number=1, timer=time.time, setup='gc.enable()') |
| 67 | + input_df.cache().count() |
| 68 | + dfTimeings = timeit.repeat(stmt=lambda: runOnDF(input_df), repeat=10, number=1, timer=time.time, setup='gc.enable()') |
| 69 | + print "RDD:" |
| 70 | + print rddTimeings |
| 71 | + print "group:" |
| 72 | + print groupTimeings |
| 73 | + print "df:" |
| 74 | + print dfTimeings |
| 75 | + print "yay" |
| 76 | + |
| 77 | +if __name__ == "__main__": |
| 78 | + |
| 79 | + """ |
| 80 | + Usage: simple_perf_test scalingFactor size |
| 81 | + """ |
| 82 | + import sys |
| 83 | + from pyspark import SparkContext |
| 84 | + from pyspark.sql import SQLContext |
| 85 | + scalingFactor = int(sys.argv[1]) |
| 86 | + size = int(sys.argv[2]) |
| 87 | + sc = SparkContext(appName="SimplePythonPerf") |
| 88 | + sqlCtx = SQLContext(sc) |
| 89 | + run(sc, sqlCtx, scalingFactor, size) |
| 90 | + |
| 91 | + sc.stop() |
0 commit comments