任务二完成

This commit is contained in:
Yvan 2026-01-07 10:35:49 +08:00
parent d572a03654
commit 7d27c70453
11 changed files with 968 additions and 1 deletions

View File

@ -0,0 +1,64 @@
package middleware
import (
"context"
"errors"
"strconv"
"strings"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
pkgcasbin "kra/pkg/casbin"
)
var (
ErrPermissionDenied = errors.New("权限不足")
)
// CasbinRBAC Casbin权限中间件
func CasbinRBAC(routerPrefix string) middleware.Middleware {
return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (interface{}, error) {
claims, ok := GetClaims(ctx)
if !ok {
return nil, ErrMissingToken
}
if tr, ok := transport.FromServerContext(ctx); ok {
// 获取请求路径
path := tr.Operation()
if routerPrefix != "" {
path = strings.TrimPrefix(path, routerPrefix)
}
// 获取请求方法
act := "GET"
if header := tr.RequestHeader(); header != nil {
if method := header.Get(":method"); method != "" {
act = method
}
}
// 获取用户角色
sub := strconv.Itoa(int(claims.AuthorityID))
// 检查权限
enforcer := pkgcasbin.GetEnforcer()
if enforcer == nil {
return handler(ctx, req)
}
success, err := enforcer.Enforce(sub, path, act)
if err != nil {
return nil, err
}
if !success {
return nil, ErrPermissionDenied
}
}
return handler(ctx, req)
}
}
}

View File

@ -0,0 +1,107 @@
package middleware
import (
"context"
"net/http"
"github.com/go-kratos/kratos/v2/errors"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
kratoshttp "github.com/go-kratos/kratos/v2/transport/http"
)
// CORSWhitelist CORS白名单配置
type CORSWhitelist struct {
AllowOrigin string
AllowHeaders string
AllowMethods string
ExposeHeaders string
AllowCredentials bool
}
// CORSConfig CORS配置
type CORSConfig struct {
// Mode 模式: allow-all, whitelist, strict-whitelist
Mode string
Whitelist []CORSWhitelist
}
// CORS 跨域中间件(放行所有)
func CORS() middleware.Middleware {
return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (interface{}, error) {
if tr, ok := transport.FromServerContext(ctx); ok {
if ht, ok := tr.(kratoshttp.Transporter); ok {
origin := ht.Request().Header.Get("Origin")
header := ht.ReplyHeader()
header.Set("Access-Control-Allow-Origin", origin)
header.Set("Access-Control-Allow-Headers", "Content-Type,AccessToken,X-CSRF-Token,Authorization,Token,X-Token,X-User-Id")
header.Set("Access-Control-Allow-Methods", "POST,GET,OPTIONS,DELETE,PUT")
header.Set("Access-Control-Expose-Headers", "Content-Length,Access-Control-Allow-Origin,Access-Control-Allow-Headers,Content-Type,New-Token,New-Expires-At")
header.Set("Access-Control-Allow-Credentials", "true")
// OPTIONS请求直接返回
if ht.Request().Method == http.MethodOptions {
return nil, nil
}
}
}
return handler(ctx, req)
}
}
}
// CORSByRules 按配置规则处理跨域
func CORSByRules(cfg CORSConfig) middleware.Middleware {
// 放行全部
if cfg.Mode == "allow-all" {
return CORS()
}
return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (interface{}, error) {
if tr, ok := transport.FromServerContext(ctx); ok {
if ht, ok := tr.(kratoshttp.Transporter); ok {
r := ht.Request()
origin := r.Header.Get("Origin")
whitelist := checkCors(origin, cfg.Whitelist)
header := ht.ReplyHeader()
// 通过检查,添加请求头
if whitelist != nil {
header.Set("Access-Control-Allow-Origin", whitelist.AllowOrigin)
header.Set("Access-Control-Allow-Headers", whitelist.AllowHeaders)
header.Set("Access-Control-Allow-Methods", whitelist.AllowMethods)
header.Set("Access-Control-Expose-Headers", whitelist.ExposeHeaders)
if whitelist.AllowCredentials {
header.Set("Access-Control-Allow-Credentials", "true")
}
}
// 严格白名单模式且未通过检查,直接拒绝
if whitelist == nil && cfg.Mode == "strict-whitelist" {
// 健康检查放行
if !(r.Method == http.MethodGet && r.URL.Path == "/health") {
return nil, errors.Forbidden("CORS_FORBIDDEN", "forbidden")
}
}
// OPTIONS请求直接返回
if r.Method == http.MethodOptions {
return nil, nil
}
}
}
return handler(ctx, req)
}
}
}
func checkCors(currentOrigin string, whitelist []CORSWhitelist) *CORSWhitelist {
for _, w := range whitelist {
if currentOrigin == w.AllowOrigin {
return &w
}
}
return nil
}

