diff --git a/br/pkg/restore/ingestrec/ingest_recorder.go b/br/pkg/restore/ingestrec/ingest_recorder.go index f406d36b0d680..772c18635476b 100644 --- a/br/pkg/restore/ingestrec/ingest_recorder.go +++ b/br/pkg/restore/ingestrec/ingest_recorder.go @@ -247,3 +247,71 @@ func (i *IngestRecorder) IterateForeignKeys(f func(*ForeignKeyRecord) error) err } return nil } + +// RecorderState is a serializable snapshot of ingest recorder data. +type RecorderState struct { + Items map[int64]map[int64]IndexState `json:"items,omitempty"` +} + +// IndexState is a minimal representation of an ingested index. +type IndexState struct { + IsPrimary bool `json:"is_primary,omitempty"` +} + +// ExportState returns a snapshot of the ingest recorder state. +func (i *IngestRecorder) ExportState() *RecorderState { + if i == nil || len(i.items) == 0 { + return nil + } + state := &RecorderState{ + Items: make(map[int64]map[int64]IndexState, len(i.items)), + } + for tableID, indexes := range i.items { + if len(indexes) == 0 { + continue + } + tableIndexes := make(map[int64]IndexState, len(indexes)) + for indexID, info := range indexes { + if info == nil { + continue + } + tableIndexes[indexID] = IndexState{IsPrimary: info.IsPrimary} + } + if len(tableIndexes) > 0 { + state.Items[tableID] = tableIndexes + } + } + if len(state.Items) == 0 { + return nil + } + return state +} + +// MergeState merges a snapshot into the ingest recorder. +func (i *IngestRecorder) MergeState(state *RecorderState) { + if i == nil || state == nil || len(state.Items) == 0 { + return + } + if i.items == nil { + i.items = make(map[int64]map[int64]*IngestIndexInfo) + } + for tableID, indexes := range state.Items { + if len(indexes) == 0 { + continue + } + tableIndexes, exists := i.items[tableID] + if !exists { + tableIndexes = make(map[int64]*IngestIndexInfo, len(indexes)) + i.items[tableID] = tableIndexes + } + for indexID, info := range indexes { + if _, ok := tableIndexes[indexID]; ok { + continue + } + tableIndexes[indexID] = &IngestIndexInfo{ + IsPrimary: info.IsPrimary, + Updated: false, + } + } + } +} diff --git a/br/pkg/restore/log_client/client.go b/br/pkg/restore/log_client/client.go index 07d77db537a56..78f5d93c1b76f 100644 --- a/br/pkg/restore/log_client/client.go +++ b/br/pkg/restore/log_client/client.go @@ -1008,14 +1008,6 @@ type FullBackupStorageConfig struct { Opts *storeapi.Options } -type GetIDMapConfig struct { - // required - LoadSavedIDMap bool - - // optional - TableMappingManager *stream.TableMappingManager -} - // GetBaseIDMapAndMerge get the id map from following ways // 1. from previously saved id map if the same task has been running and built/saved id map already but failed later // 2. from previous different task. A PiTR job might be split into multiple runs/tasks and each task only restores @@ -1025,20 +1017,22 @@ func (rc *LogClient) GetBaseIDMapAndMerge( hasFullBackupStorageConfig, loadSavedIDMap bool, logCheckpointMetaManager checkpoint.LogMetaManagerT, - tableMappingManager *stream.TableMappingManager, -) error { +) (*SegmentedPiTRState, error) { var ( - err error - dbMaps []*backuppb.PitrDBMap - dbReplaces map[stream.UpstreamID]*stream.DBReplace + err error + state *SegmentedPiTRState + dbMaps []*backuppb.PitrDBMap ) // this is a retry, id map saved last time, load it from external storage if loadSavedIDMap { log.Info("try to load previously saved pitr id maps") - dbMaps, err = rc.loadSchemasMap(ctx, rc.restoreTS, logCheckpointMetaManager) + state, err = rc.loadSegmentedPiTRState(ctx, rc.restoreTS, logCheckpointMetaManager, true) if err != nil { - return errors.Trace(err) + return nil, errors.Trace(err) + } + if state != nil { + dbMaps = state.DbMaps } } @@ -1046,28 +1040,25 @@ func (rc *LogClient) GetBaseIDMapAndMerge( // schemas map whose `restore-ts`` is the task's `start-ts`. if len(dbMaps) <= 0 && !hasFullBackupStorageConfig { log.Info("try to load pitr id maps of the previous task", zap.Uint64("start-ts", rc.startTS)) - dbMaps, err = rc.loadSchemasMap(ctx, rc.startTS, logCheckpointMetaManager) + state, err = rc.loadSegmentedPiTRState(ctx, rc.startTS, logCheckpointMetaManager, false) if err != nil { - return errors.Trace(err) + return nil, errors.Trace(err) } - err := rc.validateNoTiFlashReplica() - if err != nil { - return errors.Trace(err) + if state != nil { + dbMaps = state.DbMaps + } + if len(dbMaps) > 0 { + if err := rc.validateNoTiFlashReplica(); err != nil { + return nil, errors.Trace(err) + } } } if len(dbMaps) <= 0 && !hasFullBackupStorageConfig { log.Error("no id maps found") - return errors.New("no base id map found from saved id or last restored PiTR") - } - dbReplaces = stream.FromDBMapProto(dbMaps) - - stream.LogDBReplaceMap("base db replace info", dbReplaces) - if len(dbReplaces) != 0 { - tableMappingManager.SetFromPiTRIDMap() - tableMappingManager.MergeBaseDBReplace(dbReplaces) + return nil, errors.New("no base id map found from saved id or last restored PiTR") } - return nil + return state, nil } func SortMetaKVFiles(files []*backuppb.DataFileInfo) []*backuppb.DataFileInfo { @@ -1979,14 +1970,14 @@ func (rc *LogClient) GetGCRows() []*stream.PreDelRangeQuery { func (rc *LogClient) SaveIdMapWithFailPoints( ctx context.Context, - manager *stream.TableMappingManager, + state *SegmentedPiTRState, logCheckpointMetaManager checkpoint.LogMetaManagerT, ) error { failpoint.Inject("failed-before-id-maps-saved", func(_ failpoint.Value) { failpoint.Return(errors.New("failpoint: failed before id maps saved")) }) - if err := rc.saveIDMap(ctx, manager, logCheckpointMetaManager); err != nil { + if err := rc.SaveSegmentedPiTRState(ctx, state, logCheckpointMetaManager); err != nil { return errors.Trace(err) } diff --git a/br/pkg/restore/log_client/client_test.go b/br/pkg/restore/log_client/client_test.go index 2563510038b09..07b8073022e4d 100644 --- a/br/pkg/restore/log_client/client_test.go +++ b/br/pkg/restore/log_client/client_test.go @@ -20,6 +20,7 @@ import ( "fmt" "math" "path/filepath" + "strings" "sync" "testing" "time" @@ -1346,9 +1347,8 @@ func TestInitSchemasReplaceForDDL(t *testing.T) { require.NoError(t, err) err = stg.WriteFile(ctx, logclient.PitrIDMapsFilename(123, 1), []byte("123")) require.NoError(t, err) - err = client.GetBaseIDMapAndMerge(ctx, false, false, nil, stream.NewTableMappingManager()) - require.Error(t, err) - require.Contains(t, err.Error(), "proto: wrong") + _, err = client.GetBaseIDMapAndMerge(ctx, false, false, nil) + requireInvalidProtoError(t, err) err = stg.DeleteFile(ctx, logclient.PitrIDMapsFilename(123, 1)) require.NoError(t, err) } @@ -1358,9 +1358,8 @@ func TestInitSchemasReplaceForDDL(t *testing.T) { client.SetStorage(ctx, backend, nil) err := stg.WriteFile(ctx, logclient.PitrIDMapsFilename(123, 2), []byte("123")) require.NoError(t, err) - err = client.GetBaseIDMapAndMerge(ctx, false, true, nil, stream.NewTableMappingManager()) - require.Error(t, err) - require.Contains(t, err.Error(), "proto: wrong") + _, err = client.GetBaseIDMapAndMerge(ctx, false, true, nil) + requireInvalidProtoError(t, err) err = stg.DeleteFile(ctx, logclient.PitrIDMapsFilename(123, 2)) require.NoError(t, err) } @@ -1373,12 +1372,20 @@ func TestInitSchemasReplaceForDDL(t *testing.T) { se, err := g.CreateSession(s.Mock.Storage) require.NoError(t, err) client := logclient.TEST_NewLogClient(123, 1, 2, 1, s.Mock.Domain, se) - err = client.GetBaseIDMapAndMerge(ctx, false, true, nil, stream.NewTableMappingManager()) + _, err = client.GetBaseIDMapAndMerge(ctx, false, true, nil) require.Error(t, err) require.Contains(t, err.Error(), "no base id map found from saved id or last restored PiTR") } } +func requireInvalidProtoError(t *testing.T, err error) { + t.Helper() + require.Error(t, err) + errMsg := err.Error() + require.True(t, strings.Contains(errMsg, "proto") || strings.Contains(errMsg, "EOF"), + "unexpected error: %s", errMsg) +} + func downstreamID(upstreamID int64) int64 { return upstreamID + 10000000 } @@ -1446,8 +1453,30 @@ func TestPITRIDMap(t *testing.T) { baseTableMappingManager := &stream.TableMappingManager{ DBReplaceMap: getDBMap(), } - err = client.TEST_saveIDMap(ctx, baseTableMappingManager, nil) + tiflashItems := map[int64]model.TiFlashReplicaInfo{ + 1: {Count: 1, Available: true}, + 2: {Count: 2, LocationLabels: []string{"zone", "rack"}, AvailablePartitionIDs: []int64{3, 4}}, + } + ingestState := &ingestrec.RecorderState{ + Items: map[int64]map[int64]ingestrec.IndexState{ + 10: { + 1: {IsPrimary: true}, + 2: {IsPrimary: false}, + }, + }, + } + state := &logclient.SegmentedPiTRState{ + DbMaps: baseTableMappingManager.ToProto(), + TiFlashItems: tiflashItems, + IngestRecorderState: ingestState, + } + err = client.TEST_saveIDMap(ctx, state, nil) require.NoError(t, err) + loadedState, err := client.TEST_loadSegmentedPiTRState(ctx, 2, nil) + require.NoError(t, err) + require.NotNil(t, loadedState) + require.Equal(t, tiflashItems, loadedState.TiFlashItems) + require.Equal(t, ingestState, loadedState.IngestRecorderState) newSchemaReplaces, err := client.TEST_initSchemasMap(ctx, 1, nil) require.NoError(t, err) require.Nil(t, newSchemaReplaces) @@ -1496,7 +1525,10 @@ func TestPITRIDMapOnStorage(t *testing.T) { baseTableMappingManager := &stream.TableMappingManager{ DBReplaceMap: getDBMap(), } - err = client.TEST_saveIDMap(ctx, baseTableMappingManager, nil) + state := &logclient.SegmentedPiTRState{ + DbMaps: baseTableMappingManager.ToProto(), + } + err = client.TEST_saveIDMap(ctx, state, nil) require.NoError(t, err) newSchemaReplaces, err := client.TEST_initSchemasMap(ctx, 1, nil) require.NoError(t, err) @@ -1552,7 +1584,10 @@ func TestPITRIDMapOnCheckpointStorage(t *testing.T) { baseTableMappingManager := &stream.TableMappingManager{ DBReplaceMap: getDBMap(), } - err = client.TEST_saveIDMap(ctx, baseTableMappingManager, logCheckpointMetaManager) + state := &logclient.SegmentedPiTRState{ + DbMaps: baseTableMappingManager.ToProto(), + } + err = client.TEST_saveIDMap(ctx, state, logCheckpointMetaManager) require.NoError(t, err) newSchemaReplaces, err := client.TEST_initSchemasMap(ctx, 1, logCheckpointMetaManager) require.NoError(t, err) diff --git a/br/pkg/restore/log_client/export_test.go b/br/pkg/restore/log_client/export_test.go index db5104b59d9b0..c6f6a630cad4b 100644 --- a/br/pkg/restore/log_client/export_test.go +++ b/br/pkg/restore/log_client/export_test.go @@ -22,7 +22,6 @@ import ( "github.com/pingcap/kvproto/pkg/encryptionpb" "github.com/pingcap/tidb/br/pkg/checkpoint" "github.com/pingcap/tidb/br/pkg/glue" - "github.com/pingcap/tidb/br/pkg/stream" "github.com/pingcap/tidb/br/pkg/utils/iter" "github.com/pingcap/tidb/pkg/domain" "github.com/pingcap/tidb/pkg/objstore/storeapi" @@ -66,10 +65,10 @@ func (m *PhysicalWithMigrations) Physical() *backuppb.DataFileGroup { func (rc *LogClient) TEST_saveIDMap( ctx context.Context, - m *stream.TableMappingManager, + state *SegmentedPiTRState, logCheckpointMetaManager checkpoint.LogMetaManagerT, ) error { - return rc.SaveIdMapWithFailPoints(ctx, m, logCheckpointMetaManager) + return rc.SaveIdMapWithFailPoints(ctx, state, logCheckpointMetaManager) } func (rc *LogClient) TEST_initSchemasMap( @@ -77,7 +76,22 @@ func (rc *LogClient) TEST_initSchemasMap( restoreTS uint64, logCheckpointMetaManager checkpoint.LogMetaManagerT, ) ([]*backuppb.PitrDBMap, error) { - return rc.loadSchemasMap(ctx, restoreTS, logCheckpointMetaManager) + state, err := rc.loadSegmentedPiTRState(ctx, restoreTS, logCheckpointMetaManager, true) + if err != nil { + return nil, err + } + if state == nil { + return nil, nil + } + return state.DbMaps, nil +} + +func (rc *LogClient) TEST_loadSegmentedPiTRState( + ctx context.Context, + restoreTS uint64, + logCheckpointMetaManager checkpoint.LogMetaManagerT, +) (*SegmentedPiTRState, error) { + return rc.loadSegmentedPiTRState(ctx, restoreTS, logCheckpointMetaManager, true) } // readStreamMetaByTS is used for streaming task. collect all meta file by TS, it is for test usage. diff --git a/br/pkg/restore/log_client/id_map.go b/br/pkg/restore/log_client/id_map.go index 0ad5bb92ee9d0..4fdb45d2b38db 100644 --- a/br/pkg/restore/log_client/id_map.go +++ b/br/pkg/restore/log_client/id_map.go @@ -23,9 +23,7 @@ import ( backuppb "github.com/pingcap/kvproto/pkg/brpb" "github.com/pingcap/log" "github.com/pingcap/tidb/br/pkg/checkpoint" - "github.com/pingcap/tidb/br/pkg/metautil" "github.com/pingcap/tidb/br/pkg/restore" - "github.com/pingcap/tidb/br/pkg/stream" "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/objstore/storeapi" "github.com/pingcap/tidb/pkg/parser/ast" @@ -56,58 +54,67 @@ func (rc *LogClient) tryGetCheckpointStorage( return logCheckpointMetaManager.TryGetStorage() } -// saveIDMap saves the id mapping information. -func (rc *LogClient) saveIDMap( +func (rc *LogClient) SaveSegmentedPiTRState( ctx context.Context, - manager *stream.TableMappingManager, + state *SegmentedPiTRState, logCheckpointMetaManager checkpoint.LogMetaManagerT, ) error { - dbmaps := manager.ToProto() + if state == nil { + return errors.New("segmented pitr state is nil") + } + pbState, err := state.toProto() + if err != nil { + return errors.Trace(err) + } if checkpointStorage := rc.tryGetCheckpointStorage(logCheckpointMetaManager); checkpointStorage != nil { - log.Info("checkpoint storage is specified, load pitr id map from the checkpoint storage.") - if err := rc.saveIDMap2Storage(ctx, checkpointStorage, dbmaps); err != nil { + log.Info("checkpoint storage is specified, save pitr id map to the checkpoint storage.") + if err := rc.saveSegmentedPiTRStateToStorage(ctx, checkpointStorage, pbState); err != nil { return errors.Trace(err) } } else if rc.pitrIDMapTableExists() { - if err := rc.saveIDMap2Table(ctx, dbmaps); err != nil { + if err := rc.saveSegmentedPiTRStateToTable(ctx, pbState); err != nil { return errors.Trace(err) } } else { log.Info("the table mysql.tidb_pitr_id_map does not exist, maybe the cluster version is old.") - if err := rc.saveIDMap2Storage(ctx, rc.storage, dbmaps); err != nil { + if err := rc.saveSegmentedPiTRStateToStorage(ctx, rc.storage, pbState); err != nil { return errors.Trace(err) } } if rc.useCheckpoint { - log.Info("save checkpoint task info with InLogRestoreAndIdMapPersist status") - if err := logCheckpointMetaManager.SaveCheckpointProgress(ctx, &checkpoint.CheckpointProgress{ - Progress: checkpoint.InLogRestoreAndIdMapPersisted, - }); err != nil { + exists, err := logCheckpointMetaManager.ExistsCheckpointProgress(ctx) + if err != nil { return errors.Trace(err) } + if !exists { + log.Info("save checkpoint task info with InLogRestoreAndIdMapPersist status") + if err := logCheckpointMetaManager.SaveCheckpointProgress(ctx, &checkpoint.CheckpointProgress{ + Progress: checkpoint.InLogRestoreAndIdMapPersisted, + }); err != nil { + return errors.Trace(err) + } + } } return nil } -func (rc *LogClient) saveIDMap2Storage( +func (rc *LogClient) saveSegmentedPiTRStateToStorage( ctx context.Context, storage storeapi.Storage, - dbMaps []*backuppb.PitrDBMap, + state *backuppb.SegmentedPiTRState, ) error { clusterID := rc.GetClusterID(ctx) metaFileName := PitrIDMapsFilename(clusterID, rc.restoreTS) - metaWriter := metautil.NewMetaWriter(storage, metautil.MetaFileSize, false, metaFileName, nil) - metaWriter.Update(func(m *backuppb.BackupMeta) { - m.ClusterId = clusterID - m.DbMaps = dbMaps - }) - return metaWriter.FlushBackupMeta(ctx) + data, err := proto.Marshal(state) + if err != nil { + return errors.Trace(err) + } + return storage.WriteFile(ctx, metaFileName, data) } -func (rc *LogClient) saveIDMap2Table(ctx context.Context, dbMaps []*backuppb.PitrDBMap) error { - backupmeta := &backuppb.BackupMeta{DbMaps: dbMaps} - data, err := proto.Marshal(backupmeta) +func (rc *LogClient) saveSegmentedPiTRStateToTable(ctx context.Context, state *backuppb.SegmentedPiTRState) error { + data, err := proto.Marshal(state) if err != nil { return errors.Trace(err) } @@ -153,30 +160,31 @@ func (rc *LogClient) saveIDMap2Table(ctx context.Context, dbMaps []*backuppb.Pit return nil } -func (rc *LogClient) loadSchemasMap( +func (rc *LogClient) loadSegmentedPiTRState( ctx context.Context, restoredTS uint64, logCheckpointMetaManager checkpoint.LogMetaManagerT, -) ([]*backuppb.PitrDBMap, error) { + onlyThisRestore bool, +) (*SegmentedPiTRState, error) { if checkpointStorage := rc.tryGetCheckpointStorage(logCheckpointMetaManager); checkpointStorage != nil { log.Info("checkpoint storage is specified, load pitr id map from the checkpoint storage.") - dbMaps, err := rc.loadSchemasMapFromStorage(ctx, checkpointStorage, restoredTS) - return dbMaps, errors.Trace(err) + state, err := rc.loadSegmentedPiTRStateFromStorage(ctx, checkpointStorage, restoredTS) + return state, errors.Trace(err) } if rc.pitrIDMapTableExists() { - dbMaps, err := rc.loadSchemasMapFromTable(ctx, restoredTS) - return dbMaps, errors.Trace(err) + state, err := rc.loadSegmentedPiTRStateFromTable(ctx, restoredTS, onlyThisRestore) + return state, errors.Trace(err) } log.Info("the table mysql.tidb_pitr_id_map does not exist, maybe the cluster version is old.") - dbMaps, err := rc.loadSchemasMapFromStorage(ctx, rc.storage, restoredTS) - return dbMaps, errors.Trace(err) + state, err := rc.loadSegmentedPiTRStateFromStorage(ctx, rc.storage, restoredTS) + return state, errors.Trace(err) } -func (rc *LogClient) loadSchemasMapFromStorage( +func (rc *LogClient) loadSegmentedPiTRStateFromStorage( ctx context.Context, storage storeapi.Storage, restoredTS uint64, -) ([]*backuppb.PitrDBMap, error) { +) (*SegmentedPiTRState, error) { clusterID := rc.GetClusterID(ctx) metaFileName := PitrIDMapsFilename(clusterID, restoredTS) exist, err := storage.FileExists(ctx, metaFileName) @@ -192,26 +200,34 @@ func (rc *LogClient) loadSchemasMapFromStorage( if err != nil { return nil, errors.Trace(err) } - backupMeta := &backuppb.BackupMeta{} - if err := backupMeta.Unmarshal(metaData); err != nil { + state := &backuppb.SegmentedPiTRState{} + if err := state.Unmarshal(metaData); err != nil { return nil, errors.Trace(err) } - return backupMeta.GetDbMaps(), nil + return segmentedPiTRStateFromProto(state) } -func (rc *LogClient) loadSchemasMapFromTable( +func (rc *LogClient) loadSegmentedPiTRStateFromTable( ctx context.Context, restoredTS uint64, -) ([]*backuppb.PitrDBMap, error) { + onlyThisRestore bool, +) (*SegmentedPiTRState, error) { hasRestoreIDColumn := rc.pitrIDMapHasRestoreIDColumn() var getPitrIDMapSQL string var args []any + var withRestoreID bool if hasRestoreIDColumn { - // new version with restore_id column - getPitrIDMapSQL = "SELECT segment_id, id_map FROM mysql.tidb_pitr_id_map WHERE restore_id = %? and restored_ts = %? and upstream_cluster_id = %? ORDER BY segment_id;" - args = []any{rc.restoreID, restoredTS, rc.upstreamClusterID} + if onlyThisRestore { + // new version with restore_id column + getPitrIDMapSQL = "SELECT segment_id, id_map FROM mysql.tidb_pitr_id_map WHERE restore_id = %? and restored_ts = %? and upstream_cluster_id = %? ORDER BY segment_id;" + args = []any{rc.restoreID, restoredTS, rc.upstreamClusterID} + } else { + getPitrIDMapSQL = "SELECT restore_id, segment_id, id_map FROM mysql.tidb_pitr_id_map WHERE restored_ts = %? and upstream_cluster_id = %? ORDER BY restore_id, segment_id;" + args = []any{restoredTS, rc.upstreamClusterID} + withRestoreID = true + } } else { // old version without restore_id column log.Info("mysql.tidb_pitr_id_map table does not have restore_id column, using backward compatible mode") @@ -234,21 +250,38 @@ func (rc *LogClient) loadSchemasMapFromTable( return nil, nil } metaData := make([]byte, 0, len(rows)*PITRIdMapBlockSize) + var expectedSegmentID uint64 + var selectedRestoreID uint64 for i, row := range rows { - elementID := row.GetUint64(0) - if uint64(i) != elementID { + var elementID uint64 + var data []byte + if withRestoreID { + restoreID := row.GetUint64(0) + if i == 0 { + selectedRestoreID = restoreID + } else if restoreID != selectedRestoreID { + return nil, errors.Errorf("multiple restore_id values found for restored_ts=%d and upstream_cluster_id=%d: %d, %d", + restoredTS, rc.upstreamClusterID, selectedRestoreID, restoreID) + } + elementID = row.GetUint64(1) + data = row.GetBytes(2) + } else { + elementID = row.GetUint64(0) + data = row.GetBytes(1) + } + if expectedSegmentID != elementID { return nil, errors.Errorf("the part(segment_id = %d) of pitr id map is lost", i) } - d := row.GetBytes(1) - if len(d) == 0 { + if len(data) == 0 { return nil, errors.Errorf("get the empty part(segment_id = %d) of pitr id map", i) } - metaData = append(metaData, d...) + metaData = append(metaData, data...) + expectedSegmentID++ } - backupMeta := &backuppb.BackupMeta{} - if err := backupMeta.Unmarshal(metaData); err != nil { + state := &backuppb.SegmentedPiTRState{} + if err := state.Unmarshal(metaData); err != nil { return nil, errors.Trace(err) } - return backupMeta.GetDbMaps(), nil + return segmentedPiTRStateFromProto(state) } diff --git a/br/pkg/restore/log_client/segmented_state.go b/br/pkg/restore/log_client/segmented_state.go new file mode 100644 index 0000000000000..277ecafd2ff12 --- /dev/null +++ b/br/pkg/restore/log_client/segmented_state.go @@ -0,0 +1,93 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logclient + +import ( + "encoding/json" + + "github.com/pingcap/errors" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/tidb/br/pkg/restore/ingestrec" + "github.com/pingcap/tidb/pkg/meta/model" +) + +const segmentedPiTRStatePayloadVersion uint64 = 1 + +// SegmentedPiTRState is the decoded segmented restore state stored in tidb_pitr_id_map. +type SegmentedPiTRState struct { + DbMaps []*backuppb.PitrDBMap + TiFlashItems map[int64]model.TiFlashReplicaInfo + IngestRecorderState *ingestrec.RecorderState +} + +type segmentedPiTRStatePayload struct { + TiFlashItems map[int64]model.TiFlashReplicaInfo `json:"tiflash_items"` + IngestRecorderState *ingestrec.RecorderState `json:"ingest_recorder,omitempty"` +} + +func (s *SegmentedPiTRState) hasPayload() bool { + if s == nil { + return false + } + if s.TiFlashItems != nil { + return true + } + return s.IngestRecorderState != nil +} + +func (s *SegmentedPiTRState) toProto() (*backuppb.SegmentedPiTRState, error) { + if s == nil { + return nil, errors.New("segmented pitr state is nil") + } + state := &backuppb.SegmentedPiTRState{ + DbMaps: s.DbMaps, + } + if !s.hasPayload() { + return state, nil + } + payload := segmentedPiTRStatePayload{ + TiFlashItems: s.TiFlashItems, + IngestRecorderState: s.IngestRecorderState, + } + data, err := json.Marshal(payload) + if err != nil { + return nil, errors.Trace(err) + } + state.SegmentedPitrStateVer = segmentedPiTRStatePayloadVersion + state.SegmentedPitrState = [][]byte{data} + return state, nil +} + +func segmentedPiTRStateFromProto(state *backuppb.SegmentedPiTRState) (*SegmentedPiTRState, error) { + if state == nil { + return nil, nil + } + result := &SegmentedPiTRState{ + DbMaps: state.GetDbMaps(), + } + if state.GetSegmentedPitrStateVer() == 0 || len(state.GetSegmentedPitrState()) == 0 { + return result, nil + } + if state.GetSegmentedPitrStateVer() != segmentedPiTRStatePayloadVersion { + return nil, errors.Errorf("unsupported segmented pitr state version: %d", state.GetSegmentedPitrStateVer()) + } + var payload segmentedPiTRStatePayload + if err := json.Unmarshal(state.GetSegmentedPitrState()[0], &payload); err != nil { + return nil, errors.Trace(err) + } + result.TiFlashItems = payload.TiFlashItems + result.IngestRecorderState = payload.IngestRecorderState + return result, nil +} diff --git a/br/pkg/task/restore.go b/br/pkg/task/restore.go index 3916e4292fcca..d386933c17689 100644 --- a/br/pkg/task/restore.go +++ b/br/pkg/task/restore.go @@ -105,6 +105,8 @@ const ( // FlagStreamStartTS and FlagStreamRestoreTS is used for log restore timestamp range. FlagStreamStartTS = "start-ts" FlagStreamRestoreTS = "restored-ts" + // FlagLastSegment indicates whether this restore is the last segment. + FlagLastSegment = "last-segment" // FlagStreamFullBackupStorage is used for log restore, represents the full backup storage. FlagStreamFullBackupStorage = "full-backup-storage" // FlagPiTRBatchCount and FlagPiTRBatchSize are used for restore log with batch method. @@ -279,11 +281,13 @@ type RestoreConfig struct { // whether RestoreTS was explicitly specified by user vs auto-detected IsRestoredTSUserSpecified bool `json:"-" toml:"-"` // rewriteTS is the rewritten timestamp of meta kvs. - RewriteTS uint64 `json:"-" toml:"-"` - tiflashRecorder *tiflashrec.TiFlashRecorder `json:"-" toml:"-"` - PitrBatchCount uint32 `json:"pitr-batch-count" toml:"pitr-batch-count"` - PitrBatchSize uint32 `json:"pitr-batch-size" toml:"pitr-batch-size"` - PitrConcurrency uint32 `json:"-" toml:"-"` + RewriteTS uint64 `json:"-" toml:"-"` + tiflashRecorder *tiflashrec.TiFlashRecorder `json:"-" toml:"-"` + LastRestore bool `json:"last-segment" toml:"last-segment"` + IsLastRestoreUserSpecified bool `json:"-" toml:"-"` + PitrBatchCount uint32 `json:"pitr-batch-count" toml:"pitr-batch-count"` + PitrBatchSize uint32 `json:"pitr-batch-size" toml:"pitr-batch-size"` + PitrConcurrency uint32 `json:"-" toml:"-"` UseCheckpoint bool `json:"use-checkpoint" toml:"use-checkpoint"` CheckpointStorage string `json:"checkpoint-storage" toml:"checkpoint-storage"` @@ -383,6 +387,7 @@ func DefineStreamRestoreFlags(command *cobra.Command) { "support TSO or datetime, e.g. '400036290571534337' or '2018-05-11 01:42:23+0800'") command.Flags().String(FlagStreamRestoreTS, "", "the point of restore, used for log restore.\n"+ "support TSO or datetime, e.g. '400036290571534337' or '2018-05-11 01:42:23+0800'") + command.Flags().Bool(FlagLastSegment, true, "whether this restore is the last segment of a segmented PiTR task") command.Flags().String(FlagStreamFullBackupStorage, "", "specify the backup full storage. "+ "fill it if want restore full backup before restore log.") command.Flags().Uint32(FlagPiTRBatchCount, defaultPiTRBatchCount, "specify the batch count to restore log.") @@ -406,6 +411,11 @@ func (cfg *RestoreConfig) ParseStreamRestoreFlags(flags *pflag.FlagSet) error { if cfg.RestoreTS, err = ParseTSString(tsString, true); err != nil { return errors.Trace(err) } + cfg.LastRestore, err = flags.GetBool(FlagLastSegment) + if err != nil { + return errors.Trace(err) + } + cfg.IsLastRestoreUserSpecified = flags.Changed(FlagLastSegment) // check if RestoreTS was explicitly specified by user cfg.IsRestoredTSUserSpecified = flags.Changed(FlagStreamRestoreTS) @@ -614,6 +624,9 @@ func (cfg *RestoreConfig) Adjust() { } func (cfg *RestoreConfig) adjustRestoreConfigForStreamRestore() { + if !cfg.IsLastRestoreUserSpecified { + cfg.LastRestore = true + } if cfg.PitrConcurrency == 0 { cfg.PitrConcurrency = defaultPiTRConcurrency } diff --git a/br/pkg/task/stream.go b/br/pkg/task/stream.go index 611d7d391b2c2..93cdc95206c5e 100644 --- a/br/pkg/task/stream.go +++ b/br/pkg/task/stream.go @@ -1478,6 +1478,7 @@ type LogRestoreConfig struct { tableMappingManager *stream.TableMappingManager logClient *logclient.LogClient ddlFiles []logclient.Log + ingestRecorderState *ingestrec.RecorderState } // restoreStream starts the log restore @@ -1627,6 +1628,9 @@ func restoreStream( if err != nil { return errors.Trace(err) } + if cfg.ingestRecorderState != nil { + schemasReplace.GetIngestRecorder().MergeState(cfg.ingestRecorderState) + } importModeSwitcher := restore.NewImportModeSwitcher(mgr.GetPDClient(), cfg.Config.SwitchModeInterval, mgr.GetTLSConfig()) @@ -1679,9 +1683,6 @@ func restoreStream( rewriteRules := buildRewriteRules(schemasReplace) ingestRecorder := schemasReplace.GetIngestRecorder() - if err := rangeFilterFromIngestRecorder(ingestRecorder, rewriteRules); err != nil { - return errors.Trace(err) - } logFilesIter, err := client.LoadDMLFiles(ctx) if err != nil { @@ -1790,17 +1791,33 @@ func restoreStream( return errors.Annotate(err, "failed to insert rows into gc_delete_range") } - // index ingestion is not captured by regular log backup, so we need to manually ingest again - if err = client.RepairIngestIndex(ctx, ingestRecorder, cfg.logCheckpointMetaManager, g); err != nil { - return errors.Annotate(err, "failed to repair ingest index") - } + if !cfg.LastRestore { + state := &logclient.SegmentedPiTRState{ + DbMaps: cfg.tableMappingManager.ToProto(), + IngestRecorderState: ingestRecorder.ExportState(), + } + if cfg.tiflashRecorder != nil { + state.TiFlashItems = cfg.tiflashRecorder.GetItems() + } + if err := client.SaveSegmentedPiTRState(ctx, state, cfg.logCheckpointMetaManager); err != nil { + return errors.Annotate(err, "failed to save segmented pitr state") + } + } else { + if err := rangeFilterFromIngestRecorder(ingestRecorder, rewriteRules); err != nil { + return errors.Trace(err) + } + // index ingestion is not captured by regular log backup, so we need to manually ingest again + if err = client.RepairIngestIndex(ctx, ingestRecorder, cfg.logCheckpointMetaManager, g); err != nil { + return errors.Annotate(err, "failed to repair ingest index") + } - if cfg.tiflashRecorder != nil { - sqls := cfg.tiflashRecorder.GenerateAlterTableDDLs(mgr.GetDomain().InfoSchema()) - log.Info("Generating SQLs for restoring TiFlash Replica", - zap.Strings("sqls", sqls)) - if err := client.ResetTiflashReplicas(ctx, sqls, g); err != nil { - return errors.Annotate(err, "failed to reset tiflash replicas") + if cfg.tiflashRecorder != nil { + sqls := cfg.tiflashRecorder.GenerateAlterTableDDLs(mgr.GetDomain().InfoSchema()) + log.Info("Generating SQLs for restoring TiFlash Replica", + zap.Strings("sqls", sqls)) + if err := client.ResetTiflashReplicas(ctx, sqls, g); err != nil { + return errors.Annotate(err, "failed to reset tiflash replicas") + } } } @@ -2218,11 +2235,21 @@ func buildAndSaveIDMapIfNeeded(ctx context.Context, client *logclient.LogClient, // get the schemas ID replace information. saved := isCurrentIdMapSaved(cfg.checkpointTaskInfo) hasFullBackupStorage := len(cfg.FullBackupStorage) != 0 - err := client.GetBaseIDMapAndMerge(ctx, hasFullBackupStorage, saved, - cfg.logCheckpointMetaManager, cfg.tableMappingManager) + state, err := client.GetBaseIDMapAndMerge(ctx, hasFullBackupStorage, saved, + cfg.logCheckpointMetaManager) if err != nil { return errors.Trace(err) } + if state != nil { + if len(state.DbMaps) > 0 { + cfg.tableMappingManager.SetFromPiTRIDMap() + cfg.tableMappingManager.MergeBaseDBReplace(stream.FromDBMapProto(state.DbMaps)) + } + if state.TiFlashItems != nil && cfg.tiflashRecorder != nil { + cfg.tiflashRecorder.Load(state.TiFlashItems) + } + cfg.ingestRecorderState = state.IngestRecorderState + } if saved { return nil @@ -2240,7 +2267,16 @@ func buildAndSaveIDMapIfNeeded(ctx context.Context, client *logclient.LogClient, if err != nil { return errors.Trace(err) } - if err = client.SaveIdMapWithFailPoints(ctx, cfg.tableMappingManager, cfg.logCheckpointMetaManager); err != nil { + newState := &logclient.SegmentedPiTRState{ + DbMaps: cfg.tableMappingManager.ToProto(), + IngestRecorderState: cfg.ingestRecorderState, + } + if cfg.tiflashRecorder != nil { + newState.TiFlashItems = cfg.tiflashRecorder.GetItems() + } else if state != nil && state.TiFlashItems != nil { + newState.TiFlashItems = state.TiFlashItems + } + if err = client.SaveIdMapWithFailPoints(ctx, newState, cfg.logCheckpointMetaManager); err != nil { return errors.Trace(err) } return nil diff --git a/go.mod b/go.mod index b677ac577614e..5091bfc1db717 100644 --- a/go.mod +++ b/go.mod @@ -369,6 +369,7 @@ replace ( // Downgrade grpc to v1.63.2, as well as other related modules. github.com/apache/arrow-go/v18 => github.com/joechenrh/arrow-go/v18 v18.0.0-20250911101656-62c34c9a3b82 github.com/go-ldap/ldap/v3 => github.com/YangKeao/ldap/v3 v3.4.5-0.20230421065457-369a3bab1117 + github.com/pingcap/kvproto => /root/workspace/kvproto/worktree/pitr_id_map_payload github.com/pingcap/tidb/pkg/parser => ./pkg/parser // TODO: `sourcegraph.com/sourcegraph/appdash` has been archived, and the original host has been removed. diff --git a/pkg/testkit/brhelper/workload/context.go b/pkg/testkit/brhelper/workload/context.go new file mode 100644 index 0000000000000..fd51f91691f28 --- /dev/null +++ b/pkg/testkit/brhelper/workload/context.go @@ -0,0 +1,60 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package workload + +import ( + "context" + "database/sql" + "encoding/json" + "math/rand/v2" +) + +type Context struct { + context.Context + DB *sql.DB +} + +type TickContext struct { + Context + RNG *rand.Rand + + UpdateStateFn func(json.RawMessage) +} + +func (c TickContext) UpdateState(state json.RawMessage) { + if c.UpdateStateFn != nil { + c.UpdateStateFn(state) + } +} + +type ExitContext struct { + Context + + UpdateStateFn func(json.RawMessage) +} + +func (c ExitContext) UpdateState(state json.RawMessage) { + if c.UpdateStateFn != nil { + c.UpdateStateFn(state) + } +} + +type Case interface { + Name() string + Prepare(Context) (json.RawMessage, error) + Tick(TickContext, json.RawMessage) error + Exit(ExitContext, json.RawMessage) error + Verify(Context, json.RawMessage) error +} diff --git a/pkg/testkit/brhelper/workload/runner.go b/pkg/testkit/brhelper/workload/runner.go new file mode 100644 index 0000000000000..22984abac7386 --- /dev/null +++ b/pkg/testkit/brhelper/workload/runner.go @@ -0,0 +1,368 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package workload + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "hash/fnv" + "math/rand/v2" + "strings" + "sync" + "time" +) + +type CaseSpec struct { + Name string + Case Case +} + +type RunConfig struct { + TickCount int + TickInterval time.Duration + Seed int64 + Parallel bool +} + +type Runner struct { + db *sql.DB + store StateStore + cases []CaseSpec +} + +func NewRunner(db *sql.DB, store StateStore, cases ...Case) (*Runner, error) { + specs := make([]CaseSpec, len(cases)) + for i, c := range cases { + specs[i] = CaseSpec{Case: c} + } + return NewRunnerWithSpecs(db, store, specs...) +} + +func NewRunnerWithSpecs(db *sql.DB, store StateStore, specs ...CaseSpec) (*Runner, error) { + if db == nil { + return nil, fmt.Errorf("workload: nil db") + } + if store == nil { + return nil, fmt.Errorf("workload: nil state store") + } + + caseSpecs, err := normalizeCaseSpecs(specs) + if err != nil { + return nil, err + } + return &Runner{db: db, store: store, cases: caseSpecs}, nil +} + +func (r *Runner) Cases() []CaseSpec { + out := make([]CaseSpec, len(r.cases)) + copy(out, r.cases) + return out +} + +func (r *Runner) Prepare(ctx context.Context) error { + if err := r.store.Reset(ctx); err != nil { + return err + } + + for _, spec := range r.cases { + state, err := spec.Case.Prepare(Context{Context: ctx, DB: r.db}) + if err != nil { + return err + } + if err := r.store.Put(ctx, spec.Name, state); err != nil { + return err + } + } + return nil +} + +func (r *Runner) Run(ctx context.Context, cfg RunConfig) error { + if cfg.TickCount <= 0 { + return fmt.Errorf("workload: TickCount must be > 0") + } + if cfg.TickInterval < 0 { + return fmt.Errorf("workload: TickInterval must be >= 0") + } + + states, err := r.store.GetAll(ctx) + if err != nil { + return err + } + byName := make(map[string]Case, len(r.cases)) + for _, spec := range r.cases { + byName[spec.Name] = spec.Case + } + for name := range states { + if _, ok := byName[name]; !ok { + return fmt.Errorf("workload: unknown case %q in state store", name) + } + } + + selected := make([]CaseSpec, 0, len(r.cases)) + for _, spec := range r.cases { + if _, ok := states[spec.Name]; ok { + selected = append(selected, spec) + } + } + if len(selected) == 0 { + return fmt.Errorf("workload: no cases in state store; run Prepare first") + } + + rngs := newCaseRNGs(cfg.Seed, selected) + if cfg.Parallel { + if err := r.runParallelTicks(ctx, cfg, selected, states, rngs); err != nil { + return err + } + } else { + if err := r.runSequentialTicks(ctx, cfg, selected, states, rngs); err != nil { + return err + } + } + + for _, spec := range selected { + state, ok := states[spec.Name] + if !ok { + return fmt.Errorf("workload: case %q not found in state store; run Prepare first", spec.Name) + } + exitCtx := ExitContext{ + Context: Context{Context: ctx, DB: r.db}, + UpdateStateFn: func(updated json.RawMessage) { + states[spec.Name] = updated + }, + } + if err := spec.Case.Exit(exitCtx, state); err != nil { + return err + } + } + + finalStates := make(map[string]json.RawMessage, len(selected)) + for _, spec := range selected { + if state, ok := states[spec.Name]; ok { + finalStates[spec.Name] = state + } + } + if err := r.store.PutMany(ctx, finalStates); err != nil { + return err + } + return nil +} + +func (r *Runner) Verify(ctx context.Context) error { + states, err := r.store.GetAll(ctx) + if err != nil { + return err + } + byName := make(map[string]Case, len(r.cases)) + for _, spec := range r.cases { + byName[spec.Name] = spec.Case + } + + for name, state := range states { + c, ok := byName[name] + if !ok { + base, _, cut := strings.Cut(name, "#") + if cut { + c, ok = byName[base] + } + } + if !ok { + return fmt.Errorf("workload: unknown case %q in state store", name) + } + if err := c.Verify(Context{Context: ctx, DB: r.db}, state); err != nil { + return err + } + } + return nil +} + +func (r *Runner) runSequentialTicks( + ctx context.Context, + cfg RunConfig, + selected []CaseSpec, + states map[string]json.RawMessage, + rngs map[string]*rand.Rand, +) error { + shuffleRNG := rand.New(rand.NewPCG(uint64(cfg.Seed), uint64(cfg.Seed>>1))) + for tick := 0; tick < cfg.TickCount; tick++ { + shuffleRNG.Shuffle(len(selected), func(i, j int) { selected[i], selected[j] = selected[j], selected[i] }) + + for _, spec := range selected { + state, ok := states[spec.Name] + if !ok { + return fmt.Errorf("workload: case %q not found in state store; run Prepare first", spec.Name) + } + rng := rngs[spec.Name] + tickCtx := TickContext{ + Context: Context{Context: ctx, DB: r.db}, + RNG: rng, + UpdateStateFn: func(updated json.RawMessage) { + states[spec.Name] = updated + }, + } + if err := spec.Case.Tick(tickCtx, state); err != nil { + return err + } + } + + if tick != cfg.TickCount-1 { + if err := sleep(ctx, cfg.TickInterval); err != nil { + return err + } + } + } + return nil +} + +func (r *Runner) runParallelTicks( + ctx context.Context, + cfg RunConfig, + selected []CaseSpec, + states map[string]json.RawMessage, + rngs map[string]*rand.Rand, +) error { + var mu sync.Mutex + for tick := 0; tick < cfg.TickCount; tick++ { + if err := r.runParallelTick(ctx, selected, states, rngs, &mu); err != nil { + return err + } + + if tick != cfg.TickCount-1 { + if err := sleep(ctx, cfg.TickInterval); err != nil { + return err + } + } + } + return nil +} + +func (r *Runner) runParallelTick( + ctx context.Context, + selected []CaseSpec, + states map[string]json.RawMessage, + rngs map[string]*rand.Rand, + mu *sync.Mutex, +) error { + if len(selected) == 0 { + return nil + } + + runCtx, cancel := context.WithCancel(ctx) + defer cancel() + + var wg sync.WaitGroup + var once sync.Once + var firstErr error + + for _, spec := range selected { + wg.Go( + func() { + mu.Lock() + state, ok := states[spec.Name] + mu.Unlock() + if !ok { + once.Do(func() { + firstErr = fmt.Errorf("workload: case %q not found in state store; run Prepare first", spec.Name) + cancel() + }) + return + } + rng := rngs[spec.Name] + + tickCtx := TickContext{ + Context: Context{Context: runCtx, DB: r.db}, + RNG: rng, + UpdateStateFn: func(updated json.RawMessage) { + mu.Lock() + states[spec.Name] = updated + mu.Unlock() + }, + } + if err := spec.Case.Tick(tickCtx, state); err != nil { + once.Do(func() { + firstErr = err + cancel() + }) + return + } + }) + } + wg.Wait() + return firstErr +} + +func newCaseRNGs(seed int64, selected []CaseSpec) map[string]*rand.Rand { + out := make(map[string]*rand.Rand, len(selected)) + for _, spec := range selected { + out[spec.Name] = newCaseRNG(seed, spec.Name) + } + return out +} + +func newCaseRNG(seed int64, name string) *rand.Rand { + h := fnv.New64a() + _, _ = h.Write([]byte(name)) + seq := h.Sum64() | 1 + return rand.New(rand.NewPCG(uint64(seed), seq)) +} + +func normalizeCaseSpecs(specs []CaseSpec) ([]CaseSpec, error) { + out := make([]CaseSpec, 0, len(specs)) + nameCounts := make(map[string]int, len(specs)) + for _, spec := range specs { + if spec.Case == nil { + return nil, fmt.Errorf("workload: nil case") + } + if spec.Name == "" { + nameCounts[spec.Case.Name()]++ + } + } + + used := make(map[string]struct{}, len(specs)) + caseIndex := make(map[string]int, len(specs)) + for _, spec := range specs { + name := spec.Name + if name == "" { + base := spec.Case.Name() + if nameCounts[base] > 1 { + caseIndex[base]++ + name = fmt.Sprintf("%s#%d", base, caseIndex[base]) + } else { + name = base + } + } + if _, ok := used[name]; ok { + return nil, fmt.Errorf("workload: duplicate case name %q", name) + } + used[name] = struct{}{} + out = append(out, CaseSpec{Name: name, Case: spec.Case}) + } + return out, nil +} + +func sleep(ctx context.Context, d time.Duration) error { + if d <= 0 { + return nil + } + t := time.NewTimer(d) + defer t.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-t.C: + return nil + } +} diff --git a/pkg/testkit/brhelper/workload/state_store.go b/pkg/testkit/brhelper/workload/state_store.go new file mode 100644 index 0000000000000..f484ed385e4c8 --- /dev/null +++ b/pkg/testkit/brhelper/workload/state_store.go @@ -0,0 +1,84 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package workload + +import ( + "context" + "encoding/json" + "sync" +) + +type StateStore interface { + Reset(ctx context.Context) error + Put(ctx context.Context, caseName string, state json.RawMessage) error + PutMany(ctx context.Context, states map[string]json.RawMessage) error + GetAll(ctx context.Context) (map[string]json.RawMessage, error) +} + +type MemoryStore struct { + mu sync.RWMutex + states map[string]json.RawMessage +} + +func NewMemoryStore() *MemoryStore { + return &MemoryStore{ + states: make(map[string]json.RawMessage), + } +} + +func (s *MemoryStore) Reset(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + s.states = make(map[string]json.RawMessage) + return nil +} + +func (s *MemoryStore) Put(ctx context.Context, caseName string, state json.RawMessage) error { + s.mu.Lock() + s.states[caseName] = cloneRaw(state) + s.mu.Unlock() + return nil +} + +func (s *MemoryStore) PutMany(ctx context.Context, states map[string]json.RawMessage) error { + if len(states) == 0 { + return nil + } + s.mu.Lock() + for caseName, state := range states { + s.states[caseName] = cloneRaw(state) + } + s.mu.Unlock() + return nil +} + +func (s *MemoryStore) GetAll(ctx context.Context) (map[string]json.RawMessage, error) { + s.mu.RLock() + defer s.mu.RUnlock() + out := make(map[string]json.RawMessage, len(s.states)) + for caseName, state := range s.states { + out[caseName] = cloneRaw(state) + } + return out, nil +} + +func cloneRaw(state json.RawMessage) json.RawMessage { + if len(state) == 0 { + return nil + } + out := make([]byte, len(state)) + copy(out, state) + return out +} diff --git a/pkg/testkit/brhelper/workload/util.go b/pkg/testkit/brhelper/workload/util.go new file mode 100644 index 0000000000000..3f88b394dce71 --- /dev/null +++ b/pkg/testkit/brhelper/workload/util.go @@ -0,0 +1,182 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package workload + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/hex" + "encoding/json" + "fmt" + "strings" +) + +func RandSuffix() (string, error) { + var b [8]byte + if _, err := rand.Read(b[:]); err != nil { + return "", err + } + return hex.EncodeToString(b[:]), nil +} + +func QIdent(s string) string { + return "`" + strings.ReplaceAll(s, "`", "``") + "`" +} + +func QTable(schema, table string) string { + return QIdent(schema) + "." + QIdent(table) +} + +func ExecAll(ctx context.Context, db *sql.DB, stmts []string) error { + for _, stmt := range stmts { + if _, err := db.ExecContext(ctx, stmt); err != nil { + return err + } + } + return nil +} + +func SchemaExists(ctx context.Context, db *sql.DB, schema string) (bool, error) { + var n int + err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?", schema).Scan(&n) + return n > 0, err +} + +func TableExists(ctx context.Context, db *sql.DB, schema, table string) (bool, error) { + var n int + err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM information_schema.TABLES WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?", schema, table).Scan(&n) + return n > 0, err +} + +func ColumnExists(ctx context.Context, db *sql.DB, schema, table, column string) (bool, error) { + var n int + err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? AND COLUMN_NAME = ?", schema, table, column).Scan(&n) + return n > 0, err +} + +func IndexExists(ctx context.Context, db *sql.DB, schema, table, index string) (bool, error) { + var n int + err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM information_schema.STATISTICS WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? AND INDEX_NAME = ?", schema, table, index).Scan(&n) + return n > 0, err +} + +func TiFlashReplicaCount(ctx context.Context, db *sql.DB, schema, table string) (int, error) { + var n sql.NullInt64 + err := db.QueryRowContext(ctx, "SELECT REPLICA_COUNT FROM information_schema.TIFLASH_REPLICA WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?", schema, table).Scan(&n) + if err == sql.ErrNoRows { + return 0, nil + } + if err != nil { + return 0, err + } + if !n.Valid { + return 0, nil + } + return int(n.Int64), nil +} + +type TableChecksum struct { + TotalKvs string `json:"total_kvs,omitempty"` + TotalBytes string `json:"total_bytes,omitempty"` + ChecksumCRC64Xor string `json:"checksum_crc64_xor,omitempty"` +} + +func (c *TableChecksum) UnmarshalJSON(b []byte) error { + if len(b) == 0 || string(b) == "null" { + return nil + } + if b[0] == '"' { + var s string + if err := json.Unmarshal(b, &s); err != nil { + return err + } + *c = TableChecksum{TotalKvs: s} + return nil + } + type alias TableChecksum + var v alias + if err := json.Unmarshal(b, &v); err != nil { + return err + } + *c = TableChecksum(v) + return nil +} + +func AdminChecksumTable(ctx context.Context, db *sql.DB, schema, table string) (TableChecksum, error) { + rows, err := db.QueryContext(ctx, "ADMIN CHECKSUM TABLE "+QTable(schema, table)) + if err != nil { + return TableChecksum{}, err + } + defer rows.Close() + + cols, err := rows.Columns() + if err != nil { + return TableChecksum{}, err + } + if !rows.Next() { + if err := rows.Err(); err != nil { + return TableChecksum{}, err + } + return TableChecksum{}, fmt.Errorf("checksum: no rows returned") + } + raw := make([]sql.RawBytes, len(cols)) + dest := make([]any, len(cols)) + for i := range raw { + dest[i] = &raw[i] + } + if err := rows.Scan(dest...); err != nil { + return TableChecksum{}, err + } + + var out TableChecksum + for i, c := range cols { + v := strings.TrimSpace(string(raw[i])) + switch { + case strings.EqualFold(c, "Total_kvs"): + out.TotalKvs = v + case strings.EqualFold(c, "Total_bytes"): + out.TotalBytes = v + case strings.EqualFold(c, "Checksum_crc64_xor"): + out.ChecksumCRC64Xor = v + } + } + + var missing []string + if out.TotalKvs == "" { + missing = append(missing, "Total_kvs") + } + if out.TotalBytes == "" { + missing = append(missing, "Total_bytes") + } + if out.ChecksumCRC64Xor == "" { + missing = append(missing, "Checksum_crc64_xor") + } + if len(missing) > 0 { + return TableChecksum{}, fmt.Errorf("checksum: column(s) not found: %v; columns=%v", missing, cols) + } + return out, nil +} + +func Require(cond bool, format string, args ...any) error { + if cond { + return nil + } + return fmt.Errorf(format, args...) +} + +func EveryNTick(tick int, n int) bool { + return n > 0 && tick%n == 0 +} diff --git a/tests/realtikvtest/brietest/BUILD.bazel b/tests/realtikvtest/brietest/BUILD.bazel index 37184d9791185..68fb3c0dcc050 100644 --- a/tests/realtikvtest/brietest/BUILD.bazel +++ b/tests/realtikvtest/brietest/BUILD.bazel @@ -12,6 +12,7 @@ go_test( "operator_test.go", "pitr_test.go", "registry_test.go", + "segmented_restore_test.go", "scheduler_test.go", ], flaky = True, @@ -45,6 +46,7 @@ go_test( "//pkg/util/printer", "//pkg/util/table-filter", "//tests/realtikvtest", + "//tests/realtikvtest/brietest/workloadcases", "@com_github_google_uuid//:uuid", "@com_github_pingcap_failpoint//:failpoint", "@com_github_pingcap_kvproto//pkg/brpb", diff --git a/tests/realtikvtest/brietest/segmented_restore_test.go b/tests/realtikvtest/brietest/segmented_restore_test.go new file mode 100644 index 0000000000000..ad4fe480a7f72 --- /dev/null +++ b/tests/realtikvtest/brietest/segmented_restore_test.go @@ -0,0 +1,213 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package brietest + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/pingcap/tidb/br/pkg/metautil" + "github.com/pingcap/tidb/br/pkg/registry" + "github.com/pingcap/tidb/br/pkg/task" + "github.com/pingcap/tidb/pkg/testkit" + "github.com/pingcap/tidb/pkg/testkit/brhelper/workload" + "github.com/pingcap/tidb/tests/realtikvtest/brietest/workloadcases" + "github.com/stretchr/testify/require" +) + +func TestSegmentedRestoreWorkload(t *testing.T) { + kit := NewLogBackupKit(t) + taskName := "segmented_restore_" + t.Name() + kit.StopTaskIfExists(taskName) + + db := testkit.CreateMockDB(kit.tk) + t.Cleanup(func() { + _ = db.Close() + }) + + store := workload.NewMemoryStore() + cases := []workload.Case{ + &workloadcases.NexusDDLDestructiveCase{}, + &workloadcases.NexusDDLCase{}, + &workloadcases.AddIndexCase{}, + } + if tiflashCount := tiflashStoreCount(t, kit.tk); tiflashCount > 0 { + cases = append(cases, &workloadcases.ModifyTiFlashCase{NAP: tiflashCount}) + } else { + t.Log("TiFlash not found in environment, won't run tiflash related cases.") + } + runner, err := workload.NewRunner(db, store, cases...) + require.NoError(t, err) + + ctx := context.Background() + err = runner.Prepare(ctx) + require.NoError(t, err) + + kit.RunFullBackup(func(cfg *task.BackupConfig) { + cfg.Storage = kit.LocalURI("full") + }) + backupTS := readBackupEndTS(t, kit.LocalURI("full")) + + kit.RunLogStart(taskName, func(cfg *task.StreamConfig) { + cfg.StartTS = backupTS + }) + t.Cleanup(func() { + kit.StopTaskIfExists(taskName) + }) + + checkpoints := make([]uint64, 0, 5) + runCfg := workload.RunConfig{ + TickCount: 100, + TickInterval: 0, + Seed: 1, + Parallel: true, + } + + for range 4 { + err := runner.Run(ctx, runCfg) + require.NoError(t, err) + kit.forceFlushAndWait(taskName) + checkpoints = append(checkpoints, kit.CheckpointTSOf(taskName)) + } + kit.forceFlushAndWait(taskName) + checkpoints = append(checkpoints, kit.CheckpointTSOf(taskName)) + kit.StopTaskIfExists(taskName) + + cleanupWorkloadSchemas(t, kit.tk) + cleanupRestoreRegistry(t, kit.tk) + + checkpointDir := filepath.Join(kit.base, "checkpoint") + require.NoError(t, os.RemoveAll(checkpointDir)) + + for i, restoreTS := range checkpoints { + idx := i + rcTS := restoreTS + kit.RunStreamRestore(func(rc *task.RestoreConfig) { + rc.RestoreTS = rcTS + rc.IsRestoredTSUserSpecified = true + rc.LastRestore = idx == len(checkpoints)-1 + rc.IsLastRestoreUserSpecified = true + rc.UseCheckpoint = true + if idx > 0 { + rc.StartTS = checkpoints[idx-1] + rc.FullBackupStorage = "" + } + }) + } + + require.NoError(t, runner.Verify(ctx)) +} + +func readBackupEndTS(t *testing.T, storage string) uint64 { + cfg := task.DefaultConfig() + cfg.Storage = storage + _, _, backupMeta, err := task.ReadBackupMeta(context.Background(), metautil.MetaFile, &cfg) + require.NoError(t, err) + return backupMeta.GetEndVersion() +} + +func cleanupWorkloadSchemas(t *testing.T, tk *testkit.TestKit) { + t.Helper() + + droppedAt := make(map[string]time.Time) + var lastLog time.Time + require.Eventuallyf(t, func() bool { + rows := tk.MustQuery("SELECT schema_name FROM information_schema.schemata").Rows() + remaining := make([]string, 0, len(rows)) + now := time.Now() + + for _, row := range rows { + name := fmt.Sprint(row[0]) + if isSystemSchema(name) { + continue + } + key := strings.ToLower(name) + remaining = append(remaining, name) + if last, ok := droppedAt[key]; !ok || now.Sub(last) > 5*time.Second { + tk.MustExec("DROP DATABASE IF EXISTS " + workload.QIdent(name)) + droppedAt[key] = now + } + } + + if len(remaining) == 0 { + return true + } + if now.Sub(lastLog) > 10*time.Second { + t.Logf("waiting for schemas to drop: %v", remaining) + lastLog = now + } + return false + }, 2*time.Minute, 500*time.Millisecond, "user schemas still exist") + + tk.MustExec("CREATE DATABASE IF NOT EXISTS test") +} + +func cleanupRestoreRegistry(t *testing.T, tk *testkit.TestKit) { + t.Helper() + + rows := tk.MustQuery(fmt.Sprintf( + "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = '%s' AND table_name = '%s'", + registry.RestoreRegistryDBName, + registry.RestoreRegistryTableName, + )).Rows() + require.Len(t, rows, 1) + count, err := parseCount(rows[0][0]) + require.NoError(t, err) + if count == 0 { + return + } + tk.MustExec(fmt.Sprintf("DELETE FROM %s.%s", registry.RestoreRegistryDBName, registry.RestoreRegistryTableName)) +} + +func isSystemSchema(name string) bool { + switch strings.ToLower(name) { + case "mysql", + "information_schema", + "performance_schema", + "sys", + "metrics_schema": + return true + default: + return false + } +} + +func tiflashStoreCount(t *testing.T, tk *testkit.TestKit) int { + rows := tk.MustQuery("SELECT COUNT(*) FROM information_schema.tikv_store_status WHERE JSON_SEARCH(LABEL, 'one', 'tiflash') IS NOT NULL").Rows() + require.Len(t, rows, 1) + count, err := parseCount(rows[0][0]) + require.NoError(t, err) + return count +} + +func parseCount(raw any) (int, error) { + switch v := raw.(type) { + case string: + var out int + _, err := fmt.Sscanf(v, "%d", &out) + return out, err + case int: + return v, nil + case int64: + return int(v), nil + default: + return 0, fmt.Errorf("unexpected count type %T", raw) + } +} diff --git a/tests/realtikvtest/brietest/workloadcases/BUILD.bazel b/tests/realtikvtest/brietest/workloadcases/BUILD.bazel new file mode 100644 index 0000000000000..fc5dd10ca3cd7 --- /dev/null +++ b/tests/realtikvtest/brietest/workloadcases/BUILD.bazel @@ -0,0 +1,17 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "workloadcases", + srcs = [ + "add_index.go", + "modify_tiflash.go", + "nexus_common.go", + "nexus_ddl.go", + "nexus_ddl_destructive.go", + ], + importpath = "github.com/pingcap/tidb/tests/realtikvtest/brietest/workloadcases", + visibility = ["//visibility:public"], + deps = [ + "//pkg/testkit/brhelper/workload", + ], +) diff --git a/tests/realtikvtest/brietest/workloadcases/add_index.go b/tests/realtikvtest/brietest/workloadcases/add_index.go new file mode 100644 index 0000000000000..961a6561d9d2b --- /dev/null +++ b/tests/realtikvtest/brietest/workloadcases/add_index.go @@ -0,0 +1,273 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package workloadcases + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/pingcap/tidb/pkg/testkit/brhelper/workload" +) + +type AddIndexCase struct { + Suffix string `json:"suffix"` + N int `json:"n"` + NR int `json:"nr"` +} + +type addIndexSpec struct { + Name string `json:"name"` + Columns []string `json:"columns"` +} + +type addIndexState struct { + Suffix string `json:"suffix"` + DB string `json:"db"` + Table string `json:"table"` + N int `json:"n"` + NR int `json:"nr"` + + Inserted int `json:"inserted"` + Ticked int `json:"ticked"` + + NextIndexID int `json:"next_index_id"` + + Indexes []addIndexSpec `json:"indexes"` + + Checksum workload.TableChecksum `json:"checksum"` + LogDone bool `json:"log_done"` +} + +func (c *AddIndexCase) Name() string { return "AddIndex" } + +func (c *AddIndexCase) Prepare(ctx workload.Context) (json.RawMessage, error) { + suffix := c.Suffix + if suffix == "" { + var err error + suffix, err = workload.RandSuffix() + if err != nil { + return nil, err + } + } + n := c.N + if n <= 0 { + n = 100 + } + nr := c.NR + if nr <= 0 { + nr = 150 + } + st := addIndexState{ + Suffix: suffix, + DB: fmt.Sprintf("test_add_index_%s", suffix), + Table: "t1", + N: n, + NR: nr, + NextIndexID: 0, + } + if err := workload.ExecAll(ctx, ctx.DB, []string{ + "CREATE DATABASE IF NOT EXISTS " + workload.QIdent(st.DB), + "CREATE TABLE IF NOT EXISTS " + workload.QTable(st.DB, st.Table) + " (" + + "id BIGINT PRIMARY KEY AUTO_INCREMENT," + + "a BIGINT," + + "b BIGINT," + + "c BIGINT," + + "d BIGINT," + + "e BIGINT" + + ")", + }); err != nil { + return nil, err + } + + return json.Marshal(st) +} + +func (c *AddIndexCase) Tick(ctx workload.TickContext, raw json.RawMessage) error { + var st addIndexState + if err := json.Unmarshal(raw, &st); err != nil { + return err + } + normalizeAddIndexState(&st) + + tickNo := st.Ticked + 1 + + if err := addIndexInsertRow(ctx, &st); err != nil { + return err + } + if err := c.maybeAddIndex(ctx, &st, tickNo); err != nil { + return err + } + if err := c.maybeDropIndex(ctx, &st, tickNo); err != nil { + return err + } + + st.Ticked++ + st.LogDone = true + + updated, err := json.Marshal(st) + if err != nil { + return err + } + ctx.UpdateState(updated) + return nil +} + +func (c *AddIndexCase) Exit(ctx workload.ExitContext, raw json.RawMessage) error { + var st addIndexState + if err := json.Unmarshal(raw, &st); err != nil { + return err + } + + checksum, err := workload.AdminChecksumTable(ctx, ctx.DB, st.DB, st.Table) + if err != nil { + return err + } + st.Checksum = checksum + st.LogDone = true + + updated, err := json.Marshal(st) + if err != nil { + return err + } + ctx.UpdateState(updated) + return nil +} + +func (c *AddIndexCase) Verify(ctx workload.Context, raw json.RawMessage) error { + var st addIndexState + if err := json.Unmarshal(raw, &st); err != nil { + return err + } + if err := workload.Require(st.LogDone, "AddIndex: log not executed"); err != nil { + return err + } + if err := workload.Require(st.Checksum.TotalKvs != "", "AddIndex: checksum not recorded; run Exit first"); err != nil { + return err + } + + for _, idx := range st.Indexes { + ok, err := workload.IndexExists(ctx, ctx.DB, st.DB, st.Table, idx.Name) + if err != nil { + return err + } + if err := workload.Require(ok, "AddIndex: index %q not found", idx.Name); err != nil { + return err + } + } + + checksum, err := workload.AdminChecksumTable(ctx, ctx.DB, st.DB, st.Table) + if err != nil { + return err + } + if err := workload.Require(checksum.TotalKvs == st.Checksum.TotalKvs, "AddIndex: Total_kvs mismatch: got %q want %q", checksum.TotalKvs, st.Checksum.TotalKvs); err != nil { + return err + } + if st.Checksum.TotalBytes != "" { + return workload.Require(checksum.TotalBytes == st.Checksum.TotalBytes, "AddIndex: Total_bytes mismatch: got %q want %q", checksum.TotalBytes, st.Checksum.TotalBytes) + } + return nil +} + +func hasAddIndexSpec(indexes []addIndexSpec, name string) bool { + for _, idx := range indexes { + if idx.Name == name { + return true + } + } + return false +} + +func normalizeAddIndexState(st *addIndexState) { + if st.N <= 0 { + st.N = 100 + } + if st.NR <= 0 { + st.NR = 150 + } + if st.NextIndexID < len(st.Indexes) { + st.NextIndexID = len(st.Indexes) + } +} + +func addIndexInsertRow(ctx workload.TickContext, st *addIndexState) error { + v := int64(st.Inserted) + if _, err := ctx.DB.ExecContext(ctx, "INSERT INTO "+workload.QTable(st.DB, st.Table)+" (a,b,c,d,e) VALUES (?,?,?,?,?)", + v, v*7+1, v*11+2, v*13+3, v*17+4, + ); err != nil { + return err + } + st.Inserted++ + return nil +} + +func (c *AddIndexCase) maybeAddIndex(ctx workload.TickContext, st *addIndexState, tickNo int) error { + if !workload.EveryNTick(tickNo, st.N) { + return nil + } + allCols := []string{"a", "b", "c", "d", "e"} + idxID := st.NextIndexID + idxName := fmt.Sprintf("idx_%d", idxID) + + colN := 1 + (idxID % 3) + start := idxID % len(allCols) + cols := make([]string, 0, colN) + for i := 0; i < colN; i++ { + cols = append(cols, allCols[(start+i)%len(allCols)]) + } + + exists, err := workload.IndexExists(ctx, ctx.DB, st.DB, st.Table, idxName) + if err != nil { + return err + } + if !exists { + colSQL := make([]string, 0, len(cols)) + for _, col := range cols { + colSQL = append(colSQL, workload.QIdent(col)) + } + stmt := "CREATE INDEX " + workload.QIdent(idxName) + " ON " + workload.QTable(st.DB, st.Table) + " (" + strings.Join(colSQL, ",") + ")" + if _, err := ctx.DB.ExecContext(ctx, stmt); err != nil { + return err + } + } + + spec := addIndexSpec{Name: idxName, Columns: cols} + if !hasAddIndexSpec(st.Indexes, idxName) { + st.Indexes = append(st.Indexes, spec) + } + st.NextIndexID++ + return nil +} + +func (c *AddIndexCase) maybeDropIndex(ctx workload.TickContext, st *addIndexState, tickNo int) error { + if !workload.EveryNTick(tickNo, st.NR) || len(st.Indexes) == 0 { + return nil + } + idx := ctx.RNG.IntN(len(st.Indexes)) + dropSpec := st.Indexes[idx] + + exists, err := workload.IndexExists(ctx, ctx.DB, st.DB, st.Table, dropSpec.Name) + if err != nil { + return err + } + if exists { + stmt := "DROP INDEX " + workload.QIdent(dropSpec.Name) + " ON " + workload.QTable(st.DB, st.Table) + if _, err := ctx.DB.ExecContext(ctx, stmt); err != nil { + return err + } + } + st.Indexes = append(st.Indexes[:idx], st.Indexes[idx+1:]...) + return nil +} diff --git a/tests/realtikvtest/brietest/workloadcases/modify_tiflash.go b/tests/realtikvtest/brietest/workloadcases/modify_tiflash.go new file mode 100644 index 0000000000000..7cd0649b08448 --- /dev/null +++ b/tests/realtikvtest/brietest/workloadcases/modify_tiflash.go @@ -0,0 +1,192 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package workloadcases + +import ( + "encoding/json" + "fmt" + + "github.com/pingcap/tidb/pkg/testkit/brhelper/workload" +) + +type ModifyTiFlashCase struct { + Suffix string `json:"suffix"` + N int `json:"n"` + NAP int `json:"nap"` +} + +type modifyTiFlashState struct { + Suffix string `json:"suffix"` + DB string `json:"db"` + Table string `json:"table"` + N int `json:"n"` + NAP int `json:"nap"` + + Ticked int `json:"ticked"` + Inserted int `json:"inserted"` + + Replica int `json:"replica"` + + Checksum workload.TableChecksum `json:"checksum"` + LogDone bool `json:"log_done"` +} + +func (c *ModifyTiFlashCase) Name() string { return "ModifyTiFlash" } + +func (c *ModifyTiFlashCase) Prepare(ctx workload.Context) (json.RawMessage, error) { + suffix := c.Suffix + if suffix == "" { + var err error + suffix, err = workload.RandSuffix() + if err != nil { + return nil, err + } + } + n := c.N + if n <= 0 { + n = 100 + } + nap := c.NAP + if nap <= 0 { + nap = 1 + } + st := modifyTiFlashState{ + Suffix: suffix, + DB: fmt.Sprintf("test_modify_tiflash_%s", suffix), + Table: "t1", + N: n, + NAP: nap, + Replica: 0, + } + if err := workload.ExecAll(ctx, ctx.DB, []string{ + "CREATE DATABASE IF NOT EXISTS " + workload.QIdent(st.DB), + "CREATE TABLE IF NOT EXISTS " + workload.QTable(st.DB, st.Table) + " (" + + "id BIGINT PRIMARY KEY AUTO_INCREMENT," + + "a BIGINT," + + "b BIGINT," + + "c BIGINT" + + ")", + "ALTER TABLE " + workload.QTable(st.DB, st.Table) + " SET TIFLASH REPLICA 0", + }); err != nil { + return nil, err + } + + return json.Marshal(st) +} + +func (c *ModifyTiFlashCase) Tick(ctx workload.TickContext, raw json.RawMessage) error { + var st modifyTiFlashState + if err := json.Unmarshal(raw, &st); err != nil { + return err + } + if st.N <= 0 { + st.N = 100 + } + if st.NAP <= 0 { + st.NAP = 2 + } + + tickNo := st.Ticked + 1 + + if _, err := ctx.DB.ExecContext(ctx, "INSERT INTO "+workload.QTable(st.DB, st.Table)+" (a,b,c) VALUES (?,?,?)", + int64(st.Inserted), int64(st.Inserted*7+1), int64(st.Inserted*11+2), + ); err != nil { + return err + } + st.Inserted++ + + if workload.EveryNTick(tickNo, st.N) { + max := st.NAP + if max > 0 { + next := tickNo % (max + 1) + if next == st.Replica { + next = (next + 1) % (max + 1) + } + stmt := fmt.Sprintf("ALTER TABLE %s SET TIFLASH REPLICA %d", workload.QTable(st.DB, st.Table), next) + if _, err := ctx.DB.ExecContext(ctx, stmt); err != nil { + return err + } + st.Replica = next + } + } + + st.Ticked++ + st.LogDone = true + + updated, err := json.Marshal(st) + if err != nil { + return err + } + ctx.UpdateState(updated) + return nil +} + +func (c *ModifyTiFlashCase) Exit(ctx workload.ExitContext, raw json.RawMessage) error { + var st modifyTiFlashState + if err := json.Unmarshal(raw, &st); err != nil { + return err + } + + sum, err := workload.AdminChecksumTable(ctx, ctx.DB, st.DB, st.Table) + if err != nil { + return err + } + replica, err := workload.TiFlashReplicaCount(ctx, ctx.DB, st.DB, st.Table) + if err != nil { + return err + } + st.Checksum = sum + st.Replica = replica + st.LogDone = true + + updated, err := json.Marshal(st) + if err != nil { + return err + } + ctx.UpdateState(updated) + return nil +} + +func (c *ModifyTiFlashCase) Verify(ctx workload.Context, raw json.RawMessage) error { + var st modifyTiFlashState + if err := json.Unmarshal(raw, &st); err != nil { + return err + } + if err := workload.Require(st.LogDone, "ModifyTiFlash: log not executed"); err != nil { + return err + } + if err := workload.Require(st.Checksum.TotalKvs != "", "ModifyTiFlash: checksum not recorded; run Exit first"); err != nil { + return err + } + + sum, err := workload.AdminChecksumTable(ctx, ctx.DB, st.DB, st.Table) + if err != nil { + return err + } + if err := workload.Require(sum.TotalKvs == st.Checksum.TotalKvs, "ModifyTiFlash: Total_kvs mismatch: got %q want %q", sum.TotalKvs, st.Checksum.TotalKvs); err != nil { + return err + } + if st.Checksum.TotalBytes != "" { + if err := workload.Require(sum.TotalBytes == st.Checksum.TotalBytes, "ModifyTiFlash: Total_bytes mismatch: got %q want %q", sum.TotalBytes, st.Checksum.TotalBytes); err != nil { + return err + } + } + + replica, err := workload.TiFlashReplicaCount(ctx, ctx.DB, st.DB, st.Table) + if err != nil { + return err + } + return workload.Require(replica == st.Replica, "ModifyTiFlash: tiflash replica mismatch: got %d want %d", replica, st.Replica) +} diff --git a/tests/realtikvtest/brietest/workloadcases/nexus_common.go b/tests/realtikvtest/brietest/workloadcases/nexus_common.go new file mode 100644 index 0000000000000..f9a41a04edfbf --- /dev/null +++ b/tests/realtikvtest/brietest/workloadcases/nexus_common.go @@ -0,0 +1,94 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package workloadcases + +import ( + "context" + "database/sql" + "fmt" + + "github.com/pingcap/tidb/pkg/testkit/brhelper/workload" +) + +type nexusTableState struct { + Name string `json:"name"` + NextColID int `json:"next_col_id,omitempty"` + Cols []string `json:"cols,omitempty"` +} + +type nexusState struct { + Suffix string `json:"suffix"` + DB string `json:"db"` + N int `json:"n"` + + Ticked int `json:"ticked"` + NextTableID int `json:"next_table_id"` + Tables []nexusTableState `json:"tables"` + + Checksums map[string]workload.TableChecksum `json:"checksums,omitempty"` + LogDone bool `json:"log_done"` +} + +func nexusDefaultN(n int) int { + if n <= 0 { + return 50 + } + return n +} + +func nexusHalf(n int) int { + h := n / 2 + if h <= 0 { + return 1 + } + return h +} + +func nexusTableName(id int) string { + return fmt.Sprintf("t_%d", id) +} + +func nexusExecDDL(ctx context.Context, db *sql.DB, tick int, stmt string) error { + _, err := db.ExecContext(ctx, stmt) + return err +} + +func nexusCreateTable(ctx context.Context, db *sql.DB, tick int, schema, table string) error { + stmt := "CREATE TABLE IF NOT EXISTS " + workload.QTable(schema, table) + " (" + + "id BIGINT PRIMARY KEY AUTO_INCREMENT," + + "v BIGINT," + + "s VARCHAR(64) NOT NULL" + + ")" + return nexusExecDDL(ctx, db, tick, stmt) +} + +func nexusInsertRow(ctx context.Context, db *sql.DB, schema, table string, tick int) error { + _, err := db.ExecContext(ctx, "INSERT INTO "+workload.QTable(schema, table)+" (v,s) VALUES (?,?)", + int64(tick), fmt.Sprintf("%s_%d", table, tick), + ) + return err +} + +func nexusRecordChecksums(ctx context.Context, db *sql.DB, schema string, tables []nexusTableState) (map[string]workload.TableChecksum, error) { + out := make(map[string]workload.TableChecksum, len(tables)) + for _, t := range tables { + sum, err := workload.AdminChecksumTable(ctx, db, schema, t.Name) + if err != nil { + return nil, err + } + out[t.Name] = sum + } + return out, nil +} diff --git a/tests/realtikvtest/brietest/workloadcases/nexus_ddl.go b/tests/realtikvtest/brietest/workloadcases/nexus_ddl.go new file mode 100644 index 0000000000000..456ae457ab352 --- /dev/null +++ b/tests/realtikvtest/brietest/workloadcases/nexus_ddl.go @@ -0,0 +1,250 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package workloadcases + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "slices" + + "github.com/pingcap/tidb/pkg/testkit/brhelper/workload" +) + +type NexusDDLCase struct { + Suffix string `json:"suffix"` + N int `json:"n"` +} + +func (c *NexusDDLCase) Name() string { return "NexusDDL" } + +func (c *NexusDDLCase) Prepare(ctx workload.Context) (json.RawMessage, error) { + suffix := c.Suffix + if suffix == "" { + var err error + suffix, err = workload.RandSuffix() + if err != nil { + return nil, err + } + } + n := c.N + if n <= 0 { + n = 50 + } + st := nexusState{ + Suffix: suffix, + DB: fmt.Sprintf("test_nexus_ddl_%s", suffix), + N: n, + Ticked: 0, + NextTableID: 1, + Tables: []nexusTableState{{Name: "t_0"}}, + } + if err := nexusExecDDL(ctx, ctx.DB, 0, "CREATE DATABASE IF NOT EXISTS "+workload.QIdent(st.DB)); err != nil { + return nil, err + } + if err := nexusCreateTable(ctx, ctx.DB, 0, st.DB, st.Tables[0].Name); err != nil { + return nil, err + } + return json.Marshal(st) +} + +func (c *NexusDDLCase) Tick(ctx workload.TickContext, raw json.RawMessage) error { + var st nexusState + if err := json.Unmarshal(raw, &st); err != nil { + return err + } + st.N = nexusDefaultN(st.N) + if st.NextTableID <= 0 { + st.NextTableID = len(st.Tables) + } + for i := range st.Tables { + if st.Tables[i].NextColID < len(st.Tables[i].Cols) { + st.Tables[i].NextColID = len(st.Tables[i].Cols) + } + } + + tickNo := st.Ticked + 1 + half := nexusHalf(st.N) + + if workload.EveryNTick(tickNo, 2*st.N) && len(st.Tables) > 0 { + oldest := st.Tables[0].Name + stmt := "DROP TABLE IF EXISTS " + workload.QTable(st.DB, oldest) + if err := nexusExecDDL(ctx, ctx.DB, tickNo, stmt); err != nil { + return err + } + st.Tables = st.Tables[1:] + } + + if workload.EveryNTick(tickNo, st.N) { + name := nexusTableName(st.NextTableID) + st.NextTableID++ + if err := nexusCreateTable(ctx, ctx.DB, tickNo, st.DB, name); err != nil { + return err + } + st.Tables = append(st.Tables, nexusTableState{Name: name}) + } + + if workload.EveryNTick(tickNo, half) && len(st.Tables) > 0 { + youngest := &st.Tables[len(st.Tables)-1] + if err := nexusAddOneColumn(ctx, ctx.DB, &st, tickNo, youngest); err != nil { + return err + } + } + + if workload.EveryNTick(tickNo, st.N) && len(st.Tables) > 0 { + oldest := &st.Tables[0] + if err := nexusDropOneColumn(ctx, ctx.DB, &st, tickNo, oldest); err != nil { + return err + } + } + + for _, t := range st.Tables { + if err := nexusInsertRow(ctx, ctx.DB, st.DB, t.Name, tickNo); err != nil { + return err + } + } + + st.Ticked++ + st.LogDone = true + + updated, err := json.Marshal(st) + if err != nil { + return err + } + ctx.UpdateState(updated) + return nil +} + +func (c *NexusDDLCase) Exit(ctx workload.ExitContext, raw json.RawMessage) error { + var st nexusState + if err := json.Unmarshal(raw, &st); err != nil { + return err + } + + sums, err := nexusRecordChecksums(ctx, ctx.DB, st.DB, st.Tables) + if err != nil { + return err + } + st.Checksums = sums + st.LogDone = true + + updated, err := json.Marshal(st) + if err != nil { + return err + } + ctx.UpdateState(updated) + return nil +} + +func (c *NexusDDLCase) Verify(ctx workload.Context, raw json.RawMessage) error { + var st nexusState + if err := json.Unmarshal(raw, &st); err != nil { + return err + } + if err := workload.Require(st.LogDone, "NexusDDL: log not executed"); err != nil { + return err + } + if err := workload.Require(len(st.Checksums) > 0, "NexusDDL: checksum not recorded; run Exit first"); err != nil { + return err + } + + for _, t := range st.Tables { + ok, err := workload.TableExists(ctx, ctx.DB, st.DB, t.Name) + if err != nil { + return err + } + if err := workload.Require(ok, "NexusDDL: table %s.%s not found", st.DB, t.Name); err != nil { + return err + } + + for _, col := range t.Cols { + has, err := workload.ColumnExists(ctx, ctx.DB, st.DB, t.Name, col) + if err != nil { + return err + } + if err := workload.Require(has, "NexusDDL: %s.%s column %q not found", st.DB, t.Name, col); err != nil { + return err + } + } + + want, ok := st.Checksums[t.Name] + if !ok { + return fmt.Errorf("NexusDDL: missing checksum for table %s.%s", st.DB, t.Name) + } + got, err := workload.AdminChecksumTable(ctx, ctx.DB, st.DB, t.Name) + if err != nil { + return err + } + if err := workload.Require(got.TotalKvs == want.TotalKvs, "NexusDDL: Total_kvs mismatch for %s.%s: got %q want %q", st.DB, t.Name, got.TotalKvs, want.TotalKvs); err != nil { + return err + } + if want.TotalBytes != "" { + if err := workload.Require(got.TotalBytes == want.TotalBytes, "NexusDDL: Total_bytes mismatch for %s.%s: got %q want %q", st.DB, t.Name, got.TotalBytes, want.TotalBytes); err != nil { + return err + } + } + } + return nil +} + +func nexusAddOneColumn(ctx context.Context, db *sql.DB, st *nexusState, tick int, t *nexusTableState) error { + if t == nil { + return nil + } + if t.NextColID < len(t.Cols) { + t.NextColID = len(t.Cols) + } + + col := fmt.Sprintf("c_%d", t.NextColID) + exists, err := workload.ColumnExists(ctx, db, st.DB, t.Name, col) + if err != nil { + return err + } + if exists { + if !slices.Contains(t.Cols, col) { + t.Cols = append(t.Cols, col) + } + t.NextColID++ + return nil + } + + stmt := "ALTER TABLE " + workload.QTable(st.DB, t.Name) + " ADD COLUMN " + workload.QIdent(col) + " BIGINT" + if err := nexusExecDDL(ctx, db, tick, stmt); err != nil { + return err + } + t.Cols = append(t.Cols, col) + t.NextColID++ + return nil +} + +func nexusDropOneColumn(ctx context.Context, db *sql.DB, st *nexusState, tick int, t *nexusTableState) error { + if t == nil || len(t.Cols) == 0 { + return nil + } + col := t.Cols[0] + exists, err := workload.ColumnExists(ctx, db, st.DB, t.Name, col) + if err != nil { + return err + } + if exists { + stmt := "ALTER TABLE " + workload.QTable(st.DB, t.Name) + " DROP COLUMN " + workload.QIdent(col) + if err := nexusExecDDL(ctx, db, tick, stmt); err != nil { + return err + } + } + t.Cols = t.Cols[1:] + return nil +} diff --git a/tests/realtikvtest/brietest/workloadcases/nexus_ddl_destructive.go b/tests/realtikvtest/brietest/workloadcases/nexus_ddl_destructive.go new file mode 100644 index 0000000000000..55e27d81bc1fa --- /dev/null +++ b/tests/realtikvtest/brietest/workloadcases/nexus_ddl_destructive.go @@ -0,0 +1,180 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package workloadcases + +import ( + "encoding/json" + "fmt" + + "github.com/pingcap/tidb/pkg/testkit/brhelper/workload" +) + +type NexusDDLDestructiveCase struct { + Suffix string `json:"suffix"` + N int `json:"n"` +} + +func (c *NexusDDLDestructiveCase) Name() string { return "NexusDDLDestructive" } + +func (c *NexusDDLDestructiveCase) Prepare(ctx workload.Context) (json.RawMessage, error) { + suffix := c.Suffix + if suffix == "" { + var err error + suffix, err = workload.RandSuffix() + if err != nil { + return nil, err + } + } + n := c.N + if n <= 0 { + n = 50 + } + st := nexusState{ + Suffix: suffix, + DB: fmt.Sprintf("test_nexus_ddl_destructive_%s", suffix), + N: n, + Ticked: 0, + NextTableID: 1, + Tables: []nexusTableState{{Name: "t_0"}}, + } + if err := nexusExecDDL(ctx, ctx.DB, 0, "CREATE DATABASE IF NOT EXISTS "+workload.QIdent(st.DB)); err != nil { + return nil, err + } + if err := nexusCreateTable(ctx, ctx.DB, 0, st.DB, st.Tables[0].Name); err != nil { + return nil, err + } + return json.Marshal(st) +} + +func (c *NexusDDLDestructiveCase) Tick(ctx workload.TickContext, raw json.RawMessage) error { + var st nexusState + if err := json.Unmarshal(raw, &st); err != nil { + return err + } + st.N = nexusDefaultN(st.N) + if st.NextTableID <= 0 { + st.NextTableID = len(st.Tables) + } + + tickNo := st.Ticked + 1 + half := nexusHalf(st.N) + + if workload.EveryNTick(tickNo, st.N) { + name := nexusTableName(st.NextTableID) + st.NextTableID++ + if err := nexusCreateTable(ctx, ctx.DB, tickNo, st.DB, name); err != nil { + return err + } + st.Tables = append(st.Tables, nexusTableState{Name: name}) + } + + if workload.EveryNTick(tickNo, half) && len(st.Tables) > 0 { + idx := ctx.RNG.IntN(len(st.Tables)) + oldName := st.Tables[idx].Name + newName := nexusTableName(st.NextTableID) + st.NextTableID++ + stmt := "RENAME TABLE " + workload.QTable(st.DB, oldName) + " TO " + workload.QTable(st.DB, newName) + if err := nexusExecDDL(ctx, ctx.DB, tickNo, stmt); err != nil { + return err + } + st.Tables[idx].Name = newName + } + + if workload.EveryNTick(tickNo, 2*st.N) && len(st.Tables) > 0 { + idx := ctx.RNG.IntN(len(st.Tables)) + stmt := "TRUNCATE TABLE " + workload.QTable(st.DB, st.Tables[idx].Name) + if err := nexusExecDDL(ctx, ctx.DB, tickNo, stmt); err != nil { + return err + } + } + + for _, t := range st.Tables { + if err := nexusInsertRow(ctx, ctx.DB, st.DB, t.Name, tickNo); err != nil { + return err + } + } + + st.Ticked++ + st.LogDone = true + + updated, err := json.Marshal(st) + if err != nil { + return err + } + ctx.UpdateState(updated) + return nil +} + +func (c *NexusDDLDestructiveCase) Exit(ctx workload.ExitContext, raw json.RawMessage) error { + var st nexusState + if err := json.Unmarshal(raw, &st); err != nil { + return err + } + + sums, err := nexusRecordChecksums(ctx, ctx.DB, st.DB, st.Tables) + if err != nil { + return err + } + st.Checksums = sums + st.LogDone = true + + updated, err := json.Marshal(st) + if err != nil { + return err + } + ctx.UpdateState(updated) + return nil +} + +func (c *NexusDDLDestructiveCase) Verify(ctx workload.Context, raw json.RawMessage) error { + var st nexusState + if err := json.Unmarshal(raw, &st); err != nil { + return err + } + if err := workload.Require(st.LogDone, "NexusDDLDestructive: log not executed"); err != nil { + return err + } + if err := workload.Require(len(st.Checksums) > 0, "NexusDDLDestructive: checksum not recorded; run Exit first"); err != nil { + return err + } + + for _, t := range st.Tables { + ok, err := workload.TableExists(ctx, ctx.DB, st.DB, t.Name) + if err != nil { + return err + } + if err := workload.Require(ok, "NexusDDLDestructive: table %s.%s not found", st.DB, t.Name); err != nil { + return err + } + + want, ok := st.Checksums[t.Name] + if !ok { + return fmt.Errorf("NexusDDLDestructive: missing checksum for table %s.%s", st.DB, t.Name) + } + got, err := workload.AdminChecksumTable(ctx, ctx.DB, st.DB, t.Name) + if err != nil { + return err + } + if err := workload.Require(got.TotalKvs == want.TotalKvs, "NexusDDLDestructive: Total_kvs mismatch for %s.%s: got %q want %q", st.DB, t.Name, got.TotalKvs, want.TotalKvs); err != nil { + return err + } + if want.TotalBytes != "" { + if err := workload.Require(got.TotalBytes == want.TotalBytes, "NexusDDLDestructive: Total_bytes mismatch for %s.%s: got %q want %q", st.DB, t.Name, got.TotalBytes, want.TotalBytes); err != nil { + return err + } + } + } + return nil +}