Skip to content

Commit

Permalink
implement temporary credentials validation endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
mk6i committed May 27, 2024
1 parent 3a41734 commit a372557
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 5 deletions.
60 changes: 60 additions & 0 deletions server/http/mgmt_api.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
package http

import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net"
"net/http"
"os"
"strings"

"github.com/google/uuid"

"github.com/mk6i/retro-aim-server/config"
"github.com/mk6i/retro-aim-server/state"
"github.com/mk6i/retro-aim-server/wire"
)

type userWithPassword struct {
Expand All @@ -33,6 +36,7 @@ type UserManager interface {
AllUsers() ([]state.User, error)
InsertUser(u state.User) error
SetUserPassword(u state.User) error
User(screenName string) (*state.User, error)
}

type SessionRetriever interface {
Expand All @@ -50,6 +54,9 @@ func StartManagementAPI(cfg config.Config, userManager UserManager, sessionRetri
mux.HandleFunc("/user/password", func(w http.ResponseWriter, r *http.Request) {
userPasswordHandler(w, r, userManager, newUser, logger)
})
mux.HandleFunc("/user/login", func(w http.ResponseWriter, r *http.Request) {
loginHandler(w, r, userManager, logger)
})
mux.HandleFunc("/session", func(w http.ResponseWriter, r *http.Request) {
sessionHandler(w, r, sessionRetriever)
})
Expand Down Expand Up @@ -207,3 +214,56 @@ func postUserHandler(
w.WriteHeader(http.StatusCreated)
fmt.Fprintln(w, "User account created successfully.")
}

// loginHandler is a temporary endpoint for validating user credentials for
// chivanet. do not rely on this endpoint, as it will be eventually removed.
func loginHandler(w http.ResponseWriter, r *http.Request, userManager UserManager, logger *slog.Logger) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
// No authentication header found
w.WriteHeader(http.StatusUnauthorized)
w.Header().Set("WWW-Authenticate", `Basic realm="User Login"`)
w.Write([]byte("401 Unauthorized\n"))
return
}

auth := strings.SplitN(authHeader, " ", 2)
if len(auth) != 2 || auth[0] != "Basic" {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("401 Unauthorized: Missing Basic prefix\n"))
return
}

payload, err := base64.StdEncoding.DecodeString(auth[1])
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("401 Unauthorized: Invalid Base64 Encoding\n"))
return
}

pair := strings.SplitN(string(payload), ":", 2)
if len(pair) != 2 {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("401 Unauthorized: Invalid Authentication Token\n"))
return
}

username, password := pair[0], pair[1]

user, err := userManager.User(username)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("500 InternalServerError\n"))
logger.Error("error getting user", "err", err.Error())
return
}
if user == nil || !user.ValidateHash(wire.StrongMD5PasswordHash(password, user.AuthKey)) {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("401 Unauthorized: Invalid Credentials\n"))
return
}

// Successfully authenticated
w.WriteHeader(http.StatusOK)
w.Write([]byte("200 OK: Successfully Authenticated\n"))
}
21 changes: 16 additions & 5 deletions server/oscar/middleware/logger.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package middleware

import (
"bytes"
"context"
"log/slog"
"os"
Expand Down Expand Up @@ -83,8 +84,8 @@ func (rt RouteLogger) LogRequestAndResponse(ctx context.Context, inFrame wire.SN
msg := "client request -> server response"
switch {
case rt.Logger.Enabled(ctx, LevelTrace):
rt.Logger.LogAttrs(ctx, LevelTrace, msg, snacLogGroupWithPayload("request", inFrame, inSNAC),
snacLogGroupWithPayload("response", outFrame, outSNAC))
rt.Logger.LogAttrs(ctx, LevelTrace, msg, snacLogGroupWithPayload(rt.Logger, "request", inFrame, inSNAC),
snacLogGroupWithPayload(rt.Logger, "response", outFrame, outSNAC))
case rt.Logger.Enabled(ctx, slog.LevelDebug):
rt.Logger.LogAttrs(ctx, slog.LevelDebug, msg, snacLogGroup("request", inFrame),
snacLogGroup("response", outFrame))
Expand Down Expand Up @@ -113,7 +114,7 @@ func LogRequest(ctx context.Context, logger *slog.Logger, inFrame wire.SNACFrame
const msg = "client request"
switch {
case logger.Enabled(ctx, LevelTrace):
logger.LogAttrs(ctx, LevelTrace, msg, snacLogGroupWithPayload("request", inFrame, inSNAC))
logger.LogAttrs(ctx, LevelTrace, msg, snacLogGroupWithPayload(logger, "request", inFrame, inSNAC))
case logger.Enabled(ctx, slog.LevelDebug):
logger.LogAttrs(ctx, slog.LevelDebug, msg, snacLogGroup("request", inFrame))
}
Expand All @@ -126,11 +127,21 @@ func snacLogGroup(key string, outFrame wire.SNACFrame) slog.Attr {
)
}

func snacLogGroupWithPayload(key string, outFrame wire.SNACFrame, outSNAC any) slog.Attr {
func snacLogGroupWithPayload(logger *slog.Logger, key string, outFrame wire.SNACFrame, outSNAC any) slog.Attr {
frameBuf := &bytes.Buffer{}
if err := wire.Marshal(outFrame, frameBuf); err != nil {
logger.Error("unable to marshal SNAC frame in logger", "err", err.Error())
}
snacBuf := &bytes.Buffer{}
if outSNAC != nil {
if err := wire.Marshal(outSNAC, snacBuf); err != nil {
logger.Error("unable to marshal SNAC body in logger", "err", err.Error())
}
}
return slog.Group(key,
slog.String("food_group", wire.FoodGroupName(outFrame.FoodGroup)),
slog.String("sub_group", wire.SubGroupName(outFrame.FoodGroup, outFrame.SubGroup)),
slog.Any("snac_frame", outFrame),
slog.Any("snac_frame", frameBuf),
slog.Any("snac_payload", outSNAC),
)
}

0 comments on commit a372557

Please sign in to comment.