kra/internal/server/middleware/jwt.go

197 lines
5.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"context"
"encoding/json"
"errors"
"strconv"
"strings"
"time"
"github.com/go-kratos/kratos/v2/metadata"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
"github.com/go-kratos/kratos/v2/transport/http"
pkgjwt "kra/pkg/jwt"
)
const (
// metadata keys for claims
mdKeyUserID = "x-md-user-id"
mdKeyUsername = "x-md-username"
mdKeyAuthorityID = "x-md-authority-id"
mdKeyUUID = "x-md-uuid"
mdKeyClaims = "x-md-claims"
)
var (
ErrMissingToken = errors.New("未登录或非法访问")
ErrInvalidToken = errors.New("无效的token")
ErrBlacklistJWT = errors.New("您的帐户异地登陆或令牌失效")
)
// BlacklistChecker JWT黑名单检查接口
type BlacklistChecker interface {
IsBlacklist(jwt string) bool
}
// JWTAuthConfig JWT认证配置
type JWTAuthConfig struct {
JWT *pkgjwt.JWT
BlacklistChecker BlacklistChecker
UseMultipoint bool
SetRedisJWT func(token, username string) error
}
// JWTAuth JWT认证中间件
func JWTAuth(cfg JWTAuthConfig) middleware.Middleware {
return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (interface{}, error) {
if tr, ok := transport.FromServerContext(ctx); ok {
token := extractToken(tr)
if token == "" {
return nil, ErrMissingToken
}
// 检查黑名单
if cfg.BlacklistChecker != nil && cfg.BlacklistChecker.IsBlacklist(token) {
// 清除token
if ht, ok := tr.(http.Transporter); ok {
clearToken(ht)
}
return nil, ErrBlacklistJWT
}
claims, err := cfg.JWT.ParseToken(token)
if err != nil {
if ht, ok := tr.(http.Transporter); ok {
clearToken(ht)
}
if errors.Is(err, pkgjwt.ErrTokenExpired) {
return nil, errors.New("登录已过期,请重新登录")
}
return nil, ErrInvalidToken
}
// 将claims存入metadata
claimsJSON, _ := json.Marshal(claims)
ctx = metadata.AppendToClientContext(ctx,
mdKeyUserID, strconv.FormatUint(uint64(claims.BaseClaims.ID), 10),
mdKeyUsername, claims.Username,
mdKeyAuthorityID, strconv.FormatUint(uint64(claims.AuthorityID), 10),
mdKeyUUID, claims.UUID,
mdKeyClaims, string(claimsJSON),
)
// 检查是否需要刷新token
if cfg.JWT.NeedRefresh(claims) {
newToken, newClaims, err := cfg.JWT.CreateTokenByOldToken(token, *claims)
if err == nil {
if ht, ok := tr.(http.Transporter); ok {
ht.ReplyHeader().Set("new-token", newToken)
ht.ReplyHeader().Set("new-expires-at", strconv.FormatInt(newClaims.ExpiresAt.Unix(), 10))
setToken(ht, newToken, int(time.Until(newClaims.ExpiresAt.Time).Seconds()/60))
}
// 多点登录记录新JWT
if cfg.UseMultipoint && cfg.SetRedisJWT != nil {
_ = cfg.SetRedisJWT(newToken, newClaims.Username)
}
}
}
}
return handler(ctx, req)
}
}
}
// JWTAuthSimple 简化版JWT认证中间件不含黑名单检查
func JWTAuthSimple(j *pkgjwt.JWT) middleware.Middleware {
return JWTAuth(JWTAuthConfig{JWT: j})
}
// extractToken 从请求中提取token
func extractToken(tr transport.Transporter) string {
// 优先从Authorization头获取
auth := tr.RequestHeader().Get("Authorization")
if auth != "" {
if strings.HasPrefix(auth, "Bearer ") {
return strings.TrimPrefix(auth, "Bearer ")
}
return auth
}
// 从x-token头获取
token := tr.RequestHeader().Get("x-token")
if token != "" {
return token
}
return ""
}
// clearToken 清除token
func clearToken(ht http.Transporter) {
ht.ReplyHeader().Set("x-token", "")
}
// setToken 设置token到cookie
func setToken(ht http.Transporter, token string, maxAge int) {
// 通过header设置cookie信息让前端处理
ht.ReplyHeader().Set("x-token", token)
}
// GetClaims 从context获取claims
func GetClaims(ctx context.Context) (*pkgjwt.CustomClaims, bool) {
if md, ok := metadata.FromServerContext(ctx); ok {
claimsStr := md.Get(mdKeyClaims)
if claimsStr != "" {
var claims pkgjwt.CustomClaims
if err := json.Unmarshal([]byte(claimsStr), &claims); err == nil {
return &claims, true
}
}
}
return nil, false
}
// GetUserID 从context获取用户ID
func GetUserID(ctx context.Context) uint {
if md, ok := metadata.FromServerContext(ctx); ok {
if idStr := md.Get(mdKeyUserID); idStr != "" {
if id, err := strconv.ParseUint(idStr, 10, 64); err == nil {
return uint(id)
}
}
}
return 0
}
// GetUsername 从context获取用户名
func GetUsername(ctx context.Context) string {
if md, ok := metadata.FromServerContext(ctx); ok {
return md.Get(mdKeyUsername)
}
return ""
}
// GetAuthorityID 从context获取角色ID
func GetAuthorityID(ctx context.Context) uint {
if md, ok := metadata.FromServerContext(ctx); ok {
if idStr := md.Get(mdKeyAuthorityID); idStr != "" {
if id, err := strconv.ParseUint(idStr, 10, 64); err == nil {
return uint(id)
}
}
}
return 0
}
// GetToken 从context获取token
func GetToken(ctx context.Context) string {
if tr, ok := transport.FromServerContext(ctx); ok {
return extractToken(tr)
}
return ""
}