kra/pkg/utils/autocode/injection.go

619 lines
15 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package autocode
import (
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"io"
"os"
"github.com/pkg/errors"
)
// WireProviderAst Wire Provider AST 注入结构体
// 用于向各层的 ProviderSet 中注入新的 Provider
type WireProviderAst struct {
Layer string // 层名称: biz, data, handler, service, router
PackageName string // 包名
ProviderName string // Provider 名称,如 NewArticleUsecase
Path string // 目标文件路径
}
// NewWireProviderAst 创建 Wire Provider AST 注入器
func NewWireProviderAst(layer, packageName, providerName, path string) *WireProviderAst {
return &WireProviderAst{
Layer: layer,
PackageName: packageName,
ProviderName: providerName,
Path: path,
}
}
// Parse 解析文件
func (a *WireProviderAst) Parse(filename string, writer io.Writer) (*ast.File, error) {
if filename == "" {
filename = a.Path
}
if filename == "" {
return nil, fmt.Errorf("文件路径为空")
}
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, filename, nil, parser.ParseComments)
if err != nil {
return nil, fmt.Errorf("解析文件失败 %s: %w", filename, err)
}
return file, nil
}
// Rollback 回滚注入
func (a *WireProviderAst) Rollback(file *ast.File) error {
return removeWireProvider(file, a.ProviderName)
}
// Injection 注入代码
func (a *WireProviderAst) Injection(file *ast.File) error {
return addWireProvider(file, a.ProviderName)
}
// Format 格式化输出
func (a *WireProviderAst) Format(filename string, writer io.Writer, file *ast.File) error {
if filename == "" {
filename = a.Path
}
fset := token.NewFileSet()
if writer == nil {
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_TRUNC, 0666)
if err != nil {
return errors.Wrapf(err, "打开文件失败: %s", filename)
}
defer f.Close()
writer = f
}
return format.Node(writer, fset, file)
}
// addWireProvider 向 wire.NewSet 中添加 Provider
func addWireProvider(file *ast.File, providerName string) error {
var found bool
ast.Inspect(file, func(node ast.Node) bool {
// 查找 var ProviderSet = wire.NewSet(...) 声明
genDecl, ok := node.(*ast.GenDecl)
if !ok || genDecl.Tok != token.VAR {
return true
}
for _, spec := range genDecl.Specs {
valueSpec, ok := spec.(*ast.ValueSpec)
if !ok {
continue
}
// 检查是否是 ProviderSet 变量
for i, name := range valueSpec.Names {
if name.Name != "ProviderSet" {
continue
}
if i >= len(valueSpec.Values) {
continue
}
// 查找 wire.NewSet 调用
callExpr, ok := valueSpec.Values[i].(*ast.CallExpr)
if !ok {
continue
}
selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
if !ok {
continue
}
ident, ok := selExpr.X.(*ast.Ident)
if !ok || ident.Name != "wire" || selExpr.Sel.Name != "NewSet" {
continue
}
// 检查 Provider 是否已存在
for _, arg := range callExpr.Args {
if argIdent, ok := arg.(*ast.Ident); ok {
if argIdent.Name == providerName {
found = true
return false // 已存在,不需要添加
}
}
}
// 添加新的 Provider
newArg := &ast.Ident{Name: providerName}
callExpr.Args = append(callExpr.Args, newArg)
found = true
return false
}
}
return true
})
if !found {
return fmt.Errorf("未找到 ProviderSet 变量或 wire.NewSet 调用")
}
return nil
}
// removeWireProvider 从 wire.NewSet 中移除 Provider
func removeWireProvider(file *ast.File, providerName string) error {
ast.Inspect(file, func(node ast.Node) bool {
genDecl, ok := node.(*ast.GenDecl)
if !ok || genDecl.Tok != token.VAR {
return true
}
for _, spec := range genDecl.Specs {
valueSpec, ok := spec.(*ast.ValueSpec)
if !ok {
continue
}
for i, name := range valueSpec.Names {
if name.Name != "ProviderSet" {
continue
}
if i >= len(valueSpec.Values) {
continue
}
callExpr, ok := valueSpec.Values[i].(*ast.CallExpr)
if !ok {
continue
}
selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
if !ok {
continue
}
ident, ok := selExpr.X.(*ast.Ident)
if !ok || ident.Name != "wire" || selExpr.Sel.Name != "NewSet" {
continue
}
// 移除指定的 Provider
newArgs := make([]ast.Expr, 0, len(callExpr.Args))
for _, arg := range callExpr.Args {
if argIdent, ok := arg.(*ast.Ident); ok {
if argIdent.Name == providerName {
continue // 跳过要移除的 Provider
}
}
newArgs = append(newArgs, arg)
}
callExpr.Args = newArgs
return false
}
}
return true
})
return nil
}
// MainProviderSetAst 主 ProviderSet AST 注入结构体
// 用于向主层的 ProviderSet如 internal/biz/biz.go中注入新包的 Provider
type MainProviderSetAst struct {
Layer string // 层名称: biz, data, service
PackageName string // 新包名
ProviderName string // Provider 名称,如 NewArticleUsecase
ImportPath string // 导入路径
Path string // 目标文件路径
}
// NewMainProviderSetAst 创建主 ProviderSet AST 注入器
func NewMainProviderSetAst(layer, packageName, providerName, importPath, path string) *MainProviderSetAst {
return &MainProviderSetAst{
Layer: layer,
PackageName: packageName,
ProviderName: providerName,
ImportPath: importPath,
Path: path,
}
}
// Parse 解析文件
func (a *MainProviderSetAst) Parse(filename string, writer io.Writer) (*ast.File, error) {
if filename == "" {
filename = a.Path
}
if filename == "" {
return nil, fmt.Errorf("文件路径为空")
}
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, filename, nil, parser.ParseComments)
if err != nil {
return nil, fmt.Errorf("解析文件失败 %s: %w", filename, err)
}
return file, nil
}
// Rollback 回滚注入
func (a *MainProviderSetAst) Rollback(file *ast.File) error {
// 移除导入
removeImport(file, a.ImportPath)
// 移除 Provider
return removeWireProviderWithPackage(file, a.PackageName, a.ProviderName)
}
// Injection 注入代码
func (a *MainProviderSetAst) Injection(file *ast.File) error {
// 添加导入
addImport(file, a.ImportPath)
// 添加 Provider
return addWireProviderWithPackage(file, a.PackageName, a.ProviderName)
}
// Format 格式化输出
func (a *MainProviderSetAst) Format(filename string, writer io.Writer, file *ast.File) error {
if filename == "" {
filename = a.Path
}
fset := token.NewFileSet()
if writer == nil {
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_TRUNC, 0666)
if err != nil {
return errors.Wrapf(err, "打开文件失败: %s", filename)
}
defer f.Close()
writer = f
}
return format.Node(writer, fset, file)
}
// addImport 添加导入
func addImport(file *ast.File, importPath string) {
if importPath == "" {
return
}
impStr := fmt.Sprintf("\"%s\"", importPath)
// 检查是否已存在
for _, imp := range file.Imports {
if imp.Path.Value == impStr {
return
}
}
// 查找 import 声明并添加
for _, decl := range file.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok || genDecl.Tok != token.IMPORT {
continue
}
genDecl.Specs = append(genDecl.Specs, &ast.ImportSpec{
Path: &ast.BasicLit{
Kind: token.STRING,
Value: impStr,
},
})
return
}
// 如果没有 import 声明,创建一个新的
newDecl := &ast.GenDecl{
Tok: token.IMPORT,
Specs: []ast.Spec{
&ast.ImportSpec{
Path: &ast.BasicLit{
Kind: token.STRING,
Value: impStr,
},
},
},
}
file.Decls = append([]ast.Decl{newDecl}, file.Decls...)
}
// removeImport 移除导入
func removeImport(file *ast.File, importPath string) {
if importPath == "" {
return
}
impStr := fmt.Sprintf("\"%s\"", importPath)
for i := 0; i < len(file.Decls); i++ {
genDecl, ok := file.Decls[i].(*ast.GenDecl)
if !ok || genDecl.Tok != token.IMPORT {
continue
}
for j := 0; j < len(genDecl.Specs); j++ {
impSpec, ok := genDecl.Specs[j].(*ast.ImportSpec)
if !ok {
continue
}
if impSpec.Path.Value == impStr {
genDecl.Specs = append(genDecl.Specs[:j], genDecl.Specs[j+1:]...)
if len(genDecl.Specs) == 0 {
file.Decls = append(file.Decls[:i], file.Decls[i+1:]...)
}
return
}
}
}
}
// addWireProviderWithPackage 向 wire.NewSet 中添加带包名的 Provider
// 例如: article.NewArticleUsecase
func addWireProviderWithPackage(file *ast.File, packageName, providerName string) error {
var found bool
ast.Inspect(file, func(node ast.Node) bool {
genDecl, ok := node.(*ast.GenDecl)
if !ok || genDecl.Tok != token.VAR {
return true
}
for _, spec := range genDecl.Specs {
valueSpec, ok := spec.(*ast.ValueSpec)
if !ok {
continue
}
for i, name := range valueSpec.Names {
if name.Name != "ProviderSet" {
continue
}
if i >= len(valueSpec.Values) {
continue
}
callExpr, ok := valueSpec.Values[i].(*ast.CallExpr)
if !ok {
continue
}
selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
if !ok {
continue
}
ident, ok := selExpr.X.(*ast.Ident)
if !ok || ident.Name != "wire" || selExpr.Sel.Name != "NewSet" {
continue
}
// 构建完整的 Provider 名称
fullProviderName := fmt.Sprintf("%s.%s", packageName, providerName)
// 检查 Provider 是否已存在
for _, arg := range callExpr.Args {
if selArg, ok := arg.(*ast.SelectorExpr); ok {
if pkgIdent, ok := selArg.X.(*ast.Ident); ok {
existingName := fmt.Sprintf("%s.%s", pkgIdent.Name, selArg.Sel.Name)
if existingName == fullProviderName {
found = true
return false // 已存在,不需要添加
}
}
}
}
// 添加新的 Provider (package.ProviderName 格式)
newArg := &ast.SelectorExpr{
X: &ast.Ident{Name: packageName},
Sel: &ast.Ident{Name: providerName},
}
callExpr.Args = append(callExpr.Args, newArg)
found = true
return false
}
}
return true
})
if !found {
return fmt.Errorf("未找到 ProviderSet 变量或 wire.NewSet 调用")
}
return nil
}
// removeWireProviderWithPackage 从 wire.NewSet 中移除带包名的 Provider
func removeWireProviderWithPackage(file *ast.File, packageName, providerName string) error {
fullProviderName := fmt.Sprintf("%s.%s", packageName, providerName)
ast.Inspect(file, func(node ast.Node) bool {
genDecl, ok := node.(*ast.GenDecl)
if !ok || genDecl.Tok != token.VAR {
return true
}
for _, spec := range genDecl.Specs {
valueSpec, ok := spec.(*ast.ValueSpec)
if !ok {
continue
}
for i, name := range valueSpec.Names {
if name.Name != "ProviderSet" {
continue
}
if i >= len(valueSpec.Values) {
continue
}
callExpr, ok := valueSpec.Values[i].(*ast.CallExpr)
if !ok {
continue
}
selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
if !ok {
continue
}
ident, ok := selExpr.X.(*ast.Ident)
if !ok || ident.Name != "wire" || selExpr.Sel.Name != "NewSet" {
continue
}
// 移除指定的 Provider
newArgs := make([]ast.Expr, 0, len(callExpr.Args))
for _, arg := range callExpr.Args {
if selArg, ok := arg.(*ast.SelectorExpr); ok {
if pkgIdent, ok := selArg.X.(*ast.Ident); ok {
existingName := fmt.Sprintf("%s.%s", pkgIdent.Name, selArg.Sel.Name)
if existingName == fullProviderName {
continue // 跳过要移除的 Provider
}
}
}
newArgs = append(newArgs, arg)
}
callExpr.Args = newArgs
return false
}
}
return true
})
return nil
}
// GetLayerMainFile 获取层的主文件路径
func GetLayerMainFile(layer string) string {
switch layer {
case "biz":
return "internal/biz/biz.go"
case "data":
return "internal/data/data.go"
case "service":
return "internal/service/service.go"
case "handler", "api":
return "internal/server/handler/enter.go"
case "router":
return "internal/server/router/enter.go"
default:
return ""
}
}
// GetLayerImportPath 获取层的导入路径
func GetLayerImportPath(module, layer, packageName string) string {
switch layer {
case "biz":
return fmt.Sprintf("%s/internal/biz/%s", module, packageName)
case "data":
return fmt.Sprintf("%s/internal/data/%s", module, packageName)
case "service":
return fmt.Sprintf("%s/internal/service/%s", module, packageName)
case "handler", "api":
return fmt.Sprintf("%s/internal/server/handler/%s", module, packageName)
case "router":
return fmt.Sprintf("%s/internal/server/router/%s", module, packageName)
default:
return ""
}
}
// GetProviderName 获取 Provider 名称
func GetProviderName(layer, structName string) string {
switch layer {
case "biz":
return fmt.Sprintf("New%sUsecase", structName)
case "data":
return fmt.Sprintf("New%sRepo", structName)
case "service":
return fmt.Sprintf("New%sService", structName)
case "handler", "api":
return fmt.Sprintf("New%sHandler", structName)
case "router":
return fmt.Sprintf("New%sRouter", structName)
default:
return ""
}
}
// InjectWireProvider 注入 Wire Provider 到主文件
// 这是一个便捷函数,用于一次性完成导入和 Provider 注入
func InjectWireProvider(module, layer, packageName, structName string) error {
mainFile := GetLayerMainFile(layer)
if mainFile == "" {
return fmt.Errorf("未知的层: %s", layer)
}
importPath := GetLayerImportPath(module, layer, packageName)
providerName := GetProviderName(layer, structName)
injector := NewMainProviderSetAst(layer, packageName, providerName, importPath, mainFile)
file, err := injector.Parse("", nil)
if err != nil {
return err
}
if err := injector.Injection(file); err != nil {
return err
}
return injector.Format("", nil, file)
}
// RollbackWireProvider 回滚 Wire Provider 注入
func RollbackWireProvider(module, layer, packageName, structName string) error {
mainFile := GetLayerMainFile(layer)
if mainFile == "" {
return fmt.Errorf("未知的层: %s", layer)
}
importPath := GetLayerImportPath(module, layer, packageName)
providerName := GetProviderName(layer, structName)
injector := NewMainProviderSetAst(layer, packageName, providerName, importPath, mainFile)
file, err := injector.Parse("", nil)
if err != nil {
return err
}
if err := injector.Rollback(file); err != nil {
return err
}
return injector.Format("", nil, file)
}
// WireInjectionConfig Wire 注入配置
type WireInjectionConfig struct {
Module string // 模块名 (go.mod 中的 module)
PackageName string // 包名
StructName string // 结构体名
Layers []string // 需要注入的层
}
// InjectAllWireProviders 注入所有层的 Wire Provider
func InjectAllWireProviders(config *WireInjectionConfig) error {
for _, layer := range config.Layers {
if err := InjectWireProvider(config.Module, layer, config.PackageName, config.StructName); err != nil {
// 如果注入失败,尝试回滚已注入的
for _, rollbackLayer := range config.Layers {
if rollbackLayer == layer {
break
}
_ = RollbackWireProvider(config.Module, rollbackLayer, config.PackageName, config.StructName)
}
return fmt.Errorf("注入 %s 层 Wire Provider 失败: %w", layer, err)
}
}
return nil
}
// RollbackAllWireProviders 回滚所有层的 Wire Provider
func RollbackAllWireProviders(config *WireInjectionConfig) error {
var lastErr error
for _, layer := range config.Layers {
if err := RollbackWireProvider(config.Module, layer, config.PackageName, config.StructName); err != nil {
lastErr = err
}
}
return lastErr
}