View File

@ -0,0 +1,78 @@
package middleware
import (
"bytes"
"context"
"fmt"
"io"
"time"
"github.com/go-kratos/kratos/v2/log"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
kratoshttp "github.com/go-kratos/kratos/v2/transport/http"
)
// EmailSender 邮件发送接口
type EmailSender interface {
SendError(subject, body string) error
}
// ErrorToEmail 错误邮件通知中间件
func ErrorToEmail(sender EmailSender, logger log.Logger) middleware.Middleware {
helper := log.NewHelper(logger)
return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (interface{}, error) {
var (
username string
body string
method string
path string
ip string
)
// 获取用户信息
if claims, ok := GetClaims(ctx); ok {
username = claims.Username
}
if username == "" {
username = "Unknown"
}
if tr, ok := transport.FromServerContext(ctx); ok {
path = tr.Operation()
if header := tr.RequestHeader(); header != nil {
method = header.Get(":method")
}
if ht, ok := tr.(kratoshttp.Transporter); ok {
r := ht.Request()
ip = getClientIP(r)
// 读取body
bodyBytes, _ := io.ReadAll(r.Body)
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
body = string(bodyBytes)
}
}
start := time.Now()
reply, err := handler(ctx, req)
latency := time.Since(start)
// 如果有错误,发送邮件
if err != nil && sender != nil {
subject := fmt.Sprintf("%s %s调用了%s报错了", username, ip, path)
content := fmt.Sprintf(
"接收到的请求为%s\n请求方式为%s\n报错信息如下%s\n耗时%s\n",
body, method, err.Error(), latency.String(),
)
if sendErr := sender.SendError(subject, content); sendErr != nil {
helper.Error("ErrorToEmail Failed, err:", sendErr)
}
}
return reply, err
}
}
}

View File

