kra/pkg/utils/ast/ast.go

424 lines
12 KiB
Go

package ast
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"log"
)
// AddImport 增加 import 方法
func AddImport(astNode ast.Node, imp string) {
impStr := fmt.Sprintf("\"%s\"", imp)
ast.Inspect(astNode, func(node ast.Node) bool {
if genDecl, ok := node.(*ast.GenDecl); ok {
if genDecl.Tok == token.IMPORT {
for i := range genDecl.Specs {
if impNode, ok := genDecl.Specs[i].(*ast.ImportSpec); ok {
if impNode.Path.Value == impStr {
return false
}
}
}
genDecl.Specs = append(genDecl.Specs, &ast.ImportSpec{
Path: &ast.BasicLit{
Kind: token.STRING,
Value: impStr,
},
})
}
}
return true
})
}
// FindFunction 查询特定function方法
func FindFunction(astNode ast.Node, FunctionName string) *ast.FuncDecl {
var funcDeclP *ast.FuncDecl
ast.Inspect(astNode, func(node ast.Node) bool {
if funcDecl, ok := node.(*ast.FuncDecl); ok {
if funcDecl.Name.String() == FunctionName {
funcDeclP = funcDecl
return false
}
}
return true
})
return funcDeclP
}
// FindArray 查询特定数组方法
func FindArray(astNode ast.Node, identName, selectorExprName string) *ast.CompositeLit {
var assignStmt *ast.CompositeLit
ast.Inspect(astNode, func(n ast.Node) bool {
switch node := n.(type) {
case *ast.AssignStmt:
for _, expr := range node.Rhs {
if exprType, ok := expr.(*ast.CompositeLit); ok {
if arrayType, ok := exprType.Type.(*ast.ArrayType); ok {
sel, ok1 := arrayType.Elt.(*ast.SelectorExpr)
x, ok2 := sel.X.(*ast.Ident)
if ok1 && ok2 && x.Name == identName && sel.Sel.Name == selectorExprName {
assignStmt = exprType
return false
}
}
}
}
}
return true
})
return assignStmt
}
// CheckImport 检查是否存在Import
func CheckImport(file *ast.File, importPath string) bool {
for _, imp := range file.Imports {
path := imp.Path.Value[1 : len(imp.Path.Value)-1]
if path == importPath {
return true
}
}
return false
}
// ClearPosition 清除AST节点位置信息
func ClearPosition(astNode ast.Node) {
ast.Inspect(astNode, func(n ast.Node) bool {
switch node := n.(type) {
case *ast.Ident:
node.NamePos = token.NoPos
case *ast.CallExpr:
node.Lparen = token.NoPos
node.Rparen = token.NoPos
case *ast.BasicLit:
node.ValuePos = token.NoPos
case *ast.SelectorExpr:
node.Sel.NamePos = token.NoPos
case *ast.BinaryExpr:
node.OpPos = token.NoPos
case *ast.UnaryExpr:
node.OpPos = token.NoPos
case *ast.StarExpr:
node.Star = token.NoPos
}
return true
})
}
// CreateStmt 创建语句
func CreateStmt(statement string) *ast.ExprStmt {
expr, err := parser.ParseExpr(statement)
if err != nil {
log.Fatal(err)
}
ClearPosition(expr)
return &ast.ExprStmt{X: expr}
}
// IsBlockStmt 判断是否为块语句
func IsBlockStmt(node ast.Node) bool {
_, ok := node.(*ast.BlockStmt)
return ok
}
// VariableExistsInBlock 检查变量是否存在于块中
func VariableExistsInBlock(block *ast.BlockStmt, varName string) bool {
exists := false
ast.Inspect(block, func(n ast.Node) bool {
switch node := n.(type) {
case *ast.AssignStmt:
for _, expr := range node.Lhs {
if ident, ok := expr.(*ast.Ident); ok && ident.Name == varName {
exists = true
return false
}
}
}
return true
})
return exists
}
// RemoveImport 移除 import
func RemoveImport(file *ast.File, importPath string) {
impStr := fmt.Sprintf("\"%s\"", importPath)
for i := 0; i < len(file.Decls); i++ {
if genDecl, ok := file.Decls[i].(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT {
for j := 0; j < len(genDecl.Specs); j++ {
if impSpec, ok := genDecl.Specs[j].(*ast.ImportSpec); ok {
if impSpec.Path.Value == impStr {
genDecl.Specs = append(genDecl.Specs[:j], genDecl.Specs[j+1:]...)
if len(genDecl.Specs) == 0 {
file.Decls = append(file.Decls[:i], file.Decls[i+1:]...)
}
return
}
}
}
}
}
}
// AddStructField 添加结构体字段
func AddStructField(file *ast.File, structName, fieldName, fieldType string) {
ast.Inspect(file, func(node ast.Node) bool {
if typeSpec, ok := node.(*ast.TypeSpec); ok && typeSpec.Name.Name == structName {
if structType, ok := typeSpec.Type.(*ast.StructType); ok {
// 检查字段是否已存在
for _, field := range structType.Fields.List {
for _, name := range field.Names {
if name.Name == fieldName {
return false
}
}
}
// 添加新字段
newField := &ast.Field{
Names: []*ast.Ident{{Name: fieldName}},
Type: &ast.Ident{Name: fieldType},
}
structType.Fields.List = append(structType.Fields.List, newField)
return false
}
}
return true
})
}
// RemoveStructField 移除结构体字段
func RemoveStructField(file *ast.File, structName, fieldName string) {
ast.Inspect(file, func(node ast.Node) bool {
if typeSpec, ok := node.(*ast.TypeSpec); ok && typeSpec.Name.Name == structName {
if structType, ok := typeSpec.Type.(*ast.StructType); ok {
for i, field := range structType.Fields.List {
for _, name := range field.Names {
if name.Name == fieldName {
structType.Fields.List = append(structType.Fields.List[:i], structType.Fields.List[i+1:]...)
return false
}
}
}
}
}
return true
})
}
// AddFuncCall 在函数体中添加函数调用
func AddFuncCall(file *ast.File, funcName, callExpr string) {
funcDecl := FindFunction(file, funcName)
if funcDecl == nil || funcDecl.Body == nil {
return
}
stmt := CreateStmt(callExpr)
funcDecl.Body.List = append(funcDecl.Body.List, stmt)
}
// RemoveFuncCall 从函数体中移除函数调用
func RemoveFuncCall(file *ast.File, funcName, targetFunc string) {
funcDecl := FindFunction(file, funcName)
if funcDecl == nil || funcDecl.Body == nil {
return
}
for i := 0; i < len(funcDecl.Body.List); i++ {
if exprStmt, ok := funcDecl.Body.List[i].(*ast.ExprStmt); ok {
if callExpr, ok := exprStmt.X.(*ast.CallExpr); ok {
if selExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok {
if selExpr.Sel.Name == targetFunc {
funcDecl.Body.List = append(funcDecl.Body.List[:i], funcDecl.Body.List[i+1:]...)
return
}
}
}
}
}
}
// AddAutoMigrateModel 添加 AutoMigrate 模型
func AddAutoMigrateModel(file *ast.File, modelExpr string) {
ast.Inspect(file, func(node ast.Node) bool {
if callExpr, ok := node.(*ast.CallExpr); ok {
if selExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok {
if selExpr.Sel.Name == "AutoMigrate" {
// 检查是否已存在
for _, arg := range callExpr.Args {
if unaryExpr, ok := arg.(*ast.UnaryExpr); ok {
if compLit, ok := unaryExpr.X.(*ast.CompositeLit); ok {
if selType, ok := compLit.Type.(*ast.SelectorExpr); ok {
existingModel := fmt.Sprintf("&%s.%s{}", selType.X.(*ast.Ident).Name, selType.Sel.Name)
if existingModel == modelExpr {
return false
}
}
}
}
}
// 添加新模型
expr, err := parser.ParseExpr(modelExpr)
if err == nil {
ClearPosition(expr)
callExpr.Args = append(callExpr.Args, expr)
}
return false
}
}
}
return true
})
}
// RemoveAutoMigrateModel 移除 AutoMigrate 模型
func RemoveAutoMigrateModel(file *ast.File, modelName string) {
ast.Inspect(file, func(node ast.Node) bool {
if callExpr, ok := node.(*ast.CallExpr); ok {
if selExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok {
if selExpr.Sel.Name == "AutoMigrate" {
for i := 0; i < len(callExpr.Args); i++ {
if unaryExpr, ok := callExpr.Args[i].(*ast.UnaryExpr); ok {
if compLit, ok := unaryExpr.X.(*ast.CompositeLit); ok {
if selType, ok := compLit.Type.(*ast.SelectorExpr); ok {
existingModel := fmt.Sprintf("%s.%s", selType.X.(*ast.Ident).Name, selType.Sel.Name)
if existingModel == modelName {
callExpr.Args = append(callExpr.Args[:i], callExpr.Args[i+1:]...)
return false
}
}
}
}
}
}
}
}
return true
})
}
// AddGormGenModel 向 GORM Gen 的 g.ApplyBasic 调用中添加新模型
// 用于自动更新 cmd/gen/main.go 文件
func AddGormGenModel(file *ast.File, tableName string) {
ast.Inspect(file, func(node ast.Node) bool {
if callExpr, ok := node.(*ast.CallExpr); ok {
if selExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok {
// 查找 g.ApplyBasic 调用
if ident, ok := selExpr.X.(*ast.Ident); ok && ident.Name == "g" && selExpr.Sel.Name == "ApplyBasic" {
// 检查是否已存在该表
for _, arg := range callExpr.Args {
if call, ok := arg.(*ast.CallExpr); ok {
if sel, ok := call.Fun.(*ast.SelectorExpr); ok {
if sel.Sel.Name == "GenerateModel" && len(call.Args) > 0 {
if lit, ok := call.Args[0].(*ast.BasicLit); ok {
existingTable := lit.Value[1 : len(lit.Value)-1] // 去掉引号
if existingTable == tableName {
return false // 已存在,不需要添加
}
}
}
}
}
}
// 添加新模型: g.GenerateModel("tableName", fieldOpts...)
newArg := &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: &ast.Ident{Name: "g"},
Sel: &ast.Ident{Name: "GenerateModel"},
},
Args: []ast.Expr{
&ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", tableName)},
&ast.Ident{Name: "fieldOpts..."},
},
}
callExpr.Args = append(callExpr.Args, newArg)
return false
}
}
}
return true
})
}
// RemoveGormGenModel 从 GORM Gen 的 g.ApplyBasic 调用中移除模型
func RemoveGormGenModel(file *ast.File, tableName string) {
ast.Inspect(file, func(node ast.Node) bool {
if callExpr, ok := node.(*ast.CallExpr); ok {
if selExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok {
if ident, ok := selExpr.X.(*ast.Ident); ok && ident.Name == "g" && selExpr.Sel.Name == "ApplyBasic" {
newArgs := make([]ast.Expr, 0, len(callExpr.Args))
for _, arg := range callExpr.Args {
if call, ok := arg.(*ast.CallExpr); ok {
if sel, ok := call.Fun.(*ast.SelectorExpr); ok {
if sel.Sel.Name == "GenerateModel" && len(call.Args) > 0 {
if lit, ok := call.Args[0].(*ast.BasicLit); ok {
existingTable := lit.Value[1 : len(lit.Value)-1]
if existingTable == tableName {
continue // 跳过要移除的模型
}
}
}
}
}
newArgs = append(newArgs, arg)
}
callExpr.Args = newArgs
return false
}
}
}
return true
})
}
// AddGormGenModelToSlice 向 models 切片中添加新模型
// 用于处理 cmd/gen/main.go 中的 models = append(models, g.GenerateModel(...)) 模式
func AddGormGenModelToSlice(file *ast.File, tableName string) bool {
var found bool
ast.Inspect(file, func(node ast.Node) bool {
// 查找 for 循环中的 models = append(models, ...) 语句
if forStmt, ok := node.(*ast.ForStmt); ok {
ast.Inspect(forStmt.Body, func(n ast.Node) bool {
if assignStmt, ok := n.(*ast.AssignStmt); ok {
for _, rhs := range assignStmt.Rhs {
if callExpr, ok := rhs.(*ast.CallExpr); ok {
if ident, ok := callExpr.Fun.(*ast.Ident); ok && ident.Name == "append" {
// 找到 append 调用,检查是否是 models 切片
if len(callExpr.Args) >= 1 {
if argIdent, ok := callExpr.Args[0].(*ast.Ident); ok && argIdent.Name == "models" {
found = true
return false
}
}
}
}
}
}
return true
})
}
return true
})
return found
}
// CheckGormGenModelExists 检查 GORM Gen 模型是否已存在
func CheckGormGenModelExists(file *ast.File, tableName string) bool {
var exists bool
ast.Inspect(file, func(node ast.Node) bool {
if callExpr, ok := node.(*ast.CallExpr); ok {
if selExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok {
if selExpr.Sel.Name == "GenerateModel" && len(callExpr.Args) > 0 {
if lit, ok := callExpr.Args[0].(*ast.BasicLit); ok {
existingTable := lit.Value[1 : len(lit.Value)-1]
if existingTable == tableName {
exists = true
return false
}
}
}
}
}
return true
})
return exists
}