package middleware import ( "context" "fmt" "time" "github.com/go-kratos/kratos/v2/errors" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/transport" kratoshttp "github.com/go-kratos/kratos/v2/transport/http" "github.com/redis/go-redis/v9" ) // LimitConfig 限流配置 type LimitConfig struct { // GenerationKey 根据业务生成key GenerationKey func(ctx context.Context, path string) string // CheckOrMark 检查函数 CheckOrMark func(ctx context.Context, key string, expire int, limit int) error // Expire key过期时间(秒) Expire int // Limit 周期内限制次数 Limit int } // RateLimit 限流中间件 func RateLimit(cfg LimitConfig) middleware.Middleware { return func(handler middleware.Handler) middleware.Handler { return func(ctx context.Context, req interface{}) (interface{}, error) { var key string if tr, ok := transport.FromServerContext(ctx); ok { path := tr.Operation() key = cfg.GenerationKey(ctx, path) } if err := cfg.CheckOrMark(ctx, key, cfg.Expire, cfg.Limit); err != nil { return nil, errors.New(429, "RATE_LIMIT", err.Error()) } return handler(ctx, req) } } } // DefaultGenerationKey 默认生成key(基于IP) func DefaultGenerationKey(ctx context.Context, path string) string { if tr, ok := transport.FromServerContext(ctx); ok { if ht, ok := tr.(kratoshttp.Transporter); ok { return "Limit_" + getClientIP(ht.Request()) } } return "Limit_unknown" } // NewRedisCheckOrMark 创建基于Redis的检查函数 func NewRedisCheckOrMark(rdb *redis.Client) func(ctx context.Context, key string, expire int, limit int) error { return func(ctx context.Context, key string, expire int, limit int) error { if rdb == nil { return nil } return SetLimitWithTime(ctx, rdb, key, limit, time.Duration(expire)*time.Second) } } // SetLimitWithTime 设置访问次数 func SetLimitWithTime(ctx context.Context, rdb *redis.Client, key string, limit int, expiration time.Duration) error { count, err := rdb.Exists(ctx, key).Result() if err != nil { return err } if count == 0 { // key不存在,创建并设置过期时间 pipe := rdb.TxPipeline() pipe.Incr(ctx, key) pipe.Expire(ctx, key, expiration) _, err = pipe.Exec(ctx) return err } // key存在,检查次数 times, err := rdb.Get(ctx, key).Int() if err != nil { return err } if times >= limit { // 获取剩余时间 ttl, err := rdb.PTTL(ctx, key).Result() if err != nil { return fmt.Errorf("请求太过频繁,请稍后再试") } return fmt.Errorf("请求太过频繁, 请 %s 后尝试", ttl.String()) } return rdb.Incr(ctx, key).Err() } // DefaultRateLimit 默认限流中间件 func DefaultRateLimit(rdb *redis.Client, expire, limit int) middleware.Middleware { return RateLimit(LimitConfig{ GenerationKey: DefaultGenerationKey, CheckOrMark: NewRedisCheckOrMark(rdb), Expire: expire, Limit: limit, }) }