Skip to content

Commit

Permalink
fix: too many tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
withchao committed Oct 28, 2024
1 parent f1ae650 commit f102134
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 87 deletions.
44 changes: 3 additions & 41 deletions internal/api/mw/mw.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/openimsdk/chat/pkg/common/constant"
"github.com/openimsdk/chat/pkg/protocol/admin"
constantpb "github.com/openimsdk/protocol/constant"
"github.com/openimsdk/tools/apiresp"
"github.com/openimsdk/tools/errs"
)
Expand Down Expand Up @@ -56,74 +55,37 @@ func (o *MW) parseTokenType(c *gin.Context, userType int32) (string, string, err
return userID, token, nil
}

func (o *MW) isValidToken(c *gin.Context, userID string, token string) error {
resp, err := o.client.GetUserToken(c, &admin.GetUserTokenReq{UserID: userID})
if err != nil {
return err
}
if len(resp.TokensMap) == 0 {
return errs.ErrTokenExpired.Wrap()
}
if v, ok := resp.TokensMap[token]; ok {
switch v {
case constantpb.NormalToken:
case constantpb.KickedToken:
return errs.ErrTokenExpired.Wrap()
default:
return errs.ErrTokenUnknown.Wrap()
}
} else {
return errs.ErrTokenExpired.Wrap()
}
return nil
}

func (o *MW) setToken(c *gin.Context, userID string, userType int32) {
SetToken(c, userID, userType)
}

func (o *MW) CheckToken(c *gin.Context) {
userID, userType, token, err := o.parseToken(c)
userID, userType, _, err := o.parseToken(c)
if err != nil {
c.Abort()
apiresp.GinError(c, err)
return
}
if err := o.isValidToken(c, userID, token); err != nil {
c.Abort()
apiresp.GinError(c, err)
return
}
o.setToken(c, userID, userType)
}

func (o *MW) CheckAdmin(c *gin.Context) {
userID, token, err := o.parseTokenType(c, constant.AdminUser)
userID, _, err := o.parseTokenType(c, constant.AdminUser)
if err != nil {
c.Abort()
apiresp.GinError(c, err)
return
}
if err := o.isValidToken(c, userID, token); err != nil {
c.Abort()
apiresp.GinError(c, err)
return
}
o.setToken(c, userID, constant.AdminUser)
}

func (o *MW) CheckUser(c *gin.Context) {
userID, token, err := o.parseTokenType(c, constant.NormalUser)
userID, _, err := o.parseTokenType(c, constant.NormalUser)
if err != nil {
c.Abort()
apiresp.GinError(c, err)
return
}
if err := o.isValidToken(c, userID, token); err != nil {
c.Abort()
apiresp.GinError(c, err)
return
}
o.setToken(c, userID, constant.NormalUser)
}

