pet-ai/server/utils/volcengine_llm.go

362 lines
9.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package utils
import (
"context"
"fmt"
"sync"
"github.com/flipped-aurora/gin-vue-admin/server/global"
"github.com/google/uuid"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
"go.uber.org/zap"
)
// LLMMessage 聊天消息结构体
type LLMMessage struct {
Role string `json:"role"` // 角色: system, user, assistant
Content string `json:"content"` // 消息内容
}
// LLMRequest 聊天请求结构体
type LLMRequest struct {
Messages []LLMMessage `json:"messages" binding:"required"` // 对话消息列表
Model string `json:"model"` // 模型名称,如果为空则使用默认模型
Stream bool `json:"stream"` // 是否流式响应默认false
Temperature float64 `json:"temperature"` // 温度参数控制随机性范围0-1
MaxTokens int `json:"max_tokens"` // 最大生成token数
TopP float64 `json:"top_p"` // 核采样参数
RequestID string `json:"request_id,omitempty"` // 请求ID用于流式响应管理
}
// LLMChoice 响应选择结构体
type LLMChoice struct {
Index int `json:"index"` // 选择索引
Message LLMMessage `json:"message,omitempty"` // 完整响应消息(非流式)
Delta LLMMessage `json:"delta,omitempty"` // 增量消息(流式)
FinishReason string `json:"finish_reason,omitempty"` // 结束原因: stop, length, content_filter
}
// LLMUsage 使用量统计结构体
type LLMUsage struct {
PromptTokens int `json:"prompt_tokens"` // 输入token数
CompletionTokens int `json:"completion_tokens"` // 输出token数
TotalTokens int `json:"total_tokens"` // 总token数
}
// LLMError API错误结构体
type LLMError struct {
Code string `json:"code"` // 错误代码
Message string `json:"message"` // 错误消息
Type string `json:"type"` // 错误类型
}
// LLMResponse 聊天响应结构体
type LLMResponse struct {
ID string `json:"id"` // 响应ID
Object string `json:"object"` // 对象类型: chat.completion 或 chat.completion.chunk
Created int64 `json:"created"` // 创建时间戳
Model string `json:"model"` // 使用的模型
Choices []LLMChoice `json:"choices"` // 响应选择列表
Usage *LLMUsage `json:"usage,omitempty"` // 使用量统计(非流式响应)
Error *LLMError `json:"error,omitempty"` // 错误信息
}
// LLMStreamEvent 流式事件结构体
type LLMStreamEvent struct {
Event string `json:"event"` // 事件类型: message, error, done
Data LLMResponse `json:"data"` // 事件数据
}
// VolcengineLLMUtil 火山引擎LLM工具类
type VolcengineLLMUtil struct {
client *arkruntime.Client
activeSessions sync.Map // requestID -> context.CancelFunc
once sync.Once
initErr error
}
// 全局单例实例
var (
volcengineLLMInstance *VolcengineLLMUtil
volcengineLLMOnce sync.Once
)
// GetVolcengineLLMUtil 获取火山引擎LLM工具单例实例
func GetVolcengineLLMUtil() *VolcengineLLMUtil {
volcengineLLMOnce.Do(func() {
volcengineLLMInstance = &VolcengineLLMUtil{}
})
return volcengineLLMInstance
}
// InitClient 初始化火山引擎客户端(单例模式)
func (v *VolcengineLLMUtil) InitClient() error {
v.once.Do(func() {
config := global.GVA_CONFIG.Volcengine
// 验证配置
if config.ApiKey == "" {
v.initErr = fmt.Errorf("volcengine api key is empty")
global.GVA_LOG.Error("Volcengine configuration error", zap.Error(v.initErr))
return
}
// 创建ARK Runtime客户端使用API Key
v.client = arkruntime.NewClientWithApiKey(config.ApiKey)
global.GVA_LOG.Info("Volcengine LLM client initialized successfully")
})
return v.initErr
}
// ChatCompletion 非流式聊天完成
func (v *VolcengineLLMUtil) ChatCompletion(req LLMRequest) (*LLMResponse, error) {
if v.client == nil {
if err := v.InitClient(); err != nil {
return nil, err
}
}
// 生成请求ID
if req.RequestID == "" {
req.RequestID = uuid.New().String()
}
// 转换消息格式
messages := make([]*model.ChatCompletionMessage, len(req.Messages))
for i, msg := range req.Messages {
messages[i] = &model.ChatCompletionMessage{
Role: msg.Role,
Content: &model.ChatCompletionMessageContent{
StringValue: &msg.Content,
},
}
}
// 设置模型
modelName := req.Model
if modelName == "" {
modelName = global.GVA_CONFIG.Volcengine.Model
}
// 创建请求
chatReq := &model.ChatCompletionRequest{
Model: modelName,
Messages: messages,
Stream: false,
}
// 设置可选参数
if req.Temperature > 0 {
chatReq.Temperature = float32(req.Temperature)
}
if req.MaxTokens > 0 {
chatReq.MaxTokens = req.MaxTokens
}
if req.TopP > 0 {
chatReq.TopP = float32(req.TopP)
}
// 调用API
resp, err := v.client.CreateChatCompletion(context.Background(), chatReq)
if err != nil {
global.GVA_LOG.Error("Chat completion failed", zap.Error(err), zap.String("requestID", req.RequestID))
return nil, err
}
// 转换响应格式
choices := make([]LLMChoice, len(resp.Choices))
for i, choice := range resp.Choices {
choices[i] = LLMChoice{
Index: choice.Index,
Message: LLMMessage{
Role: choice.Message.Role,
Content: *choice.Message.Content.StringValue,
},
FinishReason: string(choice.FinishReason),
}
}
usage := &LLMUsage{
PromptTokens: resp.Usage.PromptTokens,
CompletionTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
}
chatResp := &LLMResponse{
ID: resp.ID,
Object: resp.Object,
Created: resp.Created,
Model: resp.Model,
Choices: choices,
Usage: usage,
}
global.GVA_LOG.Info("Chat completion successful", zap.String("requestID", req.RequestID))
return chatResp, nil
}
// StreamChatCompletion 流式聊天完成
func (v *VolcengineLLMUtil) StreamChatCompletion(req LLMRequest, eventChan chan<- LLMStreamEvent) error {
if v.client == nil {
if err := v.InitClient(); err != nil {
return err
}
}
// 生成请求ID
if req.RequestID == "" {
req.RequestID = uuid.New().String()
}
// 转换消息格式
messages := make([]*model.ChatCompletionMessage, len(req.Messages))
for i, msg := range req.Messages {
messages[i] = &model.ChatCompletionMessage{
Role: msg.Role,
Content: &model.ChatCompletionMessageContent{
StringValue: &msg.Content,
},
}
}
// 设置模型
modelName := req.Model
if modelName == "" {
modelName = global.GVA_CONFIG.Volcengine.Model
}
// 创建流式请求
chatReq := &model.ChatCompletionRequest{
Model: modelName,
Messages: messages,
Stream: true,
}
// 设置可选参数
if req.Temperature > 0 {
chatReq.Temperature = float32(req.Temperature)
}
if req.MaxTokens > 0 {
chatReq.MaxTokens = req.MaxTokens
}
if req.TopP > 0 {
chatReq.TopP = float32(req.TopP)
}
// 创建上下文和取消函数
ctx, cancel := context.WithCancel(context.Background())
v.activeSessions.Store(req.RequestID, cancel)
defer func() {
v.activeSessions.Delete(req.RequestID)
cancel()
}()
// 调用流式API
stream_resp, err := v.client.CreateChatCompletionStream(ctx, chatReq)
if err != nil {
global.GVA_LOG.Error("Stream chat completion failed", zap.Error(err), zap.String("requestID", req.RequestID))
return err
}
defer stream_resp.Close()
global.GVA_LOG.Info("Stream chat completion started", zap.String("requestID", req.RequestID))
// 处理流式响应
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
// 接收流式数据
recv, err := stream_resp.Recv()
if err != nil {
if err.Error() == "EOF" {
// 流结束
eventChan <- LLMStreamEvent{
Event: "done",
Data: LLMResponse{
ID: req.RequestID,
Object: "chat.completion.chunk",
},
}
global.GVA_LOG.Info("Stream chat completion finished", zap.String("requestID", req.RequestID))
return nil
}
global.GVA_LOG.Error("Stream receive error", zap.Error(err), zap.String("requestID", req.RequestID))
eventChan <- LLMStreamEvent{
Event: "error",
Data: LLMResponse{
ID: req.RequestID,
Error: &LLMError{
Code: "stream_error",
Message: err.Error(),
Type: "stream_error",
},
},
}
return err
}
// 转换响应格式
choices := make([]LLMChoice, len(recv.Choices))
for i, choice := range recv.Choices {
choices[i] = LLMChoice{
Index: choice.Index,
Delta: LLMMessage{
Role: choice.Delta.Role,
Content: choice.Delta.Content,
},
FinishReason: string(choice.FinishReason),
}
}
var usage *LLMUsage
if recv.Usage != nil {
usage = &LLMUsage{
PromptTokens: recv.Usage.PromptTokens,
CompletionTokens: recv.Usage.CompletionTokens,
TotalTokens: recv.Usage.TotalTokens,
}
}
// 发送消息事件
eventChan <- LLMStreamEvent{
Event: "message",
Data: LLMResponse{
ID: recv.ID,
Object: recv.Object,
Created: recv.Created,
Model: recv.Model,
Choices: choices,
Usage: usage,
},
}
}
}
}
// StopGeneration 停止生成
func (v *VolcengineLLMUtil) StopGeneration(requestID string) error {
if cancel, ok := v.activeSessions.Load(requestID); ok {
if cancelFunc, ok := cancel.(context.CancelFunc); ok {
cancelFunc()
v.activeSessions.Delete(requestID)
global.GVA_LOG.Info("Generation stopped", zap.String("requestID", requestID))
return nil
}
}
return fmt.Errorf("session not found: %s", requestID)
}
// GetActiveSessionsCount 获取活跃会话数量
func (v *VolcengineLLMUtil) GetActiveSessionsCount() int {
count := 0
v.activeSessions.Range(func(key, value interface{}) bool {
count++
return true
})
return count
}