pet-ai/server/service/pet/pet_chat_service.go

452 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package pet
import (
"context"
"fmt"
"time"
"github.com/flipped-aurora/gin-vue-admin/server/global"
"github.com/flipped-aurora/gin-vue-admin/server/model/pet"
petRequest "github.com/flipped-aurora/gin-vue-admin/server/model/pet/request"
petResponse "github.com/flipped-aurora/gin-vue-admin/server/model/pet/response"
"github.com/flipped-aurora/gin-vue-admin/server/utils"
"github.com/google/uuid"
"go.uber.org/zap"
)
// 使用新定义的模型结构
type ChatRequest = petRequest.ChatRequest
type ChatResponse = petResponse.ChatResponse
type StreamEvent = petResponse.StreamEvent
// PetChatService 宠物聊天服务
type PetChatService struct{}
// SendMessage 发送消息(非流式)
func (p *PetChatService) SendMessage(ctx context.Context, userId uint, req ChatRequest) (*ChatResponse, error) {
startTime := time.Now()
// 1. 敏感词检测
sensitiveUtil := utils.GetSensitiveWordUtil()
filtered, hasSensitive, err := sensitiveUtil.FilterText(req.Message)
if err != nil {
global.GVA_LOG.Error("敏感词检测失败", zap.Error(err))
return nil, fmt.Errorf("敏感词检测失败: %v", err)
}
// 如果包含敏感词,记录并返回提示
if hasSensitive {
global.GVA_LOG.Warn("用户消息包含敏感词", zap.Uint("userId", userId), zap.String("original", req.Message), zap.String("filtered", filtered))
// 保存用户消息(包含敏感词标记)
if err := p.saveUserMessage(ctx, userId, req.SessionId, req.Message, true); err != nil {
global.GVA_LOG.Error("保存用户消息失败", zap.Error(err))
}
return &ChatResponse{
Message: "您的消息包含不当内容,请重新输入。",
SessionId: req.SessionId,
IsSensitive: true,
TokenCount: 0,
ResponseTime: time.Since(startTime).Milliseconds(),
}, nil
}
// 2. 生成会话ID如果没有提供
sessionId := req.SessionId
if sessionId == "" {
sessionId = uuid.New().String()
}
// 3. 获取对话历史构建上下文
history, err := p.GetChatHistory(ctx, userId, sessionId, 10) // 获取最近10条消息
if err != nil {
global.GVA_LOG.Error("获取对话历史失败", zap.Error(err))
return nil, fmt.Errorf("获取对话历史失败: %v", err)
}
// 4. 构建LLM请求
messages := p.buildMessages(history, req.Message)
llmReq := utils.LLMRequest{
Messages: messages,
Model: req.Model,
Stream: false,
Temperature: req.Temperature,
MaxTokens: req.MaxTokens,
RequestID: uuid.New().String(),
}
// 5. 调用LLM服务
llmUtil := utils.GetVolcengineLLMUtil()
llmResp, err := llmUtil.ChatCompletion(llmReq)
if err != nil {
global.GVA_LOG.Error("LLM调用失败", zap.Error(err))
return nil, fmt.Errorf("LLM调用失败: %v", err)
}
// 6. 提取AI回复
var aiMessage string
if len(llmResp.Choices) > 0 {
aiMessage = llmResp.Choices[0].Message.Content
}
// 7. 对AI回复进行敏感词检测
aiFiltered, aiHasSensitive, err := sensitiveUtil.FilterText(aiMessage)
if err != nil {
global.GVA_LOG.Error("AI回复敏感词检测失败", zap.Error(err))
aiFiltered = aiMessage // 如果检测失败,使用原始消息
}
if aiHasSensitive {
global.GVA_LOG.Warn("AI回复包含敏感词", zap.String("original", aiMessage), zap.String("filtered", aiFiltered))
aiMessage = aiFiltered
}
// 8. 保存对话记录
responseTime := time.Since(startTime).Milliseconds()
tokenCount := 0
if llmResp.Usage != nil {
tokenCount = llmResp.Usage.TotalTokens
}
// 保存用户消息
if err := p.saveUserMessage(ctx, userId, sessionId, req.Message, false); err != nil {
global.GVA_LOG.Error("保存用户消息失败", zap.Error(err))
}
// 保存AI回复
if err := p.saveAssistantMessage(ctx, userId, sessionId, aiMessage, aiHasSensitive, tokenCount, responseTime); err != nil {
global.GVA_LOG.Error("保存AI回复失败", zap.Error(err))
}
return &ChatResponse{
Message: aiMessage,
SessionId: sessionId,
IsSensitive: aiHasSensitive,
TokenCount: tokenCount,
ResponseTime: responseTime,
RequestId: llmReq.RequestID,
}, nil
}
// StreamChat 流式聊天
func (p *PetChatService) StreamChat(ctx context.Context, userId uint, req ChatRequest, eventChan chan<- StreamEvent) error {
startTime := time.Now()
// 1. 敏感词检测
sensitiveUtil := utils.GetSensitiveWordUtil()
filtered, hasSensitive, err := sensitiveUtil.FilterText(req.Message)
if err != nil {
global.GVA_LOG.Error("敏感词检测失败", zap.Error(err))
eventChan <- StreamEvent{
Event: "error",
Data: petResponse.StreamErrorData{
Error: "敏感词检测失败",
},
}
return err
}
// 如果包含敏感词,返回提示并结束
if hasSensitive {
global.GVA_LOG.Warn("用户消息包含敏感词", zap.Uint("userId", userId), zap.String("original", req.Message), zap.String("filtered", filtered))
// 保存用户消息(包含敏感词标记)
if err := p.saveUserMessage(ctx, userId, req.SessionId, req.Message, true); err != nil {
global.GVA_LOG.Error("保存用户消息失败", zap.Error(err))
}
eventChan <- StreamEvent{
Event: "message",
Data: petResponse.StreamMessageData{
Delta: "您的消息包含不当内容,请重新输入。",
},
}
eventChan <- StreamEvent{
Event: "done",
Data: petResponse.StreamDoneData{
Message: "您的消息包含不当内容,请重新输入。",
SessionId: req.SessionId,
},
}
return nil
}
// 2. 生成会话ID如果没有提供
sessionId := req.SessionId
if sessionId == "" {
sessionId = uuid.New().String()
}
// 3. 获取对话历史构建上下文
history, err := p.GetChatHistory(ctx, userId, sessionId, 10)
if err != nil {
global.GVA_LOG.Error("获取对话历史失败", zap.Error(err))
eventChan <- StreamEvent{
Event: "error",
Data: petResponse.StreamErrorData{
Error: "获取对话历史失败",
},
}
return err
}
// 4. 构建LLM请求
messages := p.buildMessages(history, req.Message)
llmReq := utils.LLMRequest{
Messages: messages,
Model: req.Model,
Stream: true,
Temperature: req.Temperature,
MaxTokens: req.MaxTokens,
RequestID: uuid.New().String(),
}
// 5. 创建LLM流式响应通道
llmEventChan := make(chan utils.LLMStreamEvent, 100)
defer close(llmEventChan)
// 6. 启动LLM流式调用
llmUtil := utils.GetVolcengineLLMUtil()
go func() {
if err := llmUtil.StreamChatCompletion(llmReq, llmEventChan); err != nil {
global.GVA_LOG.Error("LLM流式调用失败", zap.Error(err))
eventChan <- StreamEvent{
Event: "error",
Data: petResponse.StreamErrorData{
Error: "LLM调用失败",
},
}
}
}()
// 7. 处理流式响应
var fullMessage string
var tokenCount int
for {
select {
case <-ctx.Done():
return ctx.Err()
case llmEvent, ok := <-llmEventChan:
if !ok {
// 通道关闭,流式响应结束
return nil
}
switch llmEvent.Event {
case "message":
// 处理消息事件
if len(llmEvent.Data.Choices) > 0 {
delta := llmEvent.Data.Choices[0].Delta.Content
fullMessage += delta
// 转发消息给客户端
eventChan <- StreamEvent{
Event: "message",
Data: petResponse.StreamMessageData{
Delta: delta,
},
}
}
case "done":
// 流式响应完成
responseTime := time.Since(startTime).Milliseconds()
if llmEvent.Data.Usage != nil {
tokenCount = llmEvent.Data.Usage.TotalTokens
}
// 对完整消息进行敏感词检测
aiFiltered, aiHasSensitive, err := sensitiveUtil.FilterText(fullMessage)
if err != nil {
global.GVA_LOG.Error("AI回复敏感词检测失败", zap.Error(err))
aiFiltered = fullMessage
}
if aiHasSensitive {
global.GVA_LOG.Warn("AI回复包含敏感词", zap.String("original", fullMessage), zap.String("filtered", aiFiltered))
fullMessage = aiFiltered
}
// 保存对话记录
if err := p.saveUserMessage(ctx, userId, sessionId, req.Message, false); err != nil {
global.GVA_LOG.Error("保存用户消息失败", zap.Error(err))
}
if err := p.saveAssistantMessage(ctx, userId, sessionId, fullMessage, aiHasSensitive, tokenCount, responseTime); err != nil {
global.GVA_LOG.Error("保存AI回复失败", zap.Error(err))
}
// 发送完成事件
eventChan <- StreamEvent{
Event: "done",
Data: petResponse.StreamDoneData{
Message: fullMessage,
SessionId: sessionId,
},
}
return nil
case "error":
// 处理错误事件
global.GVA_LOG.Error("LLM流式响应错误", zap.Any("error", llmEvent.Data.Error))
eventChan <- StreamEvent{
Event: "error",
Data: petResponse.StreamErrorData{
Error: "LLM响应错误",
},
}
return fmt.Errorf("LLM响应错误")
}
}
}
}
// GetChatHistory 获取对话历史
func (p *PetChatService) GetChatHistory(ctx context.Context, userId uint, sessionId string, limit int) ([]petResponse.ChatHistoryItem, error) {
var conversations []pet.PetAiAssistantConversations
query := global.GVA_DB.WithContext(ctx).Where("user_id = ?", userId)
// 如果提供了会话ID则按会话ID过滤
if sessionId != "" {
query = query.Where("session_id = ?", sessionId)
}
// 按创建时间倒序,限制数量
if err := query.Order("created_at DESC").Limit(limit).Find(&conversations).Error; err != nil {
return nil, err
}
// 反转切片,使其按时间正序排列
for i, j := 0, len(conversations)-1; i < j; i, j = i+1, j-1 {
conversations[i], conversations[j] = conversations[j], conversations[i]
}
// 转换为简化结构
items := make([]petResponse.ChatHistoryItem, 0, len(conversations))
for _, conv := range conversations {
if conv.MessageContent != nil && conv.Role != nil {
item := petResponse.ChatHistoryItem{
ID: conv.ID,
Role: *conv.Role,
Message: *conv.MessageContent,
SessionId: "",
CreatedAt: conv.CreatedAt,
}
// 设置会话ID如果存在
if conv.SessionId != nil {
item.SessionId = *conv.SessionId
}
items = append(items, item)
}
}
return items, nil
}
// SaveConversation 保存对话记录
func (p *PetChatService) SaveConversation(ctx context.Context, conversation *pet.PetAiAssistantConversations) error {
return global.GVA_DB.WithContext(ctx).Create(conversation).Error
}
// ClearChatHistory 清空对话历史
func (p *PetChatService) ClearChatHistory(ctx context.Context, userId uint, sessionId string) error {
query := global.GVA_DB.WithContext(ctx).Where("user_id = ?", userId)
// 如果提供了会话ID则只清空指定会话
if sessionId != "" {
query = query.Where("session_id = ?", sessionId)
}
return query.Delete(&pet.PetAiAssistantConversations{}).Error
}
// GetChatSessions 获取用户的聊天会话列表
func (p *PetChatService) GetChatSessions(ctx context.Context, userId uint) ([]map[string]interface{}, error) {
var sessions []map[string]interface{}
// 查询用户的所有会话,按最后更新时间分组
err := global.GVA_DB.WithContext(ctx).
Model(&pet.PetAiAssistantConversations{}).
Select("session_id, MAX(updated_at) as last_updated, COUNT(*) as message_count").
Where("user_id = ? AND session_id IS NOT NULL AND session_id != ''", userId).
Group("session_id").
Order("last_updated DESC").
Scan(&sessions).Error
return sessions, err
}
// buildMessages 构建LLM请求消息
func (p *PetChatService) buildMessages(history []petResponse.ChatHistoryItem, userMessage string) []utils.LLMMessage {
messages := []utils.LLMMessage{
{
Role: "system",
Content: "你是一个专业的宠物助手,专门为宠物主人提供关于宠物护理、健康、训练和日常生活的建议。请用友善、专业的语气回答问题。",
},
}
// 添加历史对话
for _, conv := range history {
messages = append(messages, utils.LLMMessage{
Role: conv.Role,
Content: conv.Message,
})
}
// 添加当前用户消息
messages = append(messages, utils.LLMMessage{
Role: "user",
Content: userMessage,
})
return messages
}
// saveUserMessage 保存用户消息
func (p *PetChatService) saveUserMessage(ctx context.Context, userId uint, sessionId, message string, isSensitive bool) error {
userIdPtr := int(userId)
rolePtr := "user"
messagePtr := message
sessionIdPtr := sessionId
isSensitivePtr := isSensitive
conversation := &pet.PetAiAssistantConversations{
UserId: &userIdPtr,
MessageContent: &messagePtr,
Role: &rolePtr,
SessionId: &sessionIdPtr,
IsSensitive: &isSensitivePtr,
}
return p.SaveConversation(ctx, conversation)
}
// saveAssistantMessage 保存AI助手消息
func (p *PetChatService) saveAssistantMessage(ctx context.Context, userId uint, sessionId, message string, isSensitive bool, tokenCount int, responseTime int64) error {
userIdPtr := int(userId)
rolePtr := "assistant"
messagePtr := message
sessionIdPtr := sessionId
isSensitivePtr := isSensitive
tokenCountPtr := tokenCount
responseTimePtr := int(responseTime)
conversation := &pet.PetAiAssistantConversations{
UserId: &userIdPtr,
MessageContent: &messagePtr,
Role: &rolePtr,
SessionId: &sessionIdPtr,
IsSensitive: &isSensitivePtr,
TokenCount: &tokenCountPtr,
ResponseTime: &responseTimePtr,
}
return p.SaveConversation(ctx, conversation)
}