kra/internal/server/middleware/gin_limit_ip.go

113 lines
2.6 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"context"
"errors"
"net/http"
"time"
"kra/pkg/response"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
)
// LimitConfig IP限流配置
type LimitConfig struct {
// GenerationKey 根据业务生成key
GenerationKey func(c *gin.Context) string
// CheckOrMark 检查函数
CheckOrMark func(key string, expire int, limit int) error
// Expire key 过期时间(秒)
Expire int
// Limit 周期内限制次数
Limit int
}
var redisClient redis.UniversalClient
var limitLogger *zap.Logger
// SetRedisClient 设置Redis客户端
func SetRedisClient(client redis.UniversalClient) {
redisClient = client
}
// SetLimitLogger 设置限流日志
func SetLimitLogger(logger *zap.Logger) {
limitLogger = logger
}
// LimitWithTime 返回限流中间件
func (l LimitConfig) LimitWithTime() gin.HandlerFunc {
return func(c *gin.Context) {
if err := l.CheckOrMark(l.GenerationKey(c), l.Expire, l.Limit); err != nil {
c.JSON(http.StatusOK, gin.H{"code": response.ERROR, "msg": err.Error()})
c.Abort()
return
}
c.Next()
}
}
// DefaultGenerationKey 默认生成key
func DefaultGenerationKey(c *gin.Context) string {
return "KRA_Limit" + c.ClientIP()
}
// DefaultCheckOrMark 默认检查函数
func DefaultCheckOrMark(key string, expire int, limit int) error {
if redisClient == nil {
return nil // Redis未配置跳过限流
}
if err := SetLimitWithTime(key, limit, time.Duration(expire)*time.Second); err != nil {
if limitLogger != nil {
limitLogger.Error("limit", zap.Error(err))
}
return err
}
return nil
}
// DefaultLimit 默认限流中间件
func DefaultLimit(expire, limit int) gin.HandlerFunc {
return LimitConfig{
GenerationKey: DefaultGenerationKey,
CheckOrMark: DefaultCheckOrMark,
Expire: expire,
Limit: limit,
}.LimitWithTime()
}
// SetLimitWithTime 设置访问次数
func SetLimitWithTime(key string, limit int, expiration time.Duration) error {
ctx := context.Background()
count, err := redisClient.Exists(ctx, key).Result()
if err != nil {
return err
}
if count == 0 {
pipe := redisClient.TxPipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, expiration)
_, err = pipe.Exec(ctx)
return err
}
times, err := redisClient.Get(ctx, key).Int()
if err != nil {
return err
}
if times >= limit {
t, err := redisClient.PTTL(ctx, key).Result()
if err != nil {
return errors.New("请求太过频繁,请稍后再试")
}
return errors.New("请求太过频繁, 请 " + t.String() + " 后尝试")
}
return redisClient.Incr(ctx, key).Err()
}