519 lines
12 KiB
Go
519 lines
12 KiB
Go
package ast
|
|
|
|
import (
|
|
"go/ast"
|
|
"go/parser"
|
|
"go/token"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
// TestAddImport 测试添加导入
|
|
func TestAddImport(t *testing.T) {
|
|
testCode := `package main
|
|
|
|
import "fmt"
|
|
|
|
func main() {}
|
|
`
|
|
fset := token.NewFileSet()
|
|
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
|
|
if err != nil {
|
|
t.Fatalf("解析测试代码失败: %v", err)
|
|
}
|
|
|
|
// 添加新导入
|
|
AddImport(file, "kra/internal/biz/article")
|
|
|
|
// 验证导入是否添加
|
|
var found bool
|
|
ast.Inspect(file, func(node ast.Node) bool {
|
|
if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT {
|
|
for _, spec := range genDecl.Specs {
|
|
if impSpec, ok := spec.(*ast.ImportSpec); ok {
|
|
if impSpec.Path.Value == "\"kra/internal/biz/article\"" {
|
|
found = true
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
if !found {
|
|
t.Error("应该添加 kra/internal/biz/article 导入")
|
|
}
|
|
}
|
|
|
|
// TestAddImport_Duplicate 测试重复添加导入
|
|
func TestAddImport_Duplicate(t *testing.T) {
|
|
testCode := `package main
|
|
|
|
import (
|
|
"fmt"
|
|
"kra/internal/biz/article"
|
|
)
|
|
|
|
func main() {}
|
|
`
|
|
fset := token.NewFileSet()
|
|
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
|
|
if err != nil {
|
|
t.Fatalf("解析测试代码失败: %v", err)
|
|
}
|
|
|
|
// 统计初始导入数量
|
|
initialCount := 0
|
|
ast.Inspect(file, func(node ast.Node) bool {
|
|
if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT {
|
|
initialCount = len(genDecl.Specs)
|
|
return false
|
|
}
|
|
return true
|
|
})
|
|
|
|
// 尝试添加已存在的导入
|
|
AddImport(file, "kra/internal/biz/article")
|
|
|
|
// 验证导入数量没有增加
|
|
finalCount := 0
|
|
ast.Inspect(file, func(node ast.Node) bool {
|
|
if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT {
|
|
finalCount = len(genDecl.Specs)
|
|
return false
|
|
}
|
|
return true
|
|
})
|
|
|
|
if finalCount != initialCount {
|
|
t.Errorf("重复添加导入后数量应保持 %d, 实际为 %d", initialCount, finalCount)
|
|
}
|
|
}
|
|
|
|
// TestRemoveImport 测试移除导入
|
|
func TestRemoveImport(t *testing.T) {
|
|
testCode := `package main
|
|
|
|
import (
|
|
"fmt"
|
|
"kra/internal/biz/article"
|
|
)
|
|
|
|
func main() {}
|
|
`
|
|
fset := token.NewFileSet()
|
|
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
|
|
if err != nil {
|
|
t.Fatalf("解析测试代码失败: %v", err)
|
|
}
|
|
|
|
// 移除导入
|
|
RemoveImport(file, "kra/internal/biz/article")
|
|
|
|
// 验证导入是否移除
|
|
var found bool
|
|
ast.Inspect(file, func(node ast.Node) bool {
|
|
if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT {
|
|
for _, spec := range genDecl.Specs {
|
|
if impSpec, ok := spec.(*ast.ImportSpec); ok {
|
|
if impSpec.Path.Value == "\"kra/internal/biz/article\"" {
|
|
found = true
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
if found {
|
|
t.Error("应该移除 kra/internal/biz/article 导入")
|
|
}
|
|
}
|
|
|
|
// TestAddStructField 测试添加结构体字段
|
|
func TestAddStructField(t *testing.T) {
|
|
testCode := `package main
|
|
|
|
type MyStruct struct {
|
|
Name string
|
|
}
|
|
`
|
|
fset := token.NewFileSet()
|
|
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
|
|
if err != nil {
|
|
t.Fatalf("解析测试代码失败: %v", err)
|
|
}
|
|
|
|
// 添加新字段
|
|
AddStructField(file, "MyStruct", "Age", "int")
|
|
|
|
// 验证字段是否添加
|
|
var found bool
|
|
ast.Inspect(file, func(node ast.Node) bool {
|
|
if typeSpec, ok := node.(*ast.TypeSpec); ok && typeSpec.Name.Name == "MyStruct" {
|
|
if structType, ok := typeSpec.Type.(*ast.StructType); ok {
|
|
for _, field := range structType.Fields.List {
|
|
for _, name := range field.Names {
|
|
if name.Name == "Age" {
|
|
found = true
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
if !found {
|
|
t.Error("应该添加 Age 字段")
|
|
}
|
|
}
|
|
|
|
// TestAddStructField_Duplicate 测试重复添加结构体字段
|
|
func TestAddStructField_Duplicate(t *testing.T) {
|
|
testCode := `package main
|
|
|
|
type MyStruct struct {
|
|
Name string
|
|
Age int
|
|
}
|
|
`
|
|
fset := token.NewFileSet()
|
|
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
|
|
if err != nil {
|
|
t.Fatalf("解析测试代码失败: %v", err)
|
|
}
|
|
|
|
// 统计初始字段数量
|
|
initialCount := 0
|
|
ast.Inspect(file, func(node ast.Node) bool {
|
|
if typeSpec, ok := node.(*ast.TypeSpec); ok && typeSpec.Name.Name == "MyStruct" {
|
|
if structType, ok := typeSpec.Type.(*ast.StructType); ok {
|
|
initialCount = len(structType.Fields.List)
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
|
|
// 尝试添加已存在的字段
|
|
AddStructField(file, "MyStruct", "Age", "int")
|
|
|
|
// 验证字段数量没有增加
|
|
finalCount := 0
|
|
ast.Inspect(file, func(node ast.Node) bool {
|
|
if typeSpec, ok := node.(*ast.TypeSpec); ok && typeSpec.Name.Name == "MyStruct" {
|
|
if structType, ok := typeSpec.Type.(*ast.StructType); ok {
|
|
finalCount = len(structType.Fields.List)
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
|
|
if finalCount != initialCount {
|
|
t.Errorf("重复添加字段后数量应保持 %d, 实际为 %d", initialCount, finalCount)
|
|
}
|
|
}
|
|
|
|
// TestRemoveStructField 测试移除结构体字段
|
|
func TestRemoveStructField(t *testing.T) {
|
|
testCode := `package main
|
|
|
|
type MyStruct struct {
|
|
Name string
|
|
Age int
|
|
}
|
|
`
|
|
fset := token.NewFileSet()
|
|
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
|
|
if err != nil {
|
|
t.Fatalf("解析测试代码失败: %v", err)
|
|
}
|
|
|
|
// 移除字段
|
|
RemoveStructField(file, "MyStruct", "Age")
|
|
|
|
// 验证字段是否移除
|
|
var found bool
|
|
ast.Inspect(file, func(node ast.Node) bool {
|
|
if typeSpec, ok := node.(*ast.TypeSpec); ok && typeSpec.Name.Name == "MyStruct" {
|
|
if structType, ok := typeSpec.Type.(*ast.StructType); ok {
|
|
for _, field := range structType.Fields.List {
|
|
for _, name := range field.Names {
|
|
if name.Name == "Age" {
|
|
found = true
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
if found {
|
|
t.Error("应该移除 Age 字段")
|
|
}
|
|
}
|
|
|
|
// TestFindFunction 测试查找函数
|
|
func TestFindFunction(t *testing.T) {
|
|
testCode := `package main
|
|
|
|
func Hello() {}
|
|
|
|
func World() {}
|
|
`
|
|
fset := token.NewFileSet()
|
|
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
|
|
if err != nil {
|
|
t.Fatalf("解析测试代码失败: %v", err)
|
|
}
|
|
|
|
// 查找存在的函数
|
|
funcDecl := FindFunction(file, "Hello")
|
|
if funcDecl == nil {
|
|
t.Error("应该找到 Hello 函数")
|
|
}
|
|
|
|
// 查找不存在的函数
|
|
funcDecl = FindFunction(file, "NotExist")
|
|
if funcDecl != nil {
|
|
t.Error("不应该找到 NotExist 函数")
|
|
}
|
|
}
|
|
|
|
// TestCheckImport 测试检查导入是否存在
|
|
func TestCheckImport(t *testing.T) {
|
|
testCode := `package main
|
|
|
|
import (
|
|
"fmt"
|
|
"kra/internal/biz/article"
|
|
)
|
|
|
|
func main() {}
|
|
`
|
|
fset := token.NewFileSet()
|
|
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
|
|
if err != nil {
|
|
t.Fatalf("解析测试代码失败: %v", err)
|
|
}
|
|
|
|
// 检查存在的导入
|
|
if !CheckImport(file, "fmt") {
|
|
t.Error("应该找到 fmt 导入")
|
|
}
|
|
|
|
// 检查存在的导入(完整路径)
|
|
if !CheckImport(file, "kra/internal/biz/article") {
|
|
t.Error("应该找到 kra/internal/biz/article 导入")
|
|
}
|
|
|
|
// 检查不存在的导入
|
|
if CheckImport(file, "not/exist") {
|
|
t.Error("不应该找到 not/exist 导入")
|
|
}
|
|
}
|
|
|
|
// TestAddFuncCall 测试在函数体中添加函数调用
|
|
func TestAddFuncCall(t *testing.T) {
|
|
testCode := `package main
|
|
|
|
func initRouter() {
|
|
fmt.Println("init")
|
|
}
|
|
`
|
|
fset := token.NewFileSet()
|
|
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
|
|
if err != nil {
|
|
t.Fatalf("解析测试代码失败: %v", err)
|
|
}
|
|
|
|
// 添加函数调用
|
|
AddFuncCall(file, "initRouter", "router.Article.InitArticleRouter(group)")
|
|
|
|
// 验证函数调用是否添加
|
|
funcDecl := FindFunction(file, "initRouter")
|
|
if funcDecl == nil || funcDecl.Body == nil {
|
|
t.Fatal("找不到 initRouter 函数")
|
|
}
|
|
|
|
// 检查函数体中是否包含新的调用
|
|
if len(funcDecl.Body.List) < 2 {
|
|
t.Error("应该添加新的函数调用")
|
|
}
|
|
}
|
|
|
|
// TestRemoveFuncCall 测试从函数体中移除函数调用
|
|
func TestRemoveFuncCall(t *testing.T) {
|
|
testCode := `package main
|
|
|
|
func initRouter() {
|
|
fmt.Println("init")
|
|
router.Article.InitArticleRouter(group)
|
|
}
|
|
`
|
|
fset := token.NewFileSet()
|
|
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
|
|
if err != nil {
|
|
t.Fatalf("解析测试代码失败: %v", err)
|
|
}
|
|
|
|
// 移除函数调用
|
|
RemoveFuncCall(file, "initRouter", "InitArticleRouter")
|
|
|
|
// 验证函数调用是否移除
|
|
funcDecl := FindFunction(file, "initRouter")
|
|
if funcDecl == nil || funcDecl.Body == nil {
|
|
t.Fatal("找不到 initRouter 函数")
|
|
}
|
|
|
|
// 检查函数体中是否还包含该调用
|
|
for _, stmt := range funcDecl.Body.List {
|
|
if exprStmt, ok := stmt.(*ast.ExprStmt); ok {
|
|
if callExpr, ok := exprStmt.X.(*ast.CallExpr); ok {
|
|
if selExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok {
|
|
if selExpr.Sel.Name == "InitArticleRouter" {
|
|
t.Error("应该移除 InitArticleRouter 调用")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestCreateStmt 测试创建语句
|
|
func TestCreateStmt(t *testing.T) {
|
|
stmt := CreateStmt("fmt.Println(\"hello\")")
|
|
if stmt == nil {
|
|
t.Error("应该创建语句")
|
|
}
|
|
if stmt.X == nil {
|
|
t.Error("语句表达式不应为空")
|
|
}
|
|
}
|
|
|
|
// TestVariableExistsInBlock 测试检查变量是否存在于块中
|
|
func TestVariableExistsInBlock(t *testing.T) {
|
|
testCode := `package main
|
|
|
|
func test() {
|
|
x := 1
|
|
y := 2
|
|
}
|
|
`
|
|
fset := token.NewFileSet()
|
|
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
|
|
if err != nil {
|
|
t.Fatalf("解析测试代码失败: %v", err)
|
|
}
|
|
|
|
funcDecl := FindFunction(file, "test")
|
|
if funcDecl == nil || funcDecl.Body == nil {
|
|
t.Fatal("找不到 test 函数")
|
|
}
|
|
|
|
// 检查存在的变量
|
|
if !VariableExistsInBlock(funcDecl.Body, "x") {
|
|
t.Error("应该找到变量 x")
|
|
}
|
|
|
|
// 检查不存在的变量
|
|
if VariableExistsInBlock(funcDecl.Body, "z") {
|
|
t.Error("不应该找到变量 z")
|
|
}
|
|
}
|
|
|
|
// TestBase_RelativePath 测试相对路径转换
|
|
func TestBase_RelativePath(t *testing.T) {
|
|
base := &Base{
|
|
AutoCodeRoot: "/home/user/project",
|
|
AutoCodeServer: "server",
|
|
}
|
|
|
|
testCases := []struct {
|
|
input string
|
|
expected string
|
|
}{
|
|
{"/home/user/project/server/internal/biz/article.go", "internal/biz/article.go"},
|
|
{"/other/path/file.go", "/other/path/file.go"},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
result := base.RelativePath(tc.input)
|
|
// 由于路径分隔符可能不同,只检查是否包含关键部分
|
|
if !strings.Contains(result, "article.go") && tc.input != "/other/path/file.go" {
|
|
t.Errorf("RelativePath(%s) 结果不正确: %s", tc.input, result)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestAddAutoMigrateModel 测试添加 AutoMigrate 模型
|
|
func TestAddAutoMigrateModel(t *testing.T) {
|
|
testCode := `package main
|
|
|
|
import "gorm.io/gorm"
|
|
|
|
func migrate(db *gorm.DB) {
|
|
db.AutoMigrate(&User{})
|
|
}
|
|
`
|
|
fset := token.NewFileSet()
|
|
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
|
|
if err != nil {
|
|
t.Fatalf("解析测试代码失败: %v", err)
|
|
}
|
|
|
|
// 添加新模型
|
|
AddAutoMigrateModel(file, "&article.Article{}")
|
|
|
|
// 验证模型是否添加
|
|
var found bool
|
|
ast.Inspect(file, func(node ast.Node) bool {
|
|
if callExpr, ok := node.(*ast.CallExpr); ok {
|
|
if selExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok {
|
|
if selExpr.Sel.Name == "AutoMigrate" {
|
|
// 检查参数数量是否增加
|
|
if len(callExpr.Args) > 1 {
|
|
found = true
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
if !found {
|
|
t.Error("应该添加新的 AutoMigrate 模型")
|
|
}
|
|
}
|
|
|
|
// TestCheckGormGenModelExists 测试检查 GORM Gen 模型是否存在
|
|
func TestCheckGormGenModelExists(t *testing.T) {
|
|
testCode := `package main
|
|
|
|
func main() {
|
|
g.ApplyBasic(
|
|
g.GenerateModel("users"),
|
|
g.GenerateModel("articles"),
|
|
)
|
|
}
|
|
`
|
|
fset := token.NewFileSet()
|
|
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
|
|
if err != nil {
|
|
t.Fatalf("解析测试代码失败: %v", err)
|
|
}
|
|
|
|
// 检查存在的模型
|
|
if !CheckGormGenModelExists(file, "users") {
|
|
t.Error("应该找到 users 模型")
|
|
}
|
|
|
|
// 检查不存在的模型
|
|
if CheckGormGenModelExists(file, "orders") {
|
|
t.Error("不应该找到 orders 模型")
|
|
}
|
|
}
|