package autocode import ( "fmt" "go/ast" "go/format" "go/parser" "go/token" "io" "os" "github.com/pkg/errors" ) // WireProviderAst Wire Provider AST 注入结构体 // 用于向各层的 ProviderSet 中注入新的 Provider type WireProviderAst struct { Layer string // 层名称: biz, data, handler, service, router PackageName string // 包名 ProviderName string // Provider 名称,如 NewArticleUsecase Path string // 目标文件路径 } // NewWireProviderAst 创建 Wire Provider AST 注入器 func NewWireProviderAst(layer, packageName, providerName, path string) *WireProviderAst { return &WireProviderAst{ Layer: layer, PackageName: packageName, ProviderName: providerName, Path: path, } } // Parse 解析文件 func (a *WireProviderAst) Parse(filename string, writer io.Writer) (*ast.File, error) { if filename == "" { filename = a.Path } if filename == "" { return nil, fmt.Errorf("文件路径为空") } fset := token.NewFileSet() file, err := parser.ParseFile(fset, filename, nil, parser.ParseComments) if err != nil { return nil, fmt.Errorf("解析文件失败 %s: %w", filename, err) } return file, nil } // Rollback 回滚注入 func (a *WireProviderAst) Rollback(file *ast.File) error { return removeWireProvider(file, a.ProviderName) } // Injection 注入代码 func (a *WireProviderAst) Injection(file *ast.File) error { return addWireProvider(file, a.ProviderName) } // Format 格式化输出 func (a *WireProviderAst) Format(filename string, writer io.Writer, file *ast.File) error { if filename == "" { filename = a.Path } fset := token.NewFileSet() if writer == nil { f, err := os.OpenFile(filename, os.O_WRONLY|os.O_TRUNC, 0666) if err != nil { return errors.Wrapf(err, "打开文件失败: %s", filename) } defer f.Close() writer = f } return format.Node(writer, fset, file) } // addWireProvider 向 wire.NewSet 中添加 Provider func addWireProvider(file *ast.File, providerName string) error { var found bool ast.Inspect(file, func(node ast.Node) bool { // 查找 var ProviderSet = wire.NewSet(...) 声明 genDecl, ok := node.(*ast.GenDecl) if !ok || genDecl.Tok != token.VAR { return true } for _, spec := range genDecl.Specs { valueSpec, ok := spec.(*ast.ValueSpec) if !ok { continue } // 检查是否是 ProviderSet 变量 for i, name := range valueSpec.Names { if name.Name != "ProviderSet" { continue } if i >= len(valueSpec.Values) { continue } // 查找 wire.NewSet 调用 callExpr, ok := valueSpec.Values[i].(*ast.CallExpr) if !ok { continue } selExpr, ok := callExpr.Fun.(*ast.SelectorExpr) if !ok { continue } ident, ok := selExpr.X.(*ast.Ident) if !ok || ident.Name != "wire" || selExpr.Sel.Name != "NewSet" { continue } // 检查 Provider 是否已存在 for _, arg := range callExpr.Args { if argIdent, ok := arg.(*ast.Ident); ok { if argIdent.Name == providerName { found = true return false // 已存在,不需要添加 } } } // 添加新的 Provider newArg := &ast.Ident{Name: providerName} callExpr.Args = append(callExpr.Args, newArg) found = true return false } } return true }) if !found { return fmt.Errorf("未找到 ProviderSet 变量或 wire.NewSet 调用") } return nil } // removeWireProvider 从 wire.NewSet 中移除 Provider func removeWireProvider(file *ast.File, providerName string) error { ast.Inspect(file, func(node ast.Node) bool { genDecl, ok := node.(*ast.GenDecl) if !ok || genDecl.Tok != token.VAR { return true } for _, spec := range genDecl.Specs { valueSpec, ok := spec.(*ast.ValueSpec) if !ok { continue } for i, name := range valueSpec.Names { if name.Name != "ProviderSet" { continue } if i >= len(valueSpec.Values) { continue } callExpr, ok := valueSpec.Values[i].(*ast.CallExpr) if !ok { continue } selExpr, ok := callExpr.Fun.(*ast.SelectorExpr) if !ok { continue } ident, ok := selExpr.X.(*ast.Ident) if !ok || ident.Name != "wire" || selExpr.Sel.Name != "NewSet" { continue } // 移除指定的 Provider newArgs := make([]ast.Expr, 0, len(callExpr.Args)) for _, arg := range callExpr.Args { if argIdent, ok := arg.(*ast.Ident); ok { if argIdent.Name == providerName { continue // 跳过要移除的 Provider } } newArgs = append(newArgs, arg) } callExpr.Args = newArgs return false } } return true }) return nil } // MainProviderSetAst 主 ProviderSet AST 注入结构体 // 用于向主层的 ProviderSet(如 internal/biz/biz.go)中注入新包的 Provider type MainProviderSetAst struct { Layer string // 层名称: biz, data, service PackageName string // 新包名 ProviderName string // Provider 名称,如 NewArticleUsecase ImportPath string // 导入路径 Path string // 目标文件路径 } // NewMainProviderSetAst 创建主 ProviderSet AST 注入器 func NewMainProviderSetAst(layer, packageName, providerName, importPath, path string) *MainProviderSetAst { return &MainProviderSetAst{ Layer: layer, PackageName: packageName, ProviderName: providerName, ImportPath: importPath, Path: path, } } // Parse 解析文件 func (a *MainProviderSetAst) Parse(filename string, writer io.Writer) (*ast.File, error) { if filename == "" { filename = a.Path } if filename == "" { return nil, fmt.Errorf("文件路径为空") } fset := token.NewFileSet() file, err := parser.ParseFile(fset, filename, nil, parser.ParseComments) if err != nil { return nil, fmt.Errorf("解析文件失败 %s: %w", filename, err) } return file, nil } // Rollback 回滚注入 func (a *MainProviderSetAst) Rollback(file *ast.File) error { // 移除导入 removeImport(file, a.ImportPath) // 移除 Provider return removeWireProviderWithPackage(file, a.PackageName, a.ProviderName) } // Injection 注入代码 func (a *MainProviderSetAst) Injection(file *ast.File) error { // 添加导入 addImport(file, a.ImportPath) // 添加 Provider return addWireProviderWithPackage(file, a.PackageName, a.ProviderName) } // Format 格式化输出 func (a *MainProviderSetAst) Format(filename string, writer io.Writer, file *ast.File) error { if filename == "" { filename = a.Path } fset := token.NewFileSet() if writer == nil { f, err := os.OpenFile(filename, os.O_WRONLY|os.O_TRUNC, 0666) if err != nil { return errors.Wrapf(err, "打开文件失败: %s", filename) } defer f.Close() writer = f } return format.Node(writer, fset, file) } // addImport 添加导入 func addImport(file *ast.File, importPath string) { if importPath == "" { return } impStr := fmt.Sprintf("\"%s\"", importPath) // 检查是否已存在 for _, imp := range file.Imports { if imp.Path.Value == impStr { return } } // 查找 import 声明并添加 for _, decl := range file.Decls { genDecl, ok := decl.(*ast.GenDecl) if !ok || genDecl.Tok != token.IMPORT { continue } genDecl.Specs = append(genDecl.Specs, &ast.ImportSpec{ Path: &ast.BasicLit{ Kind: token.STRING, Value: impStr, }, }) return } // 如果没有 import 声明,创建一个新的 newDecl := &ast.GenDecl{ Tok: token.IMPORT, Specs: []ast.Spec{ &ast.ImportSpec{ Path: &ast.BasicLit{ Kind: token.STRING, Value: impStr, }, }, }, } file.Decls = append([]ast.Decl{newDecl}, file.Decls...) } // removeImport 移除导入 func removeImport(file *ast.File, importPath string) { if importPath == "" { return } impStr := fmt.Sprintf("\"%s\"", importPath) for i := 0; i < len(file.Decls); i++ { genDecl, ok := file.Decls[i].(*ast.GenDecl) if !ok || genDecl.Tok != token.IMPORT { continue } for j := 0; j < len(genDecl.Specs); j++ { impSpec, ok := genDecl.Specs[j].(*ast.ImportSpec) if !ok { continue } if impSpec.Path.Value == impStr { genDecl.Specs = append(genDecl.Specs[:j], genDecl.Specs[j+1:]...) if len(genDecl.Specs) == 0 { file.Decls = append(file.Decls[:i], file.Decls[i+1:]...) } return } } } } // addWireProviderWithPackage 向 wire.NewSet 中添加带包名的 Provider // 例如: article.NewArticleUsecase func addWireProviderWithPackage(file *ast.File, packageName, providerName string) error { var found bool ast.Inspect(file, func(node ast.Node) bool { genDecl, ok := node.(*ast.GenDecl) if !ok || genDecl.Tok != token.VAR { return true } for _, spec := range genDecl.Specs { valueSpec, ok := spec.(*ast.ValueSpec) if !ok { continue } for i, name := range valueSpec.Names { if name.Name != "ProviderSet" { continue } if i >= len(valueSpec.Values) { continue } callExpr, ok := valueSpec.Values[i].(*ast.CallExpr) if !ok { continue } selExpr, ok := callExpr.Fun.(*ast.SelectorExpr) if !ok { continue } ident, ok := selExpr.X.(*ast.Ident) if !ok || ident.Name != "wire" || selExpr.Sel.Name != "NewSet" { continue } // 构建完整的 Provider 名称 fullProviderName := fmt.Sprintf("%s.%s", packageName, providerName) // 检查 Provider 是否已存在 for _, arg := range callExpr.Args { if selArg, ok := arg.(*ast.SelectorExpr); ok { if pkgIdent, ok := selArg.X.(*ast.Ident); ok { existingName := fmt.Sprintf("%s.%s", pkgIdent.Name, selArg.Sel.Name) if existingName == fullProviderName { found = true return false // 已存在,不需要添加 } } } } // 添加新的 Provider (package.ProviderName 格式) newArg := &ast.SelectorExpr{ X: &ast.Ident{Name: packageName}, Sel: &ast.Ident{Name: providerName}, } callExpr.Args = append(callExpr.Args, newArg) found = true return false } } return true }) if !found { return fmt.Errorf("未找到 ProviderSet 变量或 wire.NewSet 调用") } return nil } // removeWireProviderWithPackage 从 wire.NewSet 中移除带包名的 Provider func removeWireProviderWithPackage(file *ast.File, packageName, providerName string) error { fullProviderName := fmt.Sprintf("%s.%s", packageName, providerName) ast.Inspect(file, func(node ast.Node) bool { genDecl, ok := node.(*ast.GenDecl) if !ok || genDecl.Tok != token.VAR { return true } for _, spec := range genDecl.Specs { valueSpec, ok := spec.(*ast.ValueSpec) if !ok { continue } for i, name := range valueSpec.Names { if name.Name != "ProviderSet" { continue } if i >= len(valueSpec.Values) { continue } callExpr, ok := valueSpec.Values[i].(*ast.CallExpr) if !ok { continue } selExpr, ok := callExpr.Fun.(*ast.SelectorExpr) if !ok { continue } ident, ok := selExpr.X.(*ast.Ident) if !ok || ident.Name != "wire" || selExpr.Sel.Name != "NewSet" { continue } // 移除指定的 Provider newArgs := make([]ast.Expr, 0, len(callExpr.Args)) for _, arg := range callExpr.Args { if selArg, ok := arg.(*ast.SelectorExpr); ok { if pkgIdent, ok := selArg.X.(*ast.Ident); ok { existingName := fmt.Sprintf("%s.%s", pkgIdent.Name, selArg.Sel.Name) if existingName == fullProviderName { continue // 跳过要移除的 Provider } } } newArgs = append(newArgs, arg) } callExpr.Args = newArgs return false } } return true }) return nil } // GetLayerMainFile 获取层的主文件路径 func GetLayerMainFile(layer string) string { switch layer { case "biz": return "internal/biz/biz.go" case "data": return "internal/data/data.go" case "service": return "internal/service/service.go" case "handler", "api": return "internal/server/handler/enter.go" case "router": return "internal/server/router/enter.go" default: return "" } } // GetLayerImportPath 获取层的导入路径 func GetLayerImportPath(module, layer, packageName string) string { switch layer { case "biz": return fmt.Sprintf("%s/internal/biz/%s", module, packageName) case "data": return fmt.Sprintf("%s/internal/data/%s", module, packageName) case "service": return fmt.Sprintf("%s/internal/service/%s", module, packageName) case "handler", "api": return fmt.Sprintf("%s/internal/server/handler/%s", module, packageName) case "router": return fmt.Sprintf("%s/internal/server/router/%s", module, packageName) default: return "" } } // GetProviderName 获取 Provider 名称 func GetProviderName(layer, structName string) string { switch layer { case "biz": return fmt.Sprintf("New%sUsecase", structName) case "data": return fmt.Sprintf("New%sRepo", structName) case "service": return fmt.Sprintf("New%sService", structName) case "handler", "api": return fmt.Sprintf("New%sHandler", structName) case "router": return fmt.Sprintf("New%sRouter", structName) default: return "" } } // InjectWireProvider 注入 Wire Provider 到主文件 // 这是一个便捷函数,用于一次性完成导入和 Provider 注入 func InjectWireProvider(module, layer, packageName, structName string) error { mainFile := GetLayerMainFile(layer) if mainFile == "" { return fmt.Errorf("未知的层: %s", layer) } importPath := GetLayerImportPath(module, layer, packageName) providerName := GetProviderName(layer, structName) injector := NewMainProviderSetAst(layer, packageName, providerName, importPath, mainFile) file, err := injector.Parse("", nil) if err != nil { return err } if err := injector.Injection(file); err != nil { return err } return injector.Format("", nil, file) } // RollbackWireProvider 回滚 Wire Provider 注入 func RollbackWireProvider(module, layer, packageName, structName string) error { mainFile := GetLayerMainFile(layer) if mainFile == "" { return fmt.Errorf("未知的层: %s", layer) } importPath := GetLayerImportPath(module, layer, packageName) providerName := GetProviderName(layer, structName) injector := NewMainProviderSetAst(layer, packageName, providerName, importPath, mainFile) file, err := injector.Parse("", nil) if err != nil { return err } if err := injector.Rollback(file); err != nil { return err } return injector.Format("", nil, file) } // WireInjectionConfig Wire 注入配置 type WireInjectionConfig struct { Module string // 模块名 (go.mod 中的 module) PackageName string // 包名 StructName string // 结构体名 Layers []string // 需要注入的层 } // InjectAllWireProviders 注入所有层的 Wire Provider func InjectAllWireProviders(config *WireInjectionConfig) error { for _, layer := range config.Layers { if err := InjectWireProvider(config.Module, layer, config.PackageName, config.StructName); err != nil { // 如果注入失败,尝试回滚已注入的 for _, rollbackLayer := range config.Layers { if rollbackLayer == layer { break } _ = RollbackWireProvider(config.Module, rollbackLayer, config.PackageName, config.StructName) } return fmt.Errorf("注入 %s 层 Wire Provider 失败: %w", layer, err) } } return nil } // RollbackAllWireProviders 回滚所有层的 Wire Provider func RollbackAllWireProviders(config *WireInjectionConfig) error { var lastErr error for _, layer := range config.Layers { if err := RollbackWireProvider(config.Module, layer, config.PackageName, config.StructName); err != nil { lastErr = err } } return lastErr }