@@ -14,6 +14,7 @@ import (
1414 "github.com/go-openapi/runtime/middleware"
1515 "golang.org/x/exp/slices"
1616
17+ "github.com/armadaproject/armada/internal/common/auth"
1718 "github.com/armadaproject/armada/internal/common/serve"
1819 "github.com/armadaproject/armada/internal/lookout/configuration"
1920 "github.com/armadaproject/armada/internal/lookout/gen/restapi/operations"
@@ -28,6 +29,12 @@ func SetCorsAllowedOrigins(allowedOrigins []string) {
2829 corsAllowedOrigins = allowedOrigins
2930}
3031
32+ var authService auth.AuthService
33+
34+ func SetAuthService (s auth.AuthService ) {
35+ authService = s
36+ }
37+
3138func configureFlags (api * operations.LookoutAPI ) {
3239 // api.CommandLineOptionsGroups = []swag.CommandLineOptionsGroup{ ... }
3340}
@@ -91,7 +98,33 @@ var UIConfig configuration.UIConfig
9198// The middleware configuration happens before anything, this middleware also applies to serving the swagger.json document.
9299// So this is a good place to plug in a panic handling middleware, logging and metrics.
93100func setupGlobalMiddleware (apiHandler http.Handler ) http.Handler {
94- return recordRequestDuration (allowCORS (uiHandler (apiHandler ), corsAllowedOrigins ))
101+ return allowCORS (
102+ uiHandler (
103+ authHandler (
104+ recordRequestDuration (
105+ apiHandler ,
106+ ),
107+ ),
108+ ), corsAllowedOrigins )
109+ }
110+
111+ func authHandler (handler http.Handler ) http.Handler {
112+ mux := http .NewServeMux ()
113+
114+ // do not authenticate requests to healthchecker endpoint
115+ mux .Handle ("/health" , handler )
116+
117+ authFunction := auth .CreateHttpMiddlewareAuthFunction (authService )
118+ mux .Handle ("/api/" , http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
119+ ctxWithPrincipal , err := authFunction (w , r )
120+ if err != nil {
121+ return
122+ }
123+
124+ handler .ServeHTTP (w , r .WithContext (ctxWithPrincipal ))
125+ }))
126+
127+ return mux
95128}
96129
97130func uiHandler (apiHandler http.Handler ) http.Handler {
@@ -128,13 +161,11 @@ func allowCORS(handler http.Handler, corsAllowedOrigins []string) http.Handler {
128161
129162func recordRequestDuration (handler http.Handler ) http.Handler {
130163 return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
131- // TODO: for autheticated users, record the username
132- const unknownUser = "unknown"
133164 start := time .Now ()
134165 handler .ServeHTTP (w , r )
135166 duration := time .Since (start )
136167 if strings .HasPrefix (r .URL .Path , "/api/v1/" ) {
137- metrics .RecordRequestDuration (unknownUser , r .URL .Path , float64 (duration .Milliseconds ()))
168+ metrics .RecordRequestDuration (auth . GetPrincipal ( r . Context ()). GetName () , r .URL .Path , float64 (duration .Milliseconds ()))
138169 }
139170 })
140171}
0 commit comments