pet-ai/server/service/pet/pet_chat_service.go

458 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package pet
import (
"context"
"fmt"
"time"
"github.com/flipped-aurora/gin-vue-admin/server/global"
"github.com/flipped-aurora/gin-vue-admin/server/model/pet"
"github.com/flipped-aurora/gin-vue-admin/server/utils"
"github.com/google/uuid"
"go.uber.org/zap"
)
// ChatRequest 聊天请求结构体
type ChatRequest struct {
Message string `json:"message" binding:"required"` // 用户消息
SessionId string `json:"sessionId"` // 会话ID可选
Stream bool `json:"stream"` // 是否流式响应
Temperature float64 `json:"temperature"` // 温度参数
MaxTokens int `json:"maxTokens"` // 最大token数
Model string `json:"model"` // 模型名称
}
// ChatResponse 聊天响应结构体
type ChatResponse struct {
Message string `json:"message"` // AI回复消息
SessionId string `json:"sessionId"` // 会话ID
IsSensitive bool `json:"isSensitive"` // 是否包含敏感词
TokenCount int `json:"tokenCount"` // Token消耗数量
ResponseTime int64 `json:"responseTime"` // 响应时间(ms)
RequestId string `json:"requestId,omitempty"` // 请求ID
}
// StreamEvent 流式事件结构体
type StreamEvent struct {
Event string `json:"event"` // 事件类型: message, error, done
Data interface{} `json:"data"` // 事件数据
}
// PetChatService 宠物聊天服务
type PetChatService struct{}
// SendMessage 发送消息(非流式)
func (p *PetChatService) SendMessage(ctx context.Context, userId uint, req ChatRequest) (*ChatResponse, error) {
startTime := time.Now()
// 1. 敏感词检测
sensitiveUtil := utils.GetSensitiveWordUtil()
filtered, hasSensitive, err := sensitiveUtil.FilterText(req.Message)
if err != nil {
global.GVA_LOG.Error("敏感词检测失败", zap.Error(err))
return nil, fmt.Errorf("敏感词检测失败: %v", err)
}
// 如果包含敏感词,记录并返回提示
if hasSensitive {
global.GVA_LOG.Warn("用户消息包含敏感词", zap.Uint("userId", userId), zap.String("original", req.Message), zap.String("filtered", filtered))
// 保存用户消息(包含敏感词标记)
if err := p.saveUserMessage(ctx, userId, req.SessionId, req.Message, true); err != nil {
global.GVA_LOG.Error("保存用户消息失败", zap.Error(err))
}
return &ChatResponse{
Message: "您的消息包含不当内容,请重新输入。",
SessionId: req.SessionId,
IsSensitive: true,
TokenCount: 0,
ResponseTime: time.Since(startTime).Milliseconds(),
}, nil
}
// 2. 生成会话ID如果没有提供
sessionId := req.SessionId
if sessionId == "" {
sessionId = uuid.New().String()
}
// 3. 获取对话历史构建上下文
history, err := p.GetChatHistory(ctx, userId, sessionId, 10) // 获取最近10条消息
if err != nil {
global.GVA_LOG.Error("获取对话历史失败", zap.Error(err))
return nil, fmt.Errorf("获取对话历史失败: %v", err)
}
// 4. 构建LLM请求
messages := p.buildMessages(history, req.Message)
llmReq := utils.LLMRequest{
Messages: messages,
Model: req.Model,
Stream: false,
Temperature: req.Temperature,
MaxTokens: req.MaxTokens,
RequestID: uuid.New().String(),
}
// 5. 调用LLM服务
llmUtil := utils.GetVolcengineLLMUtil()
llmResp, err := llmUtil.ChatCompletion(llmReq)
if err != nil {
global.GVA_LOG.Error("LLM调用失败", zap.Error(err))
return nil, fmt.Errorf("LLM调用失败: %v", err)
}
// 6. 提取AI回复
var aiMessage string
if len(llmResp.Choices) > 0 {
aiMessage = llmResp.Choices[0].Message.Content
}
// 7. 对AI回复进行敏感词检测
aiFiltered, aiHasSensitive, err := sensitiveUtil.FilterText(aiMessage)
if err != nil {
global.GVA_LOG.Error("AI回复敏感词检测失败", zap.Error(err))
aiFiltered = aiMessage // 如果检测失败,使用原始消息
}
if aiHasSensitive {
global.GVA_LOG.Warn("AI回复包含敏感词", zap.String("original", aiMessage), zap.String("filtered", aiFiltered))
aiMessage = aiFiltered
}
// 8. 保存对话记录
responseTime := time.Since(startTime).Milliseconds()
tokenCount := 0
if llmResp.Usage != nil {
tokenCount = llmResp.Usage.TotalTokens
}
// 保存用户消息
if err := p.saveUserMessage(ctx, userId, sessionId, req.Message, false); err != nil {
global.GVA_LOG.Error("保存用户消息失败", zap.Error(err))
}
// 保存AI回复
if err := p.saveAssistantMessage(ctx, userId, sessionId, aiMessage, aiHasSensitive, tokenCount, responseTime); err != nil {
global.GVA_LOG.Error("保存AI回复失败", zap.Error(err))
}
return &ChatResponse{
Message: aiMessage,
SessionId: sessionId,
IsSensitive: aiHasSensitive,
TokenCount: tokenCount,
ResponseTime: responseTime,
RequestId: llmReq.RequestID,
}, nil
}
// StreamChat 流式聊天
func (p *PetChatService) StreamChat(ctx context.Context, userId uint, req ChatRequest, eventChan chan<- StreamEvent) error {
startTime := time.Now()
// 1. 敏感词检测
sensitiveUtil := utils.GetSensitiveWordUtil()
filtered, hasSensitive, err := sensitiveUtil.FilterText(req.Message)
if err != nil {
global.GVA_LOG.Error("敏感词检测失败", zap.Error(err))
eventChan <- StreamEvent{
Event: "error",
Data: map[string]interface{}{
"error": "敏感词检测失败",
},
}
return err
}
// 如果包含敏感词,返回提示并结束
if hasSensitive {
global.GVA_LOG.Warn("用户消息包含敏感词", zap.Uint("userId", userId), zap.String("original", req.Message), zap.String("filtered", filtered))
// 保存用户消息(包含敏感词标记)
if err := p.saveUserMessage(ctx, userId, req.SessionId, req.Message, true); err != nil {
global.GVA_LOG.Error("保存用户消息失败", zap.Error(err))
}
eventChan <- StreamEvent{
Event: "message",
Data: ChatResponse{
Message: "您的消息包含不当内容,请重新输入。",
SessionId: req.SessionId,
IsSensitive: true,
TokenCount: 0,
ResponseTime: time.Since(startTime).Milliseconds(),
},
}
eventChan <- StreamEvent{
Event: "done",
Data: map[string]interface{}{"finished": true},
}
return nil
}
// 2. 生成会话ID如果没有提供
sessionId := req.SessionId
if sessionId == "" {
sessionId = uuid.New().String()
}
// 3. 获取对话历史构建上下文
history, err := p.GetChatHistory(ctx, userId, sessionId, 10)
if err != nil {
global.GVA_LOG.Error("获取对话历史失败", zap.Error(err))
eventChan <- StreamEvent{
Event: "error",
Data: map[string]interface{}{
"error": "获取对话历史失败",
},
}
return err
}
// 4. 构建LLM请求
messages := p.buildMessages(history, req.Message)
llmReq := utils.LLMRequest{
Messages: messages,
Model: req.Model,
Stream: true,
Temperature: req.Temperature,
MaxTokens: req.MaxTokens,
RequestID: uuid.New().String(),
}
// 5. 创建LLM流式响应通道
llmEventChan := make(chan utils.LLMStreamEvent, 100)
defer close(llmEventChan)
// 6. 启动LLM流式调用
llmUtil := utils.GetVolcengineLLMUtil()
go func() {
if err := llmUtil.StreamChatCompletion(llmReq, llmEventChan); err != nil {
global.GVA_LOG.Error("LLM流式调用失败", zap.Error(err))
eventChan <- StreamEvent{
Event: "error",
Data: map[string]interface{}{
"error": "LLM调用失败",
},
}
}
}()
// 7. 处理流式响应
var fullMessage string
var tokenCount int
for {
select {
case <-ctx.Done():
return ctx.Err()
case llmEvent, ok := <-llmEventChan:
if !ok {
// 通道关闭,流式响应结束
return nil
}
switch llmEvent.Event {
case "message":
// 处理消息事件
if len(llmEvent.Data.Choices) > 0 {
delta := llmEvent.Data.Choices[0].Delta.Content
fullMessage += delta
// 转发消息给客户端
eventChan <- StreamEvent{
Event: "message",
Data: map[string]interface{}{
"delta": delta,
"sessionId": sessionId,
},
}
}
case "done":
// 流式响应完成
responseTime := time.Since(startTime).Milliseconds()
if llmEvent.Data.Usage != nil {
tokenCount = llmEvent.Data.Usage.TotalTokens
}
// 对完整消息进行敏感词检测
aiFiltered, aiHasSensitive, err := sensitiveUtil.FilterText(fullMessage)
if err != nil {
global.GVA_LOG.Error("AI回复敏感词检测失败", zap.Error(err))
aiFiltered = fullMessage
}
if aiHasSensitive {
global.GVA_LOG.Warn("AI回复包含敏感词", zap.String("original", fullMessage), zap.String("filtered", aiFiltered))
fullMessage = aiFiltered
}
// 保存对话记录
if err := p.saveUserMessage(ctx, userId, sessionId, req.Message, false); err != nil {
global.GVA_LOG.Error("保存用户消息失败", zap.Error(err))
}
if err := p.saveAssistantMessage(ctx, userId, sessionId, fullMessage, aiHasSensitive, tokenCount, responseTime); err != nil {
global.GVA_LOG.Error("保存AI回复失败", zap.Error(err))
}
// 发送完成事件
eventChan <- StreamEvent{
Event: "done",
Data: ChatResponse{
Message: fullMessage,
SessionId: sessionId,
IsSensitive: aiHasSensitive,
TokenCount: tokenCount,
ResponseTime: responseTime,
RequestId: llmReq.RequestID,
},
}
return nil
case "error":
// 处理错误事件
global.GVA_LOG.Error("LLM流式响应错误", zap.Any("error", llmEvent.Data.Error))
eventChan <- StreamEvent{
Event: "error",
Data: map[string]interface{}{
"error": "LLM响应错误",
},
}
return fmt.Errorf("LLM响应错误")
}
}
}
}
// GetChatHistory 获取对话历史
func (p *PetChatService) GetChatHistory(ctx context.Context, userId uint, sessionId string, limit int) ([]pet.PetAiAssistantConversations, error) {
var conversations []pet.PetAiAssistantConversations
query := global.GVA_DB.WithContext(ctx).Where("user_id = ?", userId)
// 如果提供了会话ID则按会话ID过滤
if sessionId != "" {
query = query.Where("session_id = ?", sessionId)
}
// 按创建时间倒序,限制数量
if err := query.Order("created_at DESC").Limit(limit).Find(&conversations).Error; err != nil {
return nil, err
}
// 反转切片,使其按时间正序排列
for i, j := 0, len(conversations)-1; i < j; i, j = i+1, j-1 {
conversations[i], conversations[j] = conversations[j], conversations[i]
}
return conversations, nil
}
// SaveConversation 保存对话记录
func (p *PetChatService) SaveConversation(ctx context.Context, conversation *pet.PetAiAssistantConversations) error {
return global.GVA_DB.WithContext(ctx).Create(conversation).Error
}
// ClearChatHistory 清空对话历史
func (p *PetChatService) ClearChatHistory(ctx context.Context, userId uint, sessionId string) error {
query := global.GVA_DB.WithContext(ctx).Where("user_id = ?", userId)
// 如果提供了会话ID则只清空指定会话
if sessionId != "" {
query = query.Where("session_id = ?", sessionId)
}
return query.Delete(&pet.PetAiAssistantConversations{}).Error
}
// GetChatSessions 获取用户的聊天会话列表
func (p *PetChatService) GetChatSessions(ctx context.Context, userId uint) ([]map[string]interface{}, error) {
var sessions []map[string]interface{}
// 查询用户的所有会话,按最后更新时间分组
err := global.GVA_DB.WithContext(ctx).
Model(&pet.PetAiAssistantConversations{}).
Select("session_id, MAX(updated_at) as last_updated, COUNT(*) as message_count").
Where("user_id = ? AND session_id IS NOT NULL AND session_id != ''", userId).
Group("session_id").
Order("last_updated DESC").
Scan(&sessions).Error
return sessions, err
}
// buildMessages 构建LLM请求消息
func (p *PetChatService) buildMessages(history []pet.PetAiAssistantConversations, userMessage string) []utils.LLMMessage {
messages := []utils.LLMMessage{
{
Role: "system",
Content: "你是一个专业的宠物助手,专门为宠物主人提供关于宠物护理、健康、训练和日常生活的建议。请用友善、专业的语气回答问题。",
},
}
// 添加历史对话
for _, conv := range history {
if conv.MessageContent != nil && conv.Role != nil {
messages = append(messages, utils.LLMMessage{
Role: *conv.Role,
Content: *conv.MessageContent,
})
}
}
// 添加当前用户消息
messages = append(messages, utils.LLMMessage{
Role: "user",
Content: userMessage,
})
return messages
}
// saveUserMessage 保存用户消息
func (p *PetChatService) saveUserMessage(ctx context.Context, userId uint, sessionId, message string, isSensitive bool) error {
userIdPtr := int(userId)
rolePtr := "user"
messagePtr := message
sessionIdPtr := sessionId
isSensitivePtr := isSensitive
conversation := &pet.PetAiAssistantConversations{
UserId: &userIdPtr,
MessageContent: &messagePtr,
Role: &rolePtr,
SessionId: &sessionIdPtr,
IsSensitive: &isSensitivePtr,
}
return p.SaveConversation(ctx, conversation)
}
// saveAssistantMessage 保存AI助手消息
func (p *PetChatService) saveAssistantMessage(ctx context.Context, userId uint, sessionId, message string, isSensitive bool, tokenCount int, responseTime int64) error {
userIdPtr := int(userId)
rolePtr := "assistant"
messagePtr := message
sessionIdPtr := sessionId
isSensitivePtr := isSensitive
tokenCountPtr := tokenCount
responseTimePtr := int(responseTime)
conversation := &pet.PetAiAssistantConversations{
UserId: &userIdPtr,
MessageContent: &messagePtr,
Role: &rolePtr,
SessionId: &sessionIdPtr,
IsSensitive: &isSensitivePtr,
TokenCount: &tokenCountPtr,
ResponseTime: &responseTimePtr,
}
return p.SaveConversation(ctx, conversation)
}