pet-ai/server/plugin/volcengine/service/llm_service.go

317 lines
8.1 KiB
Go

package service
import (
"context"
"fmt"
"sync"
"time"
"github.com/flipped-aurora/gin-vue-admin/server/global"
"github.com/flipped-aurora/gin-vue-admin/server/plugin/volcengine/model/request"
"github.com/flipped-aurora/gin-vue-admin/server/plugin/volcengine/model/response"
"github.com/google/uuid"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
"go.uber.org/zap"
)
type LLMService struct {
client *arkruntime.Client
activeSessions sync.Map // requestID -> context.CancelFunc
}
var LLMServiceApp = new(LLMService)
// InitVolcengineClient 初始化火山引擎客户端
func (l *LLMService) InitVolcengineClient() error {
config := global.GVA_CONFIG.Volcengine
// 验证配置
if config.AccessKey == "" || config.SecretKey == "" {
return fmt.Errorf("volcengine access key or secret key is empty")
}
// 创建ARK Runtime客户端
l.client = arkruntime.NewClientWithAkSk(config.AccessKey, config.SecretKey)
global.GVA_LOG.Info("Volcengine client initialized successfully")
return nil
}
// ChatCompletion 非流式聊天完成
func (l *LLMService) ChatCompletion(req request.ChatRequest) (*response.ChatResponse, error) {
if l.client == nil {
if err := l.InitVolcengineClient(); err != nil {
return nil, err
}
}
// 生成请求ID
if req.RequestID == "" {
req.RequestID = uuid.New().String()
}
// 转换消息格式
messages := make([]*model.ChatCompletionMessage, len(req.Messages))
for i, msg := range req.Messages {
messages[i] = &model.ChatCompletionMessage{
Role: msg.Role,
Content: &model.ChatCompletionMessageContent{
StringValue: &msg.Content,
},
}
}
// 设置模型
modelName := req.Model
if modelName == "" {
modelName = global.GVA_CONFIG.Volcengine.DefaultModel
}
// 创建请求
chatReq := &model.ChatCompletionRequest{
Model: modelName,
Messages: messages,
}
// 设置可选参数
if req.Temperature > 0 {
chatReq.Temperature = float32(req.Temperature)
}
if req.MaxTokens > 0 {
chatReq.MaxTokens = req.MaxTokens
}
if req.TopP > 0 {
chatReq.TopP = float32(req.TopP)
}
// 创建上下文
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
// 存储取消函数
l.activeSessions.Store(req.RequestID, cancel)
defer l.activeSessions.Delete(req.RequestID)
// 调用API
resp, err := l.client.CreateChatCompletion(ctx, chatReq)
if err != nil {
global.GVA_LOG.Error("Chat completion failed", zap.Error(err), zap.String("requestID", req.RequestID))
return nil, fmt.Errorf("chat completion failed: %w", err)
}
// 转换响应格式
chatResp := &response.ChatResponse{
ID: req.RequestID,
Object: "chat.completion",
Created: time.Now().Unix(),
Model: modelName,
Choices: make([]response.Choice, len(resp.Choices)),
}
for i, choice := range resp.Choices {
messageContent := ""
if choice.Message.Content != nil && choice.Message.Content.StringValue != nil {
messageContent = *choice.Message.Content.StringValue
}
chatResp.Choices[i] = response.Choice{
Index: i,
Message: response.Message{
Role: choice.Message.Role,
Content: messageContent,
},
FinishReason: string(choice.FinishReason),
}
}
// 设置使用量统计
chatResp.Usage = &response.Usage{
PromptTokens: resp.Usage.PromptTokens,
CompletionTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
}
global.GVA_LOG.Info("Chat completion successful", zap.String("requestID", req.RequestID))
return chatResp, nil
}
// StreamChatCompletion 流式聊天完成
func (l *LLMService) StreamChatCompletion(req request.ChatRequest, eventChan chan<- response.StreamEvent) error {
if l.client == nil {
if err := l.InitVolcengineClient(); err != nil {
return err
}
}
// 生成请求ID
if req.RequestID == "" {
req.RequestID = uuid.New().String()
}
// 转换消息格式
messages := make([]*model.ChatCompletionMessage, len(req.Messages))
for i, msg := range req.Messages {
messages[i] = &model.ChatCompletionMessage{
Role: msg.Role,
Content: &model.ChatCompletionMessageContent{
StringValue: &msg.Content,
},
}
}
// 设置模型
modelName := req.Model
if modelName == "" {
modelName = global.GVA_CONFIG.Volcengine.DefaultModel
}
// 创建流式请求
chatReq := &model.ChatCompletionRequest{
Model: modelName,
Messages: messages,
Stream: true,
}
// 设置可选参数
if req.Temperature > 0 {
chatReq.Temperature = float32(req.Temperature)
}
if req.MaxTokens > 0 {
chatReq.MaxTokens = req.MaxTokens
}
if req.TopP > 0 {
chatReq.TopP = float32(req.TopP)
}
// 创建上下文
ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second) // 流式响应需要更长时间
defer cancel()
// 存储取消函数
l.activeSessions.Store(req.RequestID, cancel)
defer l.activeSessions.Delete(req.RequestID)
// 调用流式API
stream_resp, err := l.client.CreateChatCompletionStream(ctx, chatReq)
if err != nil {
global.GVA_LOG.Error("Stream chat completion failed", zap.Error(err), zap.String("requestID", req.RequestID))
return fmt.Errorf("stream chat completion failed: %w", err)
}
defer stream_resp.Close()
// 处理流式响应
for {
select {
case <-ctx.Done():
// 上下文取消
eventChan <- response.StreamEvent{
Event: "error",
Data: response.ChatResponse{
ID: req.RequestID,
Error: &response.APIError{
Code: "context_cancelled",
Message: "Request was cancelled",
Type: "request_cancelled",
},
},
}
return ctx.Err()
default:
// 接收流式数据
recv, err := stream_resp.Recv()
if err != nil {
if err.Error() == "EOF" {
// 流结束
eventChan <- response.StreamEvent{
Event: "done",
Data: response.ChatResponse{
ID: req.RequestID,
Object: "chat.completion.chunk",
},
}
global.GVA_LOG.Info("Stream chat completion finished", zap.String("requestID", req.RequestID))
return nil
}
global.GVA_LOG.Error("Stream receive error", zap.Error(err), zap.String("requestID", req.RequestID))
eventChan <- response.StreamEvent{
Event: "error",
Data: response.ChatResponse{
ID: req.RequestID,
Error: &response.APIError{
Code: "stream_error",
Message: err.Error(),
Type: "stream_error",
},
},
}
return err
}
// 转换并发送响应
chatResp := response.ChatResponse{
ID: req.RequestID,
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
Model: modelName,
Choices: make([]response.Choice, len(recv.Choices)),
}
for i, choice := range recv.Choices {
chatResp.Choices[i] = response.Choice{
Index: i,
Delta: response.Message{
Role: choice.Delta.Role,
Content: choice.Delta.Content,
},
}
if choice.FinishReason != "" {
chatResp.Choices[i].FinishReason = string(choice.FinishReason)
}
}
eventChan <- response.StreamEvent{
Event: "message",
Data: chatResp,
}
}
}
}
// StopGeneration 停止生成
func (l *LLMService) StopGeneration(req request.StopRequest) (*response.StopResponse, error) {
// 查找并取消对应的请求
if cancelFunc, ok := l.activeSessions.Load(req.RequestID); ok {
if cancel, ok := cancelFunc.(context.CancelFunc); ok {
cancel()
l.activeSessions.Delete(req.RequestID)
global.GVA_LOG.Info("Request stopped successfully", zap.String("requestID", req.RequestID))
return &response.StopResponse{
Success: true,
Message: "Request stopped successfully",
RequestID: req.RequestID,
StoppedAt: time.Now().Unix(),
}, nil
}
}
global.GVA_LOG.Warn("Request not found or already completed", zap.String("requestID", req.RequestID))
return &response.StopResponse{
Success: false,
Message: "Request not found or already completed",
RequestID: req.RequestID,
StoppedAt: time.Now().Unix(),
}, nil
}
// GetActiveSessionsCount 获取活跃会话数量
func (l *LLMService) GetActiveSessionsCount() int {
count := 0
l.activeSessions.Range(func(key, value interface{}) bool {
count++
return true
})
return count
}