diff --git a/callbacks/aspect_inject_test.go b/callbacks/aspect_inject_test.go index ee6fb5d..9763643 100644 --- a/callbacks/aspect_inject_test.go +++ b/callbacks/aspect_inject_test.go @@ -25,6 +25,7 @@ import ( "github.com/stretchr/testify/assert" + "github.com/cloudwego/eino/internal/callbacks" "github.com/cloudwego/eino/schema" ) @@ -181,3 +182,19 @@ func TestAspectInject(t *testing.T) { assert.Equal(t, 186, cnt) }) } + +func TestGlobalCallbacksRepeated(t *testing.T) { + times := 0 + testHandler := NewHandlerBuilder().OnStartFn(func(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context { + times++ + return ctx + }).Build() + callbacks.GlobalHandlers = append(callbacks.GlobalHandlers, testHandler) + + ctx := context.Background() + ctx = callbacks.AppendHandlers(ctx, &RunInfo{}) + ctx = callbacks.AppendHandlers(ctx, &RunInfo{}) + + callbacks.On(ctx, "test", callbacks.OnStartHandle[string], TimingOnStart) + assert.Equal(t, times, 1) +} diff --git a/internal/callbacks/inject.go b/internal/callbacks/inject.go index b7e063a..95abe2b 100644 --- a/internal/callbacks/inject.go +++ b/internal/callbacks/inject.go @@ -59,8 +59,8 @@ func On[T any](ctx context.Context, inOut T, handle Handle[T], timing CallbackTi return ctx, inOut } - hs := make([]Handler, 0, len(mgr.handlers)) - for _, handler := range mgr.handlers { + hs := make([]Handler, 0, len(mgr.handlers)+len(mgr.globalHandlers)) + for _, handler := range append(mgr.handlers, mgr.globalHandlers...) { timingChecker, ok_ := handler.(TimingChecker) if !ok_ || timingChecker.Needed(ctx, mgr.runInfo, timing) { hs = append(hs, handler) diff --git a/internal/callbacks/manager.go b/internal/callbacks/manager.go index d40c8c1..86192fc 100644 --- a/internal/callbacks/manager.go +++ b/internal/callbacks/manager.go @@ -19,24 +19,25 @@ package callbacks import "context" type manager struct { - handlers []Handler - runInfo *RunInfo + globalHandlers []Handler + handlers []Handler + runInfo *RunInfo } var GlobalHandlers []Handler func newManager(runInfo *RunInfo, handlers ...Handler) (*manager, bool) { - l := len(handlers) + len(GlobalHandlers) - if l == 0 { + if len(handlers)+len(GlobalHandlers) == 0 { return nil, false } - hs := make([]Handler, 0, l) - hs = append(hs, GlobalHandlers...) - hs = append(hs, handlers...) + + hs := make([]Handler, len(GlobalHandlers)) + copy(hs, GlobalHandlers) return &manager{ - handlers: hs, - runInfo: runInfo, + globalHandlers: hs, + handlers: handlers, + runInfo: runInfo, }, true } @@ -50,8 +51,9 @@ func (m *manager) withRunInfo(runInfo *RunInfo) *manager { } return &manager{ - handlers: m.handlers, - runInfo: runInfo, + globalHandlers: m.globalHandlers, + handlers: m.handlers, + runInfo: runInfo, } } @@ -60,8 +62,9 @@ func managerFromCtx(ctx context.Context) (*manager, bool) { m, ok := v.(*manager) if ok && m != nil { return &manager{ - handlers: m.handlers, - runInfo: m.runInfo, + globalHandlers: m.globalHandlers, + handlers: m.handlers, + runInfo: m.runInfo, }, true }