package middleware import ( "context" "errors" "net/http" "time" "kra/pkg/response" "github.com/gin-gonic/gin" "github.com/redis/go-redis/v9" "go.uber.org/zap" ) // LimitConfig IP限流配置 type LimitConfig struct { // GenerationKey 根据业务生成key GenerationKey func(c *gin.Context) string // CheckOrMark 检查函数 CheckOrMark func(key string, expire int, limit int) error // Expire key 过期时间(秒) Expire int // Limit 周期内限制次数 Limit int } var redisClient redis.UniversalClient var limitLogger *zap.Logger // SetRedisClient 设置Redis客户端 func SetRedisClient(client redis.UniversalClient) { redisClient = client } // SetLimitLogger 设置限流日志 func SetLimitLogger(logger *zap.Logger) { limitLogger = logger } // LimitWithTime 返回限流中间件 func (l LimitConfig) LimitWithTime() gin.HandlerFunc { return func(c *gin.Context) { if err := l.CheckOrMark(l.GenerationKey(c), l.Expire, l.Limit); err != nil { c.JSON(http.StatusOK, gin.H{"code": response.ERROR, "msg": err.Error()}) c.Abort() return } c.Next() } } // DefaultGenerationKey 默认生成key func DefaultGenerationKey(c *gin.Context) string { return "KRA_Limit" + c.ClientIP() } // DefaultCheckOrMark 默认检查函数 func DefaultCheckOrMark(key string, expire int, limit int) error { if redisClient == nil { return nil // Redis未配置,跳过限流 } if err := SetLimitWithTime(key, limit, time.Duration(expire)*time.Second); err != nil { if limitLogger != nil { limitLogger.Error("limit", zap.Error(err)) } return err } return nil } // DefaultLimit 默认限流中间件 func DefaultLimit(expire, limit int) gin.HandlerFunc { return LimitConfig{ GenerationKey: DefaultGenerationKey, CheckOrMark: DefaultCheckOrMark, Expire: expire, Limit: limit, }.LimitWithTime() } // SetLimitWithTime 设置访问次数 func SetLimitWithTime(key string, limit int, expiration time.Duration) error { ctx := context.Background() count, err := redisClient.Exists(ctx, key).Result() if err != nil { return err } if count == 0 { pipe := redisClient.TxPipeline() pipe.Incr(ctx, key) pipe.Expire(ctx, key, expiration) _, err = pipe.Exec(ctx) return err } times, err := redisClient.Get(ctx, key).Int() if err != nil { return err } if times >= limit { t, err := redisClient.PTTL(ctx, key).Result() if err != nil { return errors.New("请求太过频繁,请稍后再试") } return errors.New("请求太过频繁, 请 " + t.String() + " 后尝试") } return redisClient.Incr(ctx, key).Err() }