From a372557260d57ef4bac29d43d9bb2a981d148f3b Mon Sep 17 00:00:00 2001 From: Mike Date: Mon, 27 May 2024 14:10:17 -0400 Subject: [PATCH] implement temporary credentials validation endpoint --- server/http/mgmt_api.go | 60 +++++++++++++++++++++++++++++++ server/oscar/middleware/logger.go | 21 ++++++++--- 2 files changed, 76 insertions(+), 5 deletions(-) diff --git a/server/http/mgmt_api.go b/server/http/mgmt_api.go index 0d2c762c..786b5d2c 100644 --- a/server/http/mgmt_api.go +++ b/server/http/mgmt_api.go @@ -1,6 +1,7 @@ package http import ( + "encoding/base64" "encoding/json" "errors" "fmt" @@ -8,11 +9,13 @@ import ( "net" "net/http" "os" + "strings" "github.com/google/uuid" "github.com/mk6i/retro-aim-server/config" "github.com/mk6i/retro-aim-server/state" + "github.com/mk6i/retro-aim-server/wire" ) type userWithPassword struct { @@ -33,6 +36,7 @@ type UserManager interface { AllUsers() ([]state.User, error) InsertUser(u state.User) error SetUserPassword(u state.User) error + User(screenName string) (*state.User, error) } type SessionRetriever interface { @@ -50,6 +54,9 @@ func StartManagementAPI(cfg config.Config, userManager UserManager, sessionRetri mux.HandleFunc("/user/password", func(w http.ResponseWriter, r *http.Request) { userPasswordHandler(w, r, userManager, newUser, logger) }) + mux.HandleFunc("/user/login", func(w http.ResponseWriter, r *http.Request) { + loginHandler(w, r, userManager, logger) + }) mux.HandleFunc("/session", func(w http.ResponseWriter, r *http.Request) { sessionHandler(w, r, sessionRetriever) }) @@ -207,3 +214,56 @@ func postUserHandler( w.WriteHeader(http.StatusCreated) fmt.Fprintln(w, "User account created successfully.") } + +// loginHandler is a temporary endpoint for validating user credentials for +// chivanet. do not rely on this endpoint, as it will be eventually removed. +func loginHandler(w http.ResponseWriter, r *http.Request, userManager UserManager, logger *slog.Logger) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + // No authentication header found + w.WriteHeader(http.StatusUnauthorized) + w.Header().Set("WWW-Authenticate", `Basic realm="User Login"`) + w.Write([]byte("401 Unauthorized\n")) + return + } + + auth := strings.SplitN(authHeader, " ", 2) + if len(auth) != 2 || auth[0] != "Basic" { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("401 Unauthorized: Missing Basic prefix\n")) + return + } + + payload, err := base64.StdEncoding.DecodeString(auth[1]) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("401 Unauthorized: Invalid Base64 Encoding\n")) + return + } + + pair := strings.SplitN(string(payload), ":", 2) + if len(pair) != 2 { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("401 Unauthorized: Invalid Authentication Token\n")) + return + } + + username, password := pair[0], pair[1] + + user, err := userManager.User(username) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("500 InternalServerError\n")) + logger.Error("error getting user", "err", err.Error()) + return + } + if user == nil || !user.ValidateHash(wire.StrongMD5PasswordHash(password, user.AuthKey)) { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("401 Unauthorized: Invalid Credentials\n")) + return + } + + // Successfully authenticated + w.WriteHeader(http.StatusOK) + w.Write([]byte("200 OK: Successfully Authenticated\n")) +} diff --git a/server/oscar/middleware/logger.go b/server/oscar/middleware/logger.go index fb10ec3e..a0a82d62 100644 --- a/server/oscar/middleware/logger.go +++ b/server/oscar/middleware/logger.go @@ -1,6 +1,7 @@ package middleware import ( + "bytes" "context" "log/slog" "os" @@ -83,8 +84,8 @@ func (rt RouteLogger) LogRequestAndResponse(ctx context.Context, inFrame wire.SN msg := "client request -> server response" switch { case rt.Logger.Enabled(ctx, LevelTrace): - rt.Logger.LogAttrs(ctx, LevelTrace, msg, snacLogGroupWithPayload("request", inFrame, inSNAC), - snacLogGroupWithPayload("response", outFrame, outSNAC)) + rt.Logger.LogAttrs(ctx, LevelTrace, msg, snacLogGroupWithPayload(rt.Logger, "request", inFrame, inSNAC), + snacLogGroupWithPayload(rt.Logger, "response", outFrame, outSNAC)) case rt.Logger.Enabled(ctx, slog.LevelDebug): rt.Logger.LogAttrs(ctx, slog.LevelDebug, msg, snacLogGroup("request", inFrame), snacLogGroup("response", outFrame)) @@ -113,7 +114,7 @@ func LogRequest(ctx context.Context, logger *slog.Logger, inFrame wire.SNACFrame const msg = "client request" switch { case logger.Enabled(ctx, LevelTrace): - logger.LogAttrs(ctx, LevelTrace, msg, snacLogGroupWithPayload("request", inFrame, inSNAC)) + logger.LogAttrs(ctx, LevelTrace, msg, snacLogGroupWithPayload(logger, "request", inFrame, inSNAC)) case logger.Enabled(ctx, slog.LevelDebug): logger.LogAttrs(ctx, slog.LevelDebug, msg, snacLogGroup("request", inFrame)) } @@ -126,11 +127,21 @@ func snacLogGroup(key string, outFrame wire.SNACFrame) slog.Attr { ) } -func snacLogGroupWithPayload(key string, outFrame wire.SNACFrame, outSNAC any) slog.Attr { +func snacLogGroupWithPayload(logger *slog.Logger, key string, outFrame wire.SNACFrame, outSNAC any) slog.Attr { + frameBuf := &bytes.Buffer{} + if err := wire.Marshal(outFrame, frameBuf); err != nil { + logger.Error("unable to marshal SNAC frame in logger", "err", err.Error()) + } + snacBuf := &bytes.Buffer{} + if outSNAC != nil { + if err := wire.Marshal(outSNAC, snacBuf); err != nil { + logger.Error("unable to marshal SNAC body in logger", "err", err.Error()) + } + } return slog.Group(key, slog.String("food_group", wire.FoodGroupName(outFrame.FoodGroup)), slog.String("sub_group", wire.SubGroupName(outFrame.FoodGroup, outFrame.SubGroup)), - slog.Any("snac_frame", outFrame), + slog.Any("snac_frame", frameBuf), slog.Any("snac_payload", outSNAC), ) }