Skip to content

Commit

Permalink
[SPARK-43415][CONNECT][SQL] Implement KVGDS.agg with custom `mapVal…
Browse files Browse the repository at this point in the history
…ues` function

### What changes were proposed in this pull request?

This PR implements `KVGDS.agg(typedColumn)` function when there is a `mapValues` function defined. This use case was previously unsupported (`mapValues` won't be applied).

This PR marks the special handling of `kvds.reduce()` obsolete. However, we keep the server-side code to maintain compatibility with older clients.

This implementation is purely done on the client side, oblivious to the Connect server. The mechanism is to first create an intermediate DF that contains only two Struct columns:
```
df
 |- iv: struct<...schema of the original df...>
 |- v: struct<...schema of the output of the mapValues func...>
```
Then we re-write all grouping exprs to use `iv` column, and all aggregating exprs to use `v` column as input. The rule is as follows:

- Prefix every column reference with `iv` or `v`, e.g., `col1` becomes `iv.col1`.
- Rewrite `*` to
  - `iv.value`, if the original df schema is a primitive type; or
  - `iv`, if the original df schema is a struct type.

Follow-up:

- [SPARK-50837](https://issues.apache.org/jira/browse/SPARK-50837): fix wrong output column names. This issue is caused by us manipulating DF schema.
- [SPARK-50846](https://issues.apache.org/jira/browse/SPARK-50846): consolidate aggregator-to-proto transformation code path.

### Why are the changes needed?

To support a use case that is previously unsupported.

### Does this PR introduce _any_ user-facing change?

Yes, see the first section.

### How was this patch tested?

New test cases.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #49111 from xupefei/kvds-mapvalues.

Authored-by: Paddy Xu <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
  • Loading branch information
xupefei authored and hvanhovell committed Feb 6, 2025
1 parent aefaa66 commit b968ce1
Show file tree
Hide file tree
Showing 7 changed files with 318 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ import org.apache.spark.sql.{Encoder, Encoders}
*
* This class currently assumes there is at least one input row.
*/
@SerialVersionUID(5066084382969966160L)
private[sql] class ReduceAggregator[T: Encoder](func: (T, T) => T)
extends Aggregator[T, (Boolean, T), T] {

@transient private val encoder = implicitly[Encoder[T]]
@transient private lazy val encoder = implicitly[Encoder[T]]

private val _zero = encoder.clsTag.runtimeClass match {
case java.lang.Boolean.TYPE => false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,16 @@ object UDFAdaptors extends Serializable {
def mapToMapPartitions[V, U](f: MapFunction[V, U]): Iterator[V] => Iterator[U] =
values => values.map(f.call)

def mapValues[IV, V](
vFunc: IV => V,
ivIsStruct: Boolean,
vIsStruct: Boolean): Iterator[IV] => Iterator[(Any, Any)] = {
val ivFunc = (iv: IV) => identity(iv)
val wrappedIvFunc = if (ivIsStruct) ivFunc else ivFunc.andThen(Tuple1(_))
val wrappedVFunc = if (vIsStruct) vFunc else vFunc.andThen(Tuple1(_))
input => input.map(i => (wrappedIvFunc(i), wrappedVFunc(i)))
}

def foreachToForeachPartition[T](f: T => Unit): Iterator[T] => Unit =
values => values.foreach(f)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ package org.apache.spark.sql.connect
import java.sql.Timestamp
import java.util.Arrays

import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.{AnalysisException, Encoder, Encoders, Row}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Append
import org.apache.spark.sql.connect.test.{QueryTest, RemoteSparkSession}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions._
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout}
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -453,6 +454,18 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with RemoteSparkSessi
}
}

// TODO(SPARK-50837): "ds.schema" is wrong: the column is named as "iv.key".
ignore(
"SPARK-26085: fix key attribute name for atomic type for typed aggregation - mapValues") {
val ds = Seq(1, 2, 3).toDS()
assert(ds.groupByKey(x => x).mapValues(x => x).count().schema.head.name == "key")

// Enable legacy flag to follow previous Spark behavior
withSQLConf("spark.sql.legacy.dataset.nameNonStructGroupingKeyAsValue" -> "true") {
assert(ds.groupByKey(x => x).mapValues(x => x).count().schema.head.name == "value")
}
}

test("reduceGroups") {
val ds = Seq("abc", "xyz", "hello").toDS()
checkDatasetUnorderly(
Expand All @@ -469,10 +482,98 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with RemoteSparkSessi
(5, "hello"))
}

object IntSumAgg extends Aggregator[Int, Int, Int] {
def zero: Int = 0
def reduce(b: Int, a: Int): Int = b + a
def merge(b1: Int, b2: Int): Int = b1 + b2
def finish(reduction: Int): Int = reduction
def bufferEncoder: Encoder[Int] = Encoders.scalaInt
def outputEncoder: Encoder[Int] = Encoders.scalaInt
}

test("agg with mapValues - DS") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val values = ds
.groupByKey(_._1)
.mapValues(_._2 * 2) // value *= 2 to make sure `mapValues` is really applied
.agg(count("*"), IntSumAgg.toColumn)
.collect()
assert(values === Array(("a", 2, 60), ("b", 2, 6), ("c", 1, 2)))
}

test("agg with mapValues - DF") {
val df = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDF("kkk", "vvv")
val values = df
.groupByKey(_.getAs[String]("kkk"))
.mapValues(_.getAs[Int]("vvv") * 2) // value *= 2 to make sure `mapValues` is really applied
.agg(count("*"), IntSumAgg.toColumn)
.collect()
assert(values === Array(("a", 2, 60), ("b", 2, 6), ("c", 1, 2)))
}

test("agg with mapValues (RGDS to KVDS)") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val values = ds
.groupBy($"_1")
.as[String, (String, Int)]
.mapValues(_._2 * 2) // value *= 2 to make sure `mapValues` is really applied
.agg(count("*"), IntSumAgg.toColumn)
.collect()
assert(values === Array(("a", 2, 60), ("b", 2, 6), ("c", 1, 2)))
}

object IntTupleSumAgg extends Aggregator[(Int, Int), (Int, Int), Int] {
def zero: (Int, Int) = (0, 0)
def reduce(b: (Int, Int), a: (Int, Int)): (Int, Int) = (b._1 + a._1, b._2 + a._2)
def merge(b1: (Int, Int), b2: (Int, Int)): (Int, Int) = (b1._1 + b2._1, b1._2 + b2._2)
def finish(reduction: (Int, Int)): Int = reduction._1 + reduction._2
def bufferEncoder: Encoder[(Int, Int)] = Encoders.tuple(Encoders.scalaInt, Encoders.scalaInt)
def outputEncoder: Encoder[Int] = Encoders.scalaInt
}

test("agg with mapValues - Tuple - DS") {
val ds =
Seq(("a", (10, 10)), ("a", (20, 20)), ("b", (1, 1)), ("b", (2, 2)), ("c", (1, 1))).toDS()
val values = ds
.groupByKey(_._1)
.mapValues(v => (v._2._1 * 2, v._2._2 * 2))
.agg(count("_1"), IntTupleSumAgg.toColumn)
.collect()
assert(values === Array(("a", 2, 120), ("b", 2, 12), ("c", 1, 4)))
}

test("agg with mapValues - Tuple - DF") {
val ds = Seq(("a", (10, 10)), ("a", (20, 20)), ("b", (1, 1)), ("b", (2, 2)), ("c", (1, 1)))
.toDF("kkk", "vvv")
val values = ds
.groupByKey(_.getAs[String]("kkk"))
.mapValues { v =>
(v.getStruct(1).getAs[Int]("_1") * 2, v.getStruct(1).getAs[Int]("_2") * 2)
}
.agg(count("_1"), IntTupleSumAgg.toColumn)
.collect()
assert(values === Array(("a", 2, 120), ("b", 2, 12), ("c", 1, 4)))
}

test("agg with mapValues (RGDS to KVDS) - Tuple") {
val ds =
Seq(("a", (10, 10)), ("a", (20, 20)), ("b", (1, 1)), ("b", (2, 2)), ("c", (1, 1))).toDS()
val values = ds
.groupBy($"_1")
.as[String, (String, (Int, Int))]
.mapValues(v => (v._2._1 * 2, v._2._2 * 2))
.agg(count("*"), IntTupleSumAgg.toColumn)
.collect()
assert(values === Array(("a", 2, 120), ("b", 2, 12), ("c", 1, 4)))
}

test("groupby") {
val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 1, 1))
.toDF("key", "seq", "value")
val grouped = ds.groupBy($"key").as[String, (String, Int, Int)]
val grouped = ds
.groupBy($"key")
.as[String, (String, Int, Int)]
.mapValues(i => i.copy(_1 = "p_" + i._1))
val aggregated = grouped
.flatMapSortedGroups($"seq", expr("length(key)"), $"value") { (g, iter) =>
Iterator(g, iter.mkString(", "))
Expand All @@ -481,11 +582,11 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with RemoteSparkSessi
checkDatasetUnorderly(
aggregated,
"a",
"(a,1,10), (a,2,20)",
"(p_a,1,10), (p_a,2,20)",
"b",
"(b,1,2), (b,2,1)",
"(p_b,1,2), (p_b,2,1)",
"c",
"(c,1,1)")
"(p_c,1,1)")
}

test("SPARK-50693: groupby on unresolved plan") {
Expand Down
Loading

0 comments on commit b968ce1

Please sign in to comment.