Skip to content
This repository was archived by the owner on Apr 12, 2024. It is now read-only.

Commit 584563a

Browse files
committed
Implement AuthNZ provider mechanism
Signed-off-by: Josh Kim <[email protected]>
1 parent 733b4b1 commit 584563a

File tree

8 files changed

+317
-3
lines changed

8 files changed

+317
-3
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
package auth_provider_manager
2+
3+
import (
4+
"fmt"
5+
"regexp"
6+
"sync"
7+
)
8+
9+
var authProviderMgrInstance AuthProviderManager
10+
11+
type AuthProviderManager interface {
12+
GetRESTAuthProvider(uri string) (*RESTAuthProvider, error)
13+
GetSTOMPAuthProvider() (*STOMPAuthProvider, error)
14+
SetRESTAuthProvider(regex *regexp.Regexp, provider *RESTAuthProvider) int
15+
SetSTOMPAuthProvider(provider *STOMPAuthProvider) error
16+
DeleteRESTAuthProvider(idx int) error
17+
DeleteSTOMPAuthProvider() error
18+
}
19+
20+
type regexpRESTAuthProviderPair struct {
21+
regexp *regexp.Regexp
22+
restAuthProvider *RESTAuthProvider
23+
}
24+
25+
type authProviderManager struct {
26+
stompAuthProvider *STOMPAuthProvider
27+
uriPatternRestAuthProviderPairs []*regexpRESTAuthProviderPair
28+
mu sync.Mutex
29+
}
30+
31+
type AuthProviderNotFoundError struct {}
32+
33+
func (e *AuthProviderNotFoundError) Error() string {
34+
return fmt.Sprintf("no auth provider was found at the given location/name")
35+
}
36+
37+
type AuthError struct {
38+
Code int
39+
Message string
40+
}
41+
func (e *AuthError) Error() string {
42+
return fmt.Sprintf("authentication/authorization error (%d): %s", e.Code, e.Message)
43+
}
44+
45+
func (a *authProviderManager) GetRESTAuthProvider(uri string) (*RESTAuthProvider, error) {
46+
a.mu.Lock()
47+
defer a.mu.Unlock()
48+
49+
// perform regex tests to find the first matching REST auth provider
50+
for _, pair := range a.uriPatternRestAuthProviderPairs {
51+
if pair.regexp.Match([]byte(uri)) {
52+
return pair.restAuthProvider, nil
53+
}
54+
}
55+
return nil, &AuthProviderNotFoundError{}
56+
}
57+
58+
func (a *authProviderManager) GetSTOMPAuthProvider() (*STOMPAuthProvider, error) {
59+
a.mu.Lock()
60+
defer a.mu.Unlock()
61+
if a.stompAuthProvider == nil {
62+
return nil, &AuthProviderNotFoundError{}
63+
}
64+
return a.stompAuthProvider, nil
65+
}
66+
67+
func (a *authProviderManager) SetRESTAuthProvider(regex *regexp.Regexp, provider *RESTAuthProvider) int {
68+
a.mu.Lock()
69+
defer a.mu.Unlock()
70+
a.uriPatternRestAuthProviderPairs = append(a.uriPatternRestAuthProviderPairs, &regexpRESTAuthProviderPair{
71+
regexp: regex,
72+
restAuthProvider: provider,
73+
})
74+
75+
return len(a.uriPatternRestAuthProviderPairs)-1
76+
}
77+
78+
func (a *authProviderManager) SetSTOMPAuthProvider(provider *STOMPAuthProvider) error {
79+
a.mu.Lock()
80+
defer a.mu.Unlock()
81+
a.stompAuthProvider = provider
82+
return nil
83+
}
84+
85+
func (a *authProviderManager) DeleteRESTAuthProvider(idx int) error {
86+
a.mu.Lock()
87+
defer a.mu.Unlock()
88+
if idx < 0 || idx > len(a.uriPatternRestAuthProviderPairs)-1 {
89+
return fmt.Errorf("no REST auth provider exists at index %d", idx)
90+
}
91+
borderLeft := idx
92+
borderRight := idx+1
93+
if idx > 0 {
94+
borderLeft--
95+
}
96+
a.uriPatternRestAuthProviderPairs = append(a.uriPatternRestAuthProviderPairs[:borderLeft], a.uriPatternRestAuthProviderPairs[borderRight:]...)
97+
return nil
98+
}
99+
100+
func (a *authProviderManager) DeleteSTOMPAuthProvider() error {
101+
a.mu.Lock()
102+
defer a.mu.Unlock()
103+
a.stompAuthProvider = nil
104+
return nil
105+
}
106+
107+
func GetAuthProviderManager() AuthProviderManager {
108+
if authProviderMgrInstance == nil {
109+
authProviderMgrInstance = &authProviderManager{}
110+
}
111+
return authProviderMgrInstance
112+
}
113+
114+
func DestroyAuthProviderManager() {
115+
authProviderMgrInstance = nil
116+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package auth_provider_manager
2+
3+
import (
4+
"github.com/vmware/transport-go/plank/utils"
5+
"net/http"
6+
"sort"
7+
"sync"
8+
)
9+
10+
type HttpRequestValidatorFn func(request *http.Request) *AuthError
11+
12+
type httpRequestValidatorRule struct {
13+
name string
14+
priority int
15+
validatorFn HttpRequestValidatorFn
16+
}
17+
18+
type RESTAuthProvider struct {
19+
rules map[string]*httpRequestValidatorRule
20+
rulesByPriority []*httpRequestValidatorRule
21+
mu sync.Mutex
22+
}
23+
24+
func (ap *RESTAuthProvider) Validate(request *http.Request) *AuthError {
25+
ap.mu.Lock()
26+
defer ap.mu.Unlock()
27+
defer func() {
28+
if r := recover(); r != nil {
29+
utils.Log.Errorln(r)
30+
}
31+
}()
32+
33+
for _, rule := range ap.rulesByPriority {
34+
err := rule.validatorFn(request)
35+
if err != nil {
36+
return err
37+
}
38+
}
39+
return nil
40+
}
41+
42+
func (ap *RESTAuthProvider) AddRule(name string, priority int, validatorFn HttpRequestValidatorFn) {
43+
ap.mu.Lock()
44+
defer ap.mu.Unlock()
45+
46+
rule := &httpRequestValidatorRule{
47+
name: name,
48+
priority: priority,
49+
validatorFn: validatorFn,
50+
}
51+
ap.rules[name] = rule
52+
ap.rulesByPriority = append(ap.rulesByPriority, rule)
53+
sort.SliceStable(ap.rulesByPriority, func(i, j int) bool {
54+
return ap.rulesByPriority[i].priority < ap.rulesByPriority[j].priority
55+
})
56+
}
57+
58+
func (ap *RESTAuthProvider) Reset() {
59+
ap.mu.Lock()
60+
defer ap.mu.Unlock()
61+
ap.rules = make(map[string]*httpRequestValidatorRule)
62+
ap.rulesByPriority = make([]*httpRequestValidatorRule, 0)
63+
}
64+
65+
func NewRESTAuthProvider() *RESTAuthProvider {
66+
ap := &RESTAuthProvider{}
67+
ap.Reset()
68+
return ap
69+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package auth_provider_manager
2+
3+
import (
4+
"github.com/go-stomp/stomp/frame"
5+
"github.com/vmware/transport-go/plank/utils"
6+
"sort"
7+
"sync"
8+
)
9+
10+
type StompFrameValidatorFn func(fr *frame.Frame) *AuthError
11+
12+
type stompFrameValidatorRule struct {
13+
msgFrameType string
14+
priority int
15+
validatorFn StompFrameValidatorFn
16+
}
17+
18+
type STOMPAuthProvider struct {
19+
rules map[string][]*stompFrameValidatorRule
20+
mu sync.Mutex
21+
}
22+
23+
func (ap *STOMPAuthProvider) Validate(fr *frame.Frame) error {
24+
ap.mu.Lock()
25+
defer ap.mu.Unlock()
26+
defer func() {
27+
if r := recover(); r != nil {
28+
utils.Log.Errorln(r)
29+
}
30+
}()
31+
32+
rules, found := ap.rules[fr.Command]
33+
// if no rule was found let the request pass through
34+
if !found {
35+
return nil
36+
}
37+
38+
for _, rule := range rules {
39+
err := rule.validatorFn(fr)
40+
if err != nil {
41+
return err
42+
}
43+
}
44+
return nil
45+
}
46+
47+
func (ap *STOMPAuthProvider) AddRule(types []string, priority int, validatorFn StompFrameValidatorFn) {
48+
ap.mu.Lock()
49+
defer ap.mu.Unlock()
50+
51+
for _, typ := range types {
52+
rule := &stompFrameValidatorRule{
53+
msgFrameType: typ,
54+
priority: priority,
55+
validatorFn: validatorFn,
56+
}
57+
if _, ok := ap.rules[typ]; !ok {
58+
ap.rules[typ] = make([]*stompFrameValidatorRule, 0)
59+
}
60+
ap.rules[typ] = append(ap.rules[typ], rule)
61+
sort.SliceStable(ap.rules[typ], func(i, j int) bool {
62+
return ap.rules[typ][i].priority < ap.rules[typ][j].priority
63+
})
64+
}
65+
}
66+
67+
func (ap *STOMPAuthProvider) Reset() {
68+
ap.mu.Lock()
69+
defer ap.mu.Unlock()
70+
ap.rules = make(map[string][]*stompFrameValidatorRule)
71+
}
72+
73+
func NewSTOMPAuthProvider() *STOMPAuthProvider {
74+
return &STOMPAuthProvider{
75+
rules: make(map[string][]*stompFrameValidatorRule),
76+
}
77+
}

plank/pkg/server/endpointer_handler_factory.go

+20
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ package server
33
import (
44
"context"
55
"encoding/json"
6+
"errors"
67
"fmt"
78
"github.com/sirupsen/logrus"
89
"github.com/vmware/transport-go/bus"
910
"github.com/vmware/transport-go/model"
11+
"github.com/vmware/transport-go/plank/pkg/server/auth_provider_manager"
1012
"github.com/vmware/transport-go/plank/utils"
1113
"github.com/vmware/transport-go/service"
1214
"net/http"
@@ -25,6 +27,24 @@ func buildEndpointHandler(svcChannel string, reqBuilder service.RequestBuilder,
2527
}
2628
}()
2729

30+
apm := auth_provider_manager.GetAuthProviderManager()
31+
provider, err := apm.GetRESTAuthProvider(r.URL.Path)
32+
33+
if err != nil && !errors.Is(err, &auth_provider_manager.AuthProviderNotFoundError{}){
34+
utils.Log.Errorln(err)
35+
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
36+
return
37+
}
38+
39+
// run validation against rules registered in the provider only if such provider exists
40+
if provider != nil {
41+
err := provider.Validate(r)
42+
if err != nil {
43+
http.Error(w, err.Message, err.Code)
44+
return
45+
}
46+
}
47+
2848
// set context that expires after the provided amount of time in restBridgeTimeout to prevent requests from hanging forever
2949
ctx, cancelFn := context.WithTimeout(context.Background(), restBridgeTimeout)
3050
defer cancelFn()

plank/pkg/server/server.go

+14
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,20 @@ func (ps *platformServer) SetHttpChannelBridge(bridgeConfig *service.RESTBridgeC
334334
bridgeConfig.ServiceChannel, bridgeConfig.Uri, bridgeConfig.Method)
335335
}
336336

337+
func (ps *platformServer) clearHttpChannelBridgesForService(serviceChannel string) {
338+
ps.lock.Lock()
339+
defer ps.lock.Unlock()
340+
341+
// if in override mode delete existing mappings associated with the service
342+
existingMappings := ps.serviceChanToBridgeEndpoints[serviceChannel]
343+
ps.serviceChanToBridgeEndpoints[serviceChannel] = make([]string, 0)
344+
for _, handlerKey := range existingMappings {
345+
utils.Log.Infof("Removing existing service - REST mapping '%s' for service '%s'", handlerKey, serviceChannel)
346+
ps.router.Get(fmt.Sprintf("bridge-%s", handlerKey)).Handler(http.NotFoundHandler())
347+
delete(ps.endpointHandlerMap, handlerKey)
348+
}
349+
}
350+
337351
// GetMiddlewareManager returns the MiddleManager instance
338352
func (ps *platformServer) GetMiddlewareManager() middleware.MiddlewareManager {
339353
return ps.middlewareManager

stompserver/server.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ func (s *stompServer) waitForConnections() {
200200
if err != nil {
201201
log.Println("Failed to establish client connection:", err)
202202
} else {
203-
c := NewStompConn(rawConn, s.config, s.connectionEvents)
203+
c := NewStompConn(rawConn, s.config, s.connectionEvents, true)
204204

205205
s.connectionEvents <- &ConnEvent{
206206
ConnId: c.GetId(),

stompserver/stomp_connection.go

+19-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"github.com/go-stomp/stomp"
99
"github.com/go-stomp/stomp/frame"
1010
"github.com/google/uuid"
11+
"github.com/vmware/transport-go/plank/pkg/server/auth_provider_manager"
1112
"log"
1213
"strconv"
1314
"strings"
@@ -52,9 +53,10 @@ type stompConn struct {
5253
subscriptions map[string]*subscription
5354
currentMessageId uint64
5455
closeOnce sync.Once
56+
authProviderManager auth_provider_manager.AuthProviderManager
5557
}
5658

57-
func NewStompConn(rawConnection RawConnection, config StompConfig, events chan *ConnEvent) StompConn {
59+
func NewStompConn(rawConnection RawConnection, config StompConfig, events chan *ConnEvent, useAuthProvider bool) StompConn {
5860
conn := &stompConn{
5961
rawConnection: rawConnection,
6062
state: connecting,
@@ -66,6 +68,10 @@ func NewStompConn(rawConnection RawConnection, config StompConfig, events chan *
6668
subscriptions: make(map[string]*subscription),
6769
}
6870

71+
if useAuthProvider {
72+
conn.authProviderManager = auth_provider_manager.GetAuthProviderManager()
73+
}
74+
6975
go conn.run()
7076
go conn.readInFrames()
7177

@@ -158,6 +164,18 @@ func (conn *stompConn) run() {
158164
}
159165

160166
func (conn *stompConn) handleIncomingFrame(f *frame.Frame) error {
167+
if conn.authProviderManager != nil {
168+
// if a STOMP auth provider was configured try to validate each frame with the rules defined
169+
// in the provider. if auth provider is not found let the payload through
170+
if provider, _ := conn.authProviderManager.GetSTOMPAuthProvider(); provider != nil {
171+
err := provider.Validate(f) // TODO: should we return error through a transport channel?
172+
if err != nil {
173+
fmt.Println(f.Command, err)
174+
return err
175+
}
176+
}
177+
}
178+
161179
switch f.Command {
162180

163181
case frame.CONNECT, frame.STOMP:

stompserver/stomp_connection_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ func getTestStompConn(conf StompConfig, events chan *ConnEvent) (*stompConn, *Mo
9797
}
9898

9999
rawConn := NewMockRawConnection()
100-
return NewStompConn(rawConn, conf, events).(*stompConn), rawConn, events
100+
return NewStompConn(rawConn, conf, events, true).(*stompConn), rawConn, events
101101
}
102102

103103
func TestStompConn_Connect(t *testing.T) {

0 commit comments

Comments
 (0)