kra/internal/server/middleware/operation.go

156 lines
3.9 KiB
Go

package middleware
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/go-kratos/kratos/v2/log"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
kratoshttp "github.com/go-kratos/kratos/v2/transport/http"
"gorm.io/gorm"
)
const maxBodySize = 1024
// OperationRecord 操作记录
type OperationRecord struct {
ID uint `gorm:"primarykey"`
CreatedAt time.Time `gorm:"index"`
Ip string `gorm:"column:ip;comment:请求ip"`
Method string `gorm:"column:method;comment:请求方法"`
Path string `gorm:"column:path;comment:请求路径"`
Status int `gorm:"column:status;comment:请求状态"`
Latency time.Duration `gorm:"column:latency;comment:延迟"`
Agent string `gorm:"column:agent;comment:代理"`
ErrorMessage string `gorm:"column:error_message;comment:错误信息"`
Body string `gorm:"type:text;column:body;comment:请求Body"`
Resp string `gorm:"type:text;column:resp;comment:响应Body"`
UserID int `gorm:"column:user_id;comment:用户id"`
}
func (OperationRecord) TableName() string {
return "sys_operation_records"
}
// OperationLog 操作日志中间件
func OperationLog(db *gorm.DB, logger log.Logger) middleware.Middleware {
return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (interface{}, error) {
var (
body string
userID int
ip string
method string
path string
agent string
)
if tr, ok := transport.FromServerContext(ctx); ok {
path = tr.Operation()
if header := tr.RequestHeader(); header != nil {
method = header.Get(":method")
agent = header.Get("User-Agent")
}
// 获取HTTP请求信息
if ht, ok := tr.(kratoshttp.Transporter); ok {
r := ht.Request()
ip = getClientIP(r)
if r.Method != http.MethodGet {
bodyBytes, _ := io.ReadAll(r.Body)
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
if len(bodyBytes) > maxBodySize {
body = "[超出记录长度]"
} else {
body = string(bodyBytes)
}
} else {
query := r.URL.RawQuery
query, _ = url.QueryUnescape(query)
params := parseQuery(query)
bodyBytes, _ := json.Marshal(params)
body = string(bodyBytes)
}
// 检查是否为文件上传
if strings.Contains(r.Header.Get("Content-Type"), "multipart/form-data") {
body = "[文件]"
}
}
}
// 获取用户ID
if claims, ok := GetClaims(ctx); ok {
userID = int(claims.BaseClaims.ID)
} else {
// 尝试从header获取
if tr, ok := transport.FromServerContext(ctx); ok {
if ht, ok := tr.(kratoshttp.Transporter); ok {
if idStr := ht.Request().Header.Get("x-user-id"); idStr != "" {
if id, err := strconv.Atoi(idStr); err == nil {
userID = id
}
}
}
}
}
start := time.Now()
reply, err := handler(ctx, req)
latency := time.Since(start)
// 记录响应
var resp string
if reply != nil {
respBytes, _ := json.Marshal(reply)
if len(respBytes) > maxBodySize {
resp = "[超出记录长度]"
} else {
resp = string(respBytes)
}
}
// 记录错误
var errMsg string
status := 200
if err != nil {
errMsg = err.Error()
status = 500
}
// 保存记录
record := OperationRecord{
Ip: ip,
Method: method,
Path: path,
Status: status,
Latency: latency,
Agent: agent,
ErrorMessage: errMsg,
Body: body,
Resp: resp,
UserID: userID,
}
if db != nil {
go func() {
if err := db.Create(&record).Error; err != nil {
log.NewHelper(logger).Error("create operation record error:", err)
}
}()
}
return reply, err
}
}
}