package service import ( "context" "fmt" "sync" "time" "github.com/flipped-aurora/gin-vue-admin/server/global" "github.com/flipped-aurora/gin-vue-admin/server/plugin/volcengine/model/request" "github.com/flipped-aurora/gin-vue-admin/server/plugin/volcengine/model/response" "github.com/google/uuid" "github.com/volcengine/volcengine-go-sdk/service/arkruntime" "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" "go.uber.org/zap" ) type LLMService struct { client *arkruntime.Client activeSessions sync.Map // requestID -> context.CancelFunc } var LLMServiceApp = new(LLMService) // InitVolcengineClient 初始化火山引擎客户端 func (l *LLMService) InitVolcengineClient() error { config := global.GVA_CONFIG.Volcengine // 验证配置 if config.AccessKey == "" || config.SecretKey == "" { return fmt.Errorf("volcengine access key or secret key is empty") } // 创建ARK Runtime客户端 l.client = arkruntime.NewClientWithAkSk(config.AccessKey, config.SecretKey) global.GVA_LOG.Info("Volcengine client initialized successfully") return nil } // ChatCompletion 非流式聊天完成 func (l *LLMService) ChatCompletion(req request.ChatRequest) (*response.ChatResponse, error) { if l.client == nil { if err := l.InitVolcengineClient(); err != nil { return nil, err } } // 生成请求ID if req.RequestID == "" { req.RequestID = uuid.New().String() } // 转换消息格式 messages := make([]*model.ChatCompletionMessage, len(req.Messages)) for i, msg := range req.Messages { messages[i] = &model.ChatCompletionMessage{ Role: msg.Role, Content: &model.ChatCompletionMessageContent{ StringValue: &msg.Content, }, } } // 设置模型 modelName := req.Model if modelName == "" { modelName = global.GVA_CONFIG.Volcengine.DefaultModel } // 创建请求 chatReq := &model.ChatCompletionRequest{ Model: modelName, Messages: messages, } // 设置可选参数 if req.Temperature > 0 { chatReq.Temperature = float32(req.Temperature) } if req.MaxTokens > 0 { chatReq.MaxTokens = req.MaxTokens } if req.TopP > 0 { chatReq.TopP = float32(req.TopP) } // 创建上下文 ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() // 存储取消函数 l.activeSessions.Store(req.RequestID, cancel) defer l.activeSessions.Delete(req.RequestID) // 调用API resp, err := l.client.CreateChatCompletion(ctx, chatReq) if err != nil { global.GVA_LOG.Error("Chat completion failed", zap.Error(err), zap.String("requestID", req.RequestID)) return nil, fmt.Errorf("chat completion failed: %w", err) } // 转换响应格式 chatResp := &response.ChatResponse{ ID: req.RequestID, Object: "chat.completion", Created: time.Now().Unix(), Model: modelName, Choices: make([]response.Choice, len(resp.Choices)), } for i, choice := range resp.Choices { messageContent := "" if choice.Message.Content != nil && choice.Message.Content.StringValue != nil { messageContent = *choice.Message.Content.StringValue } chatResp.Choices[i] = response.Choice{ Index: i, Message: response.Message{ Role: choice.Message.Role, Content: messageContent, }, FinishReason: string(choice.FinishReason), } } // 设置使用量统计 chatResp.Usage = &response.Usage{ PromptTokens: resp.Usage.PromptTokens, CompletionTokens: resp.Usage.CompletionTokens, TotalTokens: resp.Usage.TotalTokens, } global.GVA_LOG.Info("Chat completion successful", zap.String("requestID", req.RequestID)) return chatResp, nil } // StreamChatCompletion 流式聊天完成 func (l *LLMService) StreamChatCompletion(req request.ChatRequest, eventChan chan<- response.StreamEvent) error { if l.client == nil { if err := l.InitVolcengineClient(); err != nil { return err } } // 生成请求ID if req.RequestID == "" { req.RequestID = uuid.New().String() } // 转换消息格式 messages := make([]*model.ChatCompletionMessage, len(req.Messages)) for i, msg := range req.Messages { messages[i] = &model.ChatCompletionMessage{ Role: msg.Role, Content: &model.ChatCompletionMessageContent{ StringValue: &msg.Content, }, } } // 设置模型 modelName := req.Model if modelName == "" { modelName = global.GVA_CONFIG.Volcengine.DefaultModel } // 创建流式请求 chatReq := &model.ChatCompletionRequest{ Model: modelName, Messages: messages, Stream: true, } // 设置可选参数 if req.Temperature > 0 { chatReq.Temperature = float32(req.Temperature) } if req.MaxTokens > 0 { chatReq.MaxTokens = req.MaxTokens } if req.TopP > 0 { chatReq.TopP = float32(req.TopP) } // 创建上下文 ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second) // 流式响应需要更长时间 defer cancel() // 存储取消函数 l.activeSessions.Store(req.RequestID, cancel) defer l.activeSessions.Delete(req.RequestID) // 调用流式API stream_resp, err := l.client.CreateChatCompletionStream(ctx, chatReq) if err != nil { global.GVA_LOG.Error("Stream chat completion failed", zap.Error(err), zap.String("requestID", req.RequestID)) return fmt.Errorf("stream chat completion failed: %w", err) } defer stream_resp.Close() // 处理流式响应 for { select { case <-ctx.Done(): // 上下文取消 eventChan <- response.StreamEvent{ Event: "error", Data: response.ChatResponse{ ID: req.RequestID, Error: &response.APIError{ Code: "context_cancelled", Message: "Request was cancelled", Type: "request_cancelled", }, }, } return ctx.Err() default: // 接收流式数据 recv, err := stream_resp.Recv() if err != nil { if err.Error() == "EOF" { // 流结束 eventChan <- response.StreamEvent{ Event: "done", Data: response.ChatResponse{ ID: req.RequestID, Object: "chat.completion.chunk", }, } global.GVA_LOG.Info("Stream chat completion finished", zap.String("requestID", req.RequestID)) return nil } global.GVA_LOG.Error("Stream receive error", zap.Error(err), zap.String("requestID", req.RequestID)) eventChan <- response.StreamEvent{ Event: "error", Data: response.ChatResponse{ ID: req.RequestID, Error: &response.APIError{ Code: "stream_error", Message: err.Error(), Type: "stream_error", }, }, } return err } // 转换并发送响应 chatResp := response.ChatResponse{ ID: req.RequestID, Object: "chat.completion.chunk", Created: time.Now().Unix(), Model: modelName, Choices: make([]response.Choice, len(recv.Choices)), } for i, choice := range recv.Choices { chatResp.Choices[i] = response.Choice{ Index: i, Delta: response.Message{ Role: choice.Delta.Role, Content: choice.Delta.Content, }, } if choice.FinishReason != "" { chatResp.Choices[i].FinishReason = string(choice.FinishReason) } } eventChan <- response.StreamEvent{ Event: "message", Data: chatResp, } } } } // StopGeneration 停止生成 func (l *LLMService) StopGeneration(req request.StopRequest) (*response.StopResponse, error) { // 查找并取消对应的请求 if cancelFunc, ok := l.activeSessions.Load(req.RequestID); ok { if cancel, ok := cancelFunc.(context.CancelFunc); ok { cancel() l.activeSessions.Delete(req.RequestID) global.GVA_LOG.Info("Request stopped successfully", zap.String("requestID", req.RequestID)) return &response.StopResponse{ Success: true, Message: "Request stopped successfully", RequestID: req.RequestID, StoppedAt: time.Now().Unix(), }, nil } } global.GVA_LOG.Warn("Request not found or already completed", zap.String("requestID", req.RequestID)) return &response.StopResponse{ Success: false, Message: "Request not found or already completed", RequestID: req.RequestID, StoppedAt: time.Now().Unix(), }, nil } // GetActiveSessionsCount 获取活跃会话数量 func (l *LLMService) GetActiveSessionsCount() int { count := 0 l.activeSessions.Range(func(key, value interface{}) bool { count++ return true }) return count }