diff --git a/router.go b/router.go index 79547e6..82b01d1 100644 --- a/router.go +++ b/router.go @@ -68,32 +68,33 @@ func (r *Router) mutable(v bool) { } } -func (r *Router) buildMiddlewaresChain(skip ...Middleware) Middlewares { - mdlws := Middlewares{} +func (r *Router) buildMiddlewares(m Middlewares) Middlewares { + m2 := Middlewares{} + m2.Before = append(m2.Before, r.middlewares.Before...) + m2.Before = append(m2.Before, m.Before...) + m2.After = append(m2.After, m.After...) + m2.After = append(m2.After, r.middlewares.After...) - var subMdlws Middlewares + m2.Skip = append(m2.Skip, m.Skip...) + m2.Skip = append(m2.Skip, r.middlewares.Skip...) - if r.parent != nil { - skip = append(skip, r.middlewares.Skip...) - subMdlws = r.parent.buildMiddlewaresChain(skip...) - } else if r.log.DebugEnabled() { + switch { + case r.parent != nil: + return r.parent.buildMiddlewares(m2) + case r.log.DebugEnabled(): debugMiddleware := func(ctx *RequestCtx) error { r.log.Debugf("%s %s", ctx.Method(), ctx.URI()) return ctx.Next() } - // Add debug middleware at first position if the log level is enabled as debug - mdlws.Before = append(mdlws.Before, debugMiddleware) + m2.Before = append([]Middleware{debugMiddleware}, m2.Before...) } - mdlws.Before = appendMiddlewares(mdlws.Before, subMdlws.Before, skip...) - mdlws.Before = appendMiddlewares(mdlws.Before, r.middlewares.Before, skip...) - - mdlws.After = appendMiddlewares(mdlws.After, r.middlewares.After, skip...) - mdlws.After = appendMiddlewares(mdlws.After, subMdlws.After, skip...) + m2.Before = appendMiddlewares(m2.Before[:0], m2.Before, m2.Skip...) + m2.After = appendMiddlewares(m2.After[:0], m2.After, m2.Skip...) - return mdlws + return m2 } func (r *Router) getGroupFullPath(path string) string { @@ -105,10 +106,9 @@ func (r *Router) getGroupFullPath(path string) string { } func (r *Router) handler(fn View, middle Middlewares) fasthttp.RequestHandler { - mdlws := r.buildMiddlewaresChain(middle.Skip...) + middle = r.buildMiddlewares(middle) - chain := append(mdlws.Before, middle.Before...) - chain = append(chain, func(ctx *RequestCtx) error { + chain := append(middle.Before, func(ctx *RequestCtx) error { if !ctx.skipView { if err := fn(ctx); err != nil { return err @@ -117,7 +117,6 @@ func (r *Router) handler(fn View, middle Middlewares) fasthttp.RequestHandler { return ctx.Next() }) chain = append(chain, middle.After...) - chain = append(chain, mdlws.After...) return func(ctx *fasthttp.RequestCtx) { actx := AcquireRequestCtx(ctx) diff --git a/router_test.go b/router_test.go index aea6330..fb0c23f 100644 --- a/router_test.go +++ b/router_test.go @@ -170,44 +170,48 @@ func TestRouter_mutable(t *testing.T) { } } -func TestRouter_buildMiddlewaresChain(t *testing.T) { +func TestRouter_buildMiddlewares(t *testing.T) { logLevels := []string{"fatal", "debug"} - mdlws := Middlewares{ - Before: []Middleware{ - func(ctx *RequestCtx) error { return ctx.Next() }, - func(ctx *RequestCtx) error { return ctx.Next() }, - }, - After: []Middleware{func(ctx *RequestCtx) error { return ctx.Next() }}, + middleware1 := func(ctx *RequestCtx) error { return ctx.Next() } + middleware2 := func(ctx *RequestCtx) error { return ctx.Next() } + middleware3 := func(ctx *RequestCtx) error { return ctx.Next() } + + middle := Middlewares{ + Before: []Middleware{middleware1, middleware2}, + After: []Middleware{middleware3}, + } + m := Middlewares{ + Skip: []Middleware{middleware1}, } for _, level := range logLevels { s := New(Config{LogLevel: level}) - s.Middlewares(mdlws) + s.Middlewares(middle) - chain := s.buildMiddlewaresChain(mdlws.Before[0]) + result := s.buildMiddlewares(m) - wantSkipLen := 0 - if len(chain.Skip) != wantSkipLen { - t.Errorf("Middlewares.Skip length == %d, want %d", len(chain.Skip), wantSkipLen) + wantSkipLen := len(m.Skip) + len(middle.Skip) + if len(result.Skip) != wantSkipLen { + t.Errorf("Middlewares.Skip length == %d, want %d", len(result.Skip), wantSkipLen) } - wantBeforeLen := len(mdlws.Before) - 1 + wantBeforeLen := len(middle.Before) - len(m.Skip) if s.log.DebugEnabled() { wantBeforeLen++ } - if len(chain.Before) != wantBeforeLen { - t.Errorf("Middlewares.Before length == %d, want %d", len(chain.Before), wantBeforeLen) + if len(result.Before) != wantBeforeLen { + t.Errorf("Middlewares.Before length == %d, want %d", len(result.Before), wantBeforeLen) } - if s.log.DebugEnabled() && isEqual(chain.Before[0], mdlws.Before[1]) { + if s.log.DebugEnabled() && isEqual(result.Before[0], middle.Before[1]) { t.Error("First before middleware must be the logger middleware") } - wantAfterLen := len(mdlws.After) - if len(chain.After) != wantAfterLen { - t.Errorf("Middlewares.After length == %d, want %d", len(chain.After), wantAfterLen) + wantAfterLen := len(middle.After) + if len(result.After) != wantAfterLen { + t.Errorf("Middlewares.After length == %d, want %d", len(result.After), wantAfterLen) } } }