Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

Commit df88d23

Browse files
Logout hook plugin (#611)
* logout hook plugin
1 parent b32a3d3 commit df88d23

File tree

6 files changed

+141
-14
lines changed

6 files changed

+141
-14
lines changed

auth/cookie.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,11 @@ func NewRedirectCookie(ctx context.Context, redirectURL string) *http.Cookie {
163163
}
164164
}
165165

166+
// GetAuthFlowEndRedirect returns the redirect URI according to data in request.
166167
// At the end of the OAuth flow, the server needs to send the user somewhere. This should have been stored as a cookie
167168
// during the initial /login call. If that cookie is missing from the request, it will default to the one configured
168169
// in this package's Config object.
169-
func getAuthFlowEndRedirect(ctx context.Context, authCtx interfaces.AuthenticationContext, request *http.Request) string {
170+
func GetAuthFlowEndRedirect(ctx context.Context, authCtx interfaces.AuthenticationContext, request *http.Request) string {
170171
queryParams := request.URL.Query()
171172
// Use the redirect URL specified in the request if one is available.
172173
if redirectURL := queryParams.Get(RedirectURLParameter); len(redirectURL) > 0 {

auth/cookie_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@ import (
99
"net/url"
1010
"testing"
1111

12-
"github.com/flyteorg/flyteadmin/auth/config"
13-
"github.com/flyteorg/flyteadmin/auth/interfaces/mocks"
1412
stdConfig "github.com/flyteorg/flytestdlib/config"
1513
"github.com/gorilla/securecookie"
1614
"github.com/stretchr/testify/assert"
15+
16+
"github.com/flyteorg/flyteadmin/auth/config"
17+
"github.com/flyteorg/flyteadmin/auth/interfaces/mocks"
1718
)
1819

1920
func mustParseURL(t testing.TB, u string) url.URL {
@@ -131,7 +132,7 @@ func TestGetAuthFlowEndRedirect(t *testing.T) {
131132
assert.NotNil(t, cookie)
132133
request.AddCookie(cookie)
133134
mockAuthCtx := &mocks.AuthenticationContext{}
134-
redirect := getAuthFlowEndRedirect(ctx, mockAuthCtx, request)
135+
redirect := GetAuthFlowEndRedirect(ctx, mockAuthCtx, request)
135136
assert.Equal(t, "/console", redirect)
136137
})
137138

@@ -145,7 +146,7 @@ func TestGetAuthFlowEndRedirect(t *testing.T) {
145146
RedirectURL: stdConfig.URL{URL: mustParseURL(t, "/api/v1/projects")},
146147
},
147148
})
148-
redirect := getAuthFlowEndRedirect(ctx, mockAuthCtx, request)
149+
redirect := GetAuthFlowEndRedirect(ctx, mockAuthCtx, request)
149150
assert.Equal(t, "/api/v1/projects", redirect)
150151
})
151152
}

auth/handlers.go

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ func (e *PreRedirectHookError) Error() string {
4848
// PreRedirectHookError is the error interface which allows the user to set correct http status code and Message to be set in case the function returns an error
4949
// without which the current usage in GetCallbackHandler will set this to InternalServerError
5050
type PreRedirectHookFunc func(ctx context.Context, authCtx interfaces.AuthenticationContext, request *http.Request, w http.ResponseWriter) *PreRedirectHookError
51+
type LogoutHookFunc func(ctx context.Context, authCtx interfaces.AuthenticationContext, request *http.Request, w http.ResponseWriter) error
5152
type HTTPRequestToMetadataAnnotator func(ctx context.Context, request *http.Request) metadata.MD
5253
type UserInfoForwardResponseHandler func(ctx context.Context, w http.ResponseWriter, m protoiface.MessageV1) error
5354

@@ -68,7 +69,7 @@ func RegisterHandlers(ctx context.Context, handler interfaces.HandlerRegisterer,
6869
handler.HandleFunc(fmt.Sprintf("/%s", OIdCMetadataEndpoint), GetOIdCMetadataEndpointRedirectHandler(ctx, authCtx))
6970

7071
// These endpoints require authentication
71-
handler.HandleFunc("/logout", GetLogoutEndpointHandler(ctx, authCtx))
72+
handler.HandleFunc("/logout", GetLogoutEndpointHandler(ctx, authCtx, pluginRegistry))
7273
}
7374

7475
// Look for access token and refresh token, if both are present and the access token is expired, then attempt to
@@ -123,7 +124,7 @@ func RefreshTokensIfExists(ctx context.Context, authCtx interfaces.Authenticatio
123124
return
124125
}
125126

126-
redirectURL := getAuthFlowEndRedirect(ctx, authCtx, request)
127+
redirectURL := GetAuthFlowEndRedirect(ctx, authCtx, request)
127128
http.Redirect(writer, request, redirectURL, http.StatusTemporaryRedirect)
128129
}
129130
}
@@ -210,7 +211,7 @@ func GetCallbackHandler(ctx context.Context, authCtx interfaces.AuthenticationCo
210211
}
211212
logger.Info(ctx, "Successfully called the preRedirect hook")
212213
}
213-
redirectURL := getAuthFlowEndRedirect(ctx, authCtx, request)
214+
redirectURL := GetAuthFlowEndRedirect(ctx, authCtx, request)
214215
http.Redirect(writer, request, redirectURL, http.StatusTemporaryRedirect)
215216
}
216217
}
@@ -466,9 +467,19 @@ func GetOIdCMetadataEndpointRedirectHandler(ctx context.Context, authCtx interfa
466467
}
467468
}
468469

