package ast import ( "go/ast" "go/parser" "go/token" "strings" "testing" ) // TestAddImport 测试添加导入 func TestAddImport(t *testing.T) { testCode := `package main import "fmt" func main() {} ` fset := token.NewFileSet() file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments) if err != nil { t.Fatalf("解析测试代码失败: %v", err) } // 添加新导入 AddImport(file, "kra/internal/biz/article") // 验证导入是否添加 var found bool ast.Inspect(file, func(node ast.Node) bool { if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT { for _, spec := range genDecl.Specs { if impSpec, ok := spec.(*ast.ImportSpec); ok { if impSpec.Path.Value == "\"kra/internal/biz/article\"" { found = true return false } } } } return true }) if !found { t.Error("应该添加 kra/internal/biz/article 导入") } } // TestAddImport_Duplicate 测试重复添加导入 func TestAddImport_Duplicate(t *testing.T) { testCode := `package main import ( "fmt" "kra/internal/biz/article" ) func main() {} ` fset := token.NewFileSet() file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments) if err != nil { t.Fatalf("解析测试代码失败: %v", err) } // 统计初始导入数量 initialCount := 0 ast.Inspect(file, func(node ast.Node) bool { if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT { initialCount = len(genDecl.Specs) return false } return true }) // 尝试添加已存在的导入 AddImport(file, "kra/internal/biz/article") // 验证导入数量没有增加 finalCount := 0 ast.Inspect(file, func(node ast.Node) bool { if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT { finalCount = len(genDecl.Specs) return false } return true }) if finalCount != initialCount { t.Errorf("重复添加导入后数量应保持 %d, 实际为 %d", initialCount, finalCount) } } // TestRemoveImport 测试移除导入 func TestRemoveImport(t *testing.T) { testCode := `package main import ( "fmt" "kra/internal/biz/article" ) func main() {} ` fset := token.NewFileSet() file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments) if err != nil { t.Fatalf("解析测试代码失败: %v", err) } // 移除导入 RemoveImport(file, "kra/internal/biz/article") // 验证导入是否移除 var found bool ast.Inspect(file, func(node ast.Node) bool { if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT { for _, spec := range genDecl.Specs { if impSpec, ok := spec.(*ast.ImportSpec); ok { if impSpec.Path.Value == "\"kra/internal/biz/article\"" { found = true return false } } } } return true }) if found { t.Error("应该移除 kra/internal/biz/article 导入") } } // TestAddStructField 测试添加结构体字段 func TestAddStructField(t *testing.T) { testCode := `package main type MyStruct struct { Name string } ` fset := token.NewFileSet() file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments) if err != nil { t.Fatalf("解析测试代码失败: %v", err) } // 添加新字段 AddStructField(file, "MyStruct", "Age", "int") // 验证字段是否添加 var found bool ast.Inspect(file, func(node ast.Node) bool { if typeSpec, ok := node.(*ast.TypeSpec); ok && typeSpec.Name.Name == "MyStruct" { if structType, ok := typeSpec.Type.(*ast.StructType); ok { for _, field := range structType.Fields.List { for _, name := range field.Names { if name.Name == "Age" { found = true return false } } } } } return true }) if !found { t.Error("应该添加 Age 字段") } } // TestAddStructField_Duplicate 测试重复添加结构体字段 func TestAddStructField_Duplicate(t *testing.T) { testCode := `package main type MyStruct struct { Name string Age int } ` fset := token.NewFileSet() file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments) if err != nil { t.Fatalf("解析测试代码失败: %v", err) } // 统计初始字段数量 initialCount := 0 ast.Inspect(file, func(node ast.Node) bool { if typeSpec, ok := node.(*ast.TypeSpec); ok && typeSpec.Name.Name == "MyStruct" { if structType, ok := typeSpec.Type.(*ast.StructType); ok { initialCount = len(structType.Fields.List) return false } } return true }) // 尝试添加已存在的字段 AddStructField(file, "MyStruct", "Age", "int") // 验证字段数量没有增加 finalCount := 0 ast.Inspect(file, func(node ast.Node) bool { if typeSpec, ok := node.(*ast.TypeSpec); ok && typeSpec.Name.Name == "MyStruct" { if structType, ok := typeSpec.Type.(*ast.StructType); ok { finalCount = len(structType.Fields.List) return false } } return true }) if finalCount != initialCount { t.Errorf("重复添加字段后数量应保持 %d, 实际为 %d", initialCount, finalCount) } } // TestRemoveStructField 测试移除结构体字段 func TestRemoveStructField(t *testing.T) { testCode := `package main type MyStruct struct { Name string Age int } ` fset := token.NewFileSet() file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments) if err != nil { t.Fatalf("解析测试代码失败: %v", err) } // 移除字段 RemoveStructField(file, "MyStruct", "Age") // 验证字段是否移除 var found bool ast.Inspect(file, func(node ast.Node) bool { if typeSpec, ok := node.(*ast.TypeSpec); ok && typeSpec.Name.Name == "MyStruct" { if structType, ok := typeSpec.Type.(*ast.StructType); ok { for _, field := range structType.Fields.List { for _, name := range field.Names { if name.Name == "Age" { found = true return false } } } } } return true }) if found { t.Error("应该移除 Age 字段") } } // TestFindFunction 测试查找函数 func TestFindFunction(t *testing.T) { testCode := `package main func Hello() {} func World() {} ` fset := token.NewFileSet() file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments) if err != nil { t.Fatalf("解析测试代码失败: %v", err) } // 查找存在的函数 funcDecl := FindFunction(file, "Hello") if funcDecl == nil { t.Error("应该找到 Hello 函数") } // 查找不存在的函数 funcDecl = FindFunction(file, "NotExist") if funcDecl != nil { t.Error("不应该找到 NotExist 函数") } } // TestCheckImport 测试检查导入是否存在 func TestCheckImport(t *testing.T) { testCode := `package main import ( "fmt" "kra/internal/biz/article" ) func main() {} ` fset := token.NewFileSet() file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments) if err != nil { t.Fatalf("解析测试代码失败: %v", err) } // 检查存在的导入 if !CheckImport(file, "fmt") { t.Error("应该找到 fmt 导入") } // 检查存在的导入(完整路径) if !CheckImport(file, "kra/internal/biz/article") { t.Error("应该找到 kra/internal/biz/article 导入") } // 检查不存在的导入 if CheckImport(file, "not/exist") { t.Error("不应该找到 not/exist 导入") } } // TestAddFuncCall 测试在函数体中添加函数调用 func TestAddFuncCall(t *testing.T) { testCode := `package main func initRouter() { fmt.Println("init") } ` fset := token.NewFileSet() file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments) if err != nil { t.Fatalf("解析测试代码失败: %v", err) } // 添加函数调用 AddFuncCall(file, "initRouter", "router.Article.InitArticleRouter(group)") // 验证函数调用是否添加 funcDecl := FindFunction(file, "initRouter") if funcDecl == nil || funcDecl.Body == nil { t.Fatal("找不到 initRouter 函数") } // 检查函数体中是否包含新的调用 if len(funcDecl.Body.List) < 2 { t.Error("应该添加新的函数调用") } } // TestRemoveFuncCall 测试从函数体中移除函数调用 func TestRemoveFuncCall(t *testing.T) { testCode := `package main func initRouter() { fmt.Println("init") router.Article.InitArticleRouter(group) } ` fset := token.NewFileSet() file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments) if err != nil { t.Fatalf("解析测试代码失败: %v", err) } // 移除函数调用 RemoveFuncCall(file, "initRouter", "InitArticleRouter") // 验证函数调用是否移除 funcDecl := FindFunction(file, "initRouter") if funcDecl == nil || funcDecl.Body == nil { t.Fatal("找不到 initRouter 函数") } // 检查函数体中是否还包含该调用 for _, stmt := range funcDecl.Body.List { if exprStmt, ok := stmt.(*ast.ExprStmt); ok { if callExpr, ok := exprStmt.X.(*ast.CallExpr); ok { if selExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok { if selExpr.Sel.Name == "InitArticleRouter" { t.Error("应该移除 InitArticleRouter 调用") } } } } } } // TestCreateStmt 测试创建语句 func TestCreateStmt(t *testing.T) { stmt := CreateStmt("fmt.Println(\"hello\")") if stmt == nil { t.Error("应该创建语句") } if stmt.X == nil { t.Error("语句表达式不应为空") } } // TestVariableExistsInBlock 测试检查变量是否存在于块中 func TestVariableExistsInBlock(t *testing.T) { testCode := `package main func test() { x := 1 y := 2 } ` fset := token.NewFileSet() file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments) if err != nil { t.Fatalf("解析测试代码失败: %v", err) } funcDecl := FindFunction(file, "test") if funcDecl == nil || funcDecl.Body == nil { t.Fatal("找不到 test 函数") } // 检查存在的变量 if !VariableExistsInBlock(funcDecl.Body, "x") { t.Error("应该找到变量 x") } // 检查不存在的变量 if VariableExistsInBlock(funcDecl.Body, "z") { t.Error("不应该找到变量 z") } } // TestBase_RelativePath 测试相对路径转换 func TestBase_RelativePath(t *testing.T) { base := &Base{ AutoCodeRoot: "/home/user/project", AutoCodeServer: "server", } testCases := []struct { input string expected string }{ {"/home/user/project/server/internal/biz/article.go", "internal/biz/article.go"}, {"/other/path/file.go", "/other/path/file.go"}, } for _, tc := range testCases { result := base.RelativePath(tc.input) // 由于路径分隔符可能不同,只检查是否包含关键部分 if !strings.Contains(result, "article.go") && tc.input != "/other/path/file.go" { t.Errorf("RelativePath(%s) 结果不正确: %s", tc.input, result) } } } // TestAddAutoMigrateModel 测试添加 AutoMigrate 模型 func TestAddAutoMigrateModel(t *testing.T) { testCode := `package main import "gorm.io/gorm" func migrate(db *gorm.DB) { db.AutoMigrate(&User{}) } ` fset := token.NewFileSet() file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments) if err != nil { t.Fatalf("解析测试代码失败: %v", err) } // 添加新模型 AddAutoMigrateModel(file, "&article.Article{}") // 验证模型是否添加 var found bool ast.Inspect(file, func(node ast.Node) bool { if callExpr, ok := node.(*ast.CallExpr); ok { if selExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok { if selExpr.Sel.Name == "AutoMigrate" { // 检查参数数量是否增加 if len(callExpr.Args) > 1 { found = true return false } } } } return true }) if !found { t.Error("应该添加新的 AutoMigrate 模型") } } // TestCheckGormGenModelExists 测试检查 GORM Gen 模型是否存在 func TestCheckGormGenModelExists(t *testing.T) { testCode := `package main func main() { g.ApplyBasic( g.GenerateModel("users"), g.GenerateModel("articles"), ) } ` fset := token.NewFileSet() file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments) if err != nil { t.Fatalf("解析测试代码失败: %v", err) } // 检查存在的模型 if !CheckGormGenModelExists(file, "users") { t.Error("应该找到 users 模型") } // 检查不存在的模型 if CheckGormGenModelExists(file, "orders") { t.Error("不应该找到 orders 模型") } }