任务二完成
This commit is contained in:
parent
d572a03654
commit
7d27c70453
|
|
@ -0,0 +1,64 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/middleware"
|
||||
"github.com/go-kratos/kratos/v2/transport"
|
||||
|
||||
pkgcasbin "kra/pkg/casbin"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPermissionDenied = errors.New("权限不足")
|
||||
)
|
||||
|
||||
// CasbinRBAC Casbin权限中间件
|
||||
func CasbinRBAC(routerPrefix string) middleware.Middleware {
|
||||
return func(handler middleware.Handler) middleware.Handler {
|
||||
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
claims, ok := GetClaims(ctx)
|
||||
if !ok {
|
||||
return nil, ErrMissingToken
|
||||
}
|
||||
|
||||
if tr, ok := transport.FromServerContext(ctx); ok {
|
||||
// 获取请求路径
|
||||
path := tr.Operation()
|
||||
if routerPrefix != "" {
|
||||
path = strings.TrimPrefix(path, routerPrefix)
|
||||
}
|
||||
|
||||
// 获取请求方法
|
||||
act := "GET"
|
||||
if header := tr.RequestHeader(); header != nil {
|
||||
if method := header.Get(":method"); method != "" {
|
||||
act = method
|
||||
}
|
||||
}
|
||||
|
||||
// 获取用户角色
|
||||
sub := strconv.Itoa(int(claims.AuthorityID))
|
||||
|
||||
// 检查权限
|
||||
enforcer := pkgcasbin.GetEnforcer()
|
||||
if enforcer == nil {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
success, err := enforcer.Enforce(sub, path, act)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !success {
|
||||
return nil, ErrPermissionDenied
|
||||
}
|
||||
}
|
||||
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/errors"
|
||||
"github.com/go-kratos/kratos/v2/middleware"
|
||||
"github.com/go-kratos/kratos/v2/transport"
|
||||
kratoshttp "github.com/go-kratos/kratos/v2/transport/http"
|
||||
)
|
||||
|
||||
// CORSWhitelist CORS白名单配置
|
||||
type CORSWhitelist struct {
|
||||
AllowOrigin string
|
||||
AllowHeaders string
|
||||
AllowMethods string
|
||||
ExposeHeaders string
|
||||
AllowCredentials bool
|
||||
}
|
||||
|
||||
// CORSConfig CORS配置
|
||||
type CORSConfig struct {
|
||||
// Mode 模式: allow-all, whitelist, strict-whitelist
|
||||
Mode string
|
||||
Whitelist []CORSWhitelist
|
||||
}
|
||||
|
||||
// CORS 跨域中间件(放行所有)
|
||||
func CORS() middleware.Middleware {
|
||||
return func(handler middleware.Handler) middleware.Handler {
|
||||
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
if tr, ok := transport.FromServerContext(ctx); ok {
|
||||
if ht, ok := tr.(kratoshttp.Transporter); ok {
|
||||
origin := ht.Request().Header.Get("Origin")
|
||||
header := ht.ReplyHeader()
|
||||
header.Set("Access-Control-Allow-Origin", origin)
|
||||
header.Set("Access-Control-Allow-Headers", "Content-Type,AccessToken,X-CSRF-Token,Authorization,Token,X-Token,X-User-Id")
|
||||
header.Set("Access-Control-Allow-Methods", "POST,GET,OPTIONS,DELETE,PUT")
|
||||
header.Set("Access-Control-Expose-Headers", "Content-Length,Access-Control-Allow-Origin,Access-Control-Allow-Headers,Content-Type,New-Token,New-Expires-At")
|
||||
header.Set("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
// OPTIONS请求直接返回
|
||||
if ht.Request().Method == http.MethodOptions {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CORSByRules 按配置规则处理跨域
|
||||
func CORSByRules(cfg CORSConfig) middleware.Middleware {
|
||||
// 放行全部
|
||||
if cfg.Mode == "allow-all" {
|
||||
return CORS()
|
||||
}
|
||||
|
||||
return func(handler middleware.Handler) middleware.Handler {
|
||||
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
if tr, ok := transport.FromServerContext(ctx); ok {
|
||||
if ht, ok := tr.(kratoshttp.Transporter); ok {
|
||||
r := ht.Request()
|
||||
origin := r.Header.Get("Origin")
|
||||
whitelist := checkCors(origin, cfg.Whitelist)
|
||||
|
||||
header := ht.ReplyHeader()
|
||||
// 通过检查,添加请求头
|
||||
if whitelist != nil {
|
||||
header.Set("Access-Control-Allow-Origin", whitelist.AllowOrigin)
|
||||
header.Set("Access-Control-Allow-Headers", whitelist.AllowHeaders)
|
||||
header.Set("Access-Control-Allow-Methods", whitelist.AllowMethods)
|
||||
header.Set("Access-Control-Expose-Headers", whitelist.ExposeHeaders)
|
||||
if whitelist.AllowCredentials {
|
||||
header.Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
}
|
||||
|
||||
// 严格白名单模式且未通过检查,直接拒绝
|
||||
if whitelist == nil && cfg.Mode == "strict-whitelist" {
|
||||
// 健康检查放行
|
||||
if !(r.Method == http.MethodGet && r.URL.Path == "/health") {
|
||||
return nil, errors.Forbidden("CORS_FORBIDDEN", "forbidden")
|
||||
}
|
||||
}
|
||||
|
||||
// OPTIONS请求直接返回
|
||||
if r.Method == http.MethodOptions {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkCors(currentOrigin string, whitelist []CORSWhitelist) *CORSWhitelist {
|
||||
for _, w := range whitelist {
|
||||
if currentOrigin == w.AllowOrigin {
|
||||
return &w
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,78 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/log"
|
||||
"github.com/go-kratos/kratos/v2/middleware"
|
||||
"github.com/go-kratos/kratos/v2/transport"
|
||||
kratoshttp "github.com/go-kratos/kratos/v2/transport/http"
|
||||
)
|
||||
|
||||
// EmailSender 邮件发送接口
|
||||
type EmailSender interface {
|
||||
SendError(subject, body string) error
|
||||
}
|
||||
|
||||
// ErrorToEmail 错误邮件通知中间件
|
||||
func ErrorToEmail(sender EmailSender, logger log.Logger) middleware.Middleware {
|
||||
helper := log.NewHelper(logger)
|
||||
return func(handler middleware.Handler) middleware.Handler {
|
||||
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
var (
|
||||
username string
|
||||
body string
|
||||
method string
|
||||
path string
|
||||
ip string
|
||||
)
|
||||
|
||||
// 获取用户信息
|
||||
if claims, ok := GetClaims(ctx); ok {
|
||||
username = claims.Username
|
||||
}
|
||||
if username == "" {
|
||||
username = "Unknown"
|
||||
}
|
||||
|
||||
if tr, ok := transport.FromServerContext(ctx); ok {
|
||||
path = tr.Operation()
|
||||
if header := tr.RequestHeader(); header != nil {
|
||||
method = header.Get(":method")
|
||||
}
|
||||
|
||||
if ht, ok := tr.(kratoshttp.Transporter); ok {
|
||||
r := ht.Request()
|
||||
ip = getClientIP(r)
|
||||
|
||||
// 读取body
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
body = string(bodyBytes)
|
||||
}
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
reply, err := handler(ctx, req)
|
||||
latency := time.Since(start)
|
||||
|
||||
// 如果有错误,发送邮件
|
||||
if err != nil && sender != nil {
|
||||
subject := fmt.Sprintf("%s %s调用了%s报错了", username, ip, path)
|
||||
content := fmt.Sprintf(
|
||||
"接收到的请求为%s\n请求方式为%s\n报错信息如下%s\n耗时%s\n",
|
||||
body, method, err.Error(), latency.String(),
|
||||
)
|
||||
if sendErr := sender.SendError(subject, content); sendErr != nil {
|
||||
helper.Error("ErrorToEmail Failed, err:", sendErr)
|
||||
}
|
||||
}
|
||||
|
||||
return reply, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,156 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/middleware"
|
||||
"github.com/go-kratos/kratos/v2/transport"
|
||||
"github.com/go-kratos/kratos/v2/transport/http"
|
||||
|
||||
pkgjwt "kra/pkg/jwt"
|
||||
)
|
||||
|
||||
type authKey struct{}
|
||||
|
||||
var (
|
||||
ErrMissingToken = errors.New("未登录或非法访问")
|
||||
ErrInvalidToken = errors.New("无效的token")
|
||||
ErrBlacklistJWT = errors.New("您的帐户异地登陆或令牌失效")
|
||||
)
|
||||
|
||||
// BlacklistChecker JWT黑名单检查接口
|
||||
type BlacklistChecker interface {
|
||||
IsBlacklist(jwt string) bool
|
||||
}
|
||||
|
||||
// JWTAuthConfig JWT认证配置
|
||||
type JWTAuthConfig struct {
|
||||
JWT *pkgjwt.JWT
|
||||
BlacklistChecker BlacklistChecker
|
||||
UseMultipoint bool
|
||||
SetRedisJWT func(token, username string) error
|
||||
}
|
||||
|
||||
// JWTAuth JWT认证中间件
|
||||
func JWTAuth(cfg JWTAuthConfig) middleware.Middleware {
|
||||
return func(handler middleware.Handler) middleware.Handler {
|
||||
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
if tr, ok := transport.FromServerContext(ctx); ok {
|
||||
token := extractToken(tr)
|
||||
if token == "" {
|
||||
return nil, ErrMissingToken
|
||||
}
|
||||
|
||||
// 检查黑名单
|
||||
if cfg.BlacklistChecker != nil && cfg.BlacklistChecker.IsBlacklist(token) {
|
||||
// 清除token
|
||||
if ht, ok := tr.(http.Transporter); ok {
|
||||
clearToken(ht)
|
||||
}
|
||||
return nil, ErrBlacklistJWT
|
||||
}
|
||||
|
||||
claims, err := cfg.JWT.ParseToken(token)
|
||||
if err != nil {
|
||||
if ht, ok := tr.(http.Transporter); ok {
|
||||
clearToken(ht)
|
||||
}
|
||||
if errors.Is(err, pkgjwt.ErrTokenExpired) {
|
||||
return nil, errors.New("登录已过期,请重新登录")
|
||||
}
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
// 将claims存入context
|
||||
ctx = context.WithValue(ctx, authKey{}, claims)
|
||||
|
||||
// 检查是否需要刷新token
|
||||
if cfg.JWT.NeedRefresh(claims) {
|
||||
newToken, newClaims, err := cfg.JWT.CreateTokenByOldToken(token, *claims)
|
||||
if err == nil {
|
||||
if ht, ok := tr.(http.Transporter); ok {
|
||||
ht.ReplyHeader().Set("new-token", newToken)
|
||||
ht.ReplyHeader().Set("new-expires-at", strconv.FormatInt(newClaims.ExpiresAt.Unix(), 10))
|
||||
setToken(ht, newToken, int(time.Until(newClaims.ExpiresAt.Time).Seconds()/60))
|
||||
}
|
||||
// 多点登录记录新JWT
|
||||
if cfg.UseMultipoint && cfg.SetRedisJWT != nil {
|
||||
_ = cfg.SetRedisJWT(newToken, newClaims.Username)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// JWTAuthSimple 简化版JWT认证中间件(不含黑名单检查)
|
||||
func JWTAuthSimple(j *pkgjwt.JWT) middleware.Middleware {
|
||||
return JWTAuth(JWTAuthConfig{JWT: j})
|
||||
}
|
||||
|
||||
// extractToken 从请求中提取token
|
||||
func extractToken(tr transport.Transporter) string {
|
||||
// 优先从Authorization头获取
|
||||
auth := tr.RequestHeader().Get("Authorization")
|
||||
if auth != "" {
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
return strings.TrimPrefix(auth, "Bearer ")
|
||||
}
|
||||
return auth
|
||||
}
|
||||
|
||||
// 从x-token头获取
|
||||
token := tr.RequestHeader().Get("x-token")
|
||||
if token != "" {
|
||||
return token
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// clearToken 清除token
|
||||
func clearToken(ht http.Transporter) {
|
||||
ht.ReplyHeader().Set("x-token", "")
|
||||
}
|
||||
|
||||
// setToken 设置token到cookie
|
||||
func setToken(ht http.Transporter, token string, maxAge int) {
|
||||
// 通过header设置cookie信息,让前端处理
|
||||
ht.ReplyHeader().Set("x-token", token)
|
||||
}
|
||||
|
||||
// GetClaims 从context获取claims
|
||||
func GetClaims(ctx context.Context) (*pkgjwt.CustomClaims, bool) {
|
||||
claims, ok := ctx.Value(authKey{}).(*pkgjwt.CustomClaims)
|
||||
return claims, ok
|
||||
}
|
||||
|
||||
// GetUserID 从context获取用户ID
|
||||
func GetUserID(ctx context.Context) uint {
|
||||
if claims, ok := GetClaims(ctx); ok {
|
||||
return claims.BaseClaims.ID
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetUsername 从context获取用户名
|
||||
func GetUsername(ctx context.Context) string {
|
||||
if claims, ok := GetClaims(ctx); ok {
|
||||
return claims.Username
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetAuthorityID 从context获取角色ID
|
||||
func GetAuthorityID(ctx context.Context) uint {
|
||||
if claims, ok := GetClaims(ctx); ok {
|
||||
return claims.AuthorityID
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
|
@ -0,0 +1,133 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/log"
|
||||
"github.com/go-kratos/kratos/v2/middleware"
|
||||
"github.com/go-kratos/kratos/v2/transport"
|
||||
kratoshttp "github.com/go-kratos/kratos/v2/transport/http"
|
||||
)
|
||||
|
||||
// LogLayout 日志layout
|
||||
type LogLayout struct {
|
||||
Time time.Time `json:"time"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
Path string `json:"path"`
|
||||
Query string `json:"query,omitempty"`
|
||||
Body string `json:"body,omitempty"`
|
||||
IP string `json:"ip"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Cost time.Duration `json:"cost"`
|
||||
Source string `json:"source"`
|
||||
}
|
||||
|
||||
// LoggerConfig 日志中间件配置
|
||||
type LoggerConfig struct {
|
||||
// Filter 用户自定义过滤
|
||||
Filter func(ctx context.Context, path string) bool
|
||||
// FilterKeyword 关键字过滤
|
||||
FilterKeyword func(layout *LogLayout) bool
|
||||
// AuthProcess 鉴权处理
|
||||
AuthProcess func(ctx context.Context, layout *LogLayout)
|
||||
// Print 日志处理
|
||||
Print func(LogLayout)
|
||||
// Source 服务唯一标识
|
||||
Source string
|
||||
}
|
||||
|
||||
// Logger 日志中间件
|
||||
func Logger(cfg LoggerConfig) middleware.Middleware {
|
||||
if cfg.Print == nil {
|
||||
cfg.Print = func(layout LogLayout) {
|
||||
v, _ := json.Marshal(layout)
|
||||
log.Info(string(v))
|
||||
}
|
||||
}
|
||||
if cfg.Source == "" {
|
||||
cfg.Source = "Kratos"
|
||||
}
|
||||
|
||||
return func(handler middleware.Handler) middleware.Handler {
|
||||
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
var (
|
||||
path string
|
||||
query string
|
||||
body string
|
||||
ip string
|
||||
userAgent string
|
||||
)
|
||||
|
||||
if tr, ok := transport.FromServerContext(ctx); ok {
|
||||
path = tr.Operation()
|
||||
if header := tr.RequestHeader(); header != nil {
|
||||
userAgent = header.Get("User-Agent")
|
||||
}
|
||||
|
||||
if ht, ok := tr.(kratoshttp.Transporter); ok {
|
||||
r := ht.Request()
|
||||
query = r.URL.RawQuery
|
||||
ip = getClientIP(r)
|
||||
|
||||
// 读取body(仅在未过滤时)
|
||||
if cfg.Filter == nil || !cfg.Filter(ctx, path) {
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
body = string(bodyBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否需要过滤
|
||||
if cfg.Filter != nil && cfg.Filter(ctx, path) {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
reply, err := handler(ctx, req)
|
||||
cost := time.Since(start)
|
||||
|
||||
layout := LogLayout{
|
||||
Time: time.Now(),
|
||||
Path: path,
|
||||
Query: query,
|
||||
Body: body,
|
||||
IP: ip,
|
||||
UserAgent: userAgent,
|
||||
Cost: cost,
|
||||
Source: cfg.Source,
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
layout.Error = err.Error()
|
||||
}
|
||||
|
||||
// 处理鉴权信息
|
||||
if cfg.AuthProcess != nil {
|
||||
cfg.AuthProcess(ctx, &layout)
|
||||
}
|
||||
|
||||
// 关键字过滤
|
||||
if cfg.FilterKeyword != nil {
|
||||
cfg.FilterKeyword(&layout)
|
||||
}
|
||||
|
||||
// 输出日志
|
||||
cfg.Print(layout)
|
||||
|
||||
return reply, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultLogger 默认日志中间件
|
||||
func DefaultLogger() middleware.Middleware {
|
||||
return Logger(LoggerConfig{
|
||||
Source: "Kratos",
|
||||
})
|
||||
}
|
||||
|
|
@ -0,0 +1,155 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/log"
|
||||
"github.com/go-kratos/kratos/v2/middleware"
|
||||
"github.com/go-kratos/kratos/v2/transport"
|
||||
kratoshttp "github.com/go-kratos/kratos/v2/transport/http"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const maxBodySize = 1024
|
||||
|
||||
// OperationRecord 操作记录
|
||||
type OperationRecord struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
CreatedAt time.Time `gorm:"index"`
|
||||
Ip string `gorm:"column:ip;comment:请求ip"`
|
||||
Method string `gorm:"column:method;comment:请求方法"`
|
||||
Path string `gorm:"column:path;comment:请求路径"`
|
||||
Status int `gorm:"column:status;comment:请求状态"`
|
||||
Latency time.Duration `gorm:"column:latency;comment:延迟"`
|
||||
Agent string `gorm:"column:agent;comment:代理"`
|
||||
ErrorMessage string `gorm:"column:error_message;comment:错误信息"`
|
||||
Body string `gorm:"type:text;column:body;comment:请求Body"`
|
||||
Resp string `gorm:"type:text;column:resp;comment:响应Body"`
|
||||
UserID int `gorm:"column:user_id;comment:用户id"`
|
||||
}
|
||||
|
||||
func (OperationRecord) TableName() string {
|
||||
return "sys_operation_records"
|
||||
}
|
||||
|
||||
// OperationLog 操作日志中间件
|
||||
func OperationLog(db *gorm.DB, logger log.Logger) middleware.Middleware {
|
||||
return func(handler middleware.Handler) middleware.Handler {
|
||||
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
var (
|
||||
body string
|
||||
userID int
|
||||
ip string
|
||||
method string
|
||||
path string
|
||||
agent string
|
||||
)
|
||||
|
||||
if tr, ok := transport.FromServerContext(ctx); ok {
|
||||
path = tr.Operation()
|
||||
if header := tr.RequestHeader(); header != nil {
|
||||
method = header.Get(":method")
|
||||
agent = header.Get("User-Agent")
|
||||
}
|
||||
|
||||
// 获取HTTP请求信息
|
||||
if ht, ok := tr.(kratoshttp.Transporter); ok {
|
||||
r := ht.Request()
|
||||
ip = getClientIP(r)
|
||||
|
||||
if r.Method != http.MethodGet {
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
if len(bodyBytes) > maxBodySize {
|
||||
body = "[超出记录长度]"
|
||||
} else {
|
||||
body = string(bodyBytes)
|
||||
}
|
||||
} else {
|
||||
query := r.URL.RawQuery
|
||||
query, _ = url.QueryUnescape(query)
|
||||
params := parseQuery(query)
|
||||
bodyBytes, _ := json.Marshal(params)
|
||||
body = string(bodyBytes)
|
||||
}
|
||||
|
||||
// 检查是否为文件上传
|
||||
if strings.Contains(r.Header.Get("Content-Type"), "multipart/form-data") {
|
||||
body = "[文件]"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取用户ID
|
||||
if claims, ok := GetClaims(ctx); ok {
|
||||
userID = int(claims.BaseClaims.ID)
|
||||
} else {
|
||||
// 尝试从header获取
|
||||
if tr, ok := transport.FromServerContext(ctx); ok {
|
||||
if ht, ok := tr.(kratoshttp.Transporter); ok {
|
||||
if idStr := ht.Request().Header.Get("x-user-id"); idStr != "" {
|
||||
if id, err := strconv.Atoi(idStr); err == nil {
|
||||
userID = id
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
reply, err := handler(ctx, req)
|
||||
latency := time.Since(start)
|
||||
|
||||
// 记录响应
|
||||
var resp string
|
||||
if reply != nil {
|
||||
respBytes, _ := json.Marshal(reply)
|
||||
if len(respBytes) > maxBodySize {
|
||||
resp = "[超出记录长度]"
|
||||
} else {
|
||||
resp = string(respBytes)
|
||||
}
|
||||
}
|
||||
|
||||
// 记录错误
|
||||
var errMsg string
|
||||
status := 200
|
||||
if err != nil {
|
||||
errMsg = err.Error()
|
||||
status = 500
|
||||
}
|
||||
|
||||
// 保存记录
|
||||
record := OperationRecord{
|
||||
Ip: ip,
|
||||
Method: method,
|
||||
Path: path,
|
||||
Status: status,
|
||||
Latency: latency,
|
||||
Agent: agent,
|
||||
ErrorMessage: errMsg,
|
||||
Body: body,
|
||||
Resp: resp,
|
||||
UserID: userID,
|
||||
}
|
||||
|
||||
if db != nil {
|
||||
go func() {
|
||||
if err := db.Create(&record).Error; err != nil {
|
||||
log.NewHelper(logger).Error("create operation record error:", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return reply, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,108 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/errors"
|
||||
"github.com/go-kratos/kratos/v2/middleware"
|
||||
"github.com/go-kratos/kratos/v2/transport"
|
||||
kratoshttp "github.com/go-kratos/kratos/v2/transport/http"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// LimitConfig 限流配置
|
||||
type LimitConfig struct {
|
||||
// GenerationKey 根据业务生成key
|
||||
GenerationKey func(ctx context.Context, path string) string
|
||||
// CheckOrMark 检查函数
|
||||
CheckOrMark func(ctx context.Context, key string, expire int, limit int) error
|
||||
// Expire key过期时间(秒)
|
||||
Expire int
|
||||
// Limit 周期内限制次数
|
||||
Limit int
|
||||
}
|
||||
|
||||
// RateLimit 限流中间件
|
||||
func RateLimit(cfg LimitConfig) middleware.Middleware {
|
||||
return func(handler middleware.Handler) middleware.Handler {
|
||||
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
var key string
|
||||
if tr, ok := transport.FromServerContext(ctx); ok {
|
||||
path := tr.Operation()
|
||||
key = cfg.GenerationKey(ctx, path)
|
||||
}
|
||||
|
||||
if err := cfg.CheckOrMark(ctx, key, cfg.Expire, cfg.Limit); err != nil {
|
||||
return nil, errors.New(429, "RATE_LIMIT", err.Error())
|
||||
}
|
||||
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultGenerationKey 默认生成key(基于IP)
|
||||
func DefaultGenerationKey(ctx context.Context, path string) string {
|
||||
if tr, ok := transport.FromServerContext(ctx); ok {
|
||||
if ht, ok := tr.(kratoshttp.Transporter); ok {
|
||||
return "Limit_" + getClientIP(ht.Request())
|
||||
}
|
||||
}
|
||||
return "Limit_unknown"
|
||||
}
|
||||
|
||||
// NewRedisCheckOrMark 创建基于Redis的检查函数
|
||||
func NewRedisCheckOrMark(rdb *redis.Client) func(ctx context.Context, key string, expire int, limit int) error {
|
||||
return func(ctx context.Context, key string, expire int, limit int) error {
|
||||
if rdb == nil {
|
||||
return nil
|
||||
}
|
||||
return SetLimitWithTime(ctx, rdb, key, limit, time.Duration(expire)*time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
// SetLimitWithTime 设置访问次数
|
||||
func SetLimitWithTime(ctx context.Context, rdb *redis.Client, key string, limit int, expiration time.Duration) error {
|
||||
count, err := rdb.Exists(ctx, key).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
// key不存在,创建并设置过期时间
|
||||
pipe := rdb.TxPipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, expiration)
|
||||
_, err = pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// key存在,检查次数
|
||||
times, err := rdb.Get(ctx, key).Int()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if times >= limit {
|
||||
// 获取剩余时间
|
||||
ttl, err := rdb.PTTL(ctx, key).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("请求太过频繁,请稍后再试")
|
||||
}
|
||||
return fmt.Errorf("请求太过频繁, 请 %s 后尝试", ttl.String())
|
||||
}
|
||||
|
||||
return rdb.Incr(ctx, key).Err()
|
||||
}
|
||||
|
||||
// DefaultRateLimit 默认限流中间件
|
||||
func DefaultRateLimit(rdb *redis.Client, expire, limit int) middleware.Middleware {
|
||||
return RateLimit(LimitConfig{
|
||||
GenerationKey: DefaultGenerationKey,
|
||||
CheckOrMark: NewRedisCheckOrMark(rdb),
|
||||
Expire: expire,
|
||||
Limit: limit,
|
||||
})
|
||||
}
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/errors"
|
||||
"github.com/go-kratos/kratos/v2/log"
|
||||
"github.com/go-kratos/kratos/v2/middleware"
|
||||
)
|
||||
|
||||
// Recovery panic恢复中间件
|
||||
func Recovery(logger log.Logger) middleware.Middleware {
|
||||
helper := log.NewHelper(logger)
|
||||
return func(handler middleware.Handler) middleware.Handler {
|
||||
return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
|
||||
defer func() {
|
||||
if rerr := recover(); rerr != nil {
|
||||
// 检查是否为断开的连接
|
||||
var brokenPipe bool
|
||||
if ne, ok := rerr.(*net.OpError); ok {
|
||||
if se, ok := ne.Err.(*os.SyscallError); ok {
|
||||
errStr := strings.ToLower(se.Error())
|
||||
if strings.Contains(errStr, "broken pipe") ||
|
||||
strings.Contains(errStr, "connection reset by peer") {
|
||||
brokenPipe = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stack := string(debug.Stack())
|
||||
if brokenPipe {
|
||||
helper.Errorw(
|
||||
"msg", "[Recovery from panic]",
|
||||
"error", rerr,
|
||||
"broken_pipe", true,
|
||||
)
|
||||
} else {
|
||||
helper.Errorw(
|
||||
"msg", "[Recovery from panic]",
|
||||
"error", rerr,
|
||||
"stack", stack,
|
||||
)
|
||||
}
|
||||
|
||||
err = errors.InternalServer("PANIC", "服务器内部错误")
|
||||
}
|
||||
}()
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/go-kratos/kratos/v2/errors"
|
||||
"github.com/go-kratos/kratos/v2/middleware"
|
||||
)
|
||||
|
||||
// Timeout 超时中间件
|
||||
func Timeout(timeout time.Duration) middleware.Middleware {
|
||||
return func(handler middleware.Handler) middleware.Handler {
|
||||
return func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
// 使用 buffered channel 避免 goroutine 泄漏
|
||||
type result struct {
|
||||
reply interface{}
|
||||
err error
|
||||
}
|
||||
done := make(chan result, 1)
|
||||
panicChan := make(chan interface{}, 1)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
select {
|
||||
case panicChan <- p:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
reply, err := handler(ctx, req)
|
||||
select {
|
||||
case done <- result{reply: reply, err: err}:
|
||||
default:
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case p := <-panicChan:
|
||||
panic(p)
|
||||
case r := <-done:
|
||||
return r.reply, r.err
|
||||
case <-ctx.Done():
|
||||
return nil, errors.GatewayTimeout("TIMEOUT", "请求超时")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// getClientIP 获取客户端IP
|
||||
func getClientIP(r *http.Request) string {
|
||||
// X-Forwarded-For
|
||||
xff := r.Header.Get("X-Forwarded-For")
|
||||
if xff != "" {
|
||||
ips := strings.Split(xff, ",")
|
||||
if len(ips) > 0 {
|
||||
return strings.TrimSpace(ips[0])
|
||||
}
|
||||
}
|
||||
|
||||
// X-Real-IP
|
||||
xri := r.Header.Get("X-Real-IP")
|
||||
if xri != "" {
|
||||
return xri
|
||||
}
|
||||
|
||||
// RemoteAddr
|
||||
ip := r.RemoteAddr
|
||||
if idx := strings.LastIndex(ip, ":"); idx != -1 {
|
||||
ip = ip[:idx]
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
// parseQuery 解析query参数
|
||||
func parseQuery(query string) map[string]string {
|
||||
m := make(map[string]string)
|
||||
for _, v := range strings.Split(query, "&") {
|
||||
kv := strings.Split(v, "=")
|
||||
if len(kv) == 2 {
|
||||
m[kv[0]] = kv[1]
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
|
@ -103,5 +103,20 @@ func (j *JWT) ParseToken(tokenString string) (*CustomClaims, error) {
|
|||
|
||||
// NeedRefresh 判断是否需要刷新token
|
||||
func (j *JWT) NeedRefresh(claims *CustomClaims) bool {
|
||||
return time.Until(claims.ExpiresAt.Time) < j.BufferTime
|
||||
return claims.ExpiresAt.Unix()-time.Now().Unix() < claims.BufferTime
|
||||
}
|
||||
|
||||
// CreateTokenByOldToken 根据旧token创建新token
|
||||
func (j *JWT) CreateTokenByOldToken(oldToken string, oldClaims CustomClaims) (string, *CustomClaims, error) {
|
||||
// 更新过期时间
|
||||
oldClaims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(j.ExpiresAt))
|
||||
oldClaims.IssuedAt = jwt.NewNumericDate(time.Now())
|
||||
oldClaims.NotBefore = jwt.NewNumericDate(time.Now().Add(-1000))
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, oldClaims)
|
||||
newToken, err := token.SignedString(j.SigningKey)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return newToken, &oldClaims, nil
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue