Skip to content
This repository has been archived by the owner on Apr 12, 2024. It is now read-only.

Commit

Permalink
Implement AuthNZ provider mechanism
Browse files Browse the repository at this point in the history
Signed-off-by: Josh Kim <[email protected]>
  • Loading branch information
jooskim committed Aug 6, 2021
1 parent b0852fd commit 2bf6b03
Show file tree
Hide file tree
Showing 8 changed files with 318 additions and 6 deletions.
2 changes: 2 additions & 0 deletions plank/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,5 @@ require (
golang.org/x/tools v0.1.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
)

replace github.com/vmware/transport-go => ../
127 changes: 127 additions & 0 deletions plank/pkg/server/auth_provider_manager/auth_provider_manager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package auth_provider_manager

import (
"fmt"
"regexp"
"sync"
)

type AuthProviderType int

const (
STOMPAuthProviderType AuthProviderType = iota
RESTAuthProviderType

AUTH_PROVIDER_MANAGER_STORE = "AuthProviderManager"
AUTH_PROVIDER_MANAGER_STORE_INIT_EVENT = "INIT"
AUTH_PROVIDER_MANAGER_INSTANCE_KEY = "instance"
)

var authProviderMgrInstance AuthProviderManager

type AuthProviderManager interface {
GetRESTAuthProvider(uri string) (*RESTAuthProvider, error)
GetSTOMPAuthProvider() (*STOMPAuthProvider, error)
SetRESTAuthProvider(regex *regexp.Regexp, provider *RESTAuthProvider) int
SetSTOMPAuthProvider(provider *STOMPAuthProvider) error
DeleteRESTAuthProvider(idx int) error
DeleteSTOMPAuthProvider() error
}

type regexpRESTAuthProviderPair struct {
regexp *regexp.Regexp
restAuthProvider *RESTAuthProvider
}

type authProviderManager struct {
stompAuthProvider *STOMPAuthProvider
uriPatternRestAuthProviderPairs []*regexpRESTAuthProviderPair
mu sync.Mutex
}

type AuthProviderNotFoundError struct {}

func (e *AuthProviderNotFoundError) Error() string {
return fmt.Sprintf("no auth provider was found at the given location/name")
}

type AuthError struct {
Code int
Message string
}
func (e *AuthError) Error() string {
return fmt.Sprintf("authentication/authorization error (%d): %s", e.Code, e.Message)
}

func (a *authProviderManager) GetRESTAuthProvider(uri string) (*RESTAuthProvider, error) {
a.mu.Lock()
defer a.mu.Unlock()

// perform regex tests to find the first matching REST auth provider
for _, pair := range a.uriPatternRestAuthProviderPairs {
if pair.regexp.Match([]byte(uri)) {
return pair.restAuthProvider, nil
}
}
return nil, &AuthProviderNotFoundError{}
}

func (a *authProviderManager) GetSTOMPAuthProvider() (*STOMPAuthProvider, error) {
a.mu.Lock()
defer a.mu.Unlock()
if a.stompAuthProvider == nil {
return nil, &AuthProviderNotFoundError{}
}
return a.stompAuthProvider, nil
}

func (a *authProviderManager) SetRESTAuthProvider(regex *regexp.Regexp, provider *RESTAuthProvider) int {
a.mu.Lock()
defer a.mu.Unlock()
a.uriPatternRestAuthProviderPairs = append(a.uriPatternRestAuthProviderPairs, &regexpRESTAuthProviderPair{
regexp: regex,
restAuthProvider: provider,
})

return len(a.uriPatternRestAuthProviderPairs)-1
}

func (a *authProviderManager) SetSTOMPAuthProvider(provider *STOMPAuthProvider) error {
a.mu.Lock()
defer a.mu.Unlock()
a.stompAuthProvider = provider
return nil
}

func (a *authProviderManager) DeleteRESTAuthProvider(idx int) error {
a.mu.Lock()
defer a.mu.Unlock()
if idx < 0 || idx > len(a.uriPatternRestAuthProviderPairs)-1 {
return fmt.Errorf("no REST auth provider exists at index %d", idx)
}
borderLeft := idx
borderRight := idx+1
if idx > 0 {
borderLeft--
}
a.uriPatternRestAuthProviderPairs = append(a.uriPatternRestAuthProviderPairs[:borderLeft], a.uriPatternRestAuthProviderPairs[borderRight:]...)
return nil
}

func (a *authProviderManager) DeleteSTOMPAuthProvider() error {
a.mu.Lock()
defer a.mu.Unlock()
a.stompAuthProvider = nil
return nil
}

func GetAuthProviderManager() AuthProviderManager {
if authProviderMgrInstance == nil {
authProviderMgrInstance = &authProviderManager{}
}
return authProviderMgrInstance
}

func DestroyAuthProviderManager() {
authProviderMgrInstance = nil
}
69 changes: 69 additions & 0 deletions plank/pkg/server/auth_provider_manager/rest_auth_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package auth_provider_manager

import (
"github.com/vmware/transport-go/plank/utils"
"net/http"
"sort"
"sync"
)

type HttpRequestValidatorFn func(request *http.Request) *AuthError

type httpRequestValidatorRule struct {
name string
priority int
validatorFn HttpRequestValidatorFn
}

type RESTAuthProvider struct {
rules map[string]*httpRequestValidatorRule
rulesByPriority []*httpRequestValidatorRule
mu sync.Mutex
}

func (ap *RESTAuthProvider) Validate(request *http.Request) *AuthError {
ap.mu.Lock()
defer ap.mu.Unlock()
defer func() {
if r := recover(); r != nil {
utils.Log.Errorln(r)
}
}()

for _, rule := range ap.rulesByPriority {
err := rule.validatorFn(request)
if err != nil {
return err
}
}
return nil
}

func (ap *RESTAuthProvider) AddRule(name string, priority int, validatorFn HttpRequestValidatorFn) {
ap.mu.Lock()
defer ap.mu.Unlock()

rule := &httpRequestValidatorRule{
name: name,
priority: priority,
validatorFn: validatorFn,
}
ap.rules[name] = rule
ap.rulesByPriority = append(ap.rulesByPriority, rule)
sort.SliceStable(ap.rulesByPriority, func(i, j int) bool {
return ap.rulesByPriority[i].priority < ap.rulesByPriority[j].priority
})
}

func (ap *RESTAuthProvider) Reset() {
ap.mu.Lock()
defer ap.mu.Unlock()
ap.rules = make(map[string]*httpRequestValidatorRule)
ap.rulesByPriority = make([]*httpRequestValidatorRule, 0)
}

func NewRESTAuthProvider() *RESTAuthProvider {
ap := &RESTAuthProvider{}
ap.Reset()
return ap
}
77 changes: 77 additions & 0 deletions plank/pkg/server/auth_provider_manager/stomp_auth_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package auth_provider_manager

import (
"github.com/go-stomp/stomp/frame"
"github.com/vmware/transport-go/plank/utils"
"sort"
"sync"
)

type StompFrameValidatorFn func(fr *frame.Frame) *AuthError

type stompFrameValidatorRule struct {
msgFrameType string
priority int
validatorFn StompFrameValidatorFn
}

type STOMPAuthProvider struct {
rules map[string][]*stompFrameValidatorRule
mu sync.Mutex
}

func (ap *STOMPAuthProvider) Validate(fr *frame.Frame) error {
ap.mu.Lock()
defer ap.mu.Unlock()
defer func() {
if r := recover(); r != nil {
utils.Log.Errorln(r)
}
}()

rules, found := ap.rules[fr.Command]
// if no rule was found let the request pass through
if !found {
return nil
}

for _, rule := range rules {
err := rule.validatorFn(fr)
if err != nil {
return err
}
}
return nil
}

func (ap *STOMPAuthProvider) AddRule(types []string, priority int, validatorFn StompFrameValidatorFn) {
ap.mu.Lock()
defer ap.mu.Unlock()

for _, typ := range types {
rule := &stompFrameValidatorRule{
msgFrameType: typ,
priority: priority,
validatorFn: validatorFn,
}
if _, ok := ap.rules[typ]; !ok {
ap.rules[typ] = make([]*stompFrameValidatorRule, 0)
}
ap.rules[typ] = append(ap.rules[typ], rule)
sort.SliceStable(ap.rules[typ], func(i, j int) bool {
return ap.rules[typ][i].priority < ap.rules[typ][j].priority
})
}
}

func (ap *STOMPAuthProvider) Reset() {
ap.mu.Lock()
defer ap.mu.Unlock()
ap.rules = make(map[string][]*stompFrameValidatorRule)
}

