Merge remote-tracking branch 'yvan/main'
# Conflicts: # server/go.mod # server/go.sum
This commit is contained in:
commit
73063af861
|
|
@ -1,6 +1,9 @@
|
|||
package pet
|
||||
|
||||
import "github.com/flipped-aurora/gin-vue-admin/server/service"
|
||||
import (
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/api/v1/pet/user"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/service"
|
||||
)
|
||||
|
||||
type ApiGroup struct {
|
||||
PetAdoptionApplicationsApi
|
||||
|
|
@ -13,6 +16,7 @@ type ApiGroup struct {
|
|||
PetFamilyPetsApi
|
||||
PetPetsApi
|
||||
PetRecordsApi
|
||||
PetUserApiGroup user.ApiGroup
|
||||
}
|
||||
|
||||
var (
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package user
|
|||
import "github.com/flipped-aurora/gin-vue-admin/server/service"
|
||||
|
||||
type ApiGroup struct {
|
||||
PetAssistantApi
|
||||
}
|
||||
|
||||
var (
|
||||
|
|
|
|||
|
|
@ -0,0 +1,298 @@
|
|||
package user
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/common/response"
|
||||
petRequest "github.com/flipped-aurora/gin-vue-admin/server/model/pet/request"
|
||||
petResponse "github.com/flipped-aurora/gin-vue-admin/server/model/pet/response"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/service"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type PetAssistantApi struct{}
|
||||
|
||||
var petChatService = service.ServiceGroupApp.PetServiceGroup.PetChatService
|
||||
|
||||
// AskPetAssistant 向宠物助手提问(非流式)
|
||||
// @Tags PetAssistant
|
||||
// @Summary 向宠物助手提问
|
||||
// @Security ApiKeyAuth
|
||||
// @Accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body petRequest.ChatRequest true "宠物助手提问请求"
|
||||
// @Success 200 {object} response.Response{data=petResponse.ChatResponse,msg=string} "提问成功"
|
||||
// @Router /api/v1/pet/user/assistant/ask [post]
|
||||
func (p *PetAssistantApi) AskPetAssistant(ctx *gin.Context) {
|
||||
// 创建业务用Context
|
||||
businessCtx := ctx.Request.Context()
|
||||
|
||||
// 获取用户ID
|
||||
userId := utils.GetAppUserID(ctx)
|
||||
if userId == 0 {
|
||||
global.GVA_LOG.Error("获取用户ID失败")
|
||||
response.FailWithMessage("用户认证失败", ctx)
|
||||
return
|
||||
}
|
||||
|
||||
// 绑定请求参数
|
||||
var req petRequest.ChatRequest
|
||||
if err := ctx.ShouldBindJSON(&req); err != nil {
|
||||
global.GVA_LOG.Error("参数绑定失败", zap.Error(err))
|
||||
response.FailWithMessage("参数错误: "+err.Error(), ctx)
|
||||
return
|
||||
}
|
||||
|
||||
// 参数验证
|
||||
if req.Message == "" {
|
||||
response.FailWithMessage("消息内容不能为空", ctx)
|
||||
return
|
||||
}
|
||||
|
||||
// 设置默认参数
|
||||
if req.Temperature <= 0 {
|
||||
req.Temperature = 0.7
|
||||
}
|
||||
if req.MaxTokens <= 0 {
|
||||
req.MaxTokens = 1000
|
||||
}
|
||||
|
||||
// 调用服务层
|
||||
resp, err := petChatService.SendMessage(businessCtx, userId, req)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("发送消息失败", zap.Error(err), zap.Uint("userId", userId))
|
||||
response.FailWithMessage("发送消息失败: "+err.Error(), ctx)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithDetailed(resp, "发送成功", ctx)
|
||||
}
|
||||
|
||||
// StreamAskPetAssistant 向宠物助手流式提问
|
||||
// @Tags PetAssistant
|
||||
// @Summary 向宠物助手流式提问接口
|
||||
// @Security ApiKeyAuth
|
||||
// @Accept application/json
|
||||
// @Produce text/event-stream
|
||||
// @Param data body petRequest.ChatRequest true "宠物助手流式提问请求"
|
||||
// @Success 200 {string} string "流式响应"
|
||||
// @Router /api/v1/pet/user/assistant/stream-ask [post]
|
||||
func (p *PetAssistantApi) StreamAskPetAssistant(ctx *gin.Context) {
|
||||
// 创建业务用Context
|
||||
businessCtx := ctx.Request.Context()
|
||||
|
||||
// 获取用户ID
|
||||
userId := utils.GetAppUserID(ctx)
|
||||
if userId == 0 {
|
||||
global.GVA_LOG.Error("获取用户ID失败")
|
||||
response.FailWithMessage("用户认证失败", ctx)
|
||||
return
|
||||
}
|
||||
|
||||
// 绑定请求参数
|
||||
var req petRequest.ChatRequest
|
||||
if err := ctx.ShouldBindJSON(&req); err != nil {
|
||||
global.GVA_LOG.Error("参数绑定失败", zap.Error(err))
|
||||
response.FailWithMessage("参数错误: "+err.Error(), ctx)
|
||||
return
|
||||
}
|
||||
|
||||
// 参数验证
|
||||
if req.Message == "" {
|
||||
response.FailWithMessage("消息内容不能为空", ctx)
|
||||
return
|
||||
}
|
||||
|
||||
// 强制设置为流式响应
|
||||
req.Stream = true
|
||||
|
||||
// 设置默认参数
|
||||
if req.Temperature <= 0 {
|
||||
req.Temperature = 0.7
|
||||
}
|
||||
if req.MaxTokens <= 0 {
|
||||
req.MaxTokens = 1000
|
||||
}
|
||||
|
||||
// 设置SSE响应头
|
||||
ctx.Header("Content-Type", "text/event-stream")
|
||||
ctx.Header("Cache-Control", "no-cache")
|
||||
ctx.Header("Connection", "keep-alive")
|
||||
ctx.Header("Access-Control-Allow-Origin", "*")
|
||||
ctx.Header("Access-Control-Allow-Headers", "Cache-Control")
|
||||
|
||||
// 创建事件通道
|
||||
eventChan := make(chan petResponse.StreamEvent, 100)
|
||||
defer close(eventChan)
|
||||
|
||||
// 启动流式聊天
|
||||
go func() {
|
||||
if err := petChatService.StreamChat(businessCtx, userId, req, eventChan); err != nil {
|
||||
global.GVA_LOG.Error("流式聊天失败", zap.Error(err), zap.Uint("userId", userId))
|
||||
eventChan <- petResponse.StreamEvent{
|
||||
Event: "error",
|
||||
Data: map[string]interface{}{
|
||||
"error": "流式聊天失败: " + err.Error(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 发送流式数据
|
||||
ctx.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case event, ok := <-eventChan:
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
switch event.Event {
|
||||
case "message":
|
||||
// 发送消息事件
|
||||
ctx.SSEvent("message", event.Data)
|
||||
case "error":
|
||||
// 发送错误事件
|
||||
ctx.SSEvent("error", event.Data)
|
||||
return false
|
||||
case "done":
|
||||
// 发送完成事件
|
||||
ctx.SSEvent("done", event.Data)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
case <-time.After(30 * time.Second):
|
||||
// 超时处理
|
||||
ctx.SSEvent("error", map[string]interface{}{
|
||||
"error": "流式响应超时",
|
||||
})
|
||||
return false
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// GetAssistantHistory 获取宠物助手对话历史
|
||||
// @Tags PetAssistant
|
||||
// @Summary 获取宠物助手对话历史记录
|
||||
// @Security ApiKeyAuth
|
||||
// @Accept application/json
|
||||
// @Produce application/json
|
||||
// @Param sessionId query string false "会话ID"
|
||||
// @Param page query int false "页码" default(1)
|
||||
// @Param pageSize query int false "每页数量" default(20)
|
||||
// @Success 200 {object} response.Response{data=[]pet.PetAiAssistantConversations,msg=string} "获取成功"
|
||||
// @Router /api/v1/pet/user/assistant/history [get]
|
||||
func (p *PetAssistantApi) GetAssistantHistory(ctx *gin.Context) {
|
||||
// 创建业务用Context
|
||||
businessCtx := ctx.Request.Context()
|
||||
|
||||
// 获取用户ID
|
||||
userId := utils.GetAppUserID(ctx)
|
||||
if userId == 0 {
|
||||
global.GVA_LOG.Error("获取用户ID失败")
|
||||
response.FailWithMessage("用户认证失败", ctx)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取查询参数
|
||||
sessionId := ctx.Query("sessionId")
|
||||
pageStr := ctx.DefaultQuery("page", "1")
|
||||
pageSizeStr := ctx.DefaultQuery("pageSize", "20")
|
||||
|
||||
page, err := strconv.Atoi(pageStr)
|
||||
if err != nil || page < 1 {
|
||||
page = 1
|
||||
}
|
||||
|
||||
pageSize, err := strconv.Atoi(pageSizeStr)
|
||||
if err != nil || pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
// 计算limit
|
||||
limit := pageSize
|
||||
|
||||
// 调用服务层
|
||||
conversations, err := petChatService.GetChatHistory(businessCtx, userId, sessionId, limit)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("获取聊天历史失败", zap.Error(err), zap.Uint("userId", userId))
|
||||
response.FailWithMessage("获取聊天历史失败: "+err.Error(), ctx)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithDetailed(conversations, "获取成功", ctx)
|
||||
}
|
||||
|
||||
// ClearAssistantHistory 清空宠物助手对话历史
|
||||
// @Tags PetAssistant
|
||||
// @Summary 清空宠物助手对话历史记录
|
||||
// @Security ApiKeyAuth
|
||||
// @Accept application/json
|
||||
// @Produce application/json
|
||||
// @Param sessionId query string false "会话ID,不传则清空所有会话"
|
||||
// @Success 200 {object} response.Response{msg=string} "清空成功"
|
||||
// @Router /api/v1/pet/user/assistant/clear-history [delete]
|
||||
func (p *PetAssistantApi) ClearAssistantHistory(ctx *gin.Context) {
|
||||
// 创建业务用Context
|
||||
businessCtx := ctx.Request.Context()
|
||||
|
||||
// 获取用户ID
|
||||
userId := utils.GetAppUserID(ctx)
|
||||
if userId == 0 {
|
||||
global.GVA_LOG.Error("获取用户ID失败")
|
||||
response.FailWithMessage("用户认证失败", ctx)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取会话ID(可选)
|
||||
sessionId := ctx.Query("sessionId")
|
||||
|
||||
// 调用服务层
|
||||
err := petChatService.ClearChatHistory(businessCtx, userId, sessionId)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("清空聊天历史失败", zap.Error(err), zap.Uint("userId", userId))
|
||||
response.FailWithMessage("清空聊天历史失败: "+err.Error(), ctx)
|
||||
return
|
||||
}
|
||||
|
||||
if sessionId != "" {
|
||||
response.OkWithMessage("指定会话历史清空成功", ctx)
|
||||
} else {
|
||||
response.OkWithMessage("所有聊天历史清空成功", ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// GetAssistantSessions 获取宠物助手会话列表
|
||||
// @Tags PetAssistant
|
||||
// @Summary 获取用户的宠物助手会话列表
|
||||
// @Security ApiKeyAuth
|
||||
// @Accept application/json
|
||||
// @Produce application/json
|
||||
// @Success 200 {object} response.Response{data=[]map[string]interface{},msg=string} "获取成功"
|
||||
// @Router /api/v1/pet/user/assistant/sessions [get]
|
||||
func (p *PetAssistantApi) GetAssistantSessions(ctx *gin.Context) {
|
||||
// 创建业务用Context
|
||||
businessCtx := ctx.Request.Context()
|
||||
|
||||
// 获取用户ID
|
||||
userId := utils.GetAppUserID(ctx)
|
||||
if userId == 0 {
|
||||
global.GVA_LOG.Error("获取用户ID失败")
|
||||
response.FailWithMessage("用户认证失败", ctx)
|
||||
return
|
||||
}
|
||||
|
||||
// 调用服务层
|
||||
sessions, err := petChatService.GetChatSessions(businessCtx, userId)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("获取会话列表失败", zap.Error(err), zap.Uint("userId", userId))
|
||||
response.FailWithMessage("获取会话列表失败: "+err.Error(), ctx)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithDetailed(sessions, "获取成功", ctx)
|
||||
}
|
||||
|
|
@ -237,11 +237,8 @@ tencent-cos:
|
|||
base-url: https://gin.vue.admin
|
||||
path-prefix: github.com/flipped-aurora/gin-vue-admin/server
|
||||
volcengine:
|
||||
access-key: your-access-key
|
||||
secret-key: your-secret-key
|
||||
region: cn-beijing
|
||||
endpoint: https://ark.cn-beijing.volces.com
|
||||
default-model: ep-xxx
|
||||
api-key: 7562d83f-fc5c-4229-9aed-e8d9242c8683
|
||||
model: ep-20250909151934-m777l
|
||||
wechat:
|
||||
mini-app-id: wx0f5dc17ba3f9fe31
|
||||
mini-app-secret: 5e700810a6f56717e28af76dbb63983d
|
||||
|
|
|
|||
|
|
@ -1,9 +1,6 @@
|
|||
package config
|
||||
|
||||
type Volcengine struct {
|
||||
AccessKey string `mapstructure:"access-key" json:"access-key" yaml:"access-key"` // 火山引擎访问密钥ID
|
||||
SecretKey string `mapstructure:"secret-key" json:"secret-key" yaml:"secret-key"` // 火山引擎访问密钥Secret
|
||||
Region string `mapstructure:"region" json:"region" yaml:"region"` // 区域,如:cn-beijing
|
||||
Endpoint string `mapstructure:"endpoint" json:"endpoint" yaml:"endpoint"` // 服务端点,如:https://ark.cn-beijing.volces.com
|
||||
DefaultModel string `mapstructure:"default-model" json:"default-model" yaml:"default-model"` // 默认模型ID,如:ep-xxx
|
||||
ApiKey string `mapstructure:"api-key" json:"api-key" yaml:"api-key"` // 火山引擎API密钥
|
||||
Model string `mapstructure:"model" json:"model" yaml:"model"` // 模型ID,如:ep-xxx
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,15 +1,12 @@
|
|||
module github.com/flipped-aurora/gin-vue-admin/server
|
||||
|
||||
go 1.23
|
||||
|
||||
toolchain go1.23.9
|
||||
go 1.25.0
|
||||
|
||||
require (
|
||||
github.com/aliyun/aliyun-oss-go-sdk v3.0.2+incompatible
|
||||
github.com/aws/aws-sdk-go v1.55.6
|
||||
github.com/casbin/casbin/v2 v2.103.0
|
||||
github.com/casbin/gorm-adapter/v3 v3.32.0
|
||||
github.com/dzwvip/gorm-oracle v0.1.2
|
||||
github.com/fsnotify/fsnotify v1.8.0
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/glebarez/sqlite v1.11.0
|
||||
|
|
@ -34,13 +31,13 @@ require (
|
|||
github.com/silenceper/wechat/v2 v2.1.9
|
||||
github.com/songzhibin97/gkit v1.2.13
|
||||
github.com/spf13/viper v1.19.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/swaggo/files v1.0.1
|
||||
github.com/swaggo/gin-swagger v1.6.0
|
||||
github.com/swaggo/swag v1.16.4
|
||||
github.com/tencentyun/cos-go-sdk-v5 v0.7.60
|
||||
github.com/unrolled/secure v1.17.0
|
||||
github.com/volcengine/volcengine-go-sdk v1.1.32
|
||||
github.com/volcengine/volcengine-go-sdk v1.1.31
|
||||
github.com/xuri/excelize/v2 v2.9.0
|
||||
go.mongodb.org/mongo-driver v1.17.2
|
||||
go.uber.org/automaxprocs v1.6.0
|
||||
|
|
@ -78,7 +75,6 @@ require (
|
|||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dsnet/compress v0.0.2-0.20230904184137-39efe44ab707 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/emirpasic/gods v1.12.0 // indirect
|
||||
github.com/fatih/structs v1.1.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.8 // indirect
|
||||
github.com/gammazero/toposort v0.1.1 // indirect
|
||||
|
|
@ -113,6 +109,7 @@ require (
|
|||
github.com/jmespath/go-jmespath v0.4.0 // indirect
|
||||
github.com/josharian/intern v1.0.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/kirklin/go-swd v0.0.3 // indirect
|
||||
github.com/klauspost/compress v1.17.11 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.9 // indirect
|
||||
github.com/klauspost/pgzip v1.2.6 // indirect
|
||||
|
|
@ -144,7 +141,6 @@ require (
|
|||
github.com/sagikazarmark/locafero v0.7.0 // indirect
|
||||
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
||||
github.com/shoenig/go-m1cpu v0.1.6 // indirect
|
||||
github.com/sijms/go-ora/v2 v2.7.17 // indirect
|
||||
github.com/sirupsen/logrus v1.9.0 // indirect
|
||||
github.com/sorairolake/lzip-go v0.3.5 // indirect
|
||||
github.com/sourcegraph/conc v0.3.0 // indirect
|
||||
|
|
@ -153,7 +149,6 @@ require (
|
|||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/therootcompany/xz v1.0.1 // indirect
|
||||
github.com/thoas/go-funk v0.7.0 // indirect
|
||||
github.com/tidwall/gjson v1.14.1 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
|
|
|
|||
|
|
@ -115,10 +115,6 @@ github.com/dsnet/compress v0.0.2-0.20230904184137-39efe44ab707/go.mod h1:qssHWj6
|
|||
github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/dzwvip/gorm-oracle v0.1.2 h1:811aFDY7oDfKWHc0Z0lHdXzzr89EmKBSwc/jLJ8GU5g=
|
||||
github.com/dzwvip/gorm-oracle v0.1.2/go.mod h1:TbF7idnO9UgGpJ0qJpDZby1/wGquzP5GYof88ScBITE=
|
||||
github.com/emirpasic/gods v1.12.0 h1:QAUIPSaCu4G+POclxeqb3F+WPpdKqFGlw36+yOzGlrg=
|
||||
github.com/emirpasic/gods v1.12.0/go.mod h1:YfzfFFoVP/catgzJb4IKIqXjX78Ha8FMSDh3ymbK86o=
|
||||
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||
github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo=
|
||||
|
|
@ -295,7 +291,6 @@ github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJk
|
|||
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
|
||||
|
|
@ -310,6 +305,8 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr
|
|||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
|
||||
github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk=
|
||||
github.com/kirklin/go-swd v0.0.3 h1:NvtV4jV9mm/1gobGYp7UE6bgBhCdJfBckO+ePzFffbo=
|
||||
github.com/kirklin/go-swd v0.0.3/go.mod h1:bnU1Fz3Uil9T1mRyiwjyeV39567+iDFaxj7XRQUmK2s=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.4.1/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A=
|
||||
github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
|
||||
|
|
@ -460,8 +457,6 @@ github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFt
|
|||
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
|
||||
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
|
||||
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
|
||||
github.com/sijms/go-ora/v2 v2.7.17 h1:M/pYIqjaMUeBxyzOWp2oj4ntF6fHSBloJWGNH9vbmsU=
|
||||
github.com/sijms/go-ora/v2 v2.7.17/go.mod h1:EHxlY6x7y9HAsdfumurRfTd+v8NrEOTR3Xl4FWlH6xk=
|
||||
github.com/silenceper/wechat/v2 v2.1.9 h1:wc092gUkGbbBRTdzPxROhQhOH5iE98stnfzKA73mnTo=
|
||||
github.com/silenceper/wechat/v2 v2.1.9/go.mod h1:7Iu3EhQYVtDUJAj+ZVRy8yom75ga7aDWv8RurLkVm0s=
|
||||
github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0=
|
||||
|
|
@ -496,8 +491,9 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
|||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
github.com/swaggo/files v1.0.1 h1:J1bVJ4XHZNq0I46UU90611i9/YzdrF7x92oX1ig5IdE=
|
||||
|
|
@ -512,8 +508,6 @@ github.com/tencentyun/cos-go-sdk-v5 v0.7.60 h1:/e/tmvRmfKexr/QQIBzWhOkZWsmY3EK72
|
|||
github.com/tencentyun/cos-go-sdk-v5 v0.7.60/go.mod h1:8+hG+mQMuRP/OIS9d83syAvXvrMj9HhkND6Q1fLghw0=
|
||||
github.com/therootcompany/xz v1.0.1 h1:CmOtsn1CbtmyYiusbfmhmkpAAETj0wBIH6kCYaX+xzw=
|
||||
github.com/therootcompany/xz v1.0.1/go.mod h1:3K3UH1yCKgBneZYhuQUvJ9HPD19UEXEI0BWbMn8qNMY=
|
||||
github.com/thoas/go-funk v0.7.0 h1:GmirKrs6j6zJbhJIficOsz2aAI7700KsU/5YrdHRM1Y=
|
||||
github.com/thoas/go-funk v0.7.0/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q=
|
||||
github.com/tidwall/gjson v1.14.1 h1:iymTbGkQBhveq21bEvAQ81I0LEBork8BFe1CUZXdyuo=
|
||||
github.com/tidwall/gjson v1.14.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
|
|
@ -535,8 +529,10 @@ github.com/unrolled/secure v1.17.0 h1:Io7ifFgo99Bnh0J7+Q+qcMzWM6kaDPCA5FroFZEdbW
|
|||
github.com/unrolled/secure v1.17.0/go.mod h1:BmF5hyM6tXczk3MpQkFf1hpKSRqCyhqcbiQtiAF7+40=
|
||||
github.com/volcengine/volc-sdk-golang v1.0.23 h1:anOslb2Qp6ywnsbyq9jqR0ljuO63kg9PY+4OehIk5R8=
|
||||
github.com/volcengine/volc-sdk-golang v1.0.23/go.mod h1:AfG/PZRUkHJ9inETvbjNifTDgut25Wbkm2QoYBTbvyU=
|
||||
github.com/volcengine/volcengine-go-sdk v1.1.32 h1:ixNtkXFyf1hy1LL0np17wP6CaN9T42GAhzv8hZiKFPc=
|
||||
github.com/volcengine/volcengine-go-sdk v1.1.32/go.mod h1:EyKoi6t6eZxoPNGr2GdFCZti2Skd7MO3eUzx7TtSvNo=
|
||||
github.com/volcengine/volcengine-go-sdk v1.1.25 h1:wwR2DTJGw2sOZ1wTWaQLn03PGO0O+motGvsoVvAp5Zk=
|
||||
github.com/volcengine/volcengine-go-sdk v1.1.25/go.mod h1:EyKoi6t6eZxoPNGr2GdFCZti2Skd7MO3eUzx7TtSvNo=
|
||||
github.com/volcengine/volcengine-go-sdk v1.1.31 h1:qrNd/fu+ZWzH93EBSPBSntGwgwo7cHITxxv1IdvLxls=
|
||||
github.com/volcengine/volcengine-go-sdk v1.1.31/go.mod h1:EyKoi6t6eZxoPNGr2GdFCZti2Skd7MO3eUzx7TtSvNo=
|
||||
github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c=
|
||||
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
|
||||
github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY=
|
||||
|
|
@ -898,7 +894,6 @@ gorm.io/driver/sqlserver v1.5.4 h1:xA+Y1KDNspv79q43bPyjDMUgHoYHLhXYmdFcYPobg8g=
|
|||
gorm.io/driver/sqlserver v1.5.4/go.mod h1:+frZ/qYmuna11zHPlh5oc2O6ZA/lS88Keb0XSH1Zh/g=
|
||||
gorm.io/gen v0.3.26 h1:sFf1j7vNStimPRRAtH4zz5NiHM+1dr6eA9aaRdplyhY=
|
||||
gorm.io/gen v0.3.26/go.mod h1:a5lq5y3w4g5LMxBcw0wnO6tYUCdNutWODq5LrIt75LE=
|
||||
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
|
||||
gorm.io/gorm v1.24.7-0.20230306060331-85eaf9eeda11/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
|
||||
gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
|
||||
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import (
|
|||
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/plugin/email"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/plugin/volcengine"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/utils/plugin"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
|
@ -34,7 +33,6 @@ func bizPluginV1(group ...*gin.RouterGroup) {
|
|||
global.GVA_CONFIG.Email.IsSSL,
|
||||
global.GVA_CONFIG.Email.IsLoginAuth,
|
||||
),
|
||||
volcengine.CreateVolcenginePlug(),
|
||||
)
|
||||
holder(public, private)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ func Routers() *gin.Engine {
|
|||
InstallPlugin(PrivateGroup, PublicGroup, Router)
|
||||
|
||||
// 注册业务路由
|
||||
initBizRouter(PrivateGroup, PublicGroup)
|
||||
initBizRouter(PrivateGroup, PublicGroup, UserGroup)
|
||||
|
||||
global.GVA_ROUTERS = Router.Routes()
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ func holder(routers ...*gin.RouterGroup) {
|
|||
func initBizRouter(routers ...*gin.RouterGroup) {
|
||||
privateGroup := routers[0]
|
||||
publicGroup := routers[1]
|
||||
userGroup := routers[2]
|
||||
holder(publicGroup, privateGroup)
|
||||
{
|
||||
petRouter := router.RouterGroupApp.Pet
|
||||
|
|
@ -25,5 +26,10 @@ func initBizRouter(routers ...*gin.RouterGroup) {
|
|||
petRouter.InitPetFamilyPetsRouter(privateGroup, publicGroup)
|
||||
petRouter.InitPetPetsRouter(privateGroup, publicGroup) // 占位方法,保证文件可以正确加载,避免go空变量检测报错,请勿删除。
|
||||
petRouter.InitPetRecordsRouter(privateGroup, publicGroup)
|
||||
|
||||
// 用户相关路由(需要UserJWTAuth认证)
|
||||
if userGroup != nil {
|
||||
petRouter.InitPetAssistantRouter(userGroup, publicGroup)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,10 @@ type PetAiAssistantConversations struct {
|
|||
UserId *int `json:"userId" form:"userId" gorm:"comment:用户ID;column:user_id;size:20;" binding:"required"` //用户ID
|
||||
MessageContent *string `json:"messageContent" form:"messageContent" gorm:"comment:消息内容;column:message_content;" binding:"required"` //消息内容
|
||||
Role *string `json:"role" form:"role" gorm:"comment:发送者角色:user-用户,assistant-AI助手;column:role;size:20;" binding:"required"` //发送者角色:user-用户,assistant-AI助手
|
||||
SessionId *string `json:"sessionId" form:"sessionId" gorm:"comment:会话ID;column:session_id;size:64;"` //会话ID
|
||||
IsSensitive *bool `json:"isSensitive" form:"isSensitive" gorm:"comment:是否包含敏感词;column:is_sensitive;default:false;"` //是否包含敏感词
|
||||
TokenCount *int `json:"tokenCount" form:"tokenCount" gorm:"comment:Token消耗数量;column:token_count;default:0;"` //Token消耗数量
|
||||
ResponseTime *int `json:"responseTime" form:"responseTime" gorm:"comment:响应时间(ms);column:response_time;default:0;"` //响应时间(ms)
|
||||
}
|
||||
|
||||
// TableName petAiAssistantConversations表 PetAiAssistantConversations自定义表名 pet_ai_assistant_conversations
|
||||
|
|
|
|||
|
|
@ -0,0 +1,60 @@
|
|||
package request
|
||||
|
||||
import (
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/common/request"
|
||||
)
|
||||
|
||||
// ChatRequest 宠物助手聊天请求结构体
|
||||
type ChatRequest struct {
|
||||
Message string `json:"message" binding:"required" validate:"required,min=1,max=2000"` // 用户消息内容,必填,1-2000字符
|
||||
SessionId string `json:"sessionId" validate:"omitempty,uuid4"` // 会话ID,可选,UUID格式
|
||||
Stream bool `json:"stream"` // 是否流式响应,默认false
|
||||
Temperature float64 `json:"temperature" validate:"omitempty,min=0,max=2"` // 温度参数,控制随机性,范围0-2
|
||||
MaxTokens int `json:"maxTokens" validate:"omitempty,min=1,max=4000"` // 最大生成token数,范围1-4000
|
||||
Model string `json:"model" validate:"omitempty,min=1,max=100"` // 模型名称,可选
|
||||
TopP float64 `json:"topP" validate:"omitempty,min=0,max=1"` // 核采样参数,范围0-1
|
||||
}
|
||||
|
||||
// ChatHistoryRequest 获取聊天历史请求结构体
|
||||
type ChatHistoryRequest struct {
|
||||
SessionId string `json:"sessionId" form:"sessionId" validate:"omitempty,uuid4"` // 会话ID,可选,UUID格式
|
||||
request.PageInfo // 分页信息
|
||||
}
|
||||
|
||||
// ClearHistoryRequest 清空聊天历史请求结构体
|
||||
type ClearHistoryRequest struct {
|
||||
SessionId string `json:"sessionId" form:"sessionId" validate:"omitempty,uuid4"` // 会话ID,可选,不传则清空所有会话
|
||||
}
|
||||
|
||||
// SessionsRequest 获取会话列表请求结构体
|
||||
type SessionsRequest struct {
|
||||
request.PageInfo // 分页信息
|
||||
}
|
||||
|
||||
// StopGenerationRequest 停止生成请求结构体
|
||||
type StopGenerationRequest struct {
|
||||
RequestId string `json:"requestId" binding:"required" validate:"required,uuid4"` // 请求ID,必填,UUID格式
|
||||
}
|
||||
|
||||
// RegenerateRequest 重新生成回复请求结构体
|
||||
type RegenerateRequest struct {
|
||||
SessionId string `json:"sessionId" binding:"required" validate:"required,uuid4"` // 会话ID,必填
|
||||
MessageId uint `json:"messageId" binding:"required" validate:"required,min=1"` // 要重新生成的消息ID
|
||||
Temperature float64 `json:"temperature" validate:"omitempty,min=0,max=2"` // 温度参数
|
||||
MaxTokens int `json:"maxTokens" validate:"omitempty,min=1,max=4000"` // 最大token数
|
||||
}
|
||||
|
||||
// FeedbackRequest 用户反馈请求结构体
|
||||
type FeedbackRequest struct {
|
||||
MessageId uint `json:"messageId" binding:"required" validate:"required,min=1"` // 消息ID,必填
|
||||
FeedbackType string `json:"feedbackType" binding:"required" validate:"required,oneof=like dislike"` // 反馈类型:like/dislike
|
||||
Comment string `json:"comment" validate:"omitempty,max=500"` // 反馈评论,可选,最多500字符
|
||||
}
|
||||
|
||||
// ExportHistoryRequest 导出聊天历史请求结构体
|
||||
type ExportHistoryRequest struct {
|
||||
SessionId string `json:"sessionId" form:"sessionId" validate:"omitempty,uuid4"` // 会话ID,可选
|
||||
Format string `json:"format" form:"format" validate:"omitempty,oneof=json txt markdown"` // 导出格式:json/txt/markdown
|
||||
StartTime string `json:"startTime" form:"startTime" validate:"omitempty,datetime=2006-01-02"` // 开始时间,可选
|
||||
EndTime string `json:"endTime" form:"endTime" validate:"omitempty,datetime=2006-01-02"` // 结束时间,可选
|
||||
}
|
||||
|
|
@ -0,0 +1,128 @@
|
|||
package response
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/pet"
|
||||
)
|
||||
|
||||
// ChatResponse 宠物助手聊天响应结构体
|
||||
type ChatResponse struct {
|
||||
ID uint `json:"id"` // 消息ID
|
||||
Message string `json:"message"` // AI回复消息内容
|
||||
SessionId string `json:"sessionId"` // 会话ID
|
||||
IsSensitive bool `json:"isSensitive"` // 是否包含敏感词
|
||||
TokenCount int `json:"tokenCount"` // Token消耗数量
|
||||
ResponseTime int64 `json:"responseTime"` // 响应时间(毫秒)
|
||||
RequestId string `json:"requestId,omitempty"` // 请求ID,用于流式响应管理
|
||||
Model string `json:"model,omitempty"` // 使用的模型名称
|
||||
CreatedAt time.Time `json:"createdAt"` // 创建时间
|
||||
FinishReason string `json:"finishReason,omitempty"` // 完成原因:stop/length/content_filter
|
||||
}
|
||||
|
||||
// StreamEvent 流式事件结构体
|
||||
type StreamEvent struct {
|
||||
Event string `json:"event"` // 事件类型: message, error, done, start
|
||||
Data interface{} `json:"data"` // 事件数据
|
||||
}
|
||||
|
||||
// StreamMessageData 流式消息数据结构体
|
||||
type StreamMessageData struct {
|
||||
Delta string `json:"delta"` // 增量消息内容
|
||||
SessionId string `json:"sessionId"` // 会话ID
|
||||
RequestId string `json:"requestId"` // 请求ID
|
||||
}
|
||||
|
||||
// StreamErrorData 流式错误数据结构体
|
||||
type StreamErrorData struct {
|
||||
Error string `json:"error"` // 错误信息
|
||||
Code string `json:"code"` // 错误代码
|
||||
RequestId string `json:"requestId"` // 请求ID
|
||||
}
|
||||
|
||||
// StreamDoneData 流式完成数据结构体
|
||||
type StreamDoneData struct {
|
||||
Message string `json:"message"` // 完整消息内容
|
||||
SessionId string `json:"sessionId"` // 会话ID
|
||||
RequestId string `json:"requestId"` // 请求ID
|
||||
TokenCount int `json:"tokenCount"` // Token消耗数量
|
||||
ResponseTime int64 `json:"responseTime"` // 响应时间(毫秒)
|
||||
IsSensitive bool `json:"isSensitive"` // 是否包含敏感词
|
||||
FinishReason string `json:"finishReason"` // 完成原因
|
||||
}
|
||||
|
||||
// ChatHistoryResponse 聊天历史响应结构体
|
||||
type ChatHistoryResponse struct {
|
||||
List []pet.PetAiAssistantConversations `json:"list"` // 对话记录列表
|
||||
Total int64 `json:"total"` // 总记录数
|
||||
Page int `json:"page"` // 当前页码
|
||||
PageSize int `json:"pageSize"` // 每页大小
|
||||
}
|
||||
|
||||
// SessionInfo 会话信息结构体
|
||||
type SessionInfo struct {
|
||||
SessionId string `json:"sessionId"` // 会话ID
|
||||
LastUpdated time.Time `json:"lastUpdated"` // 最后更新时间
|
||||
MessageCount int `json:"messageCount"` // 消息数量
|
||||
FirstMessage string `json:"firstMessage"` // 第一条消息内容(用作会话标题)
|
||||
CreatedAt time.Time `json:"createdAt"` // 创建时间
|
||||
}
|
||||
|
||||
// SessionsResponse 会话列表响应结构体
|
||||
type SessionsResponse struct {
|
||||
List []SessionInfo `json:"list"` // 会话列表
|
||||
Total int64 `json:"total"` // 总会话数
|
||||
Page int `json:"page"` // 当前页码
|
||||
PageSize int `json:"pageSize"` // 每页大小
|
||||
}
|
||||
|
||||
// TokenUsage Token使用情况结构体
|
||||
type TokenUsage struct {
|
||||
PromptTokens int `json:"promptTokens"` // 输入token数
|
||||
CompletionTokens int `json:"completionTokens"` // 输出token数
|
||||
TotalTokens int `json:"totalTokens"` // 总token数
|
||||
}
|
||||
|
||||
// ModelInfo 模型信息结构体
|
||||
type ModelInfo struct {
|
||||
Name string `json:"name"` // 模型名称
|
||||
Description string `json:"description"` // 模型描述
|
||||
MaxTokens int `json:"maxTokens"` // 最大token数
|
||||
Available bool `json:"available"` // 是否可用
|
||||
}
|
||||
|
||||
// ModelsResponse 可用模型列表响应结构体
|
||||
type ModelsResponse struct {
|
||||
Models []ModelInfo `json:"models"` // 模型列表
|
||||
}
|
||||
|
||||
// StatsResponse 统计信息响应结构体
|
||||
type StatsResponse struct {
|
||||
TotalSessions int64 `json:"totalSessions"` // 总会话数
|
||||
TotalMessages int64 `json:"totalMessages"` // 总消息数
|
||||
TotalTokens int64 `json:"totalTokens"` // 总token消耗
|
||||
LastChatTime *time.Time `json:"lastChatTime"` // 最后聊天时间
|
||||
AverageResponse float64 `json:"averageResponse"` // 平均响应时间(毫秒)
|
||||
SensitiveCount int64 `json:"sensitiveCount"` // 敏感词触发次数
|
||||
PopularQuestions []string `json:"popularQuestions"` // 热门问题
|
||||
}
|
||||
|
||||
// ExportResponse 导出响应结构体
|
||||
type ExportResponse struct {
|
||||
FileName string `json:"fileName"` // 文件名
|
||||
FileSize int64 `json:"fileSize"` // 文件大小(字节)
|
||||
DownloadUrl string `json:"downloadUrl"` // 下载链接
|
||||
ExpiresAt int64 `json:"expiresAt"` // 过期时间戳
|
||||
}
|
||||
|
||||
// HealthResponse 健康检查响应结构体
|
||||
type HealthResponse struct {
|
||||
Status string `json:"status"` // 服务状态:healthy/unhealthy
|
||||
LLMService string `json:"llmService"` // LLM服务状态
|
||||
SensitiveFilter string `json:"sensitiveFilter"` // 敏感词过滤状态
|
||||
Database string `json:"database"` // 数据库状态
|
||||
LastChecked time.Time `json:"lastChecked"` // 最后检查时间
|
||||
ResponseTime int64 `json:"responseTime"` // 响应时间(毫秒)
|
||||
ActiveSessions int `json:"activeSessions"` // 活跃会话数
|
||||
QueuedRequests int `json:"queuedRequests"` // 排队请求数
|
||||
}
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
package api
|
||||
|
||||
type ApiGroup struct {
|
||||
LLMApi
|
||||
}
|
||||
|
||||
var ApiGroupApp = new(ApiGroup)
|
||||
|
|
@ -1,177 +0,0 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/common/response"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/plugin/volcengine/model/request"
|
||||
volcengineResponse "github.com/flipped-aurora/gin-vue-admin/server/plugin/volcengine/model/response"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/plugin/volcengine/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type LLMApi struct{}
|
||||
|
||||
// ChatCompletion
|
||||
// @Tags Volcengine
|
||||
// @Summary LLM聊天完成
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.ChatRequest true "聊天请求参数"
|
||||
// @Success 200 {object} response.Response{data=volcengineResponse.ChatResponse} "聊天响应"
|
||||
// @Router /volcengine/llm/chat [post]
|
||||
func (l *LLMApi) ChatCompletion(c *gin.Context) {
|
||||
var chatReq request.ChatRequest
|
||||
err := c.ShouldBindJSON(&chatReq)
|
||||
if err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
// 生成请求ID
|
||||
if chatReq.RequestID == "" {
|
||||
chatReq.RequestID = uuid.New().String()
|
||||
}
|
||||
|
||||
// 根据stream参数选择响应模式
|
||||
if chatReq.Stream {
|
||||
// 流式响应 (SSE)
|
||||
l.handleStreamResponse(c, chatReq)
|
||||
} else {
|
||||
// 普通JSON响应
|
||||
l.handleNormalResponse(c, chatReq)
|
||||
}
|
||||
}
|
||||
|
||||
// handleNormalResponse 处理普通JSON响应
|
||||
func (l *LLMApi) handleNormalResponse(c *gin.Context, chatReq request.ChatRequest) {
|
||||
resp, err := service.LLMServiceApp.ChatCompletion(chatReq)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("聊天完成失败!", zap.Error(err), zap.String("requestID", chatReq.RequestID))
|
||||
response.FailWithMessage("聊天完成失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
global.GVA_LOG.Info("聊天完成成功", zap.String("requestID", chatReq.RequestID))
|
||||
response.OkWithDetailed(resp, "聊天完成成功", c)
|
||||
}
|
||||
|
||||
// handleStreamResponse 处理流式响应 (SSE)
|
||||
func (l *LLMApi) handleStreamResponse(c *gin.Context, chatReq request.ChatRequest) {
|
||||
// 设置SSE响应头
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Headers", "Cache-Control")
|
||||
|
||||
// 创建事件通道
|
||||
eventChan := make(chan volcengineResponse.StreamEvent, 100)
|
||||
defer close(eventChan)
|
||||
|
||||
// 启动流式处理
|
||||
go func() {
|
||||
err := service.LLMServiceApp.StreamChatCompletion(chatReq, eventChan)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("流式聊天失败!", zap.Error(err), zap.String("requestID", chatReq.RequestID))
|
||||
eventChan <- volcengineResponse.StreamEvent{
|
||||
Event: "error",
|
||||
Data: volcengineResponse.ChatResponse{
|
||||
ID: chatReq.RequestID,
|
||||
Error: &volcengineResponse.APIError{
|
||||
Code: "stream_error",
|
||||
Message: err.Error(),
|
||||
Type: "internal_error",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 发送流式数据
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case event, ok := <-eventChan:
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
switch event.Event {
|
||||
case "message":
|
||||
// 发送消息事件
|
||||
c.SSEvent("message", event.Data)
|
||||
case "error":
|
||||
// 发送错误事件
|
||||
c.SSEvent("error", event.Data)
|
||||
return false
|
||||
case "done":
|
||||
// 发送完成事件
|
||||
c.SSEvent("done", event.Data)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
case <-time.After(30 * time.Second):
|
||||
// 超时处理
|
||||
c.SSEvent("error", volcengineResponse.ChatResponse{
|
||||
ID: chatReq.RequestID,
|
||||
Error: &volcengineResponse.APIError{
|
||||
Code: "timeout",
|
||||
Message: "Stream timeout",
|
||||
Type: "timeout_error",
|
||||
},
|
||||
})
|
||||
return false
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// StopGeneration
|
||||
// @Tags Volcengine
|
||||
// @Summary 停止LLM生成
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.StopRequest true "停止请求参数"
|
||||
// @Success 200 {object} response.Response{data=volcengineResponse.StopResponse} "停止响应"
|
||||
// @Router /volcengine/llm/stop [post]
|
||||
func (l *LLMApi) StopGeneration(c *gin.Context) {
|
||||
var stopReq request.StopRequest
|
||||
err := c.ShouldBindJSON(&stopReq)
|
||||
if err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := service.LLMServiceApp.StopGeneration(stopReq)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("停止生成失败!", zap.Error(err), zap.String("requestID", stopReq.RequestID))
|
||||
response.FailWithMessage("停止生成失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Success {
|
||||
global.GVA_LOG.Info("停止生成成功", zap.String("requestID", stopReq.RequestID))
|
||||
response.OkWithDetailed(resp, "停止生成成功", c)
|
||||
} else {
|
||||
global.GVA_LOG.Warn("停止生成失败", zap.String("requestID", stopReq.RequestID), zap.String("reason", resp.Message))
|
||||
response.FailWithDetailed(resp, resp.Message, c)
|
||||
}
|
||||
}
|
||||
|
||||
// GetActiveSessionsCount
|
||||
// @Tags Volcengine
|
||||
// @Summary 获取活跃会话数量
|
||||
// @Security ApiKeyAuth
|
||||
// @Produce application/json
|
||||
// @Success 200 {object} response.Response{data=int} "活跃会话数量"
|
||||
// @Router /volcengine/llm/sessions [get]
|
||||
func (l *LLMApi) GetActiveSessionsCount(c *gin.Context) {
|
||||
count := service.LLMServiceApp.GetActiveSessionsCount()
|
||||
response.OkWithDetailed(gin.H{"count": count}, fmt.Sprintf("当前活跃会话数量: %d", count), c)
|
||||
}
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
package volcengine
|
||||
|
||||
import (
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/plugin/volcengine/router"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type volcenginePlugin struct{}
|
||||
|
||||
// CreateVolcenginePlug 创建volcengine插件实例
|
||||
func CreateVolcenginePlug() *volcenginePlugin {
|
||||
return &volcenginePlugin{}
|
||||
}
|
||||
|
||||
// Register 注册路由
|
||||
func (*volcenginePlugin) Register(group *gin.RouterGroup) {
|
||||
router.RouterGroupApp.InitLLMRouter(group)
|
||||
}
|
||||
|
||||
// RouterPath 返回注册路由路径
|
||||
func (*volcenginePlugin) RouterPath() string {
|
||||
return "volcengine"
|
||||
}
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
package request
|
||||
|
||||
// Message 聊天消息结构体
|
||||
type Message struct {
|
||||
Role string `json:"role" form:"role" binding:"required"` // 角色: system, user, assistant
|
||||
Content string `json:"content" form:"content" binding:"required"` // 消息内容
|
||||
}
|
||||
|
||||
// ChatRequest 聊天请求结构体
|
||||
type ChatRequest struct {
|
||||
Messages []Message `json:"messages" form:"messages" binding:"required"` // 对话消息列表
|
||||
Model string `json:"model" form:"model"` // 模型名称,如果为空则使用默认模型
|
||||
Stream bool `json:"stream" form:"stream"` // 是否流式响应,默认false
|
||||
Temperature float64 `json:"temperature" form:"temperature"` // 温度参数,控制随机性,范围0-1
|
||||
MaxTokens int `json:"max_tokens" form:"max_tokens"` // 最大生成token数
|
||||
TopP float64 `json:"top_p" form:"top_p"` // 核采样参数
|
||||
RequestID string `json:"request_id,omitempty" form:"request_id,omitempty"` // 请求ID,用于流式响应管理
|
||||
}
|
||||
|
||||
// StopRequest 停止生成请求结构体
|
||||
type StopRequest struct {
|
||||
RequestID string `json:"request_id" form:"request_id" binding:"required"` // 要停止的请求ID
|
||||
}
|
||||
|
|
@ -1,66 +0,0 @@
|
|||
package response
|
||||
|
||||
import "time"
|
||||
|
||||
// Message 聊天消息结构体
|
||||
type Message struct {
|
||||
Role string `json:"role"` // 角色: system, user, assistant
|
||||
Content string `json:"content"` // 消息内容
|
||||
}
|
||||
|
||||
// Choice 响应选择结构体
|
||||
type Choice struct {
|
||||
Index int `json:"index"` // 选择索引
|
||||
Message Message `json:"message,omitempty"` // 完整响应消息(非流式)
|
||||
Delta Message `json:"delta,omitempty"` // 增量消息(流式)
|
||||
FinishReason string `json:"finish_reason,omitempty"` // 结束原因: stop, length, content_filter
|
||||
}
|
||||
|
||||
// Usage 使用量统计结构体
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"` // 输入token数
|
||||
CompletionTokens int `json:"completion_tokens"` // 输出token数
|
||||
TotalTokens int `json:"total_tokens"` // 总token数
|
||||
}
|
||||
|
||||
// APIError API错误结构体
|
||||
type APIError struct {
|
||||
Code string `json:"code"` // 错误代码
|
||||
Message string `json:"message"` // 错误消息
|
||||
Type string `json:"type"` // 错误类型
|
||||
}
|
||||
|
||||
// ChatResponse 聊天响应结构体
|
||||
type ChatResponse struct {
|
||||
ID string `json:"id"` // 响应ID
|
||||
Object string `json:"object"` // 对象类型: chat.completion 或 chat.completion.chunk
|
||||
Created int64 `json:"created"` // 创建时间戳
|
||||
Model string `json:"model"` // 使用的模型
|
||||
Choices []Choice `json:"choices"` // 响应选择列表
|
||||
Usage *Usage `json:"usage,omitempty"` // 使用量统计(非流式响应)
|
||||
Error *APIError `json:"error,omitempty"` // 错误信息
|
||||
}
|
||||
|
||||
// StopResponse 停止生成响应结构体
|
||||
type StopResponse struct {
|
||||
Success bool `json:"success"` // 是否成功停止
|
||||
Message string `json:"message"` // 响应消息
|
||||
RequestID string `json:"request_id"` // 请求ID
|
||||
StoppedAt int64 `json:"stopped_at"` // 停止时间戳
|
||||
}
|
||||
|
||||
// StreamEvent 流式事件结构体
|
||||
type StreamEvent struct {
|
||||
Event string `json:"event"` // 事件类型: message, error, done
|
||||
Data ChatResponse `json:"data"` // 事件数据
|
||||
}
|
||||
|
||||
// ChatSession 聊天会话管理结构体(用于内部管理)
|
||||
type ChatSession struct {
|
||||
RequestID string `json:"request_id"` // 请求ID
|
||||
UserID uint `json:"user_id"` // 用户ID
|
||||
Model string `json:"model"` // 使用的模型
|
||||
StartTime time.Time `json:"start_time"` // 开始时间
|
||||
IsStreaming bool `json:"is_streaming"` // 是否流式响应
|
||||
Status string `json:"status"` // 状态: running, completed, stopped, error
|
||||
}
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
package router
|
||||
|
||||
type RouterGroup struct {
|
||||
LLMRouter
|
||||
}
|
||||
|
||||
var RouterGroupApp = new(RouterGroup)
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/middleware"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/plugin/volcengine/api"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type LLMRouter struct{}
|
||||
|
||||
// InitLLMRouter 初始化LLM路由
|
||||
func (l *LLMRouter) InitLLMRouter(Router *gin.RouterGroup) {
|
||||
// LLM相关路由组,使用操作记录中间件
|
||||
llmRouter := Router.Group("llm").Use(middleware.OperationRecord())
|
||||
// LLM查询路由组,不使用操作记录中间件(用于GET请求)
|
||||
llmRouterWithoutRecord := Router.Group("llm")
|
||||
|
||||
// 获取API实例
|
||||
llmApi := api.ApiGroupApp.LLMApi
|
||||
|
||||
{
|
||||
// 需要记录操作的路由(POST请求)
|
||||
llmRouter.POST("chat", llmApi.ChatCompletion) // LLM聊天完成
|
||||
llmRouter.POST("stop", llmApi.StopGeneration) // 停止LLM生成
|
||||
}
|
||||
{
|
||||
// 不需要记录操作的路由(GET请求)
|
||||
llmRouterWithoutRecord.GET("sessions", llmApi.GetActiveSessionsCount) // 获取活跃会话数量
|
||||
}
|
||||
}
|
||||
|
|
@ -1,316 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/plugin/volcengine/model/request"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/plugin/volcengine/model/response"
|
||||
"github.com/google/uuid"
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type LLMService struct {
|
||||
client *arkruntime.Client
|
||||
activeSessions sync.Map // requestID -> context.CancelFunc
|
||||
}
|
||||
|
||||
var LLMServiceApp = new(LLMService)
|
||||
|
||||
// InitVolcengineClient 初始化火山引擎客户端
|
||||
func (l *LLMService) InitVolcengineClient() error {
|
||||
config := global.GVA_CONFIG.Volcengine
|
||||
|
||||
// 验证配置
|
||||
if config.AccessKey == "" || config.SecretKey == "" {
|
||||
return fmt.Errorf("volcengine access key or secret key is empty")
|
||||
}
|
||||
|
||||
// 创建ARK Runtime客户端
|
||||
l.client = arkruntime.NewClientWithAkSk(config.AccessKey, config.SecretKey)
|
||||
|
||||
global.GVA_LOG.Info("Volcengine client initialized successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChatCompletion 非流式聊天完成
|
||||
func (l *LLMService) ChatCompletion(req request.ChatRequest) (*response.ChatResponse, error) {
|
||||
if l.client == nil {
|
||||
if err := l.InitVolcengineClient(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 生成请求ID
|
||||
if req.RequestID == "" {
|
||||
req.RequestID = uuid.New().String()
|
||||
}
|
||||
|
||||
// 转换消息格式
|
||||
messages := make([]*model.ChatCompletionMessage, len(req.Messages))
|
||||
for i, msg := range req.Messages {
|
||||
messages[i] = &model.ChatCompletionMessage{
|
||||
Role: msg.Role,
|
||||
Content: &model.ChatCompletionMessageContent{
|
||||
StringValue: &msg.Content,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 设置模型
|
||||
modelName := req.Model
|
||||
if modelName == "" {
|
||||
modelName = global.GVA_CONFIG.Volcengine.DefaultModel
|
||||
}
|
||||
|
||||
// 创建请求
|
||||
chatReq := &model.ChatCompletionRequest{
|
||||
Model: modelName,
|
||||
Messages: messages,
|
||||
}
|
||||
|
||||
// 设置可选参数
|
||||
if req.Temperature > 0 {
|
||||
chatReq.Temperature = float32(req.Temperature)
|
||||
}
|
||||
if req.MaxTokens > 0 {
|
||||
chatReq.MaxTokens = req.MaxTokens
|
||||
}
|
||||
if req.TopP > 0 {
|
||||
chatReq.TopP = float32(req.TopP)
|
||||
}
|
||||
|
||||
// 创建上下文
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 存储取消函数
|
||||
l.activeSessions.Store(req.RequestID, cancel)
|
||||
defer l.activeSessions.Delete(req.RequestID)
|
||||
|
||||
// 调用API
|
||||
resp, err := l.client.CreateChatCompletion(ctx, chatReq)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("Chat completion failed", zap.Error(err), zap.String("requestID", req.RequestID))
|
||||
return nil, fmt.Errorf("chat completion failed: %w", err)
|
||||
}
|
||||
|
||||
// 转换响应格式
|
||||
chatResp := &response.ChatResponse{
|
||||
ID: req.RequestID,
|
||||
Object: "chat.completion",
|
||||
Created: time.Now().Unix(),
|
||||
Model: modelName,
|
||||
Choices: make([]response.Choice, len(resp.Choices)),
|
||||
}
|
||||
|
||||
for i, choice := range resp.Choices {
|
||||
messageContent := ""
|
||||
if choice.Message.Content != nil && choice.Message.Content.StringValue != nil {
|
||||
messageContent = *choice.Message.Content.StringValue
|
||||
}
|
||||
|
||||
chatResp.Choices[i] = response.Choice{
|
||||
Index: i,
|
||||
Message: response.Message{
|
||||
Role: choice.Message.Role,
|
||||
Content: messageContent,
|
||||
},
|
||||
FinishReason: string(choice.FinishReason),
|
||||
}
|
||||
}
|
||||
|
||||
// 设置使用量统计
|
||||
chatResp.Usage = &response.Usage{
|
||||
PromptTokens: resp.Usage.PromptTokens,
|
||||
CompletionTokens: resp.Usage.CompletionTokens,
|
||||
TotalTokens: resp.Usage.TotalTokens,
|
||||
}
|
||||
|
||||
global.GVA_LOG.Info("Chat completion successful", zap.String("requestID", req.RequestID))
|
||||
return chatResp, nil
|
||||
}
|
||||
|
||||
// StreamChatCompletion 流式聊天完成
|
||||
func (l *LLMService) StreamChatCompletion(req request.ChatRequest, eventChan chan<- response.StreamEvent) error {
|
||||
if l.client == nil {
|
||||
if err := l.InitVolcengineClient(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 生成请求ID
|
||||
if req.RequestID == "" {
|
||||
req.RequestID = uuid.New().String()
|
||||
}
|
||||
|
||||
// 转换消息格式
|
||||
messages := make([]*model.ChatCompletionMessage, len(req.Messages))
|
||||
for i, msg := range req.Messages {
|
||||
messages[i] = &model.ChatCompletionMessage{
|
||||
Role: msg.Role,
|
||||
Content: &model.ChatCompletionMessageContent{
|
||||
StringValue: &msg.Content,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 设置模型
|
||||
modelName := req.Model
|
||||
if modelName == "" {
|
||||
modelName = global.GVA_CONFIG.Volcengine.DefaultModel
|
||||
}
|
||||
|
||||
// 创建流式请求
|
||||
chatReq := &model.ChatCompletionRequest{
|
||||
Model: modelName,
|
||||
Messages: messages,
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
// 设置可选参数
|
||||
if req.Temperature > 0 {
|
||||
chatReq.Temperature = float32(req.Temperature)
|
||||
}
|
||||
if req.MaxTokens > 0 {
|
||||
chatReq.MaxTokens = req.MaxTokens
|
||||
}
|
||||
if req.TopP > 0 {
|
||||
chatReq.TopP = float32(req.TopP)
|
||||
}
|
||||
|
||||
// 创建上下文
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second) // 流式响应需要更长时间
|
||||
defer cancel()
|
||||
|
||||
// 存储取消函数
|
||||
l.activeSessions.Store(req.RequestID, cancel)
|
||||
defer l.activeSessions.Delete(req.RequestID)
|
||||
|
||||
// 调用流式API
|
||||
stream_resp, err := l.client.CreateChatCompletionStream(ctx, chatReq)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("Stream chat completion failed", zap.Error(err), zap.String("requestID", req.RequestID))
|
||||
return fmt.Errorf("stream chat completion failed: %w", err)
|
||||
}
|
||||
defer stream_resp.Close()
|
||||
|
||||
// 处理流式响应
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// 上下文取消
|
||||
eventChan <- response.StreamEvent{
|
||||
Event: "error",
|
||||
Data: response.ChatResponse{
|
||||
ID: req.RequestID,
|
||||
Error: &response.APIError{
|
||||
Code: "context_cancelled",
|
||||
Message: "Request was cancelled",
|
||||
Type: "request_cancelled",
|
||||
},
|
||||
},
|
||||
}
|
||||
return ctx.Err()
|
||||
default:
|
||||
// 接收流式数据
|
||||
recv, err := stream_resp.Recv()
|
||||
if err != nil {
|
||||
if err.Error() == "EOF" {
|
||||
// 流结束
|
||||
eventChan <- response.StreamEvent{
|
||||
Event: "done",
|
||||
Data: response.ChatResponse{
|
||||
ID: req.RequestID,
|
||||
Object: "chat.completion.chunk",
|
||||
},
|
||||
}
|
||||
global.GVA_LOG.Info("Stream chat completion finished", zap.String("requestID", req.RequestID))
|
||||
return nil
|
||||
}
|
||||
global.GVA_LOG.Error("Stream receive error", zap.Error(err), zap.String("requestID", req.RequestID))
|
||||
eventChan <- response.StreamEvent{
|
||||
Event: "error",
|
||||
Data: response.ChatResponse{
|
||||
ID: req.RequestID,
|
||||
Error: &response.APIError{
|
||||
Code: "stream_error",
|
||||
Message: err.Error(),
|
||||
Type: "stream_error",
|
||||
},
|
||||
},
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 转换并发送响应
|
||||
chatResp := response.ChatResponse{
|
||||
ID: req.RequestID,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: time.Now().Unix(),
|
||||
Model: modelName,
|
||||
Choices: make([]response.Choice, len(recv.Choices)),
|
||||
}
|
||||
|
||||
for i, choice := range recv.Choices {
|
||||
chatResp.Choices[i] = response.Choice{
|
||||
Index: i,
|
||||
Delta: response.Message{
|
||||
Role: choice.Delta.Role,
|
||||
Content: choice.Delta.Content,
|
||||
},
|
||||
}
|
||||
if choice.FinishReason != "" {
|
||||
chatResp.Choices[i].FinishReason = string(choice.FinishReason)
|
||||
}
|
||||
}
|
||||
|
||||
eventChan <- response.StreamEvent{
|
||||
Event: "message",
|
||||
Data: chatResp,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StopGeneration 停止生成
|
||||
func (l *LLMService) StopGeneration(req request.StopRequest) (*response.StopResponse, error) {
|
||||
// 查找并取消对应的请求
|
||||
if cancelFunc, ok := l.activeSessions.Load(req.RequestID); ok {
|
||||
if cancel, ok := cancelFunc.(context.CancelFunc); ok {
|
||||
cancel()
|
||||
l.activeSessions.Delete(req.RequestID)
|
||||
|
||||
global.GVA_LOG.Info("Request stopped successfully", zap.String("requestID", req.RequestID))
|
||||
return &response.StopResponse{
|
||||
Success: true,
|
||||
Message: "Request stopped successfully",
|
||||
RequestID: req.RequestID,
|
||||
StoppedAt: time.Now().Unix(),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
global.GVA_LOG.Warn("Request not found or already completed", zap.String("requestID", req.RequestID))
|
||||
return &response.StopResponse{
|
||||
Success: false,
|
||||
Message: "Request not found or already completed",
|
||||
RequestID: req.RequestID,
|
||||
StoppedAt: time.Now().Unix(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetActiveSessionsCount 获取活跃会话数量
|
||||
func (l *LLMService) GetActiveSessionsCount() int {
|
||||
count := 0
|
||||
l.activeSessions.Range(func(key, value interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
return count
|
||||
}
|
||||
|
|
@ -13,6 +13,7 @@ type RouterGroup struct {
|
|||
PetFamilyPetsRouter
|
||||
PetPetsRouter
|
||||
PetRecordsRouter
|
||||
PetAssistantRouter
|
||||
}
|
||||
|
||||
var (
|
||||
|
|
@ -26,4 +27,5 @@ var (
|
|||
petFamilyPetsApi = api.ApiGroupApp.PetApiGroup.PetFamilyPetsApi
|
||||
petPetsApi = api.ApiGroupApp.PetApiGroup.PetPetsApi
|
||||
petRecordsApi = api.ApiGroupApp.PetApiGroup.PetRecordsApi
|
||||
petUserApiGroup = api.ApiGroupApp.PetApiGroup.PetUserApiGroup
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,25 @@
|
|||
package pet
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type PetAssistantRouter struct{}
|
||||
|
||||
// InitPetAssistantRouter 初始化宠物助手路由信息
|
||||
func (p *PetAssistantRouter) InitPetAssistantRouter(UserRouter *gin.RouterGroup, PublicRouter *gin.RouterGroup) {
|
||||
// 宠物助手路由组,UserRouter已经应用了UserJWTAuth中间件
|
||||
petAssistantRouter := UserRouter.Group("pet/user/assistant")
|
||||
|
||||
// 获取宠物助手API实例
|
||||
petAssistantApi := petUserApiGroup.PetAssistantApi
|
||||
|
||||
{
|
||||
// 宠物助手问答相关路由
|
||||
petAssistantRouter.POST("ask", petAssistantApi.AskPetAssistant) // 向宠物助手提问
|
||||
petAssistantRouter.POST("stream-ask", petAssistantApi.StreamAskPetAssistant) // 向宠物助手流式提问
|
||||
petAssistantRouter.GET("history", petAssistantApi.GetAssistantHistory) // 获取宠物助手对话历史
|
||||
petAssistantRouter.DELETE("clear-history", petAssistantApi.ClearAssistantHistory) // 清空宠物助手对话历史
|
||||
petAssistantRouter.GET("sessions", petAssistantApi.GetAssistantSessions) // 获取宠物助手会话列表
|
||||
}
|
||||
}
|
||||
|
|
@ -11,4 +11,5 @@ type ServiceGroup struct {
|
|||
PetFamilyPetsService
|
||||
PetPetsService
|
||||
PetRecordsService
|
||||
PetChatService
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,438 @@
|
|||
package pet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/pet"
|
||||
petRequest "github.com/flipped-aurora/gin-vue-admin/server/model/pet/request"
|
||||
petResponse "github.com/flipped-aurora/gin-vue-admin/server/model/pet/response"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/utils"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// 使用新定义的模型结构
|
||||
type ChatRequest = petRequest.ChatRequest
|
||||
type ChatResponse = petResponse.ChatResponse
|
||||
type StreamEvent = petResponse.StreamEvent
|
||||
|
||||
// PetChatService 宠物聊天服务
|
||||
type PetChatService struct{}
|
||||
|
||||
// SendMessage 发送消息(非流式)
|
||||
func (p *PetChatService) SendMessage(ctx context.Context, userId uint, req ChatRequest) (*ChatResponse, error) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 1. 敏感词检测
|
||||
sensitiveUtil := utils.GetSensitiveWordUtil()
|
||||
filtered, hasSensitive, err := sensitiveUtil.FilterText(req.Message)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("敏感词检测失败", zap.Error(err))
|
||||
return nil, fmt.Errorf("敏感词检测失败: %v", err)
|
||||
}
|
||||
|
||||
// 如果包含敏感词,记录并返回提示
|
||||
if hasSensitive {
|
||||
global.GVA_LOG.Warn("用户消息包含敏感词", zap.Uint("userId", userId), zap.String("original", req.Message), zap.String("filtered", filtered))
|
||||
|
||||
// 保存用户消息(包含敏感词标记)
|
||||
if err := p.saveUserMessage(ctx, userId, req.SessionId, req.Message, true); err != nil {
|
||||
global.GVA_LOG.Error("保存用户消息失败", zap.Error(err))
|
||||
}
|
||||
|
||||
return &ChatResponse{
|
||||
Message: "您的消息包含不当内容,请重新输入。",
|
||||
SessionId: req.SessionId,
|
||||
IsSensitive: true,
|
||||
TokenCount: 0,
|
||||
ResponseTime: time.Since(startTime).Milliseconds(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 2. 生成会话ID(如果没有提供)
|
||||
sessionId := req.SessionId
|
||||
if sessionId == "" {
|
||||
sessionId = uuid.New().String()
|
||||
}
|
||||
|
||||
// 3. 获取对话历史构建上下文
|
||||
history, err := p.GetChatHistory(ctx, userId, sessionId, 10) // 获取最近10条消息
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("获取对话历史失败", zap.Error(err))
|
||||
return nil, fmt.Errorf("获取对话历史失败: %v", err)
|
||||
}
|
||||
|
||||
// 4. 构建LLM请求
|
||||
messages := p.buildMessages(history, req.Message)
|
||||
llmReq := utils.LLMRequest{
|
||||
Messages: messages,
|
||||
Model: req.Model,
|
||||
Stream: false,
|
||||
Temperature: req.Temperature,
|
||||
MaxTokens: req.MaxTokens,
|
||||
RequestID: uuid.New().String(),
|
||||
}
|
||||
|
||||
// 5. 调用LLM服务
|
||||
llmUtil := utils.GetVolcengineLLMUtil()
|
||||
llmResp, err := llmUtil.ChatCompletion(llmReq)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("LLM调用失败", zap.Error(err))
|
||||
return nil, fmt.Errorf("LLM调用失败: %v", err)
|
||||
}
|
||||
|
||||
// 6. 提取AI回复
|
||||
var aiMessage string
|
||||
if len(llmResp.Choices) > 0 {
|
||||
aiMessage = llmResp.Choices[0].Message.Content
|
||||
}
|
||||
|
||||
// 7. 对AI回复进行敏感词检测
|
||||
aiFiltered, aiHasSensitive, err := sensitiveUtil.FilterText(aiMessage)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("AI回复敏感词检测失败", zap.Error(err))
|
||||
aiFiltered = aiMessage // 如果检测失败,使用原始消息
|
||||
}
|
||||
|
||||
if aiHasSensitive {
|
||||
global.GVA_LOG.Warn("AI回复包含敏感词", zap.String("original", aiMessage), zap.String("filtered", aiFiltered))
|
||||
aiMessage = aiFiltered
|
||||
}
|
||||
|
||||
// 8. 保存对话记录
|
||||
responseTime := time.Since(startTime).Milliseconds()
|
||||
tokenCount := 0
|
||||
if llmResp.Usage != nil {
|
||||
tokenCount = llmResp.Usage.TotalTokens
|
||||
}
|
||||
|
||||
// 保存用户消息
|
||||
if err := p.saveUserMessage(ctx, userId, sessionId, req.Message, false); err != nil {
|
||||
global.GVA_LOG.Error("保存用户消息失败", zap.Error(err))
|
||||
}
|
||||
|
||||
// 保存AI回复
|
||||
if err := p.saveAssistantMessage(ctx, userId, sessionId, aiMessage, aiHasSensitive, tokenCount, responseTime); err != nil {
|
||||
global.GVA_LOG.Error("保存AI回复失败", zap.Error(err))
|
||||
}
|
||||
|
||||
return &ChatResponse{
|
||||
Message: aiMessage,
|
||||
SessionId: sessionId,
|
||||
IsSensitive: aiHasSensitive,
|
||||
TokenCount: tokenCount,
|
||||
ResponseTime: responseTime,
|
||||
RequestId: llmReq.RequestID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// StreamChat 流式聊天
|
||||
func (p *PetChatService) StreamChat(ctx context.Context, userId uint, req ChatRequest, eventChan chan<- StreamEvent) error {
|
||||
startTime := time.Now()
|
||||
|
||||
// 1. 敏感词检测
|
||||
sensitiveUtil := utils.GetSensitiveWordUtil()
|
||||
filtered, hasSensitive, err := sensitiveUtil.FilterText(req.Message)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("敏感词检测失败", zap.Error(err))
|
||||
eventChan <- StreamEvent{
|
||||
Event: "error",
|
||||
Data: map[string]interface{}{
|
||||
"error": "敏感词检测失败",
|
||||
},
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果包含敏感词,返回提示并结束
|
||||
if hasSensitive {
|
||||
global.GVA_LOG.Warn("用户消息包含敏感词", zap.Uint("userId", userId), zap.String("original", req.Message), zap.String("filtered", filtered))
|
||||
|
||||
// 保存用户消息(包含敏感词标记)
|
||||
if err := p.saveUserMessage(ctx, userId, req.SessionId, req.Message, true); err != nil {
|
||||
global.GVA_LOG.Error("保存用户消息失败", zap.Error(err))
|
||||
}
|
||||
|
||||
eventChan <- StreamEvent{
|
||||
Event: "message",
|
||||
Data: ChatResponse{
|
||||
Message: "您的消息包含不当内容,请重新输入。",
|
||||
SessionId: req.SessionId,
|
||||
IsSensitive: true,
|
||||
TokenCount: 0,
|
||||
ResponseTime: time.Since(startTime).Milliseconds(),
|
||||
},
|
||||
}
|
||||
|
||||
eventChan <- StreamEvent{
|
||||
Event: "done",
|
||||
Data: map[string]interface{}{"finished": true},
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 2. 生成会话ID(如果没有提供)
|
||||
sessionId := req.SessionId
|
||||
if sessionId == "" {
|
||||
sessionId = uuid.New().String()
|
||||
}
|
||||
|
||||
// 3. 获取对话历史构建上下文
|
||||
history, err := p.GetChatHistory(ctx, userId, sessionId, 10)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("获取对话历史失败", zap.Error(err))
|
||||
eventChan <- StreamEvent{
|
||||
Event: "error",
|
||||
Data: map[string]interface{}{
|
||||
"error": "获取对话历史失败",
|
||||
},
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 4. 构建LLM请求
|
||||
messages := p.buildMessages(history, req.Message)
|
||||
llmReq := utils.LLMRequest{
|
||||
Messages: messages,
|
||||
Model: req.Model,
|
||||
Stream: true,
|
||||
Temperature: req.Temperature,
|
||||
MaxTokens: req.MaxTokens,
|
||||
RequestID: uuid.New().String(),
|
||||
}
|
||||
|
||||
// 5. 创建LLM流式响应通道
|
||||
llmEventChan := make(chan utils.LLMStreamEvent, 100)
|
||||
defer close(llmEventChan)
|
||||
|
||||
// 6. 启动LLM流式调用
|
||||
llmUtil := utils.GetVolcengineLLMUtil()
|
||||
go func() {
|
||||
if err := llmUtil.StreamChatCompletion(llmReq, llmEventChan); err != nil {
|
||||
global.GVA_LOG.Error("LLM流式调用失败", zap.Error(err))
|
||||
eventChan <- StreamEvent{
|
||||
Event: "error",
|
||||
Data: map[string]interface{}{
|
||||
"error": "LLM调用失败",
|
||||
},
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 7. 处理流式响应
|
||||
var fullMessage string
|
||||
var tokenCount int
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case llmEvent, ok := <-llmEventChan:
|
||||
if !ok {
|
||||
// 通道关闭,流式响应结束
|
||||
return nil
|
||||
}
|
||||
|
||||
switch llmEvent.Event {
|
||||
case "message":
|
||||
// 处理消息事件
|
||||
if len(llmEvent.Data.Choices) > 0 {
|
||||
delta := llmEvent.Data.Choices[0].Delta.Content
|
||||
fullMessage += delta
|
||||
|
||||
// 转发消息给客户端
|
||||
eventChan <- StreamEvent{
|
||||
Event: "message",
|
||||
Data: map[string]interface{}{
|
||||
"delta": delta,
|
||||
"sessionId": sessionId,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
case "done":
|
||||
// 流式响应完成
|
||||
responseTime := time.Since(startTime).Milliseconds()
|
||||
if llmEvent.Data.Usage != nil {
|
||||
tokenCount = llmEvent.Data.Usage.TotalTokens
|
||||
}
|
||||
|
||||
// 对完整消息进行敏感词检测
|
||||
aiFiltered, aiHasSensitive, err := sensitiveUtil.FilterText(fullMessage)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("AI回复敏感词检测失败", zap.Error(err))
|
||||
aiFiltered = fullMessage
|
||||
}
|
||||
|
||||
if aiHasSensitive {
|
||||
global.GVA_LOG.Warn("AI回复包含敏感词", zap.String("original", fullMessage), zap.String("filtered", aiFiltered))
|
||||
fullMessage = aiFiltered
|
||||
}
|
||||
|
||||
// 保存对话记录
|
||||
if err := p.saveUserMessage(ctx, userId, sessionId, req.Message, false); err != nil {
|
||||
global.GVA_LOG.Error("保存用户消息失败", zap.Error(err))
|
||||
}
|
||||
|
||||
if err := p.saveAssistantMessage(ctx, userId, sessionId, fullMessage, aiHasSensitive, tokenCount, responseTime); err != nil {
|
||||
global.GVA_LOG.Error("保存AI回复失败", zap.Error(err))
|
||||
}
|
||||
|
||||
// 发送完成事件
|
||||
eventChan <- StreamEvent{
|
||||
Event: "done",
|
||||
Data: ChatResponse{
|
||||
Message: fullMessage,
|
||||
SessionId: sessionId,
|
||||
IsSensitive: aiHasSensitive,
|
||||
TokenCount: tokenCount,
|
||||
ResponseTime: responseTime,
|
||||
RequestId: llmReq.RequestID,
|
||||
},
|
||||
}
|
||||
return nil
|
||||
|
||||
case "error":
|
||||
// 处理错误事件
|
||||
global.GVA_LOG.Error("LLM流式响应错误", zap.Any("error", llmEvent.Data.Error))
|
||||
eventChan <- StreamEvent{
|
||||
Event: "error",
|
||||
Data: map[string]interface{}{
|
||||
"error": "LLM响应错误",
|
||||
},
|
||||
}
|
||||
return fmt.Errorf("LLM响应错误")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetChatHistory 获取对话历史
|
||||
func (p *PetChatService) GetChatHistory(ctx context.Context, userId uint, sessionId string, limit int) ([]pet.PetAiAssistantConversations, error) {
|
||||
var conversations []pet.PetAiAssistantConversations
|
||||
|
||||
query := global.GVA_DB.WithContext(ctx).Where("user_id = ?", userId)
|
||||
|
||||
// 如果提供了会话ID,则按会话ID过滤
|
||||
if sessionId != "" {
|
||||
query = query.Where("session_id = ?", sessionId)
|
||||
}
|
||||
|
||||
// 按创建时间倒序,限制数量
|
||||
if err := query.Order("created_at DESC").Limit(limit).Find(&conversations).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 反转切片,使其按时间正序排列
|
||||
for i, j := 0, len(conversations)-1; i < j; i, j = i+1, j-1 {
|
||||
conversations[i], conversations[j] = conversations[j], conversations[i]
|
||||
}
|
||||
|
||||
return conversations, nil
|
||||
}
|
||||
|
||||
// SaveConversation 保存对话记录
|
||||
func (p *PetChatService) SaveConversation(ctx context.Context, conversation *pet.PetAiAssistantConversations) error {
|
||||
return global.GVA_DB.WithContext(ctx).Create(conversation).Error
|
||||
}
|
||||
|
||||
// ClearChatHistory 清空对话历史
|
||||
func (p *PetChatService) ClearChatHistory(ctx context.Context, userId uint, sessionId string) error {
|
||||
query := global.GVA_DB.WithContext(ctx).Where("user_id = ?", userId)
|
||||
|
||||
// 如果提供了会话ID,则只清空指定会话
|
||||
if sessionId != "" {
|
||||
query = query.Where("session_id = ?", sessionId)
|
||||
}
|
||||
|
||||
return query.Delete(&pet.PetAiAssistantConversations{}).Error
|
||||
}
|
||||
|
||||
// GetChatSessions 获取用户的聊天会话列表
|
||||
func (p *PetChatService) GetChatSessions(ctx context.Context, userId uint) ([]map[string]interface{}, error) {
|
||||
var sessions []map[string]interface{}
|
||||
|
||||
// 查询用户的所有会话,按最后更新时间分组
|
||||
err := global.GVA_DB.WithContext(ctx).
|
||||
Model(&pet.PetAiAssistantConversations{}).
|
||||
Select("session_id, MAX(updated_at) as last_updated, COUNT(*) as message_count").
|
||||
Where("user_id = ? AND session_id IS NOT NULL AND session_id != ''", userId).
|
||||
Group("session_id").
|
||||
Order("last_updated DESC").
|
||||
Scan(&sessions).Error
|
||||
|
||||
return sessions, err
|
||||
}
|
||||
|
||||
// buildMessages 构建LLM请求消息
|
||||
func (p *PetChatService) buildMessages(history []pet.PetAiAssistantConversations, userMessage string) []utils.LLMMessage {
|
||||
messages := []utils.LLMMessage{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "你是一个专业的宠物助手,专门为宠物主人提供关于宠物护理、健康、训练和日常生活的建议。请用友善、专业的语气回答问题。",
|
||||
},
|
||||
}
|
||||
|
||||
// 添加历史对话
|
||||
for _, conv := range history {
|
||||
if conv.MessageContent != nil && conv.Role != nil {
|
||||
messages = append(messages, utils.LLMMessage{
|
||||
Role: *conv.Role,
|
||||
Content: *conv.MessageContent,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 添加当前用户消息
|
||||
messages = append(messages, utils.LLMMessage{
|
||||
Role: "user",
|
||||
Content: userMessage,
|
||||
})
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
// saveUserMessage 保存用户消息
|
||||
func (p *PetChatService) saveUserMessage(ctx context.Context, userId uint, sessionId, message string, isSensitive bool) error {
|
||||
userIdPtr := int(userId)
|
||||
rolePtr := "user"
|
||||
messagePtr := message
|
||||
sessionIdPtr := sessionId
|
||||
isSensitivePtr := isSensitive
|
||||
|
||||
conversation := &pet.PetAiAssistantConversations{
|
||||
UserId: &userIdPtr,
|
||||
MessageContent: &messagePtr,
|
||||
Role: &rolePtr,
|
||||
SessionId: &sessionIdPtr,
|
||||
IsSensitive: &isSensitivePtr,
|
||||
}
|
||||
|
||||
return p.SaveConversation(ctx, conversation)
|
||||
}
|
||||
|
||||
// saveAssistantMessage 保存AI助手消息
|
||||
func (p *PetChatService) saveAssistantMessage(ctx context.Context, userId uint, sessionId, message string, isSensitive bool, tokenCount int, responseTime int64) error {
|
||||
userIdPtr := int(userId)
|
||||
rolePtr := "assistant"
|
||||
messagePtr := message
|
||||
sessionIdPtr := sessionId
|
||||
isSensitivePtr := isSensitive
|
||||
tokenCountPtr := tokenCount
|
||||
responseTimePtr := int(responseTime)
|
||||
|
||||
conversation := &pet.PetAiAssistantConversations{
|
||||
UserId: &userIdPtr,
|
||||
MessageContent: &messagePtr,
|
||||
Role: &rolePtr,
|
||||
SessionId: &sessionIdPtr,
|
||||
IsSensitive: &isSensitivePtr,
|
||||
TokenCount: &tokenCountPtr,
|
||||
ResponseTime: &responseTimePtr,
|
||||
}
|
||||
|
||||
return p.SaveConversation(ctx, conversation)
|
||||
}
|
||||
|
|
@ -0,0 +1,158 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||||
swd "github.com/kirklin/go-swd"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type SensitiveWordUtil struct {
|
||||
detector *swd.SWD
|
||||
once sync.Once
|
||||
initErr error
|
||||
}
|
||||
|
||||
// 全局单例实例
|
||||
var (
|
||||
sensitiveWordInstance *SensitiveWordUtil
|
||||
sensitiveWordOnce sync.Once
|
||||
)
|
||||
|
||||
// GetSensitiveWordUtil 获取敏感词工具单例实例
|
||||
func GetSensitiveWordUtil() *SensitiveWordUtil {
|
||||
sensitiveWordOnce.Do(func() {
|
||||
sensitiveWordInstance = &SensitiveWordUtil{}
|
||||
})
|
||||
return sensitiveWordInstance
|
||||
}
|
||||
|
||||
// InitDetector 初始化敏感词检测器(单例模式)
|
||||
func (s *SensitiveWordUtil) InitDetector() error {
|
||||
s.once.Do(func() {
|
||||
detector, err := swd.New()
|
||||
if err != nil {
|
||||
s.initErr = err
|
||||
// 只有在global.GVA_LOG不为nil时才记录日志
|
||||
if global.GVA_LOG != nil {
|
||||
global.GVA_LOG.Error("Failed to initialize sensitive word detector", zap.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
s.detector = detector
|
||||
// 只有在global.GVA_LOG不为nil时才记录日志
|
||||
if global.GVA_LOG != nil {
|
||||
global.GVA_LOG.Info("Sensitive word detector initialized successfully")
|
||||
}
|
||||
})
|
||||
return s.initErr
|
||||
}
|
||||
|
||||
// DetectSensitive 检测文本中的敏感词
|
||||
func (s *SensitiveWordUtil) DetectSensitive(text string) ([]swd.SensitiveWord, error) {
|
||||
if s.detector == nil {
|
||||
if err := s.InitDetector(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
words := s.detector.MatchAll(text)
|
||||
return words, nil
|
||||
}
|
||||
|
||||
// HasSensitive 检查文本是否包含敏感词
|
||||
func (s *SensitiveWordUtil) HasSensitive(text string) (bool, error) {
|
||||
if s.detector == nil {
|
||||
if err := s.InitDetector(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
return s.detector.Detect(text), nil
|
||||
}
|
||||
|
||||
// FilterText 过滤文本中的敏感词,返回过滤后的文本和是否包含敏感词
|
||||
func (s *SensitiveWordUtil) FilterText(text string) (filtered string, hasSensitive bool, err error) {
|
||||
if s.detector == nil {
|
||||
if err := s.InitDetector(); err != nil {
|
||||
return text, false, err
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否包含敏感词
|
||||
hasSensitive = s.detector.Detect(text)
|
||||
|
||||
if !hasSensitive {
|
||||
return text, false, nil
|
||||
}
|
||||
|
||||
// 使用星号替换敏感词
|
||||
filtered = s.detector.ReplaceWithAsterisk(text)
|
||||
return filtered, true, nil
|
||||
}
|
||||
|
||||
// FilterTextWithCustomReplace 使用自定义替换策略过滤敏感词
|
||||
func (s *SensitiveWordUtil) FilterTextWithCustomReplace(text string, replaceFunc func(swd.SensitiveWord) string) (filtered string, hasSensitive bool, err error) {
|
||||
if s.detector == nil {
|
||||
if err := s.InitDetector(); err != nil {
|
||||
return text, false, err
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否包含敏感词
|
||||
hasSensitive = s.detector.Detect(text)
|
||||
|
||||
if !hasSensitive {
|
||||
return text, false, nil
|
||||
}
|
||||
|
||||
// 使用自定义替换策略
|
||||
filtered = s.detector.ReplaceWithStrategy(text, replaceFunc)
|
||||
return filtered, true, nil
|
||||
}
|
||||
|
||||
// GetFirstSensitiveWord 获取文本中第一个敏感词
|
||||
func (s *SensitiveWordUtil) GetFirstSensitiveWord(text string) (*swd.SensitiveWord, error) {
|
||||
if s.detector == nil {
|
||||
if err := s.InitDetector(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
word := s.detector.Match(text)
|
||||
return word, nil
|
||||
}
|
||||
|
||||
// AddCustomWords 添加自定义敏感词
|
||||
func (s *SensitiveWordUtil) AddCustomWords(words map[string]swd.Category) error {
|
||||
if s.detector == nil {
|
||||
if err := s.InitDetector(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return s.detector.AddWords(words)
|
||||
}
|
||||
|
||||
// RemoveWord 移除指定敏感词
|
||||
func (s *SensitiveWordUtil) RemoveWord(word string) error {
|
||||
if s.detector == nil {
|
||||
if err := s.InitDetector(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return s.detector.RemoveWord(word)
|
||||
}
|
||||
|
||||
// Clear 清空词库
|
||||
func (s *SensitiveWordUtil) Clear() error {
|
||||
if s.detector == nil {
|
||||
if err := s.InitDetector(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return s.detector.Clear()
|
||||
}
|
||||
|
|
@ -0,0 +1,379 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||||
"github.com/google/uuid"
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime"
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// LLMMessage 聊天消息结构体
|
||||
type LLMMessage struct {
|
||||
Role string `json:"role"` // 角色: system, user, assistant
|
||||
Content string `json:"content"` // 消息内容
|
||||
}
|
||||
|
||||
// LLMRequest 聊天请求结构体
|
||||
type LLMRequest struct {
|
||||
Messages []LLMMessage `json:"messages" binding:"required"` // 对话消息列表
|
||||
Model string `json:"model"` // 模型名称,如果为空则使用默认模型
|
||||
Stream bool `json:"stream"` // 是否流式响应,默认false
|
||||
Temperature float64 `json:"temperature"` // 温度参数,控制随机性,范围0-1
|
||||
MaxTokens int `json:"max_tokens"` // 最大生成token数
|
||||
TopP float64 `json:"top_p"` // 核采样参数
|
||||
RequestID string `json:"request_id,omitempty"` // 请求ID,用于流式响应管理
|
||||
}
|
||||
|
||||
// LLMChoice 响应选择结构体
|
||||
type LLMChoice struct {
|
||||
Index int `json:"index"` // 选择索引
|
||||
Message LLMMessage `json:"message,omitempty"` // 完整响应消息(非流式)
|
||||
Delta LLMMessage `json:"delta,omitempty"` // 增量消息(流式)
|
||||
FinishReason string `json:"finish_reason,omitempty"` // 结束原因: stop, length, content_filter
|
||||
}
|
||||
|
||||
// LLMUsage 使用量统计结构体
|
||||
type LLMUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"` // 输入token数
|
||||
CompletionTokens int `json:"completion_tokens"` // 输出token数
|
||||
TotalTokens int `json:"total_tokens"` // 总token数
|
||||
}
|
||||
|
||||
// LLMError API错误结构体
|
||||
type LLMError struct {
|
||||
Code string `json:"code"` // 错误代码
|
||||
Message string `json:"message"` // 错误消息
|
||||
Type string `json:"type"` // 错误类型
|
||||
}
|
||||
|
||||
// LLMResponse 聊天响应结构体
|
||||
type LLMResponse struct {
|
||||
ID string `json:"id"` // 响应ID
|
||||
Object string `json:"object"` // 对象类型: chat.completion 或 chat.completion.chunk
|
||||
Created int64 `json:"created"` // 创建时间戳
|
||||
Model string `json:"model"` // 使用的模型
|
||||
Choices []LLMChoice `json:"choices"` // 响应选择列表
|
||||
Usage *LLMUsage `json:"usage,omitempty"` // 使用量统计(非流式响应)
|
||||
Error *LLMError `json:"error,omitempty"` // 错误信息
|
||||
}
|
||||
|
||||
// LLMStreamEvent 流式事件结构体
|
||||
type LLMStreamEvent struct {
|
||||
Event string `json:"event"` // 事件类型: message, error, done
|
||||
Data LLMResponse `json:"data"` // 事件数据
|
||||
}
|
||||
|
||||
// VolcengineLLMUtil 火山引擎LLM工具类
|
||||
type VolcengineLLMUtil struct {
|
||||
client *arkruntime.Client
|
||||
activeSessions sync.Map // requestID -> context.CancelFunc
|
||||
once sync.Once
|
||||
initErr error
|
||||
}
|
||||
|
||||
// 全局单例实例
|
||||
var (
|
||||
volcengineLLMInstance *VolcengineLLMUtil
|
||||
volcengineLLMOnce sync.Once
|
||||
)
|
||||
|
||||
// GetVolcengineLLMUtil 获取火山引擎LLM工具单例实例
|
||||
func GetVolcengineLLMUtil() *VolcengineLLMUtil {
|
||||
volcengineLLMOnce.Do(func() {
|
||||
volcengineLLMInstance = &VolcengineLLMUtil{}
|
||||
})
|
||||
return volcengineLLMInstance
|
||||
}
|
||||
|
||||
// InitClient 初始化火山引擎客户端(单例模式)
|
||||
func (v *VolcengineLLMUtil) InitClient() error {
|
||||
v.once.Do(func() {
|
||||
config := global.GVA_CONFIG.Volcengine
|
||||
|
||||
// 验证配置
|
||||
if config.ApiKey == "" {
|
||||
v.initErr = fmt.Errorf("volcengine api key is empty")
|
||||
if global.GVA_LOG != nil {
|
||||
global.GVA_LOG.Error("Volcengine configuration error", zap.Error(v.initErr))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 创建ARK Runtime客户端,使用API Key
|
||||
v.client = arkruntime.NewClientWithApiKey(config.ApiKey)
|
||||
|
||||
if global.GVA_LOG != nil {
|
||||
global.GVA_LOG.Info("Volcengine LLM client initialized successfully")
|
||||
}
|
||||
})
|
||||
return v.initErr
|
||||
}
|
||||
|
||||
// ChatCompletion 非流式聊天完成
|
||||
func (v *VolcengineLLMUtil) ChatCompletion(req LLMRequest) (*LLMResponse, error) {
|
||||
if v.client == nil {
|
||||
if err := v.InitClient(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 生成请求ID
|
||||
if req.RequestID == "" {
|
||||
req.RequestID = uuid.New().String()
|
||||
}
|
||||
|
||||
// 转换消息格式
|
||||
messages := make([]*model.ChatCompletionMessage, len(req.Messages))
|
||||
for i, msg := range req.Messages {
|
||||
messages[i] = &model.ChatCompletionMessage{
|
||||
Role: msg.Role,
|
||||
Content: &model.ChatCompletionMessageContent{
|
||||
StringValue: &msg.Content,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 设置模型
|
||||
modelName := req.Model
|
||||
if modelName == "" {
|
||||
modelName = global.GVA_CONFIG.Volcengine.Model
|
||||
}
|
||||
|
||||
// 创建请求
|
||||
chatReq := &model.ChatCompletionRequest{
|
||||
Model: modelName,
|
||||
Messages: messages,
|
||||
Stream: false,
|
||||
}
|
||||
|
||||
// 设置可选参数
|
||||
if req.Temperature > 0 {
|
||||
chatReq.Temperature = float32(req.Temperature)
|
||||
}
|
||||
if req.MaxTokens > 0 {
|
||||
chatReq.MaxTokens = req.MaxTokens
|
||||
}
|
||||
if req.TopP > 0 {
|
||||
chatReq.TopP = float32(req.TopP)
|
||||
}
|
||||
|
||||
// 调用API
|
||||
resp, err := v.client.CreateChatCompletion(context.Background(), chatReq)
|
||||
if err != nil {
|
||||
if global.GVA_LOG != nil {
|
||||
global.GVA_LOG.Error("Chat completion failed", zap.Error(err), zap.String("requestID", req.RequestID))
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换响应格式
|
||||
choices := make([]LLMChoice, len(resp.Choices))
|
||||
for i, choice := range resp.Choices {
|
||||
choices[i] = LLMChoice{
|
||||
Index: choice.Index,
|
||||
Message: LLMMessage{
|
||||
Role: choice.Message.Role,
|
||||
Content: *choice.Message.Content.StringValue,
|
||||
},
|
||||
FinishReason: string(choice.FinishReason),
|
||||
}
|
||||
}
|
||||
|
||||
usage := &LLMUsage{
|
||||
PromptTokens: resp.Usage.PromptTokens,
|
||||
CompletionTokens: resp.Usage.CompletionTokens,
|
||||
TotalTokens: resp.Usage.TotalTokens,
|
||||
}
|
||||
|
||||
chatResp := &LLMResponse{
|
||||
ID: resp.ID,
|
||||
Object: resp.Object,
|
||||
Created: resp.Created,
|
||||
Model: resp.Model,
|
||||
Choices: choices,
|
||||
Usage: usage,
|
||||
}
|
||||
|
||||
if global.GVA_LOG != nil {
|
||||
global.GVA_LOG.Info("Chat completion successful", zap.String("requestID", req.RequestID))
|
||||
}
|
||||
return chatResp, nil
|
||||
}
|
||||
|
||||
// StreamChatCompletion 流式聊天完成
|
||||
func (v *VolcengineLLMUtil) StreamChatCompletion(req LLMRequest, eventChan chan<- LLMStreamEvent) error {
|
||||
if v.client == nil {
|
||||
if err := v.InitClient(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 生成请求ID
|
||||
if req.RequestID == "" {
|
||||
req.RequestID = uuid.New().String()
|
||||
}
|
||||
|
||||
// 转换消息格式
|
||||
messages := make([]*model.ChatCompletionMessage, len(req.Messages))
|
||||
for i, msg := range req.Messages {
|
||||
messages[i] = &model.ChatCompletionMessage{
|
||||
Role: msg.Role,
|
||||
Content: &model.ChatCompletionMessageContent{
|
||||
StringValue: &msg.Content,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 设置模型
|
||||
modelName := req.Model
|
||||
if modelName == "" {
|
||||
modelName = global.GVA_CONFIG.Volcengine.Model
|
||||
}
|
||||
|
||||
// 创建流式请求
|
||||
chatReq := &model.ChatCompletionRequest{
|
||||
Model: modelName,
|
||||
Messages: messages,
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
// 设置可选参数
|
||||
if req.Temperature > 0 {
|
||||
chatReq.Temperature = float32(req.Temperature)
|
||||
}
|
||||
if req.MaxTokens > 0 {
|
||||
chatReq.MaxTokens = req.MaxTokens
|
||||
}
|
||||
if req.TopP > 0 {
|
||||
chatReq.TopP = float32(req.TopP)
|
||||
}
|
||||
|
||||
// 创建上下文和取消函数
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
v.activeSessions.Store(req.RequestID, cancel)
|
||||
defer func() {
|
||||
v.activeSessions.Delete(req.RequestID)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
// 调用流式API
|
||||
stream_resp, err := v.client.CreateChatCompletionStream(ctx, chatReq)
|
||||
if err != nil {
|
||||
if global.GVA_LOG != nil {
|
||||
global.GVA_LOG.Error("Stream chat completion failed", zap.Error(err), zap.String("requestID", req.RequestID))
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer stream_resp.Close()
|
||||
|
||||
if global.GVA_LOG != nil {
|
||||
global.GVA_LOG.Info("Stream chat completion started", zap.String("requestID", req.RequestID))
|
||||
}
|
||||
|
||||
// 处理流式响应
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
// 接收流式数据
|
||||
recv, err := stream_resp.Recv()
|
||||
if err != nil {
|
||||
if err.Error() == "EOF" {
|
||||
// 流结束
|
||||
eventChan <- LLMStreamEvent{
|
||||
Event: "done",
|
||||
Data: LLMResponse{
|
||||
ID: req.RequestID,
|
||||
Object: "chat.completion.chunk",
|
||||
},
|
||||
}
|
||||
if global.GVA_LOG != nil {
|
||||
global.GVA_LOG.Info("Stream chat completion finished", zap.String("requestID", req.RequestID))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if global.GVA_LOG != nil {
|
||||
global.GVA_LOG.Error("Stream receive error", zap.Error(err), zap.String("requestID", req.RequestID))
|
||||
}
|
||||
eventChan <- LLMStreamEvent{
|
||||
Event: "error",
|
||||
Data: LLMResponse{
|
||||
ID: req.RequestID,
|
||||
Error: &LLMError{
|
||||
Code: "stream_error",
|
||||
Message: err.Error(),
|
||||
Type: "stream_error",
|
||||
},
|
||||
},
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 转换响应格式
|
||||
choices := make([]LLMChoice, len(recv.Choices))
|
||||
for i, choice := range recv.Choices {
|
||||
choices[i] = LLMChoice{
|
||||
Index: choice.Index,
|
||||
Delta: LLMMessage{
|
||||
Role: choice.Delta.Role,
|
||||
Content: choice.Delta.Content,
|
||||
},
|
||||
FinishReason: string(choice.FinishReason),
|
||||
}
|
||||
}
|
||||
|
||||
var usage *LLMUsage
|
||||
if recv.Usage != nil {
|
||||
usage = &LLMUsage{
|
||||
PromptTokens: recv.Usage.PromptTokens,
|
||||
CompletionTokens: recv.Usage.CompletionTokens,
|
||||
TotalTokens: recv.Usage.TotalTokens,
|
||||
}
|
||||
}
|
||||
|
||||
// 发送消息事件
|
||||
eventChan <- LLMStreamEvent{
|
||||
Event: "message",
|
||||
Data: LLMResponse{
|
||||
ID: recv.ID,
|
||||
Object: recv.Object,
|
||||
Created: recv.Created,
|
||||
Model: recv.Model,
|
||||
Choices: choices,
|
||||
Usage: usage,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// StopGeneration 停止生成
|
||||
func (v *VolcengineLLMUtil) StopGeneration(requestID string) error {
|
||||
if cancel, ok := v.activeSessions.Load(requestID); ok {
|
||||
if cancelFunc, ok := cancel.(context.CancelFunc); ok {
|
||||
cancelFunc()
|
||||
v.activeSessions.Delete(requestID)
|
||||
if global.GVA_LOG != nil {
|
||||
global.GVA_LOG.Info("Generation stopped", zap.String("requestID", requestID))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("session not found: %s", requestID)
|
||||
}
|
||||
|
||||
// GetActiveSessionsCount 获取活跃会话数量
|
||||
func (v *VolcengineLLMUtil) GetActiveSessionsCount() int {
|
||||
count := 0
|
||||
v.activeSessions.Range(func(key, value interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
return count
|
||||
}
|
||||
Loading…
Reference in New Issue