Skip to content

Commit

Permalink
Add Atreugo.NewVirtualHost
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergio Andres Virviescas Santana committed Jul 5, 2020
1 parent 6b40c09 commit 5829d17
Show file tree
Hide file tree
Showing 4 changed files with 269 additions and 45 deletions.
58 changes: 51 additions & 7 deletions atreugo.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func New(cfg Config) *Atreugo {
}

server := &Atreugo{
server: newFasthttpServer(cfg, r.router.Handler, log),
server: newFasthttpServer(cfg, log),
log: log,
cfg: cfg,
Router: r,
Expand All @@ -70,14 +70,9 @@ func New(cfg Config) *Atreugo {
return server
}

func newFasthttpServer(cfg Config, handler fasthttp.RequestHandler, log fasthttp.Logger) *fasthttp.Server {
if cfg.Compress {
handler = fasthttp.CompressHandler(handler)
}

func newFasthttpServer(cfg Config, log fasthttp.Logger) *fasthttp.Server {
return &fasthttp.Server{
Name: cfg.Name,
Handler: handler,
HeaderReceived: cfg.HeaderReceived,
Concurrency: cfg.Concurrency,
DisableKeepalive: cfg.DisableKeepalive,
Expand All @@ -104,6 +99,28 @@ func newFasthttpServer(cfg Config, handler fasthttp.RequestHandler, log fasthttp
}
}

func (s *Atreugo) handler() fasthttp.RequestHandler {
handler := s.router.Handler

if len(s.virtualHosts) > 0 {
handler = func(ctx *fasthttp.RequestCtx) {
hostname := gotils.B2S(ctx.URI().Host())

if h := s.virtualHosts[hostname]; h != nil {
h(ctx)
} else {
s.router.Handler(ctx)
}
}
}

if s.cfg.Compress {
handler = fasthttp.CompressHandler(handler)
}

return handler
}

// SaveMatchedRoutePath if enabled, adds the matched route path onto the ctx.UserValue context
// before invoking the handler.
// The matched route path is only added to handlers of routes that were
Expand Down Expand Up @@ -184,6 +201,7 @@ func (s *Atreugo) Serve(ln net.Listener) error {

s.cfg.Addr = ln.Addr().String()
s.cfg.Network = ln.Addr().Network()
s.server.Handler = s.handler()

if gotils.StringSliceInclude(tcpNetworks, s.cfg.Network) {
schema := "http"
Expand All @@ -207,3 +225,29 @@ func (s *Atreugo) Serve(ln net.Listener) error {
func (s *Atreugo) SetLogOutput(output io.Writer) {
s.log.SetOutput(output)
}

// NewVirtualHost returns a new sub-router for running more than one web site
// (such as company1.example.com and company2.example.com) on a single atreugo instance.
// Virtual hosts can be "IP-based", meaning that you have a different IP address
// for every web site, or "name-based", meaning that you have multiple names
// running on each IP address.
//
// The fact that they are running on the same atreugo instance is not apparent to the end user.
func (s *Atreugo) NewVirtualHost(hostname string) *Router {
if s.virtualHosts == nil {
s.virtualHosts = make(map[string]fasthttp.RequestHandler)
}

vHost := newRouter(s.log, s.cfg.ErrorView)
vHost.router.NotFound = s.router.NotFound
vHost.router.MethodNotAllowed = s.router.MethodNotAllowed
vHost.router.PanicHandler = s.router.PanicHandler

if s.virtualHosts[hostname] != nil {
panicf("a router is already registered for virtual host '%s'", hostname)
}

s.virtualHosts[hostname] = vHost.router.Handler

return vHost
}
249 changes: 211 additions & 38 deletions atreugo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package atreugo
import (
"bytes"
"errors"
"fmt"
"math/rand"
"net"
"reflect"
"testing"
"time"
Expand Down Expand Up @@ -160,75 +163,174 @@ func Test_New(t *testing.T) { //nolint:funlen,gocognit
}

func Test_newFasthttpServer(t *testing.T) { //nolint:funlen
type args struct {
compress bool
cfg := Config{
Name: "test",
HeaderReceived: func(header *fasthttp.RequestHeader) fasthttp.RequestConfig {
return fasthttp.RequestConfig{}
},
Concurrency: rand.Int(), // nolint:gosec
DisableKeepalive: true,
ReadBufferSize: rand.Int(), // nolint:gosec
WriteBufferSize: rand.Int(), // nolint:gosec
ReadTimeout: time.Duration(rand.Int()), // nolint:gosec
WriteTimeout: time.Duration(rand.Int()), // nolint:gosec
IdleTimeout: time.Duration(rand.Int()), // nolint:gosec
MaxConnsPerIP: rand.Int(), // nolint:gosec
MaxRequestsPerConn: rand.Int(), // nolint:gosec
MaxRequestBodySize: rand.Int(), // nolint:gosec
ReduceMemoryUsage: true,
GetOnly: true,
DisablePreParseMultipartForm: true,
LogAllErrors: true,
DisableHeaderNamesNormalizing: true,
SleepWhenConcurrencyLimitsExceeded: time.Duration(rand.Int()), // nolint:gosec
NoDefaultServerHeader: true,
NoDefaultDate: true,
NoDefaultContentType: true,
ConnState: func(net.Conn, fasthttp.ConnState) {},
KeepHijackedConns: true,
}

type want struct {
compress bool
srv := newFasthttpServer(cfg, testLog)

if srv == nil {
t.Fatal("newFasthttpServer() == nil")
}

fasthttpServerType := reflect.TypeOf(fasthttp.Server{})
configType := reflect.TypeOf(Config{})

fasthttpServerValue := reflect.ValueOf(*srv) // nolint:govet
configValue := reflect.ValueOf(cfg)

for i := 0; i < fasthttpServerType.NumField(); i++ {
field := fasthttpServerType.Field(i)

if !unicode.IsUpper(rune(field.Name[0])) { // Check if the field is public
continue
} else if gotils.StringSliceInclude(notConfigFasthttpFields, field.Name) {
continue
}

_, exist := configType.FieldByName(field.Name)
if !exist {
t.Errorf("The field '%s' does not exist in atreugo.Config", field.Name)
}

v1 := fmt.Sprint(fasthttpServerValue.FieldByName(field.Name).Interface())
v2 := fmt.Sprint(configValue.FieldByName(field.Name).Interface())

if v1 != v2 {
t.Errorf("fasthttp.Server.%s == %s, want %s", field.Name, v1, v2)
}
}

if srv.Handler != nil {
t.Error("fasthttp.Server.Handler must be nil")
}

if !isEqual(srv.Logger, testLog) {
t.Errorf("fasthttp.Server.Logger == %p, want %p", srv.Logger, testLog)
}
}

func TestAtreugo_handler(t *testing.T) { // nolint:funlen,gocognit
type args struct {
cfg Config
hosts []string
}

tests := []struct {
name string
args args
want want
}{
{
name: "NotCompress",
name: "Default",
args: args{
compress: false,
},
want: want{
compress: false,
cfg: Config{},
},
},
{
name: "Compress",
args: args{
compress: true,
cfg: Config{Compress: true},
},
want: want{
compress: true,
},
{
name: "MultiHost",
args: args{
cfg: Config{},
hosts: []string{"localhost", "example.com"},
},
},
{
name: "MultiHostCompress",
args: args{
cfg: Config{Compress: true},
hosts: []string{"localhost", "example.com"},
},
},
}

handler := func(ctx *fasthttp.RequestCtx) {}

for _, test := range tests {
tt := test

t.Run(tt.name, func(t *testing.T) {
cfg := Config{
LogLevel: "fatal",
Compress: tt.args.compress,
testView := func(ctx *RequestCtx) error {
return ctx.JSONResponse(JSON{"data": gotils.RandBytes(make([]byte, 300))})
}
srv := newFasthttpServer(cfg, handler, testLog)
testPath := "/"

s := New(tt.args.cfg)
s.GET(testPath, testView)

if (reflect.ValueOf(handler).Pointer() == reflect.ValueOf(srv.Handler).Pointer()) == tt.want.compress {
t.Error("The handler has not been wrapped by compression handler")
for _, hostname := range tt.args.hosts {
vHost := s.NewVirtualHost(hostname)
vHost.GET(testPath, testView)
}
})
}
}

func TestAtreugo_ConfigFasthttpFields(t *testing.T) {
fasthttpServerType := reflect.TypeOf(fasthttp.Server{})
configType := reflect.TypeOf(Config{})
handler := s.handler()

for i := 0; i < fasthttpServerType.NumField(); i++ {
field := fasthttpServerType.Field(i)
if handler == nil {
t.Errorf("handler is nil")
}

if !unicode.IsUpper(rune(field.Name[0])) { // Check if the field is public
continue
} else if gotils.StringSliceInclude(notConfigFasthttpFields, field.Name) {
continue
}
newHostname := string(gotils.RandBytes(make([]byte, 10))) + ".com"

_, exist := configType.FieldByName(field.Name)
if !exist {
t.Errorf("The field '%s' does not exist in atreugo.Config", field.Name)
}
hosts := tt.args.hosts
hosts = append(hosts, newHostname)

for _, hostname := range hosts {
for _, path := range []string{testPath, "/notfound"} {
ctx := new(fasthttp.RequestCtx)
ctx.Request.Header.Set(fasthttp.HeaderAcceptEncoding, "gzip")
ctx.Request.Header.Set(fasthttp.HeaderHost, hostname)
ctx.Request.URI().SetHost(hostname)
ctx.Request.SetRequestURI(path)

handler(ctx)

statusCode := ctx.Response.StatusCode()
wantStatusCode := fasthttp.StatusOK

if path != testPath {
wantStatusCode = fasthttp.StatusNotFound
}

if statusCode != wantStatusCode {
t.Errorf("Host %s - Path %s, Status code == %d, want %d", hostname, path, statusCode, wantStatusCode)
}

if wantStatusCode == fasthttp.StatusNotFound {
continue
}

if tt.args.cfg.Compress && len(ctx.Response.Header.Peek(fasthttp.HeaderContentEncoding)) == 0 {
t.Errorf("The header '%s' is not setted", fasthttp.HeaderContentEncoding)
}
}
}
})
}
}

Expand Down Expand Up @@ -359,6 +461,10 @@ func TestAtreugo_Serve(t *testing.T) {
if s.cfg.Network != lnNetwork {
t.Errorf("Atreugo.Config.Network = %s, want %s", s.cfg.Network, lnNetwork)
}

if s.server.Handler == nil {
t.Error("Atreugo.server.Handler is nil")
}
}

func TestAtreugo_SetLogOutput(t *testing.T) {
Expand All @@ -372,3 +478,70 @@ func TestAtreugo_SetLogOutput(t *testing.T) {
t.Error("SetLogOutput() log output was not changed")
}
}

func TestAtreugo_NewVirtualHost(t *testing.T) {
hostname := "localhost"
s := New(testAtreugoConfig)

if s.virtualHosts != nil {
t.Error("Atreugo.virtualHosts must be nil before register a new virtual host")
}

vHost := s.NewVirtualHost(hostname)
if vHost == nil {
t.Fatal("Atreugo.NewVirtualHost() returned a nil router")
}

if !isEqual(vHost.router.NotFound, s.router.NotFound) {
t.Errorf("VirtualHost router.NotFound == %p, want %p", vHost.router.NotFound, s.router.NotFound)
}

if !isEqual(vHost.router.MethodNotAllowed, s.router.MethodNotAllowed) {
t.Errorf(
"VirtualHost router.MethodNotAllowed == %p, want %p",
vHost.router.MethodNotAllowed,
s.router.MethodNotAllowed,
)
}

if !isEqual(vHost.router.PanicHandler, s.router.PanicHandler) {
t.Errorf("VirtualHost router.PanicHandler == %p, want %p", vHost.router.PanicHandler, s.router.PanicHandler)
}

if h := s.virtualHosts[hostname]; h == nil {
t.Error("The new virtual host is not registeded")
}

defer func() {
err := recover()
if err == nil {
t.Error("Expected panic when a virtual host is duplicated")
}

wantErrString := fmt.Sprintf("a router is already registered for virtual host '%s'", hostname)
if err != wantErrString {
t.Errorf("Error string == %s, want %s", err, wantErrString)
}
}()

// panic when a virtual host is duplicated
s.NewVirtualHost(hostname)
}

// Benchmarks.
func Benchmark_Handler(b *testing.B) {
s := New(testAtreugoConfig)
s.GET("/", func(ctx *RequestCtx) error { return nil })

ctx := new(fasthttp.RequestCtx)
ctx.Request.Header.SetMethod("GET")
ctx.Request.SetRequestURI("/")

handler := s.handler()

b.ResetTimer()

for i := 0; i <= b.N; i++ {
handler(ctx)
}
}
Loading

0 comments on commit 5829d17

Please sign in to comment.