Skip to content

Commit d571e93

Browse files
authored
ARROW-17730: [Go] Implement Take kernels for FSB and VarBinary (apache#14127)
Authored-by: Matt Topol <[email protected]> Signed-off-by: Matt Topol <[email protected]>
1 parent 68e0fa7 commit d571e93

File tree

2 files changed

+218
-10
lines changed

2 files changed

+218
-10
lines changed

Diff for: go/arrow/compute/internal/kernels/vector_selection.go

+167-10
Original file line numberDiff line numberDiff line change
@@ -991,19 +991,171 @@ func binaryFilterImpl[OffsetT int32 | int64](ctx *exec.KernelCtx, values, filter
991991
return nil
992992
}
993993

994-
func FilterFSB(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error {
994+
func takeExecImpl[T exec.UintTypes](ctx *exec.KernelCtx, outputLen int64, values, indices *exec.ArraySpan, out *exec.ExecResult, visitValid func(int64) error, visitNull func() error) error {
995995
var (
996-
values = &batch.Values[0].Array
997-
selection = &batch.Values[1].Array
998-
outputLength = getFilterOutputSize(selection, ctx.State.(FilterState).NullSelection)
999-
valueSize = int64(values.Type.(arrow.FixedWidthDataType).Bytes())
1000-
valueData = values.Buffers[1].Buf[values.Offset*valueSize:]
996+
validityBuilder = validityBuilder{mem: exec.GetAllocator(ctx.Ctx)}
997+
indicesValues = exec.GetSpanValues[T](indices, 1)
998+
isValid = indices.Buffers[0].Buf
999+
valuesHaveNulls = values.MayHaveNulls()
1000+
1001+
indicesIsValid = bitutil.OptionalBitIndexer{Bitmap: isValid, Offset: int(indices.Offset)}
1002+
valuesIsValid = bitutil.OptionalBitIndexer{Bitmap: values.Buffers[0].Buf, Offset: int(values.Offset)}
1003+
bitCounter = bitutils.NewOptionalBitBlockCounter(isValid, indices.Offset, indices.Len)
1004+
pos int64
1005+
)
1006+
1007+
validityBuilder.Reserve(outputLen)
1008+
for pos < indices.Len {
1009+
block := bitCounter.NextBlock()
1010+
indicesHaveNulls := block.Popcnt < block.Len
1011+
if !indicesHaveNulls && !valuesHaveNulls {
1012+
// fastest path, neither indices nor values have nulls
1013+
validityBuilder.UnsafeAppendN(int64(block.Len), true)
1014+
for i := 0; i < int(block.Len); i++ {
1015+
if err := visitValid(int64(indicesValues[pos])); err != nil {
1016+
return err
1017+
}
1018+
pos++
1019+
}
1020+
} else if block.Popcnt > 0 {
1021+
// since we have to branch on whether indices are null or not,
1022+
// we combine the "non-null indices block but some values null"
1023+
// and "some null indices block but values non-null" into single loop
1024+
for i := 0; i < int(block.Len); i++ {
1025+
if (!indicesHaveNulls || indicesIsValid.GetBit(int(pos))) && valuesIsValid.GetBit(int(indicesValues[pos])) {
1026+
validityBuilder.UnsafeAppend(true)
1027+
if err := visitValid(int64(indicesValues[pos])); err != nil {
1028+
return err
1029+
}
1030+
} else {
1031+
validityBuilder.UnsafeAppend(false)
1032+
if err := visitNull(); err != nil {
1033+
return err
1034+
}
1035+
}
1036+
pos++
1037+
}
1038+
} else {
1039+
// the whole block is null
1040+
validityBuilder.UnsafeAppendN(int64(block.Len), false)
1041+
for i := 0; i < int(block.Len); i++ {
1042+
if err := visitNull(); err != nil {
1043+
return err
1044+
}
1045+
}
1046+
pos += int64(block.Len)
1047+
}
1048+
}
1049+
1050+
out.Len = int64(validityBuilder.bitLength)
1051+
out.Nulls = int64(validityBuilder.falseCount)
1052+
out.Buffers[0].WrapBuffer(validityBuilder.Finish())
1053+
return nil
1054+
}
1055+
1056+
func takeExec(ctx *exec.KernelCtx, outputLen int64, values, indices *exec.ArraySpan, out *exec.ExecResult, visitValid func(int64) error, visitNull func() error) error {
1057+
indexWidth := indices.Type.(arrow.FixedWidthDataType).Bytes()
1058+
1059+
switch indexWidth {
1060+
case 1:
1061+
return takeExecImpl[uint8](ctx, outputLen, values, indices, out, visitValid, visitNull)
1062+
case 2:
1063+
return takeExecImpl[uint16](ctx, outputLen, values, indices, out, visitValid, visitNull)
1064+
case 4:
1065+
return takeExecImpl[uint32](ctx, outputLen, values, indices, out, visitValid, visitNull)
1066+
case 8:
1067+
return takeExecImpl[uint64](ctx, outputLen, values, indices, out, visitValid, visitNull)
1068+
default:
1069+
return fmt.Errorf("%w: invalid index width", arrow.ErrInvalid)
1070+
}
1071+
}
1072+
1073+
type outputFn func(*exec.KernelCtx, int64, *exec.ArraySpan, *exec.ArraySpan, *exec.ExecResult, func(int64) error, func() error) error
1074+
type implFn func(*exec.KernelCtx, *exec.ExecSpan, int64, *exec.ExecResult, outputFn) error
1075+
1076+
func FilterExec(impl implFn, fn outputFn) exec.ArrayKernelExec {
1077+
return func(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error {
1078+
var (
1079+
selection = &batch.Values[1].Array
1080+
outputLength = getFilterOutputSize(selection, ctx.State.(FilterState).NullSelection)
1081+
)
1082+
return impl(ctx, batch, outputLength, out, fn)
1083+
}
1084+
}
1085+
1086+
func TakeExec(impl implFn, fn outputFn) exec.ArrayKernelExec {
1087+
return func(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error {
1088+
if ctx.State.(TakeState).BoundsCheck {
1089+
if err := checkIndexBounds(&batch.Values[1].Array, uint64(batch.Values[0].Array.Len)); err != nil {
1090+
return err
1091+
}
1092+
}
1093+
1094+
return impl(ctx, batch, batch.Values[1].Array.Len, out, fn)
1095+
}
1096+
}
1097+
1098+
func VarBinaryImpl[OffsetT int32 | int64](ctx *exec.KernelCtx, batch *exec.ExecSpan, outputLength int64, out *exec.ExecResult, fn outputFn) error {
1099+
var (
1100+
values = &batch.Values[0].Array
1101+
selection = &batch.Values[1].Array
1102+
rawOffsets = exec.GetSpanOffsets[OffsetT](values, 1)
1103+
rawData = values.Buffers[2].Buf
1104+
offsetBuilder = newBufferBuilder[OffsetT](exec.GetAllocator(ctx.Ctx))
1105+
dataBuilder = newBufferBuilder[uint8](exec.GetAllocator(ctx.Ctx))
1106+
)
1107+
1108+
// presize the data builder with a rough estimate of the required data size
1109+
if values.Len > 0 {
1110+
dataLength := rawOffsets[values.Len] - rawOffsets[0]
1111+
meanValueLen := float64(dataLength) / float64(values.Len)
1112+
dataBuilder.reserve(int(meanValueLen))
1113+
}
1114+
1115+
offsetBuilder.reserve(int(outputLength) + 1)
1116+
spaceAvail := dataBuilder.cap()
1117+
var offset OffsetT
1118+
err := fn(ctx, outputLength, values, selection, out,
1119+
func(idx int64) error {
1120+
offsetBuilder.unsafeAppend(offset)
1121+
valOffset := rawOffsets[idx]
1122+
valSize := rawOffsets[idx+1] - valOffset
1123+
1124+
offset += valSize
1125+
if valSize > OffsetT(spaceAvail) {
1126+
dataBuilder.reserve(int(valSize))
1127+
spaceAvail = dataBuilder.cap() - dataBuilder.len()
1128+
}
1129+
dataBuilder.unsafeAppendSlice(rawData[valOffset : valOffset+valSize])
1130+
spaceAvail -= int(valSize)
1131+
return nil
1132+
}, func() error {
1133+
offsetBuilder.unsafeAppend(offset)
1134+
return nil
1135+
})
1136+
1137+
if err != nil {
1138+
return err
1139+
}
1140+
1141+
offsetBuilder.unsafeAppend(offset)
1142+
out.Buffers[1].WrapBuffer(offsetBuilder.finish())
1143+
out.Buffers[2].WrapBuffer(dataBuilder.finish())
1144+
return nil
1145+
}
1146+
1147+
func FSBImpl(ctx *exec.KernelCtx, batch *exec.ExecSpan, outputLength int64, out *exec.ExecResult, fn outputFn) error {
1148+
var (
1149+
values = &batch.Values[0].Array
1150+
selection = &batch.Values[1].Array
1151+
valueSize = int64(values.Type.(arrow.FixedWidthDataType).Bytes())
1152+
valueData = values.Buffers[1].Buf[values.Offset*valueSize:]
10011153
)
10021154

10031155
out.Buffers[1].WrapBuffer(ctx.Allocate(int(valueSize * outputLength)))
10041156
buf := out.Buffers[1].Buf
10051157

1006-
err := filterExec(ctx, outputLength, values, selection, out,
1158+
err := fn(ctx, outputLength, values, selection, out,
10071159
func(idx int64) error {
10081160
start := idx * int64(valueSize)
10091161
copy(buf, valueData[start:start+valueSize])
@@ -1076,16 +1228,21 @@ func GetVectorSelectionKernels() (filterkernels, takeKernels []SelectionKernelDa
10761228
filterkernels = []SelectionKernelData{
10771229
{In: exec.NewMatchedInput(exec.Primitive()), Exec: PrimitiveFilter},
10781230
{In: exec.NewExactInput(arrow.Null), Exec: NullFilter},
1079-
{In: exec.NewIDInput(arrow.DECIMAL128), Exec: FilterFSB},
1080-
{In: exec.NewIDInput(arrow.DECIMAL256), Exec: FilterFSB},
1081-
{In: exec.NewIDInput(arrow.FIXED_SIZE_BINARY), Exec: FilterFSB},
1231+
{In: exec.NewIDInput(arrow.DECIMAL128), Exec: FilterExec(FSBImpl, filterExec)},
1232+
{In: exec.NewIDInput(arrow.DECIMAL256), Exec: FilterExec(FSBImpl, filterExec)},
1233+
{In: exec.NewIDInput(arrow.FIXED_SIZE_BINARY), Exec: FilterExec(FSBImpl, filterExec)},
10821234
{In: exec.NewMatchedInput(exec.BinaryLike()), Exec: FilterBinary},
10831235
{In: exec.NewMatchedInput(exec.LargeBinaryLike()), Exec: FilterBinary},
10841236
}
10851237

10861238
takeKernels = []SelectionKernelData{
10871239
{In: exec.NewExactInput(arrow.Null), Exec: NullTake},
10881240
{In: exec.NewMatchedInput(exec.Primitive()), Exec: PrimitiveTake},
1241+
{In: exec.NewIDInput(arrow.DECIMAL128), Exec: TakeExec(FSBImpl, takeExec)},
1242+
{In: exec.NewIDInput(arrow.DECIMAL256), Exec: TakeExec(FSBImpl, takeExec)},
1243+
{In: exec.NewIDInput(arrow.FIXED_SIZE_BINARY), Exec: TakeExec(FSBImpl, takeExec)},
1244+
{In: exec.NewMatchedInput(exec.BinaryLike()), Exec: TakeExec(VarBinaryImpl[int32], takeExec)},
1245+
{In: exec.NewMatchedInput(exec.LargeBinaryLike()), Exec: TakeExec(VarBinaryImpl[int64], takeExec)},
10891246
}
10901247
return
10911248
}

Diff for: go/arrow/compute/vector_selection_test.go

+51
Original file line numberDiff line numberDiff line change
@@ -663,11 +663,62 @@ func (tk *TakeKernelTestNumeric) TestTakeNumeric() {
663663
})
664664
}
665665

666+
type TakeKernelTestFSB struct {
667+
TakeKernelTestTyped
668+
}
669+
670+
func (tk *TakeKernelTestFSB) SetupSuite() {
671+
tk.dt = &arrow.FixedSizeBinaryType{ByteWidth: 3}
672+
}
673+
674+
func (tk *TakeKernelTestFSB) TestFixedSizeBinary() {
675+
// YWFh == base64("aaa")
676+
// YmJi == base64("bbb")
677+
// Y2Nj == base64("ccc")
678+
tk.assertTake(`["YWFh", "YmJi", "Y2Nj"]`, `[0, 1, 0]`, `["YWFh", "YmJi", "YWFh"]`)
679+
tk.assertTake(`[null, "YmJi", "Y2Nj"]`, `[0, 1, 0]`, `[null, "YmJi", null]`)
680+
tk.assertTake(`["YWFh", "YmJi", "Y2Nj"]`, `[null, 1, 0]`, `[null, "YmJi", "YWFh"]`)
681+
682+
tk.assertNoValidityBitmapUnknownNullCountJSON(tk.dt, `["YWFh", "YmJi", "Y2Nj"]`, `[0, 1, 0]`)
683+
684+
_, err := tk.takeJSON(tk.dt, `["YWFh", "YmJi", "Y2Nj"]`, arrow.PrimitiveTypes.Int8, `[0, 9, 0]`)
685+
tk.ErrorIs(err, arrow.ErrIndex)
686+
_, err = tk.takeJSON(tk.dt, `["YWFh", "YmJi", "Y2Nj"]`, arrow.PrimitiveTypes.Int64, `[2, 5]`)
687+
tk.ErrorIs(err, arrow.ErrIndex)
688+
}
689+
690+
type TakeKernelTestString struct {
691+
TakeKernelTestTyped
692+
}
693+
694+
func (tk *TakeKernelTestString) TestTakeString() {
695+
tk.Run(tk.dt.String(), func() {
696+
// base64 encoded so the binary non-utf8 arrays work
697+
// YQ== -> "a"
698+
// Yg== -> "b"
699+
// Yw== -> "c"
700+
tk.assertTake(`["YQ==", "Yg==", "Yw=="]`, `[0, 1, 0]`, `["YQ==", "Yg==", "YQ=="]`)
701+
tk.assertTake(`[null, "Yg==", "Yw=="]`, `[0, 1, 0]`, `[null, "Yg==", null]`)
702+
tk.assertTake(`["YQ==", "Yg==", "Yw=="]`, `[null, 1, 0]`, `[null, "Yg==", "YQ=="]`)
703+
704+
tk.assertNoValidityBitmapUnknownNullCountJSON(tk.dt, `["YQ==", "Yg==", "Yw=="]`, `[0, 1, 0]`)
705+
706+
_, err := tk.takeJSON(tk.dt, `["YQ==", "Yg==", "Yw=="]`, arrow.PrimitiveTypes.Int8, `[0, 9, 0]`)
707+
tk.ErrorIs(err, arrow.ErrIndex)
708+
_, err = tk.takeJSON(tk.dt, `["YQ==", "Yg==", "Yw=="]`, arrow.PrimitiveTypes.Int64, `[2, 5]`)
709+
tk.ErrorIs(err, arrow.ErrIndex)
710+
})
711+
}
712+
666713
func TestTakeKernels(t *testing.T) {
667714
suite.Run(t, new(TakeKernelTest))
668715
for _, dt := range numericTypes {
669716
suite.Run(t, &TakeKernelTestNumeric{TakeKernelTestTyped: TakeKernelTestTyped{dt: dt}})
670717
}
718+
suite.Run(t, new(TakeKernelTestFSB))
719+
for _, dt := range baseBinaryTypes {
720+
suite.Run(t, &TakeKernelTestString{TakeKernelTestTyped: TakeKernelTestTyped{dt: dt}})
721+
}
671722
}
672723

673724
func TestFilterKernels(t *testing.T) {

0 commit comments

Comments
 (0)