func NewSTOMPAuthProvider() *STOMPAuthProvider {
return &STOMPAuthProvider{
rules: make(map[string][]*stompFrameValidatorRule),
}
}
25 changes: 22 additions & 3 deletions plank/pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/vmware/transport-go/plank/pkg/server/auth_provider_manager"
"io/ioutil"
"log"
"net"
Expand All @@ -22,14 +23,14 @@ import (

"github.com/gorilla/handlers"
"github.com/gorilla/mux"
"github.com/vmware/transport-go/plank/pkg/middleware"
"github.com/vmware/transport-go/plank/utils"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/sirupsen/logrus"
"github.com/urfave/cli"
"github.com/vmware/transport-go/bus"
"github.com/vmware/transport-go/model"
"github.com/vmware/transport-go/plank/pkg/middleware"
"github.com/vmware/transport-go/plank/utils"
"github.com/vmware/transport-go/service"
"github.com/vmware/transport-go/stompserver"
)
Expand Down Expand Up @@ -356,10 +357,28 @@ func buildEndpointHandler(svcChannel string, reqBuilder func(w http.ResponseWrit
defer func() {
if r := recover(); r != nil {
utils.Log.Errorln(r)
http.Error(w, "Internal Server Error", 500)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}()

apm := auth_provider_manager.GetAuthProviderManager()
provider, err := apm.GetRESTAuthProvider(r.URL.Path)

if err != nil && !errors.Is(err, &auth_provider_manager.AuthProviderNotFoundError{}){
utils.Log.Errorln(err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}

// run validation against rules registered in the provider only if such provider exists
if provider != nil {
err := provider.Validate(r)
if err != nil {
http.Error(w, err.Message, err.Code)
return
}
}

// set context that would expire after 30 seconds by default to prevent requests from hanging forever
ctx, cancelFn := context.WithTimeout(context.Background(), 30*time.Second)
defer cancelFn()
Expand Down
2 changes: 1 addition & 1 deletion stompserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ func (s *stompServer) waitForConnections() {
if err != nil {
log.Println("Failed to establish client connection:", err)
} else {
c := NewStompConn(rawConn, s.config, s.connectionEvents)
c := NewStompConn(rawConn, s.config, s.connectionEvents, true)

s.connectionEvents <- &connEvent{
conn: c,
Expand Down
20 changes: 19 additions & 1 deletion stompserver/stomp_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/go-stomp/stomp"
"github.com/go-stomp/stomp/frame"
"github.com/google/uuid"
"github.com/vmware/transport-go/plank/pkg/server/auth_provider_manager"
"log"
"strconv"
"strings"
Expand Down Expand Up @@ -52,9 +53,10 @@ type stompConn struct {
subscriptions map[string]*subscription
currentMessageId uint64
closeOnce sync.Once
authProviderManager auth_provider_manager.AuthProviderManager
}

func NewStompConn(rawConnection RawConnection, config StompConfig, events chan *connEvent) StompConn {
func NewStompConn(rawConnection RawConnection, config StompConfig, events chan *connEvent, useAuthProvider bool) StompConn {
conn := &stompConn{
rawConnection: rawConnection,
state: connecting,
Expand All @@ -66,6 +68,10 @@ func NewStompConn(rawConnection RawConnection, config StompConfig, events chan *
subscriptions: make(map[string]*subscription),
}

if useAuthProvider {
conn.authProviderManager = auth_provider_manager.GetAuthProviderManager()
}

go conn.run()
go conn.readInFrames()

Expand Down Expand Up @@ -157,6 +163,18 @@ func (conn *stompConn) run() {
}

func (conn *stompConn) handleIncomingFrame(f *frame.Frame) error {
if conn.authProviderManager != nil {
// if a STOMP auth provider was configured try to validate each frame with the rules defined
// in the provider. if auth provider is not found let the payload through
if provider, _ := conn.authProviderManager.GetSTOMPAuthProvider(); provider != nil {
err := provider.Validate(f) // TODO: should we return error through a transport channel?
if err != nil {
fmt.Println(f.Command, err)
return err
}
}
}

switch f.Command {

case frame.CONNECT, frame.STOMP:
Expand Down
Loading

0 comments on commit 2bf6b03

Please sign in to comment.