diff --git a/docker-compose-tests.yml b/docker-compose-tests.yml index eca2cb8ce..4227bd3df 100644 --- a/docker-compose-tests.yml +++ b/docker-compose-tests.yml @@ -41,7 +41,7 @@ services: interval: 5s test: - image: "golang:1.21" + image: "golang:1.23.2" command: /bin/sh -c "mkdir -p /nakama/internal/gopher-lua/_lua5.1-tests/libs/P1; go test -v -race ./..." working_dir: "/nakama" diff --git a/internal/ctxkeys/ctxkeys.go b/internal/ctxkeys/ctxkeys.go new file mode 100644 index 000000000..a7d28e170 --- /dev/null +++ b/internal/ctxkeys/ctxkeys.go @@ -0,0 +1,9 @@ +package ctxkeys + +// Keys used for storing/retrieving user information in the context of a request after authentication. +type UserIDKey struct{} +type UsernameKey struct{} +type VarsKey struct{} +type ExpiryKey struct{} +type TokenIDKey struct{} +type TokenIssuedAtKey struct{} diff --git a/internal/satori/satori.go b/internal/satori/satori.go index 5c8d39307..2137835b3 100644 --- a/internal/satori/satori.go +++ b/internal/satori/satori.go @@ -30,37 +30,38 @@ import ( "github.com/golang-jwt/jwt/v4" "github.com/heroiclabs/nakama-common/runtime" + "github.com/heroiclabs/nakama/v3/internal/ctxkeys" "go.uber.org/zap" ) var _ runtime.Satori = &SatoriClient{} -type CtxTokenIDKey struct{} - type SatoriClient struct { - logger *zap.Logger - httpc *http.Client - url *url.URL - urlString string - apiKeyName string - apiKey string - signingKey string - tokenExpirySec int - invalidConfig bool + logger *zap.Logger + httpc *http.Client + url *url.URL + urlString string + apiKeyName string + apiKey string + signingKey string + tokenExpirySec int + nakamaTokenExpirySec int64 + invalidConfig bool } -func NewSatoriClient(logger *zap.Logger, satoriUrl, apiKeyName, apiKey, signingKey string) *SatoriClient { +func NewSatoriClient(logger *zap.Logger, satoriUrl, apiKeyName, apiKey, signingKey string, nakamaTokenExpirySec int64) *SatoriClient { parsedUrl, _ := url.Parse(satoriUrl) sc := &SatoriClient{ - logger: logger, - urlString: satoriUrl, - httpc: &http.Client{Timeout: 2 * time.Second}, - url: parsedUrl, - apiKeyName: strings.TrimSpace(apiKeyName), - apiKey: strings.TrimSpace(apiKey), - signingKey: strings.TrimSpace(signingKey), - tokenExpirySec: 3600, + logger: logger, + urlString: satoriUrl, + httpc: &http.Client{Timeout: 2 * time.Second}, + url: parsedUrl, + apiKeyName: strings.TrimSpace(apiKeyName), + apiKey: strings.TrimSpace(apiKey), + signingKey: strings.TrimSpace(signingKey), + tokenExpirySec: 3600, + nakamaTokenExpirySec: nakamaTokenExpirySec, } if sc.urlString == "" && sc.apiKeyName == "" && sc.apiKey == "" && sc.signingKey == "" { @@ -121,13 +122,34 @@ func (stc *sessionTokenClaims) Valid() error { } func (s *SatoriClient) generateToken(ctx context.Context, id string) (string, error) { - tid, _ := ctx.Value(CtxTokenIDKey{}).(string) + tid, ok := ctx.Value(ctxkeys.TokenIDKey{}).(string) + if !ok { + s.logger.Warn("satori request token id was not found in ctx") + } + tIssuedAt, ok := ctx.Value(ctxkeys.TokenIssuedAtKey{}).(int64) + if !ok { + s.logger.Warn("satori request token issued at was not found in ctx") + } + tExpirySec, ok := ctx.Value(ctxkeys.ExpiryKey{}).(int64) + if !ok { + s.logger.Warn("satori request token expires at was not found in ctx") + } + timestamp := time.Now().UTC() + if tIssuedAt == 0 && tExpirySec > s.nakamaTokenExpirySec { + // Token was issued before 'IssuedAt' had been added to the session token. + // Thus, Nakama will make a guess of that value. + tIssuedAt = tExpirySec - s.nakamaTokenExpirySec + } else if tIssuedAt == 0 { + // Unable to determine the token's issued at. + tIssuedAt = timestamp.Unix() + } + claims := sessionTokenClaims{ SessionID: tid, IdentityId: id, ExpiresAt: timestamp.Add(time.Duration(s.tokenExpirySec) * time.Second).Unix(), - IssuedAt: timestamp.Unix(), + IssuedAt: tIssuedAt, ApiKeyName: s.apiKeyName, } token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, &claims).SignedString([]byte(s.signingKey)) @@ -351,6 +373,10 @@ func (s *SatoriClient) EventsPublish(ctx context.Context, id string, events []*r case 200: return nil default: + errBody, err := io.ReadAll(res.Body) + if err == nil && len(errBody) > 0 { + return fmt.Errorf("%d status code: %s", res.StatusCode, string(errBody)) + } return fmt.Errorf("%d status code", res.StatusCode) } } @@ -395,13 +421,13 @@ func (s *SatoriClient) ExperimentsList(ctx context.Context, id string, names ... defer res.Body.Close() + resBody, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + switch res.StatusCode { case 200: - resBody, err := io.ReadAll(res.Body) - if err != nil { - return nil, err - } - var experiments runtime.ExperimentList if err = json.Unmarshal(resBody, &experiments); err != nil { return nil, err @@ -409,6 +435,10 @@ func (s *SatoriClient) ExperimentsList(ctx context.Context, id string, names ... return &experiments, nil default: + if len(resBody) > 0 { + return nil, fmt.Errorf("%d status code: %s", res.StatusCode, string(resBody)) + } + return nil, fmt.Errorf("%d status code", res.StatusCode) } } @@ -452,14 +482,13 @@ func (s *SatoriClient) FlagsList(ctx context.Context, id string, names ...string } defer res.Body.Close() + resBody, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } switch res.StatusCode { case 200: - resBody, err := io.ReadAll(res.Body) - if err != nil { - return nil, err - } - var flags runtime.FlagList if err = json.Unmarshal(resBody, &flags); err != nil { return nil, err @@ -467,6 +496,10 @@ func (s *SatoriClient) FlagsList(ctx context.Context, id string, names ...string return &flags, nil default: + if len(resBody) > 0 { + return nil, fmt.Errorf("%d status code: %s", res.StatusCode, string(resBody)) + } + return nil, fmt.Errorf("%d status code", res.StatusCode) } } @@ -510,13 +543,13 @@ func (s *SatoriClient) LiveEventsList(ctx context.Context, id string, names ...s } defer res.Body.Close() + resBody, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } switch res.StatusCode { case 200: - resBody, err := io.ReadAll(res.Body) - if err != nil { - return nil, err - } var liveEvents runtime.LiveEventList if err = json.Unmarshal(resBody, &liveEvents); err != nil { return nil, err @@ -524,6 +557,9 @@ func (s *SatoriClient) LiveEventsList(ctx context.Context, id string, names ...s return &liveEvents, nil default: + if len(resBody) > 0 { + return nil, fmt.Errorf("%d status code: %s", res.StatusCode, string(resBody)) + } return nil, fmt.Errorf("%d status code", res.StatusCode) } } @@ -572,13 +608,13 @@ func (s *SatoriClient) MessagesList(ctx context.Context, id string, limit int, f } defer res.Body.Close() + resBody, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } switch res.StatusCode { case 200: - resBody, err := io.ReadAll(res.Body) - if err != nil { - return nil, err - } var messages runtime.MessageList if err = json.Unmarshal(resBody, &messages); err != nil { return nil, err @@ -586,6 +622,9 @@ func (s *SatoriClient) MessagesList(ctx context.Context, id string, limit int, f return &messages, nil default: + if len(resBody) > 0 { + return nil, fmt.Errorf("%d status code: %s", res.StatusCode, string(resBody)) + } return nil, fmt.Errorf("%d status code", res.StatusCode) } } @@ -635,6 +674,10 @@ func (s *SatoriClient) MessageUpdate(ctx context.Context, id, messageId string, case 200: return nil default: + errBody, err := io.ReadAll(res.Body) + if err == nil && len(errBody) > 0 { + return fmt.Errorf("%d status code: %s", res.StatusCode, string(errBody)) + } return fmt.Errorf("%d status code", res.StatusCode) } } @@ -678,6 +721,10 @@ func (s *SatoriClient) MessageDelete(ctx context.Context, id, messageId string) case 200: return nil default: + errBody, err := io.ReadAll(res.Body) + if err == nil && len(errBody) > 0 { + return fmt.Errorf("%d status code: %s", res.StatusCode, string(errBody)) + } return fmt.Errorf("%d status code", res.StatusCode) } } diff --git a/internal/satori/satori_test.go b/internal/satori/satori_test.go index 128eb49de..436e4daf8 100644 --- a/internal/satori/satori_test.go +++ b/internal/satori/satori_test.go @@ -32,7 +32,7 @@ func TestSatoriClient_EventsPublish(t *testing.T) { identityID := uuid.Must(uuid.NewV4()).String() logger := NewConsoleLogger(os.Stdout, true) - client := NewSatoriClient(logger, "", "", "", "") + client := NewSatoriClient(logger, "", "", "", "", 0) ctx, ctxCancelFn := context.WithTimeout(context.Background(), 5*time.Second) defer ctxCancelFn() diff --git a/server/api.go b/server/api.go index 6e93b1931..1f0957fa2 100644 --- a/server/api.go +++ b/server/api.go @@ -39,7 +39,7 @@ import ( grpcgw "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/heroiclabs/nakama-common/api" "github.com/heroiclabs/nakama/v3/apigrpc" - "github.com/heroiclabs/nakama/v3/internal/satori" + "github.com/heroiclabs/nakama/v3/internal/ctxkeys" "github.com/heroiclabs/nakama/v3/social" "go.uber.org/zap" "google.golang.org/grpc" @@ -61,11 +61,12 @@ var once sync.Once const byteBracket byte = '{' // Keys used for storing/retrieving user information in the context of a request after authentication. -type ctxUserIDKey struct{} -type ctxUsernameKey struct{} -type ctxVarsKey struct{} -type ctxExpiryKey struct{} -type ctxTokenIDKey = satori.CtxTokenIDKey +type ctxUserIDKey = ctxkeys.UserIDKey +type ctxUsernameKey = ctxkeys.UsernameKey +type ctxVarsKey = ctxkeys.VarsKey +type ctxExpiryKey = ctxkeys.ExpiryKey +type ctxTokenIDKey = ctxkeys.TokenIDKey +type ctxTokenIssuedAtKey = ctxkeys.TokenIssuedAtKey type ctxFullMethodKey struct{} @@ -430,7 +431,7 @@ func securityInterceptorFunc(logger *zap.Logger, config Config, sessionCache Ses // Value of "authorization" or "grpc-authorization" was empty or repeated. return nil, status.Error(codes.Unauthenticated, "Auth token invalid") } - userID, username, vars, exp, tokenId, ok := parseBearerAuth([]byte(config.GetSession().EncryptionKey), auth[0]) + userID, username, vars, exp, tokenId, tokenIssuedAt, ok := parseBearerAuth([]byte(config.GetSession().EncryptionKey), auth[0]) if !ok { // Value of "authorization" or "grpc-authorization" was malformed or expired. return nil, status.Error(codes.Unauthenticated, "Auth token invalid") @@ -438,7 +439,7 @@ func securityInterceptorFunc(logger *zap.Logger, config Config, sessionCache Ses if !sessionCache.IsValidSession(userID, exp, tokenId) { return nil, status.Error(codes.Unauthenticated, "Auth token invalid") } - ctx = context.WithValue(context.WithValue(context.WithValue(context.WithValue(context.WithValue(ctx, ctxUserIDKey{}, userID), ctxUsernameKey{}, username), ctxVarsKey{}, vars), ctxExpiryKey{}, exp), ctxTokenIDKey{}, tokenId) + ctx = context.WithValue(context.WithValue(context.WithValue(context.WithValue(context.WithValue(context.WithValue(ctx, ctxUserIDKey{}, userID), ctxUsernameKey{}, username), ctxVarsKey{}, vars), ctxExpiryKey{}, exp), ctxTokenIDKey{}, tokenId), ctxTokenIssuedAtKey{}, tokenIssuedAt) default: // Unless explicitly defined above, handlers require full user authentication. md, ok := metadata.FromIncomingContext(ctx) @@ -458,7 +459,7 @@ func securityInterceptorFunc(logger *zap.Logger, config Config, sessionCache Ses // Value of "authorization" or "grpc-authorization" was empty or repeated. return nil, status.Error(codes.Unauthenticated, "Auth token invalid") } - userID, username, vars, exp, tokenId, ok := parseBearerAuth([]byte(config.GetSession().EncryptionKey), auth[0]) + userID, username, vars, exp, tokenId, tokenIssuedAt, ok := parseBearerAuth([]byte(config.GetSession().EncryptionKey), auth[0]) if !ok { // Value of "authorization" or "grpc-authorization" was malformed or expired. return nil, status.Error(codes.Unauthenticated, "Auth token invalid") @@ -466,7 +467,7 @@ func securityInterceptorFunc(logger *zap.Logger, config Config, sessionCache Ses if !sessionCache.IsValidSession(userID, exp, tokenId) { return nil, status.Error(codes.Unauthenticated, "Auth token invalid") } - ctx = context.WithValue(context.WithValue(context.WithValue(context.WithValue(context.WithValue(ctx, ctxUserIDKey{}, userID), ctxUsernameKey{}, username), ctxVarsKey{}, vars), ctxExpiryKey{}, exp), ctxTokenIDKey{}, tokenId) + ctx = context.WithValue(context.WithValue(context.WithValue(context.WithValue(context.WithValue(context.WithValue(ctx, ctxUserIDKey{}, userID), ctxUsernameKey{}, username), ctxVarsKey{}, vars), ctxExpiryKey{}, exp), ctxTokenIDKey{}, tokenId), ctxTokenIssuedAtKey{}, tokenIssuedAt) } return context.WithValue(ctx, ctxFullMethodKey{}, info.FullMethod), nil } @@ -491,7 +492,7 @@ func parseBasicAuth(auth string) (username, password string, ok bool) { return cs[:s], cs[s+1:], true } -func parseBearerAuth(hmacSecretByte []byte, auth string) (userID uuid.UUID, username string, vars map[string]string, exp int64, tokenId string, ok bool) { +func parseBearerAuth(hmacSecretByte []byte, auth string) (userID uuid.UUID, username string, vars map[string]string, exp int64, tokenId string, issuedAt int64, ok bool) { if auth == "" { return } @@ -502,7 +503,7 @@ func parseBearerAuth(hmacSecretByte []byte, auth string) (userID uuid.UUID, user return parseToken(hmacSecretByte, auth[len(prefix):]) } -func parseToken(hmacSecretByte []byte, tokenString string) (userID uuid.UUID, username string, vars map[string]string, exp int64, tokenId string, ok bool) { +func parseToken(hmacSecretByte []byte, tokenString string) (userID uuid.UUID, username string, vars map[string]string, exp int64, tokenId string, issuedAt int64, ok bool) { jwtToken, err := jwt.ParseWithClaims(tokenString, &SessionTokenClaims{}, func(token *jwt.Token) (interface{}, error) { if s, ok := token.Method.(*jwt.SigningMethodHMAC); !ok || s.Hash != crypto.SHA256 { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) @@ -520,7 +521,7 @@ func parseToken(hmacSecretByte []byte, tokenString string) (userID uuid.UUID, us if err != nil { return } - return userID, claims.Username, claims.Vars, claims.ExpiresAt, claims.TokenId, true + return userID, claims.Username, claims.Vars, claims.ExpiresAt, claims.TokenId, claims.IssuedAt, true } func decompressHandler(logger *zap.Logger, h http.Handler) http.HandlerFunc { diff --git a/server/api_authenticate.go b/server/api_authenticate.go index a3d9523c3..09957c2ab 100644 --- a/server/api_authenticate.go +++ b/server/api_authenticate.go @@ -42,6 +42,7 @@ type SessionTokenClaims struct { Username string `json:"usn,omitempty"` Vars map[string]string `json:"vrs,omitempty"` ExpiresAt int64 `json:"exp,omitempty"` + IssuedAt int64 `json:"iat,omitempty"` } func (stc *SessionTokenClaims) Valid() error { @@ -107,9 +108,10 @@ func (s *ApiServer) AuthenticateApple(ctx context.Context, in *api.AuthenticateA s.sessionCache.RemoveAll(uuid.Must(uuid.FromString(dbUserID))) } + tokenIssuedAt := time.Now().Unix() tokenID := uuid.Must(uuid.NewV4()).String() - token, exp := generateToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars) - refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars) + token, exp := generateToken(s.config, tokenID, tokenIssuedAt, dbUserID, dbUsername, in.Account.Vars) + refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, tokenIssuedAt, dbUserID, dbUsername, in.Account.Vars) s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, tokenID, refreshExp, tokenID) session := &api.Session{Created: created, Token: token, RefreshToken: refreshToken} @@ -179,8 +181,9 @@ func (s *ApiServer) AuthenticateCustom(ctx context.Context, in *api.Authenticate } tokenID := uuid.Must(uuid.NewV4()).String() - token, exp := generateToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars) - refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars) + tokenIssuedAt := time.Now().Unix() + token, exp := generateToken(s.config, tokenID, tokenIssuedAt, dbUserID, dbUsername, in.Account.Vars) + refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, tokenIssuedAt, dbUserID, dbUsername, in.Account.Vars) s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, tokenID, refreshExp, tokenID) session := &api.Session{Created: created, Token: token, RefreshToken: refreshToken} @@ -250,8 +253,9 @@ func (s *ApiServer) AuthenticateDevice(ctx context.Context, in *api.Authenticate } tokenID := uuid.Must(uuid.NewV4()).String() - token, exp := generateToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars) - refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars) + tokenIssuedAt := time.Now().Unix() + token, exp := generateToken(s.config, tokenID, tokenIssuedAt, dbUserID, dbUsername, in.Account.Vars) + refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, tokenIssuedAt, dbUserID, dbUsername, in.Account.Vars) s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, tokenID, refreshExp, tokenID) session := &api.Session{Created: created, Token: token, RefreshToken: refreshToken} @@ -351,8 +355,9 @@ func (s *ApiServer) AuthenticateEmail(ctx context.Context, in *api.AuthenticateE } tokenID := uuid.Must(uuid.NewV4()).String() - token, exp := generateToken(s.config, tokenID, dbUserID, username, in.Account.Vars) - refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, dbUserID, username, in.Account.Vars) + tokenIssuedAt := time.Now().Unix() + token, exp := generateToken(s.config, tokenID, tokenIssuedAt, dbUserID, username, in.Account.Vars) + refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, tokenIssuedAt, dbUserID, username, in.Account.Vars) s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, tokenID, refreshExp, tokenID) session := &api.Session{Created: created, Token: token, RefreshToken: refreshToken} @@ -423,8 +428,9 @@ func (s *ApiServer) AuthenticateFacebook(ctx context.Context, in *api.Authentica } tokenID := uuid.Must(uuid.NewV4()).String() - token, exp := generateToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars) - refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars) + tokenIssuedAt := time.Now().Unix() + token, exp := generateToken(s.config, tokenID, tokenIssuedAt, dbUserID, dbUsername, in.Account.Vars) + refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, tokenIssuedAt, dbUserID, dbUsername, in.Account.Vars) s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, tokenID, refreshExp, tokenID) session := &api.Session{Created: created, Token: token, RefreshToken: refreshToken} @@ -490,8 +496,9 @@ func (s *ApiServer) AuthenticateFacebookInstantGame(ctx context.Context, in *api } tokenID := uuid.Must(uuid.NewV4()).String() - token, exp := generateToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars) - refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars) + tokenIssuedAt := time.Now().Unix() + token, exp := generateToken(s.config, tokenID, tokenIssuedAt, dbUserID, dbUsername, in.Account.Vars) + refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, tokenIssuedAt, dbUserID, dbUsername, in.Account.Vars) s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, tokenID, refreshExp, tokenID) session := &api.Session{Created: created, Token: token, RefreshToken: refreshToken} @@ -569,8 +576,9 @@ func (s *ApiServer) AuthenticateGameCenter(ctx context.Context, in *api.Authenti } tokenID := uuid.Must(uuid.NewV4()).String() - token, exp := generateToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars) - refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars) + tokenIssuedAt := time.Now().Unix() + token, exp := generateToken(s.config, tokenID, tokenIssuedAt, dbUserID, dbUsername, in.Account.Vars) + refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, tokenIssuedAt, dbUserID, dbUsername, in.Account.Vars) s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, tokenID, refreshExp, tokenID) session := &api.Session{Created: created, Token: token, RefreshToken: refreshToken} @@ -636,8 +644,9 @@ func (s *ApiServer) AuthenticateGoogle(ctx context.Context, in *api.Authenticate } tokenID := uuid.Must(uuid.NewV4()).String() - token, exp := generateToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars) - refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars) + tokenIssuedAt := time.Now().Unix() + token, exp := generateToken(s.config, tokenID, tokenIssuedAt, dbUserID, dbUsername, in.Account.Vars) + refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, tokenIssuedAt, dbUserID, dbUsername, in.Account.Vars) s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, tokenID, refreshExp, tokenID) session := &api.Session{Created: created, Token: token, RefreshToken: refreshToken} @@ -712,8 +721,9 @@ func (s *ApiServer) AuthenticateSteam(ctx context.Context, in *api.AuthenticateS } tokenID := uuid.Must(uuid.NewV4()).String() - token, exp := generateToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars) - refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, dbUserID, dbUsername, in.Account.Vars) + tokenIssuedAt := time.Now().Unix() + token, exp := generateToken(s.config, tokenID, tokenIssuedAt, dbUserID, dbUsername, in.Account.Vars) + refreshToken, refreshExp := generateRefreshToken(s.config, tokenID, tokenIssuedAt, dbUserID, dbUsername, in.Account.Vars) s.sessionCache.Add(uuid.FromStringOrNil(dbUserID), exp, tokenID, refreshExp, tokenID) session := &api.Session{Created: created, Token: token, RefreshToken: refreshToken} @@ -730,23 +740,24 @@ func (s *ApiServer) AuthenticateSteam(ctx context.Context, in *api.AuthenticateS return session, nil } -func generateToken(config Config, tokenID, userID, username string, vars map[string]string) (string, int64) { +func generateToken(config Config, tokenID string, tokenIssuedAt int64, userID, username string, vars map[string]string) (string, int64) { exp := time.Now().UTC().Add(time.Duration(config.GetSession().TokenExpirySec) * time.Second).Unix() - return generateTokenWithExpiry(config.GetSession().EncryptionKey, tokenID, userID, username, vars, exp) + return generateTokenWithExpiry(config.GetSession().EncryptionKey, tokenID, tokenIssuedAt, userID, username, vars, exp) } -func generateRefreshToken(config Config, tokenID, userID string, username string, vars map[string]string) (string, int64) { +func generateRefreshToken(config Config, tokenID string, tokenIssuedAt int64, userID string, username string, vars map[string]string) (string, int64) { exp := time.Now().UTC().Add(time.Duration(config.GetSession().RefreshTokenExpirySec) * time.Second).Unix() - return generateTokenWithExpiry(config.GetSession().RefreshEncryptionKey, tokenID, userID, username, vars, exp) + return generateTokenWithExpiry(config.GetSession().RefreshEncryptionKey, tokenID, tokenIssuedAt, userID, username, vars, exp) } -func generateTokenWithExpiry(signingKey, tokenID, userID, username string, vars map[string]string, exp int64) (string, int64) { +func generateTokenWithExpiry(signingKey, tokenID string, tokenIssuedAt int64, userID, username string, vars map[string]string, exp int64) (string, int64) { token := jwt.NewWithClaims(jwt.SigningMethodHS256, &SessionTokenClaims{ TokenId: tokenID, UserId: userID, Username: username, Vars: vars, ExpiresAt: exp, + IssuedAt: tokenIssuedAt, }) signedToken, _ := token.SignedString([]byte(signingKey)) return signedToken, exp diff --git a/server/api_rpc.go b/server/api_rpc.go index 81615ef3d..839800e71 100644 --- a/server/api_rpc.go +++ b/server/api_rpc.go @@ -51,6 +51,7 @@ func (s *ApiServer) RpcFuncHttp(w http.ResponseWriter, r *http.Request) { var username string var vars map[string]string var expiry int64 + requestCtx := r.Context() if httpKey := queryParams.Get("http_key"); httpKey != "" { if httpKey != s.config.GetRuntime().HTTPKey { // HTTP key did not match. @@ -75,9 +76,15 @@ func (s *ApiServer) RpcFuncHttp(w http.ResponseWriter, r *http.Request) { return } } else { - var token string - userID, username, vars, expiry, token, isTokenAuth = parseBearerAuth([]byte(s.config.GetSession().EncryptionKey), auth[0]) - if !isTokenAuth || !s.sessionCache.IsValidSession(userID, expiry, token) { + var tokenId string + var tokenIssuedAt int64 + userID, username, vars, expiry, tokenId, tokenIssuedAt, isTokenAuth = parseBearerAuth([]byte(s.config.GetSession().EncryptionKey), auth[0]) + requestCtx = context.WithValue(requestCtx, ctxTokenIDKey{}, tokenId) + requestCtx = context.WithValue(requestCtx, ctxExpiryKey{}, expiry) + requestCtx = context.WithValue(requestCtx, ctxTokenIssuedAtKey{}, tokenIssuedAt) + requestCtx = context.WithValue(requestCtx, ctxVarsKey{}, vars) + + if !isTokenAuth || !s.sessionCache.IsValidSession(userID, expiry, tokenId) { // Auth token not valid or expired. w.Header().Set("content-type", "application/json") w.WriteHeader(http.StatusUnauthorized) @@ -206,7 +213,7 @@ func (s *ApiServer) RpcFuncHttp(w http.ResponseWriter, r *http.Request) { } // Execute the function. - result, fnErr, code := fn(r.Context(), headers, queryParams, uid, username, vars, expiry, "", clientIP, clientPort, "", payload) + result, fnErr, code := fn(requestCtx, headers, queryParams, uid, username, vars, expiry, "", clientIP, clientPort, "", payload) if fnErr != nil { response, _ := json.Marshal(map[string]interface{}{"error": fnErr, "message": fnErr.Error(), "code": code}) w.Header().Set("content-type", "application/json") diff --git a/server/api_session.go b/server/api_session.go index ff275a4b5..9f268fa34 100644 --- a/server/api_session.go +++ b/server/api_session.go @@ -53,7 +53,7 @@ func (s *ApiServer) SessionRefresh(ctx context.Context, in *api.SessionRefreshRe return nil, status.Error(codes.InvalidArgument, "Refresh token is required.") } - userID, username, vars, tokenId, err := SessionRefresh(ctx, s.logger, s.db, s.config, s.sessionCache, in.Token) + userID, username, vars, tokenId, tokenIssuedAt, err := SessionRefresh(ctx, s.logger, s.db, s.config, s.sessionCache, in.Token) if err != nil { return nil, err } @@ -72,8 +72,8 @@ func (s *ApiServer) SessionRefresh(ctx context.Context, in *api.SessionRefreshRe //s.sessionCache.Add(userID, tokenExp, newTokenId, refreshTokenExp, newTokenId) //session := &api.Session{Created: false, Token: token, RefreshToken: refreshToken} - token, tokenExp := generateToken(s.config, tokenId, userIDStr, username, useVars) - refreshToken, refreshTokenExp := generateRefreshToken(s.config, tokenId, userIDStr, username, useVars) + token, tokenExp := generateToken(s.config, tokenId, tokenIssuedAt, userIDStr, username, useVars) + refreshToken, refreshTokenExp := generateRefreshToken(s.config, tokenId, tokenIssuedAt, userIDStr, username, useVars) s.sessionCache.Add(userID, tokenExp, tokenId, refreshTokenExp, tokenId) session := &api.Session{Created: false, Token: token, RefreshToken: refreshToken} diff --git a/server/core_session.go b/server/core_session.go index 4b2852563..0c856d144 100644 --- a/server/core_session.go +++ b/server/core_session.go @@ -31,13 +31,13 @@ var ( ErrRefreshTokenInvalid = errors.New("refresh token invalid") ) -func SessionRefresh(ctx context.Context, logger *zap.Logger, db *sql.DB, config Config, sessionCache SessionCache, token string) (uuid.UUID, string, map[string]string, string, error) { - userID, _, vars, exp, tokenId, ok := parseToken([]byte(config.GetSession().RefreshEncryptionKey), token) +func SessionRefresh(ctx context.Context, logger *zap.Logger, db *sql.DB, config Config, sessionCache SessionCache, token string) (uuid.UUID, string, map[string]string, string, int64, error) { + userID, _, vars, exp, tokenId, tokenIssuedAt, ok := parseToken([]byte(config.GetSession().RefreshEncryptionKey), token) if !ok { - return uuid.Nil, "", nil, "", status.Error(codes.Unauthenticated, "Refresh token invalid or expired.") + return uuid.Nil, "", nil, "", 0, status.Error(codes.Unauthenticated, "Refresh token invalid or expired.") } if !sessionCache.IsValidRefresh(userID, exp, tokenId) { - return uuid.Nil, "", nil, "", status.Error(codes.Unauthenticated, "Refresh token invalid or expired.") + return uuid.Nil, "", nil, "", 0, status.Error(codes.Unauthenticated, "Refresh token invalid or expired.") } // Look for an existing account. @@ -48,19 +48,19 @@ func SessionRefresh(ctx context.Context, logger *zap.Logger, db *sql.DB, config if err != nil { if err == sql.ErrNoRows { // Account not found and creation is never allowed for this type. - return uuid.Nil, "", nil, "", status.Error(codes.NotFound, "User account not found.") + return uuid.Nil, "", nil, "", 0, status.Error(codes.NotFound, "User account not found.") } logger.Error("Error looking up user by ID.", zap.Error(err), zap.String("id", userID.String())) - return uuid.Nil, "", nil, "", status.Error(codes.Internal, "Error finding user account.") + return uuid.Nil, "", nil, "", 0, status.Error(codes.Internal, "Error finding user account.") } // Check if it's disabled. if dbDisableTime.Valid && dbDisableTime.Time.Unix() != 0 { logger.Info("User account is disabled.", zap.String("id", userID.String())) - return uuid.Nil, "", nil, "", status.Error(codes.PermissionDenied, "User account banned.") + return uuid.Nil, "", nil, "", 0, status.Error(codes.PermissionDenied, "User account banned.") } - return userID, dbUsername, vars, tokenId, nil + return userID, dbUsername, vars, tokenId, tokenIssuedAt, nil } func SessionLogout(config Config, sessionCache SessionCache, userID uuid.UUID, token, refreshToken string) error { @@ -69,7 +69,7 @@ func SessionLogout(config Config, sessionCache SessionCache, userID uuid.UUID, t if token != "" { var sessionUserID uuid.UUID var ok bool - sessionUserID, _, _, maybeSessionExp, maybeSessionTokenId, ok = parseToken([]byte(config.GetSession().EncryptionKey), token) + sessionUserID, _, _, maybeSessionExp, maybeSessionTokenId, _, ok = parseToken([]byte(config.GetSession().EncryptionKey), token) if !ok || sessionUserID != userID { return ErrSessionTokenInvalid } @@ -80,7 +80,7 @@ func SessionLogout(config Config, sessionCache SessionCache, userID uuid.UUID, t if refreshToken != "" { var refreshUserID uuid.UUID var ok bool - refreshUserID, _, _, maybeRefreshExp, maybeRefreshTokenId, ok = parseToken([]byte(config.GetSession().RefreshEncryptionKey), refreshToken) + refreshUserID, _, _, maybeRefreshExp, maybeRefreshTokenId, _, ok = parseToken([]byte(config.GetSession().RefreshEncryptionKey), refreshToken) if !ok || refreshUserID != userID { return ErrRefreshTokenInvalid } diff --git a/server/runtime_go_nakama.go b/server/runtime_go_nakama.go index f6beadb5f..2225d0a8e 100644 --- a/server/runtime_go_nakama.go +++ b/server/runtime_go_nakama.go @@ -89,7 +89,14 @@ func NewRuntimeGoNakamaModule(logger *zap.Logger, db *sql.DB, protojsonMarshaler node: config.GetName(), - satori: satori.NewSatoriClient(logger, config.GetSatori().Url, config.GetSatori().ApiKeyName, config.GetSatori().ApiKey, config.GetSatori().SigningKey), + satori: satori.NewSatoriClient( + logger, + config.GetSatori().Url, + config.GetSatori().ApiKeyName, + config.GetSatori().ApiKey, + config.GetSatori().SigningKey, + config.GetSession().TokenExpirySec, + ), } } @@ -425,7 +432,8 @@ func (n *RuntimeGoNakamaModule) AuthenticateTokenGenerate(userID, username strin } tokenId := uuid.Must(uuid.NewV4()).String() - token, exp := generateTokenWithExpiry(n.config.GetSession().EncryptionKey, tokenId, userID, username, vars, exp) + tokenIssuedAt := time.Now().Unix() + token, exp := generateTokenWithExpiry(n.config.GetSession().EncryptionKey, tokenId, tokenIssuedAt, userID, username, vars, exp) n.sessionCache.Add(uid, exp, tokenId, 0, "") return token, exp, nil } diff --git a/server/runtime_javascript_nakama.go b/server/runtime_javascript_nakama.go index 8b0f0a174..8fb1095cb 100644 --- a/server/runtime_javascript_nakama.go +++ b/server/runtime_javascript_nakama.go @@ -117,11 +117,19 @@ func NewRuntimeJavascriptNakamaModule(logger *zap.Logger, db *sql.DB, protojsonM eventFn: eventFn, matchCreateFn: matchCreateFn, - satori: satori.NewSatoriClient(logger, config.GetSatori().Url, config.GetSatori().ApiKeyName, config.GetSatori().ApiKey, config.GetSatori().SigningKey), + satori: satori.NewSatoriClient( + logger, + config.GetSatori().Url, + config.GetSatori().ApiKeyName, + config.GetSatori().ApiKey, + config.GetSatori().SigningKey, + config.GetSession().TokenExpirySec, + ), } } -func (n *runtimeJavascriptNakamaModule) Constructor(r *goja.Runtime) (*goja.Object, error) { +func (n *runtimeJavascriptNakamaModule) Constructor(r *goja.Runtime) (*goja.Object, + error) { satoriJsObj, err := n.satoriConstructor(r) if err != nil { return nil, err @@ -1868,7 +1876,8 @@ func (n *runtimeJavascriptNakamaModule) authenticateTokenGenerate(r *goja.Runtim } tokenId := uuid.Must(uuid.NewV4()).String() - token, exp := generateTokenWithExpiry(n.config.GetSession().EncryptionKey, tokenId, userIDString, username, vars, exp) + tokenIssuedAt := time.Now().Unix() + token, exp := generateTokenWithExpiry(n.config.GetSession().EncryptionKey, tokenId, tokenIssuedAt, userIDString, username, vars, exp) n.sessionCache.Add(uid, exp, tokenId, 0, "") return r.ToValue(map[string]interface{}{ diff --git a/server/runtime_lua_nakama.go b/server/runtime_lua_nakama.go index 3ed26b8bd..d97579de6 100644 --- a/server/runtime_lua_nakama.go +++ b/server/runtime_lua_nakama.go @@ -126,7 +126,14 @@ func NewRuntimeLuaNakamaModule(logger *zap.Logger, db *sql.DB, protojsonMarshale matchCreateFn: matchCreateFn, eventFn: eventFn, - satori: satori.NewSatoriClient(logger, config.GetSatori().Url, config.GetSatori().ApiKeyName, config.GetSatori().ApiKey, config.GetSatori().SigningKey), + satori: satori.NewSatoriClient( + logger, + config.GetSatori().Url, + config.GetSatori().ApiKeyName, + config.GetSatori().ApiKey, + config.GetSatori().SigningKey, + config.GetSession().TokenExpirySec, + ), } } @@ -2260,7 +2267,8 @@ func (n *RuntimeLuaNakamaModule) authenticateTokenGenerate(l *lua.LState) int { } tokenId := uuid.Must(uuid.NewV4()).String() - token, exp := generateTokenWithExpiry(n.config.GetSession().EncryptionKey, tokenId, userIDString, username, varsMap, exp) + tokenIssuedAt := time.Now().Unix() + token, exp := generateTokenWithExpiry(n.config.GetSession().EncryptionKey, tokenId, tokenIssuedAt, userIDString, username, varsMap, exp) n.sessionCache.Add(uid, exp, tokenId, 0, "") l.Push(lua.LString(token)) diff --git a/server/socket_ws.go b/server/socket_ws.go index 0af807d2b..517a6e185 100644 --- a/server/socket_ws.go +++ b/server/socket_ws.go @@ -71,7 +71,7 @@ func NewSocketWsAcceptor(logger *zap.Logger, config Config, sessionRegistry Sess http.Error(w, "Missing or invalid token", 401) return } - userID, username, vars, expiry, _, ok := parseToken([]byte(config.GetSession().EncryptionKey), token) + userID, username, vars, expiry, _, _, ok := parseToken([]byte(config.GetSession().EncryptionKey), token) if !ok || !sessionCache.IsValidSession(userID, expiry, token) { http.Error(w, "Missing or invalid token", 401) return