kra/internal/server/middleware/ratelimit.go

109 lines
2.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"context"
"fmt"
"time"
"github.com/go-kratos/kratos/v2/errors"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
kratoshttp "github.com/go-kratos/kratos/v2/transport/http"
"github.com/redis/go-redis/v9"
)
// LimitConfig 限流配置
type LimitConfig struct {
// GenerationKey 根据业务生成key
GenerationKey func(ctx context.Context, path string) string
// CheckOrMark 检查函数
CheckOrMark func(ctx context.Context, key string, expire int, limit int) error
// Expire key过期时间
Expire int
// Limit 周期内限制次数
Limit int
}
// RateLimit 限流中间件
func RateLimit(cfg LimitConfig) middleware.Middleware {
return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (interface{}, error) {
var key string
if tr, ok := transport.FromServerContext(ctx); ok {
path := tr.Operation()
key = cfg.GenerationKey(ctx, path)
}
if err := cfg.CheckOrMark(ctx, key, cfg.Expire, cfg.Limit); err != nil {
return nil, errors.New(429, "RATE_LIMIT", err.Error())
}
return handler(ctx, req)
}
}
}
// DefaultGenerationKey 默认生成key基于IP
func DefaultGenerationKey(ctx context.Context, path string) string {
if tr, ok := transport.FromServerContext(ctx); ok {
if ht, ok := tr.(kratoshttp.Transporter); ok {
return "Limit_" + getClientIP(ht.Request())
}
}
return "Limit_unknown"
}
// NewRedisCheckOrMark 创建基于Redis的检查函数
func NewRedisCheckOrMark(rdb *redis.Client) func(ctx context.Context, key string, expire int, limit int) error {
return func(ctx context.Context, key string, expire int, limit int) error {
if rdb == nil {
return nil
}
return SetLimitWithTime(ctx, rdb, key, limit, time.Duration(expire)*time.Second)
}
}
// SetLimitWithTime 设置访问次数
func SetLimitWithTime(ctx context.Context, rdb *redis.Client, key string, limit int, expiration time.Duration) error {
count, err := rdb.Exists(ctx, key).Result()
if err != nil {
return err
}
if count == 0 {
// key不存在创建并设置过期时间
pipe := rdb.TxPipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, expiration)
_, err = pipe.Exec(ctx)
return err
}
// key存在检查次数
times, err := rdb.Get(ctx, key).Int()
if err != nil {
return err
}
if times >= limit {
// 获取剩余时间
ttl, err := rdb.PTTL(ctx, key).Result()
if err != nil {
return fmt.Errorf("请求太过频繁,请稍后再试")
}
return fmt.Errorf("请求太过频繁, 请 %s 后尝试", ttl.String())
}
return rdb.Incr(ctx, key).Err()
}
// DefaultRateLimit 默认限流中间件
func DefaultRateLimit(rdb *redis.Client, expire, limit int) middleware.Middleware {
return RateLimit(LimitConfig{
GenerationKey: DefaultGenerationKey,
CheckOrMark: NewRedisCheckOrMark(rdb),
Expire: expire,
Limit: limit,
})
}