Skip to content

Commit 56f8b3b

Browse files
uros7251brickcloud-fan
authored andcommitted
[SPARK-53877] Introduce BITMAP_AND_AGG function
Currently, Spark has two bitmap aggregation functions `bitmap_construct_agg` and `bitmap_or_agg` for constructing a bitmap out of set of integers and performing union on two sets represented by bitmaps, respectively. However, efficient intersect operation (bitwise AND) is missing. ## What changes were proposed in this pull request? - **Implemented `bitmap_and_agg` expression**: New aggregation function that performs bitwise AND operations on binary column inputs. ### Design Decisions - **Result on empty input is identity element for the operation**: Empty input groups return all-ones bitmaps (AND identity). - **Missing bytes handling**: For AND operations, missing bytes in input are treated as zeros to maintain intersection semantics. ## How was this patch tested? Added new test cases to cover `bitmap_and_agg` functionality: - **`BitmapExpressionsQuerySuite`**: Added test cases for basic AND operations, edge cases, empty group handling, and integration with other bitmap functions. ### Does this PR introduce _any_ user-facing change? No. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#52586 from uros7251brick/add-bitmap-and-agg. Authored-by: Uros Stojkovic <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 343a25b commit 56f8b3b

File tree

9 files changed

+267
-1
lines changed

9 files changed

+267
-1
lines changed

python/pyspark/sql/connect/functions/builtin.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4619,6 +4619,13 @@ def bitmap_or_agg(col: "ColumnOrName") -> Column:
46194619
bitmap_or_agg.__doc__ = pysparkfuncs.bitmap_or_agg.__doc__
46204620

46214621

4622+
def bitmap_and_agg(col: "ColumnOrName") -> Column:
4623+
return _invoke_function_over_columns("bitmap_and_agg", col)
4624+
4625+
4626+
bitmap_and_agg.__doc__ = pysparkfuncs.bitmap_and_agg.__doc__
4627+
4628+
46224629
# Call Functions
46234630

46244631

python/pyspark/sql/functions/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@
349349
"bit_and",
350350
"bit_or",
351351
"bit_xor",
352+
"bitmap_and_agg",
352353
"bitmap_construct_agg",
353354
"bitmap_or_agg",
354355
"bool_and",

python/pyspark/sql/functions/builtin.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27441,6 +27441,7 @@ def bitmap_or_agg(col: "ColumnOrName") -> Column:
2744127441
:meth:`pyspark.sql.functions.bitmap_bucket_number`
2744227442
:meth:`pyspark.sql.functions.bitmap_construct_agg`
2744327443
:meth:`pyspark.sql.functions.bitmap_count`
27444+
:meth:`pyspark.sql.functions.bitmap_and_agg`
2744427445

2744527446
Parameters
2744627447
----------
@@ -27461,6 +27462,41 @@ def bitmap_or_agg(col: "ColumnOrName") -> Column:
2746127462
return _invoke_function_over_columns("bitmap_or_agg", col)
2746227463

2746327464

27465+
@_try_remote_functions
27466+
def bitmap_and_agg(col: "ColumnOrName") -> Column:
27467+
"""
27468+
Returns a bitmap that is the bitwise AND of all of the bitmaps from the input column.
27469+
The input column should be bitmaps created from bitmap_construct_agg().
27470+
27471+
.. versionadded:: 4.1.0
27472+
27473+
See Also
27474+
--------
27475+
:meth:`pyspark.sql.functions.bitmap_bit_position`
27476+
:meth:`pyspark.sql.functions.bitmap_bucket_number`
27477+
:meth:`pyspark.sql.functions.bitmap_construct_agg`
27478+
:meth:`pyspark.sql.functions.bitmap_count`
27479+
:meth:`pyspark.sql.functions.bitmap_or_agg`
27480+
27481+
Parameters
27482+
----------
27483+
col : :class:`~pyspark.sql.Column` or column name
27484+
The input column should be bitmaps created from bitmap_construct_agg().
27485+
27486+
Examples
27487+
--------
27488+
>>> from pyspark.sql import functions as sf
27489+
>>> df = spark.createDataFrame([("F0",),("70",),("30",)], ["a"])
27490+
>>> df.select(sf.bitmap_and_agg(sf.to_binary(df.a, sf.lit("hex")))).show()
27491+
+---------------------------------+
27492+
|bitmap_and_agg(to_binary(a, hex))|
27493+
+---------------------------------+
27494+
| [30 00 00 00 00 0...|
27495+
+---------------------------------+
27496+
"""
27497+
return _invoke_function_over_columns("bitmap_and_agg", col)
27498+
27499+
2746427500
# ---------------------------- User Defined Function ----------------------------------
2746527501