@ -0,0 +1,156 @@
package middleware
import (
"context"
"errors"
"strconv"
"strings"
"time"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
"github.com/go-kratos/kratos/v2/transport/http"
pkgjwt "kra/pkg/jwt"
)
type authKey struct{}
var (
ErrMissingToken = errors.New("未登录或非法访问")
ErrInvalidToken = errors.New("无效的token")
ErrBlacklistJWT = errors.New("您的帐户异地登陆或令牌失效")
)
// BlacklistChecker JWT黑名单检查接口
type BlacklistChecker interface {
IsBlacklist(jwt string) bool
}
// JWTAuthConfig JWT认证配置
type JWTAuthConfig struct {
JWT *pkgjwt.JWT
BlacklistChecker BlacklistChecker
UseMultipoint bool
SetRedisJWT func(token, username string) error
}
// JWTAuth JWT认证中间件
func JWTAuth(cfg JWTAuthConfig) middleware.Middleware {
return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (interface{}, error) {
if tr, ok := transport.FromServerContext(ctx); ok {
token := extractToken(tr)
if token == "" {
return nil, ErrMissingToken
}
// 检查黑名单
if cfg.BlacklistChecker != nil && cfg.BlacklistChecker.IsBlacklist(token) {
// 清除token
if ht, ok := tr.(http.Transporter); ok {
clearToken(ht)
}
return nil, ErrBlacklistJWT
}
claims, err := cfg.JWT.ParseToken(token)
if err != nil {
if ht, ok := tr.(http.Transporter); ok {
clearToken(ht)
}
if errors.Is(err, pkgjwt.ErrTokenExpired) {
return nil, errors.New("登录已过期,请重新登录")
}
return nil, ErrInvalidToken
}
// 将claims存入context
ctx = context.WithValue(ctx, authKey{}, claims)
// 检查是否需要刷新token
if cfg.JWT.NeedRefresh(claims) {
newToken, newClaims, err := cfg.JWT.CreateTokenByOldToken(token, *claims)
if err == nil {
if ht, ok := tr.(http.Transporter); ok {
ht.ReplyHeader().Set("new-token", newToken)
ht.ReplyHeader().Set("new-expires-at", strconv.FormatInt(newClaims.ExpiresAt.Unix(), 10))
setToken(ht, newToken, int(time.Until(newClaims.ExpiresAt.Time).Seconds()/60))
}
// 多点登录记录新JWT
if cfg.UseMultipoint && cfg.SetRedisJWT != nil {
_ = cfg.SetRedisJWT(newToken, newClaims.Username)
}
}
}
}
return handler(ctx, req)
}
}
}
// JWTAuthSimple 简化版JWT认证中间件不含黑名单检查
func JWTAuthSimple(j *pkgjwt.JWT) middleware.Middleware {
return JWTAuth(JWTAuthConfig{JWT: j})
}
// extractToken 从请求中提取token
func extractToken(tr transport.Transporter) string {
// 优先从Authorization头获取
auth := tr.RequestHeader().Get("Authorization")
if auth != "" {
if strings.HasPrefix(auth, "Bearer ") {
return strings.TrimPrefix(auth, "Bearer ")
}
return auth
}
// 从x-token头获取
token := tr.RequestHeader().Get("x-token")
if token != "" {
return token
}
return ""
}
// clearToken 清除token
func clearToken(ht http.Transporter) {
ht.ReplyHeader().Set("x-token", "")
}
// setToken 设置token到cookie
func setToken(ht http.Transporter, token string, maxAge int) {
// 通过header设置cookie信息让前端处理
ht.ReplyHeader().Set("x-token", token)
}
// GetClaims 从context获取claims
func GetClaims(ctx context.Context) (*pkgjwt.CustomClaims, bool) {
claims, ok := ctx.Value(authKey{}).(*pkgjwt.CustomClaims)
return claims, ok
}
// GetUserID 从context获取用户ID
func GetUserID(ctx context.Context) uint {
if claims, ok := GetClaims(ctx); ok {
return claims.BaseClaims.ID
}
return 0
}
// GetUsername 从context获取用户名
func GetUsername(ctx context.Context) string {
if claims, ok := GetClaims(ctx); ok {
return claims.Username
}
return ""
}
// GetAuthorityID 从context获取角色ID
func GetAuthorityID(ctx context.Context) uint {
if claims, ok := GetClaims(ctx); ok {
return claims.AuthorityID
}
return 0
}

View File

