diff --git a/internal/server/middleware/casbin.go b/internal/server/middleware/casbin.go new file mode 100644 index 0000000..074ce38 --- /dev/null +++ b/internal/server/middleware/casbin.go @@ -0,0 +1,64 @@ +package middleware + +import ( + "context" + "errors" + "strconv" + "strings" + + "github.com/go-kratos/kratos/v2/middleware" + "github.com/go-kratos/kratos/v2/transport" + + pkgcasbin "kra/pkg/casbin" +) + +var ( + ErrPermissionDenied = errors.New("权限不足") +) + +// CasbinRBAC Casbin权限中间件 +func CasbinRBAC(routerPrefix string) middleware.Middleware { + return func(handler middleware.Handler) middleware.Handler { + return func(ctx context.Context, req interface{}) (interface{}, error) { + claims, ok := GetClaims(ctx) + if !ok { + return nil, ErrMissingToken + } + + if tr, ok := transport.FromServerContext(ctx); ok { + // 获取请求路径 + path := tr.Operation() + if routerPrefix != "" { + path = strings.TrimPrefix(path, routerPrefix) + } + + // 获取请求方法 + act := "GET" + if header := tr.RequestHeader(); header != nil { + if method := header.Get(":method"); method != "" { + act = method + } + } + + // 获取用户角色 + sub := strconv.Itoa(int(claims.AuthorityID)) + + // 检查权限 + enforcer := pkgcasbin.GetEnforcer() + if enforcer == nil { + return handler(ctx, req) + } + + success, err := enforcer.Enforce(sub, path, act) + if err != nil { + return nil, err + } + if !success { + return nil, ErrPermissionDenied + } + } + + return handler(ctx, req) + } + } +} diff --git a/internal/server/middleware/cors.go b/internal/server/middleware/cors.go new file mode 100644 index 0000000..71761c8 --- /dev/null +++ b/internal/server/middleware/cors.go @@ -0,0 +1,107 @@ +package middleware + +import ( + "context" + "net/http" + + "github.com/go-kratos/kratos/v2/errors" + "github.com/go-kratos/kratos/v2/middleware" + "github.com/go-kratos/kratos/v2/transport" + kratoshttp "github.com/go-kratos/kratos/v2/transport/http" +) + +// CORSWhitelist CORS白名单配置 +type CORSWhitelist struct { + AllowOrigin string + AllowHeaders string + AllowMethods string + ExposeHeaders string + AllowCredentials bool +} + +// CORSConfig CORS配置 +type CORSConfig struct { + // Mode 模式: allow-all, whitelist, strict-whitelist + Mode string + Whitelist []CORSWhitelist +} + +// CORS 跨域中间件(放行所有) +func CORS() middleware.Middleware { + return func(handler middleware.Handler) middleware.Handler { + return func(ctx context.Context, req interface{}) (interface{}, error) { + if tr, ok := transport.FromServerContext(ctx); ok { + if ht, ok := tr.(kratoshttp.Transporter); ok { + origin := ht.Request().Header.Get("Origin") + header := ht.ReplyHeader() + header.Set("Access-Control-Allow-Origin", origin) + header.Set("Access-Control-Allow-Headers", "Content-Type,AccessToken,X-CSRF-Token,Authorization,Token,X-Token,X-User-Id") + header.Set("Access-Control-Allow-Methods", "POST,GET,OPTIONS,DELETE,PUT") + header.Set("Access-Control-Expose-Headers", "Content-Length,Access-Control-Allow-Origin,Access-Control-Allow-Headers,Content-Type,New-Token,New-Expires-At") + header.Set("Access-Control-Allow-Credentials", "true") + + // OPTIONS请求直接返回 + if ht.Request().Method == http.MethodOptions { + return nil, nil + } + } + } + return handler(ctx, req) + } + } +} + +// CORSByRules 按配置规则处理跨域 +func CORSByRules(cfg CORSConfig) middleware.Middleware { + // 放行全部 + if cfg.Mode == "allow-all" { + return CORS() + } + + return func(handler middleware.Handler) middleware.Handler { + return func(ctx context.Context, req interface{}) (interface{}, error) { + if tr, ok := transport.FromServerContext(ctx); ok { + if ht, ok := tr.(kratoshttp.Transporter); ok { + r := ht.Request() + origin := r.Header.Get("Origin") + whitelist := checkCors(origin, cfg.Whitelist) + + header := ht.ReplyHeader() + // 通过检查,添加请求头 + if whitelist != nil { + header.Set("Access-Control-Allow-Origin", whitelist.AllowOrigin) + header.Set("Access-Control-Allow-Headers", whitelist.AllowHeaders) + header.Set("Access-Control-Allow-Methods", whitelist.AllowMethods) + header.Set("Access-Control-Expose-Headers", whitelist.ExposeHeaders) + if whitelist.AllowCredentials { + header.Set("Access-Control-Allow-Credentials", "true") + } + } + + // 严格白名单模式且未通过检查,直接拒绝 + if whitelist == nil && cfg.Mode == "strict-whitelist" { + // 健康检查放行 + if !(r.Method == http.MethodGet && r.URL.Path == "/health") { + return nil, errors.Forbidden("CORS_FORBIDDEN", "forbidden") + } + } + + // OPTIONS请求直接返回 + if r.Method == http.MethodOptions { + return nil, nil + } + } + } + return handler(ctx, req) + } + } +} + +func checkCors(currentOrigin string, whitelist []CORSWhitelist) *CORSWhitelist { + for _, w := range whitelist { + if currentOrigin == w.AllowOrigin { + return &w + } + } + return nil +} diff --git a/internal/server/middleware/email.go b/internal/server/middleware/email.go new file mode 100644 index 0000000..08a1c58 --- /dev/null +++ b/internal/server/middleware/email.go @@ -0,0 +1,78 @@ +package middleware + +import ( + "bytes" + "context" + "fmt" + "io" + "time" + + "github.com/go-kratos/kratos/v2/log" + "github.com/go-kratos/kratos/v2/middleware" + "github.com/go-kratos/kratos/v2/transport" + kratoshttp "github.com/go-kratos/kratos/v2/transport/http" +) + +// EmailSender 邮件发送接口 +type EmailSender interface { + SendError(subject, body string) error +} + +// ErrorToEmail 错误邮件通知中间件 +func ErrorToEmail(sender EmailSender, logger log.Logger) middleware.Middleware { + helper := log.NewHelper(logger) + return func(handler middleware.Handler) middleware.Handler { + return func(ctx context.Context, req interface{}) (interface{}, error) { + var ( + username string + body string + method string + path string + ip string + ) + + // 获取用户信息 + if claims, ok := GetClaims(ctx); ok { + username = claims.Username + } + if username == "" { + username = "Unknown" + } + + if tr, ok := transport.FromServerContext(ctx); ok { + path = tr.Operation() + if header := tr.RequestHeader(); header != nil { + method = header.Get(":method") + } + + if ht, ok := tr.(kratoshttp.Transporter); ok { + r := ht.Request() + ip = getClientIP(r) + + // 读取body + bodyBytes, _ := io.ReadAll(r.Body) + r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + body = string(bodyBytes) + } + } + + start := time.Now() + reply, err := handler(ctx, req) + latency := time.Since(start) + + // 如果有错误,发送邮件 + if err != nil && sender != nil { + subject := fmt.Sprintf("%s %s调用了%s报错了", username, ip, path) + content := fmt.Sprintf( + "接收到的请求为%s\n请求方式为%s\n报错信息如下%s\n耗时%s\n", + body, method, err.Error(), latency.String(), + ) + if sendErr := sender.SendError(subject, content); sendErr != nil { + helper.Error("ErrorToEmail Failed, err:", sendErr) + } + } + + return reply, err + } + } +} diff --git a/internal/server/middleware/jwt.go b/internal/server/middleware/jwt.go new file mode 100644 index 0000000..d4a8c96 --- /dev/null +++ b/internal/server/middleware/jwt.go @@ -0,0 +1,156 @@ +package middleware + +import ( + "context" + "errors" + "strconv" + "strings" + "time" + + "github.com/go-kratos/kratos/v2/middleware" + "github.com/go-kratos/kratos/v2/transport" + "github.com/go-kratos/kratos/v2/transport/http" + + pkgjwt "kra/pkg/jwt" +) + +type authKey struct{} + +var ( + ErrMissingToken = errors.New("未登录或非法访问") + ErrInvalidToken = errors.New("无效的token") + ErrBlacklistJWT = errors.New("您的帐户异地登陆或令牌失效") +) + +// BlacklistChecker JWT黑名单检查接口 +type BlacklistChecker interface { + IsBlacklist(jwt string) bool +} + +// JWTAuthConfig JWT认证配置 +type JWTAuthConfig struct { + JWT *pkgjwt.JWT + BlacklistChecker BlacklistChecker + UseMultipoint bool + SetRedisJWT func(token, username string) error +} + +// JWTAuth JWT认证中间件 +func JWTAuth(cfg JWTAuthConfig) middleware.Middleware { + return func(handler middleware.Handler) middleware.Handler { + return func(ctx context.Context, req interface{}) (interface{}, error) { + if tr, ok := transport.FromServerContext(ctx); ok { + token := extractToken(tr) + if token == "" { + return nil, ErrMissingToken + } + + // 检查黑名单 + if cfg.BlacklistChecker != nil && cfg.BlacklistChecker.IsBlacklist(token) { + // 清除token + if ht, ok := tr.(http.Transporter); ok { + clearToken(ht) + } + return nil, ErrBlacklistJWT + } + + claims, err := cfg.JWT.ParseToken(token) + if err != nil { + if ht, ok := tr.(http.Transporter); ok { + clearToken(ht) + } + if errors.Is(err, pkgjwt.ErrTokenExpired) { + return nil, errors.New("登录已过期,请重新登录") + } + return nil, ErrInvalidToken + } + + // 将claims存入context + ctx = context.WithValue(ctx, authKey{}, claims) + + // 检查是否需要刷新token + if cfg.JWT.NeedRefresh(claims) { + newToken, newClaims, err := cfg.JWT.CreateTokenByOldToken(token, *claims) + if err == nil { + if ht, ok := tr.(http.Transporter); ok { + ht.ReplyHeader().Set("new-token", newToken) + ht.ReplyHeader().Set("new-expires-at", strconv.FormatInt(newClaims.ExpiresAt.Unix(), 10)) + setToken(ht, newToken, int(time.Until(newClaims.ExpiresAt.Time).Seconds()/60)) + } + // 多点登录记录新JWT + if cfg.UseMultipoint && cfg.SetRedisJWT != nil { + _ = cfg.SetRedisJWT(newToken, newClaims.Username) + } + } + } + } + return handler(ctx, req) + } + } +} + +// JWTAuthSimple 简化版JWT认证中间件(不含黑名单检查) +func JWTAuthSimple(j *pkgjwt.JWT) middleware.Middleware { + return JWTAuth(JWTAuthConfig{JWT: j}) +} + +// extractToken 从请求中提取token +func extractToken(tr transport.Transporter) string { + // 优先从Authorization头获取 + auth := tr.RequestHeader().Get("Authorization") + if auth != "" { + if strings.HasPrefix(auth, "Bearer ") { + return strings.TrimPrefix(auth, "Bearer ") + } + return auth + } + + // 从x-token头获取 + token := tr.RequestHeader().Get("x-token") + if token != "" { + return token + } + + return "" +} + +// clearToken 清除token +func clearToken(ht http.Transporter) { + ht.ReplyHeader().Set("x-token", "") +} + +// setToken 设置token到cookie +func setToken(ht http.Transporter, token string, maxAge int) { + // 通过header设置cookie信息,让前端处理 + ht.ReplyHeader().Set("x-token", token) +} + +// GetClaims 从context获取claims +func GetClaims(ctx context.Context) (*pkgjwt.CustomClaims, bool) { + claims, ok := ctx.Value(authKey{}).(*pkgjwt.CustomClaims) + return claims, ok +} + +// GetUserID 从context获取用户ID +func GetUserID(ctx context.Context) uint { + if claims, ok := GetClaims(ctx); ok { + return claims.BaseClaims.ID + } + return 0 +} + +// GetUsername 从context获取用户名 +func GetUsername(ctx context.Context) string { + if claims, ok := GetClaims(ctx); ok { + return claims.Username + } + return "" +} + +// GetAuthorityID 从context获取角色ID +func GetAuthorityID(ctx context.Context) uint { + if claims, ok := GetClaims(ctx); ok { + return claims.AuthorityID + } + return 0 +} diff --git a/internal/server/middleware/logger.go b/internal/server/middleware/logger.go new file mode 100644 index 0000000..77aaa77 --- /dev/null +++ b/internal/server/middleware/logger.go @@ -0,0 +1,133 @@ +package middleware + +import ( + "bytes" + "context" + "encoding/json" + "io" + "time" + + "github.com/go-kratos/kratos/v2/log" + "github.com/go-kratos/kratos/v2/middleware" + "github.com/go-kratos/kratos/v2/transport" + kratoshttp "github.com/go-kratos/kratos/v2/transport/http" +) + +// LogLayout 日志layout +type LogLayout struct { + Time time.Time `json:"time"` + Metadata map[string]interface{} `json:"metadata,omitempty"` + Path string `json:"path"` + Query string `json:"query,omitempty"` + Body string `json:"body,omitempty"` + IP string `json:"ip"` + UserAgent string `json:"user_agent"` + Error string `json:"error,omitempty"` + Cost time.Duration `json:"cost"` + Source string `json:"source"` +} + +// LoggerConfig 日志中间件配置 +type LoggerConfig struct { + // Filter 用户自定义过滤 + Filter func(ctx context.Context, path string) bool + // FilterKeyword 关键字过滤 + FilterKeyword func(layout *LogLayout) bool + // AuthProcess 鉴权处理 + AuthProcess func(ctx context.Context, layout *LogLayout) + // Print 日志处理 + Print func(LogLayout) + // Source 服务唯一标识 + Source string +} + +// Logger 日志中间件 +func Logger(cfg LoggerConfig) middleware.Middleware { + if cfg.Print == nil { + cfg.Print = func(layout LogLayout) { + v, _ := json.Marshal(layout) + log.Info(string(v)) + } + } + if cfg.Source == "" { + cfg.Source = "Kratos" + } + + return func(handler middleware.Handler) middleware.Handler { + return func(ctx context.Context, req interface{}) (interface{}, error) { + var ( + path string + query string + body string + ip string + userAgent string + ) + + if tr, ok := transport.FromServerContext(ctx); ok { + path = tr.Operation() + if header := tr.RequestHeader(); header != nil { + userAgent = header.Get("User-Agent") + } + + if ht, ok := tr.(kratoshttp.Transporter); ok { + r := ht.Request() + query = r.URL.RawQuery + ip = getClientIP(r) + + // 读取body(仅在未过滤时) + if cfg.Filter == nil || !cfg.Filter(ctx, path) { + bodyBytes, _ := io.ReadAll(r.Body) + r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + body = string(bodyBytes) + } + } + } + + // 检查是否需要过滤 + if cfg.Filter != nil && cfg.Filter(ctx, path) { + return handler(ctx, req) + } + + start := time.Now() + reply, err := handler(ctx, req) + cost := time.Since(start) + + layout := LogLayout{ + Time: time.Now(), + Path: path, + Query: query, + Body: body, + IP: ip, + UserAgent: userAgent, + Cost: cost, + Source: cfg.Source, + } + + if err != nil { + layout.Error = err.Error() + } + + // 处理鉴权信息 + if cfg.AuthProcess != nil { + cfg.AuthProcess(ctx, &layout) + } + + // 关键字过滤 + if cfg.FilterKeyword != nil { + cfg.FilterKeyword(&layout) + } + + // 输出日志 + cfg.Print(layout) + + return reply, err + } + } +} + +// DefaultLogger 默认日志中间件 +func DefaultLogger() middleware.Middleware { + return Logger(LoggerConfig{ + Source: "Kratos", + }) +} diff --git a/internal/server/middleware/operation.go b/internal/server/middleware/operation.go new file mode 100644 index 0000000..bc858be --- /dev/null +++ b/internal/server/middleware/operation.go @@ -0,0 +1,155 @@ +package middleware + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/go-kratos/kratos/v2/log" + "github.com/go-kratos/kratos/v2/middleware" + "github.com/go-kratos/kratos/v2/transport" + kratoshttp "github.com/go-kratos/kratos/v2/transport/http" + "gorm.io/gorm" +) + +const maxBodySize = 1024 + +// OperationRecord 操作记录 +type OperationRecord struct { + ID uint `gorm:"primarykey"` + CreatedAt time.Time `gorm:"index"` + Ip string `gorm:"column:ip;comment:请求ip"` + Method string `gorm:"column:method;comment:请求方法"` + Path string `gorm:"column:path;comment:请求路径"` + Status int `gorm:"column:status;comment:请求状态"` + Latency time.Duration `gorm:"column:latency;comment:延迟"` + Agent string `gorm:"column:agent;comment:代理"` + ErrorMessage string `gorm:"column:error_message;comment:错误信息"` + Body string `gorm:"type:text;column:body;comment:请求Body"` + Resp string `gorm:"type:text;column:resp;comment:响应Body"` + UserID int `gorm:"column:user_id;comment:用户id"` +} + +func (OperationRecord) TableName() string { + return "sys_operation_records" +} + +// OperationLog 操作日志中间件 +func OperationLog(db *gorm.DB, logger log.Logger) middleware.Middleware { + return func(handler middleware.Handler) middleware.Handler { + return func(ctx context.Context, req interface{}) (interface{}, error) { + var ( + body string + userID int + ip string + method string + path string + agent string + ) + + if tr, ok := transport.FromServerContext(ctx); ok { + path = tr.Operation() + if header := tr.RequestHeader(); header != nil { + method = header.Get(":method") + agent = header.Get("User-Agent") + } + + // 获取HTTP请求信息 + if ht, ok := tr.(kratoshttp.Transporter); ok { + r := ht.Request() + ip = getClientIP(r) + + if r.Method != http.MethodGet { + bodyBytes, _ := io.ReadAll(r.Body) + r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + if len(bodyBytes) > maxBodySize { + body = "[超出记录长度]" + } else { + body = string(bodyBytes) + } + } else { + query := r.URL.RawQuery + query, _ = url.QueryUnescape(query) + params := parseQuery(query) + bodyBytes, _ := json.Marshal(params) + body = string(bodyBytes) + } + + // 检查是否为文件上传 + if strings.Contains(r.Header.Get("Content-Type"), "multipart/form-data") { + body = "[文件]" + } + } + } + + // 获取用户ID + if claims, ok := GetClaims(ctx); ok { + userID = int(claims.BaseClaims.ID) + } else { + // 尝试从header获取 + if tr, ok := transport.FromServerContext(ctx); ok { + if ht, ok := tr.(kratoshttp.Transporter); ok { + if idStr := ht.Request().Header.Get("x-user-id"); idStr != "" { + if id, err := strconv.Atoi(idStr); err == nil { + userID = id + } + } + } + } + } + + start := time.Now() + reply, err := handler(ctx, req) + latency := time.Since(start) + + // 记录响应 + var resp string + if reply != nil { + respBytes, _ := json.Marshal(reply) + if len(respBytes) > maxBodySize { + resp = "[超出记录长度]" + } else { + resp = string(respBytes) + } + } + + // 记录错误 + var errMsg string + status := 200 + if err != nil { + errMsg = err.Error() + status = 500 + } + + // 保存记录 + record := OperationRecord{ + Ip: ip, + Method: method, + Path: path, + Status: status, + Latency: latency, + Agent: agent, + ErrorMessage: errMsg, + Body: body, + Resp: resp, + UserID: userID, + } + + if db != nil { + go func() { + if err := db.Create(&record).Error; err != nil { + log.NewHelper(logger).Error("create operation record error:", err) + } + }() + } + + return reply, err + } + } +} diff --git a/internal/server/middleware/ratelimit.go b/internal/server/middleware/ratelimit.go new file mode 100644 index 0000000..99d7b34 --- /dev/null +++ b/internal/server/middleware/ratelimit.go @@ -0,0 +1,108 @@ +package middleware + +import ( + "context" + "fmt" + "time" + + "github.com/go-kratos/kratos/v2/errors" + "github.com/go-kratos/kratos/v2/middleware" + "github.com/go-kratos/kratos/v2/transport" + kratoshttp "github.com/go-kratos/kratos/v2/transport/http" + "github.com/redis/go-redis/v9" +) + +// LimitConfig 限流配置 +type LimitConfig struct { + // GenerationKey 根据业务生成key + GenerationKey func(ctx context.Context, path string) string + // CheckOrMark 检查函数 + CheckOrMark func(ctx context.Context, key string, expire int, limit int) error + // Expire key过期时间(秒) + Expire int + // Limit 周期内限制次数 + Limit int +} + +// RateLimit 限流中间件 +func RateLimit(cfg LimitConfig) middleware.Middleware { + return func(handler middleware.Handler) middleware.Handler { + return func(ctx context.Context, req interface{}) (interface{}, error) { + var key string + if tr, ok := transport.FromServerContext(ctx); ok { + path := tr.Operation() + key = cfg.GenerationKey(ctx, path) + } + + if err := cfg.CheckOrMark(ctx, key, cfg.Expire, cfg.Limit); err != nil { + return nil, errors.New(429, "RATE_LIMIT", err.Error()) + } + + return handler(ctx, req) + } + } +} + +// DefaultGenerationKey 默认生成key(基于IP) +func DefaultGenerationKey(ctx context.Context, path string) string { + if tr, ok := transport.FromServerContext(ctx); ok { + if ht, ok := tr.(kratoshttp.Transporter); ok { + return "Limit_" + getClientIP(ht.Request()) + } + } + return "Limit_unknown" +} + +// NewRedisCheckOrMark 创建基于Redis的检查函数 +func NewRedisCheckOrMark(rdb *redis.Client) func(ctx context.Context, key string, expire int, limit int) error { + return func(ctx context.Context, key string, expire int, limit int) error { + if rdb == nil { + return nil + } + return SetLimitWithTime(ctx, rdb, key, limit, time.Duration(expire)*time.Second) + } +} + +// SetLimitWithTime 设置访问次数 +func SetLimitWithTime(ctx context.Context, rdb *redis.Client, key string, limit int, expiration time.Duration) error { + count, err := rdb.Exists(ctx, key).Result() + if err != nil { + return err + } + + if count == 0 { + // key不存在,创建并设置过期时间 + pipe := rdb.TxPipeline() + pipe.Incr(ctx, key) + pipe.Expire(ctx, key, expiration) + _, err = pipe.Exec(ctx) + return err + } + + // key存在,检查次数 + times, err := rdb.Get(ctx, key).Int() + if err != nil { + return err + } + + if times >= limit { + // 获取剩余时间 + ttl, err := rdb.PTTL(ctx, key).Result() + if err != nil { + return fmt.Errorf("请求太过频繁,请稍后再试") + } + return fmt.Errorf("请求太过频繁, 请 %s 后尝试", ttl.String()) + } + + return rdb.Incr(ctx, key).Err() +} + +// DefaultRateLimit 默认限流中间件 +func DefaultRateLimit(rdb *redis.Client, expire, limit int) middleware.Middleware { + return RateLimit(LimitConfig{ + GenerationKey: DefaultGenerationKey, + CheckOrMark: NewRedisCheckOrMark(rdb), + Expire: expire, + Limit: limit, + }) +} diff --git a/internal/server/middleware/recovery.go b/internal/server/middleware/recovery.go new file mode 100644 index 0000000..98910f6 --- /dev/null +++ b/internal/server/middleware/recovery.go @@ -0,0 +1,55 @@ +package middleware + +import ( + "context" + "net" + "os" + "runtime/debug" + "strings" + + "github.com/go-kratos/kratos/v2/errors" + "github.com/go-kratos/kratos/v2/log" + "github.com/go-kratos/kratos/v2/middleware" +) + +// Recovery panic恢复中间件 +func Recovery(logger log.Logger) middleware.Middleware { + helper := log.NewHelper(logger) + return func(handler middleware.Handler) middleware.Handler { + return func(ctx context.Context, req interface{}) (reply interface{}, err error) { + defer func() { + if rerr := recover(); rerr != nil { + // 检查是否为断开的连接 + var brokenPipe bool + if ne, ok := rerr.(*net.OpError); ok { + if se, ok := ne.Err.(*os.SyscallError); ok { + errStr := strings.ToLower(se.Error()) + if strings.Contains(errStr, "broken pipe") || + strings.Contains(errStr, "connection reset by peer") { + brokenPipe = true + } + } + } + + stack := string(debug.Stack()) + if brokenPipe { + helper.Errorw( + "msg", "[Recovery from panic]", + "error", rerr, + "broken_pipe", true, + ) + } else { + helper.Errorw( + "msg", "[Recovery from panic]", + "error", rerr, + "stack", stack, + ) + } + + err = errors.InternalServer("PANIC", "服务器内部错误") + } + }() + return handler(ctx, req) + } + } +} diff --git a/internal/server/middleware/timeout.go b/internal/server/middleware/timeout.go new file mode 100644 index 0000000..772cdd1 --- /dev/null +++ b/internal/server/middleware/timeout.go @@ -0,0 +1,53 @@ +package middleware + +import ( + "context" + "time" + + "github.com/go-kratos/kratos/v2/errors" + "github.com/go-kratos/kratos/v2/middleware" +) + +// Timeout 超时中间件 +func Timeout(timeout time.Duration) middleware.Middleware { + return func(handler middleware.Handler) middleware.Handler { + return func(ctx context.Context, req interface{}) (interface{}, error) { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + // 使用 buffered channel 避免 goroutine 泄漏 + type result struct { + reply interface{} + err error + } + done := make(chan result, 1) + panicChan := make(chan interface{}, 1) + + go func() { + defer func() { + if p := recover(); p != nil { + select { + case panicChan <- p: + default: + } + } + }() + + reply, err := handler(ctx, req) + select { + case done <- result{reply: reply, err: err}: + default: + } + }() + + select { + case p := <-panicChan: + panic(p) + case r := <-done: + return r.reply, r.err + case <-ctx.Done(): + return nil, errors.GatewayTimeout("TIMEOUT", "请求超时") + } + } + } +} diff --git a/internal/server/middleware/utils.go b/internal/server/middleware/utils.go new file mode 100644 index 0000000..62de1bb --- /dev/null +++ b/internal/server/middleware/utils.go @@ -0,0 +1,43 @@ +package middleware + +import ( + "net/http" + "strings" +) + +// getClientIP 获取客户端IP +func getClientIP(r *http.Request) string { + // X-Forwarded-For + xff := r.Header.Get("X-Forwarded-For") + if xff != "" { + ips := strings.Split(xff, ",") + if len(ips) > 0 { + return strings.TrimSpace(ips[0]) + } + } + + // X-Real-IP + xri := r.Header.Get("X-Real-IP") + if xri != "" { + return xri + } + + // RemoteAddr + ip := r.RemoteAddr + if idx := strings.LastIndex(ip, ":"); idx != -1 { + ip = ip[:idx] + } + return ip +} + +// parseQuery 解析query参数 +func parseQuery(query string) map[string]string { + m := make(map[string]string) + for _, v := range strings.Split(query, "&") { + kv := strings.Split(v, "=") + if len(kv) == 2 { + m[kv[0]] = kv[1] + } + } + return m +} diff --git a/pkg/jwt/jwt.go b/pkg/jwt/jwt.go index ab64790..3e96c0c 100644 --- a/pkg/jwt/jwt.go +++ b/pkg/jwt/jwt.go @@ -103,5 +103,20 @@ func (j *JWT) ParseToken(tokenString string) (*CustomClaims, error) { // NeedRefresh 判断是否需要刷新token func (j *JWT) NeedRefresh(claims *CustomClaims) bool { - return time.Until(claims.ExpiresAt.Time) < j.BufferTime + return claims.ExpiresAt.Unix()-time.Now().Unix() < claims.BufferTime +} + +// CreateTokenByOldToken 根据旧token创建新token +func (j *JWT) CreateTokenByOldToken(oldToken string, oldClaims CustomClaims) (string, *CustomClaims, error) { + // 更新过期时间 + oldClaims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(j.ExpiresAt)) + oldClaims.IssuedAt = jwt.NewNumericDate(time.Now()) + oldClaims.NotBefore = jwt.NewNumericDate(time.Now().Add(-1000)) + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, oldClaims) + newToken, err := token.SignedString(j.SigningKey) + if err != nil { + return "", nil, err + } + return newToken, &oldClaims, nil }