118 lines
2.6 KiB
Go
118 lines
2.6 KiB
Go
package initialize
|
||
|
||
import (
|
||
"kra/internal/conf"
|
||
mcpTool "kra/pkg/mcp/tool"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/mark3labs/mcp-go/mcp"
|
||
"github.com/mark3labs/mcp-go/server"
|
||
"go.uber.org/zap"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// MCPServer MCP服务器包装器
|
||
type MCPServer struct {
|
||
server *server.MCPServer
|
||
sseServer *server.SSEServer
|
||
config *conf.MCP
|
||
}
|
||
|
||
// NewMCPServer 创建MCP服务器
|
||
func NewMCPServer(config *conf.MCP) *MCPServer {
|
||
if config == nil || !config.Enabled {
|
||
return nil
|
||
}
|
||
|
||
s := server.NewMCPServer(
|
||
config.Name,
|
||
config.Version,
|
||
)
|
||
|
||
sseServer := server.NewSSEServer(s,
|
||
server.WithSSEEndpoint(config.SsePath),
|
||
server.WithMessageEndpoint(config.MessagePath),
|
||
server.WithBaseURL(config.UrlPrefix))
|
||
|
||
return &MCPServer{
|
||
server: s,
|
||
sseServer: sseServer,
|
||
config: config,
|
||
}
|
||
}
|
||
|
||
// Server 获取MCP服务器实例
|
||
func (m *MCPServer) Server() *server.MCPServer {
|
||
if m == nil {
|
||
return nil
|
||
}
|
||
return m.server
|
||
}
|
||
|
||
// SSEServer 获取SSE服务器实例
|
||
func (m *MCPServer) SSEServer() *server.SSEServer {
|
||
if m == nil {
|
||
return nil
|
||
}
|
||
return m.sseServer
|
||
}
|
||
|
||
// Config 获取MCP配置
|
||
func (m *MCPServer) Config() *conf.MCP {
|
||
if m == nil {
|
||
return nil
|
||
}
|
||
return m.config
|
||
}
|
||
|
||
// RegisterRoutes 注册MCP路由到Gin引擎
|
||
func (m *MCPServer) RegisterRoutes(engine *gin.Engine) {
|
||
if m == nil || m.sseServer == nil || m.config == nil {
|
||
return
|
||
}
|
||
|
||
// 如果不是独立模式,注册到主服务
|
||
if !m.config.Separate {
|
||
// 注册SSE端点
|
||
engine.GET(m.config.SsePath, func(c *gin.Context) {
|
||
m.sseServer.SSEHandler().ServeHTTP(c.Writer, c.Request)
|
||
})
|
||
|
||
// 注册消息端点
|
||
engine.POST(m.config.MessagePath, func(c *gin.Context) {
|
||
m.sseServer.MessageHandler().ServeHTTP(c.Writer, c.Request)
|
||
})
|
||
}
|
||
}
|
||
|
||
// RegisterTool 注册MCP工具
|
||
func (m *MCPServer) RegisterTool(tool mcp.Tool, handler server.ToolHandlerFunc) {
|
||
if m == nil || m.server == nil {
|
||
return
|
||
}
|
||
m.server.AddTool(tool, handler)
|
||
}
|
||
|
||
// MCPToolRegistrar MCP工具注册器接口
|
||
type MCPToolRegistrar interface {
|
||
RegisterTools(s *server.MCPServer)
|
||
}
|
||
|
||
// RegisterAllTools 注册所有MCP工具
|
||
func (m *MCPServer) RegisterAllTools(registrars ...MCPToolRegistrar) {
|
||
if m == nil || m.server == nil {
|
||
return
|
||
}
|
||
for _, registrar := range registrars {
|
||
registrar.RegisterTools(m.server)
|
||
}
|
||
}
|
||
|
||
// RegisterMCPTools 注册所有MCP工具(带依赖注入)
|
||
func (m *MCPServer) RegisterMCPTools(db *gorm.DB, logger *zap.Logger, autoCodeConfig *mcpTool.AutoCodeConfig) {
|
||
if m == nil || m.server == nil {
|
||
return
|
||
}
|
||
mcpTool.RegisterAllToolsWithDependencies(m.server, db, logger, autoCodeConfig)
|
||
}
|