197 lines
5.1 KiB
Go
197 lines
5.1 KiB
Go
package middleware
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/go-kratos/kratos/v2/metadata"
|
||
"github.com/go-kratos/kratos/v2/middleware"
|
||
"github.com/go-kratos/kratos/v2/transport"
|
||
"github.com/go-kratos/kratos/v2/transport/http"
|
||
|
||
pkgjwt "kra/pkg/jwt"
|
||
)
|
||
|
||
const (
|
||
// metadata keys for claims
|
||
mdKeyUserID = "x-md-user-id"
|
||
mdKeyUsername = "x-md-username"
|
||
mdKeyAuthorityID = "x-md-authority-id"
|
||
mdKeyUUID = "x-md-uuid"
|
||
mdKeyClaims = "x-md-claims"
|
||
)
|
||
|
||
var (
|
||
ErrMissingToken = errors.New("未登录或非法访问")
|
||
ErrInvalidToken = errors.New("无效的token")
|
||
ErrBlacklistJWT = errors.New("您的帐户异地登陆或令牌失效")
|
||
)
|
||
|
||
// BlacklistChecker JWT黑名单检查接口
|
||
type BlacklistChecker interface {
|
||
IsBlacklist(jwt string) bool
|
||
}
|
||
|
||
// JWTAuthConfig JWT认证配置
|
||
type JWTAuthConfig struct {
|
||
JWT *pkgjwt.JWT
|
||
BlacklistChecker BlacklistChecker
|
||
UseMultipoint bool
|
||
SetRedisJWT func(token, username string) error
|
||
}
|
||
|
||
// JWTAuth JWT认证中间件
|
||
func JWTAuth(cfg JWTAuthConfig) middleware.Middleware {
|
||
return func(handler middleware.Handler) middleware.Handler {
|
||
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
||
if tr, ok := transport.FromServerContext(ctx); ok {
|
||
token := extractToken(tr)
|
||
if token == "" {
|
||
return nil, ErrMissingToken
|
||
}
|
||
|
||
// 检查黑名单
|
||
if cfg.BlacklistChecker != nil && cfg.BlacklistChecker.IsBlacklist(token) {
|
||
// 清除token
|
||
if ht, ok := tr.(http.Transporter); ok {
|
||
clearToken(ht)
|
||
}
|
||
return nil, ErrBlacklistJWT
|
||
}
|
||
|
||
claims, err := cfg.JWT.ParseToken(token)
|
||
if err != nil {
|
||
if ht, ok := tr.(http.Transporter); ok {
|
||
clearToken(ht)
|
||
}
|
||
if errors.Is(err, pkgjwt.ErrTokenExpired) {
|
||
return nil, errors.New("登录已过期,请重新登录")
|
||
}
|
||
return nil, ErrInvalidToken
|
||
}
|
||
|
||
// 将claims存入metadata
|
||
claimsJSON, _ := json.Marshal(claims)
|
||
ctx = metadata.AppendToClientContext(ctx,
|
||
mdKeyUserID, strconv.FormatUint(uint64(claims.BaseClaims.ID), 10),
|
||
mdKeyUsername, claims.Username,
|
||
mdKeyAuthorityID, strconv.FormatUint(uint64(claims.AuthorityID), 10),
|
||
mdKeyUUID, claims.UUID,
|
||
mdKeyClaims, string(claimsJSON),
|
||
)
|
||
|
||
// 检查是否需要刷新token
|
||
if cfg.JWT.NeedRefresh(claims) {
|
||
newToken, newClaims, err := cfg.JWT.CreateTokenByOldToken(token, *claims)
|
||
if err == nil {
|
||
if ht, ok := tr.(http.Transporter); ok {
|
||
ht.ReplyHeader().Set("new-token", newToken)
|
||
ht.ReplyHeader().Set("new-expires-at", strconv.FormatInt(newClaims.ExpiresAt.Unix(), 10))
|
||
setToken(ht, newToken, int(time.Until(newClaims.ExpiresAt.Time).Seconds()/60))
|
||
}
|
||
// 多点登录记录新JWT
|
||
if cfg.UseMultipoint && cfg.SetRedisJWT != nil {
|
||
_ = cfg.SetRedisJWT(newToken, newClaims.Username)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
return handler(ctx, req)
|
||
}
|
||
}
|
||
}
|
||
|
||
// JWTAuthSimple 简化版JWT认证中间件(不含黑名单检查)
|
||
func JWTAuthSimple(j *pkgjwt.JWT) middleware.Middleware {
|
||
return JWTAuth(JWTAuthConfig{JWT: j})
|
||
}
|
||
|
||
// extractToken 从请求中提取token
|
||
func extractToken(tr transport.Transporter) string {
|
||
// 优先从Authorization头获取
|
||
auth := tr.RequestHeader().Get("Authorization")
|
||
if auth != "" {
|
||
if strings.HasPrefix(auth, "Bearer ") {
|
||
return strings.TrimPrefix(auth, "Bearer ")
|
||
}
|
||
return auth
|
||
}
|
||
|
||
// 从x-token头获取
|
||
token := tr.RequestHeader().Get("x-token")
|
||
if token != "" {
|
||
return token
|
||
}
|
||
|
||
return ""
|
||
}
|
||
|
||
// clearToken 清除token
|
||
func clearToken(ht http.Transporter) {
|
||
ht.ReplyHeader().Set("x-token", "")
|
||
}
|
||
|
||
// setToken 设置token到cookie
|
||
func setToken(ht http.Transporter, token string, maxAge int) {
|
||
// 通过header设置cookie信息,让前端处理
|
||
ht.ReplyHeader().Set("x-token", token)
|
||
}
|
||
|
||
// GetClaims 从context获取claims
|
||
func GetClaims(ctx context.Context) (*pkgjwt.CustomClaims, bool) {
|
||
if md, ok := metadata.FromServerContext(ctx); ok {
|
||
claimsStr := md.Get(mdKeyClaims)
|
||
if claimsStr != "" {
|
||
var claims pkgjwt.CustomClaims
|
||
if err := json.Unmarshal([]byte(claimsStr), &claims); err == nil {
|
||
return &claims, true
|
||
}
|
||
}
|
||
}
|
||
return nil, false
|
||
}
|
||
|
||
// GetUserID 从context获取用户ID
|
||
func GetUserID(ctx context.Context) uint {
|
||
if md, ok := metadata.FromServerContext(ctx); ok {
|
||
if idStr := md.Get(mdKeyUserID); idStr != "" {
|
||
if id, err := strconv.ParseUint(idStr, 10, 64); err == nil {
|
||
return uint(id)
|
||
}
|
||
}
|
||
}
|
||
return 0
|
||
}
|
||
|
||
// GetUsername 从context获取用户名
|
||
func GetUsername(ctx context.Context) string {
|
||
if md, ok := metadata.FromServerContext(ctx); ok {
|
||
return md.Get(mdKeyUsername)
|
||
}
|
||
return ""
|
||
}
|
||
|
||
// GetAuthorityID 从context获取角色ID
|
||
func GetAuthorityID(ctx context.Context) uint {
|
||
if md, ok := metadata.FromServerContext(ctx); ok {
|
||
if idStr := md.Get(mdKeyAuthorityID); idStr != "" {
|
||
if id, err := strconv.ParseUint(idStr, 10, 64); err == nil {
|
||
return uint(id)
|
||
}
|
||
}
|
||
}
|
||
return 0
|
||
}
|
||
|
||
// GetToken 从context获取token
|
||
func GetToken(ctx context.Context) string {
|
||
if tr, ok := transport.FromServerContext(ctx); ok {
|
||
return extractToken(tr)
|
||
}
|
||
return ""
|
||
}
|