Skip to content

Commit 7c4b70e

Browse files
committedMay 23, 2016
Merge pull request #52 from mahmoudhanafy/port-UDF-to-java
Port UDFs to Java
2 parents cedfd7d + 37df289 commit 7c4b70e

File tree

2 files changed

+75
-1
lines changed
  • src/main
    • java/com/highperformancespark/examples/dataframe
    • scala/com/high-performance-spark-examples/dataframe

2 files changed

+75
-1
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package com.highperformancespark.examples.dataframe;
2+
3+
import org.apache.spark.sql.Row;
4+
import org.apache.spark.sql.SQLContext;
5+
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
6+
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
7+
import org.apache.spark.sql.types.*;
8+
9+
public class JavaUDFs {
10+
11+
public static void setupUDFs(SQLContext sqlContext) {
12+
sqlContext.udf().register("strlen", (String s) -> s.length(), DataTypes.StringType);
13+
}
14+
15+
public static void setupUDAFs(SQLContext sqlContext) {
16+
17+
class Avg extends UserDefinedAggregateFunction {
18+
19+
@Override
20+
public StructType inputSchema() {
21+
StructType inputSchema =
22+
new StructType(new StructField[]{new StructField("value", DataTypes.DoubleType, true, Metadata.empty())});
23+
return inputSchema;
24+
}
25+
26+
@Override
27+
public StructType bufferSchema() {
28+
StructType bufferSchema =
29+
new StructType(new StructField[]{
30+
new StructField("count", DataTypes.LongType, true, Metadata.empty()),
31+
new StructField("sum", DataTypes.DoubleType, true, Metadata.empty())
32+
});
33+
34+
return bufferSchema;
35+
}
36+
37+
@Override
38+
public DataType dataType() {
39+
return DataTypes.DoubleType;
40+
}
41+
42+
@Override
43+
public boolean deterministic() {
44+
return true;
45+
}
46+
47+
@Override
48+
public void initialize(MutableAggregationBuffer buffer) {
49+
buffer.update(0, 0L);
50+
buffer.update(1, 0.0);
51+
}
52+
53+
@Override
54+
public void update(MutableAggregationBuffer buffer, Row input) {
55+
buffer.update(0, buffer.getLong(0) + 1);
56+
buffer.update(1, buffer.getDouble(1) + input.getDouble(0));
57+
}
58+
59+
@Override
60+
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
61+
buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0));
62+
buffer1.update(1, buffer1.getDouble(1) + buffer2.getDouble(1));
63+
}
64+
65+
@Override
66+
public Object evaluate(Row buffer) {
67+
return buffer.getDouble(1) / buffer.getLong(0);
68+
}
69+
}
70+
71+
Avg average = new Avg();
72+
sqlContext.udf().register("ourAvg", average);
73+
}
74+
}

‎src/main/scala/com/high-performance-spark-examples/dataframe/UDFs.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ object UDFs {
4747
}
4848

4949
def evaluate(buffer: Row): Any = {
50-
math.pow(buffer.getDouble(1), 1.toDouble / buffer.getLong(0))
50+
buffer.getDouble(1) / buffer.getLong(0)
5151
}
5252
}
5353
// Optionally register

0 commit comments

Comments
 (0)
Please sign in to comment.