@ -0,0 +1,133 @@
package middleware
import (
"bytes"
"context"
"encoding/json"
"io"
"time"
"github.com/go-kratos/kratos/v2/log"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
kratoshttp "github.com/go-kratos/kratos/v2/transport/http"
)
// LogLayout 日志layout
type LogLayout struct {
Time time.Time `json:"time"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
Path string `json:"path"`
Query string `json:"query,omitempty"`
Body string `json:"body,omitempty"`
IP string `json:"ip"`
UserAgent string `json:"user_agent"`
Error string `json:"error,omitempty"`
Cost time.Duration `json:"cost"`
Source string `json:"source"`
}
// LoggerConfig 日志中间件配置
type LoggerConfig struct {
// Filter 用户自定义过滤
Filter func(ctx context.Context, path string) bool
// FilterKeyword 关键字过滤
FilterKeyword func(layout *LogLayout) bool
// AuthProcess 鉴权处理
AuthProcess func(ctx context.Context, layout *LogLayout)
// Print 日志处理
Print func(LogLayout)
// Source 服务唯一标识
Source string
}
// Logger 日志中间件
func Logger(cfg LoggerConfig) middleware.Middleware {
if cfg.Print == nil {
cfg.Print = func(layout LogLayout) {
v, _ := json.Marshal(layout)
log.Info(string(v))
}
}
if cfg.Source == "" {
cfg.Source = "Kratos"
}
return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (interface{}, error) {
var (
path string
query string
body string
ip string
userAgent string
)
if tr, ok := transport.FromServerContext(ctx); ok {
path = tr.Operation()
if header := tr.RequestHeader(); header != nil {
userAgent = header.Get("User-Agent")
}
if ht, ok := tr.(kratoshttp.Transporter); ok {
r := ht.Request()
query = r.URL.RawQuery
ip = getClientIP(r)
// 读取body仅在未过滤时
if cfg.Filter == nil || !cfg.Filter(ctx, path) {
bodyBytes, _ := io.ReadAll(r.Body)
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
body = string(bodyBytes)
}
}
}
// 检查是否需要过滤
if cfg.Filter != nil && cfg.Filter(ctx, path) {
return handler(ctx, req)
}
start := time.Now()
reply, err := handler(ctx, req)
cost := time.Since(start)
layout := LogLayout{
Time: time.Now(),
Path: path,
Query: query,
Body: body,
IP: ip,
UserAgent: userAgent,
Cost: cost,
Source: cfg.Source,
}
if err != nil {
layout.Error = err.Error()
}
// 处理鉴权信息
if cfg.AuthProcess != nil {
cfg.AuthProcess(ctx, &layout)
}
// 关键字过滤
if cfg.FilterKeyword != nil {
cfg.FilterKeyword(&layout)
}
// 输出日志
cfg.Print(layout)
return reply, err
}
}
}
// DefaultLogger 默认日志中间件
func DefaultLogger() middleware.Middleware {
return Logger(LoggerConfig{
Source: "Kratos",
})
}

View File

@ -0,0 +1,155 @@
package middleware
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/go-kratos/kratos/v2/log"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
kratoshttp "github.com/go-kratos/kratos/v2/transport/http"
"gorm.io/gorm"
)
const maxBodySize = 1024
// OperationRecord 操作记录
type OperationRecord struct {
ID uint `gorm:"primarykey"`
CreatedAt time.Time `gorm:"index"`
Ip string `gorm:"column:ip;comment:请求ip"`
Method string `gorm:"column:method;comment:请求方法"`
Path string `gorm:"column:path;comment:请求路径"`
Status int `gorm:"column:status;comment:请求状态"`
Latency time.Duration `gorm:"column:latency;comment:延迟"`
Agent string `gorm:"column:agent;comment:代理"`
ErrorMessage string `gorm:"column:error_message;comment:错误信息"`
Body string `gorm:"type:text;column:body;comment:请求Body"`
Resp string `gorm:"type:text;column:resp;comment:响应Body"`
UserID int `gorm:"column:user_id;comment:用户id"`
}
func (OperationRecord) TableName() string {
return "sys_operation_records"
}
// OperationLog 操作日志中间件
func OperationLog(db *gorm.DB, logger log.Logger) middleware.Middleware {
return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (interface{}, error) {
var (
body string
userID int
ip string
method string
path string
agent string
)
if tr, ok := transport.FromServerContext(ctx); ok {
path = tr.Operation()
if header := tr.RequestHeader(); header != nil {
method = header.Get(":method")
agent = header.Get("User-Agent")
}
// 获取HTTP请求信息
if ht, ok := tr.(kratoshttp.Transporter); ok {
r := ht.Request()
ip = getClientIP(r)
if r.Method != http.MethodGet {
bodyBytes, _ := io.ReadAll(r.Body)
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
if len(bodyBytes) > maxBodySize {
body = "[超出记录长度]"
} else {
body = string(bodyBytes)
}
} else {
query := r.URL.RawQuery
query, _ = url.QueryUnescape(query)
params := parseQuery(query)
bodyBytes, _ := json.Marshal(params)
body = string(bodyBytes)
}
// 检查是否为文件上传
if strings.Contains(r.Header.Get("Content-Type"), "multipart/form-data") {
body = "[文件]"
}
}
}
// 获取用户ID
if claims, ok := GetClaims(ctx); ok {
userID = int(claims.BaseClaims.ID)
} else {
// 尝试从header获取
if tr, ok := transport.FromServerContext(ctx); ok {
if ht, ok := tr.(kratoshttp.Transporter); ok {
if idStr := ht.Request().Header.Get("x-user-id"); idStr != "" {
if id, err := strconv.Atoi(idStr); err == nil {
userID = id
}
}
}
}
}
start := time.Now()
reply, err := handler(ctx, req)
latency := time.Since(start)
// 记录响应
var resp string
if reply != nil {
respBytes, _ := json.Marshal(reply)
if len(respBytes) > maxBodySize {
resp = "[超出记录长度]"
} else {
resp = string(respBytes)
}
}
// 记录错误
var errMsg string
status := 200
if err != nil {
errMsg = err.Error()
status = 500
}
// 保存记录
record := OperationRecord{
Ip: ip,
Method: method,
Path: path,
Status: status,
Latency: latency,
Agent: agent,
ErrorMessage: errMsg,
Body: body,
Resp: resp,
UserID: userID,
}
if db != nil {
go func() {
if err := db.Create(&record).Error; err != nil {
log.NewHelper(logger).Error("create operation record error:", err)
}
}()
}
return reply, err
}
}
}

View File

@ -0,0 +1,108 @@
package middleware
import (
"context"
"fmt"
"time"
"github.com/go-kratos/kratos/v2/errors"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
kratoshttp "github.com/go-kratos/kratos/v2/transport/http"
"github.com/redis/go-redis/v9"
)
// LimitConfig 限流配置
type LimitConfig struct {
// GenerationKey 根据业务生成key
GenerationKey func(ctx context.Context, path string) string
// CheckOrMark 检查函数
CheckOrMark func(ctx context.Context, key string, expire int, limit int) error
// Expire key过期时间
Expire int
// Limit 周期内限制次数
Limit int
}
// RateLimit 限流中间件
func RateLimit(cfg LimitConfig) middleware.Middleware {
return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (interface{}, error) {
var key string
if tr, ok := transport.FromServerContext(ctx); ok {
path := tr.Operation()
key = cfg.GenerationKey(ctx, path)
}
if err := cfg.CheckOrMark(ctx, key, cfg.Expire, cfg.Limit); err != nil {
return nil, errors.New(429, "RATE_LIMIT", err.Error())
}
return handler(ctx, req)
}
}
}
// DefaultGenerationKey 默认生成key基于IP
func DefaultGenerationKey(ctx context.Context, path string) string {
if tr, ok := transport.FromServerContext(ctx); ok {
if ht, ok := tr.(kratoshttp.Transporter); ok {
return "Limit_" + getClientIP(ht.Request())
}
}
return "Limit_unknown"
}
// NewRedisCheckOrMark 创建基于Redis的检查函数
func NewRedisCheckOrMark(rdb *redis.Client) func(ctx context.Context, key string, expire int, limit int) error {
return func(ctx context.Context, key string, expire int, limit int) error {
if rdb == nil {
return nil
}
return SetLimitWithTime(ctx, rdb, key, limit, time.Duration(expire)*time.Second)
}
}
// SetLimitWithTime 设置访问次数
func SetLimitWithTime(ctx context.Context, rdb *redis.Client, key string, limit int, expiration time.Duration) error {
count, err := rdb.Exists(ctx, key).Result()
if err != nil {
return err
}
if count == 0 {
// key不存在创建并设置过期时间
pipe := rdb.TxPipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, expiration)
_, err = pipe.Exec(ctx)
return err
}
// key存在检查次数
times, err := rdb.Get(ctx, key).Int()
if err != nil {
return err
}
if times >= limit {
// 获取剩余时间
ttl, err := rdb.PTTL(ctx, key).Result()
if err != nil {
return fmt.Errorf("请求太过频繁,请稍后再试")
}
return fmt.Errorf("请求太过频繁, 请 %s 后尝试", ttl.String())
}
return rdb.Incr(ctx, key).Err()
}
// DefaultRateLimit 默认限流中间件
func DefaultRateLimit(rdb *redis.Client, expire, limit int) middleware.Middleware {
return RateLimit(LimitConfig{
GenerationKey: DefaultGenerationKey,
CheckOrMark: NewRedisCheckOrMark(rdb),
Expire: expire,
Limit: limit,
})
}

View File

@ -0,0 +1,55 @@
package middleware
import (
"context"
"net"
"os"
"runtime/debug"
"strings"
"github.com/go-kratos/kratos/v2/errors"
"github.com/go-kratos/kratos/v2/log"
"github.com/go-kratos/kratos/v2/middleware"
)
// Recovery panic恢复中间件
func Recovery(logger log.Logger) middleware.Middleware {
helper := log.NewHelper(logger)
return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
defer func() {
if rerr := recover(); rerr != nil {
// 检查是否为断开的连接
var brokenPipe bool
if ne, ok := rerr.(*net.OpError); ok {
if se, ok := ne.Err.(*os.SyscallError); ok {
errStr := strings.ToLower(se.Error())
if strings.Contains(errStr, "broken pipe") ||
strings.Contains(errStr, "connection reset by peer") {
brokenPipe = true
}
}
}
stack := string(debug.Stack())
if brokenPipe {
helper.Errorw(
"msg", "[Recovery from panic]",
"error", rerr,
"broken_pipe", true,
)
} else {
helper.Errorw(
"msg", "[Recovery from panic]",
"error", rerr,
"stack", stack,
)
}
err = errors.InternalServer("PANIC", "服务器内部错误")
}
}()
return handler(ctx, req)
}
}
}

View File

@ -0,0 +1,53 @@
package middleware
import (
"context"
"time"
"github.com/go-kratos/kratos/v2/errors"
"github.com/go-kratos/kratos/v2/middleware"
)
// Timeout 超时中间件
func Timeout(timeout time.Duration) middleware.Middleware {
return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (interface{}, error) {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// 使用 buffered channel 避免 goroutine 泄漏
type result struct {
reply interface{}
err error
}
done := make(chan result, 1)
panicChan := make(chan interface{}, 1)
go func() {
defer func() {
if p := recover(); p != nil {
select {
case panicChan <- p:
default:
}
}
}()
reply, err := handler(ctx, req)
select {
case done <- result{reply: reply, err: err}:
default:
}
}()
select {
case p := <-panicChan:
panic(p)
case r := <-done:
return r.reply, r.err
case <-ctx.Done():
return nil, errors.GatewayTimeout("TIMEOUT", "请求超时")
}
}
}
}

View File

@ -0,0 +1,43 @@
package middleware
import (
"net/http"
"strings"
)
// getClientIP 获取客户端IP
func getClientIP(r *http.Request) string {
// X-Forwarded-For
xff := r.Header.Get("X-Forwarded-For")
if xff != "" {
ips := strings.Split(xff, ",")
if len(ips) > 0 {
return strings.TrimSpace(ips[0])
}
}
// X-Real-IP
xri := r.Header.Get("X-Real-IP")
if xri != "" {
return xri
}
// RemoteAddr
ip := r.RemoteAddr
if idx := strings.LastIndex(ip, ":"); idx != -1 {
ip = ip[:idx]
}
return ip
}
// parseQuery 解析query参数
func parseQuery(query string) map[string]string {
m := make(map[string]string)
for _, v := range strings.Split(query, "&") {
kv := strings.Split(v, "=")
if len(kv) == 2 {
m[kv[0]] = kv[1]
}
}
return m
}

View File

@ -103,5 +103,20 @@ func (j *JWT) ParseToken(tokenString string) (*CustomClaims, error) {
// NeedRefresh 判断是否需要刷新token
func (j *JWT) NeedRefresh(claims *CustomClaims) bool {
return time.Until(claims.ExpiresAt.Time) < j.BufferTime
return claims.ExpiresAt.Unix()-time.Now().Unix() < claims.BufferTime
}
// CreateTokenByOldToken 根据旧token创建新token
func (j *JWT) CreateTokenByOldToken(oldToken string, oldClaims CustomClaims) (string, *CustomClaims, error) {
// 更新过期时间
oldClaims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(j.ExpiresAt))
oldClaims.IssuedAt = jwt.NewNumericDate(time.Now())
oldClaims.NotBefore = jwt.NewNumericDate(time.Now().Add(-1000))
token := jwt.NewWithClaims(jwt.SigningMethodHS256, oldClaims)
newToken, err := token.SignedString(j.SigningKey)
if err != nil {
return "", nil, err
}
return newToken, &oldClaims, nil
}