diff --git a/pkg/index/job/correction/service/corrector.go b/pkg/index/job/correction/service/corrector.go index 37519f8bb99..aa2adf47ec2 100644 --- a/pkg/index/job/correction/service/corrector.go +++ b/pkg/index/job/correction/service/corrector.go @@ -34,6 +34,7 @@ import ( "github.com/vdaas/vald/internal/net/grpc" "github.com/vdaas/vald/internal/net/grpc/codes" "github.com/vdaas/vald/internal/net/grpc/status" + "github.com/vdaas/vald/internal/safety" "github.com/vdaas/vald/internal/sync" "github.com/vdaas/vald/internal/sync/errgroup" "github.com/vdaas/vald/pkg/index/job/correction/config" @@ -106,9 +107,8 @@ func (c *correct) Start(ctx context.Context) (<-chan error, error) { return nil, err } - // For debugging c.indexInfos.Range(func(addr string, info *payload.Info_Index_Count) bool { - log.Debugf("index info: addr(%s), stored(%d), uncommitted(%d)", addr, info.GetStored(), info.GetUncommitted()) + log.Infof("index info: addr(%s), stored(%d), uncommitted(%d)", addr, info.GetStored(), info.GetUncommitted()) return true }) @@ -133,30 +133,22 @@ func (c *correct) correct(ctx context.Context) (err error) { // This is used to know which agents possibly have the same index as the target replica. // We can say this because, thanks to caching, there is no way that the target replica is // in the agent that has already been corrected. - leftAgentAddrs := make([]string, len(c.agentAddrs)) - n := copy(leftAgentAddrs, c.agentAddrs) - if n != len(c.agentAddrs) { - return fmt.Errorf("failed to copy agentAddrs") - } // Vector with time after this should not be processed - correctionStartTime, err := getCorrectionStartTime(ctx) + correctionStartTime, err := correctionStartTime(ctx) if err != nil { log.Errorf("cannot determine correction start time: %w", err) return err } + curTargetAgent := 0 if err := c.discoverer.GetClient().OrderedRange(ctx, c.agentAddrs, func(ctx context.Context, addr string, conn *grpc.ClientConn, copts ...grpc.CallOption) error { // current address is the leftAgentAddrs[0] because this is OrderedRange and // leftAgentAddrs is copied from c.agentAddrs - leftAgentAddrs = leftAgentAddrs[1:] - - vc := vald.NewValdClient(conn) - stream, err := vc.StreamListObject(ctx, &payload.Object_List_Request{}) - if err != nil { - return err - } + defer func() { + curTargetAgent++ + }() // context and errgroup for stream.Recv and correction sctx, scancel := context.WithCancel(ctx) @@ -173,6 +165,12 @@ func (c *correct) correct(ctx context.Context) (err error) { var mu sync.Mutex log.Infof("starting correction for agent %s, stream concurrency: %d, bbolt concurrency: %d", addr, sconcurrency, bconcurrency) + vc := vald.NewValdClient(conn) + stream, err := vc.StreamListObject(ctx, &payload.Object_List_Request{}) + if err != nil { + return err + } + // The number of items to be received in advance is not known in advance. // This is because there is a possibility of new items being inserted during processing. for { @@ -181,9 +179,26 @@ func (c *correct) correct(ctx context.Context) (err error) { if !errors.Is(sctx.Err(), context.Canceled) { log.Errorf("context done unexpectedly: %v", sctx.Err()) } - goto Finalize + + // Finalize + err = seg.Wait() + if err != nil { + log.Errorf("err group returned error: %v", err) + } + + berr := bolteg.Wait() + if berr != nil { + log.Errorf("bbolt err group returned error: %v", err) + err = errors.Join(err, berr) + } else { + log.Info("bbolt all batch finished") + } + + log.Infof("correction finished for agent %s", addr) + return err + default: - seg.Go(func() error { + seg.Go(safety.RecoverFunc(func() error { mu.Lock() // As long as we don't stream.Recv() from the stream, we do not consume the memory of the message. // So by limiting the number of this errgroup.Go instances, we can limit the memory usage @@ -235,7 +250,7 @@ func (c *correct) correct(ctx context.Context) (err error) { addr: addr, vec: vec, }, - leftAgentAddrs, + curTargetAgent, ); err != nil { log.Errorf("failed to check consistency: %v", err) return nil // continue other processes @@ -245,25 +260,9 @@ func (c *correct) correct(ctx context.Context) (err error) { c.checkedID.AsyncSet(bolteg, []byte(id), nil) return nil - }) + })) } } - - Finalize: - err = seg.Wait() - if err != nil { - log.Errorf("err group returned error: %v", err) - } - - berr := bolteg.Wait() - if berr != nil { - log.Errorf("bolt err group returned error: %v", err) - err = errors.Join(err, berr) - } - log.Info("bbolt all batch finished") - - log.Infof("correction finished for agent %s", addr) - return err }, ); err != nil { log.Errorf("failed to range over agents(%v): %v", c.agentAddrs, err) @@ -279,39 +278,23 @@ type vectorReplica struct { } // Validate len(addrs) >= 2 before calling this function -func (c *correct) checkConsistency(ctx context.Context, targetReplica *vectorReplica, leftAgentAddrs []string) error { - // availableAddrs is the agents' addr that doesn't have the target replica thus is available to insert the replica - // to fix the index replica number if required. - availableAddrs := make([]string, 0, len(c.agentAddrs)-1) - for _, addr := range c.agentAddrs { - if addr != targetReplica.addr { - availableAddrs = append(availableAddrs, addr) - } - } +// idxだけ渡せば良い?c.addrsに全ての情報があるので? +func (c *correct) checkConsistency(ctx context.Context, targetReplica *vectorReplica, targetAgentIdx int) error { + // leftAgentAddrs is the agents' addr that hasn't been corrected yet. + leftAgentAddrs := c.agentAddrs[targetAgentIdx+1:] // Vector with time after this should not be processed - correctionStartTime, err := getCorrectionStartTime(ctx) + correctionStartTime, err := correctionStartTime(ctx) if err != nil { log.Errorf("cannot determine correction start time: %w", err) return err } - foundReplicas := make([]*vectorReplica, 0, len(availableAddrs)) + foundReplicas := make([]*vectorReplica, 0, len(c.agentAddrs)) var mu sync.Mutex if err := c.discoverer.GetClient().OrderedRangeConcurrent(ctx, leftAgentAddrs, len(leftAgentAddrs), func(ctx context.Context, addr string, conn *grpc.ClientConn, copts ...grpc.CallOption) error { - // To avoid GetObject to myself. To maintain backward compatibility for withoug cache operation - if addr == targetReplica.addr { - return nil - } - - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - vc := vald.NewValdClient(conn) - vec, err := vc.GetObject(ctx, &payload.Object_VectorRequest{ + vec, err := vald.NewValdClient(conn).GetObject(ctx, &payload.Object_VectorRequest{ Id: &payload.Object_ID{ Id: targetReplica.vec.GetId(), }, @@ -344,12 +327,6 @@ func (c *correct) checkConsistency(ctx context.Context, targetReplica *vectorRep addr: addr, vec: vec, }) - - // Remove this addr from availableAddrs because this addr has the target replica - // and not available to insert the replica to fix the index replica number - slices.DeleteFunc(availableAddrs, func(availableAddr string) bool { - return availableAddr == addr - }) mu.Unlock() return nil @@ -364,7 +341,7 @@ func (c *correct) checkConsistency(ctx context.Context, targetReplica *vectorRep } // check replica number - if err := c.correctReplica(ctx, targetReplica, foundReplicas, availableAddrs); err != nil { + if err := c.correctReplica(ctx, targetReplica, foundReplicas); err != nil { return fmt.Errorf("failed to fix index replica: %w", err) } @@ -413,7 +390,6 @@ func (c *correct) correctReplica( ctx context.Context, targetReplica *vectorReplica, foundReplicas []*vectorReplica, - availableAddrs []string, ) error { // diff < 0 means there is less replica than the correct number existReplica := len(foundReplicas) + 1 @@ -423,6 +399,19 @@ func (c *correct) correctReplica( return nil } + availableAddrs := make([]string, 0, len(c.agentAddrs)) + for _, addr := range c.agentAddrs { + if addr == targetReplica.addr { + continue + } + if slices.ContainsFunc(foundReplicas, func(replica *vectorReplica) bool { + return replica.addr == addr + }) { + continue + } + availableAddrs = append(availableAddrs, addr) + } + // when there are less replicas than the correct number, add the extra replicas if diff < 0 { log.Infof("replica shortage of vector %s. inserting to other agents...", targetReplica.vec.GetId()) @@ -596,7 +585,7 @@ func embedTime(ctx context.Context) context.Context { return context.WithValue(ctx, correctionStartTimeKey, time.Now()) } -func getCorrectionStartTime(ctx context.Context) (time.Time, error) { +func correctionStartTime(ctx context.Context) (time.Time, error) { v := ctx.Value(correctionStartTimeKey) if t, ok := v.(time.Time); ok { return t, nil diff --git a/pkg/index/job/correction/service/corrector_test.go b/pkg/index/job/correction/service/corrector_test.go index 533a8c2fd1d..7c30375de7e 100644 --- a/pkg/index/job/correction/service/corrector_test.go +++ b/pkg/index/job/correction/service/corrector_test.go @@ -437,9 +437,16 @@ func Test_correct_correctReplica(t *testing.T) { }, }, } + + // agentAddrs = availableAddrs + target.addr + found.addr + c.agentAddrs = append(test.args.availableAddrs, test.args.target.addr) + for _, found := range test.args.found { + c.agentAddrs = append(c.agentAddrs, found.addr) + } + t.Run(test.name, func(tt *testing.T) { tt.Parallel() - err := c.correctReplica(context.Background(), test.args.target, test.args.found, test.args.availableAddrs) + err := c.correctReplica(context.Background(), test.args.target, test.args.found) if test.want.err != nil { require.ErrorIs(t, test.want.err, err) }