424 lines
12 KiB
Go
424 lines
12 KiB
Go
package ast
|
|
|
|
import (
|
|
"fmt"
|
|
"go/ast"
|
|
"go/parser"
|
|
"go/token"
|
|
"log"
|
|
)
|
|
|
|
// AddImport 增加 import 方法
|
|
func AddImport(astNode ast.Node, imp string) {
|
|
impStr := fmt.Sprintf("\"%s\"", imp)
|
|
ast.Inspect(astNode, func(node ast.Node) bool {
|
|
if genDecl, ok := node.(*ast.GenDecl); ok {
|
|
if genDecl.Tok == token.IMPORT {
|
|
for i := range genDecl.Specs {
|
|
if impNode, ok := genDecl.Specs[i].(*ast.ImportSpec); ok {
|
|
if impNode.Path.Value == impStr {
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
genDecl.Specs = append(genDecl.Specs, &ast.ImportSpec{
|
|
Path: &ast.BasicLit{
|
|
Kind: token.STRING,
|
|
Value: impStr,
|
|
},
|
|
})
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
|
|
// FindFunction 查询特定function方法
|
|
func FindFunction(astNode ast.Node, FunctionName string) *ast.FuncDecl {
|
|
var funcDeclP *ast.FuncDecl
|
|
ast.Inspect(astNode, func(node ast.Node) bool {
|
|
if funcDecl, ok := node.(*ast.FuncDecl); ok {
|
|
if funcDecl.Name.String() == FunctionName {
|
|
funcDeclP = funcDecl
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
return funcDeclP
|
|
}
|
|
|
|
// FindArray 查询特定数组方法
|
|
func FindArray(astNode ast.Node, identName, selectorExprName string) *ast.CompositeLit {
|
|
var assignStmt *ast.CompositeLit
|
|
ast.Inspect(astNode, func(n ast.Node) bool {
|
|
switch node := n.(type) {
|
|
case *ast.AssignStmt:
|
|
for _, expr := range node.Rhs {
|
|
if exprType, ok := expr.(*ast.CompositeLit); ok {
|
|
if arrayType, ok := exprType.Type.(*ast.ArrayType); ok {
|
|
sel, ok1 := arrayType.Elt.(*ast.SelectorExpr)
|
|
x, ok2 := sel.X.(*ast.Ident)
|
|
if ok1 && ok2 && x.Name == identName && sel.Sel.Name == selectorExprName {
|
|
assignStmt = exprType
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
return assignStmt
|
|
}
|
|
|
|
// CheckImport 检查是否存在Import
|
|
func CheckImport(file *ast.File, importPath string) bool {
|
|
for _, imp := range file.Imports {
|
|
path := imp.Path.Value[1 : len(imp.Path.Value)-1]
|
|
if path == importPath {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// ClearPosition 清除AST节点位置信息
|
|
func ClearPosition(astNode ast.Node) {
|
|
ast.Inspect(astNode, func(n ast.Node) bool {
|
|
switch node := n.(type) {
|
|
case *ast.Ident:
|
|
node.NamePos = token.NoPos
|
|
case *ast.CallExpr:
|
|
node.Lparen = token.NoPos
|
|
node.Rparen = token.NoPos
|
|
case *ast.BasicLit:
|
|
node.ValuePos = token.NoPos
|
|
case *ast.SelectorExpr:
|
|
node.Sel.NamePos = token.NoPos
|
|
case *ast.BinaryExpr:
|
|
node.OpPos = token.NoPos
|
|
case *ast.UnaryExpr:
|
|
node.OpPos = token.NoPos
|
|
case *ast.StarExpr:
|
|
node.Star = token.NoPos
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
|
|
// CreateStmt 创建语句
|
|
func CreateStmt(statement string) *ast.ExprStmt {
|
|
expr, err := parser.ParseExpr(statement)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
ClearPosition(expr)
|
|
return &ast.ExprStmt{X: expr}
|
|
}
|
|
|
|
// IsBlockStmt 判断是否为块语句
|
|
func IsBlockStmt(node ast.Node) bool {
|
|
_, ok := node.(*ast.BlockStmt)
|
|
return ok
|
|
}
|
|
|
|
// VariableExistsInBlock 检查变量是否存在于块中
|
|
func VariableExistsInBlock(block *ast.BlockStmt, varName string) bool {
|
|
exists := false
|
|
ast.Inspect(block, func(n ast.Node) bool {
|
|
switch node := n.(type) {
|
|
case *ast.AssignStmt:
|
|
for _, expr := range node.Lhs {
|
|
if ident, ok := expr.(*ast.Ident); ok && ident.Name == varName {
|
|
exists = true
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
return exists
|
|
}
|
|
|
|
// RemoveImport 移除 import
|
|
func RemoveImport(file *ast.File, importPath string) {
|
|
impStr := fmt.Sprintf("\"%s\"", importPath)
|
|
for i := 0; i < len(file.Decls); i++ {
|
|
if genDecl, ok := file.Decls[i].(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT {
|
|
for j := 0; j < len(genDecl.Specs); j++ {
|
|
if impSpec, ok := genDecl.Specs[j].(*ast.ImportSpec); ok {
|
|
if impSpec.Path.Value == impStr {
|
|
genDecl.Specs = append(genDecl.Specs[:j], genDecl.Specs[j+1:]...)
|
|
if len(genDecl.Specs) == 0 {
|
|
file.Decls = append(file.Decls[:i], file.Decls[i+1:]...)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// AddStructField 添加结构体字段
|
|
func AddStructField(file *ast.File, structName, fieldName, fieldType string) {
|
|
ast.Inspect(file, func(node ast.Node) bool {
|
|
if typeSpec, ok := node.(*ast.TypeSpec); ok && typeSpec.Name.Name == structName {
|
|
if structType, ok := typeSpec.Type.(*ast.StructType); ok {
|
|
// 检查字段是否已存在
|
|
for _, field := range structType.Fields.List {
|
|
for _, name := range field.Names {
|
|
if name.Name == fieldName {
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
// 添加新字段
|
|
newField := &ast.Field{
|
|
Names: []*ast.Ident{{Name: fieldName}},
|
|
Type: &ast.Ident{Name: fieldType},
|
|
}
|
|
structType.Fields.List = append(structType.Fields.List, newField)
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
|
|
// RemoveStructField 移除结构体字段
|
|
func RemoveStructField(file *ast.File, structName, fieldName string) {
|
|
ast.Inspect(file, func(node ast.Node) bool {
|
|
if typeSpec, ok := node.(*ast.TypeSpec); ok && typeSpec.Name.Name == structName {
|
|
if structType, ok := typeSpec.Type.(*ast.StructType); ok {
|
|
for i, field := range structType.Fields.List {
|
|
for _, name := range field.Names {
|
|
if name.Name == fieldName {
|
|
structType.Fields.List = append(structType.Fields.List[:i], structType.Fields.List[i+1:]...)
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
|
|
// AddFuncCall 在函数体中添加函数调用
|
|
func AddFuncCall(file *ast.File, funcName, callExpr string) {
|
|
funcDecl := FindFunction(file, funcName)
|
|
if funcDecl == nil || funcDecl.Body == nil {
|
|
return
|
|
}
|
|
stmt := CreateStmt(callExpr)
|
|
funcDecl.Body.List = append(funcDecl.Body.List, stmt)
|
|
}
|
|
|
|
// RemoveFuncCall 从函数体中移除函数调用
|
|
func RemoveFuncCall(file *ast.File, funcName, targetFunc string) {
|
|
funcDecl := FindFunction(file, funcName)
|
|
if funcDecl == nil || funcDecl.Body == nil {
|
|
return
|
|
}
|
|
for i := 0; i < len(funcDecl.Body.List); i++ {
|
|
if exprStmt, ok := funcDecl.Body.List[i].(*ast.ExprStmt); ok {
|
|
if callExpr, ok := exprStmt.X.(*ast.CallExpr); ok {
|
|
if selExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok {
|
|
if selExpr.Sel.Name == targetFunc {
|
|
funcDecl.Body.List = append(funcDecl.Body.List[:i], funcDecl.Body.List[i+1:]...)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// AddAutoMigrateModel 添加 AutoMigrate 模型
|
|
func AddAutoMigrateModel(file *ast.File, modelExpr string) {
|
|
ast.Inspect(file, func(node ast.Node) bool {
|
|
if callExpr, ok := node.(*ast.CallExpr); ok {
|
|
if selExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok {
|
|
if selExpr.Sel.Name == "AutoMigrate" {
|
|
// 检查是否已存在
|
|
for _, arg := range callExpr.Args {
|
|
if unaryExpr, ok := arg.(*ast.UnaryExpr); ok {
|
|
if compLit, ok := unaryExpr.X.(*ast.CompositeLit); ok {
|
|
if selType, ok := compLit.Type.(*ast.SelectorExpr); ok {
|
|
existingModel := fmt.Sprintf("&%s.%s{}", selType.X.(*ast.Ident).Name, selType.Sel.Name)
|
|
if existingModel == modelExpr {
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// 添加新模型
|
|
expr, err := parser.ParseExpr(modelExpr)
|
|
if err == nil {
|
|
ClearPosition(expr)
|
|
callExpr.Args = append(callExpr.Args, expr)
|
|
}
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
|
|
// RemoveAutoMigrateModel 移除 AutoMigrate 模型
|
|
func RemoveAutoMigrateModel(file *ast.File, modelName string) {
|
|
ast.Inspect(file, func(node ast.Node) bool {
|
|
if callExpr, ok := node.(*ast.CallExpr); ok {
|
|
if selExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok {
|
|
if selExpr.Sel.Name == "AutoMigrate" {
|
|
for i := 0; i < len(callExpr.Args); i++ {
|
|
if unaryExpr, ok := callExpr.Args[i].(*ast.UnaryExpr); ok {
|
|
if compLit, ok := unaryExpr.X.(*ast.CompositeLit); ok {
|
|
if selType, ok := compLit.Type.(*ast.SelectorExpr); ok {
|
|
existingModel := fmt.Sprintf("%s.%s", selType.X.(*ast.Ident).Name, selType.Sel.Name)
|
|
if existingModel == modelName {
|
|
callExpr.Args = append(callExpr.Args[:i], callExpr.Args[i+1:]...)
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
|
|
// AddGormGenModel 向 GORM Gen 的 g.ApplyBasic 调用中添加新模型
|
|
// 用于自动更新 cmd/gen/main.go 文件
|
|
func AddGormGenModel(file *ast.File, tableName string) {
|
|
ast.Inspect(file, func(node ast.Node) bool {
|
|
if callExpr, ok := node.(*ast.CallExpr); ok {
|
|
if selExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok {
|
|
// 查找 g.ApplyBasic 调用
|
|
if ident, ok := selExpr.X.(*ast.Ident); ok && ident.Name == "g" && selExpr.Sel.Name == "ApplyBasic" {
|
|
// 检查是否已存在该表
|
|
for _, arg := range callExpr.Args {
|
|
if call, ok := arg.(*ast.CallExpr); ok {
|
|
if sel, ok := call.Fun.(*ast.SelectorExpr); ok {
|
|
if sel.Sel.Name == "GenerateModel" && len(call.Args) > 0 {
|
|
if lit, ok := call.Args[0].(*ast.BasicLit); ok {
|
|
existingTable := lit.Value[1 : len(lit.Value)-1] // 去掉引号
|
|
if existingTable == tableName {
|
|
return false // 已存在,不需要添加
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// 添加新模型: g.GenerateModel("tableName", fieldOpts...)
|
|
newArg := &ast.CallExpr{
|
|
Fun: &ast.SelectorExpr{
|
|
X: &ast.Ident{Name: "g"},
|
|
Sel: &ast.Ident{Name: "GenerateModel"},
|
|
},
|
|
Args: []ast.Expr{
|
|
&ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", tableName)},
|
|
&ast.Ident{Name: "fieldOpts..."},
|
|
},
|
|
}
|
|
callExpr.Args = append(callExpr.Args, newArg)
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
|
|
// RemoveGormGenModel 从 GORM Gen 的 g.ApplyBasic 调用中移除模型
|
|
func RemoveGormGenModel(file *ast.File, tableName string) {
|
|
ast.Inspect(file, func(node ast.Node) bool {
|
|
if callExpr, ok := node.(*ast.CallExpr); ok {
|
|
if selExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok {
|
|
if ident, ok := selExpr.X.(*ast.Ident); ok && ident.Name == "g" && selExpr.Sel.Name == "ApplyBasic" {
|
|
newArgs := make([]ast.Expr, 0, len(callExpr.Args))
|
|
for _, arg := range callExpr.Args {
|
|
if call, ok := arg.(*ast.CallExpr); ok {
|
|
if sel, ok := call.Fun.(*ast.SelectorExpr); ok {
|
|
if sel.Sel.Name == "GenerateModel" && len(call.Args) > 0 {
|
|
if lit, ok := call.Args[0].(*ast.BasicLit); ok {
|
|
existingTable := lit.Value[1 : len(lit.Value)-1]
|
|
if existingTable == tableName {
|
|
continue // 跳过要移除的模型
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
newArgs = append(newArgs, arg)
|
|
}
|
|
callExpr.Args = newArgs
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
|
|
// AddGormGenModelToSlice 向 models 切片中添加新模型
|
|
// 用于处理 cmd/gen/main.go 中的 models = append(models, g.GenerateModel(...)) 模式
|
|
func AddGormGenModelToSlice(file *ast.File, tableName string) bool {
|
|
var found bool
|
|
ast.Inspect(file, func(node ast.Node) bool {
|
|
// 查找 for 循环中的 models = append(models, ...) 语句
|
|
if forStmt, ok := node.(*ast.ForStmt); ok {
|
|
ast.Inspect(forStmt.Body, func(n ast.Node) bool {
|
|
if assignStmt, ok := n.(*ast.AssignStmt); ok {
|
|
for _, rhs := range assignStmt.Rhs {
|
|
if callExpr, ok := rhs.(*ast.CallExpr); ok {
|
|
if ident, ok := callExpr.Fun.(*ast.Ident); ok && ident.Name == "append" {
|
|
// 找到 append 调用,检查是否是 models 切片
|
|
if len(callExpr.Args) >= 1 {
|
|
if argIdent, ok := callExpr.Args[0].(*ast.Ident); ok && argIdent.Name == "models" {
|
|
found = true
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
return true
|
|
})
|
|
return found
|
|
}
|
|
|
|
// CheckGormGenModelExists 检查 GORM Gen 模型是否已存在
|
|
func CheckGormGenModelExists(file *ast.File, tableName string) bool {
|
|
var exists bool
|
|
ast.Inspect(file, func(node ast.Node) bool {
|
|
if callExpr, ok := node.(*ast.CallExpr); ok {
|
|
if selExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok {
|
|
if selExpr.Sel.Name == "GenerateModel" && len(callExpr.Args) > 0 {
|
|
if lit, ok := callExpr.Args[0].(*ast.BasicLit); ok {
|
|
existingTable := lit.Value[1 : len(lit.Value)-1]
|
|
if existingTable == tableName {
|
|
exists = true
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
return exists
|
|
}
|