kra/internal/data/system/jwt_blacklist.go

72 lines
1.9 KiB
Go

package system
import (
"context"
"time"
"kra/internal/biz/system"
"kra/internal/data/model"
"kra/internal/data/query"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
type jwtBlacklistRepo struct {
db *gorm.DB
redisClient redis.UniversalClient
expiresTime time.Duration
}
// NewJwtBlacklistRepo 创建JWT黑名单仓储
func NewJwtBlacklistRepo(db *gorm.DB) system.JwtBlacklistRepo {
query.SetDefault(db)
return &jwtBlacklistRepo{
db: db,
redisClient: nil, // 默认不使用Redis
expiresTime: 7 * 24 * time.Hour,
}
}
// NewJwtBlacklistRepoWithRedis 创建带Redis的JWT黑名单仓储
func NewJwtBlacklistRepoWithRedis(db *gorm.DB, redisClient redis.UniversalClient, expiresTime time.Duration) system.JwtBlacklistRepo {
query.SetDefault(db)
return &jwtBlacklistRepo{
db: db,
redisClient: redisClient,
expiresTime: expiresTime,
}
}
func (r *jwtBlacklistRepo) Create(ctx context.Context, jwt string) error {
m := &model.JwtBlacklist{Jwt: jwt}
return query.JwtBlacklist.WithContext(ctx).Create(m)
}
func (r *jwtBlacklistRepo) IsInBlacklist(ctx context.Context, jwt string) bool {
count, _ := query.JwtBlacklist.WithContext(ctx).Where(query.JwtBlacklist.Jwt.Eq(jwt)).Count()
return count > 0
}
func (r *jwtBlacklistRepo) LoadAll(ctx context.Context) ([]string, error) {
var jwts []string
err := r.db.WithContext(ctx).Model(&model.JwtBlacklist{}).Pluck("jwt", &jwts).Error
return jwts, err
}
// GetRedisJWT 从Redis获取JWT
func (r *jwtBlacklistRepo) GetRedisJWT(ctx context.Context, username string) (string, error) {
if r.redisClient == nil {
return "", redis.Nil
}
return r.redisClient.Get(ctx, "jwt:"+username).Result()
}
// SetRedisJWT 设置Redis JWT
func (r *jwtBlacklistRepo) SetRedisJWT(ctx context.Context, token, username string) error {
if r.redisClient == nil {
return nil
}
return r.redisClient.Set(ctx, "jwt:"+username, token, r.expiresTime).Err()
}