113 lines
2.6 KiB
Go
113 lines
2.6 KiB
Go
package middleware
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"net/http"
|
||
"time"
|
||
|
||
"kra/pkg/response"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/redis/go-redis/v9"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// LimitConfig IP限流配置
|
||
type LimitConfig struct {
|
||
// GenerationKey 根据业务生成key
|
||
GenerationKey func(c *gin.Context) string
|
||
// CheckOrMark 检查函数
|
||
CheckOrMark func(key string, expire int, limit int) error
|
||
// Expire key 过期时间(秒)
|
||
Expire int
|
||
// Limit 周期内限制次数
|
||
Limit int
|
||
}
|
||
|
||
var redisClient redis.UniversalClient
|
||
var limitLogger *zap.Logger
|
||
|
||
// SetRedisClient 设置Redis客户端
|
||
func SetRedisClient(client redis.UniversalClient) {
|
||
redisClient = client
|
||
}
|
||
|
||
// SetLimitLogger 设置限流日志
|
||
func SetLimitLogger(logger *zap.Logger) {
|
||
limitLogger = logger
|
||
}
|
||
|
||
// LimitWithTime 返回限流中间件
|
||
func (l LimitConfig) LimitWithTime() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
if err := l.CheckOrMark(l.GenerationKey(c), l.Expire, l.Limit); err != nil {
|
||
c.JSON(http.StatusOK, gin.H{"code": response.ERROR, "msg": err.Error()})
|
||
c.Abort()
|
||
return
|
||
}
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// DefaultGenerationKey 默认生成key
|
||
func DefaultGenerationKey(c *gin.Context) string {
|
||
return "KRA_Limit" + c.ClientIP()
|
||
}
|
||
|
||
// DefaultCheckOrMark 默认检查函数
|
||
func DefaultCheckOrMark(key string, expire int, limit int) error {
|
||
if redisClient == nil {
|
||
return nil // Redis未配置,跳过限流
|
||
}
|
||
if err := SetLimitWithTime(key, limit, time.Duration(expire)*time.Second); err != nil {
|
||
if limitLogger != nil {
|
||
limitLogger.Error("limit", zap.Error(err))
|
||
}
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// DefaultLimit 默认限流中间件
|
||
func DefaultLimit(expire, limit int) gin.HandlerFunc {
|
||
return LimitConfig{
|
||
GenerationKey: DefaultGenerationKey,
|
||
CheckOrMark: DefaultCheckOrMark,
|
||
Expire: expire,
|
||
Limit: limit,
|
||
}.LimitWithTime()
|
||
}
|
||
|
||
// SetLimitWithTime 设置访问次数
|
||
func SetLimitWithTime(key string, limit int, expiration time.Duration) error {
|
||
ctx := context.Background()
|
||
count, err := redisClient.Exists(ctx, key).Result()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
if count == 0 {
|
||
pipe := redisClient.TxPipeline()
|
||
pipe.Incr(ctx, key)
|
||
pipe.Expire(ctx, key, expiration)
|
||
_, err = pipe.Exec(ctx)
|
||
return err
|
||
}
|
||
|
||
times, err := redisClient.Get(ctx, key).Int()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
if times >= limit {
|
||
t, err := redisClient.PTTL(ctx, key).Result()
|
||
if err != nil {
|
||
return errors.New("请求太过频繁,请稍后再试")
|
||
}
|
||
return errors.New("请求太过频繁, 请 " + t.String() + " 后尝试")
|
||
}
|
||
|
||
return redisClient.Incr(ctx, key).Err()
|
||
}
|