439 lines
13 KiB
Go
439 lines
13 KiB
Go
package pet
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"time"
|
||
|
||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||
"github.com/flipped-aurora/gin-vue-admin/server/model/pet"
|
||
petRequest "github.com/flipped-aurora/gin-vue-admin/server/model/pet/request"
|
||
petResponse "github.com/flipped-aurora/gin-vue-admin/server/model/pet/response"
|
||
"github.com/flipped-aurora/gin-vue-admin/server/utils"
|
||
"github.com/google/uuid"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// 使用新定义的模型结构
|
||
type ChatRequest = petRequest.ChatRequest
|
||
type ChatResponse = petResponse.ChatResponse
|
||
type StreamEvent = petResponse.StreamEvent
|
||
|
||
// PetChatService 宠物聊天服务
|
||
type PetChatService struct{}
|
||
|
||
// SendMessage 发送消息(非流式)
|
||
func (p *PetChatService) SendMessage(ctx context.Context, userId uint, req ChatRequest) (*ChatResponse, error) {
|
||
startTime := time.Now()
|
||
|
||
// 1. 敏感词检测
|
||
sensitiveUtil := utils.GetSensitiveWordUtil()
|
||
filtered, hasSensitive, err := sensitiveUtil.FilterText(req.Message)
|
||
if err != nil {
|
||
global.GVA_LOG.Error("敏感词检测失败", zap.Error(err))
|
||
return nil, fmt.Errorf("敏感词检测失败: %v", err)
|
||
}
|
||
|
||
// 如果包含敏感词,记录并返回提示
|
||
if hasSensitive {
|
||
global.GVA_LOG.Warn("用户消息包含敏感词", zap.Uint("userId", userId), zap.String("original", req.Message), zap.String("filtered", filtered))
|
||
|
||
// 保存用户消息(包含敏感词标记)
|
||
if err := p.saveUserMessage(ctx, userId, req.SessionId, req.Message, true); err != nil {
|
||
global.GVA_LOG.Error("保存用户消息失败", zap.Error(err))
|
||
}
|
||
|
||
return &ChatResponse{
|
||
Message: "您的消息包含不当内容,请重新输入。",
|
||
SessionId: req.SessionId,
|
||
IsSensitive: true,
|
||
TokenCount: 0,
|
||
ResponseTime: time.Since(startTime).Milliseconds(),
|
||
}, nil
|
||
}
|
||
|
||
// 2. 生成会话ID(如果没有提供)
|
||
sessionId := req.SessionId
|
||
if sessionId == "" {
|
||
sessionId = uuid.New().String()
|
||
}
|
||
|
||
// 3. 获取对话历史构建上下文
|
||
history, err := p.GetChatHistory(ctx, userId, sessionId, 10) // 获取最近10条消息
|
||
if err != nil {
|
||
global.GVA_LOG.Error("获取对话历史失败", zap.Error(err))
|
||
return nil, fmt.Errorf("获取对话历史失败: %v", err)
|
||
}
|
||
|
||
// 4. 构建LLM请求
|
||
messages := p.buildMessages(history, req.Message)
|
||
llmReq := utils.LLMRequest{
|
||
Messages: messages,
|
||
Model: req.Model,
|
||
Stream: false,
|
||
Temperature: req.Temperature,
|
||
MaxTokens: req.MaxTokens,
|
||
RequestID: uuid.New().String(),
|
||
}
|
||
|
||
// 5. 调用LLM服务
|
||
llmUtil := utils.GetVolcengineLLMUtil()
|
||
llmResp, err := llmUtil.ChatCompletion(llmReq)
|
||
if err != nil {
|
||
global.GVA_LOG.Error("LLM调用失败", zap.Error(err))
|
||
return nil, fmt.Errorf("LLM调用失败: %v", err)
|
||
}
|
||
|
||
// 6. 提取AI回复
|
||
var aiMessage string
|
||
if len(llmResp.Choices) > 0 {
|
||
aiMessage = llmResp.Choices[0].Message.Content
|
||
}
|
||
|
||
// 7. 对AI回复进行敏感词检测
|
||
aiFiltered, aiHasSensitive, err := sensitiveUtil.FilterText(aiMessage)
|
||
if err != nil {
|
||
global.GVA_LOG.Error("AI回复敏感词检测失败", zap.Error(err))
|
||
aiFiltered = aiMessage // 如果检测失败,使用原始消息
|
||
}
|
||
|
||
if aiHasSensitive {
|
||
global.GVA_LOG.Warn("AI回复包含敏感词", zap.String("original", aiMessage), zap.String("filtered", aiFiltered))
|
||
aiMessage = aiFiltered
|
||
}
|
||
|
||
// 8. 保存对话记录
|
||
responseTime := time.Since(startTime).Milliseconds()
|
||
tokenCount := 0
|
||
if llmResp.Usage != nil {
|
||
tokenCount = llmResp.Usage.TotalTokens
|
||
}
|
||
|
||
// 保存用户消息
|
||
if err := p.saveUserMessage(ctx, userId, sessionId, req.Message, false); err != nil {
|
||
global.GVA_LOG.Error("保存用户消息失败", zap.Error(err))
|
||
}
|
||
|
||
// 保存AI回复
|
||
if err := p.saveAssistantMessage(ctx, userId, sessionId, aiMessage, aiHasSensitive, tokenCount, responseTime); err != nil {
|
||
global.GVA_LOG.Error("保存AI回复失败", zap.Error(err))
|
||
}
|
||
|
||
return &ChatResponse{
|
||
Message: aiMessage,
|
||
SessionId: sessionId,
|
||
IsSensitive: aiHasSensitive,
|
||
TokenCount: tokenCount,
|
||
ResponseTime: responseTime,
|
||
RequestId: llmReq.RequestID,
|
||
}, nil
|
||
}
|
||
|
||
// StreamChat 流式聊天
|
||
func (p *PetChatService) StreamChat(ctx context.Context, userId uint, req ChatRequest, eventChan chan<- StreamEvent) error {
|
||
startTime := time.Now()
|
||
|
||
// 1. 敏感词检测
|
||
sensitiveUtil := utils.GetSensitiveWordUtil()
|
||
filtered, hasSensitive, err := sensitiveUtil.FilterText(req.Message)
|
||
if err != nil {
|
||
global.GVA_LOG.Error("敏感词检测失败", zap.Error(err))
|
||
eventChan <- StreamEvent{
|
||
Event: "error",
|
||
Data: map[string]interface{}{
|
||
"error": "敏感词检测失败",
|
||
},
|
||
}
|
||
return err
|
||
}
|
||
|
||
// 如果包含敏感词,返回提示并结束
|
||
if hasSensitive {
|
||
global.GVA_LOG.Warn("用户消息包含敏感词", zap.Uint("userId", userId), zap.String("original", req.Message), zap.String("filtered", filtered))
|
||
|
||
// 保存用户消息(包含敏感词标记)
|
||
if err := p.saveUserMessage(ctx, userId, req.SessionId, req.Message, true); err != nil {
|
||
global.GVA_LOG.Error("保存用户消息失败", zap.Error(err))
|
||
}
|
||
|
||
eventChan <- StreamEvent{
|
||
Event: "message",
|
||
Data: ChatResponse{
|
||
Message: "您的消息包含不当内容,请重新输入。",
|
||
SessionId: req.SessionId,
|
||
IsSensitive: true,
|
||
TokenCount: 0,
|
||
ResponseTime: time.Since(startTime).Milliseconds(),
|
||
},
|
||
}
|
||
|
||
eventChan <- StreamEvent{
|
||
Event: "done",
|
||
Data: map[string]interface{}{"finished": true},
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// 2. 生成会话ID(如果没有提供)
|
||
sessionId := req.SessionId
|
||
if sessionId == "" {
|
||
sessionId = uuid.New().String()
|
||
}
|
||
|
||
// 3. 获取对话历史构建上下文
|
||
history, err := p.GetChatHistory(ctx, userId, sessionId, 10)
|
||
if err != nil {
|
||
global.GVA_LOG.Error("获取对话历史失败", zap.Error(err))
|
||
eventChan <- StreamEvent{
|
||
Event: "error",
|
||
Data: map[string]interface{}{
|
||
"error": "获取对话历史失败",
|
||
},
|
||
}
|
||
return err
|
||
}
|
||
|
||
// 4. 构建LLM请求
|
||
messages := p.buildMessages(history, req.Message)
|
||
llmReq := utils.LLMRequest{
|
||
Messages: messages,
|
||
Model: req.Model,
|
||
Stream: true,
|
||
Temperature: req.Temperature,
|
||
MaxTokens: req.MaxTokens,
|
||
RequestID: uuid.New().String(),
|
||
}
|
||
|
||
// 5. 创建LLM流式响应通道
|
||
llmEventChan := make(chan utils.LLMStreamEvent, 100)
|
||
defer close(llmEventChan)
|
||
|
||
// 6. 启动LLM流式调用
|
||
llmUtil := utils.GetVolcengineLLMUtil()
|
||
go func() {
|
||
if err := llmUtil.StreamChatCompletion(llmReq, llmEventChan); err != nil {
|
||
global.GVA_LOG.Error("LLM流式调用失败", zap.Error(err))
|
||
eventChan <- StreamEvent{
|
||
Event: "error",
|
||
Data: map[string]interface{}{
|
||
"error": "LLM调用失败",
|
||
},
|
||
}
|
||
}
|
||
}()
|
||
|
||
// 7. 处理流式响应
|
||
var fullMessage string
|
||
var tokenCount int
|
||
|
||
for {
|
||
select {
|
||
case <-ctx.Done():
|
||
return ctx.Err()
|
||
case llmEvent, ok := <-llmEventChan:
|
||
if !ok {
|
||
// 通道关闭,流式响应结束
|
||
return nil
|
||
}
|
||
|
||
switch llmEvent.Event {
|
||
case "message":
|
||
// 处理消息事件
|
||
if len(llmEvent.Data.Choices) > 0 {
|
||
delta := llmEvent.Data.Choices[0].Delta.Content
|
||
fullMessage += delta
|
||
|
||
// 转发消息给客户端
|
||
eventChan <- StreamEvent{
|
||
Event: "message",
|
||
Data: map[string]interface{}{
|
||
"delta": delta,
|
||
"sessionId": sessionId,
|
||
},
|
||
}
|
||
}
|
||
|
||
case "done":
|
||
// 流式响应完成
|
||
responseTime := time.Since(startTime).Milliseconds()
|
||
if llmEvent.Data.Usage != nil {
|
||
tokenCount = llmEvent.Data.Usage.TotalTokens
|
||
}
|
||
|
||
// 对完整消息进行敏感词检测
|
||
aiFiltered, aiHasSensitive, err := sensitiveUtil.FilterText(fullMessage)
|
||
if err != nil {
|
||
global.GVA_LOG.Error("AI回复敏感词检测失败", zap.Error(err))
|
||
aiFiltered = fullMessage
|
||
}
|
||
|
||
if aiHasSensitive {
|
||
global.GVA_LOG.Warn("AI回复包含敏感词", zap.String("original", fullMessage), zap.String("filtered", aiFiltered))
|
||
fullMessage = aiFiltered
|
||
}
|
||
|
||
// 保存对话记录
|
||
if err := p.saveUserMessage(ctx, userId, sessionId, req.Message, false); err != nil {
|
||
global.GVA_LOG.Error("保存用户消息失败", zap.Error(err))
|
||
}
|
||
|
||
if err := p.saveAssistantMessage(ctx, userId, sessionId, fullMessage, aiHasSensitive, tokenCount, responseTime); err != nil {
|
||
global.GVA_LOG.Error("保存AI回复失败", zap.Error(err))
|
||
}
|
||
|
||
// 发送完成事件
|
||
eventChan <- StreamEvent{
|
||
Event: "done",
|
||
Data: ChatResponse{
|
||
Message: fullMessage,
|
||
SessionId: sessionId,
|
||
IsSensitive: aiHasSensitive,
|
||
TokenCount: tokenCount,
|
||
ResponseTime: responseTime,
|
||
RequestId: llmReq.RequestID,
|
||
},
|
||
}
|
||
return nil
|
||
|
||
case "error":
|
||
// 处理错误事件
|
||
global.GVA_LOG.Error("LLM流式响应错误", zap.Any("error", llmEvent.Data.Error))
|
||
eventChan <- StreamEvent{
|
||
Event: "error",
|
||
Data: map[string]interface{}{
|
||
"error": "LLM响应错误",
|
||
},
|
||
}
|
||
return fmt.Errorf("LLM响应错误")
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// GetChatHistory 获取对话历史
|
||
func (p *PetChatService) GetChatHistory(ctx context.Context, userId uint, sessionId string, limit int) ([]pet.PetAiAssistantConversations, error) {
|
||
var conversations []pet.PetAiAssistantConversations
|
||
|
||
query := global.GVA_DB.WithContext(ctx).Where("user_id = ?", userId)
|
||
|
||
// 如果提供了会话ID,则按会话ID过滤
|
||
if sessionId != "" {
|
||
query = query.Where("session_id = ?", sessionId)
|
||
}
|
||
|
||
// 按创建时间倒序,限制数量
|
||
if err := query.Order("created_at DESC").Limit(limit).Find(&conversations).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 反转切片,使其按时间正序排列
|
||
for i, j := 0, len(conversations)-1; i < j; i, j = i+1, j-1 {
|
||
conversations[i], conversations[j] = conversations[j], conversations[i]
|
||
}
|
||
|
||
return conversations, nil
|
||
}
|
||
|
||
// SaveConversation 保存对话记录
|
||
func (p *PetChatService) SaveConversation(ctx context.Context, conversation *pet.PetAiAssistantConversations) error {
|
||
return global.GVA_DB.WithContext(ctx).Create(conversation).Error
|
||
}
|
||
|
||
// ClearChatHistory 清空对话历史
|
||
func (p *PetChatService) ClearChatHistory(ctx context.Context, userId uint, sessionId string) error {
|
||
query := global.GVA_DB.WithContext(ctx).Where("user_id = ?", userId)
|
||
|
||
// 如果提供了会话ID,则只清空指定会话
|
||
if sessionId != "" {
|
||
query = query.Where("session_id = ?", sessionId)
|
||
}
|
||
|
||
return query.Delete(&pet.PetAiAssistantConversations{}).Error
|
||
}
|
||
|
||
// GetChatSessions 获取用户的聊天会话列表
|
||
func (p *PetChatService) GetChatSessions(ctx context.Context, userId uint) ([]map[string]interface{}, error) {
|
||
var sessions []map[string]interface{}
|
||
|
||
// 查询用户的所有会话,按最后更新时间分组
|
||
err := global.GVA_DB.WithContext(ctx).
|
||
Model(&pet.PetAiAssistantConversations{}).
|
||
Select("session_id, MAX(updated_at) as last_updated, COUNT(*) as message_count").
|
||
Where("user_id = ? AND session_id IS NOT NULL AND session_id != ''", userId).
|
||
Group("session_id").
|
||
Order("last_updated DESC").
|
||
Scan(&sessions).Error
|
||
|
||
return sessions, err
|
||
}
|
||
|
||
// buildMessages 构建LLM请求消息
|
||
func (p *PetChatService) buildMessages(history []pet.PetAiAssistantConversations, userMessage string) []utils.LLMMessage {
|
||
messages := []utils.LLMMessage{
|
||
{
|
||
Role: "system",
|
||
Content: "你是一个专业的宠物助手,专门为宠物主人提供关于宠物护理、健康、训练和日常生活的建议。请用友善、专业的语气回答问题。",
|
||
},
|
||
}
|
||
|
||
// 添加历史对话
|
||
for _, conv := range history {
|
||
if conv.MessageContent != nil && conv.Role != nil {
|
||
messages = append(messages, utils.LLMMessage{
|
||
Role: *conv.Role,
|
||
Content: *conv.MessageContent,
|
||
})
|
||
}
|
||
}
|
||
|
||
// 添加当前用户消息
|
||
messages = append(messages, utils.LLMMessage{
|
||
Role: "user",
|
||
Content: userMessage,
|
||
})
|
||
|
||
return messages
|
||
}
|
||
|
||
// saveUserMessage 保存用户消息
|
||
func (p *PetChatService) saveUserMessage(ctx context.Context, userId uint, sessionId, message string, isSensitive bool) error {
|
||
userIdPtr := int(userId)
|
||
rolePtr := "user"
|
||
messagePtr := message
|
||
sessionIdPtr := sessionId
|
||
isSensitivePtr := isSensitive
|
||
|
||
conversation := &pet.PetAiAssistantConversations{
|
||
UserId: &userIdPtr,
|
||
MessageContent: &messagePtr,
|
||
Role: &rolePtr,
|
||
SessionId: &sessionIdPtr,
|
||
IsSensitive: &isSensitivePtr,
|
||
}
|
||
|
||
return p.SaveConversation(ctx, conversation)
|
||
}
|
||
|
||
// saveAssistantMessage 保存AI助手消息
|
||
func (p *PetChatService) saveAssistantMessage(ctx context.Context, userId uint, sessionId, message string, isSensitive bool, tokenCount int, responseTime int64) error {
|
||
userIdPtr := int(userId)
|
||
rolePtr := "assistant"
|
||
messagePtr := message
|
||
sessionIdPtr := sessionId
|
||
isSensitivePtr := isSensitive
|
||
tokenCountPtr := tokenCount
|
||
responseTimePtr := int(responseTime)
|
||
|
||
conversation := &pet.PetAiAssistantConversations{
|
||
UserId: &userIdPtr,
|
||
MessageContent: &messagePtr,
|
||
Role: &rolePtr,
|
||
SessionId: &sessionIdPtr,
|
||
IsSensitive: &isSensitivePtr,
|
||
TokenCount: &tokenCountPtr,
|
||
ResponseTime: &responseTimePtr,
|
||
}
|
||
|
||
return p.SaveConversation(ctx, conversation)
|
||
}
|