Skip to content

Commit

Permalink
add the target info to the ts manager key
Browse files Browse the repository at this point in the history
Signed-off-by: SimFG <[email protected]>
  • Loading branch information
SimFG committed Oct 22, 2024
1 parent 3ea1ebf commit 74073f4
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 46 deletions.
1 change: 1 addition & 0 deletions core/config/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ type ReaderConfig struct {
Retry RetrySettings
SourceChannelNum int
TargetChannelNum int
ReplicateID string
}
66 changes: 45 additions & 21 deletions core/reader/replicate_channel_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ var replicatePool = conc.NewPool[struct{}](10, conc.WithExpiryDuration(time.Minu

type replicateChannelManager struct {
replicateCtx context.Context
replicateID string
streamDispatchClient msgdispatcher.Client
streamCreator StreamCreator
targetClient api.TargetAPI
Expand Down Expand Up @@ -108,6 +109,7 @@ func NewReplicateChannelManagerWithDispatchClient(
downstream string,
) (api.ChannelManager, error) {
return &replicateChannelManager{
replicateID: readConfig.ReplicateID,
streamDispatchClient: dispatchClient,
streamCreator: NewDisptachClientStreamCreator(factory, dispatchClient),
targetClient: client,
Expand Down Expand Up @@ -395,7 +397,8 @@ func (r *replicateChannelManager) StartReadCollection(ctx context.Context, db *m
channelHandlers = append(channelHandlers, channelHandler)
}
successChannels = append(successChannels, sourcePChannel)
log.Info("start read channel",
log.Info("start read channel in the manager",
zap.Bool("nil_handler", channelHandler == nil),
zap.String("channel", sourcePChannel),
zap.String("target_channel", targetPChannel),
zap.Int64("collection_id", info.ID))
Expand Down Expand Up @@ -468,7 +471,7 @@ func (r *replicateChannelManager) AddPartition(ctx context.Context, dbInfo *mode
}
r.channelLock.RUnlock()
if len(handlers) == 0 {
partitionLog.Info("waiting handler", zap.Int64("collection_id", collectionID))
partitionLog.Info("waiting handler")
return errors.New("no handler found")
}
return nil
Expand All @@ -478,7 +481,7 @@ func (r *replicateChannelManager) AddPartition(ctx context.Context, dbInfo *mode
}

if len(handlers) == 0 {
partitionLog.Warn("no handler found", zap.Int64("collection_id", collectionID))
partitionLog.Warn("no handler found")
return errors.New("no handler found")
}

Expand Down Expand Up @@ -591,11 +594,11 @@ func (r *replicateChannelManager) StopReadCollection(ctx context.Context, info *
}

func (r *replicateChannelManager) GetChannelChan() <-chan string {
return GetTSManager().GetTargetChannelChan()
return GetTSManager().GetTargetChannelChan(r.replicateID)
}

func (r *replicateChannelManager) GetMsgChan(pChannel string) <-chan *api.ReplicateMsg {
return GetTSManager().GetTargetMsgChan(pChannel)
return GetTSManager().GetTargetMsgChan(r.replicateID, pChannel)
}

func (r *replicateChannelManager) GetEventChan() <-chan *api.ReplicateAPIEvent {
Expand Down Expand Up @@ -641,6 +644,7 @@ func (r *replicateChannelManager) startReadChannel(sourceInfo *model.SourceColle
channelHandler.forwardMsgFunc = r.forwardMsg
channelHandler.isDroppedCollection = r.isDroppedCollection
channelHandler.isDroppedPartition = r.isDroppedPartition
channelHandler.replicateID = r.replicateID
diffValueForKey := r.channelMapping.CheckKeyNotExist(sourceInfo.PChannel, targetInfo.PChannel)

if !diffValueForKey {
Expand All @@ -658,9 +662,6 @@ func (r *replicateChannelManager) startReadChannel(sourceInfo *model.SourceColle
}
return nil, nil
}
if sourceInfo.SeekPosition != nil {
GetTSManager().CollectTS(channelHandler.targetPChannel, sourceInfo.SeekPosition.GetTimestamp())
}
if !r.channelMapping.CheckKeyExist(sourceInfo.PChannel, targetInfo.PChannel) {
log.Info("diff target pchannel",
zap.String("source_channel", sourceInfo.PChannel),
Expand Down Expand Up @@ -831,6 +832,7 @@ func (r *replicateChannelManager) stopReadChannel(pChannelName string, collectio

type replicateChannelHandler struct {
replicateCtx context.Context
replicateID string
sourcePChannel string
targetPChannel string
targetClient api.TargetAPI
Expand Down Expand Up @@ -867,6 +869,7 @@ type replicateChannelHandler struct {

func (r *replicateChannelHandler) AddCollection(sourceInfo *model.SourceCollectionInfo, targetInfo *model.TargetCollectionInfo) {
<-r.startReadChan
r.collectionSourceSeekPosition(sourceInfo.SeekPosition)
collectionID := sourceInfo.CollectionID
streamChan, closeStreamFunc, err := r.streamCreator.GetStreamChan(r.replicateCtx, sourceInfo.VChannel, sourceInfo.SeekPosition)
if err != nil {
Expand Down Expand Up @@ -1145,6 +1148,10 @@ func (r *replicateChannelHandler) Close() {
// r.stream.Close()
}

func (r *replicateChannelHandler) getTSManagerChannelKey(channelName string) string {
return FormatChanKey(r.replicateID, channelName)
}

func (r *replicateChannelHandler) innerHandleReplicateMsg(forward bool, msg *api.ReplicateMsg) {
msgPack := msg.MsgPack
p := r.handlePack(forward, msgPack)
Expand All @@ -1154,17 +1161,33 @@ func (r *replicateChannelHandler) innerHandleReplicateMsg(forward bool, msg *api
p.CollectionID = msg.CollectionID
p.CollectionName = msg.CollectionName
p.PChannelName = msg.PChannelName
GetTSManager().SendTargetMsg(r.targetPChannel, p)
GetTSManager().SendTargetMsg(r.getTSManagerChannelKey(r.targetPChannel), p)
}

func (r *replicateChannelHandler) collectionSourceSeekPosition(sourceSeekPosition *msgstream.MsgPosition) {
if sourceSeekPosition == nil {
return
}
GetTSManager().CollectTS(r.getTSManagerChannelKey(r.targetPChannel), sourceSeekPosition.GetTimestamp())
}

func (r *replicateChannelHandler) startReadChannel() {
close(r.startReadChan)
var cts uint64 = math.MaxUint64
if r.sourceSeekPosition != nil {
cts = r.sourceSeekPosition.GetTimestamp()
}
GetTSManager().InitTSInfo(r.targetPChannel, time.Duration(r.handlerOpts.TTInterval)*time.Millisecond, cts, r.handlerOpts.MessageBufferSize)
log.Info("start read channel",
log.Info("start read channel in the handler before",
zap.String("channel_name", r.sourcePChannel),
zap.String("target_channel", r.targetPChannel),
)
GetTSManager().InitTSInfo(r.replicateID, r.targetPChannel, time.Duration(r.handlerOpts.TTInterval)*time.Millisecond, cts, r.handlerOpts.MessageBufferSize)
log.Info("start read channel in the handler",
zap.String("channel_name", r.sourcePChannel),
zap.String("target_channel", r.targetPChannel),
)
close(r.startReadChan)
r.collectionSourceSeekPosition(r.sourceSeekPosition)
log.Info("start read channel in the handler end",
zap.String("channel_name", r.sourcePChannel),
zap.String("target_channel", r.targetPChannel),
)
Expand Down Expand Up @@ -1289,6 +1312,7 @@ func (r *replicateChannelHandler) handlePack(forward bool, pack *msgstream.MsgPa
return pack.Msgs[i].BeginTs() < pack.Msgs[j].BeginTs() ||
(pack.Msgs[i].BeginTs() == pack.Msgs[j].BeginTs() && pack.Msgs[i].Type() == commonpb.MsgType_Delete)
})
tsManagerChannelKey := r.getTSManagerChannelKey(r.targetPChannel)

r.addCollectionLock.RLock()
if *r.addCollectionCnt != 0 {
Expand Down Expand Up @@ -1326,7 +1350,7 @@ func (r *replicateChannelHandler) handlePack(forward bool, pack *msgstream.MsgPa
log.Warn("begin timestamp is 0", zap.Uint64("end_ts", pack.EndTs), zap.Any("hasValidMsg", hasValidMsg))
}
}
GetTSManager().CollectTS(r.targetPChannel, beginTS)
GetTSManager().CollectTS(tsManagerChannelKey, beginTS)
r.addCollectionLock.RUnlock()

if r.msgPackCallback != nil {
Expand Down Expand Up @@ -1562,26 +1586,26 @@ func (r *replicateChannelHandler) handlePack(forward bool, pack *msgstream.MsgPa
position.ChannelName = pChannel
}

maxTS, _ := GetTSManager().GetMaxTS(r.targetPChannel)
maxTS, _ := GetTSManager().GetMaxTS(tsManagerChannelKey)
resetTS := resetMsgPackTimestamp(newPack, maxTS)
if resetTS {
GetTSManager().CollectTS(r.targetPChannel, newPack.EndTs)
GetTSManager().CollectTS(tsManagerChannelKey, newPack.EndTs)
}

GetTSManager().LockTargetChannel(r.targetPChannel)
defer GetTSManager().UnLockTargetChannel(r.targetPChannel)
GetTSManager().LockTargetChannel(tsManagerChannelKey)
defer GetTSManager().UnLockTargetChannel(tsManagerChannelKey)

if !needTsMsg && len(newPack.Msgs) == 0 && !GetTSManager().UnsafeShouldSendTSMsg(r.targetPChannel) {
if !needTsMsg && len(newPack.Msgs) == 0 && !GetTSManager().UnsafeShouldSendTSMsg(tsManagerChannelKey) {
return api.EmptyMsgPack
}

generateTS, ok := GetTSManager().UnsafeGetMaxTS(r.targetPChannel)
generateTS, ok := GetTSManager().UnsafeGetMaxTS(tsManagerChannelKey)
if !ok {
log.Warn("not found the max ts", zap.String("channel", r.targetPChannel))
r.sendErrEvent(fmt.Errorf("not found the max ts"))
return nil
}
GetTSManager().UnsafeUpdatePackTS(r.targetPChannel, newPack.BeginTs, func(newTS uint64) (uint64, bool) {
GetTSManager().UnsafeUpdatePackTS(tsManagerChannelKey, newPack.BeginTs, func(newTS uint64) (uint64, bool) {
generateTS = newTS
reset := resetMsgPackTimestamp(newPack, newTS)
return newPack.EndTs, reset
Expand Down Expand Up @@ -1611,7 +1635,7 @@ func (r *replicateChannelHandler) handlePack(forward bool, pack *msgstream.MsgPa
}
newPack.Msgs = append(newPack.Msgs, timeTickMsg)

GetTSManager().UnsafeUpdateTSInfo(r.targetPChannel, generateTS, resetLastTs)
GetTSManager().UnsafeUpdateTSInfo(tsManagerChannelKey, generateTS, resetLastTs)
msgTime, _ := tsoutil.ParseHybridTs(generateTS)
TSMetricVec.WithLabelValues(r.targetPChannel).Set(float64(msgTime))
r.ttRateLog.Debug("time tick msg", zap.String("channel", r.targetPChannel), zap.Uint64("max_ts", generateTS))
Expand Down
12 changes: 9 additions & 3 deletions core/reader/replicate_channel_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ func TestStartReadCollectionForMilvus(t *testing.T) {
InitBackOff: 1,
MaxBackOff: 1,
},
ReplicateID: "127.0.0.1:19530",
}, &api.DefaultMetaOp{}, func(s string, pack *msgstream.MsgPack) {
}, "milvus")
assert.NoError(t, err)
Expand Down Expand Up @@ -410,6 +411,7 @@ func TestStartReadCollectionForKafka(t *testing.T) {
InitBackOff: 1,
MaxBackOff: 1,
},
ReplicateID: "127.0.0.1:19530",
}, &api.DefaultMetaOp{}, func(s string, pack *msgstream.MsgPack) {
}, "kafka")
assert.NoError(t, err)
Expand Down Expand Up @@ -665,6 +667,7 @@ func TestReplicateChannelHandler(t *testing.T) {
factory := msgstream.NewMockFactory(t)
stream := msgstream.NewMockMsgStream(t)
targetClient := mocks.NewTargetAPI(t)
replicateID := "127.0.0.1:19530"

factory.EXPECT().NewMsgStream(mock.Anything).Return(stream, nil)
stream.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Times(4)
Expand All @@ -689,6 +692,7 @@ func TestReplicateChannelHandler(t *testing.T) {
handler.isDroppedPartition = func(i int64) bool {
return false
}
handler.replicateID = replicateID
time.Sleep(100 * time.Millisecond)
handler.startReadChannel()

Expand Down Expand Up @@ -718,7 +722,7 @@ func TestReplicateChannelHandler(t *testing.T) {
handler.RemovePartitionInfo(2, "p2", 10002)

assert.False(t, handler.IsEmpty())
assert.NotNil(t, GetTSManager().GetTargetMsgChan(handler.targetPChannel))
assert.NotNil(t, GetTSManager().GetTargetMsgChan(replicateID, handler.targetPChannel))

// test updateTargetPartitionInfo
targetClient.EXPECT().GetPartitionInfo(mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("mock error 2")).Once()
Expand Down Expand Up @@ -751,6 +755,7 @@ func TestReplicateChannelHandler(t *testing.T) {
stream := msgstream.NewMockMsgStream(t)
targetClient := mocks.NewTargetAPI(t)
streamChan := make(chan *msgstream.MsgPack)
replicateID := "127.0.0.1:19530"

factory.EXPECT().NewMsgStream(mock.Anything).Return(stream, nil)
stream.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Twice()
Expand Down Expand Up @@ -782,6 +787,7 @@ func TestReplicateChannelHandler(t *testing.T) {
TTInterval: 10000,
})
assert.NoError(t, err)
handler.replicateID = replicateID
handler.startReadChannel()

handler.isDroppedCollection = func(i int64) bool {
Expand All @@ -790,7 +796,7 @@ func TestReplicateChannelHandler(t *testing.T) {
handler.isDroppedPartition = func(i int64) bool {
return false
}
GetTSManager().InitTSInfo(handler.targetPChannel, 100*time.Millisecond, math.MaxUint64, 10)
GetTSManager().InitTSInfo(replicateID, handler.targetPChannel, 100*time.Millisecond, math.MaxUint64, 10)

err = handler.AddPartitionInfo(&pb.CollectionInfo{
ID: 1,
Expand All @@ -807,7 +813,7 @@ func TestReplicateChannelHandler(t *testing.T) {
noRetry(handler)

done := make(chan struct{})
targetMsgChan := GetTSManager().GetTargetMsgChan(handler.targetPChannel)
targetMsgChan := GetTSManager().GetTargetMsgChan(replicateID, handler.targetPChannel)

go func() {
defer close(done)
Expand Down
Loading

0 comments on commit 74073f4

Please sign in to comment.