diff --git a/aggregator/pkg/common/time_provider_test.go b/aggregator/pkg/common/time_provider_test.go new file mode 100644 index 00000000..1f8a2efb --- /dev/null +++ b/aggregator/pkg/common/time_provider_test.go @@ -0,0 +1,40 @@ +package common + +import ( + "testing" + "time" +) + +func TestNewRealTimeProvider_NowReturnsCurrentTime(t *testing.T) { + tp := NewRealTimeProvider() + before := time.Now().Add(-1 * time.Second) + now := tp.Now() + after := time.Now().Add(1 * time.Second) + + if now.Before(before) || now.After(after) { + t.Fatalf("expected Now() to be close to current time, got %v", now) + } +} + +func TestMockTimeProvider_SetAndAdvance(t *testing.T) { + initial := time.Date(2024, 10, 1, 12, 0, 0, 0, time.UTC) + tp := NewMockTimeProvider(initial) + + if got := tp.Now(); !got.Equal(initial) { + t.Fatalf("expected initial time %v, got %v", initial, got) + } + + // SetTime + next := initial.Add(5 * time.Minute) + tp.SetTime(next) + if got := tp.Now(); !got.Equal(next) { + t.Fatalf("expected set time %v, got %v", next, got) + } + + // AdvanceTime + tp.AdvanceTime(30 * time.Second) + want := next.Add(30 * time.Second) + if got := tp.Now(); !got.Equal(want) { + t.Fatalf("expected advanced time %v, got %v", want, got) + } +} diff --git a/aggregator/pkg/configuration/file_configuration_provider_test.go b/aggregator/pkg/configuration/file_configuration_provider_test.go new file mode 100644 index 00000000..8e890257 --- /dev/null +++ b/aggregator/pkg/configuration/file_configuration_provider_test.go @@ -0,0 +1,49 @@ +package configuration + +import ( + "os" + "path/filepath" + "testing" +) + +func TestLoadConfig_Success_MinimalMemory(t *testing.T) { + tmpDir := t.TempDir() + cfgPath := filepath.Join(tmpDir, "agg.toml") + content := ` +[server] + address = ":50051" + +[storage] + type = "memory" + +[chainStatuses] + maxChainStatusesPerRequest = 10 + +[rateLimiting] + enabled = false + +[committees] + [committees.default] + [committees.default.quorumConfigs] +` + if err := os.WriteFile(cfgPath, []byte(content), 0o600); err != nil { + t.Fatalf("failed to write temp config: %v", err) + } + + cfg, err := LoadConfig(cfgPath) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if cfg == nil || cfg.Server.Address != ":50051" { + t.Fatalf("unexpected config: %+v", cfg) + } + if cfg.Storage == nil || string(cfg.Storage.StorageType) != "memory" { + t.Fatalf("expected memory storage, got %+v", cfg.Storage) + } +} + +func TestLoadConfig_Error_FileMissing(t *testing.T) { + if _, err := LoadConfig("/non/existent/file.toml"); err == nil { + t.Fatalf("expected error for missing file") + } +} diff --git a/aggregator/pkg/health/http_server_test.go b/aggregator/pkg/health/http_server_test.go new file mode 100644 index 00000000..101c980e --- /dev/null +++ b/aggregator/pkg/health/http_server_test.go @@ -0,0 +1,100 @@ +package health + +import ( + "context" + "encoding/json" + "net/http/httptest" + "testing" + + "go.uber.org/zap/zapcore" + + "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/common" + "github.com/smartcontractkit/chainlink-ccv/protocol/common/logging" + "github.com/smartcontractkit/chainlink-common/pkg/logger" +) + +type stubHealthyComponent struct{} + +func (s *stubHealthyComponent) HealthCheck(_ context.Context) *common.ComponentHealth { + return &common.ComponentHealth{Name: "stub", Status: common.HealthStatusHealthy} +} + +type stubDegradedComponent struct{} + +func (s *stubDegradedComponent) HealthCheck(_ context.Context) *common.ComponentHealth { + return &common.ComponentHealth{Name: "stub", Status: common.HealthStatusDegraded} +} + +type stubUnhealthyComponent struct{} + +func (s *stubUnhealthyComponent) HealthCheck(_ context.Context) *common.ComponentHealth { + return &common.ComponentHealth{Name: "stub", Status: common.HealthStatusUnhealthy} +} + +func newTestLogger(t *testing.T) logger.SugaredLogger { + t.Helper() + lggr, err := logger.NewWith(logging.DevelopmentConfig(zapcore.WarnLevel)) + if err != nil { + t.Fatalf("failed to create logger: %v", err) + } + return logger.Sugared(lggr) +} + +func TestHTTPHealthServer_Liveness(t *testing.T) { + m := NewManager() + h := NewHTTPHealthServer(m, "0", newTestLogger(t)) + + req := httptest.NewRequest("GET", "/health/live", nil) + rr := httptest.NewRecorder() + h.handleLiveness(rr, req) + + if rr.Code != 200 { + t.Fatalf("expected 200, got %d", rr.Code) + } + var payload common.ComponentHealth + if err := json.Unmarshal(rr.Body.Bytes(), &payload); err != nil { + t.Fatalf("invalid json: %v", err) + } + if payload.Status != common.HealthStatusHealthy { + t.Fatalf("expected healthy, got %s", payload.Status) + } +} + +func TestHTTPHealthServer_Readiness_StatusCodes(t *testing.T) { + // Healthy + { + m := NewManager() + m.Register(&stubHealthyComponent{}) + h := NewHTTPHealthServer(m, "0", newTestLogger(t)) + req := httptest.NewRequest("GET", "/health/ready", nil) + rr := httptest.NewRecorder() + h.handleReadiness(rr, req) + if rr.Code != 200 { + t.Fatalf("expected 200 for healthy, got %d", rr.Code) + } + } + // Degraded + { + m := NewManager() + m.Register(&stubDegradedComponent{}) + h := NewHTTPHealthServer(m, "0", newTestLogger(t)) + req := httptest.NewRequest("GET", "/health/ready", nil) + rr := httptest.NewRecorder() + h.handleReadiness(rr, req) + if rr.Code != 200 { + t.Fatalf("expected 200 for degraded, got %d", rr.Code) + } + } + // Unhealthy + { + m := NewManager() + m.Register(&stubUnhealthyComponent{}) + h := NewHTTPHealthServer(m, "0", newTestLogger(t)) + req := httptest.NewRequest("GET", "/health/ready", nil) + rr := httptest.NewRecorder() + h.handleReadiness(rr, req) + if rr.Code != 503 { + t.Fatalf("expected 503 for unhealthy, got %d", rr.Code) + } + } +} diff --git a/aggregator/pkg/middlewares/metric_middleware_test.go b/aggregator/pkg/middlewares/metric_middleware_test.go new file mode 100644 index 00000000..3162ecdd --- /dev/null +++ b/aggregator/pkg/middlewares/metric_middleware_test.go @@ -0,0 +1,47 @@ +package middlewares + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/mock" + "google.golang.org/grpc" + + aggregation_mocks "github.com/smartcontractkit/chainlink-ccv/aggregator/internal/aggregation_mocks" +) + +func TestMetricMiddleware_RecordsSuccessAndDuration(t *testing.T) { + metric := aggregation_mocks.NewMockAggregatorMetricLabeler(t) + monitoring := aggregation_mocks.NewMockAggregatorMonitoring(t) + + monitoring.EXPECT().Metrics().Return(metric) + metric.EXPECT().With("apiName", "/svc/Method").Return(metric).Maybe() + metric.EXPECT().IncrementActiveRequestsCounter(context.Background()) + metric.EXPECT().DecrementActiveRequestsCounter(context.Background()) + metric.EXPECT().RecordAPIRequestDuration(context.Background(), mock.Anything) + + mm := NewMetricMiddleware(monitoring) + info := &grpc.UnaryServerInfo{FullMethod: "/svc/Method"} + handler := func(ctx context.Context, req any) (any, error) { return "ok", nil } + + _, _ = mm.Intercept(context.Background(), nil, info, handler) +} + +func TestMetricMiddleware_RecordsError(t *testing.T) { + metric := aggregation_mocks.NewMockAggregatorMetricLabeler(t) + monitoring := aggregation_mocks.NewMockAggregatorMonitoring(t) + + monitoring.EXPECT().Metrics().Return(metric) + metric.EXPECT().With("apiName", "/svc/Err").Return(metric).Maybe() + metric.EXPECT().IncrementActiveRequestsCounter(context.Background()) + metric.EXPECT().DecrementActiveRequestsCounter(context.Background()) + metric.EXPECT().RecordAPIRequestDuration(context.Background(), mock.Anything) + metric.EXPECT().IncrementAPIRequestErrors(context.Background()) + + mm := NewMetricMiddleware(monitoring) + info := &grpc.UnaryServerInfo{FullMethod: "/svc/Err"} + handler := func(ctx context.Context, req any) (any, error) { return nil, errors.New("boom") } + + _, _ = mm.Intercept(context.Background(), nil, info, handler) +} diff --git a/aggregator/pkg/middlewares/scoping_middleware_test.go b/aggregator/pkg/middlewares/scoping_middleware_test.go new file mode 100644 index 00000000..98583ffa --- /dev/null +++ b/aggregator/pkg/middlewares/scoping_middleware_test.go @@ -0,0 +1,32 @@ +package middlewares + +import ( + "context" + "testing" + + "google.golang.org/grpc" + + "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/scope" + + aggregation_mocks "github.com/smartcontractkit/chainlink-ccv/aggregator/internal/aggregation_mocks" +) + +func TestScopingMiddleware_SetsAPINameInContext(t *testing.T) { + m := NewScopingMiddleware() + + // Prepare a mock labeler to assert AugmentMetrics adds apiName label + labeler := aggregation_mocks.NewMockAggregatorMetricLabeler(t) + // Expect apiName to be added by middleware + fullMethod := "/chainlink_ccv.v1.VerifierResultAPI/GetMessagesSince" + labeler.EXPECT().With("apiName", fullMethod).Return(labeler) + + info := &grpc.UnaryServerInfo{FullMethod: fullMethod} + + handler := func(ctx context.Context, req any) (any, error) { + // Scoping middleware should have placed apiName in context + _ = scope.AugmentMetrics(ctx, labeler) + return nil, nil + } + + _, _ = m.Intercept(context.Background(), nil, info, handler) +} diff --git a/aggregator/pkg/model/config_validation_test.go b/aggregator/pkg/model/config_validation_test.go new file mode 100644 index 00000000..31e63823 --- /dev/null +++ b/aggregator/pkg/model/config_validation_test.go @@ -0,0 +1,81 @@ +package model + +import ( + "testing" +) + +func TestAggregatorConfig_Validate_ErrorScenarios(t *testing.T) { + tests := []struct { + name string + mutate func(c *AggregatorConfig) + wantErrContains string + }{ + { + name: "chainStatuses must be >0", + mutate: func(c *AggregatorConfig) { c.ChainStatuses.MaxChainStatusesPerRequest = -1 }, + wantErrContains: "chain status configuration error", + }, + { + name: "batch size must be >0", + mutate: func(c *AggregatorConfig) { c.MaxMessageIDsPerBatch = -1 }, + wantErrContains: "batch configuration error", + }, + { + name: "batch size cannot exceed 1000", + mutate: func(c *AggregatorConfig) { c.MaxMessageIDsPerBatch = 2000 }, + wantErrContains: "batch configuration error", + }, + { + name: "aggregation.channelBufferSize must be >0", + mutate: func(c *AggregatorConfig) { c.Aggregation.ChannelBufferSize = -1 }, + wantErrContains: "aggregation configuration error", + }, + { + name: "aggregation.backgroundWorkerCount must be >0", + mutate: func(c *AggregatorConfig) { c.Aggregation.BackgroundWorkerCount = -1 }, + wantErrContains: "aggregation configuration error", + }, + { + name: "storage.pageSize must be >0", + mutate: func(c *AggregatorConfig) { c.Storage.PageSize = -1 }, + wantErrContains: "storage configuration error", + }, + { + name: "invalid API key empty id", + mutate: func(c *AggregatorConfig) { + c.APIKeys.Clients["abc"] = &APIClient{ClientID: ""} + }, + wantErrContains: "api key configuration error", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + cfg := &AggregatorConfig{ + Storage: &StorageConfig{PageSize: 10}, + APIKeys: APIKeyConfig{Clients: map[string]*APIClient{}}, + ChainStatuses: ChainStatusConfig{MaxChainStatusesPerRequest: 1}, + Aggregation: AggregationConfig{ChannelBufferSize: 1, BackgroundWorkerCount: 1}, + MaxMessageIDsPerBatch: 1, + } + tc.mutate(cfg) + if err := cfg.Validate(); err == nil { + t.Fatalf("expected error containing %q", tc.wantErrContains) + } + }) + } +} + +func TestAggregatorConfig_Validate_Success(t *testing.T) { + cfg := &AggregatorConfig{ + Storage: &StorageConfig{PageSize: 10}, + APIKeys: APIKeyConfig{Clients: map[string]*APIClient{"key1": {ClientID: "client1"}}}, + ChainStatuses: ChainStatusConfig{MaxChainStatusesPerRequest: 1}, + Aggregation: AggregationConfig{ChannelBufferSize: 10, BackgroundWorkerCount: 2}, + MaxMessageIDsPerBatch: 10, + RateLimiting: RateLimitingConfig{GroupLimits: map[string]map[string]RateLimitConfig{}}, + } + if err := cfg.Validate(); err != nil { + t.Fatalf("expected no error, got %v", err) + } +} diff --git a/aggregator/pkg/monitoring/noop_test.go b/aggregator/pkg/monitoring/noop_test.go new file mode 100644 index 00000000..a9205fbc --- /dev/null +++ b/aggregator/pkg/monitoring/noop_test.go @@ -0,0 +1,26 @@ +package monitoring + +import ( + "context" + "testing" + "time" +) + +func TestNoopAggregatorMonitoring_DoesNotPanic(t *testing.T) { + m := NewNoopAggregatorMonitoring() + lbl := m.Metrics() + + ctx := context.Background() + _ = lbl.With("key", "value") + lbl.IncrementActiveRequestsCounter(ctx) + lbl.DecrementActiveRequestsCounter(ctx) + lbl.IncrementCompletedAggregations(ctx) + lbl.RecordAPIRequestDuration(ctx, 10*time.Millisecond) + lbl.IncrementAPIRequestErrors(ctx) + lbl.RecordMessageSinceNumberOfRecordsReturned(ctx, 5) + lbl.IncrementPendingAggregationsChannelBuffer(ctx, 2) + lbl.DecrementPendingAggregationsChannelBuffer(ctx, 1) + lbl.RecordStorageLatency(ctx, 5*time.Millisecond) + lbl.IncrementStorageError(ctx) + lbl.RecordTimeToAggregation(ctx, time.Second) +} diff --git a/aggregator/pkg/scope/scope_test.go b/aggregator/pkg/scope/scope_test.go new file mode 100644 index 00000000..ddcb391f --- /dev/null +++ b/aggregator/pkg/scope/scope_test.go @@ -0,0 +1,63 @@ +package scope + +import ( + "context" + "testing" + + "go.uber.org/zap/zapcore" + + "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/auth" + "github.com/smartcontractkit/chainlink-ccv/protocol/common/logging" + "github.com/smartcontractkit/chainlink-common/pkg/logger" + + aggregation_mocks "github.com/smartcontractkit/chainlink-ccv/aggregator/internal/aggregation_mocks" +) + +func TestWithContextHelpers_RoundTrip(t *testing.T) { + // Table of setters and context keys to check presence (via AugmentMetrics expectations) + ctx := context.Background() + ctx = WithAPIName(ctx, "GetMessages") + ctx = WithRequestID(ctx) + ctx = WithMessageID(ctx, []byte{0x01, 0x02}) + ctx = WithAddress(ctx, []byte{0xab, 0xcd}) + ctx = WithParticipantID(ctx, "participant") + ctx = WithCommitteeID(ctx, "committee-1") + + // Ensure AugmentMetrics applies expected labels + mockLabeler := aggregation_mocks.NewMockAggregatorMetricLabeler(t) + // Order is committeeID then apiName (see metricsContextKeys) + mockLabeler.EXPECT().With("committeeID", "committee-1").Return(mockLabeler) + mockLabeler.EXPECT().With("apiName", "GetMessages").Return(mockLabeler) + + _ = AugmentMetrics(ctx, mockLabeler) +} + +func TestAugmentLogger_NoPanicAndCoversIdentity(t *testing.T) { + ctx := context.Background() + ctx = WithAPIName(ctx, "Read") + ctx = WithMessageID(ctx, []byte{0x0}) + ctx = WithAddress(ctx, []byte{0x1}) + ctx = WithParticipantID(ctx, "p") + ctx = WithCommitteeID(ctx, "c") + + // Add identity to hit the identity branch in AugmentLogger + id := auth.CreateCallerIdentity("caller-123", false) + ctx = auth.ToContext(ctx, id) + + // Use a no-op logger via mocks by relying on the interface behavior: we only ensure no panic. + // We can't easily introspect fields; just ensure call is safe. + // Create a simple logger using chainlink-common would require full zap plumbing; avoid here. + // Instead, use a nil-safe approach by passing a no-op SugaredLogger (zero-value not available), + // so we leverage the fact that AugmentLogger only calls .With and returns it; we don't log. + // For coverage, just ensure the function executes without panic. + // We'll reuse the labeler mock type to satisfy logging expectations isn't possible; so we skip assertions. + // Hence we import the package-level logger through a tiny adapter if needed; but not necessary for coverage. + + lggr, err := logger.NewWith(logging.DevelopmentConfig(zapcore.InfoLevel)) + if err != nil { + t.Fatalf("failed to create logger: %v", err) + } + sugared := logger.Sugared(lggr) + l := AugmentLogger(ctx, sugared) + l.Infof("smoke") +} diff --git a/aggregator/pkg/server.go b/aggregator/pkg/server.go index 1c64a8cc..a4c77f62 100644 --- a/aggregator/pkg/server.go +++ b/aggregator/pkg/server.go @@ -128,8 +128,10 @@ func (s *Server) Start(lis net.Listener) error { s.grpcServer.Stop() }) + // capture stopChan to avoid data race on struct field access in goroutine + stopCh := s.stopChan g.Add(func() error { - <-s.stopChan + <-stopCh s.l.Info("stop signal received, shutting down") return nil }, func(error) {}) diff --git a/aggregator/pkg/server_test.go b/aggregator/pkg/server_test.go new file mode 100644 index 00000000..d8e6e0e3 --- /dev/null +++ b/aggregator/pkg/server_test.go @@ -0,0 +1,72 @@ +package aggregator + +import ( + "net" + "testing" + "time" + + "go.uber.org/zap/zapcore" + + "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/model" + "github.com/smartcontractkit/chainlink-ccv/protocol/common/logging" + "github.com/smartcontractkit/chainlink-common/pkg/logger" +) + +func newTestLogger(t *testing.T) logger.SugaredLogger { + t.Helper() + lggr, err := logger.NewWith(logging.DevelopmentConfig(zapcore.WarnLevel)) + if err != nil { + t.Fatalf("failed to create logger: %v", err) + } + return logger.Sugared(lggr) +} + +func minimalConfig() *model.AggregatorConfig { + return &model.AggregatorConfig{ + Server: model.ServerConfig{Address: ":0"}, + Storage: &model.StorageConfig{StorageType: model.StorageTypeMemory}, + Monitoring: model.MonitoringConfig{Enabled: false}, + RateLimiting: model.RateLimitingConfig{Enabled: false}, + HealthCheck: model.HealthCheckConfig{Enabled: false}, + Committees: map[model.CommitteeID]*model.Committee{}, + StubMode: true, + } +} + +func TestServer_StartStop_Memory(t *testing.T) { + s := NewServer(newTestLogger(t), minimalConfig()) + lis, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen: %v", err) + } + if err := s.Start(lis); err != nil { + t.Fatalf("start failed: %v", err) + } + // Give the Run group a moment to spin + time.Sleep(50 * time.Millisecond) + if err := s.Stop(); err != nil { + t.Fatalf("stop failed: %v", err) + } +} + +func TestServer_DoubleStart_ReturnsError(t *testing.T) { + s := NewServer(newTestLogger(t), minimalConfig()) + lis, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen: %v", err) + } + if err := s.Start(lis); err != nil { + t.Fatalf("first start failed: %v", err) + } + if err := s.Start(lis); err == nil { + t.Fatalf("expected error on second start") + } + _ = s.Stop() +} + +func TestServer_Stop_WhenNotStarted_NoError(t *testing.T) { + s := NewServer(newTestLogger(t), minimalConfig()) + if err := s.Stop(); err != nil { + t.Fatalf("expected no error stopping non-started server, got %v", err) + } +} diff --git a/aggregator/pkg/storage/factory_test.go b/aggregator/pkg/storage/factory_test.go new file mode 100644 index 00000000..a28000d9 --- /dev/null +++ b/aggregator/pkg/storage/factory_test.go @@ -0,0 +1,51 @@ +package storage + +import ( + "testing" + + "go.uber.org/zap/zapcore" + + "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/model" + "github.com/smartcontractkit/chainlink-ccv/protocol/common/logging" + "github.com/smartcontractkit/chainlink-common/pkg/logger" +) + +func newTestLogger(t *testing.T) logger.SugaredLogger { + t.Helper() + lggr, err := logger.NewWith(logging.DevelopmentConfig(zapcore.InfoLevel)) + if err != nil { + t.Fatalf("failed to create logger: %v", err) + } + return logger.Sugared(lggr) +} + +func TestFactory_CreateStorage_Memory(t *testing.T) { + f := NewStorageFactory(newTestLogger(t)) + cfg := &model.StorageConfig{StorageType: model.StorageTypeMemory} + + s, err := f.CreateStorage(cfg, nil) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if s == nil { + t.Fatalf("expected non-nil storage") + } +} + +func TestFactory_CreateStorage_Unsupported(t *testing.T) { + f := NewStorageFactory(newTestLogger(t)) + cfg := &model.StorageConfig{StorageType: "unsupported"} + + if _, err := f.CreateStorage(cfg, nil); err == nil { + t.Fatalf("expected error for unsupported storage type") + } +} + +func TestFactory_CreateChainStatusStorage_Unsupported(t *testing.T) { + f := NewStorageFactory(newTestLogger(t)) + cfg := &model.StorageConfig{StorageType: "unsupported"} + + if _, err := f.CreateChainStatusStorage(cfg, nil); err == nil { + t.Fatalf("expected error for unsupported chain status storage type") + } +} diff --git a/aggregator/pkg/storage/metrics_aware_chainstatus_test.go b/aggregator/pkg/storage/metrics_aware_chainstatus_test.go new file mode 100644 index 00000000..54261e4f --- /dev/null +++ b/aggregator/pkg/storage/metrics_aware_chainstatus_test.go @@ -0,0 +1,80 @@ +package storage + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/mock" + + "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/common" + + aggregation_mocks "github.com/smartcontractkit/chainlink-ccv/aggregator/internal/aggregation_mocks" +) + +type fakeChainStatusStorage struct { + storeErr error + getMap map[uint64]*common.ChainStatus + getErr error + clients []string + clientsErr error +} + +func (f *fakeChainStatusStorage) StoreChainStatus(ctx context.Context, clientID string, chainStatuses map[uint64]*common.ChainStatus) error { + return f.storeErr +} + +func (f *fakeChainStatusStorage) GetClientChainStatus(ctx context.Context, clientID string) (map[uint64]*common.ChainStatus, error) { + return f.getMap, f.getErr +} + +func (f *fakeChainStatusStorage) GetAllClients(ctx context.Context) ([]string, error) { + return f.clients, f.clientsErr +} + +func setupChainStatusMetricMocks(t *testing.T) (*aggregation_mocks.MockAggregatorMetricLabeler, *aggregation_mocks.MockAggregatorMonitoring) { + t.Helper() + metric := aggregation_mocks.NewMockAggregatorMetricLabeler(t) + metric.On("With", mock.Anything, mock.Anything).Return(metric).Maybe() + metric.On("RecordStorageLatency", mock.Anything, mock.Anything).Maybe() + metric.On("IncrementStorageError", mock.Anything).Maybe() + + mon := aggregation_mocks.NewMockAggregatorMonitoring(t) + mon.EXPECT().Metrics().Return(metric).Maybe() + return metric, mon +} + +func TestMetricsAwareChainStatusStorage_SuccessPaths(t *testing.T) { + metric, mon := setupChainStatusMetricMocks(t) + inner := &fakeChainStatusStorage{ + getMap: map[uint64]*common.ChainStatus{}, + clients: []string{"a", "b"}, + } + s := NewMetricsAwareChainStatusStorage(inner, mon) + + ctx := context.Background() + _ = s.StoreChainStatus(ctx, "client", map[uint64]*common.ChainStatus{}) + _, _ = s.GetClientChainStatus(ctx, "client") + _, _ = s.GetAllClients(ctx) + + metric.AssertNumberOfCalls(t, "RecordStorageLatency", 3) + metric.AssertNumberOfCalls(t, "IncrementStorageError", 0) +} + +func TestMetricsAwareChainStatusStorage_ErrorPaths(t *testing.T) { + metric, mon := setupChainStatusMetricMocks(t) + inner := &fakeChainStatusStorage{ + storeErr: errors.New("store err"), + getErr: errors.New("get err"), + clientsErr: errors.New("clients err"), + } + s := NewMetricsAwareChainStatusStorage(inner, mon) + + ctx := context.Background() + _ = s.StoreChainStatus(ctx, "client", map[uint64]*common.ChainStatus{}) + _, _ = s.GetClientChainStatus(ctx, "client") + _, _ = s.GetAllClients(ctx) + + metric.AssertNumberOfCalls(t, "RecordStorageLatency", 3) + metric.AssertNumberOfCalls(t, "IncrementStorageError", 3) +} diff --git a/aggregator/pkg/storage/metrics_aware_storage_test.go b/aggregator/pkg/storage/metrics_aware_storage_test.go new file mode 100644 index 00000000..c1a49b4e --- /dev/null +++ b/aggregator/pkg/storage/metrics_aware_storage_test.go @@ -0,0 +1,172 @@ +package storage + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/smartcontractkit/chainlink-ccv/aggregator/pkg/model" + + aggregation_mocks "github.com/smartcontractkit/chainlink-ccv/aggregator/internal/aggregation_mocks" +) + +type fakeInnerStorage struct { + saveErr error + getRes *model.CommitVerificationRecord + getErr error + listRes []*model.CommitVerificationRecord + listErr error + queryRes *model.PaginatedAggregatedReports + queryErr error + ccvRes *model.CommitAggregatedReport + ccvErr error + batchRes map[string]*model.CommitAggregatedReport + batchErr error + submitErr error + // orphaning + orphanIDs []model.MessageID + orphanErr error +} + +func (f *fakeInnerStorage) SaveCommitVerification(ctx context.Context, record *model.CommitVerificationRecord) error { + return f.saveErr +} + +func (f *fakeInnerStorage) GetCommitVerification(ctx context.Context, id model.CommitVerificationRecordIdentifier) (*model.CommitVerificationRecord, error) { + return f.getRes, f.getErr +} + +func (f *fakeInnerStorage) ListCommitVerificationByMessageID(ctx context.Context, messageID model.MessageID, committee string) ([]*model.CommitVerificationRecord, error) { + return f.listRes, f.listErr +} + +func (f *fakeInnerStorage) QueryAggregatedReports(ctx context.Context, start int64, committeeID string, token *string) (*model.PaginatedAggregatedReports, error) { + return f.queryRes, f.queryErr +} + +func (f *fakeInnerStorage) GetCCVData(ctx context.Context, messageID model.MessageID, committeeID string) (*model.CommitAggregatedReport, error) { + return f.ccvRes, f.ccvErr +} + +func (f *fakeInnerStorage) GetBatchCCVData(ctx context.Context, messageIDs []model.MessageID, committeeID string) (map[string]*model.CommitAggregatedReport, error) { + return f.batchRes, f.batchErr +} + +func (f *fakeInnerStorage) SubmitReport(ctx context.Context, report *model.CommitAggregatedReport) error { + return f.submitErr +} + +func (f *fakeInnerStorage) ListOrphanedMessageIDs(ctx context.Context, committeeID model.CommitteeID) (<-chan model.MessageID, <-chan error) { + outIDs := make(chan model.MessageID, 1) + outErrs := make(chan error, 1) + go func() { + defer close(outIDs) + // Send any configured ids + for _, id := range f.orphanIDs { + outIDs <- id + } + if f.orphanErr != nil { + outErrs <- f.orphanErr + } + // Do not close outErrs: mirrors production interface behavior + }() + return outIDs, outErrs +} + +func setupMetricsMocks(t *testing.T) (*aggregation_mocks.MockAggregatorMetricLabeler, *aggregation_mocks.MockAggregatorMonitoring) { + t.Helper() + metric := aggregation_mocks.NewMockAggregatorMetricLabeler(t) + // Any With(...), return itself to allow chaining + metric.On("With", mock.Anything, mock.Anything).Return(metric).Maybe() + metric.On("RecordStorageLatency", mock.Anything, mock.Anything).Maybe() + metric.On("IncrementStorageError", mock.Anything).Maybe() + + mon := aggregation_mocks.NewMockAggregatorMonitoring(t) + mon.EXPECT().Metrics().Return(metric).Maybe() + return metric, mon +} + +func TestMetricsAwareStorage_SuccessPaths(t *testing.T) { + metric, mon := setupMetricsMocks(t) + // We'll count RecordStorageLatency calls after invoking all success paths + inner := &fakeInnerStorage{ + getRes: &model.CommitVerificationRecord{}, + listRes: []*model.CommitVerificationRecord{}, + queryRes: &model.PaginatedAggregatedReports{}, + ccvRes: &model.CommitAggregatedReport{}, + batchRes: map[string]*model.CommitAggregatedReport{}, + } + s := NewMetricsAwareStorage(inner, mon) + + ctx := context.Background() + _ = s.SaveCommitVerification(ctx, &model.CommitVerificationRecord{}) + _, _ = s.GetCommitVerification(ctx, model.CommitVerificationRecordIdentifier{}) + _, _ = s.ListCommitVerificationByMessageID(ctx, make([]byte, 0), "c") + _, _ = s.QueryAggregatedReports(ctx, time.Now().Unix(), "c", nil) + _, _ = s.GetCCVData(ctx, make([]byte, 0), "c") + _, _ = s.GetBatchCCVData(ctx, []model.MessageID{}, "c") + _ = s.SubmitReport(ctx, &model.CommitAggregatedReport{}) + + // 7 operations should each record latency once + metric.AssertNumberOfCalls(t, "RecordStorageLatency", 7) + metric.AssertNumberOfCalls(t, "IncrementStorageError", 0) +} + +func TestMetricsAwareStorage_ErrorPaths(t *testing.T) { + metric, mon := setupMetricsMocks(t) + inner := &fakeInnerStorage{ + saveErr: errors.New("save err"), + getErr: errors.New("get err"), + listErr: errors.New("list err"), + queryErr: errors.New("query err"), + ccvErr: errors.New("ccv err"), + batchErr: errors.New("batch err"), + submitErr: errors.New("submit err"), + } + s := NewMetricsAwareStorage(inner, mon) + + ctx := context.Background() + _ = s.SaveCommitVerification(ctx, &model.CommitVerificationRecord{}) + _, _ = s.GetCommitVerification(ctx, model.CommitVerificationRecordIdentifier{}) + _, _ = s.ListCommitVerificationByMessageID(ctx, make([]byte, 0), "c") + _, _ = s.QueryAggregatedReports(ctx, time.Now().Unix(), "c", nil) + _, _ = s.GetCCVData(ctx, make([]byte, 0), "c") + _, _ = s.GetBatchCCVData(ctx, []model.MessageID{}, "c") + _ = s.SubmitReport(ctx, &model.CommitAggregatedReport{}) + + // 7 latency + 7 errors + metric.AssertNumberOfCalls(t, "RecordStorageLatency", 7) + metric.AssertNumberOfCalls(t, "IncrementStorageError", 7) +} + +func TestMetricsAwareStorage_ListOrphanedMessageIDs_ProxiesAndRecordsLatency(t *testing.T) { + metric, mon := setupMetricsMocks(t) + // One orphan id and no error + inner := &fakeInnerStorage{ + orphanIDs: []model.MessageID{[]byte{0x01}}, + } + s := NewMetricsAwareStorage(inner, mon) + + ctx := context.Background() + ids, errs := s.ListOrphanedMessageIDs(ctx, "committee-1") + + got := make([]model.MessageID, 0, 1) + for id := range ids { + got = append(got, id) + } + // Allow deferred metric recording to run after goroutine exit + time.Sleep(10 * time.Millisecond) + select { + case <-errs: + // ignore; not expected but channel may have residual + default: + } + + assert.Equal(t, 1, len(got)) + // Latency should be recorded once when the goroutine exits + metric.AssertNumberOfCalls(t, "RecordStorageLatency", 1) +} diff --git a/aggregator/tests/commit_verification_api_test.go b/aggregator/tests/commit_verification_api_test.go index a3ccd476..1f695e51 100644 --- a/aggregator/tests/commit_verification_api_test.go +++ b/aggregator/tests/commit_verification_api_test.go @@ -1790,26 +1790,28 @@ func TestBatchGetVerifierResult_MissingMessages(t *testing.T) { }, } - batchResp, err := ccvDataClient.BatchGetVerifierResultForMessage(t.Context(), batchReqWithMissing) - require.NoError(t, err, "BatchGetVerifierResultForMessage with missing should not error") - require.NotNil(t, batchResp, "batch response with missing should not be nil") - - // Should have 1 result and 2 errors (1:1 correspondence with requests) - require.Len(t, batchResp.Results, 1, "should have 1 result (existing message)") - require.Len(t, batchResp.Errors, 2, "should have 2 errors (1:1 with requests)") - - // First request (existing) should have Status with Code 0 - require.NotNil(t, batchResp.Errors[0], "existing message should have Status with Code 0") - require.Equal(t, int32(codes.OK), batchResp.Errors[0].Code, "existing message should have Code 0") - - // Second request (missing) should have NotFound error - require.NotNil(t, batchResp.Errors[1], "missing message should have error") - require.Equal(t, int32(codes.NotFound), batchResp.Errors[1].Code, "missing message should have NotFound error") - - // Verify the result is correct - result := batchResp.Results[0] - require.Equal(t, uint64(1001), result.GetMessage().GetNonce(), "nonce should match") - require.Equal(t, sourceVerifierAddress, result.SourceVerifierAddress, "source verifier address should match") + require.EventuallyWithTf(t, func(collect *assert.CollectT) { + batchResp, err := ccvDataClient.BatchGetVerifierResultForMessage(t.Context(), batchReqWithMissing) + require.NoError(collect, err, "BatchGetVerifierResultForMessage with missing should not error") + require.NotNil(collect, batchResp, "batch response with missing should not be nil") + + // Should have 1 result and 2 errors (1:1 correspondence with requests) + require.Len(collect, batchResp.Results, 1, "should have 1 result (existing message)") + require.Len(collect, batchResp.Errors, 2, "should have 2 errors (1:1 with requests)") + + // First request (existing) should have Status with Code 0 + require.NotNil(collect, batchResp.Errors[0], "existing message should have Status with Code 0") + require.Equal(collect, int32(codes.OK), batchResp.Errors[0].Code, "existing message should have Code 0") + + // Second request (missing) should have NotFound error + require.NotNil(collect, batchResp.Errors[1], "missing message should have error") + require.Equal(collect, int32(codes.NotFound), batchResp.Errors[1].Code, "missing message should have NotFound error") + + // Verify the result is correct + result := batchResp.Results[0] + require.Equal(collect, uint64(1001), result.GetMessage().GetNonce(), "nonce should match") + require.Equal(collect, sourceVerifierAddress, result.SourceVerifierAddress, "source verifier address should match") + }, 5*time.Second, 200*time.Millisecond, "batch result for existing message not ready in time") } for _, storageType := range storageTypes {