109 lines
2.9 KiB
Go
109 lines
2.9 KiB
Go
package middleware
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"time"
|
||
|
||
"github.com/go-kratos/kratos/v2/errors"
|
||
"github.com/go-kratos/kratos/v2/middleware"
|
||
"github.com/go-kratos/kratos/v2/transport"
|
||
kratoshttp "github.com/go-kratos/kratos/v2/transport/http"
|
||
"github.com/redis/go-redis/v9"
|
||
)
|
||
|
||
// LimitConfig 限流配置
|
||
type LimitConfig struct {
|
||
// GenerationKey 根据业务生成key
|
||
GenerationKey func(ctx context.Context, path string) string
|
||
// CheckOrMark 检查函数
|
||
CheckOrMark func(ctx context.Context, key string, expire int, limit int) error
|
||
// Expire key过期时间(秒)
|
||
Expire int
|
||
// Limit 周期内限制次数
|
||
Limit int
|
||
}
|
||
|
||
// RateLimit 限流中间件
|
||
func RateLimit(cfg LimitConfig) middleware.Middleware {
|
||
return func(handler middleware.Handler) middleware.Handler {
|
||
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
||
var key string
|
||
if tr, ok := transport.FromServerContext(ctx); ok {
|
||
path := tr.Operation()
|
||
key = cfg.GenerationKey(ctx, path)
|
||
}
|
||
|
||
if err := cfg.CheckOrMark(ctx, key, cfg.Expire, cfg.Limit); err != nil {
|
||
return nil, errors.New(429, "RATE_LIMIT", err.Error())
|
||
}
|
||
|
||
return handler(ctx, req)
|
||
}
|
||
}
|
||
}
|
||
|
||
// DefaultGenerationKey 默认生成key(基于IP)
|
||
func DefaultGenerationKey(ctx context.Context, path string) string {
|
||
if tr, ok := transport.FromServerContext(ctx); ok {
|
||
if ht, ok := tr.(kratoshttp.Transporter); ok {
|
||
return "Limit_" + getClientIP(ht.Request())
|
||
}
|
||
}
|
||
return "Limit_unknown"
|
||
}
|
||
|
||
// NewRedisCheckOrMark 创建基于Redis的检查函数
|
||
func NewRedisCheckOrMark(rdb *redis.Client) func(ctx context.Context, key string, expire int, limit int) error {
|
||
return func(ctx context.Context, key string, expire int, limit int) error {
|
||
if rdb == nil {
|
||
return nil
|
||
}
|
||
return SetLimitWithTime(ctx, rdb, key, limit, time.Duration(expire)*time.Second)
|
||
}
|
||
}
|
||
|
||
// SetLimitWithTime 设置访问次数
|
||
func SetLimitWithTime(ctx context.Context, rdb *redis.Client, key string, limit int, expiration time.Duration) error {
|
||
count, err := rdb.Exists(ctx, key).Result()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
if count == 0 {
|
||
// key不存在,创建并设置过期时间
|
||
pipe := rdb.TxPipeline()
|
||
pipe.Incr(ctx, key)
|
||
pipe.Expire(ctx, key, expiration)
|
||
_, err = pipe.Exec(ctx)
|
||
return err
|
||
}
|
||
|
||
// key存在,检查次数
|
||
times, err := rdb.Get(ctx, key).Int()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
if times >= limit {
|
||
// 获取剩余时间
|
||
ttl, err := rdb.PTTL(ctx, key).Result()
|
||
if err != nil {
|
||
return fmt.Errorf("请求太过频繁,请稍后再试")
|
||
}
|
||
return fmt.Errorf("请求太过频繁, 请 %s 后尝试", ttl.String())
|
||
}
|
||
|
||
return rdb.Incr(ctx, key).Err()
|
||
}
|
||
|
||
// DefaultRateLimit 默认限流中间件
|
||
func DefaultRateLimit(rdb *redis.Client, expire, limit int) middleware.Middleware {
|
||
return RateLimit(LimitConfig{
|
||
GenerationKey: DefaultGenerationKey,
|
||
CheckOrMark: NewRedisCheckOrMark(rdb),
|
||
Expire: expire,
|
||
Limit: limit,
|
||
})
|
||
}
|