469-
func GetLogoutEndpointHandler(ctx context.Context, authCtx interfaces.AuthenticationContext) http.HandlerFunc {
470+
func GetLogoutEndpointHandler(ctx context.Context, authCtx interfaces.AuthenticationContext, pluginRegistry *plugins.Registry) http.HandlerFunc {
470471
return func(writer http.ResponseWriter, request *http.Request) {
471-
logger.Debugf(ctx, "Deleting auth cookies")
472+
hook := plugins.Get[LogoutHookFunc](pluginRegistry, plugins.PluginIDLogoutHook)
473+
if hook != nil {
474+
if err := hook(ctx, authCtx, request, writer); err != nil {
475+
logger.Errorf(ctx, "logout hook failed: %v", err)
476+
writer.WriteHeader(http.StatusInternalServerError)
477+
return
478+
}
479+
logger.Debugf(ctx, "logout hook called")
480+
}
481+
482+
logger.Debugf(ctx, "deleting auth cookies")
472483
authCtx.CookieManager().DeleteCookies(ctx, writer)
473484

474485
// Redirect if one was given

auth/handlers_test.go

Lines changed: 97 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package auth
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"io"
78
"net/http"
@@ -11,8 +12,11 @@ import (
1112
"testing"
1213

1314
"github.com/coreos/go-oidc"
15+
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"
16+
stdConfig "github.com/flyteorg/flytestdlib/config"
1417
"github.com/stretchr/testify/assert"
1518
"github.com/stretchr/testify/mock"
19+
"github.com/stretchr/testify/require"
1620
"golang.org/x/oauth2"
1721
"google.golang.org/protobuf/types/known/structpb"
1822

@@ -21,8 +25,6 @@ import (
2125
"github.com/flyteorg/flyteadmin/auth/interfaces/mocks"
2226
"github.com/flyteorg/flyteadmin/pkg/common"
2327
"github.com/flyteorg/flyteadmin/plugins"
24-
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"
25-
stdConfig "github.com/flyteorg/flytestdlib/config"
2628
)
2729

2830
const (
@@ -50,8 +52,8 @@ func setupMockedAuthContextAtEndpoint(endpoint string) *mocks.AuthenticationCont
5052
Timeout: IdpConnectionTimeout,
5153
}
5254
mockAuthCtx.OnCookieManagerMatch().Return(mockCookieHandler)
53-
mockCookieHandler.OnSetTokenCookiesMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
54-
mockCookieHandler.OnSetUserInfoCookieMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
55+
mockCookieHandler.OnSetTokenCookiesMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil).Once()
56+
mockCookieHandler.OnSetUserInfoCookieMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil).Once()
5557
mockAuthCtx.OnOAuth2ClientConfigMatch(mock.Anything).Return(&dummyOAuth2Config)
5658
mockAuthCtx.OnGetHTTPClient().Return(dummyHTTPClient)
5759
return mockAuthCtx
@@ -255,6 +257,97 @@ func TestGetLoginHandler(t *testing.T) {
255257
assert.True(t, strings.Contains(w.Header().Get("Set-Cookie"), "flyte_csrf_state="))
256258
}
257259

260+
func TestGetLogoutHandler(t *testing.T) {
261+
ctx := context.Background()
262+
263+
t.Run("no_hook_no_redirect", func(t *testing.T) {
264+
cookieHandler := &CookieManager{}
265+
authCtx := mocks.AuthenticationContext{}
266+
authCtx.OnCookieManager().Return(cookieHandler).Once()
267+
w := httptest.NewRecorder()
268+
r := plugins.NewRegistry()
269+
req, err := http.NewRequest(http.MethodGet, "/logout", nil)
270+
require.NoError(t, err)
271+
272+
GetLogoutEndpointHandler(ctx, &authCtx, r)(w, req)
273+
274+
assert.Equal(t, http.StatusOK, w.Code)
275+
require.Len(t, w.Result().Cookies(), 3)
276+
authCtx.AssertExpectations(t)
277+
})
278+
279+
t.Run("no_hook_with_redirect", func(t *testing.T) {
280+
ctx := context.Background()
281+
cookieHandler := &CookieManager{}
282+
authCtx := mocks.AuthenticationContext{}
283+
authCtx.OnCookieManager().Return(cookieHandler).Once()
284+
w := httptest.NewRecorder()
285+
r := plugins.NewRegistry()
286+
req, err := http.NewRequest(http.MethodGet, "/logout?redirect_url=/foo", nil)
287+
require.NoError(t, err)
288+
289+
GetLogoutEndpointHandler(ctx, &authCtx, r)(w, req)
290+
291+
assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
292+
authCtx.AssertExpectations(t)
293+
require.Len(t, w.Result().Cookies(), 3)
294+
})
295+
296+
t.Run("with_hook_with_redirect", func(t *testing.T) {
297+
ctx := context.Background()
298+
cookieHandler := &CookieManager{}
299+
authCtx := mocks.AuthenticationContext{}
300+
authCtx.OnCookieManager().Return(cookieHandler).Once()
301+
w := httptest.NewRecorder()
302+
r := plugins.NewRegistry()
303+
hook := new(mock.Mock)
304+
err := r.Register(plugins.PluginIDLogoutHook, LogoutHookFunc(func(
305+
ctx context.Context,
306+
authCtx interfaces.AuthenticationContext,
307+
request *http.Request,
308+
w http.ResponseWriter) error {
309+
return hook.MethodCalled("hook").Error(0)
310+
}))
311+
hook.On("hook").Return(nil).Once()
312+
require.NoError(t, err)
313+
req, err := http.NewRequest(http.MethodGet, "/logout?redirect_url=/foo", nil)
314+
require.NoError(t, err)
315+
316+
GetLogoutEndpointHandler(ctx, &authCtx, r)(w, req)
317+
318+
assert.Equal(t, http.StatusTemporaryRedirect, w.Code)
319+
require.Len(t, w.Result().Cookies(), 3)
320+
authCtx.AssertExpectations(t)
321+
hook.AssertExpectations(t)
322+
})
323+
324+
t.Run("hook_error", func(t *testing.T) {
325+
ctx := context.Background()
326+
authCtx := mocks.AuthenticationContext{}
327+
w := httptest.NewRecorder()
328+
r := plugins.NewRegistry()
329+
hook := new(mock.Mock)
330+
err := r.Register(plugins.PluginIDLogoutHook, LogoutHookFunc(func(
331+
ctx context.Context,
332+
authCtx interfaces.AuthenticationContext,
333+
request *http.Request,
334+
w http.ResponseWriter) error {
335+
return hook.MethodCalled("hook").Error(0)
336+
}))
337+
hook.On("hook").Return(errors.New("fail")).Once()
338+
require.NoError(t, err)
339+
req, err := http.NewRequest(http.MethodGet, "/logout?redirect_url=/foo", nil)
340+
require.NoError(t, err)
341+
342+
GetLogoutEndpointHandler(ctx, &authCtx, r)(w, req)
343+
344+
assert.Equal(t, http.StatusInternalServerError, w.Code)
345+
assert.Empty(t, w.Result().Cookies())
346+
authCtx.AssertExpectations(t)
347+
hook.AssertExpectations(t)
348+
})
349+
}
350+
258351
func TestGetHTTPRequestCookieToMetadataHandler(t *testing.T) {
259352
ctx := context.Background()
260353
// These were generated for unit testing only.

plugins/registry.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ const (
1313
PluginIDDataProxy PluginID = "DataProxy"
1414
PluginIDUnaryServiceMiddleware PluginID = "UnaryServiceMiddleware"
1515
PluginIDPreRedirectHook PluginID = "PreRedirectHook"
16+
PluginIDLogoutHook PluginID = "LogoutHook"
1617
)
1718

1819
type AtomicRegistry struct {

plugins/registry_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,26 @@ func TestRedirectHook(t *testing.T) {
4141
assert.Equal(t, fmt.Errorf("redirect hook error"), err)
4242
}
4343

44+
type LogoutHook func(context.Context) error
45+
46+
func TestLogoutHook(t *testing.T) {
47+
ar := NewAtomicRegistry(nil)
48+
r := NewRegistry()
49+
50+
hook := LogoutHook(func(ctx context.Context) error {
51+
return fmt.Errorf("redirect hook error")
52+
})
53+
err := r.Register(PluginIDLogoutHook, hook)
54+
assert.NoError(t, err)
55+
56+
ar.Store(r)
57+
r = ar.Load()
58+
fn := Get[LogoutHook](r, PluginIDLogoutHook)
59+
err = fn(context.Background())
60+
61+
assert.Equal(t, fmt.Errorf("redirect hook error"), err)
62+
}
63+
4464
func TestRegistry_RegisterDefault(t *testing.T) {
4565
r := NewRegistry()
4666
r.RegisterDefault("hello", 5)

0 commit comments

Comments
 (0)