Skip to content

Commit c642425

Browse files
committed
support oracle aqsort
Signed-off-by: BornChanger <[email protected]> (cherry picked from commit f531a797a8400376cc4906291f3b7000a2b74294)
1 parent 3acb646 commit c642425

19 files changed

+2059
-8
lines changed

pkg/executor/sortexec/BUILD.bazel

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
33
go_library(
44
name = "sortexec",
55
srcs = [
6+
"aqsort_switch.go",
67
"multi_way_merge.go",
78
"parallel_sort_spill_helper.go",
89
"parallel_sort_worker.go",
@@ -25,8 +26,10 @@ go_library(
2526
"//pkg/sessionctx/vardef",
2627
"//pkg/types",
2728
"//pkg/util",
29+
"//pkg/util/aqsort",
2830
"//pkg/util/channel",
2931
"//pkg/util/chunk",
32+
"//pkg/util/codec",
3033
"//pkg/util/dbterror/exeerrors",
3134
"//pkg/util/disk",
3235
"//pkg/util/logutil",
@@ -53,7 +56,7 @@ go_test(
5356
],
5457
embed = [":sortexec"],
5558
flaky = True,
56-
shard_count = 19,
59+
shard_count = 23,
5760
deps = [
5861
"//pkg/config",
5962
"//pkg/executor/internal/exec",
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// Copyright 2026 PingCAP, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package sortexec
16+
17+
import (
18+
"sync/atomic"
19+
20+
"github.com/pingcap/tidb/pkg/util/logutil"
21+
"go.uber.org/zap"
22+
)
23+
24+
var aqsortEnabled atomic.Bool
25+
26+
// SetAQSortEnabled enables/disables the experimental AQS-based in-memory sort path
27+
// for SortExec.
28+
//
29+
// This switch is intended for benchmarking/experimentation only. The default is
30+
// disabled.
31+
func SetAQSortEnabled(enabled bool) { aqsortEnabled.Store(enabled) }
32+
33+
func isAQSortEnabled() bool { return aqsortEnabled.Load() }
34+
35+
type aqsortControl struct {
36+
enabled atomic.Bool
37+
warned atomic.Bool
38+
39+
connID uint64
40+
executorID int
41+
}
42+
43+
func newAQSortControl(enabled bool, connID uint64, executorID int) *aqsortControl {
44+
ctrl := &aqsortControl{
45+
connID: connID,
46+
executorID: executorID,
47+
}
48+
ctrl.enabled.Store(enabled)
49+
return ctrl
50+
}
51+
52+
func (c *aqsortControl) isEnabled() bool {
53+
if c == nil {
54+
return false
55+
}
56+
return c.enabled.Load()
57+
}
58+
59+
func (c *aqsortControl) disableWithWarn(err error, extraFields ...zap.Field) {
60+
if c == nil {
61+
return
62+
}
63+
c.enabled.Store(false)
64+
if c.warned.CompareAndSwap(false, true) {
65+
fields := make([]zap.Field, 0, 3+len(extraFields))
66+
fields = append(fields,
67+
zap.Error(err),
68+
zap.Uint64("conn_id", c.connID),
69+
zap.Int("executor_id", c.executorID),
70+
)
71+
fields = append(fields, extraFields...)
72+
logutil.BgLogger().Warn("AQSort disabled for SortExec, falling back to std sort", fields...)
73+
}
74+
}

pkg/executor/sortexec/parallel_sort_worker.go

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,13 @@ import (
2222

2323
"github.com/pingcap/errors"
2424
"github.com/pingcap/failpoint"
25+
"github.com/pingcap/tidb/pkg/types"
26+
"github.com/pingcap/tidb/pkg/util/aqsort"
2527
"github.com/pingcap/tidb/pkg/util/chunk"
28+
"github.com/pingcap/tidb/pkg/util/codec"
2629
"github.com/pingcap/tidb/pkg/util/memory"
2730
"github.com/pingcap/tidb/pkg/util/sqlkiller"
31+
"go.uber.org/zap"
2832
)
2933

3034
// SignalCheckpointForSort indicates the times of row comparation that a signal detection will be triggered.
@@ -54,6 +58,19 @@ type parallelSortWorker struct {
5458
merger *multiWayMerger
5559

5660
sqlKiller *sqlkiller.SQLKiller
61+
62+
fieldTypes []*types.FieldType
63+
keyColumns []int
64+
byItemsDesc []bool
65+
loc *time.Location
66+
67+
aqsPairs []aqsort.Pair[chunk.Row]
68+
aqsSorter aqsort.PairSorter[chunk.Row]
69+
aqsArena []byte
70+
aqsKeyCap int
71+
72+
useAQSort bool
73+
aqsortCtrl *aqsortControl
5774
}
5875

5976
func newParallelSortWorker(
@@ -68,6 +85,10 @@ func newParallelSortWorker(
6885
maxChunkSize int,
6986
spillHelper *parallelSortSpillHelper,
7087
sqlKiller *sqlkiller.SQLKiller,
88+
fieldTypes []*types.FieldType,
89+
keyColumns []int,
90+
byItemsDesc []bool,
91+
loc *time.Location,
7192
) *parallelSortWorker {
7293
return &parallelSortWorker{
7394
workerIDForTest: workerIDForTest,
@@ -82,6 +103,12 @@ func newParallelSortWorker(
82103
maxSortedRowsLimit: maxChunkSize * 30,
83104
spillHelper: spillHelper,
84105
sqlKiller: sqlKiller,
106+
fieldTypes: fieldTypes,
107+
keyColumns: keyColumns,
108+
byItemsDesc: byItemsDesc,
109+
loc: loc,
110+
aqsKeyCap: 64,
111+
useAQSort: isAQSortEnabled(),
85112
}
86113
}
87114

@@ -172,10 +199,119 @@ func (p *parallelSortWorker) convertChunksToRows() []chunk.Row {
172199

173200
func (p *parallelSortWorker) sortBatchRows() {
174201
rows := p.convertChunksToRows()
175-
slices.SortFunc(rows, p.keyColumnsLess)
202+
useAQSort := p.useAQSort
203+
if p.aqsortCtrl != nil {
204+
useAQSort = p.aqsortCtrl.isEnabled()
205+
}
206+
if useAQSort {
207+
p.sortBatchRowsByEncodedKey(rows)
208+
} else {
209+
slices.SortFunc(rows, p.keyColumnsLess)
210+
}
176211
p.localSortedRows = append(p.localSortedRows, chunk.NewIterator4Slice(rows))
177212
}
178213

214+
func (p *parallelSortWorker) sortBatchRowsByEncodedKey(rows []chunk.Row) {
215+
if len(rows) <= 1 {
216+
return
217+
}
218+
if p.loc == nil {
219+
p.loc = time.UTC
220+
}
221+
222+
failpoint.Inject("AQSortForceEncodeKeyError", func(val failpoint.Value) {
223+
if val.(bool) {
224+
err := errors.NewNoStackError("injected aqsort sort key encode error")
225+
if p.aqsortCtrl != nil {
226+
p.aqsortCtrl.disableWithWarn(err, zap.Int("worker_id", p.workerIDForTest))
227+
} else {
228+
p.useAQSort = false
229+
}
230+
slices.SortFunc(rows, p.keyColumnsLess)
231+
failpoint.Return()
232+
}
233+
})
234+
235+
if cap(p.aqsPairs) < len(rows) {
236+
p.aqsPairs = make([]aqsort.Pair[chunk.Row], len(rows))
237+
} else {
238+
p.aqsPairs = p.aqsPairs[:len(rows)]
239+
}
240+
241+
keyCap := p.aqsKeyCap
242+
if keyCap < 64 {
243+
keyCap = 64
244+
}
245+
neededArena := len(rows) * keyCap
246+
if cap(p.aqsArena) < neededArena {
247+
p.aqsArena = make([]byte, neededArena)
248+
} else {
249+
p.aqsArena = p.aqsArena[:neededArena]
250+
}
251+
252+
maxKeyLen := 0
253+
for i := range rows {
254+
if i%1024 == 0 {
255+
p.checkKillSignal()
256+
}
257+
key := p.aqsArena[i*keyCap : i*keyCap : i*keyCap+keyCap]
258+
encoded, err := p.encodeRowSortKey(rows[i], key)
259+
if err != nil {
260+
if p.aqsortCtrl != nil {
261+
p.aqsortCtrl.disableWithWarn(err, zap.Int("worker_id", p.workerIDForTest))
262+
} else {
263+
p.useAQSort = false
264+
}
265+
slices.SortFunc(rows, p.keyColumnsLess)
266+
return
267+
}
268+
if ln := len(encoded); ln > maxKeyLen {
269+
maxKeyLen = ln
270+
}
271+
p.aqsPairs[i] = aqsort.Pair[chunk.Row]{Key: encoded, Val: rows[i]}
272+
}
273+
if maxKeyLen > keyCap {
274+
if maxKeyLen > 1024 {
275+
p.aqsKeyCap = 1024
276+
} else {
277+
p.aqsKeyCap = maxKeyLen
278+
}
279+
}
280+
281+
p.aqsSorter.SortWithCheckpoint(p.aqsPairs, SignalCheckpointForSort, p.checkKillSignal)
282+
for i := range rows {
283+
rows[i] = p.aqsPairs[i].Val
284+
}
285+
}
286+
287+
func (p *parallelSortWorker) checkKillSignal() {
288+
if p.sqlKiller == nil {
289+
return
290+
}
291+
if err := p.sqlKiller.HandleSignal(); err != nil {
292+
panic(err)
293+
}
294+
}
295+
296+
func (p *parallelSortWorker) encodeRowSortKey(row chunk.Row, dst []byte) ([]byte, error) {
297+
key := dst[:0]
298+
for i, colIdx := range p.keyColumns {
299+
start := len(key)
300+
datum := row.GetDatum(colIdx, p.fieldTypes[colIdx])
301+
var err error
302+
key, err = codec.EncodeKey(p.loc, key, datum)
303+
if err != nil {
304+
return nil, err
305+
}
306+
if p.byItemsDesc[i] {
307+
for j := start; j < len(key); j++ {
308+
key[j] = ^key[j]
309+
}
310+
}
311+
}
312+
return key, nil
313+
}
314+
179315
func (p *parallelSortWorker) sortLocalRows() ([]chunk.Row, error) {
180316
// Handle Remaining batchRows whose row number is not over the `maxSortedRowsLimit`
181317
if p.rowNumInChunkIters > 0 {

pkg/executor/sortexec/sort.go

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ type SortExec struct {
9797
}
9898

9999
enableTmpStorageOnOOM bool
100+
aqsortCtrl *aqsortControl
100101
}
101102

102103
// When fetcher and workers are not created, we need to initiatively close these channels
@@ -156,6 +157,11 @@ func (e *SortExec) Open(ctx context.Context) error {
156157
e.fetched = &atomic.Bool{}
157158
e.fetched.Store(false)
158159
e.enableTmpStorageOnOOM = vardef.EnableTmpStorageOnOOM.Load()
160+
e.aqsortCtrl = newAQSortControl(
161+
e.Ctx().GetSessionVars().EnableAQSort || isAQSortEnabled(),
162+
e.Ctx().GetSessionVars().ConnectionID,
163+
e.ID(),
164+
)
159165
e.finishCh = make(chan struct{}, 1)
160166

161167
// To avoid duplicated initialization for TopNExec.
@@ -543,7 +549,8 @@ func (e *SortExec) switchToNewSortPartition(fields []*types.FieldType, byItemsDe
543549
}
544550
}
545551

546-
e.curPartition = newSortPartition(fields, byItemsDesc, e.keyColumns, e.keyCmpFuncs, e.spillLimit, e.FileNamePrefixForTest)
552+
e.curPartition = newSortPartition(fields, byItemsDesc, e.keyColumns, e.keyCmpFuncs, e.spillLimit, e.FileNamePrefixForTest, e.Ctx().GetSessionVars().Location())
553+
e.curPartition.aqsortCtrl = e.aqsortCtrl
547554
e.curPartition.getMemTracker().AttachTo(e.memTracker)
548555
e.curPartition.getMemTracker().SetLabel(memory.LabelForRowChunks)
549556
e.Unparallel.spillAction = e.curPartition.actionSpill()
@@ -659,13 +666,35 @@ func (e *SortExec) fetchChunksUnparallel(ctx context.Context) error {
659666
}
660667

661668
func (e *SortExec) fetchChunksParallel(ctx context.Context) error {
669+
fields := exec.RetTypes(e)
670+
byItemsDesc := make([]bool, len(e.ByItems))
671+
for i, byItem := range e.ByItems {
672+
byItemsDesc[i] = byItem.Desc
673+
}
662674
// Wait for the finish of all workers
663675
workersWaiter := util.WaitGroupWrapper{}
664676
// Wait for the finish of chunk fetcher
665677
fetcherWaiter := util.WaitGroupWrapper{}
666678

667679
for i := range e.Parallel.workers {
668-
e.Parallel.workers[i] = newParallelSortWorker(i, e.lessRow, e.Parallel.chunkChannel, e.Parallel.fetcherAndWorkerSyncer, e.Parallel.resultChannel, e.finishCh, e.memTracker, e.Parallel.sortedRowsIters[i], e.MaxChunkSize(), e.Parallel.spillHelper, &e.Ctx().GetSessionVars().SQLKiller)
680+
e.Parallel.workers[i] = newParallelSortWorker(
681+
i,
682+
e.lessRow,
683+
e.Parallel.chunkChannel,
684+
e.Parallel.fetcherAndWorkerSyncer,
685+
e.Parallel.resultChannel,
686+
e.finishCh,
687+
e.memTracker,
688+
e.Parallel.sortedRowsIters[i],
689+
e.MaxChunkSize(),
690+
e.Parallel.spillHelper,
691+
&e.Ctx().GetSessionVars().SQLKiller,
692+
fields,
693+
e.keyColumns,
694+
byItemsDesc,
695+
e.Ctx().GetSessionVars().Location(),
696+
)
697+
e.Parallel.workers[i].aqsortCtrl = e.aqsortCtrl
669698
worker := e.Parallel.workers[i]
670699
workersWaiter.Run(func() {
671700
worker.run()

0 commit comments

Comments
 (0)