package utils import ( "context" "fmt" "sync" "github.com/flipped-aurora/gin-vue-admin/server/global" "github.com/google/uuid" "github.com/volcengine/volcengine-go-sdk/service/arkruntime" "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" "go.uber.org/zap" ) // LLMMessage 聊天消息结构体 type LLMMessage struct { Role string `json:"role"` // 角色: system, user, assistant Content string `json:"content"` // 消息内容 } // LLMRequest 聊天请求结构体 type LLMRequest struct { Messages []LLMMessage `json:"messages" binding:"required"` // 对话消息列表 Model string `json:"model"` // 模型名称,如果为空则使用默认模型 Stream bool `json:"stream"` // 是否流式响应,默认false Temperature float64 `json:"temperature"` // 温度参数,控制随机性,范围0-1 MaxTokens int `json:"max_tokens"` // 最大生成token数 TopP float64 `json:"top_p"` // 核采样参数 RequestID string `json:"request_id,omitempty"` // 请求ID,用于流式响应管理 } // LLMChoice 响应选择结构体 type LLMChoice struct { Index int `json:"index"` // 选择索引 Message LLMMessage `json:"message,omitempty"` // 完整响应消息(非流式) Delta LLMMessage `json:"delta,omitempty"` // 增量消息(流式) FinishReason string `json:"finish_reason,omitempty"` // 结束原因: stop, length, content_filter } // LLMUsage 使用量统计结构体 type LLMUsage struct { PromptTokens int `json:"prompt_tokens"` // 输入token数 CompletionTokens int `json:"completion_tokens"` // 输出token数 TotalTokens int `json:"total_tokens"` // 总token数 } // LLMError API错误结构体 type LLMError struct { Code string `json:"code"` // 错误代码 Message string `json:"message"` // 错误消息 Type string `json:"type"` // 错误类型 } // LLMResponse 聊天响应结构体 type LLMResponse struct { ID string `json:"id"` // 响应ID Object string `json:"object"` // 对象类型: chat.completion 或 chat.completion.chunk Created int64 `json:"created"` // 创建时间戳 Model string `json:"model"` // 使用的模型 Choices []LLMChoice `json:"choices"` // 响应选择列表 Usage *LLMUsage `json:"usage,omitempty"` // 使用量统计(非流式响应) Error *LLMError `json:"error,omitempty"` // 错误信息 } // LLMStreamEvent 流式事件结构体 type LLMStreamEvent struct { Event string `json:"event"` // 事件类型: message, error, done Data LLMResponse `json:"data"` // 事件数据 } // VolcengineLLMUtil 火山引擎LLM工具类 type VolcengineLLMUtil struct { client *arkruntime.Client activeSessions sync.Map // requestID -> context.CancelFunc once sync.Once initErr error } // 全局单例实例 var ( volcengineLLMInstance *VolcengineLLMUtil volcengineLLMOnce sync.Once ) // GetVolcengineLLMUtil 获取火山引擎LLM工具单例实例 func GetVolcengineLLMUtil() *VolcengineLLMUtil { volcengineLLMOnce.Do(func() { volcengineLLMInstance = &VolcengineLLMUtil{} }) return volcengineLLMInstance } // InitClient 初始化火山引擎客户端(单例模式) func (v *VolcengineLLMUtil) InitClient() error { v.once.Do(func() { config := global.GVA_CONFIG.Volcengine // 验证配置 if config.ApiKey == "" { v.initErr = fmt.Errorf("volcengine api key is empty") global.GVA_LOG.Error("Volcengine configuration error", zap.Error(v.initErr)) return } // 创建ARK Runtime客户端,使用API Key v.client = arkruntime.NewClientWithApiKey(config.ApiKey) global.GVA_LOG.Info("Volcengine LLM client initialized successfully") }) return v.initErr } // ChatCompletion 非流式聊天完成 func (v *VolcengineLLMUtil) ChatCompletion(req LLMRequest) (*LLMResponse, error) { if v.client == nil { if err := v.InitClient(); err != nil { return nil, err } } // 生成请求ID if req.RequestID == "" { req.RequestID = uuid.New().String() } // 转换消息格式 messages := make([]*model.ChatCompletionMessage, len(req.Messages)) for i, msg := range req.Messages { messages[i] = &model.ChatCompletionMessage{ Role: msg.Role, Content: &model.ChatCompletionMessageContent{ StringValue: &msg.Content, }, } } // 设置模型 modelName := req.Model if modelName == "" { modelName = global.GVA_CONFIG.Volcengine.Model } // 创建请求 chatReq := &model.ChatCompletionRequest{ Model: modelName, Messages: messages, Stream: false, } // 设置可选参数 if req.Temperature > 0 { chatReq.Temperature = float32(req.Temperature) } if req.MaxTokens > 0 { chatReq.MaxTokens = req.MaxTokens } if req.TopP > 0 { chatReq.TopP = float32(req.TopP) } // 调用API resp, err := v.client.CreateChatCompletion(context.Background(), chatReq) if err != nil { global.GVA_LOG.Error("Chat completion failed", zap.Error(err), zap.String("requestID", req.RequestID)) return nil, err } // 转换响应格式 choices := make([]LLMChoice, len(resp.Choices)) for i, choice := range resp.Choices { choices[i] = LLMChoice{ Index: choice.Index, Message: LLMMessage{ Role: choice.Message.Role, Content: *choice.Message.Content.StringValue, }, FinishReason: string(choice.FinishReason), } } usage := &LLMUsage{ PromptTokens: resp.Usage.PromptTokens, CompletionTokens: resp.Usage.CompletionTokens, TotalTokens: resp.Usage.TotalTokens, } chatResp := &LLMResponse{ ID: resp.ID, Object: resp.Object, Created: resp.Created, Model: resp.Model, Choices: choices, Usage: usage, } global.GVA_LOG.Info("Chat completion successful", zap.String("requestID", req.RequestID)) return chatResp, nil } // StreamChatCompletion 流式聊天完成 func (v *VolcengineLLMUtil) StreamChatCompletion(req LLMRequest, eventChan chan<- LLMStreamEvent) error { if v.client == nil { if err := v.InitClient(); err != nil { return err } } // 生成请求ID if req.RequestID == "" { req.RequestID = uuid.New().String() } // 转换消息格式 messages := make([]*model.ChatCompletionMessage, len(req.Messages)) for i, msg := range req.Messages { messages[i] = &model.ChatCompletionMessage{ Role: msg.Role, Content: &model.ChatCompletionMessageContent{ StringValue: &msg.Content, }, } } // 设置模型 modelName := req.Model if modelName == "" { modelName = global.GVA_CONFIG.Volcengine.Model } // 创建流式请求 chatReq := &model.ChatCompletionRequest{ Model: modelName, Messages: messages, Stream: true, } // 设置可选参数 if req.Temperature > 0 { chatReq.Temperature = float32(req.Temperature) } if req.MaxTokens > 0 { chatReq.MaxTokens = req.MaxTokens } if req.TopP > 0 { chatReq.TopP = float32(req.TopP) } // 创建上下文和取消函数 ctx, cancel := context.WithCancel(context.Background()) v.activeSessions.Store(req.RequestID, cancel) defer func() { v.activeSessions.Delete(req.RequestID) cancel() }() // 调用流式API stream_resp, err := v.client.CreateChatCompletionStream(ctx, chatReq) if err != nil { global.GVA_LOG.Error("Stream chat completion failed", zap.Error(err), zap.String("requestID", req.RequestID)) return err } defer stream_resp.Close() global.GVA_LOG.Info("Stream chat completion started", zap.String("requestID", req.RequestID)) // 处理流式响应 for { select { case <-ctx.Done(): return ctx.Err() default: // 接收流式数据 recv, err := stream_resp.Recv() if err != nil { if err.Error() == "EOF" { // 流结束 eventChan <- LLMStreamEvent{ Event: "done", Data: LLMResponse{ ID: req.RequestID, Object: "chat.completion.chunk", }, } global.GVA_LOG.Info("Stream chat completion finished", zap.String("requestID", req.RequestID)) return nil } global.GVA_LOG.Error("Stream receive error", zap.Error(err), zap.String("requestID", req.RequestID)) eventChan <- LLMStreamEvent{ Event: "error", Data: LLMResponse{ ID: req.RequestID, Error: &LLMError{ Code: "stream_error", Message: err.Error(), Type: "stream_error", }, }, } return err } // 转换响应格式 choices := make([]LLMChoice, len(recv.Choices)) for i, choice := range recv.Choices { choices[i] = LLMChoice{ Index: choice.Index, Delta: LLMMessage{ Role: choice.Delta.Role, Content: choice.Delta.Content, }, FinishReason: string(choice.FinishReason), } } var usage *LLMUsage if recv.Usage != nil { usage = &LLMUsage{ PromptTokens: recv.Usage.PromptTokens, CompletionTokens: recv.Usage.CompletionTokens, TotalTokens: recv.Usage.TotalTokens, } } // 发送消息事件 eventChan <- LLMStreamEvent{ Event: "message", Data: LLMResponse{ ID: recv.ID, Object: recv.Object, Created: recv.Created, Model: recv.Model, Choices: choices, Usage: usage, }, } } } } // StopGeneration 停止生成 func (v *VolcengineLLMUtil) StopGeneration(requestID string) error { if cancel, ok := v.activeSessions.Load(requestID); ok { if cancelFunc, ok := cancel.(context.CancelFunc); ok { cancelFunc() v.activeSessions.Delete(requestID) global.GVA_LOG.Info("Generation stopped", zap.String("requestID", requestID)) return nil } } return fmt.Errorf("session not found: %s", requestID) } // GetActiveSessionsCount 获取活跃会话数量 func (v *VolcengineLLMUtil) GetActiveSessionsCount() int { count := 0 v.activeSessions.Range(func(key, value interface{}) bool { count++ return true }) return count }