2746627502

sql/api/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4236,6 +4236,15 @@ object functions {
42364236
*/
42374237
def bitmap_or_agg(col: Column): Column = Column.fn("bitmap_or_agg", col)
42384238

4239+
/**
4240+
* Returns a bitmap that is the bitwise AND of all of the bitmaps from the input column. The
4241+
* input column should be bitmaps created from bitmap_construct_agg().
4242+
*
4243+
* @group agg_funcs
4244+
* @since 4.1.0
4245+
*/
4246+
def bitmap_and_agg(col: Column): Column = Column.fn("bitmap_and_agg", col)
4247+
42394248
//////////////////////////////////////////////////////////////////////////////////////////////
42404249
// String functions
42414250
//////////////////////////////////////////////////////////////////////////////////////////////

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/BitmapExpressionUtils.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,17 @@ public static void bitmapMerge(byte[] bitmap1, byte[] bitmap2) {
5656
bitmap1[i] = (byte) ((bitmap1[i] & 0x0FF) | (bitmap2[i] & 0x0FF));
5757
}
5858
}
59+
60+
/** Performs bitwise AND on both bitmaps and writes the result into bitmap1. */
61+
public static void bitmapAndMerge(byte[] bitmap1, byte[] bitmap2) {
62+
int minLen = java.lang.Math.min(bitmap1.length, bitmap2.length);
63+
for (int i = 0; i < minLen; ++i) {
64+
bitmap1[i] = (byte) ((bitmap1[i] & 0x0FF) & (bitmap2[i] & 0x0FF));
65+
}
66+
// For AND operation, any bytes beyond the input length should be set to 0
67+
// since they represent bits that don't exist in the input bitmap
68+
for (int i = bitmap2.length; i < bitmap1.length; ++i) {
69+
bitmap1[i] = 0;
70+
}
71+
}
5972
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,7 @@ object FunctionRegistry {
853853
expression[BitmapConstructAgg]("bitmap_construct_agg"),
854854
expression[BitmapCount]("bitmap_count"),
855855
expression[BitmapOrAgg]("bitmap_or_agg"),
856+
expression[BitmapAndAgg]("bitmap_and_agg"),
856857

857858
// json
858859
expression[StructsToJson]("to_json"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitmapExpressions.scala

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,3 +322,96 @@ case class BitmapOrAgg(child: Expression,
322322
buffer.getBinary(mutableAggBufferOffset)
323323
}
324324
}
325+
326+
@ExpressionDescription(
327+
usage = """
328+
_FUNC_(child) - Returns a bitmap that is the bitwise AND of all of the bitmaps from the child
329+
expression. The input should be bitmaps created from bitmap_construct_agg().
330+
""",
331+
// scalastyle:off line.size.limit
332+
examples = """
333+
Examples:
334+
> SELECT substring(hex(_FUNC_(col)), 0, 6) FROM VALUES (X 'F0'), (X '70'), (X '30') AS tab(col);
335+
300000
336+
> SELECT substring(hex(_FUNC_(col)), 0, 6) FROM VALUES (X 'FF'), (X 'FF'), (X 'FF') AS tab(col);
337+
FF0000
338+
""",
339+
// scalastyle:on line.size.limit
340+
since = "4.1.0",
341+
group = "agg_funcs")
342+
case class BitmapAndAgg(
343+
child: Expression,
344+
mutableAggBufferOffset: Int = 0,
345+
inputAggBufferOffset: Int = 0)
346+
extends ImperativeAggregate
347+
with UnaryLike[Expression] {
348+
349+
def this(child: Expression) = {
350+
this(child = child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
351+
}
352+
353+
override def checkInputDataTypes(): TypeCheckResult = {
354+
if (child.dataType != BinaryType) {
355+
DataTypeMismatch(
356+
errorSubClass = "UNEXPECTED_INPUT_TYPE",
357+
messageParameters = Map(
358+
"paramIndex" -> ordinalNumber(0),
359+
"requiredType" -> toSQLType(BinaryType),
360+
"inputSql" -> toSQLExpr(child),
361+
"inputType" -> toSQLType(child.dataType)))
362+
} else {
363+
TypeCheckSuccess
364+
}
365+
}
366+
367+
override def dataType: DataType = BinaryType
368+
369+
override def prettyName: String = "bitmap_and_agg"
370+
371+
override protected def withNewChildInternal(newChild: Expression): BitmapAndAgg =
372+
copy(child = newChild)
373+
374+
override def withNewMutableAggBufferOffset(
375+
newMutableAggBufferOffset: Int): ImperativeAggregate =
376+
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
377+
378+
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
379+
copy(inputAggBufferOffset = newInputAggBufferOffset)
380+
381+
override def nullable: Boolean = false
382+
383+
override def aggBufferSchema: StructType = DataTypeUtils.fromAttributes(aggBufferAttributes)
384+
385+
// The aggregation buffer is a fixed size binary.
386+
private val bitmapAttr = AttributeReference("bitmap", BinaryType, false)()
387+
388+
override def aggBufferAttributes: Seq[AttributeReference] = bitmapAttr :: Nil
389+
390+
override def defaultResult: Option[Literal] =
391+
Option(Literal(Array.fill[Byte](BitmapExpressionUtils.NUM_BYTES)(-1)))
392+
393+
override val inputAggBufferAttributes: Seq[AttributeReference] =
394+
aggBufferAttributes.map(_.newInstance())
395+
396+
override def initialize(buffer: InternalRow): Unit = {
397+
buffer.update(mutableAggBufferOffset, Array.fill[Byte](BitmapExpressionUtils.NUM_BYTES)(-1))
398+
}
399+
400+
override def update(buffer: InternalRow, input: InternalRow): Unit = {
401+
val input_bitmap = child.eval(input).asInstanceOf[Array[Byte]]
402+
if (input_bitmap != null) {
403+
val bitmap = buffer.getBinary(mutableAggBufferOffset)
404+
BitmapExpressionUtils.bitmapAndMerge(bitmap, input_bitmap)
405+
}
406+
}
407+
408+
override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = {
409+
val bitmap1 = buffer1.getBinary(mutableAggBufferOffset)
410+
val bitmap2 = buffer2.getBinary(inputAggBufferOffset)
411+
BitmapExpressionUtils.bitmapAndMerge(bitmap1, bitmap2)
412+
}
413+
414+
override def eval(buffer: InternalRow): Any = {
415+
buffer.getBinary(mutableAggBufferOffset)
416+
}
417+
}

sql/core/src/test/resources/sql-functions/sql-expression-schema.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
| org.apache.spark.sql.catalyst.expressions.Between | between | SELECT 0.5 between 0.1 AND 1.0 | struct<between(0.5, 0.1, 1.0):boolean> |
4949
| org.apache.spark.sql.catalyst.expressions.Bin | bin | SELECT bin(13) | struct<bin(13):string> |
5050
| org.apache.spark.sql.catalyst.expressions.BitLength | bit_length | SELECT bit_length('Spark SQL') | struct<bit_length(Spark SQL):int> |
51+
| org.apache.spark.sql.catalyst.expressions.BitmapAndAgg | bitmap_and_agg | SELECT substring(hex(bitmap_and_agg(col)), 0, 6) FROM VALUES (X 'F0'), (X '70'), (X '30') AS tab(col) | struct<substring(hex(bitmap_and_agg(col)), 0, 6):string> |
5152
| org.apache.spark.sql.catalyst.expressions.BitmapBitPosition | bitmap_bit_position | SELECT bitmap_bit_position(1) | struct<bitmap_bit_position(1):bigint> |
5253
| org.apache.spark.sql.catalyst.expressions.BitmapBucketNumber | bitmap_bucket_number | SELECT bitmap_bucket_number(123) | struct<bitmap_bucket_number(123):bigint> |
5354
| org.apache.spark.sql.catalyst.expressions.BitmapConstructAgg | bitmap_construct_agg | SELECT substring(hex(bitmap_construct_agg(bitmap_bit_position(col))), 0, 6) FROM VALUES (1), (2), (3) AS tab(col) | struct<substring(hex(bitmap_construct_agg(bitmap_bit_position(col))), 0, 6):string> |

sql/core/src/test/scala/org/apache/spark/sql/BitmapExpressionsQuerySuite.scala

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql
1919

20-
import org.apache.spark.sql.functions.{bitmap_bit_position, bitmap_bucket_number, bitmap_construct_agg, bitmap_count, bitmap_or_agg, col, hex, lit, substring, to_binary}
20+
import org.apache.spark.sql.functions.{bitmap_and_agg, bitmap_bit_position, bitmap_bucket_number, bitmap_construct_agg, bitmap_count, bitmap_or_agg, col, expr, hex, lit, substring, to_binary}
2121
import org.apache.spark.sql.test.SharedSparkSession
2222

2323
class BitmapExpressionsQuerySuite extends QueryTest with SharedSparkSession {
@@ -208,6 +208,95 @@ class BitmapExpressionsQuerySuite extends QueryTest with SharedSparkSession {
208208
)
209209
}
210210

211+
test("bitmap_and_agg") {
212+
// Test basic AND functionality: F0 & 70 & 30 = 30
213+
val df = Seq("F0", "70", "30").toDF("a")
214+
checkAnswer(
215+
df.selectExpr("substring(hex(bitmap_and_agg(to_binary(a, 'hex'))), 0, 6)"),
216+
Seq(Row("300000")))
217+
checkAnswer(
218+
df.select(substring(hex(bitmap_and_agg(to_binary(col("a"), lit("hex")))), 0, 6)),
219+
Seq(Row("300000")))
220+
221+
// Test with all 1s - should return FF
222+
val df2 = Seq("FF", "FF", "FF").toDF("a")
223+
checkAnswer(
224+
df2.selectExpr("substring(hex(bitmap_and_agg(to_binary(a, 'hex'))), 0, 6)"),
225+
Seq(Row("FF0000")))
226+
227+
// Test with mixed values - A0 & F0 & 80 = 80
228+
val df3 = Seq("A0", "F0", "80").toDF("a")
229+
checkAnswer(
230+
df3.selectExpr("substring(hex(bitmap_and_agg(to_binary(a, 'hex'))), 0, 6)"),
231+
Seq(Row("800000")))
232+
233+
// Test with one zero - anything & 00 = 00
234+
val df4 = Seq("FF", "00", "FF").toDF("a")
235+
checkAnswer(
236+
df4.selectExpr("substring(hex(bitmap_and_agg(to_binary(a, 'hex'))), 0, 6)"),
237+
Seq(Row("000000")))
238+
239+
// Test with binary values of different lengths - "FF" & "FFFF" & "FFFFFF" = "FF0000"
240+
val df5 = Seq("FF", "FFFF", "FFFFFF").toDF("a")
241+
checkAnswer(
242+
df5.selectExpr("substring(hex(bitmap_and_agg(to_binary(a, 'hex'))), 0, 6)"),
243+
Seq(Row("FF0000")))
244+
245+
// Test empty result (no rows) - should return all 1s as AND identity
246+
val emptyDf = Seq.empty[String].toDF("a")
247+
checkAnswer(
248+
emptyDf.selectExpr("substring(hex(bitmap_and_agg(to_binary(a, 'hex'))), 0, 6)"),
249+
Seq(Row("FFFFFF")))
250+
251+
val emptyDf2 = Seq.empty[(String, Int)].toDF("a", "b")
252+
checkAnswer(
253+
emptyDf2
254+
.selectExpr("bitmap_and_agg(to_binary(a, 'hex')) " +
255+
"filter (where b = 1) as and_agg")
256+
.select(substring(hex(col("and_agg")), 0, 6)),
257+
Seq(Row("FFFFFF")))
258+
259+
// Test empty result (no rows) - should return empty DataFrame
260+
val emptyDf3 = Seq.empty[(String, Int, Int)].toDF("a", "b", "c")
261+
checkAnswer(
262+
emptyDf3
263+
.groupBy("c")
264+
.agg(expr("bitmap_and_agg(to_binary(a, 'hex')) " +
265+
"filter (where b = 1) as and_agg").alias("and_agg"))
266+
.select(substring(hex(col("and_agg")), 0, 6)),
267+
Seq())
268+
}
269+
270+
test("bitmap_and_agg with complex bitmaps from bitmap_construct_agg") {
271+
val table = "bitmap_and_test_table"
272+
withTable(table) {
273+
// Create test data with overlapping bit positions
274+
spark.sql(s"""
275+
| CREATE TABLE $table (group_id INT, bit_pos LONG)
276+
| """.stripMargin)
277+
spark.sql(s"""
278+
| INSERT INTO $table VALUES
279+
| (1, 1), (1, 3), (1, 5), -- Group 1: bits 1,3,5 set
280+
| (2, 1), (2, 2), (2, 3), -- Group 2: bits 1,2,3 set
281+
| (3, 3), (3, 4), (3, 5) -- Group 3: bits 3,4,5 set
282+
| """.stripMargin)
283+
// Each group should have their respective bit counts, but when we AND them together
284+
// we should get the intersection
285+
val intersectionResult = spark.sql(s"""
286+
| SELECT bitmap_count(
287+
| bitmap_and_agg(group_bitmap)
288+
| ) as intersection_count
289+
| FROM (
290+
| SELECT bitmap_construct_agg(bitmap_bit_position(bit_pos)) as group_bitmap
291+
| FROM $table
292+
| GROUP BY group_id
293+
| )
294+
| """.stripMargin)
295+
// The intersection should be 1 (only bit 3 is common)
296+
checkAnswer(intersectionResult, Seq(Row(1)))
297+
}
298+
}
299+
211300
test("bitmap_count called with non-binary type") {
212301
val df = Seq(12).toDF("a")
213302
checkError(
@@ -251,4 +340,20 @@ class BitmapExpressionsQuerySuite extends QueryTest with SharedSparkSession {
251340
)
252341
)
253342
}
343+
344+
test("bitmap_and_agg called with non-binary type") {
345+
val df = Seq(12).toDF("a")
346+
checkError(
347+
exception = intercept[AnalysisException] {
348+
df.selectExpr("bitmap_and_agg(a)")
349+
},
350+
condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
351+
parameters = Map(
352+
"sqlExpr" -> "\"bitmap_and_agg(a)\"",
353+
"paramIndex" -> "first",
354+
"requiredType" -> "\"BINARY\"",
355+
"inputSql" -> "\"a\"",
356+
"inputType" -> "\"INT\""),
357+
context = ExpectedContext(fragment = "bitmap_and_agg(a)", start = 0, stop = 16))
358+
}
254359
}

0 commit comments

Comments
 (0)