package system import ( "bytes" "go/ast" "go/format" "go/parser" "go/token" "os" "path/filepath" "strings" "testing" "text/template" "kra/pkg/utils/autocode" ) // ==================== 端到端代码生成测试 ==================== // TestEndToEndCodeGeneration 端到端代码生成测试 // 验证所有层的代码生成是否正确 func TestEndToEndCodeGeneration(t *testing.T) { info := createTestAutoCodeInfo() t.Run("BizLayer", func(t *testing.T) { testBizLayerGeneration(t, info) }) t.Run("DataLayer", func(t *testing.T) { testDataLayerGeneration(t, info) }) t.Run("HandlerLayer", func(t *testing.T) { testHandlerLayerGeneration(t, info) }) t.Run("ServiceLayer", func(t *testing.T) { testServiceLayerGeneration(t, info) }) t.Run("RouterLayer", func(t *testing.T) { testRouterLayerGeneration(t, info) }) t.Run("TypesLayer", func(t *testing.T) { testTypesLayerGeneration(t, info) }) t.Run("ModelLayer", func(t *testing.T) { testModelLayerGeneration(t, info) }) t.Run("FrontendPages", func(t *testing.T) { testFrontendPagesGeneration(t, info) }) t.Run("FrontendServices", func(t *testing.T) { testFrontendServicesGeneration(t, info) }) t.Run("FrontendTypes", func(t *testing.T) { testFrontendTypesGeneration(t, info) }) } // createTestAutoCodeInfo 创建测试用的 AutoCodeInfo func createTestAutoCodeInfo() *AutoCodeInfo { return &AutoCodeInfo{ StructName: "Article", TableName: "articles", Package: "article", PackageName: "article", Abbreviation: "article", Description: "文章", Module: "kra", GvaModel: true, PrimaryField: &AutoCodeField{ FieldName: "ID", FieldType: "uint", FieldJson: "id", FieldDesc: "主键ID", }, Fields: []*AutoCodeField{ { FieldName: "Title", FieldType: "string", FieldJson: "title", FieldDesc: "标题", ColumnName: "title", FieldSearchType: "LIKE", Form: true, Table: true, Desc: true, Require: true, ErrorText: "请输入标题", }, { FieldName: "Content", FieldType: "string", FieldJson: "content", FieldDesc: "内容", ColumnName: "content", Form: true, Table: false, Desc: true, }, { FieldName: "Status", FieldType: "int", FieldJson: "status", FieldDesc: "状态", ColumnName: "status", FieldSearchType: "=", DictType: "articleStatus", Form: true, Table: true, Desc: true, }, { FieldName: "PublishTime", FieldType: "time.Time", FieldJson: "publishTime", FieldDesc: "发布时间", ColumnName: "publish_time", FieldSearchType: "BETWEEN", Form: true, Table: true, Desc: true, }, { FieldName: "ViewCount", FieldType: "int", FieldJson: "viewCount", FieldDesc: "浏览量", ColumnName: "view_count", Form: false, Table: true, Desc: true, Sort: true, }, { FieldName: "IsTop", FieldType: "bool", FieldJson: "isTop", FieldDesc: "是否置顶", ColumnName: "is_top", Form: true, Table: true, Desc: true, }, }, } } // renderTemplate 渲染模板 func renderTemplate(t *testing.T, name, tplContent string, data interface{}) string { tmpl, err := template.New(name).Funcs(autocode.GetTemplateFuncMap()).Parse(tplContent) if err != nil { t.Fatalf("解析模板失败: %v", err) } var buf bytes.Buffer if err := tmpl.Execute(&buf, data); err != nil { t.Fatalf("执行模板失败: %v", err) } return buf.String() } // validateGoSyntax 验证 Go 代码语法 func validateGoSyntax(t *testing.T, code, layer string) { fset := token.NewFileSet() _, err := parser.ParseFile(fset, "", code, parser.AllErrors) if err != nil { t.Errorf("%s 代码语法错误: %v\n代码:\n%s", layer, err, code) } } // validateFormatIdempotence 验证 gofmt 幂等性 (Property 4) func validateFormatIdempotence(t *testing.T, code, layer string) { formatted1, err := format.Source([]byte(code)) if err != nil { t.Errorf("%s 第一次格式化失败: %v", layer, err) return } formatted2, err := format.Source(formatted1) if err != nil { t.Errorf("%s 第二次格式化失败: %v", layer, err) return } if string(formatted1) != string(formatted2) { t.Errorf("%s gofmt 不是幂等的", layer) } } // TestFileRollback 测试文件删除回滚 func TestFileRollback(t *testing.T) { tempDir := t.TempDir() generatedFiles := []string{ filepath.Join(tempDir, "biz", "article.go"), filepath.Join(tempDir, "data", "article.go"), filepath.Join(tempDir, "service", "article.go"), } // 创建目录和文件 for _, file := range generatedFiles { dir := filepath.Dir(file) if err := os.MkdirAll(dir, 0755); err != nil { t.Fatalf("创建目录失败: %v", err) } if err := os.WriteFile(file, []byte("// generated file"), 0644); err != nil { t.Fatalf("创建文件失败: %v", err) } } // 验证文件存在 for _, file := range generatedFiles { if _, err := os.Stat(file); os.IsNotExist(err) { t.Errorf("文件应该存在: %s", file) } } // 模拟回滚:删除生成的文件 for _, file := range generatedFiles { if err := os.Remove(file); err != nil { t.Errorf("删除文件失败: %v", err) } } // 验证文件已删除 for _, file := range generatedFiles { if _, err := os.Stat(file); !os.IsNotExist(err) { t.Errorf("文件应该已被删除: %s", file) } } } // TestRollbackAllWireProviders 测试批量回滚所有层的 Wire Provider func TestRollbackAllWireProviders(t *testing.T) { tempDir := t.TempDir() layers := map[string]string{ "biz": `package biz import ( "github.com/google/wire" "kra/internal/biz/article" ) var ProviderSet = wire.NewSet( article.NewArticleUsecase, ) `, "data": `package data import ( "github.com/google/wire" "kra/internal/data/article" ) var ProviderSet = wire.NewSet( article.NewArticleRepo, ) `, } files := make(map[string]string) for layer, content := range layers { file := filepath.Join(tempDir, layer+".go") files[layer] = file if err := os.WriteFile(file, []byte(content), 0644); err != nil { t.Fatalf("创建 %s 文件失败: %v", layer, err) } } for layer, file := range files { var providerName, packageName string switch layer { case "biz": providerName = "NewArticleUsecase" packageName = "article" case "data": providerName = "NewArticleRepo" packageName = "article" } injector := autocode.NewMainProviderSetAst( layer, packageName, providerName, "kra/internal/"+layer+"/article", file, ) astFile, err := injector.Parse(file, nil) if err != nil { t.Fatalf("解析 %s 文件失败: %v", layer, err) } if err := injector.Rollback(astFile); err != nil { t.Fatalf("回滚 %s 层失败: %v", layer, err) } var buf bytes.Buffer if err := injector.Format(file, &buf, astFile); err != nil { t.Fatalf("格式化 %s 失败: %v", layer, err) } result := buf.String() if strings.Contains(result, providerName) { t.Errorf("%s 层回滚后不应包含 %s", layer, providerName) } if strings.Contains(result, "kra/internal/"+layer+"/article") { t.Errorf("%s 层回滚后不应包含 article 导入", layer) } } } // ==================== 后端层测试 ==================== // testBizLayerGeneration 测试 Biz 层代码生成 func testBizLayerGeneration(t *testing.T, info *AutoCodeInfo) { tplContent := `package {{.Package}} import ( "context" "github.com/go-kratos/kratos/v2/log" ) // {{.StructName}} {{.Description}}实体 type {{.StructName}} struct { {{- range .Fields}} {{.FieldName}} {{.FieldType}} ` + "`" + `json:"{{.FieldJson}}"` + "`" + ` {{- end}} } // {{.StructName}}Repo {{.Description}}仓储接口 type {{.StructName}}Repo interface { Create(ctx context.Context, {{.Abbreviation}} *{{.StructName}}) error Delete(ctx context.Context, {{.PrimaryField.FieldJson}} {{.PrimaryField.FieldType}}) error Update(ctx context.Context, {{.Abbreviation}} *{{.StructName}}) error FindByID(ctx context.Context, {{.PrimaryField.FieldJson}} {{.PrimaryField.FieldType}}) (*{{.StructName}}, error) List(ctx context.Context, page, pageSize int) ([]*{{.StructName}}, int64, error) } // {{.StructName}}Usecase {{.Description}}用例 type {{.StructName}}Usecase struct { repo {{.StructName}}Repo log *log.Helper } // New{{.StructName}}Usecase 创建{{.Description}}用例 func New{{.StructName}}Usecase(repo {{.StructName}}Repo, logger log.Logger) *{{.StructName}}Usecase { return &{{.StructName}}Usecase{ repo: repo, log: log.NewHelper(logger), } } ` result := renderTemplate(t, "biz", tplContent, info) requiredComponents := []string{ "package article", "Article struct", "ArticleRepo interface", "ArticleUsecase struct", "NewArticleUsecase", "Create(ctx context.Context", "Delete(ctx context.Context", "Update(ctx context.Context", "FindByID(ctx context.Context", "List(ctx context.Context", "log.Helper", } for _, component := range requiredComponents { if !strings.Contains(result, component) { t.Errorf("Biz 层代码应包含 '%s'", component) } } validateGoSyntax(t, result, "biz layer") validateFormatIdempotence(t, result, "biz layer") } // testDataLayerGeneration 测试 Data 层代码生成 func testDataLayerGeneration(t *testing.T, info *AutoCodeInfo) { tplContent := `package {{.Package}} import ( "context" "{{.Module}}/internal/biz/{{.Package}}" "{{.Module}}/internal/data/query" ) type {{.Abbreviation}}Repo struct { data *Data } func New{{.StructName}}Repo(data *Data) {{.Package}}.{{.StructName}}Repo { return &{{.Abbreviation}}Repo{data: data} } func (r *{{.Abbreviation}}Repo) Create(ctx context.Context, {{.Abbreviation}} *{{.Package}}.{{.StructName}}) error { t := query.{{.StructName}} return t.WithContext(ctx).Create({{.Abbreviation}}) } func (r *{{.Abbreviation}}Repo) List(ctx context.Context, page, pageSize int) ([]*{{.Package}}.{{.StructName}}, int64, error) { t := query.{{.StructName}} q := t.WithContext(ctx) total, _ := q.Count() list, err := q.Offset((page - 1) * pageSize).Limit(pageSize).Find() return list, total, err } ` result := renderTemplate(t, "data", tplContent, info) requiredComponents := []string{ "package article", "articleRepo struct", "NewArticleRepo", "query.Article", "WithContext(ctx)", } for _, component := range requiredComponents { if !strings.Contains(result, component) { t.Errorf("Data 层代码应包含 '%s'", component) } } validateGoSyntax(t, result, "data layer") } // testHandlerLayerGeneration 测试 Handler 层代码生成 func testHandlerLayerGeneration(t *testing.T, info *AutoCodeInfo) { tplContent := `package {{.Package}} import ( "{{.Module}}/internal/biz/{{.Package}}" "{{.Module}}/pkg/response" "github.com/gin-gonic/gin" ) type {{.StructName}}Handler struct { uc *{{.Package}}.{{.StructName}}Usecase } func New{{.StructName}}Handler(uc *{{.Package}}.{{.StructName}}Usecase) *{{.StructName}}Handler { return &{{.StructName}}Handler{uc: uc} } // Create{{.StructName}} 创建{{.Description}} // @Tags {{.Description}} // @Summary 创建{{.Description}} // @Router /{{.Abbreviation}}/create{{.StructName}} [post] func (h *{{.StructName}}Handler) Create{{.StructName}}(c *gin.Context) { response.OkWithMessage("创建成功", c) } // Delete{{.StructName}} 删除{{.Description}} // @Tags {{.Description}} // @Router /{{.Abbreviation}}/delete{{.StructName}} [delete] func (h *{{.StructName}}Handler) Delete{{.StructName}}(c *gin.Context) { response.OkWithMessage("删除成功", c) } // Update{{.StructName}} 更新{{.Description}} // @Tags {{.Description}} // @Router /{{.Abbreviation}}/update{{.StructName}} [put] func (h *{{.StructName}}Handler) Update{{.StructName}}(c *gin.Context) { response.OkWithMessage("更新成功", c) } // Find{{.StructName}} 根据ID获取{{.Description}} // @Tags {{.Description}} // @Router /{{.Abbreviation}}/find{{.StructName}} [get] func (h *{{.StructName}}Handler) Find{{.StructName}}(c *gin.Context) { response.OkWithMessage("获取成功", c) } // Get{{.StructName}}List 获取{{.Description}}列表 // @Tags {{.Description}} // @Router /{{.Abbreviation}}/get{{.StructName}}List [get] func (h *{{.StructName}}Handler) Get{{.StructName}}List(c *gin.Context) { response.OkWithMessage("获取成功", c) } ` result := renderTemplate(t, "handler", tplContent, info) requiredComponents := []string{ "package article", "ArticleHandler struct", "NewArticleHandler", "CreateArticle", "DeleteArticle", "UpdateArticle", "FindArticle", "GetArticleList", "@Tags 文章", "@Router /article/", "gin.Context", } for _, component := range requiredComponents { if !strings.Contains(result, component) { t.Errorf("Handler 层代码应包含 '%s'", component) } } validateGoSyntax(t, result, "handler layer") } // testServiceLayerGeneration 测试 Service 层代码生成 func testServiceLayerGeneration(t *testing.T, info *AutoCodeInfo) { tplContent := `package {{.Package}} import ( "context" "{{.Module}}/internal/biz/{{.Package}}" "github.com/go-kratos/kratos/v2/log" ) type {{.StructName}}Service struct { uc *{{.Package}}.{{.StructName}}Usecase log *log.Helper } func New{{.StructName}}Service(uc *{{.Package}}.{{.StructName}}Usecase, logger log.Logger) *{{.StructName}}Service { return &{{.StructName}}Service{ uc: uc, log: log.NewHelper(logger), } } func (s *{{.StructName}}Service) Create{{.StructName}}(ctx context.Context) error { return nil } func (s *{{.StructName}}Service) Delete{{.StructName}}(ctx context.Context, {{.PrimaryField.FieldJson}} {{.PrimaryField.FieldType}}) error { return nil } ` result := renderTemplate(t, "service", tplContent, info) requiredComponents := []string{ "package article", "ArticleService struct", "NewArticleService", "CreateArticle", "DeleteArticle", } for _, component := range requiredComponents { if !strings.Contains(result, component) { t.Errorf("Service 层代码应包含 '%s'", component) } } validateGoSyntax(t, result, "service layer") } // testRouterLayerGeneration 测试 Router 层代码生成 func testRouterLayerGeneration(t *testing.T, info *AutoCodeInfo) { tplContent := `package {{.Package}} import ( "{{.Module}}/internal/server/handler/{{.Package}}" "{{.Module}}/internal/middleware" "github.com/gin-gonic/gin" ) type {{.StructName}}Router struct{} func (r *{{.StructName}}Router) Init{{.StructName}}Router(Router *gin.RouterGroup, PublicRouter *gin.RouterGroup, handler *{{.Package}}.{{.StructName}}Handler) { {{.Abbreviation}}Router := Router.Group("{{.Abbreviation}}").Use(middleware.OperationRecord()) {{.Abbreviation}}RouterWithoutRecord := Router.Group("{{.Abbreviation}}") { {{.Abbreviation}}Router.POST("create{{.StructName}}", handler.Create{{.StructName}}) {{.Abbreviation}}Router.DELETE("delete{{.StructName}}", handler.Delete{{.StructName}}) {{.Abbreviation}}Router.PUT("update{{.StructName}}", handler.Update{{.StructName}}) } { {{.Abbreviation}}RouterWithoutRecord.GET("find{{.StructName}}", handler.Find{{.StructName}}) {{.Abbreviation}}RouterWithoutRecord.GET("get{{.StructName}}List", handler.Get{{.StructName}}List) } } ` result := renderTemplate(t, "router", tplContent, info) requiredComponents := []string{ "package article", "ArticleRouter struct", "InitArticleRouter", "gin.RouterGroup", "articleRouter", "middleware.OperationRecord", } for _, component := range requiredComponents { if !strings.Contains(result, component) { t.Errorf("Router 层代码应包含 '%s'", component) } } validateGoSyntax(t, result, "router layer") } // testTypesLayerGeneration 测试 Types 层代码生成 func testTypesLayerGeneration(t *testing.T, info *AutoCodeInfo) { tplContent := `package request import ( "{{.Module}}/pkg/request" ) type {{.StructName}}Request struct { {{- range .Fields}} {{.FieldName}} {{.FieldType}} ` + "`" + `json:"{{.FieldJson}}"` + "`" + ` {{- end}} } type {{.StructName}}IDRequest struct { ID {{.PrimaryField.FieldType}} ` + "`" + `json:"{{.PrimaryField.FieldJson}}"` + "`" + ` } type {{.StructName}}Search struct { request.PageInfo } ` result := renderTemplate(t, "types", tplContent, info) requiredComponents := []string{ "package request", "ArticleRequest struct", "ArticleIDRequest struct", "ArticleSearch struct", "PageInfo", } for _, component := range requiredComponents { if !strings.Contains(result, component) { t.Errorf("Types 层代码应包含 '%s'", component) } } validateGoSyntax(t, result, "types layer") } // testModelLayerGeneration 测试 Model 层代码生成 func testModelLayerGeneration(t *testing.T, info *AutoCodeInfo) { tplContent := `package model type {{.StructName}} struct { {{- if .GvaModel }} GVA_MODEL {{- end }} {{- range .Fields}} {{.FieldName}} {{.FieldType}} ` + "`" + `gorm:"column:{{.ColumnName}}" json:"{{.FieldJson}}"` + "`" + ` {{- end}} } func ({{.StructName}}) TableName() string { return "{{.TableName}}" } ` result := renderTemplate(t, "model", tplContent, info) requiredComponents := []string{ "package model", "Article struct", "GVA_MODEL", "TableName()", "articles", "gorm:", } for _, component := range requiredComponents { if !strings.Contains(result, component) { t.Errorf("Model 层代码应包含 '%s'", component) } } } // ==================== 前端层测试 ==================== // testFrontendPagesGeneration 测试前端页面代码生成 func testFrontendPagesGeneration(t *testing.T, info *AutoCodeInfo) { // 简化测试:直接验证模板函数可以生成正确的列配置 // 使用 autocode 包的类型来测试 field := autocode.AutoCodeField{ FieldName: "Title", FieldType: "string", FieldJson: "title", FieldDesc: "标题", FieldSearchType: "LIKE", Table: true, Sort: true, DictType: "", } // 测试 GenerateReactProTableColumn 函数 result := autocode.GenerateReactProTableColumn(field) requiredComponents := []string{ "title:", "'标题'", "dataIndex:", "'title'", "sorter: true", } for _, component := range requiredComponents { if !strings.Contains(result, component) { t.Errorf("ProTable 列配置应包含 '%s', 实际结果: %s", component, result) } } // 测试字典类型字段 dictField := autocode.AutoCodeField{ FieldName: "Status", FieldType: "int", FieldJson: "status", FieldDesc: "状态", FieldSearchType: "=", DictType: "articleStatus", Table: true, } dictResult := autocode.GenerateReactProTableColumn(dictField) if !strings.Contains(dictResult, "articleStatus") { t.Errorf("字典类型列应包含 articleStatus, 实际结果: %s", dictResult) } if !strings.Contains(dictResult, "select") { t.Errorf("字典类型列应包含 select valueType, 实际结果: %s", dictResult) } // 测试时间类型字段 timeField := autocode.AutoCodeField{ FieldName: "PublishTime", FieldType: "time.Time", FieldJson: "publishTime", FieldDesc: "发布时间", FieldSearchType: "BETWEEN", Table: true, } timeResult := autocode.GenerateReactProTableColumn(timeField) if !strings.Contains(timeResult, "publishTime") { t.Errorf("时间类型列应包含 publishTime, 实际结果: %s", timeResult) } if !strings.Contains(timeResult, "dateTime") { t.Errorf("时间类型列应包含 dateTime valueType, 实际结果: %s", timeResult) } // 测试布尔类型字段 boolField := autocode.AutoCodeField{ FieldName: "IsTop", FieldType: "bool", FieldJson: "isTop", FieldDesc: "是否置顶", Table: true, } boolResult := autocode.GenerateReactProTableColumn(boolField) if !strings.Contains(boolResult, "isTop") { t.Errorf("布尔类型列应包含 isTop, 实际结果: %s", boolResult) } } // testFrontendServicesGeneration 测试前端 API 服务代码生成 func testFrontendServicesGeneration(t *testing.T, info *AutoCodeInfo) { tplContent := `import { request } from '@umijs/max'; /** 获取{{.Description}}列表 */ export async function get{{.StructName}}List(params: API.{{.StructName}}ListParams) { return request('/api/{{.Abbreviation}}/get{{.StructName}}List', { method: 'GET', params, }); } /** 获取{{.Description}}详情 */ export async function get{{.StructName}}ById({{.PrimaryField.FieldJson}}: {{ GenerateTSType .PrimaryField.FieldType }}) { return request('/api/{{.Abbreviation}}/find{{.StructName}}', { method: 'GET', params: { {{.PrimaryField.FieldJson}} }, }); } /** 创建{{.Description}} */ export async function create{{.StructName}}(data: API.{{.StructName}}Request) { return request('/api/{{.Abbreviation}}/create{{.StructName}}', { method: 'POST', data, }); } /** 更新{{.Description}} */ export async function update{{.StructName}}(data: API.{{.StructName}}Request) { return request('/api/{{.Abbreviation}}/update{{.StructName}}', { method: 'PUT', data, }); } /** 删除{{.Description}} */ export async function delete{{.StructName}}({{.PrimaryField.FieldJson}}: {{ GenerateTSType .PrimaryField.FieldType }}) { return request('/api/{{.Abbreviation}}/delete{{.StructName}}', { method: 'DELETE', params: { {{.PrimaryField.FieldJson}} }, }); } /** 批量删除{{.Description}} */ export async function delete{{.StructName}}ByIds(ids: {{ GenerateTSType .PrimaryField.FieldType }}[]) { return request('/api/{{.Abbreviation}}/delete{{.StructName}}ByIds', { method: 'DELETE', data: { ids }, }); } ` result := renderTemplate(t, "services", tplContent, info) // 验证 API 服务必要组件 requiredComponents := []string{ "import { request } from '@umijs/max'", "getArticleList", "getArticleById", "createArticle", "updateArticle", "deleteArticle", "deleteArticleByIds", "/api/article/", "method: 'GET'", "method: 'POST'", "method: 'PUT'", "method: 'DELETE'", "API.ArticleListParams", "API.ArticleRequest", "API.Response", } for _, component := range requiredComponents { if !strings.Contains(result, component) { t.Errorf("前端 API 服务代码应包含 '%s'", component) } } // 验证类型转换 if !strings.Contains(result, "number") { t.Error("前端 API 服务应包含 number 类型(从 uint 转换)") } } // testFrontendTypesGeneration 测试前端 TypeScript 类型定义生成 func testFrontendTypesGeneration(t *testing.T, info *AutoCodeInfo) { tplContent := `declare namespace API { /** {{.Description}}实体 */ interface {{.StructName}} { {{.PrimaryField.FieldJson}}?: {{ GenerateTSType .PrimaryField.FieldType }}; {{- range .Fields}} {{.FieldJson}}?: {{ GenerateTSType .FieldType }}; {{- end}} createdAt?: string; updatedAt?: string; } /** {{.Description}}请求参数 */ interface {{.StructName}}Request { {{- range .Fields}} {{- if .Form}} {{.FieldJson}}{{if not .Require}}?{{end}}: {{ GenerateTSType .FieldType }}; {{- end}} {{- end}} } /** {{.Description}}列表请求参数 */ interface {{.StructName}}ListParams { page?: number; pageSize?: number; {{- range .Fields}} {{- if .FieldSearchType}} {{.FieldJson}}?: {{ GenerateTSType .FieldType }}; {{- end}} {{- end}} } /** {{.Description}}列表响应 */ interface {{.StructName}}ListResult { list: {{.StructName}}[]; total: number; page: number; pageSize: number; } } ` result := renderTemplate(t, "types", tplContent, info) // 验证 TypeScript 类型定义必要组件 requiredComponents := []string{ "declare namespace API", "interface Article", "interface ArticleRequest", "interface ArticleListParams", "interface ArticleListResult", "id?:", "title", "content", "status", "publishTime", "viewCount", "createdAt?: string", "updatedAt?: string", "page?: number", "pageSize?: number", "list: Article[]", "total: number", } for _, component := range requiredComponents { if !strings.Contains(result, component) { t.Errorf("前端 TypeScript 类型定义应包含 '%s'", component) } } // 验证类型映射正确性 typeChecks := map[string]string{ "string": "string", "number": "number", // int/uint -> number "boolean": "boolean", } for _, expectedType := range typeChecks { if !strings.Contains(result, expectedType) { t.Errorf("前端类型定义应包含 TypeScript 类型 '%s'", expectedType) } } } // ==================== Wire 注入测试 ==================== // TestWireProviderInjectionAndRollback 测试 Wire Provider 注入和回滚 func TestWireProviderInjectionAndRollback(t *testing.T) { tempDir := t.TempDir() testFile := filepath.Join(tempDir, "test_wire.go") initialContent := `package test import ( "github.com/google/wire" ) var ProviderSet = wire.NewSet( NewExistingProvider, ) ` if err := os.WriteFile(testFile, []byte(initialContent), 0644); err != nil { t.Fatalf("创建测试文件失败: %v", err) } injector := autocode.NewWireProviderAst("biz", "article", "NewArticleUsecase", testFile) file, err := injector.Parse(testFile, nil) if err != nil { t.Fatalf("解析文件失败: %v", err) } // 执行注入 if err := injector.Injection(file); err != nil { t.Fatalf("注入失败: %v", err) } var buf bytes.Buffer if err := injector.Format(testFile, &buf, file); err != nil { t.Fatalf("格式化失败: %v", err) } injectedContent := buf.String() if !strings.Contains(injectedContent, "NewArticleUsecase") { t.Error("注入后应包含 NewArticleUsecase") } if !strings.Contains(injectedContent, "NewExistingProvider") { t.Error("注入后应保留 NewExistingProvider") } // 写入注入后的内容并重新解析 if err := os.WriteFile(testFile, []byte(injectedContent), 0644); err != nil { t.Fatalf("写入注入后文件失败: %v", err) } file2, err := injector.Parse(testFile, nil) if err != nil { t.Fatalf("解析注入后文件失败: %v", err) } // 执行回滚 if err := injector.Rollback(file2); err != nil { t.Fatalf("回滚失败: %v", err) } var rollbackBuf bytes.Buffer if err := injector.Format(testFile, &rollbackBuf, file2); err != nil { t.Fatalf("格式化回滚结果失败: %v", err) } rollbackContent := rollbackBuf.String() if strings.Contains(rollbackContent, "NewArticleUsecase") { t.Error("回滚后不应包含 NewArticleUsecase") } if !strings.Contains(rollbackContent, "NewExistingProvider") { t.Error("回滚后应保留 NewExistingProvider") } } // TestMainProviderSetInjectionAndRollback 测试主 ProviderSet 注入和回滚 func TestMainProviderSetInjectionAndRollback(t *testing.T) { tempDir := t.TempDir() testFile := filepath.Join(tempDir, "biz.go") initialContent := `package biz import ( "github.com/google/wire" "kra/internal/biz/system" ) var ProviderSet = wire.NewSet( system.NewUserUsecase, ) ` if err := os.WriteFile(testFile, []byte(initialContent), 0644); err != nil { t.Fatalf("创建测试文件失败: %v", err) } injector := autocode.NewMainProviderSetAst( "biz", "article", "NewArticleUsecase", "kra/internal/biz/article", testFile, ) file, err := injector.Parse(testFile, nil) if err != nil { t.Fatalf("解析文件失败: %v", err) } if err := injector.Injection(file); err != nil { t.Fatalf("注入失败: %v", err) } var buf bytes.Buffer if err := injector.Format(testFile, &buf, file); err != nil { t.Fatalf("格式化失败: %v", err) } injectedContent := buf.String() if !strings.Contains(injectedContent, "article.NewArticleUsecase") { t.Error("注入后应包含 article.NewArticleUsecase") } if !strings.Contains(injectedContent, "kra/internal/biz/article") { t.Error("注入后应包含导入路径") } // 写入并重新解析 if err := os.WriteFile(testFile, []byte(injectedContent), 0644); err != nil { t.Fatalf("写入注入后文件失败: %v", err) } file2, err := injector.Parse(testFile, nil) if err != nil { t.Fatalf("解析注入后文件失败: %v", err) } // 执行回滚 if err := injector.Rollback(file2); err != nil { t.Fatalf("回滚失败: %v", err) } var rollbackBuf bytes.Buffer if err := injector.Format(testFile, &rollbackBuf, file2); err != nil { t.Fatalf("格式化回滚结果失败: %v", err) } rollbackContent := rollbackBuf.String() if strings.Contains(rollbackContent, "article.NewArticleUsecase") { t.Error("回滚后不应包含 article.NewArticleUsecase") } if strings.Contains(rollbackContent, "kra/internal/biz/article") { t.Error("回滚后不应包含 article 导入路径") } if !strings.Contains(rollbackContent, "system.NewUserUsecase") { t.Error("回滚后应保留 system.NewUserUsecase") } } // TestMultipleInjectionAndRollback 测试多次注入和回滚 func TestMultipleInjectionAndRollback(t *testing.T) { tempDir := t.TempDir() testFile := filepath.Join(tempDir, "test_multi.go") initialContent := `package test import ( "github.com/google/wire" ) var ProviderSet = wire.NewSet( NewBaseProvider, ) ` if err := os.WriteFile(testFile, []byte(initialContent), 0644); err != nil { t.Fatalf("创建测试文件失败: %v", err) } providers := []string{"NewProvider1", "NewProvider2", "NewProvider3"} currentContent := initialContent // 依次注入多个 Provider for _, provider := range providers { if err := os.WriteFile(testFile, []byte(currentContent), 0644); err != nil { t.Fatalf("写入文件失败: %v", err) } injector := autocode.NewWireProviderAst("biz", "test", provider, testFile) file, err := injector.Parse(testFile, nil) if err != nil { t.Fatalf("解析文件失败: %v", err) } if err := injector.Injection(file); err != nil { t.Fatalf("注入 %s 失败: %v", provider, err) } var buf bytes.Buffer if err := injector.Format(testFile, &buf, file); err != nil { t.Fatalf("格式化失败: %v", err) } currentContent = buf.String() } // 验证所有 Provider 都已注入 for _, provider := range providers { if !strings.Contains(currentContent, provider) { t.Errorf("应包含 %s", provider) } } // 逆序回滚所有 Provider for i := len(providers) - 1; i >= 0; i-- { provider := providers[i] if err := os.WriteFile(testFile, []byte(currentContent), 0644); err != nil { t.Fatalf("写入文件失败: %v", err) } injector := autocode.NewWireProviderAst("biz", "test", provider, testFile) file, err := injector.Parse(testFile, nil) if err != nil { t.Fatalf("解析文件失败: %v", err) } if err := injector.Rollback(file); err != nil { t.Fatalf("回滚 %s 失败: %v", provider, err) } var buf bytes.Buffer if err := injector.Format(testFile, &buf, file); err != nil { t.Fatalf("格式化失败: %v", err) } currentContent = buf.String() if strings.Contains(currentContent, provider) { t.Errorf("回滚后不应包含 %s", provider) } } if !strings.Contains(currentContent, "NewBaseProvider") { t.Error("回滚后应保留 NewBaseProvider") } } // TestRollbackIdempotency 测试回滚幂等性 func TestRollbackIdempotency(t *testing.T) { tempDir := t.TempDir() testFile := filepath.Join(tempDir, "test_idempotent.go") initialContent := `package test import ( "github.com/google/wire" ) var ProviderSet = wire.NewSet( NewExistingProvider, NewTargetProvider, ) ` if err := os.WriteFile(testFile, []byte(initialContent), 0644); err != nil { t.Fatalf("创建测试文件失败: %v", err) } injector := autocode.NewWireProviderAst("biz", "test", "NewTargetProvider", testFile) // 第一次回滚 file1, _ := injector.Parse(testFile, nil) _ = injector.Rollback(file1) var buf1 bytes.Buffer _ = injector.Format(testFile, &buf1, file1) firstRollback := buf1.String() if err := os.WriteFile(testFile, []byte(firstRollback), 0644); err != nil { t.Fatalf("写入文件失败: %v", err) } // 第二次回滚 file2, _ := injector.Parse(testFile, nil) _ = injector.Rollback(file2) var buf2 bytes.Buffer _ = injector.Format(testFile, &buf2, file2) secondRollback := buf2.String() // 验证幂等性 if firstRollback != secondRollback { t.Error("回滚应该是幂等的") } if strings.Contains(secondRollback, "NewTargetProvider") { t.Error("回滚后不应包含 NewTargetProvider") } if !strings.Contains(secondRollback, "NewExistingProvider") { t.Error("回滚后应保留 NewExistingProvider") } } // countProviderSetArgs 计算 ProviderSet 中的参数数量 func countProviderSetArgs(file *ast.File) int { var count int ast.Inspect(file, func(node ast.Node) bool { genDecl, ok := node.(*ast.GenDecl) if !ok || genDecl.Tok != token.VAR { return true } for _, spec := range genDecl.Specs { valueSpec, ok := spec.(*ast.ValueSpec) if !ok { continue } for i, name := range valueSpec.Names { if name.Name != "ProviderSet" { continue } if i >= len(valueSpec.Values) { continue } callExpr, ok := valueSpec.Values[i].(*ast.CallExpr) if !ok { continue } count = len(callExpr.Args) return false } } return true }) return count } // TestRollbackPreservesFormatting 测试回滚保持代码格式 func TestRollbackPreservesFormatting(t *testing.T) { tempDir := t.TempDir() testFile := filepath.Join(tempDir, "format_test.go") initialContent := `package test import ( "github.com/google/wire" ) var ProviderSet = wire.NewSet( NewProvider1, NewProvider2, NewTargetProvider, ) ` if err := os.WriteFile(testFile, []byte(initialContent), 0644); err != nil { t.Fatalf("创建测试文件失败: %v", err) } injector := autocode.NewWireProviderAst("biz", "test", "NewTargetProvider", testFile) file, err := injector.Parse(testFile, nil) if err != nil { t.Fatalf("解析文件失败: %v", err) } if err := injector.Rollback(file); err != nil { t.Fatalf("回滚失败: %v", err) } var buf bytes.Buffer if err := injector.Format(testFile, &buf, file); err != nil { t.Fatalf("格式化失败: %v", err) } result := buf.String() // 验证代码仍然是有效的 Go 代码 fset := token.NewFileSet() _, err = parser.ParseFile(fset, "", result, parser.AllErrors) if err != nil { t.Errorf("回滚后代码语法错误: %v", err) } // 验证 gofmt 幂等性 formatted1, err := format.Source([]byte(result)) if err != nil { t.Errorf("第一次格式化失败: %v", err) return } formatted2, err := format.Source(formatted1) if err != nil { t.Errorf("第二次格式化失败: %v", err) return } if string(formatted1) != string(formatted2) { t.Error("回滚后代码 gofmt 不是幂等的") } } // TestRollbackNonExistentProvider 测试回滚不存在的 Provider func TestRollbackNonExistentProvider(t *testing.T) { tempDir := t.TempDir() testFile := filepath.Join(tempDir, "nonexistent.go") initialContent := `package test import ( "github.com/google/wire" ) var ProviderSet = wire.NewSet( NewExistingProvider, ) ` if err := os.WriteFile(testFile, []byte(initialContent), 0644); err != nil { t.Fatalf("创建测试文件失败: %v", err) } injector := autocode.NewWireProviderAst("biz", "test", "NewNonExistentProvider", testFile) file, err := injector.Parse(testFile, nil) if err != nil { t.Fatalf("解析文件失败: %v", err) } // 回滚不存在的 Provider 不应该报错 if err := injector.Rollback(file); err != nil { t.Errorf("回滚不存在的 Provider 不应该报错: %v", err) } var buf bytes.Buffer if err := injector.Format(testFile, &buf, file); err != nil { t.Fatalf("格式化失败: %v", err) } result := buf.String() // 验证原有 Provider 仍然存在 if !strings.Contains(result, "NewExistingProvider") { t.Error("回滚后应保留 NewExistingProvider") } }