package middleware import ( "bytes" "encoding/json" "io" "net/http" "net/url" "strconv" "strings" "sync" "time" "kra/pkg/utils" "github.com/gin-gonic/gin" "go.uber.org/zap" ) var respPool sync.Pool var bufferSize = 1024 func init() { respPool.New = func() interface{} { return make([]byte, bufferSize) } } // SysOperationRecord 操作记录结构 type SysOperationRecord struct { IP string `json:"ip"` Method string `json:"method"` Path string `json:"path"` Status int `json:"status"` Latency time.Duration `json:"latency"` Agent string `json:"agent"` ErrorMessage string `json:"error_message"` Body string `json:"body"` Resp string `json:"resp"` UserID int `json:"user_id"` } // OperationRecordCreator 操作记录创建接口 type OperationRecordCreator interface { CreateOperationRecord(record *SysOperationRecord) error } // 全局操作记录创建器 var operationRecordCreator OperationRecordCreator var operationLogger *zap.Logger // SetOperationRecordCreator 设置操作记录创建器 func SetOperationRecordCreator(creator OperationRecordCreator) { operationRecordCreator = creator } // SetOperationLogger 设置操作记录日志 func SetOperationLogger(logger *zap.Logger) { operationLogger = logger } // OperationRecord 操作记录中间件(与 kra 保持一致的命名) // 别名函数,保持与 KRA 的兼容性 func OperationRecord() gin.HandlerFunc { return OperationRecordMiddleware() } // OperationRecordMiddleware 操作记录中间件(与 kra 保持一致) func OperationRecordMiddleware() gin.HandlerFunc { return func(c *gin.Context) { var body []byte var userId int if c.Request.Method != http.MethodGet { var err error body, err = io.ReadAll(c.Request.Body) if err != nil { if operationLogger != nil { operationLogger.Error("read body from request error:", zap.Error(err)) } } else { c.Request.Body = io.NopCloser(bytes.NewBuffer(body)) } } else { query := c.Request.URL.RawQuery query, _ = url.QueryUnescape(query) split := strings.Split(query, "&") m := make(map[string]string) for _, v := range split { kv := strings.Split(v, "=") if len(kv) == 2 { m[kv[0]] = kv[1] } } body, _ = json.Marshal(&m) } claims, _ := utils.GetClaims(c) if claims != nil && claims.BaseClaims.ID != 0 { userId = int(claims.BaseClaims.ID) } else { id, err := strconv.Atoi(c.Request.Header.Get("x-user-id")) if err != nil { userId = 0 } userId = id } record := SysOperationRecord{ IP: c.ClientIP(), Method: c.Request.Method, Path: c.Request.URL.Path, Agent: c.Request.UserAgent(), Body: "", UserID: userId, } // 上传文件时候 中间件日志进行裁断操作 if strings.Contains(c.GetHeader("Content-Type"), "multipart/form-data") { record.Body = "[文件]" } else { if len(body) > bufferSize { record.Body = "[超出记录长度]" } else { record.Body = string(body) } } writer := responseBodyWriter{ ResponseWriter: c.Writer, body: &bytes.Buffer{}, } c.Writer = writer now := time.Now() c.Next() latency := time.Since(now) record.ErrorMessage = c.Errors.ByType(gin.ErrorTypePrivate).String() record.Status = c.Writer.Status() record.Latency = latency record.Resp = writer.body.String() // 检查是否为下载响应,如果是则截断 if strings.Contains(c.Writer.Header().Get("Pragma"), "public") || strings.Contains(c.Writer.Header().Get("Expires"), "0") || strings.Contains(c.Writer.Header().Get("Cache-Control"), "must-revalidate, post-check=0, pre-check=0") || strings.Contains(c.Writer.Header().Get("Content-Type"), "application/force-download") || strings.Contains(c.Writer.Header().Get("Content-Type"), "application/octet-stream") || strings.Contains(c.Writer.Header().Get("Content-Type"), "application/vnd.ms-excel") || strings.Contains(c.Writer.Header().Get("Content-Type"), "application/download") || strings.Contains(c.Writer.Header().Get("Content-Disposition"), "attachment") || strings.Contains(c.Writer.Header().Get("Content-Transfer-Encoding"), "binary") { if len(record.Resp) > bufferSize { record.Resp = "[超出记录长度]" } } // 保存操作记录 if operationRecordCreator != nil { if err := operationRecordCreator.CreateOperationRecord(&record); err != nil { if operationLogger != nil { operationLogger.Error("create operation record error:", zap.Error(err)) } } } } } type responseBodyWriter struct { gin.ResponseWriter body *bytes.Buffer } func (r responseBodyWriter) Write(b []byte) (int, error) { r.body.Write(b) return r.ResponseWriter.Write(b) }