package ast import ( "fmt" "go/ast" "go/parser" "go/token" "log" ) // AddImport 增加 import 方法 func AddImport(astNode ast.Node, imp string) { impStr := fmt.Sprintf("\"%s\"", imp) ast.Inspect(astNode, func(node ast.Node) bool { if genDecl, ok := node.(*ast.GenDecl); ok { if genDecl.Tok == token.IMPORT { for i := range genDecl.Specs { if impNode, ok := genDecl.Specs[i].(*ast.ImportSpec); ok { if impNode.Path.Value == impStr { return false } } } genDecl.Specs = append(genDecl.Specs, &ast.ImportSpec{ Path: &ast.BasicLit{ Kind: token.STRING, Value: impStr, }, }) } } return true }) } // FindFunction 查询特定function方法 func FindFunction(astNode ast.Node, FunctionName string) *ast.FuncDecl { var funcDeclP *ast.FuncDecl ast.Inspect(astNode, func(node ast.Node) bool { if funcDecl, ok := node.(*ast.FuncDecl); ok { if funcDecl.Name.String() == FunctionName { funcDeclP = funcDecl return false } } return true }) return funcDeclP } // FindArray 查询特定数组方法 func FindArray(astNode ast.Node, identName, selectorExprName string) *ast.CompositeLit { var assignStmt *ast.CompositeLit ast.Inspect(astNode, func(n ast.Node) bool { switch node := n.(type) { case *ast.AssignStmt: for _, expr := range node.Rhs { if exprType, ok := expr.(*ast.CompositeLit); ok { if arrayType, ok := exprType.Type.(*ast.ArrayType); ok { sel, ok1 := arrayType.Elt.(*ast.SelectorExpr) x, ok2 := sel.X.(*ast.Ident) if ok1 && ok2 && x.Name == identName && sel.Sel.Name == selectorExprName { assignStmt = exprType return false } } } } } return true }) return assignStmt } // CheckImport 检查是否存在Import func CheckImport(file *ast.File, importPath string) bool { for _, imp := range file.Imports { path := imp.Path.Value[1 : len(imp.Path.Value)-1] if path == importPath { return true } } return false } // ClearPosition 清除AST节点位置信息 func ClearPosition(astNode ast.Node) { ast.Inspect(astNode, func(n ast.Node) bool { switch node := n.(type) { case *ast.Ident: node.NamePos = token.NoPos case *ast.CallExpr: node.Lparen = token.NoPos node.Rparen = token.NoPos case *ast.BasicLit: node.ValuePos = token.NoPos case *ast.SelectorExpr: node.Sel.NamePos = token.NoPos case *ast.BinaryExpr: node.OpPos = token.NoPos case *ast.UnaryExpr: node.OpPos = token.NoPos case *ast.StarExpr: node.Star = token.NoPos } return true }) } // CreateStmt 创建语句 func CreateStmt(statement string) *ast.ExprStmt { expr, err := parser.ParseExpr(statement) if err != nil { log.Fatal(err) } ClearPosition(expr) return &ast.ExprStmt{X: expr} } // IsBlockStmt 判断是否为块语句 func IsBlockStmt(node ast.Node) bool { _, ok := node.(*ast.BlockStmt) return ok } // VariableExistsInBlock 检查变量是否存在于块中 func VariableExistsInBlock(block *ast.BlockStmt, varName string) bool { exists := false ast.Inspect(block, func(n ast.Node) bool { switch node := n.(type) { case *ast.AssignStmt: for _, expr := range node.Lhs { if ident, ok := expr.(*ast.Ident); ok && ident.Name == varName { exists = true return false } } } return true }) return exists } // RemoveImport 移除 import func RemoveImport(file *ast.File, importPath string) { impStr := fmt.Sprintf("\"%s\"", importPath) for i := 0; i < len(file.Decls); i++ { if genDecl, ok := file.Decls[i].(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT { for j := 0; j < len(genDecl.Specs); j++ { if impSpec, ok := genDecl.Specs[j].(*ast.ImportSpec); ok { if impSpec.Path.Value == impStr { genDecl.Specs = append(genDecl.Specs[:j], genDecl.Specs[j+1:]...) if len(genDecl.Specs) == 0 { file.Decls = append(file.Decls[:i], file.Decls[i+1:]...) } return } } } } } } // AddStructField 添加结构体字段 func AddStructField(file *ast.File, structName, fieldName, fieldType string) { ast.Inspect(file, func(node ast.Node) bool { if typeSpec, ok := node.(*ast.TypeSpec); ok && typeSpec.Name.Name == structName { if structType, ok := typeSpec.Type.(*ast.StructType); ok { // 检查字段是否已存在 for _, field := range structType.Fields.List { for _, name := range field.Names { if name.Name == fieldName { return false } } } // 添加新字段 newField := &ast.Field{ Names: []*ast.Ident{{Name: fieldName}}, Type: &ast.Ident{Name: fieldType}, } structType.Fields.List = append(structType.Fields.List, newField) return false } } return true }) } // RemoveStructField 移除结构体字段 func RemoveStructField(file *ast.File, structName, fieldName string) { ast.Inspect(file, func(node ast.Node) bool { if typeSpec, ok := node.(*ast.TypeSpec); ok && typeSpec.Name.Name == structName { if structType, ok := typeSpec.Type.(*ast.StructType); ok { for i, field := range structType.Fields.List { for _, name := range field.Names { if name.Name == fieldName { structType.Fields.List = append(structType.Fields.List[:i], structType.Fields.List[i+1:]...) return false } } } } } return true }) } // AddFuncCall 在函数体中添加函数调用 func AddFuncCall(file *ast.File, funcName, callExpr string) { funcDecl := FindFunction(file, funcName) if funcDecl == nil || funcDecl.Body == nil { return } stmt := CreateStmt(callExpr) funcDecl.Body.List = append(funcDecl.Body.List, stmt) } // RemoveFuncCall 从函数体中移除函数调用 func RemoveFuncCall(file *ast.File, funcName, targetFunc string) { funcDecl := FindFunction(file, funcName) if funcDecl == nil || funcDecl.Body == nil { return } for i := 0; i < len(funcDecl.Body.List); i++ { if exprStmt, ok := funcDecl.Body.List[i].(*ast.ExprStmt); ok { if callExpr, ok := exprStmt.X.(*ast.CallExpr); ok { if selExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok { if selExpr.Sel.Name == targetFunc { funcDecl.Body.List = append(funcDecl.Body.List[:i], funcDecl.Body.List[i+1:]...) return } } } } } } // AddAutoMigrateModel 添加 AutoMigrate 模型 func AddAutoMigrateModel(file *ast.File, modelExpr string) { ast.Inspect(file, func(node ast.Node) bool { if callExpr, ok := node.(*ast.CallExpr); ok { if selExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok { if selExpr.Sel.Name == "AutoMigrate" { // 检查是否已存在 for _, arg := range callExpr.Args { if unaryExpr, ok := arg.(*ast.UnaryExpr); ok { if compLit, ok := unaryExpr.X.(*ast.CompositeLit); ok { if selType, ok := compLit.Type.(*ast.SelectorExpr); ok { existingModel := fmt.Sprintf("&%s.%s{}", selType.X.(*ast.Ident).Name, selType.Sel.Name) if existingModel == modelExpr { return false } } } } } // 添加新模型 expr, err := parser.ParseExpr(modelExpr) if err == nil { ClearPosition(expr) callExpr.Args = append(callExpr.Args, expr) } return false } } } return true }) } // RemoveAutoMigrateModel 移除 AutoMigrate 模型 func RemoveAutoMigrateModel(file *ast.File, modelName string) { ast.Inspect(file, func(node ast.Node) bool { if callExpr, ok := node.(*ast.CallExpr); ok { if selExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok { if selExpr.Sel.Name == "AutoMigrate" { for i := 0; i < len(callExpr.Args); i++ { if unaryExpr, ok := callExpr.Args[i].(*ast.UnaryExpr); ok { if compLit, ok := unaryExpr.X.(*ast.CompositeLit); ok { if selType, ok := compLit.Type.(*ast.SelectorExpr); ok { existingModel := fmt.Sprintf("%s.%s", selType.X.(*ast.Ident).Name, selType.Sel.Name) if existingModel == modelName { callExpr.Args = append(callExpr.Args[:i], callExpr.Args[i+1:]...) return false } } } } } } } } return true }) } // AddGormGenModel 向 GORM Gen 的 g.ApplyBasic 调用中添加新模型 // 用于自动更新 cmd/gen/main.go 文件 func AddGormGenModel(file *ast.File, tableName string) { ast.Inspect(file, func(node ast.Node) bool { if callExpr, ok := node.(*ast.CallExpr); ok { if selExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok { // 查找 g.ApplyBasic 调用 if ident, ok := selExpr.X.(*ast.Ident); ok && ident.Name == "g" && selExpr.Sel.Name == "ApplyBasic" { // 检查是否已存在该表 for _, arg := range callExpr.Args { if call, ok := arg.(*ast.CallExpr); ok { if sel, ok := call.Fun.(*ast.SelectorExpr); ok { if sel.Sel.Name == "GenerateModel" && len(call.Args) > 0 { if lit, ok := call.Args[0].(*ast.BasicLit); ok { existingTable := lit.Value[1 : len(lit.Value)-1] // 去掉引号 if existingTable == tableName { return false // 已存在,不需要添加 } } } } } } // 添加新模型: g.GenerateModel("tableName", fieldOpts...) newArg := &ast.CallExpr{ Fun: &ast.SelectorExpr{ X: &ast.Ident{Name: "g"}, Sel: &ast.Ident{Name: "GenerateModel"}, }, Args: []ast.Expr{ &ast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("\"%s\"", tableName)}, &ast.Ident{Name: "fieldOpts..."}, }, } callExpr.Args = append(callExpr.Args, newArg) return false } } } return true }) } // RemoveGormGenModel 从 GORM Gen 的 g.ApplyBasic 调用中移除模型 func RemoveGormGenModel(file *ast.File, tableName string) { ast.Inspect(file, func(node ast.Node) bool { if callExpr, ok := node.(*ast.CallExpr); ok { if selExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok { if ident, ok := selExpr.X.(*ast.Ident); ok && ident.Name == "g" && selExpr.Sel.Name == "ApplyBasic" { newArgs := make([]ast.Expr, 0, len(callExpr.Args)) for _, arg := range callExpr.Args { if call, ok := arg.(*ast.CallExpr); ok { if sel, ok := call.Fun.(*ast.SelectorExpr); ok { if sel.Sel.Name == "GenerateModel" && len(call.Args) > 0 { if lit, ok := call.Args[0].(*ast.BasicLit); ok { existingTable := lit.Value[1 : len(lit.Value)-1] if existingTable == tableName { continue // 跳过要移除的模型 } } } } } newArgs = append(newArgs, arg) } callExpr.Args = newArgs return false } } } return true }) } // AddGormGenModelToSlice 向 models 切片中添加新模型 // 用于处理 cmd/gen/main.go 中的 models = append(models, g.GenerateModel(...)) 模式 func AddGormGenModelToSlice(file *ast.File, tableName string) bool { var found bool ast.Inspect(file, func(node ast.Node) bool { // 查找 for 循环中的 models = append(models, ...) 语句 if forStmt, ok := node.(*ast.ForStmt); ok { ast.Inspect(forStmt.Body, func(n ast.Node) bool { if assignStmt, ok := n.(*ast.AssignStmt); ok { for _, rhs := range assignStmt.Rhs { if callExpr, ok := rhs.(*ast.CallExpr); ok { if ident, ok := callExpr.Fun.(*ast.Ident); ok && ident.Name == "append" { // 找到 append 调用,检查是否是 models 切片 if len(callExpr.Args) >= 1 { if argIdent, ok := callExpr.Args[0].(*ast.Ident); ok && argIdent.Name == "models" { found = true return false } } } } } } return true }) } return true }) return found } // CheckGormGenModelExists 检查 GORM Gen 模型是否已存在 func CheckGormGenModelExists(file *ast.File, tableName string) bool { var exists bool ast.Inspect(file, func(node ast.Node) bool { if callExpr, ok := node.(*ast.CallExpr); ok { if selExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok { if selExpr.Sel.Name == "GenerateModel" && len(callExpr.Args) > 0 { if lit, ok := callExpr.Args[0].(*ast.BasicLit); ok { existingTable := lit.Value[1 : len(lit.Value)-1] if existingTable == tableName { exists = true return false } } } } } return true }) return exists }