kra/pkg/utils/ast/ast_test.go

519 lines
12 KiB
Go

package ast
import (
"go/ast"
"go/parser"
"go/token"
"strings"
"testing"
)
// TestAddImport 测试添加导入
func TestAddImport(t *testing.T) {
testCode := `package main
import "fmt"
func main() {}
`
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
if err != nil {
t.Fatalf("解析测试代码失败: %v", err)
}
// 添加新导入
AddImport(file, "kra/internal/biz/article")
// 验证导入是否添加
var found bool
ast.Inspect(file, func(node ast.Node) bool {
if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT {
for _, spec := range genDecl.Specs {
if impSpec, ok := spec.(*ast.ImportSpec); ok {
if impSpec.Path.Value == "\"kra/internal/biz/article\"" {
found = true
return false
}
}
}
}
return true
})
if !found {
t.Error("应该添加 kra/internal/biz/article 导入")
}
}
// TestAddImport_Duplicate 测试重复添加导入
func TestAddImport_Duplicate(t *testing.T) {
testCode := `package main
import (
"fmt"
"kra/internal/biz/article"
)
func main() {}
`
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
if err != nil {
t.Fatalf("解析测试代码失败: %v", err)
}
// 统计初始导入数量
initialCount := 0
ast.Inspect(file, func(node ast.Node) bool {
if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT {
initialCount = len(genDecl.Specs)
return false
}
return true
})
// 尝试添加已存在的导入
AddImport(file, "kra/internal/biz/article")
// 验证导入数量没有增加
finalCount := 0
ast.Inspect(file, func(node ast.Node) bool {
if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT {
finalCount = len(genDecl.Specs)
return false
}
return true
})
if finalCount != initialCount {
t.Errorf("重复添加导入后数量应保持 %d, 实际为 %d", initialCount, finalCount)
}
}
// TestRemoveImport 测试移除导入
func TestRemoveImport(t *testing.T) {
testCode := `package main
import (
"fmt"
"kra/internal/biz/article"
)
func main() {}
`
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
if err != nil {
t.Fatalf("解析测试代码失败: %v", err)
}
// 移除导入
RemoveImport(file, "kra/internal/biz/article")
// 验证导入是否移除
var found bool
ast.Inspect(file, func(node ast.Node) bool {
if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT {
for _, spec := range genDecl.Specs {
if impSpec, ok := spec.(*ast.ImportSpec); ok {
if impSpec.Path.Value == "\"kra/internal/biz/article\"" {
found = true
return false
}
}
}
}
return true
})
if found {
t.Error("应该移除 kra/internal/biz/article 导入")
}
}
// TestAddStructField 测试添加结构体字段
func TestAddStructField(t *testing.T) {
testCode := `package main
type MyStruct struct {
Name string
}
`
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
if err != nil {
t.Fatalf("解析测试代码失败: %v", err)
}
// 添加新字段
AddStructField(file, "MyStruct", "Age", "int")
// 验证字段是否添加
var found bool
ast.Inspect(file, func(node ast.Node) bool {
if typeSpec, ok := node.(*ast.TypeSpec); ok && typeSpec.Name.Name == "MyStruct" {
if structType, ok := typeSpec.Type.(*ast.StructType); ok {
for _, field := range structType.Fields.List {
for _, name := range field.Names {
if name.Name == "Age" {
found = true
return false
}
}
}
}
}
return true
})
if !found {
t.Error("应该添加 Age 字段")
}
}
// TestAddStructField_Duplicate 测试重复添加结构体字段
func TestAddStructField_Duplicate(t *testing.T) {
testCode := `package main
type MyStruct struct {
Name string
Age int
}
`
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
if err != nil {
t.Fatalf("解析测试代码失败: %v", err)
}
// 统计初始字段数量
initialCount := 0
ast.Inspect(file, func(node ast.Node) bool {
if typeSpec, ok := node.(*ast.TypeSpec); ok && typeSpec.Name.Name == "MyStruct" {
if structType, ok := typeSpec.Type.(*ast.StructType); ok {
initialCount = len(structType.Fields.List)
return false
}
}
return true
})
// 尝试添加已存在的字段
AddStructField(file, "MyStruct", "Age", "int")
// 验证字段数量没有增加
finalCount := 0
ast.Inspect(file, func(node ast.Node) bool {
if typeSpec, ok := node.(*ast.TypeSpec); ok && typeSpec.Name.Name == "MyStruct" {
if structType, ok := typeSpec.Type.(*ast.StructType); ok {
finalCount = len(structType.Fields.List)
return false
}
}
return true
})
if finalCount != initialCount {
t.Errorf("重复添加字段后数量应保持 %d, 实际为 %d", initialCount, finalCount)
}
}
// TestRemoveStructField 测试移除结构体字段
func TestRemoveStructField(t *testing.T) {
testCode := `package main
type MyStruct struct {
Name string
Age int
}
`
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
if err != nil {
t.Fatalf("解析测试代码失败: %v", err)
}
// 移除字段
RemoveStructField(file, "MyStruct", "Age")
// 验证字段是否移除
var found bool
ast.Inspect(file, func(node ast.Node) bool {
if typeSpec, ok := node.(*ast.TypeSpec); ok && typeSpec.Name.Name == "MyStruct" {
if structType, ok := typeSpec.Type.(*ast.StructType); ok {
for _, field := range structType.Fields.List {
for _, name := range field.Names {
if name.Name == "Age" {
found = true
return false
}
}
}
}
}
return true
})
if found {
t.Error("应该移除 Age 字段")
}
}
// TestFindFunction 测试查找函数
func TestFindFunction(t *testing.T) {
testCode := `package main
func Hello() {}
func World() {}
`
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
if err != nil {
t.Fatalf("解析测试代码失败: %v", err)
}
// 查找存在的函数
funcDecl := FindFunction(file, "Hello")
if funcDecl == nil {
t.Error("应该找到 Hello 函数")
}
// 查找不存在的函数
funcDecl = FindFunction(file, "NotExist")
if funcDecl != nil {
t.Error("不应该找到 NotExist 函数")
}
}
// TestCheckImport 测试检查导入是否存在
func TestCheckImport(t *testing.T) {
testCode := `package main
import (
"fmt"
"kra/internal/biz/article"
)
func main() {}
`
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
if err != nil {
t.Fatalf("解析测试代码失败: %v", err)
}
// 检查存在的导入
if !CheckImport(file, "fmt") {
t.Error("应该找到 fmt 导入")
}
// 检查存在的导入(完整路径)
if !CheckImport(file, "kra/internal/biz/article") {
t.Error("应该找到 kra/internal/biz/article 导入")
}
// 检查不存在的导入
if CheckImport(file, "not/exist") {
t.Error("不应该找到 not/exist 导入")
}
}
// TestAddFuncCall 测试在函数体中添加函数调用
func TestAddFuncCall(t *testing.T) {
testCode := `package main
func initRouter() {
fmt.Println("init")
}
`
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
if err != nil {
t.Fatalf("解析测试代码失败: %v", err)
}
// 添加函数调用
AddFuncCall(file, "initRouter", "router.Article.InitArticleRouter(group)")
// 验证函数调用是否添加
funcDecl := FindFunction(file, "initRouter")
if funcDecl == nil || funcDecl.Body == nil {
t.Fatal("找不到 initRouter 函数")
}
// 检查函数体中是否包含新的调用
if len(funcDecl.Body.List) < 2 {
t.Error("应该添加新的函数调用")
}
}
// TestRemoveFuncCall 测试从函数体中移除函数调用
func TestRemoveFuncCall(t *testing.T) {
testCode := `package main
func initRouter() {
fmt.Println("init")
router.Article.InitArticleRouter(group)
}
`
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
if err != nil {
t.Fatalf("解析测试代码失败: %v", err)
}
// 移除函数调用
RemoveFuncCall(file, "initRouter", "InitArticleRouter")
// 验证函数调用是否移除
funcDecl := FindFunction(file, "initRouter")
if funcDecl == nil || funcDecl.Body == nil {
t.Fatal("找不到 initRouter 函数")
}
// 检查函数体中是否还包含该调用
for _, stmt := range funcDecl.Body.List {
if exprStmt, ok := stmt.(*ast.ExprStmt); ok {
if callExpr, ok := exprStmt.X.(*ast.CallExpr); ok {
if selExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok {
if selExpr.Sel.Name == "InitArticleRouter" {
t.Error("应该移除 InitArticleRouter 调用")
}
}
}
}
}
}
// TestCreateStmt 测试创建语句
func TestCreateStmt(t *testing.T) {
stmt := CreateStmt("fmt.Println(\"hello\")")
if stmt == nil {
t.Error("应该创建语句")
}
if stmt.X == nil {
t.Error("语句表达式不应为空")
}
}
// TestVariableExistsInBlock 测试检查变量是否存在于块中
func TestVariableExistsInBlock(t *testing.T) {
testCode := `package main
func test() {
x := 1
y := 2
}
`
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
if err != nil {
t.Fatalf("解析测试代码失败: %v", err)
}
funcDecl := FindFunction(file, "test")
if funcDecl == nil || funcDecl.Body == nil {
t.Fatal("找不到 test 函数")
}
// 检查存在的变量
if !VariableExistsInBlock(funcDecl.Body, "x") {
t.Error("应该找到变量 x")
}
// 检查不存在的变量
if VariableExistsInBlock(funcDecl.Body, "z") {
t.Error("不应该找到变量 z")
}
}
// TestBase_RelativePath 测试相对路径转换
func TestBase_RelativePath(t *testing.T) {
base := &Base{
AutoCodeRoot: "/home/user/project",
AutoCodeServer: "server",
}
testCases := []struct {
input string
expected string
}{
{"/home/user/project/server/internal/biz/article.go", "internal/biz/article.go"},
{"/other/path/file.go", "/other/path/file.go"},
}
for _, tc := range testCases {
result := base.RelativePath(tc.input)
// 由于路径分隔符可能不同,只检查是否包含关键部分
if !strings.Contains(result, "article.go") && tc.input != "/other/path/file.go" {
t.Errorf("RelativePath(%s) 结果不正确: %s", tc.input, result)
}
}
}
// TestAddAutoMigrateModel 测试添加 AutoMigrate 模型
func TestAddAutoMigrateModel(t *testing.T) {
testCode := `package main
import "gorm.io/gorm"
func migrate(db *gorm.DB) {
db.AutoMigrate(&User{})
}
`
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
if err != nil {
t.Fatalf("解析测试代码失败: %v", err)
}
// 添加新模型
AddAutoMigrateModel(file, "&article.Article{}")
// 验证模型是否添加
var found bool
ast.Inspect(file, func(node ast.Node) bool {
if callExpr, ok := node.(*ast.CallExpr); ok {
if selExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok {
if selExpr.Sel.Name == "AutoMigrate" {
// 检查参数数量是否增加
if len(callExpr.Args) > 1 {
found = true
return false
}
}
}
}
return true
})
if !found {
t.Error("应该添加新的 AutoMigrate 模型")
}
}
// TestCheckGormGenModelExists 测试检查 GORM Gen 模型是否存在
func TestCheckGormGenModelExists(t *testing.T) {
testCode := `package main
func main() {
g.ApplyBasic(
g.GenerateModel("users"),
g.GenerateModel("articles"),
)
}
`
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
if err != nil {
t.Fatalf("解析测试代码失败: %v", err)
}
// 检查存在的模型
if !CheckGormGenModelExists(file, "users") {
t.Error("应该找到 users 模型")
}
// 检查不存在的模型
if CheckGormGenModelExists(file, "orders") {
t.Error("不应该找到 orders 模型")
}
}