This commit is contained in:
parent
7455586bcf
commit
b8447604c2
|
|
@ -1,10 +1,14 @@
|
|||
package user
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/common/response"
|
||||
petRequest "github.com/flipped-aurora/gin-vue-admin/server/model/pet/request"
|
||||
petResponse "github.com/flipped-aurora/gin-vue-admin/server/model/pet/response"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/service"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
|
|
@ -15,6 +19,108 @@ type PetAssistantApi struct{}
|
|||
|
||||
var petChatService = service.ServiceGroupApp.PetServiceGroup.PetChatService
|
||||
|
||||
// StreamAskPetAssistant 向宠物助手流式提问
|
||||
// @Tags PetAssistant
|
||||
// @Summary 向宠物助手流式提问接口
|
||||
// @Security ApiKeyAuth
|
||||
// @Accept application/json
|
||||
// @Produce text/event-stream
|
||||
// @Param data body petRequest.ChatRequest true "宠物助手流式提问请求"
|
||||
// @Success 200 {string} string "流式响应"
|
||||
// @Router /api/v1/pet/user/assistant/stream-ask [post]
|
||||
func (p *PetAssistantApi) StreamAskPetAssistant(ctx *gin.Context) {
|
||||
// 创建业务用Context
|
||||
businessCtx := ctx.Request.Context()
|
||||
|
||||
// 获取用户ID
|
||||
userId := utils.GetAppUserID(ctx)
|
||||
if userId == 0 {
|
||||
global.GVA_LOG.Error("获取用户ID失败")
|
||||
response.FailWithMessage("用户认证失败", ctx)
|
||||
return
|
||||
}
|
||||
|
||||
// 绑定请求参数
|
||||
var req petRequest.ChatRequest
|
||||
if err := ctx.ShouldBindJSON(&req); err != nil {
|
||||
global.GVA_LOG.Error("参数绑定失败", zap.Error(err))
|
||||
response.FailWithMessage("参数错误: "+err.Error(), ctx)
|
||||
return
|
||||
}
|
||||
|
||||
// 参数验证
|
||||
if req.Message == "" {
|
||||
response.FailWithMessage("消息内容不能为空", ctx)
|
||||
return
|
||||
}
|
||||
|
||||
// 强制设置为流式响应
|
||||
req.Stream = true
|
||||
|
||||
// 设置默认参数
|
||||
if req.Temperature <= 0 {
|
||||
req.Temperature = 0.7
|
||||
}
|
||||
if req.MaxTokens <= 0 {
|
||||
req.MaxTokens = 1000
|
||||
}
|
||||
|
||||
// 设置SSE响应头
|
||||
ctx.Header("Content-Type", "text/event-stream")
|
||||
ctx.Header("Cache-Control", "no-cache")
|
||||
ctx.Header("Connection", "keep-alive")
|
||||
ctx.Header("Access-Control-Allow-Origin", "*")
|
||||
ctx.Header("Access-Control-Allow-Headers", "Cache-Control")
|
||||
|
||||
// 创建事件通道
|
||||
eventChan := make(chan petResponse.StreamEvent, 100)
|
||||
defer close(eventChan)
|
||||
|
||||
// 启动流式聊天
|
||||
go func() {
|
||||
if err := petChatService.StreamChat(businessCtx, userId, req, eventChan); err != nil {
|
||||
global.GVA_LOG.Error("流式聊天失败", zap.Error(err), zap.Uint("userId", userId))
|
||||
eventChan <- petResponse.StreamEvent{
|
||||
Event: "error",
|
||||
Data: map[string]interface{}{
|
||||
"error": "流式聊天失败: " + err.Error(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 发送流式数据
|
||||
ctx.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case event, ok := <-eventChan:
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
switch event.Event {
|
||||
case "message":
|
||||
// 发送消息事件
|
||||
ctx.SSEvent("message", event.Data)
|
||||
case "error":
|
||||
// 发送错误事件
|
||||
ctx.SSEvent("error", event.Data)
|
||||
return false
|
||||
case "done":
|
||||
// 发送完成事件
|
||||
ctx.SSEvent("done", event.Data)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
case <-time.After(30 * time.Second):
|
||||
// 超时处理
|
||||
ctx.SSEvent("error", map[string]interface{}{
|
||||
"error": "流式响应超时",
|
||||
})
|
||||
return false
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// GetAssistantHistory 获取宠物助手对话历史
|
||||
// @Tags PetAssistant
|
||||
// @Summary 获取宠物助手对话历史记录(简化版本,只返回必要字段)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ package global
|
|||
import (
|
||||
"fmt"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
"github.com/olahol/melody"
|
||||
"sync"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
|
@ -38,7 +37,6 @@ var (
|
|||
GVA_ROUTERS gin.RoutesInfo
|
||||
GVA_ACTIVE_DBNAME *string
|
||||
GVA_MCP_SERVER *server.MCPServer
|
||||
MELODY *melody.Melody
|
||||
BlackCache local_cache.Cache
|
||||
lock sync.RWMutex
|
||||
)
|
||||
|
|
|
|||
|
|
@ -99,7 +99,6 @@ require (
|
|||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
|
||||
github.com/golang/snappy v0.0.4 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/gorilla/websocket v1.5.0 // indirect
|
||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
|
||||
|
|
@ -132,7 +131,6 @@ require (
|
|||
github.com/mozillazg/go-httpheader v0.4.0 // indirect
|
||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||
github.com/nwaples/rardecode/v2 v2.1.0 // indirect
|
||||
github.com/olahol/melody v1.3.0 // indirect
|
||||
github.com/otiai10/mint v1.6.3 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.22 // indirect
|
||||
|
|
|
|||
|
|
@ -257,8 +257,6 @@ github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
|||
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
|
||||
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
|
||||
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
|
||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
|
||||
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplbc8s8sSb3V2oUCygFHVp8gC3Dn6U4MNI=
|
||||
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
|
|
@ -396,8 +394,6 @@ github.com/nwaples/rardecode/v2 v2.1.0/go.mod h1:7uz379lSxPe6j9nvzxUZ+n7mnJNgjsR
|
|||
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
|
||||
github.com/olahol/melody v1.3.0 h1:n7UlKiQnxVrgxKoM0d7usZiN+Z0y2lVENtYLgKtXS6s=
|
||||
github.com/olahol/melody v1.3.0/go.mod h1:GgkTl6Y7yWj/HtfD48Q5vLKPVoZOH+Qqgfa7CvJgJM4=
|
||||
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||
github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk=
|
||||
github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0=
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import (
|
|||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/middleware"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/router"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/websocket"
|
||||
"github.com/gin-gonic/gin"
|
||||
swaggerFiles "github.com/swaggo/files"
|
||||
ginSwagger "github.com/swaggo/gin-swagger"
|
||||
|
|
@ -43,9 +42,6 @@ func Routers() *gin.Engine {
|
|||
|
||||
sseServer := McpRun()
|
||||
|
||||
// 初始化WebSocket
|
||||
InitWebSocket()
|
||||
|
||||
// 注册mcp服务
|
||||
Router.GET(global.GVA_CONFIG.MCP.SSEPath, func(c *gin.Context) {
|
||||
sseServer.SSEHandler().ServeHTTP(c.Writer, c.Request)
|
||||
|
|
@ -82,9 +78,6 @@ func Routers() *gin.Engine {
|
|||
PrivateGroup.Use(middleware.JWTAuth()).Use(middleware.CasbinHandler())
|
||||
UserGroup.Use(middleware.UserJWTAuth())
|
||||
|
||||
// WebSocket路由(不使用UserJWTAuth中间件,使用自定义认证)
|
||||
Router.GET("/user/ws", websocket.WebSocketAuthMiddleware(), websocket.HandleConnection)
|
||||
|
||||
{
|
||||
// 健康监测
|
||||
PublicGroup.GET("/health", func(c *gin.Context) {
|
||||
|
|
|
|||
|
|
@ -1,10 +0,0 @@
|
|||
package initialize
|
||||
|
||||
import (
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/websocket"
|
||||
)
|
||||
|
||||
// InitWebSocket 初始化WebSocket
|
||||
func InitWebSocket() {
|
||||
websocket.InitWebSocketServer()
|
||||
}
|
||||
|
|
@ -1,109 +0,0 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/olahol/melody"
|
||||
)
|
||||
|
||||
// HandleConnection 处理WebSocket连接
|
||||
func HandleConnection(ctx *gin.Context) {
|
||||
// 从中间件获取用户ID
|
||||
userId, exists := ctx.Get("userId")
|
||||
if !exists {
|
||||
ctx.JSON(401, gin.H{"error": "用户认证失败"})
|
||||
return
|
||||
}
|
||||
|
||||
// 升级为WebSocket连接,并存储用户信息
|
||||
global.MELODY.HandleRequestWithKeys(ctx.Writer, ctx.Request, map[string]interface{}{
|
||||
"userId": userId,
|
||||
})
|
||||
}
|
||||
|
||||
// SendMessage 发送消息到WebSocket连接
|
||||
func SendMessage(s *melody.Session, message interface{}) error {
|
||||
data, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.Write(data)
|
||||
}
|
||||
|
||||
// SendErrorMessage 发送错误消息
|
||||
func SendErrorMessage(s *melody.Session, errorMsg string) {
|
||||
SendMessage(s, NewErrorEvent(errorMsg))
|
||||
}
|
||||
|
||||
// BroadcastToUser 向指定用户广播消息
|
||||
func BroadcastToUser(userID uint, message interface{}) error {
|
||||
messageBytes, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 遍历所有连接,找到对应用户的连接
|
||||
return global.MELODY.BroadcastFilter(messageBytes, func(s *melody.Session) bool {
|
||||
if userIdInterface, exists := s.Get("userId"); exists {
|
||||
if userId, ok := userIdInterface.(uint); ok {
|
||||
return userId == userID
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
|
||||
// BroadcastMessageToUser 向指定用户广播流式消息
|
||||
func BroadcastMessageToUser(userID uint, delta string) error {
|
||||
return BroadcastToUser(userID, NewMessageEvent(delta))
|
||||
}
|
||||
|
||||
// BroadcastErrorToUser 向指定用户广播错误消息
|
||||
func BroadcastErrorToUser(userID uint, errorMsg string) error {
|
||||
return BroadcastToUser(userID, NewErrorEvent(errorMsg))
|
||||
}
|
||||
|
||||
// BroadcastDoneToUser 向指定用户广播完成消息
|
||||
func BroadcastDoneToUser(userID uint, message, sessionId string) error {
|
||||
return BroadcastToUser(userID, NewDoneEvent(message, sessionId))
|
||||
}
|
||||
|
||||
// GetConnectedUserCount 获取当前连接的用户数量
|
||||
func GetConnectedUserCount() int {
|
||||
count := 0
|
||||
global.MELODY.BroadcastFilter([]byte{}, func(s *melody.Session) bool {
|
||||
if _, exists := s.Get("userId"); exists {
|
||||
count++
|
||||
}
|
||||
return false // 不实际发送消息,只统计
|
||||
})
|
||||
return count
|
||||
}
|
||||
|
||||
// SendPingToUser 向指定用户发送心跳消息
|
||||
func SendPingToUser(userID uint) error {
|
||||
pingMsg := Message{
|
||||
Type: "ping",
|
||||
Data: map[string]interface{}{
|
||||
"timestamp": time.Now().Unix(),
|
||||
},
|
||||
}
|
||||
return BroadcastToUser(userID, pingMsg)
|
||||
}
|
||||
|
||||
// IsUserConnected 检查用户是否在线
|
||||
func IsUserConnected(userID uint) bool {
|
||||
connected := false
|
||||
global.MELODY.BroadcastFilter([]byte{}, func(s *melody.Session) bool {
|
||||
if userIdInterface, exists := s.Get("userId"); exists {
|
||||
if userId, ok := userIdInterface.(uint); ok && userId == userID {
|
||||
connected = true
|
||||
}
|
||||
}
|
||||
return false // 不实际发送消息,只检查
|
||||
})
|
||||
return connected
|
||||
}
|
||||
|
|
@ -1,54 +0,0 @@
|
|||
package websocket
|
||||
|
||||
// Message WebSocket下发消息结构体
|
||||
type Message struct {
|
||||
Type string `json:"type"` // 消息类型:message, error, done
|
||||
Data interface{} `json:"data"` // 消息数据
|
||||
}
|
||||
|
||||
// MessageData 流式消息数据结构体
|
||||
type MessageData struct {
|
||||
Delta string `json:"delta"` // 增量消息内容
|
||||
}
|
||||
|
||||
// ErrorData 错误数据结构体
|
||||
type ErrorData struct {
|
||||
Error string `json:"error"` // 错误信息
|
||||
}
|
||||
|
||||
// DoneData 完成数据结构体
|
||||
type DoneData struct {
|
||||
Message string `json:"message"` // 完整消息内容
|
||||
SessionId string `json:"sessionId"` // 会话ID
|
||||
}
|
||||
|
||||
// NewMessageEvent 创建消息事件
|
||||
func NewMessageEvent(delta string) Message {
|
||||
return Message{
|
||||
Type: "message",
|
||||
Data: MessageData{
|
||||
Delta: delta,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewErrorEvent 创建错误事件
|
||||
func NewErrorEvent(errorMsg string) Message {
|
||||
return Message{
|
||||
Type: "error",
|
||||
Data: ErrorData{
|
||||
Error: errorMsg,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewDoneEvent 创建完成事件
|
||||
func NewDoneEvent(message, sessionId string) Message {
|
||||
return Message{
|
||||
Type: "done",
|
||||
Data: DoneData{
|
||||
Message: message,
|
||||
SessionId: sessionId,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
@ -1,116 +0,0 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/olahol/melody"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// InitWebSocketServer 初始化WebSocket服务器
|
||||
func InitWebSocketServer() {
|
||||
// 创建melody实例
|
||||
global.MELODY = melody.New()
|
||||
|
||||
// 设置WebSocket配置
|
||||
global.MELODY.Config.MaxMessageSize = 1024 * 1024 // 1MB
|
||||
global.MELODY.Config.MessageBufferSize = 256
|
||||
global.MELODY.Config.PongWait = 60 * time.Second // 等待pong消息的时间
|
||||
global.MELODY.Config.PingPeriod = 54 * time.Second // 发送ping消息的间隔
|
||||
global.MELODY.Config.WriteWait = 10 * time.Second // 写入超时时间
|
||||
global.MELODY.Config.ConcurrentMessageHandling = false // 禁用并发消息处理
|
||||
|
||||
// 连接建立时的处理
|
||||
global.MELODY.HandleConnect(func(s *melody.Session) {
|
||||
if userIdInterface, exists := s.Get("userId"); exists {
|
||||
if userId, ok := userIdInterface.(uint); ok {
|
||||
global.GVA_LOG.Info("WebSocket连接建立",
|
||||
zap.Uint("userId", userId),
|
||||
zap.String("remoteAddr", s.Request.RemoteAddr))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// 连接断开时的处理
|
||||
global.MELODY.HandleDisconnect(func(s *melody.Session) {
|
||||
if userIdInterface, exists := s.Get("userId"); exists {
|
||||
if userId, ok := userIdInterface.(uint); ok {
|
||||
global.GVA_LOG.Info("WebSocket连接断开",
|
||||
zap.Uint("userId", userId),
|
||||
zap.String("remoteAddr", s.Request.RemoteAddr))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// 错误处理
|
||||
global.MELODY.HandleError(func(s *melody.Session, err error) {
|
||||
// 检查是否是正常关闭或客户端主动断开
|
||||
errMsg := err.Error()
|
||||
if strings.Contains(errMsg, "close 1000") ||
|
||||
strings.Contains(errMsg, "close 1001") ||
|
||||
strings.Contains(errMsg, "use of closed network connection") {
|
||||
// 正常关闭,只记录debug级别日志
|
||||
if userIdInterface, exists := s.Get("userId"); exists {
|
||||
if userId, ok := userIdInterface.(uint); ok {
|
||||
global.GVA_LOG.Debug("WebSocket正常关闭",
|
||||
zap.String("reason", errMsg),
|
||||
zap.Uint("userId", userId))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 异常错误,记录error级别日志
|
||||
if userIdInterface, exists := s.Get("userId"); exists {
|
||||
if userId, ok := userIdInterface.(uint); ok {
|
||||
global.GVA_LOG.Error("WebSocket异常错误",
|
||||
zap.Error(err),
|
||||
zap.Uint("userId", userId))
|
||||
}
|
||||
} else {
|
||||
global.GVA_LOG.Error("WebSocket异常错误", zap.Error(err))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// 由于只下发数据,不处理客户端消息,所以不设置消息处理器
|
||||
global.GVA_LOG.Info("WebSocket服务器初始化完成")
|
||||
}
|
||||
|
||||
// WebSocketAuthMiddleware WebSocket鉴权中间件
|
||||
func WebSocketAuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 从查询参数获取token
|
||||
token := c.Query("token")
|
||||
if token == "" {
|
||||
// 从头部获取token
|
||||
token = c.GetHeader("Authorization")
|
||||
if token != "" && len(token) > 7 && token[:7] == "Bearer " {
|
||||
token = token[7:]
|
||||
}
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "缺少认证token"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 解析token获取用户ID
|
||||
j := utils.NewJWT()
|
||||
claims, err := j.ParseAppUserToken(token)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("WebSocket token解析失败", zap.Error(err))
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "token无效"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 将用户ID存储到上下文中
|
||||
c.Set("userId", claims.AppBaseClaims.ID)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue