pet-ai/server/plugin/volcengine/api/llm_api.go

178 lines
5.3 KiB
Go

package api
import (
"fmt"
"io"
"time"
"github.com/flipped-aurora/gin-vue-admin/server/global"
"github.com/flipped-aurora/gin-vue-admin/server/model/common/response"
"github.com/flipped-aurora/gin-vue-admin/server/plugin/volcengine/model/request"
volcengineResponse "github.com/flipped-aurora/gin-vue-admin/server/plugin/volcengine/model/response"
"github.com/flipped-aurora/gin-vue-admin/server/plugin/volcengine/service"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.uber.org/zap"
)
type LLMApi struct{}
// ChatCompletion
// @Tags Volcengine
// @Summary LLM聊天完成
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Param data body request.ChatRequest true "聊天请求参数"
// @Success 200 {object} response.Response{data=volcengineResponse.ChatResponse} "聊天响应"
// @Router /volcengine/llm/chat [post]
func (l *LLMApi) ChatCompletion(c *gin.Context) {
var chatReq request.ChatRequest
err := c.ShouldBindJSON(&chatReq)
if err != nil {
response.FailWithMessage(err.Error(), c)
return
}
// 生成请求ID
if chatReq.RequestID == "" {
chatReq.RequestID = uuid.New().String()
}
// 根据stream参数选择响应模式
if chatReq.Stream {
// 流式响应 (SSE)
l.handleStreamResponse(c, chatReq)
} else {
// 普通JSON响应
l.handleNormalResponse(c, chatReq)
}
}
// handleNormalResponse 处理普通JSON响应
func (l *LLMApi) handleNormalResponse(c *gin.Context, chatReq request.ChatRequest) {
resp, err := service.LLMServiceApp.ChatCompletion(chatReq)
if err != nil {
global.GVA_LOG.Error("聊天完成失败!", zap.Error(err), zap.String("requestID", chatReq.RequestID))
response.FailWithMessage("聊天完成失败: "+err.Error(), c)
return
}
global.GVA_LOG.Info("聊天完成成功", zap.String("requestID", chatReq.RequestID))
response.OkWithDetailed(resp, "聊天完成成功", c)
}
// handleStreamResponse 处理流式响应 (SSE)
func (l *LLMApi) handleStreamResponse(c *gin.Context, chatReq request.ChatRequest) {
// 设置SSE响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Headers", "Cache-Control")
// 创建事件通道
eventChan := make(chan volcengineResponse.StreamEvent, 100)
defer close(eventChan)
// 启动流式处理
go func() {
err := service.LLMServiceApp.StreamChatCompletion(chatReq, eventChan)
if err != nil {
global.GVA_LOG.Error("流式聊天失败!", zap.Error(err), zap.String("requestID", chatReq.RequestID))
eventChan <- volcengineResponse.StreamEvent{
Event: "error",
Data: volcengineResponse.ChatResponse{
ID: chatReq.RequestID,
Error: &volcengineResponse.APIError{
Code: "stream_error",
Message: err.Error(),
Type: "internal_error",
},
},
}
}
}()
// 发送流式数据
c.Stream(func(w io.Writer) bool {
select {
case event, ok := <-eventChan:
if !ok {
return false
}
switch event.Event {
case "message":
// 发送消息事件
c.SSEvent("message", event.Data)
case "error":
// 发送错误事件
c.SSEvent("error", event.Data)
return false
case "done":
// 发送完成事件
c.SSEvent("done", event.Data)
return false
}
return true
case <-time.After(30 * time.Second):
// 超时处理
c.SSEvent("error", volcengineResponse.ChatResponse{
ID: chatReq.RequestID,
Error: &volcengineResponse.APIError{
Code: "timeout",
Message: "Stream timeout",
Type: "timeout_error",
},
})
return false
}
})
}
// StopGeneration
// @Tags Volcengine
// @Summary 停止LLM生成
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Param data body request.StopRequest true "停止请求参数"
// @Success 200 {object} response.Response{data=volcengineResponse.StopResponse} "停止响应"
// @Router /volcengine/llm/stop [post]
func (l *LLMApi) StopGeneration(c *gin.Context) {
var stopReq request.StopRequest
err := c.ShouldBindJSON(&stopReq)
if err != nil {
response.FailWithMessage(err.Error(), c)
return
}
resp, err := service.LLMServiceApp.StopGeneration(stopReq)
if err != nil {
global.GVA_LOG.Error("停止生成失败!", zap.Error(err), zap.String("requestID", stopReq.RequestID))
response.FailWithMessage("停止生成失败: "+err.Error(), c)
return
}
if resp.Success {
global.GVA_LOG.Info("停止生成成功", zap.String("requestID", stopReq.RequestID))
response.OkWithDetailed(resp, "停止生成成功", c)
} else {
global.GVA_LOG.Warn("停止生成失败", zap.String("requestID", stopReq.RequestID), zap.String("reason", resp.Message))
response.FailWithDetailed(resp, resp.Message, c)
}
}
// GetActiveSessionsCount
// @Tags Volcengine
// @Summary 获取活跃会话数量
// @Security ApiKeyAuth
// @Produce application/json
// @Success 200 {object} response.Response{data=int} "活跃会话数量"
// @Router /volcengine/llm/sessions [get]
func (l *LLMApi) GetActiveSessionsCount(c *gin.Context) {
count := service.LLMServiceApp.GetActiveSessionsCount()
response.OkWithDetailed(gin.H{"count": count}, fmt.Sprintf("当前活跃会话数量: %d", count), c)
}