任务二完成
This commit is contained in:
parent
d572a03654
commit
7d27c70453
|
|
@ -0,0 +1,64 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/go-kratos/kratos/v2/middleware"
|
||||||
|
"github.com/go-kratos/kratos/v2/transport"
|
||||||
|
|
||||||
|
pkgcasbin "kra/pkg/casbin"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrPermissionDenied = errors.New("权限不足")
|
||||||
|
)
|
||||||
|
|
||||||
|
// CasbinRBAC Casbin权限中间件
|
||||||
|
func CasbinRBAC(routerPrefix string) middleware.Middleware {
|
||||||
|
return func(handler middleware.Handler) middleware.Handler {
|
||||||
|
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
claims, ok := GetClaims(ctx)
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrMissingToken
|
||||||
|
}
|
||||||
|
|
||||||
|
if tr, ok := transport.FromServerContext(ctx); ok {
|
||||||
|
// 获取请求路径
|
||||||
|
path := tr.Operation()
|
||||||
|
if routerPrefix != "" {
|
||||||
|
path = strings.TrimPrefix(path, routerPrefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取请求方法
|
||||||
|
act := "GET"
|
||||||
|
if header := tr.RequestHeader(); header != nil {
|
||||||
|
if method := header.Get(":method"); method != "" {
|
||||||
|
act = method
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取用户角色
|
||||||
|
sub := strconv.Itoa(int(claims.AuthorityID))
|
||||||
|
|
||||||
|
// 检查权限
|
||||||
|
enforcer := pkgcasbin.GetEnforcer()
|
||||||
|
if enforcer == nil {
|
||||||
|
return handler(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
success, err := enforcer.Enforce(sub, path, act)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
return nil, ErrPermissionDenied
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return handler(ctx, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,107 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/go-kratos/kratos/v2/errors"
|
||||||
|
"github.com/go-kratos/kratos/v2/middleware"
|
||||||
|
"github.com/go-kratos/kratos/v2/transport"
|
||||||
|
kratoshttp "github.com/go-kratos/kratos/v2/transport/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CORSWhitelist CORS白名单配置
|
||||||
|
type CORSWhitelist struct {
|
||||||
|
AllowOrigin string
|
||||||
|
AllowHeaders string
|
||||||
|
AllowMethods string
|
||||||
|
ExposeHeaders string
|
||||||
|
AllowCredentials bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// CORSConfig CORS配置
|
||||||
|
type CORSConfig struct {
|
||||||
|
// Mode 模式: allow-all, whitelist, strict-whitelist
|
||||||
|
Mode string
|
||||||
|
Whitelist []CORSWhitelist
|
||||||
|
}
|
||||||
|
|
||||||
|
// CORS 跨域中间件(放行所有)
|
||||||
|
func CORS() middleware.Middleware {
|
||||||
|
return func(handler middleware.Handler) middleware.Handler {
|
||||||
|
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
if tr, ok := transport.FromServerContext(ctx); ok {
|
||||||
|
if ht, ok := tr.(kratoshttp.Transporter); ok {
|
||||||
|
origin := ht.Request().Header.Get("Origin")
|
||||||
|
header := ht.ReplyHeader()
|
||||||
|
header.Set("Access-Control-Allow-Origin", origin)
|
||||||
|
header.Set("Access-Control-Allow-Headers", "Content-Type,AccessToken,X-CSRF-Token,Authorization,Token,X-Token,X-User-Id")
|
||||||
|
header.Set("Access-Control-Allow-Methods", "POST,GET,OPTIONS,DELETE,PUT")
|
||||||
|
header.Set("Access-Control-Expose-Headers", "Content-Length,Access-Control-Allow-Origin,Access-Control-Allow-Headers,Content-Type,New-Token,New-Expires-At")
|
||||||
|
header.Set("Access-Control-Allow-Credentials", "true")
|
||||||
|
|
||||||
|
// OPTIONS请求直接返回
|
||||||
|
if ht.Request().Method == http.MethodOptions {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return handler(ctx, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CORSByRules 按配置规则处理跨域
|
||||||
|
func CORSByRules(cfg CORSConfig) middleware.Middleware {
|
||||||
|
// 放行全部
|
||||||
|
if cfg.Mode == "allow-all" {
|
||||||
|
return CORS()
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(handler middleware.Handler) middleware.Handler {
|
||||||
|
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
if tr, ok := transport.FromServerContext(ctx); ok {
|
||||||
|
if ht, ok := tr.(kratoshttp.Transporter); ok {
|
||||||
|
r := ht.Request()
|
||||||
|
origin := r.Header.Get("Origin")
|
||||||
|
whitelist := checkCors(origin, cfg.Whitelist)
|
||||||
|
|
||||||
|
header := ht.ReplyHeader()
|
||||||
|
// 通过检查,添加请求头
|
||||||
|
if whitelist != nil {
|
||||||
|
header.Set("Access-Control-Allow-Origin", whitelist.AllowOrigin)
|
||||||
|
header.Set("Access-Control-Allow-Headers", whitelist.AllowHeaders)
|
||||||
|
header.Set("Access-Control-Allow-Methods", whitelist.AllowMethods)
|
||||||
|
header.Set("Access-Control-Expose-Headers", whitelist.ExposeHeaders)
|
||||||
|
if whitelist.AllowCredentials {
|
||||||
|
header.Set("Access-Control-Allow-Credentials", "true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 严格白名单模式且未通过检查,直接拒绝
|
||||||
|
if whitelist == nil && cfg.Mode == "strict-whitelist" {
|
||||||
|
// 健康检查放行
|
||||||
|
if !(r.Method == http.MethodGet && r.URL.Path == "/health") {
|
||||||
|
return nil, errors.Forbidden("CORS_FORBIDDEN", "forbidden")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OPTIONS请求直接返回
|
||||||
|
if r.Method == http.MethodOptions {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return handler(ctx, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkCors(currentOrigin string, whitelist []CORSWhitelist) *CORSWhitelist {
|
||||||
|
for _, w := range whitelist {
|
||||||
|
if currentOrigin == w.AllowOrigin {
|
||||||
|
return &w
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,78 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-kratos/kratos/v2/log"
|
||||||
|
"github.com/go-kratos/kratos/v2/middleware"
|
||||||
|
"github.com/go-kratos/kratos/v2/transport"
|
||||||
|
kratoshttp "github.com/go-kratos/kratos/v2/transport/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EmailSender 邮件发送接口
|
||||||
|
type EmailSender interface {
|
||||||
|
SendError(subject, body string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorToEmail 错误邮件通知中间件
|
||||||
|
func ErrorToEmail(sender EmailSender, logger log.Logger) middleware.Middleware {
|
||||||
|
helper := log.NewHelper(logger)
|
||||||
|
return func(handler middleware.Handler) middleware.Handler {
|
||||||
|
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
var (
|
||||||
|
username string
|
||||||
|
body string
|
||||||
|
method string
|
||||||
|
path string
|
||||||
|
ip string
|
||||||
|
)
|
||||||
|
|
||||||
|
// 获取用户信息
|
||||||
|
if claims, ok := GetClaims(ctx); ok {
|
||||||
|
username = claims.Username
|
||||||
|
}
|
||||||
|
if username == "" {
|
||||||
|
username = "Unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
if tr, ok := transport.FromServerContext(ctx); ok {
|
||||||
|
path = tr.Operation()
|
||||||
|
if header := tr.RequestHeader(); header != nil {
|
||||||
|
method = header.Get(":method")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ht, ok := tr.(kratoshttp.Transporter); ok {
|
||||||
|
r := ht.Request()
|
||||||
|
ip = getClientIP(r)
|
||||||
|
|
||||||
|
// 读取body
|
||||||
|
bodyBytes, _ := io.ReadAll(r.Body)
|
||||||
|
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
|
body = string(bodyBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
reply, err := handler(ctx, req)
|
||||||
|
latency := time.Since(start)
|
||||||
|
|
||||||
|
// 如果有错误,发送邮件
|
||||||
|
if err != nil && sender != nil {
|
||||||
|
subject := fmt.Sprintf("%s %s调用了%s报错了", username, ip, path)
|
||||||
|
content := fmt.Sprintf(
|
||||||
|
"接收到的请求为%s\n请求方式为%s\n报错信息如下%s\n耗时%s\n",
|
||||||
|
body, method, err.Error(), latency.String(),
|
||||||
|
)
|
||||||
|
if sendErr := sender.SendError(subject, content); sendErr != nil {
|
||||||
|
helper.Error("ErrorToEmail Failed, err:", sendErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return reply, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,156 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-kratos/kratos/v2/middleware"
|
||||||
|
"github.com/go-kratos/kratos/v2/transport"
|
||||||
|
"github.com/go-kratos/kratos/v2/transport/http"
|
||||||
|
|
||||||
|
pkgjwt "kra/pkg/jwt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type authKey struct{}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrMissingToken = errors.New("未登录或非法访问")
|
||||||
|
ErrInvalidToken = errors.New("无效的token")
|
||||||
|
ErrBlacklistJWT = errors.New("您的帐户异地登陆或令牌失效")
|
||||||
|
)
|
||||||
|
|
||||||
|
// BlacklistChecker JWT黑名单检查接口
|
||||||
|
type BlacklistChecker interface {
|
||||||
|
IsBlacklist(jwt string) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWTAuthConfig JWT认证配置
|
||||||
|
type JWTAuthConfig struct {
|
||||||
|
JWT *pkgjwt.JWT
|
||||||
|
BlacklistChecker BlacklistChecker
|
||||||
|
UseMultipoint bool
|
||||||
|
SetRedisJWT func(token, username string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWTAuth JWT认证中间件
|
||||||
|
func JWTAuth(cfg JWTAuthConfig) middleware.Middleware {
|
||||||
|
return func(handler middleware.Handler) middleware.Handler {
|
||||||
|
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
if tr, ok := transport.FromServerContext(ctx); ok {
|
||||||
|
token := extractToken(tr)
|
||||||
|
if token == "" {
|
||||||
|
return nil, ErrMissingToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查黑名单
|
||||||
|
if cfg.BlacklistChecker != nil && cfg.BlacklistChecker.IsBlacklist(token) {
|
||||||
|
// 清除token
|
||||||
|
if ht, ok := tr.(http.Transporter); ok {
|
||||||
|
clearToken(ht)
|
||||||
|
}
|
||||||
|
return nil, ErrBlacklistJWT
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, err := cfg.JWT.ParseToken(token)
|
||||||
|
if err != nil {
|
||||||
|
if ht, ok := tr.(http.Transporter); ok {
|
||||||
|
clearToken(ht)
|
||||||
|
}
|
||||||
|
if errors.Is(err, pkgjwt.ErrTokenExpired) {
|
||||||
|
return nil, errors.New("登录已过期,请重新登录")
|
||||||
|
}
|
||||||
|
return nil, ErrInvalidToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// 将claims存入context
|
||||||
|
ctx = context.WithValue(ctx, authKey{}, claims)
|
||||||
|
|
||||||
|
// 检查是否需要刷新token
|
||||||
|
if cfg.JWT.NeedRefresh(claims) {
|
||||||
|
newToken, newClaims, err := cfg.JWT.CreateTokenByOldToken(token, *claims)
|
||||||
|
if err == nil {
|
||||||
|
if ht, ok := tr.(http.Transporter); ok {
|
||||||
|
ht.ReplyHeader().Set("new-token", newToken)
|
||||||
|
ht.ReplyHeader().Set("new-expires-at", strconv.FormatInt(newClaims.ExpiresAt.Unix(), 10))
|
||||||
|
setToken(ht, newToken, int(time.Until(newClaims.ExpiresAt.Time).Seconds()/60))
|
||||||
|
}
|
||||||
|
// 多点登录记录新JWT
|
||||||
|
if cfg.UseMultipoint && cfg.SetRedisJWT != nil {
|
||||||
|
_ = cfg.SetRedisJWT(newToken, newClaims.Username)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return handler(ctx, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWTAuthSimple 简化版JWT认证中间件(不含黑名单检查)
|
||||||
|
func JWTAuthSimple(j *pkgjwt.JWT) middleware.Middleware {
|
||||||
|
return JWTAuth(JWTAuthConfig{JWT: j})
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractToken 从请求中提取token
|
||||||
|
func extractToken(tr transport.Transporter) string {
|
||||||
|
// 优先从Authorization头获取
|
||||||
|
auth := tr.RequestHeader().Get("Authorization")
|
||||||
|
if auth != "" {
|
||||||
|
if strings.HasPrefix(auth, "Bearer ") {
|
||||||
|
return strings.TrimPrefix(auth, "Bearer ")
|
||||||
|
}
|
||||||
|
return auth
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从x-token头获取
|
||||||
|
token := tr.RequestHeader().Get("x-token")
|
||||||
|
if token != "" {
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// clearToken 清除token
|
||||||
|
func clearToken(ht http.Transporter) {
|
||||||
|
ht.ReplyHeader().Set("x-token", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// setToken 设置token到cookie
|
||||||
|
func setToken(ht http.Transporter, token string, maxAge int) {
|
||||||
|
// 通过header设置cookie信息,让前端处理
|
||||||
|
ht.ReplyHeader().Set("x-token", token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClaims 从context获取claims
|
||||||
|
func GetClaims(ctx context.Context) (*pkgjwt.CustomClaims, bool) {
|
||||||
|
claims, ok := ctx.Value(authKey{}).(*pkgjwt.CustomClaims)
|
||||||
|
return claims, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserID 从context获取用户ID
|
||||||
|
func GetUserID(ctx context.Context) uint {
|
||||||
|
if claims, ok := GetClaims(ctx); ok {
|
||||||
|
return claims.BaseClaims.ID
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUsername 从context获取用户名
|
||||||
|
func GetUsername(ctx context.Context) string {
|
||||||
|
if claims, ok := GetClaims(ctx); ok {
|
||||||
|
return claims.Username
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuthorityID 从context获取角色ID
|
||||||
|
func GetAuthorityID(ctx context.Context) uint {
|
||||||
|
if claims, ok := GetClaims(ctx); ok {
|
||||||
|
return claims.AuthorityID
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,133 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-kratos/kratos/v2/log"
|
||||||
|
"github.com/go-kratos/kratos/v2/middleware"
|
||||||
|
"github.com/go-kratos/kratos/v2/transport"
|
||||||
|
kratoshttp "github.com/go-kratos/kratos/v2/transport/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LogLayout 日志layout
|
||||||
|
type LogLayout struct {
|
||||||
|
Time time.Time `json:"time"`
|
||||||
|
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||||
|
Path string `json:"path"`
|
||||||
|
Query string `json:"query,omitempty"`
|
||||||
|
Body string `json:"body,omitempty"`
|
||||||
|
IP string `json:"ip"`
|
||||||
|
UserAgent string `json:"user_agent"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
Cost time.Duration `json:"cost"`
|
||||||
|
Source string `json:"source"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoggerConfig 日志中间件配置
|
||||||
|
type LoggerConfig struct {
|
||||||
|
// Filter 用户自定义过滤
|
||||||
|
Filter func(ctx context.Context, path string) bool
|
||||||
|
// FilterKeyword 关键字过滤
|
||||||
|
FilterKeyword func(layout *LogLayout) bool
|
||||||
|
// AuthProcess 鉴权处理
|
||||||
|
AuthProcess func(ctx context.Context, layout *LogLayout)
|
||||||
|
// Print 日志处理
|
||||||
|
Print func(LogLayout)
|
||||||
|
// Source 服务唯一标识
|
||||||
|
Source string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logger 日志中间件
|
||||||
|
func Logger(cfg LoggerConfig) middleware.Middleware {
|
||||||
|
if cfg.Print == nil {
|
||||||
|
cfg.Print = func(layout LogLayout) {
|
||||||
|
v, _ := json.Marshal(layout)
|
||||||
|
log.Info(string(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cfg.Source == "" {
|
||||||
|
cfg.Source = "Kratos"
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(handler middleware.Handler) middleware.Handler {
|
||||||
|
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
var (
|
||||||
|
path string
|
||||||
|
query string
|
||||||
|
body string
|
||||||
|
ip string
|
||||||
|
userAgent string
|
||||||
|
)
|
||||||
|
|
||||||
|
if tr, ok := transport.FromServerContext(ctx); ok {
|
||||||
|
path = tr.Operation()
|
||||||
|
if header := tr.RequestHeader(); header != nil {
|
||||||
|
userAgent = header.Get("User-Agent")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ht, ok := tr.(kratoshttp.Transporter); ok {
|
||||||
|
r := ht.Request()
|
||||||
|
query = r.URL.RawQuery
|
||||||
|
ip = getClientIP(r)
|
||||||
|
|
||||||
|
// 读取body(仅在未过滤时)
|
||||||
|
if cfg.Filter == nil || !cfg.Filter(ctx, path) {
|
||||||
|
bodyBytes, _ := io.ReadAll(r.Body)
|
||||||
|
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
|
body = string(bodyBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否需要过滤
|
||||||
|
if cfg.Filter != nil && cfg.Filter(ctx, path) {
|
||||||
|
return handler(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
reply, err := handler(ctx, req)
|
||||||
|
cost := time.Since(start)
|
||||||
|
|
||||||
|
layout := LogLayout{
|
||||||
|
Time: time.Now(),
|
||||||
|
Path: path,
|
||||||
|
Query: query,
|
||||||
|
Body: body,
|
||||||
|
IP: ip,
|
||||||
|
UserAgent: userAgent,
|
||||||
|
Cost: cost,
|
||||||
|
Source: cfg.Source,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
layout.Error = err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理鉴权信息
|
||||||
|
if cfg.AuthProcess != nil {
|
||||||
|
cfg.AuthProcess(ctx, &layout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 关键字过滤
|
||||||
|
if cfg.FilterKeyword != nil {
|
||||||
|
cfg.FilterKeyword(&layout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 输出日志
|
||||||
|
cfg.Print(layout)
|
||||||
|
|
||||||
|
return reply, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultLogger 默认日志中间件
|
||||||
|
func DefaultLogger() middleware.Middleware {
|
||||||
|
return Logger(LoggerConfig{
|
||||||
|
Source: "Kratos",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,155 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-kratos/kratos/v2/log"
|
||||||
|
"github.com/go-kratos/kratos/v2/middleware"
|
||||||
|
"github.com/go-kratos/kratos/v2/transport"
|
||||||
|
kratoshttp "github.com/go-kratos/kratos/v2/transport/http"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
const maxBodySize = 1024
|
||||||
|
|
||||||
|
// OperationRecord 操作记录
|
||||||
|
type OperationRecord struct {
|
||||||
|
ID uint `gorm:"primarykey"`
|
||||||
|
CreatedAt time.Time `gorm:"index"`
|
||||||
|
Ip string `gorm:"column:ip;comment:请求ip"`
|
||||||
|
Method string `gorm:"column:method;comment:请求方法"`
|
||||||
|
Path string `gorm:"column:path;comment:请求路径"`
|
||||||
|
Status int `gorm:"column:status;comment:请求状态"`
|
||||||
|
Latency time.Duration `gorm:"column:latency;comment:延迟"`
|
||||||
|
Agent string `gorm:"column:agent;comment:代理"`
|
||||||
|
ErrorMessage string `gorm:"column:error_message;comment:错误信息"`
|
||||||
|
Body string `gorm:"type:text;column:body;comment:请求Body"`
|
||||||
|
Resp string `gorm:"type:text;column:resp;comment:响应Body"`
|
||||||
|
UserID int `gorm:"column:user_id;comment:用户id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (OperationRecord) TableName() string {
|
||||||
|
return "sys_operation_records"
|
||||||
|
}
|
||||||
|
|
||||||
|
// OperationLog 操作日志中间件
|
||||||
|
func OperationLog(db *gorm.DB, logger log.Logger) middleware.Middleware {
|
||||||
|
return func(handler middleware.Handler) middleware.Handler {
|
||||||
|
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
var (
|
||||||
|
body string
|
||||||
|
userID int
|
||||||
|
ip string
|
||||||
|
method string
|
||||||
|
path string
|
||||||
|
agent string
|
||||||
|
)
|
||||||
|
|
||||||
|
if tr, ok := transport.FromServerContext(ctx); ok {
|
||||||
|
path = tr.Operation()
|
||||||
|
if header := tr.RequestHeader(); header != nil {
|
||||||
|
method = header.Get(":method")
|
||||||
|
agent = header.Get("User-Agent")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取HTTP请求信息
|
||||||
|
if ht, ok := tr.(kratoshttp.Transporter); ok {
|
||||||
|
r := ht.Request()
|
||||||
|
ip = getClientIP(r)
|
||||||
|
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
bodyBytes, _ := io.ReadAll(r.Body)
|
||||||
|
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
|
if len(bodyBytes) > maxBodySize {
|
||||||
|
body = "[超出记录长度]"
|
||||||
|
} else {
|
||||||
|
body = string(bodyBytes)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
query := r.URL.RawQuery
|
||||||
|
query, _ = url.QueryUnescape(query)
|
||||||
|
params := parseQuery(query)
|
||||||
|
bodyBytes, _ := json.Marshal(params)
|
||||||
|
body = string(bodyBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否为文件上传
|
||||||
|
if strings.Contains(r.Header.Get("Content-Type"), "multipart/form-data") {
|
||||||
|
body = "[文件]"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取用户ID
|
||||||
|
if claims, ok := GetClaims(ctx); ok {
|
||||||
|
userID = int(claims.BaseClaims.ID)
|
||||||
|
} else {
|
||||||
|
// 尝试从header获取
|
||||||
|
if tr, ok := transport.FromServerContext(ctx); ok {
|
||||||
|
if ht, ok := tr.(kratoshttp.Transporter); ok {
|
||||||
|
if idStr := ht.Request().Header.Get("x-user-id"); idStr != "" {
|
||||||
|
if id, err := strconv.Atoi(idStr); err == nil {
|
||||||
|
userID = id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
reply, err := handler(ctx, req)
|
||||||
|
latency := time.Since(start)
|
||||||
|
|
||||||
|
// 记录响应
|
||||||
|
var resp string
|
||||||
|
if reply != nil {
|
||||||
|
respBytes, _ := json.Marshal(reply)
|
||||||
|
if len(respBytes) > maxBodySize {
|
||||||
|
resp = "[超出记录长度]"
|
||||||
|
} else {
|
||||||
|
resp = string(respBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录错误
|
||||||
|
var errMsg string
|
||||||
|
status := 200
|
||||||
|
if err != nil {
|
||||||
|
errMsg = err.Error()
|
||||||
|
status = 500
|
||||||
|
}
|
||||||
|
|
||||||
|
// 保存记录
|
||||||
|
record := OperationRecord{
|
||||||
|
Ip: ip,
|
||||||
|
Method: method,
|
||||||
|
Path: path,
|
||||||
|
Status: status,
|
||||||
|
Latency: latency,
|
||||||
|
Agent: agent,
|
||||||
|
ErrorMessage: errMsg,
|
||||||
|
Body: body,
|
||||||
|
Resp: resp,
|
||||||
|
UserID: userID,
|
||||||
|
}
|
||||||
|
|
||||||
|
if db != nil {
|
||||||
|
go func() {
|
||||||
|
if err := db.Create(&record).Error; err != nil {
|
||||||
|
log.NewHelper(logger).Error("create operation record error:", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
return reply, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,108 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-kratos/kratos/v2/errors"
|
||||||
|
"github.com/go-kratos/kratos/v2/middleware"
|
||||||
|
"github.com/go-kratos/kratos/v2/transport"
|
||||||
|
kratoshttp "github.com/go-kratos/kratos/v2/transport/http"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LimitConfig 限流配置
|
||||||
|
type LimitConfig struct {
|
||||||
|
// GenerationKey 根据业务生成key
|
||||||
|
GenerationKey func(ctx context.Context, path string) string
|
||||||
|
// CheckOrMark 检查函数
|
||||||
|
CheckOrMark func(ctx context.Context, key string, expire int, limit int) error
|
||||||
|
// Expire key过期时间(秒)
|
||||||
|
Expire int
|
||||||
|
// Limit 周期内限制次数
|
||||||
|
Limit int
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateLimit 限流中间件
|
||||||
|
func RateLimit(cfg LimitConfig) middleware.Middleware {
|
||||||
|
return func(handler middleware.Handler) middleware.Handler {
|
||||||
|
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
var key string
|
||||||
|
if tr, ok := transport.FromServerContext(ctx); ok {
|
||||||
|
path := tr.Operation()
|
||||||
|
key = cfg.GenerationKey(ctx, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cfg.CheckOrMark(ctx, key, cfg.Expire, cfg.Limit); err != nil {
|
||||||
|
return nil, errors.New(429, "RATE_LIMIT", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return handler(ctx, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultGenerationKey 默认生成key(基于IP)
|
||||||
|
func DefaultGenerationKey(ctx context.Context, path string) string {
|
||||||
|
if tr, ok := transport.FromServerContext(ctx); ok {
|
||||||
|
if ht, ok := tr.(kratoshttp.Transporter); ok {
|
||||||
|
return "Limit_" + getClientIP(ht.Request())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "Limit_unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRedisCheckOrMark 创建基于Redis的检查函数
|
||||||
|
func NewRedisCheckOrMark(rdb *redis.Client) func(ctx context.Context, key string, expire int, limit int) error {
|
||||||
|
return func(ctx context.Context, key string, expire int, limit int) error {
|
||||||
|
if rdb == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return SetLimitWithTime(ctx, rdb, key, limit, time.Duration(expire)*time.Second)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLimitWithTime 设置访问次数
|
||||||
|
func SetLimitWithTime(ctx context.Context, rdb *redis.Client, key string, limit int, expiration time.Duration) error {
|
||||||
|
count, err := rdb.Exists(ctx, key).Result()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if count == 0 {
|
||||||
|
// key不存在,创建并设置过期时间
|
||||||
|
pipe := rdb.TxPipeline()
|
||||||
|
pipe.Incr(ctx, key)
|
||||||
|
pipe.Expire(ctx, key, expiration)
|
||||||
|
_, err = pipe.Exec(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// key存在,检查次数
|
||||||
|
times, err := rdb.Get(ctx, key).Int()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if times >= limit {
|
||||||
|
// 获取剩余时间
|
||||||
|
ttl, err := rdb.PTTL(ctx, key).Result()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("请求太过频繁,请稍后再试")
|
||||||
|
}
|
||||||
|
return fmt.Errorf("请求太过频繁, 请 %s 后尝试", ttl.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return rdb.Incr(ctx, key).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultRateLimit 默认限流中间件
|
||||||
|
func DefaultRateLimit(rdb *redis.Client, expire, limit int) middleware.Middleware {
|
||||||
|
return RateLimit(LimitConfig{
|
||||||
|
GenerationKey: DefaultGenerationKey,
|
||||||
|
CheckOrMark: NewRedisCheckOrMark(rdb),
|
||||||
|
Expire: expire,
|
||||||
|
Limit: limit,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,55 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"runtime/debug"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/go-kratos/kratos/v2/errors"
|
||||||
|
"github.com/go-kratos/kratos/v2/log"
|
||||||
|
"github.com/go-kratos/kratos/v2/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Recovery panic恢复中间件
|
||||||
|
func Recovery(logger log.Logger) middleware.Middleware {
|
||||||
|
helper := log.NewHelper(logger)
|
||||||
|
return func(handler middleware.Handler) middleware.Handler {
|
||||||
|
return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
|
||||||
|
defer func() {
|
||||||
|
if rerr := recover(); rerr != nil {
|
||||||
|
// 检查是否为断开的连接
|
||||||
|
var brokenPipe bool
|
||||||
|
if ne, ok := rerr.(*net.OpError); ok {
|
||||||
|
if se, ok := ne.Err.(*os.SyscallError); ok {
|
||||||
|
errStr := strings.ToLower(se.Error())
|
||||||
|
if strings.Contains(errStr, "broken pipe") ||
|
||||||
|
strings.Contains(errStr, "connection reset by peer") {
|
||||||
|
brokenPipe = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stack := string(debug.Stack())
|
||||||
|
if brokenPipe {
|
||||||
|
helper.Errorw(
|
||||||
|
"msg", "[Recovery from panic]",
|
||||||
|
"error", rerr,
|
||||||
|
"broken_pipe", true,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
helper.Errorw(
|
||||||
|
"msg", "[Recovery from panic]",
|
||||||
|
"error", rerr,
|
||||||
|
"stack", stack,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = errors.InternalServer("PANIC", "服务器内部错误")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return handler(ctx, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,53 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-kratos/kratos/v2/errors"
|
||||||
|
"github.com/go-kratos/kratos/v2/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Timeout 超时中间件
|
||||||
|
func Timeout(timeout time.Duration) middleware.Middleware {
|
||||||
|
return func(handler middleware.Handler) middleware.Handler {
|
||||||
|
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// 使用 buffered channel 避免 goroutine 泄漏
|
||||||
|
type result struct {
|
||||||
|
reply interface{}
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
done := make(chan result, 1)
|
||||||
|
panicChan := make(chan interface{}, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
if p := recover(); p != nil {
|
||||||
|
select {
|
||||||
|
case panicChan <- p:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
reply, err := handler(ctx, req)
|
||||||
|
select {
|
||||||
|
case done <- result{reply: reply, err: err}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case p := <-panicChan:
|
||||||
|
panic(p)
|
||||||
|
case r := <-done:
|
||||||
|
return r.reply, r.err
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, errors.GatewayTimeout("TIMEOUT", "请求超时")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,43 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// getClientIP 获取客户端IP
|
||||||
|
func getClientIP(r *http.Request) string {
|
||||||
|
// X-Forwarded-For
|
||||||
|
xff := r.Header.Get("X-Forwarded-For")
|
||||||
|
if xff != "" {
|
||||||
|
ips := strings.Split(xff, ",")
|
||||||
|
if len(ips) > 0 {
|
||||||
|
return strings.TrimSpace(ips[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// X-Real-IP
|
||||||
|
xri := r.Header.Get("X-Real-IP")
|
||||||
|
if xri != "" {
|
||||||
|
return xri
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoteAddr
|
||||||
|
ip := r.RemoteAddr
|
||||||
|
if idx := strings.LastIndex(ip, ":"); idx != -1 {
|
||||||
|
ip = ip[:idx]
|
||||||
|
}
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseQuery 解析query参数
|
||||||
|
func parseQuery(query string) map[string]string {
|
||||||
|
m := make(map[string]string)
|
||||||
|
for _, v := range strings.Split(query, "&") {
|
||||||
|
kv := strings.Split(v, "=")
|
||||||
|
if len(kv) == 2 {
|
||||||
|
m[kv[0]] = kv[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
@ -103,5 +103,20 @@ func (j *JWT) ParseToken(tokenString string) (*CustomClaims, error) {
|
||||||
|
|
||||||
// NeedRefresh 判断是否需要刷新token
|
// NeedRefresh 判断是否需要刷新token
|
||||||
func (j *JWT) NeedRefresh(claims *CustomClaims) bool {
|
func (j *JWT) NeedRefresh(claims *CustomClaims) bool {
|
||||||
return time.Until(claims.ExpiresAt.Time) < j.BufferTime
|
return claims.ExpiresAt.Unix()-time.Now().Unix() < claims.BufferTime
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTokenByOldToken 根据旧token创建新token
|
||||||
|
func (j *JWT) CreateTokenByOldToken(oldToken string, oldClaims CustomClaims) (string, *CustomClaims, error) {
|
||||||
|
// 更新过期时间
|
||||||
|
oldClaims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(j.ExpiresAt))
|
||||||
|
oldClaims.IssuedAt = jwt.NewNumericDate(time.Now())
|
||||||
|
oldClaims.NotBefore = jwt.NewNumericDate(time.Now().Add(-1000))
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, oldClaims)
|
||||||
|
newToken, err := token.SignedString(j.SigningKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
return newToken, &oldClaims, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue