362 lines
9.8 KiB
Go
362 lines
9.8 KiB
Go
package utils
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"sync"
|
||
|
||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||
"github.com/google/uuid"
|
||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
|
||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// LLMMessage 聊天消息结构体
|
||
type LLMMessage struct {
|
||
Role string `json:"role"` // 角色: system, user, assistant
|
||
Content string `json:"content"` // 消息内容
|
||
}
|
||
|
||
// LLMRequest 聊天请求结构体
|
||
type LLMRequest struct {
|
||
Messages []LLMMessage `json:"messages" binding:"required"` // 对话消息列表
|
||
Model string `json:"model"` // 模型名称,如果为空则使用默认模型
|
||
Stream bool `json:"stream"` // 是否流式响应,默认false
|
||
Temperature float64 `json:"temperature"` // 温度参数,控制随机性,范围0-1
|
||
MaxTokens int `json:"max_tokens"` // 最大生成token数
|
||
TopP float64 `json:"top_p"` // 核采样参数
|
||
RequestID string `json:"request_id,omitempty"` // 请求ID,用于流式响应管理
|
||
}
|
||
|
||
// LLMChoice 响应选择结构体
|
||
type LLMChoice struct {
|
||
Index int `json:"index"` // 选择索引
|
||
Message LLMMessage `json:"message,omitempty"` // 完整响应消息(非流式)
|
||
Delta LLMMessage `json:"delta,omitempty"` // 增量消息(流式)
|
||
FinishReason string `json:"finish_reason,omitempty"` // 结束原因: stop, length, content_filter
|
||
}
|
||
|
||
// LLMUsage 使用量统计结构体
|
||
type LLMUsage struct {
|
||
PromptTokens int `json:"prompt_tokens"` // 输入token数
|
||
CompletionTokens int `json:"completion_tokens"` // 输出token数
|
||
TotalTokens int `json:"total_tokens"` // 总token数
|
||
}
|
||
|
||
// LLMError API错误结构体
|
||
type LLMError struct {
|
||
Code string `json:"code"` // 错误代码
|
||
Message string `json:"message"` // 错误消息
|
||
Type string `json:"type"` // 错误类型
|
||
}
|
||
|
||
// LLMResponse 聊天响应结构体
|
||
type LLMResponse struct {
|
||
ID string `json:"id"` // 响应ID
|
||
Object string `json:"object"` // 对象类型: chat.completion 或 chat.completion.chunk
|
||
Created int64 `json:"created"` // 创建时间戳
|
||
Model string `json:"model"` // 使用的模型
|
||
Choices []LLMChoice `json:"choices"` // 响应选择列表
|
||
Usage *LLMUsage `json:"usage,omitempty"` // 使用量统计(非流式响应)
|
||
Error *LLMError `json:"error,omitempty"` // 错误信息
|
||
}
|
||
|
||
// LLMStreamEvent 流式事件结构体
|
||
type LLMStreamEvent struct {
|
||
Event string `json:"event"` // 事件类型: message, error, done
|
||
Data LLMResponse `json:"data"` // 事件数据
|
||
}
|
||
|
||
// VolcengineLLMUtil 火山引擎LLM工具类
|
||
type VolcengineLLMUtil struct {
|
||
client *arkruntime.Client
|
||
activeSessions sync.Map // requestID -> context.CancelFunc
|
||
once sync.Once
|
||
initErr error
|
||
}
|
||
|
||
// 全局单例实例
|
||
var (
|
||
volcengineLLMInstance *VolcengineLLMUtil
|
||
volcengineLLMOnce sync.Once
|
||
)
|
||
|
||
// GetVolcengineLLMUtil 获取火山引擎LLM工具单例实例
|
||
func GetVolcengineLLMUtil() *VolcengineLLMUtil {
|
||
volcengineLLMOnce.Do(func() {
|
||
volcengineLLMInstance = &VolcengineLLMUtil{}
|
||
})
|
||
return volcengineLLMInstance
|
||
}
|
||
|
||
// InitClient 初始化火山引擎客户端(单例模式)
|
||
func (v *VolcengineLLMUtil) InitClient() error {
|
||
v.once.Do(func() {
|
||
config := global.GVA_CONFIG.Volcengine
|
||
|
||
// 验证配置
|
||
if config.ApiKey == "" {
|
||
v.initErr = fmt.Errorf("volcengine api key is empty")
|
||
global.GVA_LOG.Error("Volcengine configuration error", zap.Error(v.initErr))
|
||
return
|
||
}
|
||
|
||
// 创建ARK Runtime客户端,使用API Key
|
||
v.client = arkruntime.NewClientWithApiKey(config.ApiKey)
|
||
|
||
global.GVA_LOG.Info("Volcengine LLM client initialized successfully")
|
||
})
|
||
return v.initErr
|
||
}
|
||
|
||
// ChatCompletion 非流式聊天完成
|
||
func (v *VolcengineLLMUtil) ChatCompletion(req LLMRequest) (*LLMResponse, error) {
|
||
if v.client == nil {
|
||
if err := v.InitClient(); err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
|
||
// 生成请求ID
|
||
if req.RequestID == "" {
|
||
req.RequestID = uuid.New().String()
|
||
}
|
||
|
||
// 转换消息格式
|
||
messages := make([]*model.ChatCompletionMessage, len(req.Messages))
|
||
for i, msg := range req.Messages {
|
||
messages[i] = &model.ChatCompletionMessage{
|
||
Role: msg.Role,
|
||
Content: &model.ChatCompletionMessageContent{
|
||
StringValue: &msg.Content,
|
||
},
|
||
}
|
||
}
|
||
|
||
// 设置模型
|
||
modelName := req.Model
|
||
if modelName == "" {
|
||
modelName = global.GVA_CONFIG.Volcengine.Model
|
||
}
|
||
|
||
// 创建请求
|
||
chatReq := &model.ChatCompletionRequest{
|
||
Model: modelName,
|
||
Messages: messages,
|
||
Stream: false,
|
||
}
|
||
|
||
// 设置可选参数
|
||
if req.Temperature > 0 {
|
||
chatReq.Temperature = float32(req.Temperature)
|
||
}
|
||
if req.MaxTokens > 0 {
|
||
chatReq.MaxTokens = req.MaxTokens
|
||
}
|
||
if req.TopP > 0 {
|
||
chatReq.TopP = float32(req.TopP)
|
||
}
|
||
|
||
// 调用API
|
||
resp, err := v.client.CreateChatCompletion(context.Background(), chatReq)
|
||
if err != nil {
|
||
global.GVA_LOG.Error("Chat completion failed", zap.Error(err), zap.String("requestID", req.RequestID))
|
||
return nil, err
|
||
}
|
||
|
||
// 转换响应格式
|
||
choices := make([]LLMChoice, len(resp.Choices))
|
||
for i, choice := range resp.Choices {
|
||
choices[i] = LLMChoice{
|
||
Index: choice.Index,
|
||
Message: LLMMessage{
|
||
Role: choice.Message.Role,
|
||
Content: *choice.Message.Content.StringValue,
|
||
},
|
||
FinishReason: string(choice.FinishReason),
|
||
}
|
||
}
|
||
|
||
usage := &LLMUsage{
|
||
PromptTokens: resp.Usage.PromptTokens,
|
||
CompletionTokens: resp.Usage.CompletionTokens,
|
||
TotalTokens: resp.Usage.TotalTokens,
|
||
}
|
||
|
||
chatResp := &LLMResponse{
|
||
ID: resp.ID,
|
||
Object: resp.Object,
|
||
Created: resp.Created,
|
||
Model: resp.Model,
|
||
Choices: choices,
|
||
Usage: usage,
|
||
}
|
||
|
||
global.GVA_LOG.Info("Chat completion successful", zap.String("requestID", req.RequestID))
|
||
return chatResp, nil
|
||
}
|
||
|
||
// StreamChatCompletion 流式聊天完成
|
||
func (v *VolcengineLLMUtil) StreamChatCompletion(req LLMRequest, eventChan chan<- LLMStreamEvent) error {
|
||
if v.client == nil {
|
||
if err := v.InitClient(); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
// 生成请求ID
|
||
if req.RequestID == "" {
|
||
req.RequestID = uuid.New().String()
|
||
}
|
||
|
||
// 转换消息格式
|
||
messages := make([]*model.ChatCompletionMessage, len(req.Messages))
|
||
for i, msg := range req.Messages {
|
||
messages[i] = &model.ChatCompletionMessage{
|
||
Role: msg.Role,
|
||
Content: &model.ChatCompletionMessageContent{
|
||
StringValue: &msg.Content,
|
||
},
|
||
}
|
||
}
|
||
|
||
// 设置模型
|
||
modelName := req.Model
|
||
if modelName == "" {
|
||
modelName = global.GVA_CONFIG.Volcengine.Model
|
||
}
|
||
|
||
// 创建流式请求
|
||
chatReq := &model.ChatCompletionRequest{
|
||
Model: modelName,
|
||
Messages: messages,
|
||
Stream: true,
|
||
}
|
||
|
||
// 设置可选参数
|
||
if req.Temperature > 0 {
|
||
chatReq.Temperature = float32(req.Temperature)
|
||
}
|
||
if req.MaxTokens > 0 {
|
||
chatReq.MaxTokens = req.MaxTokens
|
||
}
|
||
if req.TopP > 0 {
|
||
chatReq.TopP = float32(req.TopP)
|
||
}
|
||
|
||
// 创建上下文和取消函数
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
v.activeSessions.Store(req.RequestID, cancel)
|
||
defer func() {
|
||
v.activeSessions.Delete(req.RequestID)
|
||
cancel()
|
||
}()
|
||
|
||
// 调用流式API
|
||
stream_resp, err := v.client.CreateChatCompletionStream(ctx, chatReq)
|
||
if err != nil {
|
||
global.GVA_LOG.Error("Stream chat completion failed", zap.Error(err), zap.String("requestID", req.RequestID))
|
||
return err
|
||
}
|
||
defer stream_resp.Close()
|
||
|
||
global.GVA_LOG.Info("Stream chat completion started", zap.String("requestID", req.RequestID))
|
||
|
||
// 处理流式响应
|
||
for {
|
||
select {
|
||
case <-ctx.Done():
|
||
return ctx.Err()
|
||
default:
|
||
// 接收流式数据
|
||
recv, err := stream_resp.Recv()
|
||
if err != nil {
|
||
if err.Error() == "EOF" {
|
||
// 流结束
|
||
eventChan <- LLMStreamEvent{
|
||
Event: "done",
|
||
Data: LLMResponse{
|
||
ID: req.RequestID,
|
||
Object: "chat.completion.chunk",
|
||
},
|
||
}
|
||
global.GVA_LOG.Info("Stream chat completion finished", zap.String("requestID", req.RequestID))
|
||
return nil
|
||
}
|
||
global.GVA_LOG.Error("Stream receive error", zap.Error(err), zap.String("requestID", req.RequestID))
|
||
eventChan <- LLMStreamEvent{
|
||
Event: "error",
|
||
Data: LLMResponse{
|
||
ID: req.RequestID,
|
||
Error: &LLMError{
|
||
Code: "stream_error",
|
||
Message: err.Error(),
|
||
Type: "stream_error",
|
||
},
|
||
},
|
||
}
|
||
return err
|
||
}
|
||
|
||
// 转换响应格式
|
||
choices := make([]LLMChoice, len(recv.Choices))
|
||
for i, choice := range recv.Choices {
|
||
choices[i] = LLMChoice{
|
||
Index: choice.Index,
|
||
Delta: LLMMessage{
|
||
Role: choice.Delta.Role,
|
||
Content: choice.Delta.Content,
|
||
},
|
||
FinishReason: string(choice.FinishReason),
|
||
}
|
||
}
|
||
|
||
var usage *LLMUsage
|
||
if recv.Usage != nil {
|
||
usage = &LLMUsage{
|
||
PromptTokens: recv.Usage.PromptTokens,
|
||
CompletionTokens: recv.Usage.CompletionTokens,
|
||
TotalTokens: recv.Usage.TotalTokens,
|
||
}
|
||
}
|
||
|
||
// 发送消息事件
|
||
eventChan <- LLMStreamEvent{
|
||
Event: "message",
|
||
Data: LLMResponse{
|
||
ID: recv.ID,
|
||
Object: recv.Object,
|
||
Created: recv.Created,
|
||
Model: recv.Model,
|
||
Choices: choices,
|
||
Usage: usage,
|
||
},
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// StopGeneration 停止生成
|
||
func (v *VolcengineLLMUtil) StopGeneration(requestID string) error {
|
||
if cancel, ok := v.activeSessions.Load(requestID); ok {
|
||
if cancelFunc, ok := cancel.(context.CancelFunc); ok {
|
||
cancelFunc()
|
||
v.activeSessions.Delete(requestID)
|
||
global.GVA_LOG.Info("Generation stopped", zap.String("requestID", requestID))
|
||
return nil
|
||
}
|
||
}
|
||
return fmt.Errorf("session not found: %s", requestID)
|
||
}
|
||
|
||
// GetActiveSessionsCount 获取活跃会话数量
|
||
func (v *VolcengineLLMUtil) GetActiveSessionsCount() int {
|
||
count := 0
|
||
v.activeSessions.Range(func(key, value interface{}) bool {
|
||
count++
|
||
return true
|
||
})
|
||
return count
|
||
}
|