package autocode import ( "bytes" "go/ast" "go/parser" "go/token" "strings" "testing" ) // TestWireProviderAst_Parse 测试 WireProviderAst 解析功能 func TestWireProviderAst_Parse(t *testing.T) { // 创建测试用的 Go 代码 testCode := `package biz import "github.com/google/wire" var ProviderSet = wire.NewSet(NewUserUsecase) ` // 创建临时文件 fset := token.NewFileSet() file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments) if err != nil { t.Fatalf("解析测试代码失败: %v", err) } if file == nil { t.Fatal("解析结果为空") } if file.Name.Name != "biz" { t.Errorf("包名应为 'biz', 实际为 '%s'", file.Name.Name) } } // TestWireProviderAst_Injection 测试 WireProviderAst 注入功能 func TestWireProviderAst_Injection(t *testing.T) { testCode := `package biz import "github.com/google/wire" var ProviderSet = wire.NewSet(NewUserUsecase) ` fset := token.NewFileSet() file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments) if err != nil { t.Fatalf("解析测试代码失败: %v", err) } // 创建注入器 injector := &WireProviderAst{ Layer: "biz", PackageName: "article", ProviderName: "NewArticleUsecase", Path: "test.go", } // 执行注入 err = injector.Injection(file) if err != nil { t.Fatalf("注入失败: %v", err) } // 验证注入结果 var found bool ast.Inspect(file, func(node ast.Node) bool { if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.VAR { for _, spec := range genDecl.Specs { if valueSpec, ok := spec.(*ast.ValueSpec); ok { for i, name := range valueSpec.Names { if name.Name == "ProviderSet" && i < len(valueSpec.Values) { if callExpr, ok := valueSpec.Values[i].(*ast.CallExpr); ok { for _, arg := range callExpr.Args { if ident, ok := arg.(*ast.Ident); ok && ident.Name == "NewArticleUsecase" { found = true return false } } } } } } } } return true }) if !found { t.Error("注入后应包含 NewArticleUsecase Provider") } } // TestWireProviderAst_Injection_Duplicate 测试重复注入不会添加重复项 func TestWireProviderAst_Injection_Duplicate(t *testing.T) { testCode := `package biz import "github.com/google/wire" var ProviderSet = wire.NewSet(NewUserUsecase, NewArticleUsecase) ` fset := token.NewFileSet() file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments) if err != nil { t.Fatalf("解析测试代码失败: %v", err) } injector := &WireProviderAst{ Layer: "biz", PackageName: "article", ProviderName: "NewArticleUsecase", Path: "test.go", } // 执行注入(应该不会添加重复项) err = injector.Injection(file) if err != nil { t.Fatalf("注入失败: %v", err) } // 统计 NewArticleUsecase 出现次数 count := 0 ast.Inspect(file, func(node ast.Node) bool { if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.VAR { for _, spec := range genDecl.Specs { if valueSpec, ok := spec.(*ast.ValueSpec); ok { for i, name := range valueSpec.Names { if name.Name == "ProviderSet" && i < len(valueSpec.Values) { if callExpr, ok := valueSpec.Values[i].(*ast.CallExpr); ok { for _, arg := range callExpr.Args { if ident, ok := arg.(*ast.Ident); ok && ident.Name == "NewArticleUsecase" { count++ } } } } } } } } return true }) if count != 1 { t.Errorf("NewArticleUsecase 应该只出现 1 次, 实际出现 %d 次", count) } } // TestWireProviderAst_Rollback 测试 WireProviderAst 回滚功能 func TestWireProviderAst_Rollback(t *testing.T) { testCode := `package biz import "github.com/google/wire" var ProviderSet = wire.NewSet(NewUserUsecase, NewArticleUsecase) ` fset := token.NewFileSet() file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments) if err != nil { t.Fatalf("解析测试代码失败: %v", err) } injector := &WireProviderAst{ Layer: "biz", PackageName: "article", ProviderName: "NewArticleUsecase", Path: "test.go", } // 执行回滚 err = injector.Rollback(file) if err != nil { t.Fatalf("回滚失败: %v", err) } // 验证回滚结果 var found bool ast.Inspect(file, func(node ast.Node) bool { if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.VAR { for _, spec := range genDecl.Specs { if valueSpec, ok := spec.(*ast.ValueSpec); ok { for i, name := range valueSpec.Names { if name.Name == "ProviderSet" && i < len(valueSpec.Values) { if callExpr, ok := valueSpec.Values[i].(*ast.CallExpr); ok { for _, arg := range callExpr.Args { if ident, ok := arg.(*ast.Ident); ok && ident.Name == "NewArticleUsecase" { found = true return false } } } } } } } } return true }) if found { t.Error("回滚后不应包含 NewArticleUsecase Provider") } } // TestWireProviderAst_Format 测试 WireProviderAst 格式化输出 func TestWireProviderAst_Format(t *testing.T) { testCode := `package biz import "github.com/google/wire" var ProviderSet = wire.NewSet(NewUserUsecase) ` fset := token.NewFileSet() file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments) if err != nil { t.Fatalf("解析测试代码失败: %v", err) } injector := &WireProviderAst{ Layer: "biz", PackageName: "article", ProviderName: "NewArticleUsecase", Path: "test.go", } // 注入后格式化 _ = injector.Injection(file) var buf bytes.Buffer err = injector.Format("", &buf, file) if err != nil { t.Fatalf("格式化失败: %v", err) } result := buf.String() if !strings.Contains(result, "NewArticleUsecase") { t.Error("格式化输出应包含 NewArticleUsecase") } if !strings.Contains(result, "package biz") { t.Error("格式化输出应包含 package biz") } } // TestMainProviderSetAst_Injection 测试 MainProviderSetAst 注入功能 func TestMainProviderSetAst_Injection(t *testing.T) { testCode := `package biz import ( "github.com/google/wire" "kra/internal/biz/user" ) var ProviderSet = wire.NewSet(user.NewUserUsecase) ` fset := token.NewFileSet() file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments) if err != nil { t.Fatalf("解析测试代码失败: %v", err) } injector := &MainProviderSetAst{ Layer: "biz", PackageName: "article", ProviderName: "NewArticleUsecase", ImportPath: "kra/internal/biz/article", Path: "test.go", } // 执行注入 err = injector.Injection(file) if err != nil { t.Fatalf("注入失败: %v", err) } // 验证导入是否添加 - 检查 import 声明中是否包含新导入 var importFound bool ast.Inspect(file, func(node ast.Node) bool { if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT { for _, spec := range genDecl.Specs { if impSpec, ok := spec.(*ast.ImportSpec); ok { if impSpec.Path.Value == "\"kra/internal/biz/article\"" { importFound = true return false } } } } return true }) if !importFound { t.Error("注入后应包含 kra/internal/biz/article 导入") } // 验证 Provider 是否添加 var providerFound bool ast.Inspect(file, func(node ast.Node) bool { if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.VAR { for _, spec := range genDecl.Specs { if valueSpec, ok := spec.(*ast.ValueSpec); ok { for i, name := range valueSpec.Names { if name.Name == "ProviderSet" && i < len(valueSpec.Values) { if callExpr, ok := valueSpec.Values[i].(*ast.CallExpr); ok { for _, arg := range callExpr.Args { if selExpr, ok := arg.(*ast.SelectorExpr); ok { if pkgIdent, ok := selExpr.X.(*ast.Ident); ok { if pkgIdent.Name == "article" && selExpr.Sel.Name == "NewArticleUsecase" { providerFound = true return false } } } } } } } } } } return true }) if !providerFound { t.Error("注入后应包含 article.NewArticleUsecase Provider") } } // TestMainProviderSetAst_Rollback 测试 MainProviderSetAst 回滚功能 func TestMainProviderSetAst_Rollback(t *testing.T) { testCode := `package biz import ( "github.com/google/wire" "kra/internal/biz/user" "kra/internal/biz/article" ) var ProviderSet = wire.NewSet(user.NewUserUsecase, article.NewArticleUsecase) ` fset := token.NewFileSet() file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments) if err != nil { t.Fatalf("解析测试代码失败: %v", err) } injector := &MainProviderSetAst{ Layer: "biz", PackageName: "article", ProviderName: "NewArticleUsecase", ImportPath: "kra/internal/biz/article", Path: "test.go", } // 执行回滚 err = injector.Rollback(file) if err != nil { t.Fatalf("回滚失败: %v", err) } // 验证导入是否移除 - 检查 import 声明 var importFound bool ast.Inspect(file, func(node ast.Node) bool { if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT { for _, spec := range genDecl.Specs { if impSpec, ok := spec.(*ast.ImportSpec); ok { if impSpec.Path.Value == "\"kra/internal/biz/article\"" { importFound = true return false } } } } return true }) if importFound { t.Error("回滚后不应包含 kra/internal/biz/article 导入") } // 验证 Provider 是否移除 var providerFound bool ast.Inspect(file, func(node ast.Node) bool { if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.VAR { for _, spec := range genDecl.Specs { if valueSpec, ok := spec.(*ast.ValueSpec); ok { for i, name := range valueSpec.Names { if name.Name == "ProviderSet" && i < len(valueSpec.Values) { if callExpr, ok := valueSpec.Values[i].(*ast.CallExpr); ok { for _, arg := range callExpr.Args { if selExpr, ok := arg.(*ast.SelectorExpr); ok { if pkgIdent, ok := selExpr.X.(*ast.Ident); ok { if pkgIdent.Name == "article" && selExpr.Sel.Name == "NewArticleUsecase" { providerFound = true return false } } } } } } } } } } return true }) if providerFound { t.Error("回滚后不应包含 article.NewArticleUsecase Provider") } } // TestGetLayerMainFile 测试获取层主文件路径 func TestGetLayerMainFile(t *testing.T) { testCases := []struct { layer string expected string }{ {"biz", "internal/biz/biz.go"}, {"data", "internal/data/data.go"}, {"service", "internal/service/service.go"}, {"handler", "internal/server/handler/enter.go"}, {"api", "internal/server/handler/enter.go"}, {"router", "internal/server/router/enter.go"}, {"unknown", ""}, } for _, tc := range testCases { t.Run(tc.layer, func(t *testing.T) { result := GetLayerMainFile(tc.layer) if result != tc.expected { t.Errorf("GetLayerMainFile(%s) = %s, 期望 %s", tc.layer, result, tc.expected) } }) } } // TestGetLayerImportPath 测试获取层导入路径 func TestGetLayerImportPath(t *testing.T) { module := "kra" packageName := "article" testCases := []struct { layer string expected string }{ {"biz", "kra/internal/biz/article"}, {"data", "kra/internal/data/article"}, {"service", "kra/internal/service/article"}, {"handler", "kra/internal/server/handler/article"}, {"api", "kra/internal/server/handler/article"}, {"router", "kra/internal/server/router/article"}, {"unknown", ""}, } for _, tc := range testCases { t.Run(tc.layer, func(t *testing.T) { result := GetLayerImportPath(module, tc.layer, packageName) if result != tc.expected { t.Errorf("GetLayerImportPath(%s, %s, %s) = %s, 期望 %s", module, tc.layer, packageName, result, tc.expected) } }) } } // TestGetProviderName 测试获取 Provider 名称 func TestGetProviderName(t *testing.T) { structName := "Article" testCases := []struct { layer string expected string }{ {"biz", "NewArticleUsecase"}, {"data", "NewArticleRepo"}, {"service", "NewArticleService"}, {"handler", "NewArticleHandler"}, {"api", "NewArticleHandler"}, {"router", "NewArticleRouter"}, {"unknown", ""}, } for _, tc := range testCases { t.Run(tc.layer, func(t *testing.T) { result := GetProviderName(tc.layer, structName) if result != tc.expected { t.Errorf("GetProviderName(%s, %s) = %s, 期望 %s", tc.layer, structName, result, tc.expected) } }) } } // TestAddImport 测试添加导入 func TestAddImport(t *testing.T) { testCode := `package main import "fmt" func main() {} ` fset := token.NewFileSet() file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments) if err != nil { t.Fatalf("解析测试代码失败: %v", err) } // 添加新导入 addImport(file, "kra/internal/biz/article") // 验证导入是否添加 - 检查 import 声明 var found bool ast.Inspect(file, func(node ast.Node) bool { if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT { for _, spec := range genDecl.Specs { if impSpec, ok := spec.(*ast.ImportSpec); ok { if impSpec.Path.Value == "\"kra/internal/biz/article\"" { found = true return false } } } } return true }) if !found { t.Error("应该添加 kra/internal/biz/article 导入") } } // TestAddImport_Duplicate 测试重复添加导入 func TestAddImport_Duplicate(t *testing.T) { testCode := `package main import ( "fmt" "kra/internal/biz/article" ) func main() {} ` fset := token.NewFileSet() file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments) if err != nil { t.Fatalf("解析测试代码失败: %v", err) } initialCount := len(file.Imports) // 尝试添加已存在的导入 addImport(file, "kra/internal/biz/article") // 验证导入数量没有增加 if len(file.Imports) != initialCount { t.Errorf("重复添加导入后数量应保持 %d, 实际为 %d", initialCount, len(file.Imports)) } } // TestRemoveImport 测试移除导入 func TestRemoveImport(t *testing.T) { testCode := `package main import ( "fmt" "kra/internal/biz/article" ) func main() {} ` fset := token.NewFileSet() file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments) if err != nil { t.Fatalf("解析测试代码失败: %v", err) } // 移除导入 removeImport(file, "kra/internal/biz/article") // 验证导入是否移除 - 检查 import 声明 var found bool ast.Inspect(file, func(node ast.Node) bool { if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT { for _, spec := range genDecl.Specs { if impSpec, ok := spec.(*ast.ImportSpec); ok { if impSpec.Path.Value == "\"kra/internal/biz/article\"" { found = true return false } } } } return true }) if found { t.Error("应该移除 kra/internal/biz/article 导入") } } // TestNewWireProviderAst 测试创建 WireProviderAst func TestNewWireProviderAst(t *testing.T) { ast := NewWireProviderAst("biz", "article", "NewArticleUsecase", "internal/biz/article/enter.go") if ast.Layer != "biz" { t.Errorf("Layer 应为 'biz', 实际为 '%s'", ast.Layer) } if ast.PackageName != "article" { t.Errorf("PackageName 应为 'article', 实际为 '%s'", ast.PackageName) } if ast.ProviderName != "NewArticleUsecase" { t.Errorf("ProviderName 应为 'NewArticleUsecase', 实际为 '%s'", ast.ProviderName) } if ast.Path != "internal/biz/article/enter.go" { t.Errorf("Path 应为 'internal/biz/article/enter.go', 实际为 '%s'", ast.Path) } } // TestNewMainProviderSetAst 测试创建 MainProviderSetAst func TestNewMainProviderSetAst(t *testing.T) { ast := NewMainProviderSetAst("biz", "article", "NewArticleUsecase", "kra/internal/biz/article", "internal/biz/biz.go") if ast.Layer != "biz" { t.Errorf("Layer 应为 'biz', 实际为 '%s'", ast.Layer) } if ast.PackageName != "article" { t.Errorf("PackageName 应为 'article', 实际为 '%s'", ast.PackageName) } if ast.ProviderName != "NewArticleUsecase" { t.Errorf("ProviderName 应为 'NewArticleUsecase', 实际为 '%s'", ast.ProviderName) } if ast.ImportPath != "kra/internal/biz/article" { t.Errorf("ImportPath 应为 'kra/internal/biz/article', 实际为 '%s'", ast.ImportPath) } if ast.Path != "internal/biz/biz.go" { t.Errorf("Path 应为 'internal/biz/biz.go', 实际为 '%s'", ast.Path) } } // TestWireProviderAst_Injection_NoProviderSet 测试没有 ProviderSet 时的注入 func TestWireProviderAst_Injection_NoProviderSet(t *testing.T) { testCode := `package biz import "github.com/google/wire" var SomeOtherVar = "test" ` fset := token.NewFileSet() file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments) if err != nil { t.Fatalf("解析测试代码失败: %v", err) } injector := &WireProviderAst{ Layer: "biz", PackageName: "article", ProviderName: "NewArticleUsecase", Path: "test.go", } // 执行注入应该返回错误 err = injector.Injection(file) if err == nil { t.Error("没有 ProviderSet 时注入应该返回错误") } } // TestWireInjectionConfig 测试 WireInjectionConfig 结构体 func TestWireInjectionConfig(t *testing.T) { config := &WireInjectionConfig{ Module: "kra", PackageName: "article", StructName: "Article", Layers: []string{"biz", "data", "service", "handler", "router"}, } if config.Module != "kra" { t.Errorf("Module 应为 'kra', 实际为 '%s'", config.Module) } if config.PackageName != "article" { t.Errorf("PackageName 应为 'article', 实际为 '%s'", config.PackageName) } if config.StructName != "Article" { t.Errorf("StructName 应为 'Article', 实际为 '%s'", config.StructName) } if len(config.Layers) != 5 { t.Errorf("Layers 长度应为 5, 实际为 %d", len(config.Layers)) } }