kra/internal/initialize/mcp.go

118 lines
2.6 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package initialize
import (
"kra/internal/conf"
mcpTool "kra/pkg/mcp/tool"
"github.com/gin-gonic/gin"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
"go.uber.org/zap"
"gorm.io/gorm"
)
// MCPServer MCP服务器包装器
type MCPServer struct {
server *server.MCPServer
sseServer *server.SSEServer
config *conf.MCP
}
// NewMCPServer 创建MCP服务器
func NewMCPServer(config *conf.MCP) *MCPServer {
if config == nil || !config.Enabled {
return nil
}
s := server.NewMCPServer(
config.Name,
config.Version,
)
sseServer := server.NewSSEServer(s,
server.WithSSEEndpoint(config.SsePath),
server.WithMessageEndpoint(config.MessagePath),
server.WithBaseURL(config.UrlPrefix))
return &MCPServer{
server: s,
sseServer: sseServer,
config: config,
}
}
// Server 获取MCP服务器实例
func (m *MCPServer) Server() *server.MCPServer {
if m == nil {
return nil
}
return m.server
}
// SSEServer 获取SSE服务器实例
func (m *MCPServer) SSEServer() *server.SSEServer {
if m == nil {
return nil
}
return m.sseServer
}
// Config 获取MCP配置
func (m *MCPServer) Config() *conf.MCP {
if m == nil {
return nil
}
return m.config
}
// RegisterRoutes 注册MCP路由到Gin引擎
func (m *MCPServer) RegisterRoutes(engine *gin.Engine) {
if m == nil || m.sseServer == nil || m.config == nil {
return
}
// 如果不是独立模式,注册到主服务
if !m.config.Separate {
// 注册SSE端点
engine.GET(m.config.SsePath, func(c *gin.Context) {
m.sseServer.SSEHandler().ServeHTTP(c.Writer, c.Request)
})
// 注册消息端点
engine.POST(m.config.MessagePath, func(c *gin.Context) {
m.sseServer.MessageHandler().ServeHTTP(c.Writer, c.Request)
})
}
}
// RegisterTool 注册MCP工具
func (m *MCPServer) RegisterTool(tool mcp.Tool, handler server.ToolHandlerFunc) {
if m == nil || m.server == nil {
return
}
m.server.AddTool(tool, handler)
}
// MCPToolRegistrar MCP工具注册器接口
type MCPToolRegistrar interface {
RegisterTools(s *server.MCPServer)
}
// RegisterAllTools 注册所有MCP工具
func (m *MCPServer) RegisterAllTools(registrars ...MCPToolRegistrar) {
if m == nil || m.server == nil {
return
}
for _, registrar := range registrars {
registrar.RegisterTools(m.server)
}
}
// RegisterMCPTools 注册所有MCP工具带依赖注入
func (m *MCPServer) RegisterMCPTools(db *gorm.DB, logger *zap.Logger, autoCodeConfig *mcpTool.AutoCodeConfig) {
if m == nil || m.server == nil {
return
}
mcpTool.RegisterAllToolsWithDependencies(m.server, db, logger, autoCodeConfig)
}