619 lines
15 KiB
Go
619 lines
15 KiB
Go
package autocode
|
||
|
||
import (
|
||
"fmt"
|
||
"go/ast"
|
||
"go/format"
|
||
"go/parser"
|
||
"go/token"
|
||
"io"
|
||
"os"
|
||
|
||
"github.com/pkg/errors"
|
||
)
|
||
|
||
// WireProviderAst Wire Provider AST 注入结构体
|
||
// 用于向各层的 ProviderSet 中注入新的 Provider
|
||
type WireProviderAst struct {
|
||
Layer string // 层名称: biz, data, handler, service, router
|
||
PackageName string // 包名
|
||
ProviderName string // Provider 名称,如 NewArticleUsecase
|
||
Path string // 目标文件路径
|
||
}
|
||
|
||
// NewWireProviderAst 创建 Wire Provider AST 注入器
|
||
func NewWireProviderAst(layer, packageName, providerName, path string) *WireProviderAst {
|
||
return &WireProviderAst{
|
||
Layer: layer,
|
||
PackageName: packageName,
|
||
ProviderName: providerName,
|
||
Path: path,
|
||
}
|
||
}
|
||
|
||
// Parse 解析文件
|
||
func (a *WireProviderAst) Parse(filename string, writer io.Writer) (*ast.File, error) {
|
||
if filename == "" {
|
||
filename = a.Path
|
||
}
|
||
if filename == "" {
|
||
return nil, fmt.Errorf("文件路径为空")
|
||
}
|
||
fset := token.NewFileSet()
|
||
file, err := parser.ParseFile(fset, filename, nil, parser.ParseComments)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("解析文件失败 %s: %w", filename, err)
|
||
}
|
||
return file, nil
|
||
}
|
||
|
||
// Rollback 回滚注入
|
||
func (a *WireProviderAst) Rollback(file *ast.File) error {
|
||
return removeWireProvider(file, a.ProviderName)
|
||
}
|
||
|
||
// Injection 注入代码
|
||
func (a *WireProviderAst) Injection(file *ast.File) error {
|
||
return addWireProvider(file, a.ProviderName)
|
||
}
|
||
|
||
// Format 格式化输出
|
||
func (a *WireProviderAst) Format(filename string, writer io.Writer, file *ast.File) error {
|
||
if filename == "" {
|
||
filename = a.Path
|
||
}
|
||
fset := token.NewFileSet()
|
||
if writer == nil {
|
||
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_TRUNC, 0666)
|
||
if err != nil {
|
||
return errors.Wrapf(err, "打开文件失败: %s", filename)
|
||
}
|
||
defer f.Close()
|
||
writer = f
|
||
}
|
||
return format.Node(writer, fset, file)
|
||
}
|
||
|
||
// addWireProvider 向 wire.NewSet 中添加 Provider
|
||
func addWireProvider(file *ast.File, providerName string) error {
|
||
var found bool
|
||
ast.Inspect(file, func(node ast.Node) bool {
|
||
// 查找 var ProviderSet = wire.NewSet(...) 声明
|
||
genDecl, ok := node.(*ast.GenDecl)
|
||
if !ok || genDecl.Tok != token.VAR {
|
||
return true
|
||
}
|
||
|
||
for _, spec := range genDecl.Specs {
|
||
valueSpec, ok := spec.(*ast.ValueSpec)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
// 检查是否是 ProviderSet 变量
|
||
for i, name := range valueSpec.Names {
|
||
if name.Name != "ProviderSet" {
|
||
continue
|
||
}
|
||
|
||
if i >= len(valueSpec.Values) {
|
||
continue
|
||
}
|
||
|
||
// 查找 wire.NewSet 调用
|
||
callExpr, ok := valueSpec.Values[i].(*ast.CallExpr)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
ident, ok := selExpr.X.(*ast.Ident)
|
||
if !ok || ident.Name != "wire" || selExpr.Sel.Name != "NewSet" {
|
||
continue
|
||
}
|
||
|
||
// 检查 Provider 是否已存在
|
||
for _, arg := range callExpr.Args {
|
||
if argIdent, ok := arg.(*ast.Ident); ok {
|
||
if argIdent.Name == providerName {
|
||
found = true
|
||
return false // 已存在,不需要添加
|
||
}
|
||
}
|
||
}
|
||
|
||
// 添加新的 Provider
|
||
newArg := &ast.Ident{Name: providerName}
|
||
callExpr.Args = append(callExpr.Args, newArg)
|
||
found = true
|
||
return false
|
||
}
|
||
}
|
||
return true
|
||
})
|
||
|
||
if !found {
|
||
return fmt.Errorf("未找到 ProviderSet 变量或 wire.NewSet 调用")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// removeWireProvider 从 wire.NewSet 中移除 Provider
|
||
func removeWireProvider(file *ast.File, providerName string) error {
|
||
ast.Inspect(file, func(node ast.Node) bool {
|
||
genDecl, ok := node.(*ast.GenDecl)
|
||
if !ok || genDecl.Tok != token.VAR {
|
||
return true
|
||
}
|
||
|
||
for _, spec := range genDecl.Specs {
|
||
valueSpec, ok := spec.(*ast.ValueSpec)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
for i, name := range valueSpec.Names {
|
||
if name.Name != "ProviderSet" {
|
||
continue
|
||
}
|
||
|
||
if i >= len(valueSpec.Values) {
|
||
continue
|
||
}
|
||
|
||
callExpr, ok := valueSpec.Values[i].(*ast.CallExpr)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
ident, ok := selExpr.X.(*ast.Ident)
|
||
if !ok || ident.Name != "wire" || selExpr.Sel.Name != "NewSet" {
|
||
continue
|
||
}
|
||
|
||
// 移除指定的 Provider
|
||
newArgs := make([]ast.Expr, 0, len(callExpr.Args))
|
||
for _, arg := range callExpr.Args {
|
||
if argIdent, ok := arg.(*ast.Ident); ok {
|
||
if argIdent.Name == providerName {
|
||
continue // 跳过要移除的 Provider
|
||
}
|
||
}
|
||
newArgs = append(newArgs, arg)
|
||
}
|
||
callExpr.Args = newArgs
|
||
return false
|
||
}
|
||
}
|
||
return true
|
||
})
|
||
return nil
|
||
}
|
||
|
||
// MainProviderSetAst 主 ProviderSet AST 注入结构体
|
||
// 用于向主层的 ProviderSet(如 internal/biz/biz.go)中注入新包的 Provider
|
||
type MainProviderSetAst struct {
|
||
Layer string // 层名称: biz, data, service
|
||
PackageName string // 新包名
|
||
ProviderName string // Provider 名称,如 NewArticleUsecase
|
||
ImportPath string // 导入路径
|
||
Path string // 目标文件路径
|
||
}
|
||
|
||
// NewMainProviderSetAst 创建主 ProviderSet AST 注入器
|
||
func NewMainProviderSetAst(layer, packageName, providerName, importPath, path string) *MainProviderSetAst {
|
||
return &MainProviderSetAst{
|
||
Layer: layer,
|
||
PackageName: packageName,
|
||
ProviderName: providerName,
|
||
ImportPath: importPath,
|
||
Path: path,
|
||
}
|
||
}
|
||
|
||
// Parse 解析文件
|
||
func (a *MainProviderSetAst) Parse(filename string, writer io.Writer) (*ast.File, error) {
|
||
if filename == "" {
|
||
filename = a.Path
|
||
}
|
||
if filename == "" {
|
||
return nil, fmt.Errorf("文件路径为空")
|
||
}
|
||
fset := token.NewFileSet()
|
||
file, err := parser.ParseFile(fset, filename, nil, parser.ParseComments)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("解析文件失败 %s: %w", filename, err)
|
||
}
|
||
return file, nil
|
||
}
|
||
|
||
// Rollback 回滚注入
|
||
func (a *MainProviderSetAst) Rollback(file *ast.File) error {
|
||
// 移除导入
|
||
removeImport(file, a.ImportPath)
|
||
// 移除 Provider
|
||
return removeWireProviderWithPackage(file, a.PackageName, a.ProviderName)
|
||
}
|
||
|
||
// Injection 注入代码
|
||
func (a *MainProviderSetAst) Injection(file *ast.File) error {
|
||
// 添加导入
|
||
addImport(file, a.ImportPath)
|
||
// 添加 Provider
|
||
return addWireProviderWithPackage(file, a.PackageName, a.ProviderName)
|
||
}
|
||
|
||
// Format 格式化输出
|
||
func (a *MainProviderSetAst) Format(filename string, writer io.Writer, file *ast.File) error {
|
||
if filename == "" {
|
||
filename = a.Path
|
||
}
|
||
fset := token.NewFileSet()
|
||
if writer == nil {
|
||
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_TRUNC, 0666)
|
||
if err != nil {
|
||
return errors.Wrapf(err, "打开文件失败: %s", filename)
|
||
}
|
||
defer f.Close()
|
||
writer = f
|
||
}
|
||
return format.Node(writer, fset, file)
|
||
}
|
||
|
||
// addImport 添加导入
|
||
func addImport(file *ast.File, importPath string) {
|
||
if importPath == "" {
|
||
return
|
||
}
|
||
impStr := fmt.Sprintf("\"%s\"", importPath)
|
||
|
||
// 检查是否已存在
|
||
for _, imp := range file.Imports {
|
||
if imp.Path.Value == impStr {
|
||
return
|
||
}
|
||
}
|
||
|
||
// 查找 import 声明并添加
|
||
for _, decl := range file.Decls {
|
||
genDecl, ok := decl.(*ast.GenDecl)
|
||
if !ok || genDecl.Tok != token.IMPORT {
|
||
continue
|
||
}
|
||
genDecl.Specs = append(genDecl.Specs, &ast.ImportSpec{
|
||
Path: &ast.BasicLit{
|
||
Kind: token.STRING,
|
||
Value: impStr,
|
||
},
|
||
})
|
||
return
|
||
}
|
||
|
||
// 如果没有 import 声明,创建一个新的
|
||
newDecl := &ast.GenDecl{
|
||
Tok: token.IMPORT,
|
||
Specs: []ast.Spec{
|
||
&ast.ImportSpec{
|
||
Path: &ast.BasicLit{
|
||
Kind: token.STRING,
|
||
Value: impStr,
|
||
},
|
||
},
|
||
},
|
||
}
|
||
file.Decls = append([]ast.Decl{newDecl}, file.Decls...)
|
||
}
|
||
|
||
// removeImport 移除导入
|
||
func removeImport(file *ast.File, importPath string) {
|
||
if importPath == "" {
|
||
return
|
||
}
|
||
impStr := fmt.Sprintf("\"%s\"", importPath)
|
||
|
||
for i := 0; i < len(file.Decls); i++ {
|
||
genDecl, ok := file.Decls[i].(*ast.GenDecl)
|
||
if !ok || genDecl.Tok != token.IMPORT {
|
||
continue
|
||
}
|
||
for j := 0; j < len(genDecl.Specs); j++ {
|
||
impSpec, ok := genDecl.Specs[j].(*ast.ImportSpec)
|
||
if !ok {
|
||
continue
|
||
}
|
||
if impSpec.Path.Value == impStr {
|
||
genDecl.Specs = append(genDecl.Specs[:j], genDecl.Specs[j+1:]...)
|
||
if len(genDecl.Specs) == 0 {
|
||
file.Decls = append(file.Decls[:i], file.Decls[i+1:]...)
|
||
}
|
||
return
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// addWireProviderWithPackage 向 wire.NewSet 中添加带包名的 Provider
|
||
// 例如: article.NewArticleUsecase
|
||
func addWireProviderWithPackage(file *ast.File, packageName, providerName string) error {
|
||
var found bool
|
||
ast.Inspect(file, func(node ast.Node) bool {
|
||
genDecl, ok := node.(*ast.GenDecl)
|
||
if !ok || genDecl.Tok != token.VAR {
|
||
return true
|
||
}
|
||
|
||
for _, spec := range genDecl.Specs {
|
||
valueSpec, ok := spec.(*ast.ValueSpec)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
for i, name := range valueSpec.Names {
|
||
if name.Name != "ProviderSet" {
|
||
continue
|
||
}
|
||
|
||
if i >= len(valueSpec.Values) {
|
||
continue
|
||
}
|
||
|
||
callExpr, ok := valueSpec.Values[i].(*ast.CallExpr)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
ident, ok := selExpr.X.(*ast.Ident)
|
||
if !ok || ident.Name != "wire" || selExpr.Sel.Name != "NewSet" {
|
||
continue
|
||
}
|
||
|
||
// 构建完整的 Provider 名称
|
||
fullProviderName := fmt.Sprintf("%s.%s", packageName, providerName)
|
||
|
||
// 检查 Provider 是否已存在
|
||
for _, arg := range callExpr.Args {
|
||
if selArg, ok := arg.(*ast.SelectorExpr); ok {
|
||
if pkgIdent, ok := selArg.X.(*ast.Ident); ok {
|
||
existingName := fmt.Sprintf("%s.%s", pkgIdent.Name, selArg.Sel.Name)
|
||
if existingName == fullProviderName {
|
||
found = true
|
||
return false // 已存在,不需要添加
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 添加新的 Provider (package.ProviderName 格式)
|
||
newArg := &ast.SelectorExpr{
|
||
X: &ast.Ident{Name: packageName},
|
||
Sel: &ast.Ident{Name: providerName},
|
||
}
|
||
callExpr.Args = append(callExpr.Args, newArg)
|
||
found = true
|
||
return false
|
||
}
|
||
}
|
||
return true
|
||
})
|
||
|
||
if !found {
|
||
return fmt.Errorf("未找到 ProviderSet 变量或 wire.NewSet 调用")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// removeWireProviderWithPackage 从 wire.NewSet 中移除带包名的 Provider
|
||
func removeWireProviderWithPackage(file *ast.File, packageName, providerName string) error {
|
||
fullProviderName := fmt.Sprintf("%s.%s", packageName, providerName)
|
||
|
||
ast.Inspect(file, func(node ast.Node) bool {
|
||
genDecl, ok := node.(*ast.GenDecl)
|
||
if !ok || genDecl.Tok != token.VAR {
|
||
return true
|
||
}
|
||
|
||
for _, spec := range genDecl.Specs {
|
||
valueSpec, ok := spec.(*ast.ValueSpec)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
for i, name := range valueSpec.Names {
|
||
if name.Name != "ProviderSet" {
|
||
continue
|
||
}
|
||
|
||
if i >= len(valueSpec.Values) {
|
||
continue
|
||
}
|
||
|
||
callExpr, ok := valueSpec.Values[i].(*ast.CallExpr)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
ident, ok := selExpr.X.(*ast.Ident)
|
||
if !ok || ident.Name != "wire" || selExpr.Sel.Name != "NewSet" {
|
||
continue
|
||
}
|
||
|
||
// 移除指定的 Provider
|
||
newArgs := make([]ast.Expr, 0, len(callExpr.Args))
|
||
for _, arg := range callExpr.Args {
|
||
if selArg, ok := arg.(*ast.SelectorExpr); ok {
|
||
if pkgIdent, ok := selArg.X.(*ast.Ident); ok {
|
||
existingName := fmt.Sprintf("%s.%s", pkgIdent.Name, selArg.Sel.Name)
|
||
if existingName == fullProviderName {
|
||
continue // 跳过要移除的 Provider
|
||
}
|
||
}
|
||
}
|
||
newArgs = append(newArgs, arg)
|
||
}
|
||
callExpr.Args = newArgs
|
||
return false
|
||
}
|
||
}
|
||
return true
|
||
})
|
||
return nil
|
||
}
|
||
|
||
// GetLayerMainFile 获取层的主文件路径
|
||
func GetLayerMainFile(layer string) string {
|
||
switch layer {
|
||
case "biz":
|
||
return "internal/biz/biz.go"
|
||
case "data":
|
||
return "internal/data/data.go"
|
||
case "service":
|
||
return "internal/service/service.go"
|
||
case "handler", "api":
|
||
return "internal/server/handler/enter.go"
|
||
case "router":
|
||
return "internal/server/router/enter.go"
|
||
default:
|
||
return ""
|
||
}
|
||
}
|
||
|
||
// GetLayerImportPath 获取层的导入路径
|
||
func GetLayerImportPath(module, layer, packageName string) string {
|
||
switch layer {
|
||
case "biz":
|
||
return fmt.Sprintf("%s/internal/biz/%s", module, packageName)
|
||
case "data":
|
||
return fmt.Sprintf("%s/internal/data/%s", module, packageName)
|
||
case "service":
|
||
return fmt.Sprintf("%s/internal/service/%s", module, packageName)
|
||
case "handler", "api":
|
||
return fmt.Sprintf("%s/internal/server/handler/%s", module, packageName)
|
||
case "router":
|
||
return fmt.Sprintf("%s/internal/server/router/%s", module, packageName)
|
||
default:
|
||
return ""
|
||
}
|
||
}
|
||
|
||
// GetProviderName 获取 Provider 名称
|
||
func GetProviderName(layer, structName string) string {
|
||
switch layer {
|
||
case "biz":
|
||
return fmt.Sprintf("New%sUsecase", structName)
|
||
case "data":
|
||
return fmt.Sprintf("New%sRepo", structName)
|
||
case "service":
|
||
return fmt.Sprintf("New%sService", structName)
|
||
case "handler", "api":
|
||
return fmt.Sprintf("New%sHandler", structName)
|
||
case "router":
|
||
return fmt.Sprintf("New%sRouter", structName)
|
||
default:
|
||
return ""
|
||
}
|
||
}
|
||
|
||
// InjectWireProvider 注入 Wire Provider 到主文件
|
||
// 这是一个便捷函数,用于一次性完成导入和 Provider 注入
|
||
func InjectWireProvider(module, layer, packageName, structName string) error {
|
||
mainFile := GetLayerMainFile(layer)
|
||
if mainFile == "" {
|
||
return fmt.Errorf("未知的层: %s", layer)
|
||
}
|
||
|
||
importPath := GetLayerImportPath(module, layer, packageName)
|
||
providerName := GetProviderName(layer, structName)
|
||
|
||
injector := NewMainProviderSetAst(layer, packageName, providerName, importPath, mainFile)
|
||
|
||
file, err := injector.Parse("", nil)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
if err := injector.Injection(file); err != nil {
|
||
return err
|
||
}
|
||
|
||
return injector.Format("", nil, file)
|
||
}
|
||
|
||
// RollbackWireProvider 回滚 Wire Provider 注入
|
||
func RollbackWireProvider(module, layer, packageName, structName string) error {
|
||
mainFile := GetLayerMainFile(layer)
|
||
if mainFile == "" {
|
||
return fmt.Errorf("未知的层: %s", layer)
|
||
}
|
||
|
||
importPath := GetLayerImportPath(module, layer, packageName)
|
||
providerName := GetProviderName(layer, structName)
|
||
|
||
injector := NewMainProviderSetAst(layer, packageName, providerName, importPath, mainFile)
|
||
|
||
file, err := injector.Parse("", nil)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
if err := injector.Rollback(file); err != nil {
|
||
return err
|
||
}
|
||
|
||
return injector.Format("", nil, file)
|
||
}
|
||
|
||
// WireInjectionConfig Wire 注入配置
|
||
type WireInjectionConfig struct {
|
||
Module string // 模块名 (go.mod 中的 module)
|
||
PackageName string // 包名
|
||
StructName string // 结构体名
|
||
Layers []string // 需要注入的层
|
||
}
|
||
|
||
// InjectAllWireProviders 注入所有层的 Wire Provider
|
||
func InjectAllWireProviders(config *WireInjectionConfig) error {
|
||
for _, layer := range config.Layers {
|
||
if err := InjectWireProvider(config.Module, layer, config.PackageName, config.StructName); err != nil {
|
||
// 如果注入失败,尝试回滚已注入的
|
||
for _, rollbackLayer := range config.Layers {
|
||
if rollbackLayer == layer {
|
||
break
|
||
}
|
||
_ = RollbackWireProvider(config.Module, rollbackLayer, config.PackageName, config.StructName)
|
||
}
|
||
return fmt.Errorf("注入 %s 层 Wire Provider 失败: %w", layer, err)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// RollbackAllWireProviders 回滚所有层的 Wire Provider
|
||
func RollbackAllWireProviders(config *WireInjectionConfig) error {
|
||
var lastErr error
|
||
for _, layer := range config.Layers {
|
||
if err := RollbackWireProvider(config.Module, layer, config.PackageName, config.StructName); err != nil {
|
||
lastErr = err
|
||
}
|
||
}
|
||
return lastErr
|
||
}
|