package middleware import ( "context" "encoding/json" "errors" "strconv" "strings" "time" "github.com/go-kratos/kratos/v2/metadata" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/transport" "github.com/go-kratos/kratos/v2/transport/http" pkgjwt "kra/pkg/jwt" ) const ( // metadata keys for claims mdKeyUserID = "x-md-user-id" mdKeyUsername = "x-md-username" mdKeyAuthorityID = "x-md-authority-id" mdKeyUUID = "x-md-uuid" mdKeyClaims = "x-md-claims" ) var ( ErrMissingToken = errors.New("未登录或非法访问") ErrInvalidToken = errors.New("无效的token") ErrBlacklistJWT = errors.New("您的帐户异地登陆或令牌失效") ) // BlacklistChecker JWT黑名单检查接口 type BlacklistChecker interface { IsBlacklist(jwt string) bool } // JWTAuthConfig JWT认证配置 type JWTAuthConfig struct { JWT *pkgjwt.JWT BlacklistChecker BlacklistChecker UseMultipoint bool SetRedisJWT func(token, username string) error } // JWTAuth JWT认证中间件 func JWTAuth(cfg JWTAuthConfig) middleware.Middleware { return func(handler middleware.Handler) middleware.Handler { return func(ctx context.Context, req interface{}) (interface{}, error) { if tr, ok := transport.FromServerContext(ctx); ok { token := extractToken(tr) if token == "" { return nil, ErrMissingToken } // 检查黑名单 if cfg.BlacklistChecker != nil && cfg.BlacklistChecker.IsBlacklist(token) { // 清除token if ht, ok := tr.(http.Transporter); ok { clearToken(ht) } return nil, ErrBlacklistJWT } claims, err := cfg.JWT.ParseToken(token) if err != nil { if ht, ok := tr.(http.Transporter); ok { clearToken(ht) } if errors.Is(err, pkgjwt.ErrTokenExpired) { return nil, errors.New("登录已过期,请重新登录") } return nil, ErrInvalidToken } // 将claims存入metadata claimsJSON, _ := json.Marshal(claims) ctx = metadata.AppendToClientContext(ctx, mdKeyUserID, strconv.FormatUint(uint64(claims.BaseClaims.ID), 10), mdKeyUsername, claims.Username, mdKeyAuthorityID, strconv.FormatUint(uint64(claims.AuthorityID), 10), mdKeyUUID, claims.UUID, mdKeyClaims, string(claimsJSON), ) // 检查是否需要刷新token if cfg.JWT.NeedRefresh(claims) { newToken, newClaims, err := cfg.JWT.CreateTokenByOldToken(token, *claims) if err == nil { if ht, ok := tr.(http.Transporter); ok { ht.ReplyHeader().Set("new-token", newToken) ht.ReplyHeader().Set("new-expires-at", strconv.FormatInt(newClaims.ExpiresAt.Unix(), 10)) setToken(ht, newToken, int(time.Until(newClaims.ExpiresAt.Time).Seconds()/60)) } // 多点登录记录新JWT if cfg.UseMultipoint && cfg.SetRedisJWT != nil { _ = cfg.SetRedisJWT(newToken, newClaims.Username) } } } } return handler(ctx, req) } } } // JWTAuthSimple 简化版JWT认证中间件(不含黑名单检查) func JWTAuthSimple(j *pkgjwt.JWT) middleware.Middleware { return JWTAuth(JWTAuthConfig{JWT: j}) } // extractToken 从请求中提取token func extractToken(tr transport.Transporter) string { // 优先从Authorization头获取 auth := tr.RequestHeader().Get("Authorization") if auth != "" { if strings.HasPrefix(auth, "Bearer ") { return strings.TrimPrefix(auth, "Bearer ") } return auth } // 从x-token头获取 token := tr.RequestHeader().Get("x-token") if token != "" { return token } return "" } // clearToken 清除token func clearToken(ht http.Transporter) { ht.ReplyHeader().Set("x-token", "") } // setToken 设置token到cookie func setToken(ht http.Transporter, token string, maxAge int) { // 通过header设置cookie信息,让前端处理 ht.ReplyHeader().Set("x-token", token) } // GetClaims 从context获取claims func GetClaims(ctx context.Context) (*pkgjwt.CustomClaims, bool) { if md, ok := metadata.FromServerContext(ctx); ok { claimsStr := md.Get(mdKeyClaims) if claimsStr != "" { var claims pkgjwt.CustomClaims if err := json.Unmarshal([]byte(claimsStr), &claims); err == nil { return &claims, true } } } return nil, false } // GetUserID 从context获取用户ID func GetUserID(ctx context.Context) uint { if md, ok := metadata.FromServerContext(ctx); ok { if idStr := md.Get(mdKeyUserID); idStr != "" { if id, err := strconv.ParseUint(idStr, 10, 64); err == nil { return uint(id) } } } return 0 } // GetUsername 从context获取用户名 func GetUsername(ctx context.Context) string { if md, ok := metadata.FromServerContext(ctx); ok { return md.Get(mdKeyUsername) } return "" } // GetAuthorityID 从context获取角色ID func GetAuthorityID(ctx context.Context) uint { if md, ok := metadata.FromServerContext(ctx); ok { if idStr := md.Get(mdKeyAuthorityID); idStr != "" { if id, err := strconv.ParseUint(idStr, 10, 64); err == nil { return uint(id) } } } return 0 } // GetToken 从context获取token func GetToken(ctx context.Context) string { if tr, ok := transport.FromServerContext(ctx); ok { return extractToken(tr) } return "" }