package middleware import ( "bytes" "context" "encoding/json" "io" "net/http" "net/url" "strconv" "strings" "time" "github.com/go-kratos/kratos/v2/log" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/transport" kratoshttp "github.com/go-kratos/kratos/v2/transport/http" "gorm.io/gorm" ) const maxBodySize = 1024 // OperationRecord 操作记录 type OperationRecord struct { ID uint `gorm:"primarykey"` CreatedAt time.Time `gorm:"index"` Ip string `gorm:"column:ip;comment:请求ip"` Method string `gorm:"column:method;comment:请求方法"` Path string `gorm:"column:path;comment:请求路径"` Status int `gorm:"column:status;comment:请求状态"` Latency time.Duration `gorm:"column:latency;comment:延迟"` Agent string `gorm:"column:agent;comment:代理"` ErrorMessage string `gorm:"column:error_message;comment:错误信息"` Body string `gorm:"type:text;column:body;comment:请求Body"` Resp string `gorm:"type:text;column:resp;comment:响应Body"` UserID int `gorm:"column:user_id;comment:用户id"` } func (OperationRecord) TableName() string { return "sys_operation_records" } // OperationLog 操作日志中间件 func OperationLog(db *gorm.DB, logger log.Logger) middleware.Middleware { return func(handler middleware.Handler) middleware.Handler { return func(ctx context.Context, req interface{}) (interface{}, error) { var ( body string userID int ip string method string path string agent string ) if tr, ok := transport.FromServerContext(ctx); ok { path = tr.Operation() if header := tr.RequestHeader(); header != nil { method = header.Get(":method") agent = header.Get("User-Agent") } // 获取HTTP请求信息 if ht, ok := tr.(kratoshttp.Transporter); ok { r := ht.Request() ip = getClientIP(r) if r.Method != http.MethodGet { bodyBytes, _ := io.ReadAll(r.Body) r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) if len(bodyBytes) > maxBodySize { body = "[超出记录长度]" } else { body = string(bodyBytes) } } else { query := r.URL.RawQuery query, _ = url.QueryUnescape(query) params := parseQuery(query) bodyBytes, _ := json.Marshal(params) body = string(bodyBytes) } // 检查是否为文件上传 if strings.Contains(r.Header.Get("Content-Type"), "multipart/form-data") { body = "[文件]" } } } // 获取用户ID if claims, ok := GetClaims(ctx); ok { userID = int(claims.BaseClaims.ID) } else { // 尝试从header获取 if tr, ok := transport.FromServerContext(ctx); ok { if ht, ok := tr.(kratoshttp.Transporter); ok { if idStr := ht.Request().Header.Get("x-user-id"); idStr != "" { if id, err := strconv.Atoi(idStr); err == nil { userID = id } } } } } start := time.Now() reply, err := handler(ctx, req) latency := time.Since(start) // 记录响应 var resp string if reply != nil { respBytes, _ := json.Marshal(reply) if len(respBytes) > maxBodySize { resp = "[超出记录长度]" } else { resp = string(respBytes) } } // 记录错误 var errMsg string status := 200 if err != nil { errMsg = err.Error() status = 500 } // 保存记录 record := OperationRecord{ Ip: ip, Method: method, Path: path, Status: status, Latency: latency, Agent: agent, ErrorMessage: errMsg, Body: body, Resp: resp, UserID: userID, } if db != nil { go func() { if err := db.Create(&record).Error; err != nil { log.NewHelper(logger).Error("create operation record error:", err) } }() } return reply, err } } }