Expand Down
10 changes: 5 additions & 5 deletions internal/rpc/admin/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryReg
return err
}
var srv adminServer
srv.Database, err = database.NewAdminDatabase(mgocli, rdb)
srv.Token = &tokenverify.Token{
Expires: time.Duration(config.RpcConfig.TokenPolicy.Expire) * time.Hour * 24,
Secret: config.RpcConfig.Secret,
}
srv.Database, err = database.NewAdminDatabase(mgocli, rdb, srv.Token)
if err != nil {
return err
}
Expand All @@ -56,10 +60,6 @@ func Start(ctx context.Context, config *Config, client discovery.SvcDiscoveryReg
return err
}
srv.Chat = chatClient.NewChatClient(chat.NewChatClient(conn))
srv.Token = &tokenverify.Token{
Expires: time.Duration(config.RpcConfig.TokenPolicy.Expire) * time.Hour * 24,
Secret: config.RpcConfig.Secret,
}
if err := srv.initAdmin(ctx, config.Share.ChatAdmin, config.Share.OpenIM.AdminUserID); err != nil {
return err
}
Expand Down
1 change: 0 additions & 1 deletion internal/rpc/admin/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (

func (o *adminServer) CreateToken(ctx context.Context, req *adminpb.CreateTokenReq) (*adminpb.CreateTokenResp, error) {
token, expire, err := o.Token.CreateToken(req.UserID, req.UserType)

if err != nil {
return nil, err
}
Expand Down
97 changes: 71 additions & 26 deletions pkg/common/db/cache/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,55 +16,36 @@ package cache

import (
"context"

"github.com/openimsdk/chat/pkg/common/tokenverify"
"github.com/openimsdk/tools/utils/stringutil"
"sort"
"time"

"github.com/openimsdk/tools/errs"
"github.com/redis/go-redis/v9"
)

const (
chatToken = "CHAT_UID_TOKEN_STATUS:"
chatToken = "CHAT_UID_TOKEN_STATUS:"
userMaxTokenNum = 20
)

type TokenInterface interface {
AddTokenFlag(ctx context.Context, userID string, token string, flag int) error
AddTokenFlagNXEx(ctx context.Context, userID string, token string, flag int, expire time.Duration) (bool, error)
SetTokenExpire(ctx context.Context, userID string, token string, expire time.Duration) error
GetTokensWithoutError(ctx context.Context, userID string) (map[string]int32, error)
DeleteTokenByUid(ctx context.Context, userID string) error
}

type TokenCacheRedis struct {
token *tokenverify.Token
rdb redis.UniversalClient
accessExpire int64
}

func NewTokenInterface(rdb redis.UniversalClient) *TokenCacheRedis {
func NewTokenInterface(rdb redis.UniversalClient, token *tokenverify.Token) *TokenCacheRedis {
return &TokenCacheRedis{rdb: rdb}
}

func (t *TokenCacheRedis) AddTokenFlag(ctx context.Context, userID string, token string, flag int) error {
key := chatToken + userID
return errs.Wrap(t.rdb.HSet(ctx, key, token, flag).Err())
}

func (t *TokenCacheRedis) AddTokenFlagNXEx(ctx context.Context, userID string, token string, flag int, expire time.Duration) (bool, error) {
key := chatToken + userID
isSet, err := t.rdb.HSetNX(ctx, key, token, flag).Result()
if err != nil {
return false, errs.Wrap(err)
}
if !isSet {
// key already exists
return false, nil
}
if err = t.rdb.Expire(ctx, key, expire).Err(); err != nil {
return false, errs.Wrap(err)
}
return isSet, nil
}

func (t *TokenCacheRedis) GetTokensWithoutError(ctx context.Context, userID string) (map[string]int32, error) {
key := chatToken + userID
m, err := t.rdb.HGetAll(ctx, key).Result()
Expand All @@ -82,3 +63,67 @@ func (t *TokenCacheRedis) DeleteTokenByUid(ctx context.Context, userID string) e
key := chatToken + userID
return errs.Wrap(t.rdb.Del(ctx, key).Err())
}

func (t *TokenCacheRedis) SetTokenExpire(ctx context.Context, userID string, token string, expire time.Duration) error {
key := chatToken + userID
if err := t.rdb.HSet(ctx, key, token, "0").Err(); err != nil {
return errs.Wrap(err)
}
if err := t.rdb.Expire(ctx, key, expire).Err(); err != nil {
return errs.Wrap(err)
}
mm, err := t.rdb.HGetAll(ctx, key).Result()
if err != nil {
return errs.Wrap(err)
}
if len(mm) <= 1 {
return nil
}
var (
fields []string
ts tokenTimes
)
now := time.Now()
for k := range mm {
if k == token {
continue
}
val := t.token.GetExpire(k)
if val.IsZero() || val.Before(now) {
fields = append(fields, k)
} else {
ts = append(ts, tokenTime{Token: k, Time: val})
}
}
var sorted bool
for i := len(mm) - len(fields); i > userMaxTokenNum; i-- {
if !sorted {
sorted = true
sort.Sort(ts)
}
fields = append(fields, ts[i].Token)
}
if err := t.rdb.HDel(ctx, key, fields...).Err(); err != nil {
return errs.Wrap(err)
}
return nil
}

type tokenTime struct {
Token string
Time time.Time
}

type tokenTimes []tokenTime

func (t tokenTimes) Len() int {
return len(t)
}

func (t tokenTimes) Less(i, j int) bool {
return t[i].Time.Before(t[j].Time)
}

func (t tokenTimes) Swap(i, j int) {
t[i], t[j] = t[j], t[i]
}
18 changes: 4 additions & 14 deletions pkg/common/db/database/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ package database

import (
"context"
"github.com/openimsdk/chat/pkg/common/tokenverify"
"time"

"github.com/openimsdk/chat/pkg/common/db/cache"
"github.com/openimsdk/protocol/constant"
"github.com/openimsdk/tools/db/mongoutil"
"github.com/openimsdk/tools/db/pagination"
"github.com/openimsdk/tools/db/tx"
Expand Down Expand Up @@ -80,7 +80,7 @@ type AdminDatabaseInterface interface {
DeleteToken(ctx context.Context, userID string) error
}

func NewAdminDatabase(cli *mongoutil.Client, rdb redis.UniversalClient) (AdminDatabaseInterface, error) {
func NewAdminDatabase(cli *mongoutil.Client, rdb redis.UniversalClient, token *tokenverify.Token) (AdminDatabaseInterface, error) {
a, err := admin.NewAdmin(cli.GetDB())
if err != nil {
return nil, err
Expand Down Expand Up @@ -128,7 +128,7 @@ func NewAdminDatabase(cli *mongoutil.Client, rdb redis.UniversalClient) (AdminDa
registerAddGroup: registerAddGroup,
applet: applet,
clientConfig: clientConfig,
cache: cache.NewTokenInterface(rdb),
cache: cache.NewTokenInterface(rdb, token),
}, nil
}

Expand Down Expand Up @@ -327,17 +327,7 @@ func (o *AdminDatabase) GetLimitUserLoginIP(ctx context.Context, userID string,
}

func (o *AdminDatabase) CacheToken(ctx context.Context, userID string, token string, expire time.Duration) error {
isSet, err := o.cache.AddTokenFlagNXEx(ctx, userID, token, constant.NormalToken, expire)
if err != nil {
return err
}
if !isSet {
// already exists, update
if err = o.cache.AddTokenFlag(ctx, userID, token, constant.NormalToken); err != nil {
return err
}
}
return nil
return o.cache.SetTokenExpire(ctx, userID, token, expire)
}

func (o *AdminDatabase) GetTokens(ctx context.Context, userID string) (map[string]int32, error) {
Expand Down
15 changes: 15 additions & 0 deletions pkg/common/tokenverify/token_verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,21 @@ func (t *Token) GetToken(token string) (string, int32, error) {
return userID, userType, nil
}

func (t *Token) GetExpire(token string) time.Time {
val, err := jwt.ParseWithClaims(token, &claims{}, t.secret())
if err != nil {
return time.Time{}
}
c, ok := val.Claims.(*claims)
if !ok {
return time.Time{}
}
if c.ExpiresAt == nil {
return time.Time{}
}
return c.ExpiresAt.Time
}

//func (t *Token) GetAdminToken(token string) (string, error) {
// userID, userType, err := getToken(token)
// if err != nil {
Expand Down

0 comments on commit f102134

Please sign in to comment.