Skip to content

Commit

Permalink
Refactor and fix Router.buildMiddlewares
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergio Andres Virviescas Santana committed Jun 17, 2020
1 parent 4a5d4ee commit 3430e5a
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 38 deletions.
37 changes: 18 additions & 19 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,32 +68,33 @@ func (r *Router) mutable(v bool) {
}
}

func (r *Router) buildMiddlewaresChain(skip ...Middleware) Middlewares {
mdlws := Middlewares{}
func (r *Router) buildMiddlewares(m Middlewares) Middlewares {
m2 := Middlewares{}
m2.Before = append(m2.Before, r.middlewares.Before...)
m2.Before = append(m2.Before, m.Before...)
m2.After = append(m2.After, m.After...)
m2.After = append(m2.After, r.middlewares.After...)

var subMdlws Middlewares
m2.Skip = append(m2.Skip, m.Skip...)
m2.Skip = append(m2.Skip, r.middlewares.Skip...)

if r.parent != nil {
skip = append(skip, r.middlewares.Skip...)
subMdlws = r.parent.buildMiddlewaresChain(skip...)
} else if r.log.DebugEnabled() {
switch {
case r.parent != nil:
return r.parent.buildMiddlewares(m2)
case r.log.DebugEnabled():
debugMiddleware := func(ctx *RequestCtx) error {
r.log.Debugf("%s %s", ctx.Method(), ctx.URI())

return ctx.Next()
}

// Add debug middleware at first position if the log level is enabled as debug
mdlws.Before = append(mdlws.Before, debugMiddleware)
m2.Before = append([]Middleware{debugMiddleware}, m2.Before...)
}

mdlws.Before = appendMiddlewares(mdlws.Before, subMdlws.Before, skip...)
mdlws.Before = appendMiddlewares(mdlws.Before, r.middlewares.Before, skip...)

mdlws.After = appendMiddlewares(mdlws.After, r.middlewares.After, skip...)
mdlws.After = appendMiddlewares(mdlws.After, subMdlws.After, skip...)
m2.Before = appendMiddlewares(m2.Before[:0], m2.Before, m2.Skip...)
m2.After = appendMiddlewares(m2.After[:0], m2.After, m2.Skip...)

return mdlws
return m2
}

func (r *Router) getGroupFullPath(path string) string {
Expand All @@ -105,10 +106,9 @@ func (r *Router) getGroupFullPath(path string) string {
}

func (r *Router) handler(fn View, middle Middlewares) fasthttp.RequestHandler {
mdlws := r.buildMiddlewaresChain(middle.Skip...)
middle = r.buildMiddlewares(middle)

chain := append(mdlws.Before, middle.Before...)
chain = append(chain, func(ctx *RequestCtx) error {
chain := append(middle.Before, func(ctx *RequestCtx) error {
if !ctx.skipView {
if err := fn(ctx); err != nil {
return err
Expand All @@ -117,7 +117,6 @@ func (r *Router) handler(fn View, middle Middlewares) fasthttp.RequestHandler {
return ctx.Next()
})
chain = append(chain, middle.After...)
chain = append(chain, mdlws.After...)

return func(ctx *fasthttp.RequestCtx) {
actx := AcquireRequestCtx(ctx)
Expand Down
42 changes: 23 additions & 19 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,44 +170,48 @@ func TestRouter_mutable(t *testing.T) {
}
}

func TestRouter_buildMiddlewaresChain(t *testing.T) {
func TestRouter_buildMiddlewares(t *testing.T) {
logLevels := []string{"fatal", "debug"}

mdlws := Middlewares{
Before: []Middleware{
func(ctx *RequestCtx) error { return ctx.Next() },
func(ctx *RequestCtx) error { return ctx.Next() },
},
After: []Middleware{func(ctx *RequestCtx) error { return ctx.Next() }},
middleware1 := func(ctx *RequestCtx) error { return ctx.Next() }
middleware2 := func(ctx *RequestCtx) error { return ctx.Next() }
middleware3 := func(ctx *RequestCtx) error { return ctx.Next() }

middle := Middlewares{
Before: []Middleware{middleware1, middleware2},
After: []Middleware{middleware3},
}
m := Middlewares{
Skip: []Middleware{middleware1},
}

for _, level := range logLevels {
s := New(Config{LogLevel: level})
s.Middlewares(mdlws)
s.Middlewares(middle)

chain := s.buildMiddlewaresChain(mdlws.Before[0])
result := s.buildMiddlewares(m)

wantSkipLen := 0
if len(chain.Skip) != wantSkipLen {
t.Errorf("Middlewares.Skip length == %d, want %d", len(chain.Skip), wantSkipLen)
wantSkipLen := len(m.Skip) + len(middle.Skip)
if len(result.Skip) != wantSkipLen {
t.Errorf("Middlewares.Skip length == %d, want %d", len(result.Skip), wantSkipLen)
}

wantBeforeLen := len(mdlws.Before) - 1
wantBeforeLen := len(middle.Before) - len(m.Skip)
if s.log.DebugEnabled() {
wantBeforeLen++
}

if len(chain.Before) != wantBeforeLen {
t.Errorf("Middlewares.Before length == %d, want %d", len(chain.Before), wantBeforeLen)
if len(result.Before) != wantBeforeLen {
t.Errorf("Middlewares.Before length == %d, want %d", len(result.Before), wantBeforeLen)
}

if s.log.DebugEnabled() && isEqual(chain.Before[0], mdlws.Before[1]) {
if s.log.DebugEnabled() && isEqual(result.Before[0], middle.Before[1]) {
t.Error("First before middleware must be the logger middleware")
}

wantAfterLen := len(mdlws.After)
if len(chain.After) != wantAfterLen {
t.Errorf("Middlewares.After length == %d, want %d", len(chain.After), wantAfterLen)
wantAfterLen := len(middle.After)
if len(result.After) != wantAfterLen {
t.Errorf("Middlewares.After length == %d, want %d", len(result.After), wantAfterLen)
}
}
}
Expand Down

0 comments on commit 3430e5a

Please sign in to comment.