kra/pkg/utils/autocode/injection_test.go

685 lines
18 KiB
Go

package autocode
import (
"bytes"
"go/ast"
"go/parser"
"go/token"
"strings"
"testing"
)
// TestWireProviderAst_Parse 测试 WireProviderAst 解析功能
func TestWireProviderAst_Parse(t *testing.T) {
// 创建测试用的 Go 代码
testCode := `package biz
import "github.com/google/wire"
var ProviderSet = wire.NewSet(NewUserUsecase)
`
// 创建临时文件
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
if err != nil {
t.Fatalf("解析测试代码失败: %v", err)
}
if file == nil {
t.Fatal("解析结果为空")
}
if file.Name.Name != "biz" {
t.Errorf("包名应为 'biz', 实际为 '%s'", file.Name.Name)
}
}
// TestWireProviderAst_Injection 测试 WireProviderAst 注入功能
func TestWireProviderAst_Injection(t *testing.T) {
testCode := `package biz
import "github.com/google/wire"
var ProviderSet = wire.NewSet(NewUserUsecase)
`
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
if err != nil {
t.Fatalf("解析测试代码失败: %v", err)
}
// 创建注入器
injector := &WireProviderAst{
Layer: "biz",
PackageName: "article",
ProviderName: "NewArticleUsecase",
Path: "test.go",
}
// 执行注入
err = injector.Injection(file)
if err != nil {
t.Fatalf("注入失败: %v", err)
}
// 验证注入结果
var found bool
ast.Inspect(file, func(node ast.Node) bool {
if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.VAR {
for _, spec := range genDecl.Specs {
if valueSpec, ok := spec.(*ast.ValueSpec); ok {
for i, name := range valueSpec.Names {
if name.Name == "ProviderSet" && i < len(valueSpec.Values) {
if callExpr, ok := valueSpec.Values[i].(*ast.CallExpr); ok {
for _, arg := range callExpr.Args {
if ident, ok := arg.(*ast.Ident); ok && ident.Name == "NewArticleUsecase" {
found = true
return false
}
}
}
}
}
}
}
}
return true
})
if !found {
t.Error("注入后应包含 NewArticleUsecase Provider")
}
}
// TestWireProviderAst_Injection_Duplicate 测试重复注入不会添加重复项
func TestWireProviderAst_Injection_Duplicate(t *testing.T) {
testCode := `package biz
import "github.com/google/wire"
var ProviderSet = wire.NewSet(NewUserUsecase, NewArticleUsecase)
`
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
if err != nil {
t.Fatalf("解析测试代码失败: %v", err)
}
injector := &WireProviderAst{
Layer: "biz",
PackageName: "article",
ProviderName: "NewArticleUsecase",
Path: "test.go",
}
// 执行注入(应该不会添加重复项)
err = injector.Injection(file)
if err != nil {
t.Fatalf("注入失败: %v", err)
}
// 统计 NewArticleUsecase 出现次数
count := 0
ast.Inspect(file, func(node ast.Node) bool {
if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.VAR {
for _, spec := range genDecl.Specs {
if valueSpec, ok := spec.(*ast.ValueSpec); ok {
for i, name := range valueSpec.Names {
if name.Name == "ProviderSet" && i < len(valueSpec.Values) {
if callExpr, ok := valueSpec.Values[i].(*ast.CallExpr); ok {
for _, arg := range callExpr.Args {
if ident, ok := arg.(*ast.Ident); ok && ident.Name == "NewArticleUsecase" {
count++
}
}
}
}
}
}
}
}
return true
})
if count != 1 {
t.Errorf("NewArticleUsecase 应该只出现 1 次, 实际出现 %d 次", count)
}
}
// TestWireProviderAst_Rollback 测试 WireProviderAst 回滚功能
func TestWireProviderAst_Rollback(t *testing.T) {
testCode := `package biz
import "github.com/google/wire"
var ProviderSet = wire.NewSet(NewUserUsecase, NewArticleUsecase)
`
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
if err != nil {
t.Fatalf("解析测试代码失败: %v", err)
}
injector := &WireProviderAst{
Layer: "biz",
PackageName: "article",
ProviderName: "NewArticleUsecase",
Path: "test.go",
}
// 执行回滚
err = injector.Rollback(file)
if err != nil {
t.Fatalf("回滚失败: %v", err)
}
// 验证回滚结果
var found bool
ast.Inspect(file, func(node ast.Node) bool {
if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.VAR {
for _, spec := range genDecl.Specs {
if valueSpec, ok := spec.(*ast.ValueSpec); ok {
for i, name := range valueSpec.Names {
if name.Name == "ProviderSet" && i < len(valueSpec.Values) {
if callExpr, ok := valueSpec.Values[i].(*ast.CallExpr); ok {
for _, arg := range callExpr.Args {
if ident, ok := arg.(*ast.Ident); ok && ident.Name == "NewArticleUsecase" {
found = true
return false
}
}
}
}
}
}
}
}
return true
})
if found {
t.Error("回滚后不应包含 NewArticleUsecase Provider")
}
}
// TestWireProviderAst_Format 测试 WireProviderAst 格式化输出
func TestWireProviderAst_Format(t *testing.T) {
testCode := `package biz
import "github.com/google/wire"
var ProviderSet = wire.NewSet(NewUserUsecase)
`
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
if err != nil {
t.Fatalf("解析测试代码失败: %v", err)
}
injector := &WireProviderAst{
Layer: "biz",
PackageName: "article",
ProviderName: "NewArticleUsecase",
Path: "test.go",
}
// 注入后格式化
_ = injector.Injection(file)
var buf bytes.Buffer
err = injector.Format("", &buf, file)
if err != nil {
t.Fatalf("格式化失败: %v", err)
}
result := buf.String()
if !strings.Contains(result, "NewArticleUsecase") {
t.Error("格式化输出应包含 NewArticleUsecase")
}
if !strings.Contains(result, "package biz") {
t.Error("格式化输出应包含 package biz")
}
}
// TestMainProviderSetAst_Injection 测试 MainProviderSetAst 注入功能
func TestMainProviderSetAst_Injection(t *testing.T) {
testCode := `package biz
import (
"github.com/google/wire"
"kra/internal/biz/user"
)
var ProviderSet = wire.NewSet(user.NewUserUsecase)
`
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
if err != nil {
t.Fatalf("解析测试代码失败: %v", err)
}
injector := &MainProviderSetAst{
Layer: "biz",
PackageName: "article",
ProviderName: "NewArticleUsecase",
ImportPath: "kra/internal/biz/article",
Path: "test.go",
}
// 执行注入
err = injector.Injection(file)
if err != nil {
t.Fatalf("注入失败: %v", err)
}
// 验证导入是否添加 - 检查 import 声明中是否包含新导入
var importFound bool
ast.Inspect(file, func(node ast.Node) bool {
if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT {
for _, spec := range genDecl.Specs {
if impSpec, ok := spec.(*ast.ImportSpec); ok {
if impSpec.Path.Value == "\"kra/internal/biz/article\"" {
importFound = true
return false
}
}
}
}
return true
})
if !importFound {
t.Error("注入后应包含 kra/internal/biz/article 导入")
}
// 验证 Provider 是否添加
var providerFound bool
ast.Inspect(file, func(node ast.Node) bool {
if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.VAR {
for _, spec := range genDecl.Specs {
if valueSpec, ok := spec.(*ast.ValueSpec); ok {
for i, name := range valueSpec.Names {
if name.Name == "ProviderSet" && i < len(valueSpec.Values) {
if callExpr, ok := valueSpec.Values[i].(*ast.CallExpr); ok {
for _, arg := range callExpr.Args {
if selExpr, ok := arg.(*ast.SelectorExpr); ok {
if pkgIdent, ok := selExpr.X.(*ast.Ident); ok {
if pkgIdent.Name == "article" && selExpr.Sel.Name == "NewArticleUsecase" {
providerFound = true
return false
}
}
}
}
}
}
}
}
}
}
return true
})
if !providerFound {
t.Error("注入后应包含 article.NewArticleUsecase Provider")
}
}
// TestMainProviderSetAst_Rollback 测试 MainProviderSetAst 回滚功能
func TestMainProviderSetAst_Rollback(t *testing.T) {
testCode := `package biz
import (
"github.com/google/wire"
"kra/internal/biz/user"
"kra/internal/biz/article"
)
var ProviderSet = wire.NewSet(user.NewUserUsecase, article.NewArticleUsecase)
`
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
if err != nil {
t.Fatalf("解析测试代码失败: %v", err)
}
injector := &MainProviderSetAst{
Layer: "biz",
PackageName: "article",
ProviderName: "NewArticleUsecase",
ImportPath: "kra/internal/biz/article",
Path: "test.go",
}
// 执行回滚
err = injector.Rollback(file)
if err != nil {
t.Fatalf("回滚失败: %v", err)
}
// 验证导入是否移除 - 检查 import 声明
var importFound bool
ast.Inspect(file, func(node ast.Node) bool {
if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT {
for _, spec := range genDecl.Specs {
if impSpec, ok := spec.(*ast.ImportSpec); ok {
if impSpec.Path.Value == "\"kra/internal/biz/article\"" {
importFound = true
return false
}
}
}
}
return true
})
if importFound {
t.Error("回滚后不应包含 kra/internal/biz/article 导入")
}
// 验证 Provider 是否移除
var providerFound bool
ast.Inspect(file, func(node ast.Node) bool {
if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.VAR {
for _, spec := range genDecl.Specs {
if valueSpec, ok := spec.(*ast.ValueSpec); ok {
for i, name := range valueSpec.Names {
if name.Name == "ProviderSet" && i < len(valueSpec.Values) {
if callExpr, ok := valueSpec.Values[i].(*ast.CallExpr); ok {
for _, arg := range callExpr.Args {
if selExpr, ok := arg.(*ast.SelectorExpr); ok {
if pkgIdent, ok := selExpr.X.(*ast.Ident); ok {
if pkgIdent.Name == "article" && selExpr.Sel.Name == "NewArticleUsecase" {
providerFound = true
return false
}
}
}
}
}
}
}
}
}
}
return true
})
if providerFound {
t.Error("回滚后不应包含 article.NewArticleUsecase Provider")
}
}
// TestGetLayerMainFile 测试获取层主文件路径
func TestGetLayerMainFile(t *testing.T) {
testCases := []struct {
layer string
expected string
}{
{"biz", "internal/biz/biz.go"},
{"data", "internal/data/data.go"},
{"service", "internal/service/service.go"},
{"handler", "internal/server/handler/enter.go"},
{"api", "internal/server/handler/enter.go"},
{"router", "internal/server/router/enter.go"},
{"unknown", ""},
}
for _, tc := range testCases {
t.Run(tc.layer, func(t *testing.T) {
result := GetLayerMainFile(tc.layer)
if result != tc.expected {
t.Errorf("GetLayerMainFile(%s) = %s, 期望 %s", tc.layer, result, tc.expected)
}
})
}
}
// TestGetLayerImportPath 测试获取层导入路径
func TestGetLayerImportPath(t *testing.T) {
module := "kra"
packageName := "article"
testCases := []struct {
layer string
expected string
}{
{"biz", "kra/internal/biz/article"},
{"data", "kra/internal/data/article"},
{"service", "kra/internal/service/article"},
{"handler", "kra/internal/server/handler/article"},
{"api", "kra/internal/server/handler/article"},
{"router", "kra/internal/server/router/article"},
{"unknown", ""},
}
for _, tc := range testCases {
t.Run(tc.layer, func(t *testing.T) {
result := GetLayerImportPath(module, tc.layer, packageName)
if result != tc.expected {
t.Errorf("GetLayerImportPath(%s, %s, %s) = %s, 期望 %s", module, tc.layer, packageName, result, tc.expected)
}
})
}
}
// TestGetProviderName 测试获取 Provider 名称
func TestGetProviderName(t *testing.T) {
structName := "Article"
testCases := []struct {
layer string
expected string
}{
{"biz", "NewArticleUsecase"},
{"data", "NewArticleRepo"},
{"service", "NewArticleService"},
{"handler", "NewArticleHandler"},
{"api", "NewArticleHandler"},
{"router", "NewArticleRouter"},
{"unknown", ""},
}
for _, tc := range testCases {
t.Run(tc.layer, func(t *testing.T) {
result := GetProviderName(tc.layer, structName)
if result != tc.expected {
t.Errorf("GetProviderName(%s, %s) = %s, 期望 %s", tc.layer, structName, result, tc.expected)
}
})
}
}
// TestAddImport 测试添加导入
func TestAddImport(t *testing.T) {
testCode := `package main
import "fmt"
func main() {}
`
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
if err != nil {
t.Fatalf("解析测试代码失败: %v", err)
}
// 添加新导入
addImport(file, "kra/internal/biz/article")
// 验证导入是否添加 - 检查 import 声明
var found bool
ast.Inspect(file, func(node ast.Node) bool {
if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT {
for _, spec := range genDecl.Specs {
if impSpec, ok := spec.(*ast.ImportSpec); ok {
if impSpec.Path.Value == "\"kra/internal/biz/article\"" {
found = true
return false
}
}
}
}
return true
})
if !found {
t.Error("应该添加 kra/internal/biz/article 导入")
}
}
// TestAddImport_Duplicate 测试重复添加导入
func TestAddImport_Duplicate(t *testing.T) {
testCode := `package main
import (
"fmt"
"kra/internal/biz/article"
)
func main() {}
`
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
if err != nil {
t.Fatalf("解析测试代码失败: %v", err)
}
initialCount := len(file.Imports)
// 尝试添加已存在的导入
addImport(file, "kra/internal/biz/article")
// 验证导入数量没有增加
if len(file.Imports) != initialCount {
t.Errorf("重复添加导入后数量应保持 %d, 实际为 %d", initialCount, len(file.Imports))
}
}
// TestRemoveImport 测试移除导入
func TestRemoveImport(t *testing.T) {
testCode := `package main
import (
"fmt"
"kra/internal/biz/article"
)
func main() {}
`
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
if err != nil {
t.Fatalf("解析测试代码失败: %v", err)
}
// 移除导入
removeImport(file, "kra/internal/biz/article")
// 验证导入是否移除 - 检查 import 声明
var found bool
ast.Inspect(file, func(node ast.Node) bool {
if genDecl, ok := node.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT {
for _, spec := range genDecl.Specs {
if impSpec, ok := spec.(*ast.ImportSpec); ok {
if impSpec.Path.Value == "\"kra/internal/biz/article\"" {
found = true
return false
}
}
}
}
return true
})
if found {
t.Error("应该移除 kra/internal/biz/article 导入")
}
}
// TestNewWireProviderAst 测试创建 WireProviderAst
func TestNewWireProviderAst(t *testing.T) {
ast := NewWireProviderAst("biz", "article", "NewArticleUsecase", "internal/biz/article/enter.go")
if ast.Layer != "biz" {
t.Errorf("Layer 应为 'biz', 实际为 '%s'", ast.Layer)
}
if ast.PackageName != "article" {
t.Errorf("PackageName 应为 'article', 实际为 '%s'", ast.PackageName)
}
if ast.ProviderName != "NewArticleUsecase" {
t.Errorf("ProviderName 应为 'NewArticleUsecase', 实际为 '%s'", ast.ProviderName)
}
if ast.Path != "internal/biz/article/enter.go" {
t.Errorf("Path 应为 'internal/biz/article/enter.go', 实际为 '%s'", ast.Path)
}
}
// TestNewMainProviderSetAst 测试创建 MainProviderSetAst
func TestNewMainProviderSetAst(t *testing.T) {
ast := NewMainProviderSetAst("biz", "article", "NewArticleUsecase", "kra/internal/biz/article", "internal/biz/biz.go")
if ast.Layer != "biz" {
t.Errorf("Layer 应为 'biz', 实际为 '%s'", ast.Layer)
}
if ast.PackageName != "article" {
t.Errorf("PackageName 应为 'article', 实际为 '%s'", ast.PackageName)
}
if ast.ProviderName != "NewArticleUsecase" {
t.Errorf("ProviderName 应为 'NewArticleUsecase', 实际为 '%s'", ast.ProviderName)
}
if ast.ImportPath != "kra/internal/biz/article" {
t.Errorf("ImportPath 应为 'kra/internal/biz/article', 实际为 '%s'", ast.ImportPath)
}
if ast.Path != "internal/biz/biz.go" {
t.Errorf("Path 应为 'internal/biz/biz.go', 实际为 '%s'", ast.Path)
}
}
// TestWireProviderAst_Injection_NoProviderSet 测试没有 ProviderSet 时的注入
func TestWireProviderAst_Injection_NoProviderSet(t *testing.T) {
testCode := `package biz
import "github.com/google/wire"
var SomeOtherVar = "test"
`
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "test.go", testCode, parser.ParseComments)
if err != nil {
t.Fatalf("解析测试代码失败: %v", err)
}
injector := &WireProviderAst{
Layer: "biz",
PackageName: "article",
ProviderName: "NewArticleUsecase",
Path: "test.go",
}
// 执行注入应该返回错误
err = injector.Injection(file)
if err == nil {
t.Error("没有 ProviderSet 时注入应该返回错误")
}
}
// TestWireInjectionConfig 测试 WireInjectionConfig 结构体
func TestWireInjectionConfig(t *testing.T) {
config := &WireInjectionConfig{
Module: "kra",
PackageName: "article",
StructName: "Article",
Layers: []string{"biz", "data", "service", "handler", "router"},
}
if config.Module != "kra" {
t.Errorf("Module 应为 'kra', 实际为 '%s'", config.Module)
}
if config.PackageName != "article" {
t.Errorf("PackageName 应为 'article', 实际为 '%s'", config.PackageName)
}
if config.StructName != "Article" {
t.Errorf("StructName 应为 'Article', 实际为 '%s'", config.StructName)
}
if len(config.Layers) != 5 {
t.Errorf("Layers 长度应为 5, 实际为 %d", len(config.Layers))
}
}