317 lines
8.1 KiB
Go
317 lines
8.1 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
|
"github.com/flipped-aurora/gin-vue-admin/server/plugin/volcengine/model/request"
|
|
"github.com/flipped-aurora/gin-vue-admin/server/plugin/volcengine/model/response"
|
|
"github.com/google/uuid"
|
|
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
|
|
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
type LLMService struct {
|
|
client *arkruntime.Client
|
|
activeSessions sync.Map // requestID -> context.CancelFunc
|
|
}
|
|
|
|
var LLMServiceApp = new(LLMService)
|
|
|
|
// InitVolcengineClient 初始化火山引擎客户端
|
|
func (l *LLMService) InitVolcengineClient() error {
|
|
config := global.GVA_CONFIG.Volcengine
|
|
|
|
// 验证配置
|
|
if config.AccessKey == "" || config.SecretKey == "" {
|
|
return fmt.Errorf("volcengine access key or secret key is empty")
|
|
}
|
|
|
|
// 创建ARK Runtime客户端
|
|
l.client = arkruntime.NewClientWithAkSk(config.AccessKey, config.SecretKey)
|
|
|
|
global.GVA_LOG.Info("Volcengine client initialized successfully")
|
|
return nil
|
|
}
|
|
|
|
// ChatCompletion 非流式聊天完成
|
|
func (l *LLMService) ChatCompletion(req request.ChatRequest) (*response.ChatResponse, error) {
|
|
if l.client == nil {
|
|
if err := l.InitVolcengineClient(); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// 生成请求ID
|
|
if req.RequestID == "" {
|
|
req.RequestID = uuid.New().String()
|
|
}
|
|
|
|
// 转换消息格式
|
|
messages := make([]*model.ChatCompletionMessage, len(req.Messages))
|
|
for i, msg := range req.Messages {
|
|
messages[i] = &model.ChatCompletionMessage{
|
|
Role: msg.Role,
|
|
Content: &model.ChatCompletionMessageContent{
|
|
StringValue: &msg.Content,
|
|
},
|
|
}
|
|
}
|
|
|
|
// 设置模型
|
|
modelName := req.Model
|
|
if modelName == "" {
|
|
modelName = global.GVA_CONFIG.Volcengine.DefaultModel
|
|
}
|
|
|
|
// 创建请求
|
|
chatReq := &model.ChatCompletionRequest{
|
|
Model: modelName,
|
|
Messages: messages,
|
|
}
|
|
|
|
// 设置可选参数
|
|
if req.Temperature > 0 {
|
|
chatReq.Temperature = float32(req.Temperature)
|
|
}
|
|
if req.MaxTokens > 0 {
|
|
chatReq.MaxTokens = req.MaxTokens
|
|
}
|
|
if req.TopP > 0 {
|
|
chatReq.TopP = float32(req.TopP)
|
|
}
|
|
|
|
// 创建上下文
|
|
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
|
defer cancel()
|
|
|
|
// 存储取消函数
|
|
l.activeSessions.Store(req.RequestID, cancel)
|
|
defer l.activeSessions.Delete(req.RequestID)
|
|
|
|
// 调用API
|
|
resp, err := l.client.CreateChatCompletion(ctx, chatReq)
|
|
if err != nil {
|
|
global.GVA_LOG.Error("Chat completion failed", zap.Error(err), zap.String("requestID", req.RequestID))
|
|
return nil, fmt.Errorf("chat completion failed: %w", err)
|
|
}
|
|
|
|
// 转换响应格式
|
|
chatResp := &response.ChatResponse{
|
|
ID: req.RequestID,
|
|
Object: "chat.completion",
|
|
Created: time.Now().Unix(),
|
|
Model: modelName,
|
|
Choices: make([]response.Choice, len(resp.Choices)),
|
|
}
|
|
|
|
for i, choice := range resp.Choices {
|
|
messageContent := ""
|
|
if choice.Message.Content != nil && choice.Message.Content.StringValue != nil {
|
|
messageContent = *choice.Message.Content.StringValue
|
|
}
|
|
|
|
chatResp.Choices[i] = response.Choice{
|
|
Index: i,
|
|
Message: response.Message{
|
|
Role: choice.Message.Role,
|
|
Content: messageContent,
|
|
},
|
|
FinishReason: string(choice.FinishReason),
|
|
}
|
|
}
|
|
|
|
// 设置使用量统计
|
|
chatResp.Usage = &response.Usage{
|
|
PromptTokens: resp.Usage.PromptTokens,
|
|
CompletionTokens: resp.Usage.CompletionTokens,
|
|
TotalTokens: resp.Usage.TotalTokens,
|
|
}
|
|
|
|
global.GVA_LOG.Info("Chat completion successful", zap.String("requestID", req.RequestID))
|
|
return chatResp, nil
|
|
}
|
|
|
|
// StreamChatCompletion 流式聊天完成
|
|
func (l *LLMService) StreamChatCompletion(req request.ChatRequest, eventChan chan<- response.StreamEvent) error {
|
|
if l.client == nil {
|
|
if err := l.InitVolcengineClient(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// 生成请求ID
|
|
if req.RequestID == "" {
|
|
req.RequestID = uuid.New().String()
|
|
}
|
|
|
|
// 转换消息格式
|
|
messages := make([]*model.ChatCompletionMessage, len(req.Messages))
|
|
for i, msg := range req.Messages {
|
|
messages[i] = &model.ChatCompletionMessage{
|
|
Role: msg.Role,
|
|
Content: &model.ChatCompletionMessageContent{
|
|
StringValue: &msg.Content,
|
|
},
|
|
}
|
|
}
|
|
|
|
// 设置模型
|
|
modelName := req.Model
|
|
if modelName == "" {
|
|
modelName = global.GVA_CONFIG.Volcengine.DefaultModel
|
|
}
|
|
|
|
// 创建流式请求
|
|
chatReq := &model.ChatCompletionRequest{
|
|
Model: modelName,
|
|
Messages: messages,
|
|
Stream: true,
|
|
}
|
|
|
|
// 设置可选参数
|
|
if req.Temperature > 0 {
|
|
chatReq.Temperature = float32(req.Temperature)
|
|
}
|
|
if req.MaxTokens > 0 {
|
|
chatReq.MaxTokens = req.MaxTokens
|
|
}
|
|
if req.TopP > 0 {
|
|
chatReq.TopP = float32(req.TopP)
|
|
}
|
|
|
|
// 创建上下文
|
|
ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second) // 流式响应需要更长时间
|
|
defer cancel()
|
|
|
|
// 存储取消函数
|
|
l.activeSessions.Store(req.RequestID, cancel)
|
|
defer l.activeSessions.Delete(req.RequestID)
|
|
|
|
// 调用流式API
|
|
stream_resp, err := l.client.CreateChatCompletionStream(ctx, chatReq)
|
|
if err != nil {
|
|
global.GVA_LOG.Error("Stream chat completion failed", zap.Error(err), zap.String("requestID", req.RequestID))
|
|
return fmt.Errorf("stream chat completion failed: %w", err)
|
|
}
|
|
defer stream_resp.Close()
|
|
|
|
// 处理流式响应
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
// 上下文取消
|
|
eventChan <- response.StreamEvent{
|
|
Event: "error",
|
|
Data: response.ChatResponse{
|
|
ID: req.RequestID,
|
|
Error: &response.APIError{
|
|
Code: "context_cancelled",
|
|
Message: "Request was cancelled",
|
|
Type: "request_cancelled",
|
|
},
|
|
},
|
|
}
|
|
return ctx.Err()
|
|
default:
|
|
// 接收流式数据
|
|
recv, err := stream_resp.Recv()
|
|
if err != nil {
|
|
if err.Error() == "EOF" {
|
|
// 流结束
|
|
eventChan <- response.StreamEvent{
|
|
Event: "done",
|
|
Data: response.ChatResponse{
|
|
ID: req.RequestID,
|
|
Object: "chat.completion.chunk",
|
|
},
|
|
}
|
|
global.GVA_LOG.Info("Stream chat completion finished", zap.String("requestID", req.RequestID))
|
|
return nil
|
|
}
|
|
global.GVA_LOG.Error("Stream receive error", zap.Error(err), zap.String("requestID", req.RequestID))
|
|
eventChan <- response.StreamEvent{
|
|
Event: "error",
|
|
Data: response.ChatResponse{
|
|
ID: req.RequestID,
|
|
Error: &response.APIError{
|
|
Code: "stream_error",
|
|
Message: err.Error(),
|
|
Type: "stream_error",
|
|
},
|
|
},
|
|
}
|
|
return err
|
|
}
|
|
|
|
// 转换并发送响应
|
|
chatResp := response.ChatResponse{
|
|
ID: req.RequestID,
|
|
Object: "chat.completion.chunk",
|
|
Created: time.Now().Unix(),
|
|
Model: modelName,
|
|
Choices: make([]response.Choice, len(recv.Choices)),
|
|
}
|
|
|
|
for i, choice := range recv.Choices {
|
|
chatResp.Choices[i] = response.Choice{
|
|
Index: i,
|
|
Delta: response.Message{
|
|
Role: choice.Delta.Role,
|
|
Content: choice.Delta.Content,
|
|
},
|
|
}
|
|
if choice.FinishReason != "" {
|
|
chatResp.Choices[i].FinishReason = string(choice.FinishReason)
|
|
}
|
|
}
|
|
|
|
eventChan <- response.StreamEvent{
|
|
Event: "message",
|
|
Data: chatResp,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// StopGeneration 停止生成
|
|
func (l *LLMService) StopGeneration(req request.StopRequest) (*response.StopResponse, error) {
|
|
// 查找并取消对应的请求
|
|
if cancelFunc, ok := l.activeSessions.Load(req.RequestID); ok {
|
|
if cancel, ok := cancelFunc.(context.CancelFunc); ok {
|
|
cancel()
|
|
l.activeSessions.Delete(req.RequestID)
|
|
|
|
global.GVA_LOG.Info("Request stopped successfully", zap.String("requestID", req.RequestID))
|
|
return &response.StopResponse{
|
|
Success: true,
|
|
Message: "Request stopped successfully",
|
|
RequestID: req.RequestID,
|
|
StoppedAt: time.Now().Unix(),
|
|
}, nil
|
|
}
|
|
}
|
|
|
|
global.GVA_LOG.Warn("Request not found or already completed", zap.String("requestID", req.RequestID))
|
|
return &response.StopResponse{
|
|
Success: false,
|
|
Message: "Request not found or already completed",
|
|
RequestID: req.RequestID,
|
|
StoppedAt: time.Now().Unix(),
|
|
}, nil
|
|
}
|
|
|
|
// GetActiveSessionsCount 获取活跃会话数量
|
|
func (l *LLMService) GetActiveSessionsCount() int {
|
|
count := 0
|
|
l.activeSessions.Range(func(key, value interface{}) bool {
|
|
count++
|
|
return true
|
|
})
|
|
return count
|
|
}
|