kra/internal/biz/system/auto_code_integration_test.go

1375 lines
35 KiB
Go

package system
import (
"bytes"
"go/ast"
"go/format"
"go/parser"
"go/token"
"os"
"path/filepath"
"strings"
"testing"
"text/template"
"kra/pkg/utils/autocode"
)
// ==================== 端到端代码生成测试 ====================
// TestEndToEndCodeGeneration 端到端代码生成测试
// 验证所有层的代码生成是否正确
func TestEndToEndCodeGeneration(t *testing.T) {
info := createTestAutoCodeInfo()
t.Run("BizLayer", func(t *testing.T) {
testBizLayerGeneration(t, info)
})
t.Run("DataLayer", func(t *testing.T) {
testDataLayerGeneration(t, info)
})
t.Run("HandlerLayer", func(t *testing.T) {
testHandlerLayerGeneration(t, info)
})
t.Run("ServiceLayer", func(t *testing.T) {
testServiceLayerGeneration(t, info)
})
t.Run("RouterLayer", func(t *testing.T) {
testRouterLayerGeneration(t, info)
})
t.Run("TypesLayer", func(t *testing.T) {
testTypesLayerGeneration(t, info)
})
t.Run("ModelLayer", func(t *testing.T) {
testModelLayerGeneration(t, info)
})
t.Run("FrontendPages", func(t *testing.T) {
testFrontendPagesGeneration(t, info)
})
t.Run("FrontendServices", func(t *testing.T) {
testFrontendServicesGeneration(t, info)
})
t.Run("FrontendTypes", func(t *testing.T) {
testFrontendTypesGeneration(t, info)
})
}
// createTestAutoCodeInfo 创建测试用的 AutoCodeInfo
func createTestAutoCodeInfo() *AutoCodeInfo {
return &AutoCodeInfo{
StructName: "Article",
TableName: "articles",
Package: "article",
PackageName: "article",
Abbreviation: "article",
Description: "文章",
Module: "kra",
GvaModel: true,
PrimaryField: &AutoCodeField{
FieldName: "ID",
FieldType: "uint",
FieldJson: "id",
FieldDesc: "主键ID",
},
Fields: []*AutoCodeField{
{
FieldName: "Title",
FieldType: "string",
FieldJson: "title",
FieldDesc: "标题",
ColumnName: "title",
FieldSearchType: "LIKE",
Form: true,
Table: true,
Desc: true,
Require: true,
ErrorText: "请输入标题",
},
{
FieldName: "Content",
FieldType: "string",
FieldJson: "content",
FieldDesc: "内容",
ColumnName: "content",
Form: true,
Table: false,
Desc: true,
},
{
FieldName: "Status",
FieldType: "int",
FieldJson: "status",
FieldDesc: "状态",
ColumnName: "status",
FieldSearchType: "=",
DictType: "articleStatus",
Form: true,
Table: true,
Desc: true,
},
{
FieldName: "PublishTime",
FieldType: "time.Time",
FieldJson: "publishTime",
FieldDesc: "发布时间",
ColumnName: "publish_time",
FieldSearchType: "BETWEEN",
Form: true,
Table: true,
Desc: true,
},
{
FieldName: "ViewCount",
FieldType: "int",
FieldJson: "viewCount",
FieldDesc: "浏览量",
ColumnName: "view_count",
Form: false,
Table: true,
Desc: true,
Sort: true,
},
{
FieldName: "IsTop",
FieldType: "bool",
FieldJson: "isTop",
FieldDesc: "是否置顶",
ColumnName: "is_top",
Form: true,
Table: true,
Desc: true,
},
},
}
}
// renderTemplate 渲染模板
func renderTemplate(t *testing.T, name, tplContent string, data interface{}) string {
tmpl, err := template.New(name).Funcs(autocode.GetTemplateFuncMap()).Parse(tplContent)
if err != nil {
t.Fatalf("解析模板失败: %v", err)
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, data); err != nil {
t.Fatalf("执行模板失败: %v", err)
}
return buf.String()
}
// validateGoSyntax 验证 Go 代码语法
func validateGoSyntax(t *testing.T, code, layer string) {
fset := token.NewFileSet()
_, err := parser.ParseFile(fset, "", code, parser.AllErrors)
if err != nil {
t.Errorf("%s 代码语法错误: %v\n代码:\n%s", layer, err, code)
}
}
// validateFormatIdempotence 验证 gofmt 幂等性 (Property 4)
func validateFormatIdempotence(t *testing.T, code, layer string) {
formatted1, err := format.Source([]byte(code))
if err != nil {
t.Errorf("%s 第一次格式化失败: %v", layer, err)
return
}
formatted2, err := format.Source(formatted1)
if err != nil {
t.Errorf("%s 第二次格式化失败: %v", layer, err)
return
}
if string(formatted1) != string(formatted2) {
t.Errorf("%s gofmt 不是幂等的", layer)
}
}
// TestFileRollback 测试文件删除回滚
func TestFileRollback(t *testing.T) {
tempDir := t.TempDir()
generatedFiles := []string{
filepath.Join(tempDir, "biz", "article.go"),
filepath.Join(tempDir, "data", "article.go"),
filepath.Join(tempDir, "service", "article.go"),
}
// 创建目录和文件
for _, file := range generatedFiles {
dir := filepath.Dir(file)
if err := os.MkdirAll(dir, 0755); err != nil {
t.Fatalf("创建目录失败: %v", err)
}
if err := os.WriteFile(file, []byte("// generated file"), 0644); err != nil {
t.Fatalf("创建文件失败: %v", err)
}
}
// 验证文件存在
for _, file := range generatedFiles {
if _, err := os.Stat(file); os.IsNotExist(err) {
t.Errorf("文件应该存在: %s", file)
}
}
// 模拟回滚:删除生成的文件
for _, file := range generatedFiles {
if err := os.Remove(file); err != nil {
t.Errorf("删除文件失败: %v", err)
}
}
// 验证文件已删除
for _, file := range generatedFiles {
if _, err := os.Stat(file); !os.IsNotExist(err) {
t.Errorf("文件应该已被删除: %s", file)
}
}
}
// TestRollbackAllWireProviders 测试批量回滚所有层的 Wire Provider
func TestRollbackAllWireProviders(t *testing.T) {
tempDir := t.TempDir()
layers := map[string]string{
"biz": `package biz
import (
"github.com/google/wire"
"kra/internal/biz/article"
)
var ProviderSet = wire.NewSet(
article.NewArticleUsecase,
)
`,
"data": `package data
import (
"github.com/google/wire"
"kra/internal/data/article"
)
var ProviderSet = wire.NewSet(
article.NewArticleRepo,
)
`,
}
files := make(map[string]string)
for layer, content := range layers {
file := filepath.Join(tempDir, layer+".go")
files[layer] = file
if err := os.WriteFile(file, []byte(content), 0644); err != nil {
t.Fatalf("创建 %s 文件失败: %v", layer, err)
}
}
for layer, file := range files {
var providerName, packageName string
switch layer {
case "biz":
providerName = "NewArticleUsecase"
packageName = "article"
case "data":
providerName = "NewArticleRepo"
packageName = "article"
}
injector := autocode.NewMainProviderSetAst(
layer, packageName, providerName,
"kra/internal/"+layer+"/article", file,
)
astFile, err := injector.Parse(file, nil)
if err != nil {
t.Fatalf("解析 %s 文件失败: %v", layer, err)
}
if err := injector.Rollback(astFile); err != nil {
t.Fatalf("回滚 %s 层失败: %v", layer, err)
}
var buf bytes.Buffer
if err := injector.Format(file, &buf, astFile); err != nil {
t.Fatalf("格式化 %s 失败: %v", layer, err)
}
result := buf.String()
if strings.Contains(result, providerName) {
t.Errorf("%s 层回滚后不应包含 %s", layer, providerName)
}
if strings.Contains(result, "kra/internal/"+layer+"/article") {
t.Errorf("%s 层回滚后不应包含 article 导入", layer)
}
}
}
// ==================== 后端层测试 ====================
// testBizLayerGeneration 测试 Biz 层代码生成
func testBizLayerGeneration(t *testing.T, info *AutoCodeInfo) {
tplContent := `package {{.Package}}
import (
"context"
"github.com/go-kratos/kratos/v2/log"
)
// {{.StructName}} {{.Description}}实体
type {{.StructName}} struct {
{{- range .Fields}}
{{.FieldName}} {{.FieldType}} ` + "`" + `json:"{{.FieldJson}}"` + "`" + `
{{- end}}
}
// {{.StructName}}Repo {{.Description}}仓储接口
type {{.StructName}}Repo interface {
Create(ctx context.Context, {{.Abbreviation}} *{{.StructName}}) error
Delete(ctx context.Context, {{.PrimaryField.FieldJson}} {{.PrimaryField.FieldType}}) error
Update(ctx context.Context, {{.Abbreviation}} *{{.StructName}}) error
FindByID(ctx context.Context, {{.PrimaryField.FieldJson}} {{.PrimaryField.FieldType}}) (*{{.StructName}}, error)
List(ctx context.Context, page, pageSize int) ([]*{{.StructName}}, int64, error)
}
// {{.StructName}}Usecase {{.Description}}用例
type {{.StructName}}Usecase struct {
repo {{.StructName}}Repo
log *log.Helper
}
// New{{.StructName}}Usecase 创建{{.Description}}用例
func New{{.StructName}}Usecase(repo {{.StructName}}Repo, logger log.Logger) *{{.StructName}}Usecase {
return &{{.StructName}}Usecase{
repo: repo,
log: log.NewHelper(logger),
}
}
`
result := renderTemplate(t, "biz", tplContent, info)
requiredComponents := []string{
"package article",
"Article struct",
"ArticleRepo interface",
"ArticleUsecase struct",
"NewArticleUsecase",
"Create(ctx context.Context",
"Delete(ctx context.Context",
"Update(ctx context.Context",
"FindByID(ctx context.Context",
"List(ctx context.Context",
"log.Helper",
}
for _, component := range requiredComponents {
if !strings.Contains(result, component) {
t.Errorf("Biz 层代码应包含 '%s'", component)
}
}
validateGoSyntax(t, result, "biz layer")
validateFormatIdempotence(t, result, "biz layer")
}
// testDataLayerGeneration 测试 Data 层代码生成
func testDataLayerGeneration(t *testing.T, info *AutoCodeInfo) {
tplContent := `package {{.Package}}
import (
"context"
"{{.Module}}/internal/biz/{{.Package}}"
"{{.Module}}/internal/data/query"
)
type {{.Abbreviation}}Repo struct {
data *Data
}
func New{{.StructName}}Repo(data *Data) {{.Package}}.{{.StructName}}Repo {
return &{{.Abbreviation}}Repo{data: data}
}
func (r *{{.Abbreviation}}Repo) Create(ctx context.Context, {{.Abbreviation}} *{{.Package}}.{{.StructName}}) error {
t := query.{{.StructName}}
return t.WithContext(ctx).Create({{.Abbreviation}})
}
func (r *{{.Abbreviation}}Repo) List(ctx context.Context, page, pageSize int) ([]*{{.Package}}.{{.StructName}}, int64, error) {
t := query.{{.StructName}}
q := t.WithContext(ctx)
total, _ := q.Count()
list, err := q.Offset((page - 1) * pageSize).Limit(pageSize).Find()
return list, total, err
}
`
result := renderTemplate(t, "data", tplContent, info)
requiredComponents := []string{
"package article",
"articleRepo struct",
"NewArticleRepo",
"query.Article",
"WithContext(ctx)",
}
for _, component := range requiredComponents {
if !strings.Contains(result, component) {
t.Errorf("Data 层代码应包含 '%s'", component)
}
}
validateGoSyntax(t, result, "data layer")
}
// testHandlerLayerGeneration 测试 Handler 层代码生成
func testHandlerLayerGeneration(t *testing.T, info *AutoCodeInfo) {
tplContent := `package {{.Package}}
import (
"{{.Module}}/internal/biz/{{.Package}}"
"{{.Module}}/pkg/response"
"github.com/gin-gonic/gin"
)
type {{.StructName}}Handler struct {
uc *{{.Package}}.{{.StructName}}Usecase
}
func New{{.StructName}}Handler(uc *{{.Package}}.{{.StructName}}Usecase) *{{.StructName}}Handler {
return &{{.StructName}}Handler{uc: uc}
}
// Create{{.StructName}} 创建{{.Description}}
// @Tags {{.Description}}
// @Summary 创建{{.Description}}
// @Router /{{.Abbreviation}}/create{{.StructName}} [post]
func (h *{{.StructName}}Handler) Create{{.StructName}}(c *gin.Context) {
response.OkWithMessage("创建成功", c)
}
// Delete{{.StructName}} 删除{{.Description}}
// @Tags {{.Description}}
// @Router /{{.Abbreviation}}/delete{{.StructName}} [delete]
func (h *{{.StructName}}Handler) Delete{{.StructName}}(c *gin.Context) {
response.OkWithMessage("删除成功", c)
}
// Update{{.StructName}} 更新{{.Description}}
// @Tags {{.Description}}
// @Router /{{.Abbreviation}}/update{{.StructName}} [put]
func (h *{{.StructName}}Handler) Update{{.StructName}}(c *gin.Context) {
response.OkWithMessage("更新成功", c)
}
// Find{{.StructName}} 根据ID获取{{.Description}}
// @Tags {{.Description}}
// @Router /{{.Abbreviation}}/find{{.StructName}} [get]
func (h *{{.StructName}}Handler) Find{{.StructName}}(c *gin.Context) {
response.OkWithMessage("获取成功", c)
}
// Get{{.StructName}}List 获取{{.Description}}列表
// @Tags {{.Description}}
// @Router /{{.Abbreviation}}/get{{.StructName}}List [get]
func (h *{{.StructName}}Handler) Get{{.StructName}}List(c *gin.Context) {
response.OkWithMessage("获取成功", c)
}
`
result := renderTemplate(t, "handler", tplContent, info)
requiredComponents := []string{
"package article",
"ArticleHandler struct",
"NewArticleHandler",
"CreateArticle",
"DeleteArticle",
"UpdateArticle",
"FindArticle",
"GetArticleList",
"@Tags 文章",
"@Router /article/",
"gin.Context",
}
for _, component := range requiredComponents {
if !strings.Contains(result, component) {
t.Errorf("Handler 层代码应包含 '%s'", component)
}
}
validateGoSyntax(t, result, "handler layer")
}
// testServiceLayerGeneration 测试 Service 层代码生成
func testServiceLayerGeneration(t *testing.T, info *AutoCodeInfo) {
tplContent := `package {{.Package}}
import (
"context"
"{{.Module}}/internal/biz/{{.Package}}"
"github.com/go-kratos/kratos/v2/log"
)
type {{.StructName}}Service struct {
uc *{{.Package}}.{{.StructName}}Usecase
log *log.Helper
}
func New{{.StructName}}Service(uc *{{.Package}}.{{.StructName}}Usecase, logger log.Logger) *{{.StructName}}Service {
return &{{.StructName}}Service{
uc: uc,
log: log.NewHelper(logger),
}
}
func (s *{{.StructName}}Service) Create{{.StructName}}(ctx context.Context) error {
return nil
}
func (s *{{.StructName}}Service) Delete{{.StructName}}(ctx context.Context, {{.PrimaryField.FieldJson}} {{.PrimaryField.FieldType}}) error {
return nil
}
`
result := renderTemplate(t, "service", tplContent, info)
requiredComponents := []string{
"package article",
"ArticleService struct",
"NewArticleService",
"CreateArticle",
"DeleteArticle",
}
for _, component := range requiredComponents {
if !strings.Contains(result, component) {
t.Errorf("Service 层代码应包含 '%s'", component)
}
}
validateGoSyntax(t, result, "service layer")
}
// testRouterLayerGeneration 测试 Router 层代码生成
func testRouterLayerGeneration(t *testing.T, info *AutoCodeInfo) {
tplContent := `package {{.Package}}
import (
"{{.Module}}/internal/server/handler/{{.Package}}"
"{{.Module}}/internal/middleware"
"github.com/gin-gonic/gin"
)
type {{.StructName}}Router struct{}
func (r *{{.StructName}}Router) Init{{.StructName}}Router(Router *gin.RouterGroup, PublicRouter *gin.RouterGroup, handler *{{.Package}}.{{.StructName}}Handler) {
{{.Abbreviation}}Router := Router.Group("{{.Abbreviation}}").Use(middleware.OperationRecord())
{{.Abbreviation}}RouterWithoutRecord := Router.Group("{{.Abbreviation}}")
{
{{.Abbreviation}}Router.POST("create{{.StructName}}", handler.Create{{.StructName}})
{{.Abbreviation}}Router.DELETE("delete{{.StructName}}", handler.Delete{{.StructName}})
{{.Abbreviation}}Router.PUT("update{{.StructName}}", handler.Update{{.StructName}})
}
{
{{.Abbreviation}}RouterWithoutRecord.GET("find{{.StructName}}", handler.Find{{.StructName}})
{{.Abbreviation}}RouterWithoutRecord.GET("get{{.StructName}}List", handler.Get{{.StructName}}List)
}
}
`
result := renderTemplate(t, "router", tplContent, info)
requiredComponents := []string{
"package article",
"ArticleRouter struct",
"InitArticleRouter",
"gin.RouterGroup",
"articleRouter",
"middleware.OperationRecord",
}
for _, component := range requiredComponents {
if !strings.Contains(result, component) {
t.Errorf("Router 层代码应包含 '%s'", component)
}
}
validateGoSyntax(t, result, "router layer")
}
// testTypesLayerGeneration 测试 Types 层代码生成
func testTypesLayerGeneration(t *testing.T, info *AutoCodeInfo) {
tplContent := `package request
import (
"{{.Module}}/pkg/request"
)
type {{.StructName}}Request struct {
{{- range .Fields}}
{{.FieldName}} {{.FieldType}} ` + "`" + `json:"{{.FieldJson}}"` + "`" + `
{{- end}}
}
type {{.StructName}}IDRequest struct {
ID {{.PrimaryField.FieldType}} ` + "`" + `json:"{{.PrimaryField.FieldJson}}"` + "`" + `
}
type {{.StructName}}Search struct {
request.PageInfo
}
`
result := renderTemplate(t, "types", tplContent, info)
requiredComponents := []string{
"package request",
"ArticleRequest struct",
"ArticleIDRequest struct",
"ArticleSearch struct",
"PageInfo",
}
for _, component := range requiredComponents {
if !strings.Contains(result, component) {
t.Errorf("Types 层代码应包含 '%s'", component)
}
}
validateGoSyntax(t, result, "types layer")
}
// testModelLayerGeneration 测试 Model 层代码生成
func testModelLayerGeneration(t *testing.T, info *AutoCodeInfo) {
tplContent := `package model
type {{.StructName}} struct {
{{- if .GvaModel }}
GVA_MODEL
{{- end }}
{{- range .Fields}}
{{.FieldName}} {{.FieldType}} ` + "`" + `gorm:"column:{{.ColumnName}}" json:"{{.FieldJson}}"` + "`" + `
{{- end}}
}
func ({{.StructName}}) TableName() string {
return "{{.TableName}}"
}
`
result := renderTemplate(t, "model", tplContent, info)
requiredComponents := []string{
"package model",
"Article struct",
"GVA_MODEL",
"TableName()",
"articles",
"gorm:",
}
for _, component := range requiredComponents {
if !strings.Contains(result, component) {
t.Errorf("Model 层代码应包含 '%s'", component)
}
}
}
// ==================== 前端层测试 ====================
// testFrontendPagesGeneration 测试前端页面代码生成
func testFrontendPagesGeneration(t *testing.T, info *AutoCodeInfo) {
// 简化测试:直接验证模板函数可以生成正确的列配置
// 使用 autocode 包的类型来测试
field := autocode.AutoCodeField{
FieldName: "Title",
FieldType: "string",
FieldJson: "title",
FieldDesc: "标题",
FieldSearchType: "LIKE",
Table: true,
Sort: true,
DictType: "",
}
// 测试 GenerateReactProTableColumn 函数
result := autocode.GenerateReactProTableColumn(field)
requiredComponents := []string{
"title:",
"'标题'",
"dataIndex:",
"'title'",
"sorter: true",
}
for _, component := range requiredComponents {
if !strings.Contains(result, component) {
t.Errorf("ProTable 列配置应包含 '%s', 实际结果: %s", component, result)
}
}
// 测试字典类型字段
dictField := autocode.AutoCodeField{
FieldName: "Status",
FieldType: "int",
FieldJson: "status",
FieldDesc: "状态",
FieldSearchType: "=",
DictType: "articleStatus",
Table: true,
}
dictResult := autocode.GenerateReactProTableColumn(dictField)
if !strings.Contains(dictResult, "articleStatus") {
t.Errorf("字典类型列应包含 articleStatus, 实际结果: %s", dictResult)
}
if !strings.Contains(dictResult, "select") {
t.Errorf("字典类型列应包含 select valueType, 实际结果: %s", dictResult)
}
// 测试时间类型字段
timeField := autocode.AutoCodeField{
FieldName: "PublishTime",
FieldType: "time.Time",
FieldJson: "publishTime",
FieldDesc: "发布时间",
FieldSearchType: "BETWEEN",
Table: true,
}
timeResult := autocode.GenerateReactProTableColumn(timeField)
if !strings.Contains(timeResult, "publishTime") {
t.Errorf("时间类型列应包含 publishTime, 实际结果: %s", timeResult)
}
if !strings.Contains(timeResult, "dateTime") {
t.Errorf("时间类型列应包含 dateTime valueType, 实际结果: %s", timeResult)
}
// 测试布尔类型字段
boolField := autocode.AutoCodeField{
FieldName: "IsTop",
FieldType: "bool",
FieldJson: "isTop",
FieldDesc: "是否置顶",
Table: true,
}
boolResult := autocode.GenerateReactProTableColumn(boolField)
if !strings.Contains(boolResult, "isTop") {
t.Errorf("布尔类型列应包含 isTop, 实际结果: %s", boolResult)
}
}
// testFrontendServicesGeneration 测试前端 API 服务代码生成
func testFrontendServicesGeneration(t *testing.T, info *AutoCodeInfo) {
tplContent := `import { request } from '@umijs/max';
/** 获取{{.Description}}列表 */
export async function get{{.StructName}}List(params: API.{{.StructName}}ListParams) {
return request<API.{{.StructName}}ListResult>('/api/{{.Abbreviation}}/get{{.StructName}}List', {
method: 'GET',
params,
});
}
/** 获取{{.Description}}详情 */
export async function get{{.StructName}}ById({{.PrimaryField.FieldJson}}: {{ GenerateTSType .PrimaryField.FieldType }}) {
return request<API.{{.StructName}}>('/api/{{.Abbreviation}}/find{{.StructName}}', {
method: 'GET',
params: { {{.PrimaryField.FieldJson}} },
});
}
/** 创建{{.Description}} */
export async function create{{.StructName}}(data: API.{{.StructName}}Request) {
return request<API.Response>('/api/{{.Abbreviation}}/create{{.StructName}}', {
method: 'POST',
data,
});
}
/** 更新{{.Description}} */
export async function update{{.StructName}}(data: API.{{.StructName}}Request) {
return request<API.Response>('/api/{{.Abbreviation}}/update{{.StructName}}', {
method: 'PUT',
data,
});
}
/** 删除{{.Description}} */
export async function delete{{.StructName}}({{.PrimaryField.FieldJson}}: {{ GenerateTSType .PrimaryField.FieldType }}) {
return request<API.Response>('/api/{{.Abbreviation}}/delete{{.StructName}}', {
method: 'DELETE',
params: { {{.PrimaryField.FieldJson}} },
});
}
/** 批量删除{{.Description}} */
export async function delete{{.StructName}}ByIds(ids: {{ GenerateTSType .PrimaryField.FieldType }}[]) {
return request<API.Response>('/api/{{.Abbreviation}}/delete{{.StructName}}ByIds', {
method: 'DELETE',
data: { ids },
});
}
`
result := renderTemplate(t, "services", tplContent, info)
// 验证 API 服务必要组件
requiredComponents := []string{
"import { request } from '@umijs/max'",
"getArticleList",
"getArticleById",
"createArticle",
"updateArticle",
"deleteArticle",
"deleteArticleByIds",
"/api/article/",
"method: 'GET'",
"method: 'POST'",
"method: 'PUT'",
"method: 'DELETE'",
"API.ArticleListParams",
"API.ArticleRequest",
"API.Response",
}
for _, component := range requiredComponents {
if !strings.Contains(result, component) {
t.Errorf("前端 API 服务代码应包含 '%s'", component)
}
}
// 验证类型转换
if !strings.Contains(result, "number") {
t.Error("前端 API 服务应包含 number 类型(从 uint 转换)")
}
}
// testFrontendTypesGeneration 测试前端 TypeScript 类型定义生成
func testFrontendTypesGeneration(t *testing.T, info *AutoCodeInfo) {
tplContent := `declare namespace API {
/** {{.Description}}实体 */
interface {{.StructName}} {
{{.PrimaryField.FieldJson}}?: {{ GenerateTSType .PrimaryField.FieldType }};
{{- range .Fields}}
{{.FieldJson}}?: {{ GenerateTSType .FieldType }};
{{- end}}
createdAt?: string;
updatedAt?: string;
}
/** {{.Description}}请求参数 */
interface {{.StructName}}Request {
{{- range .Fields}}
{{- if .Form}}
{{.FieldJson}}{{if not .Require}}?{{end}}: {{ GenerateTSType .FieldType }};
{{- end}}
{{- end}}
}
/** {{.Description}}列表请求参数 */
interface {{.StructName}}ListParams {
page?: number;
pageSize?: number;
{{- range .Fields}}
{{- if .FieldSearchType}}
{{.FieldJson}}?: {{ GenerateTSType .FieldType }};
{{- end}}
{{- end}}
}
/** {{.Description}}列表响应 */
interface {{.StructName}}ListResult {
list: {{.StructName}}[];
total: number;
page: number;
pageSize: number;
}
}
`
result := renderTemplate(t, "types", tplContent, info)
// 验证 TypeScript 类型定义必要组件
requiredComponents := []string{
"declare namespace API",
"interface Article",
"interface ArticleRequest",
"interface ArticleListParams",
"interface ArticleListResult",
"id?:",
"title",
"content",
"status",
"publishTime",
"viewCount",
"createdAt?: string",
"updatedAt?: string",
"page?: number",
"pageSize?: number",
"list: Article[]",
"total: number",
}
for _, component := range requiredComponents {
if !strings.Contains(result, component) {
t.Errorf("前端 TypeScript 类型定义应包含 '%s'", component)
}
}
// 验证类型映射正确性
typeChecks := map[string]string{
"string": "string",
"number": "number", // int/uint -> number
"boolean": "boolean",
}
for _, expectedType := range typeChecks {
if !strings.Contains(result, expectedType) {
t.Errorf("前端类型定义应包含 TypeScript 类型 '%s'", expectedType)
}
}
}
// ==================== Wire 注入测试 ====================
// TestWireProviderInjectionAndRollback 测试 Wire Provider 注入和回滚
func TestWireProviderInjectionAndRollback(t *testing.T) {
tempDir := t.TempDir()
testFile := filepath.Join(tempDir, "test_wire.go")
initialContent := `package test
import (
"github.com/google/wire"
)
var ProviderSet = wire.NewSet(
NewExistingProvider,
)
`
if err := os.WriteFile(testFile, []byte(initialContent), 0644); err != nil {
t.Fatalf("创建测试文件失败: %v", err)
}
injector := autocode.NewWireProviderAst("biz", "article", "NewArticleUsecase", testFile)
file, err := injector.Parse(testFile, nil)
if err != nil {
t.Fatalf("解析文件失败: %v", err)
}
// 执行注入
if err := injector.Injection(file); err != nil {
t.Fatalf("注入失败: %v", err)
}
var buf bytes.Buffer
if err := injector.Format(testFile, &buf, file); err != nil {
t.Fatalf("格式化失败: %v", err)
}
injectedContent := buf.String()
if !strings.Contains(injectedContent, "NewArticleUsecase") {
t.Error("注入后应包含 NewArticleUsecase")
}
if !strings.Contains(injectedContent, "NewExistingProvider") {
t.Error("注入后应保留 NewExistingProvider")
}
// 写入注入后的内容并重新解析
if err := os.WriteFile(testFile, []byte(injectedContent), 0644); err != nil {
t.Fatalf("写入注入后文件失败: %v", err)
}
file2, err := injector.Parse(testFile, nil)
if err != nil {
t.Fatalf("解析注入后文件失败: %v", err)
}
// 执行回滚
if err := injector.Rollback(file2); err != nil {
t.Fatalf("回滚失败: %v", err)
}
var rollbackBuf bytes.Buffer
if err := injector.Format(testFile, &rollbackBuf, file2); err != nil {
t.Fatalf("格式化回滚结果失败: %v", err)
}
rollbackContent := rollbackBuf.String()
if strings.Contains(rollbackContent, "NewArticleUsecase") {
t.Error("回滚后不应包含 NewArticleUsecase")
}
if !strings.Contains(rollbackContent, "NewExistingProvider") {
t.Error("回滚后应保留 NewExistingProvider")
}
}
// TestMainProviderSetInjectionAndRollback 测试主 ProviderSet 注入和回滚
func TestMainProviderSetInjectionAndRollback(t *testing.T) {
tempDir := t.TempDir()
testFile := filepath.Join(tempDir, "biz.go")
initialContent := `package biz
import (
"github.com/google/wire"
"kra/internal/biz/system"
)
var ProviderSet = wire.NewSet(
system.NewUserUsecase,
)
`
if err := os.WriteFile(testFile, []byte(initialContent), 0644); err != nil {
t.Fatalf("创建测试文件失败: %v", err)
}
injector := autocode.NewMainProviderSetAst(
"biz", "article", "NewArticleUsecase",
"kra/internal/biz/article", testFile,
)
file, err := injector.Parse(testFile, nil)
if err != nil {
t.Fatalf("解析文件失败: %v", err)
}
if err := injector.Injection(file); err != nil {
t.Fatalf("注入失败: %v", err)
}
var buf bytes.Buffer
if err := injector.Format(testFile, &buf, file); err != nil {
t.Fatalf("格式化失败: %v", err)
}
injectedContent := buf.String()
if !strings.Contains(injectedContent, "article.NewArticleUsecase") {
t.Error("注入后应包含 article.NewArticleUsecase")
}
if !strings.Contains(injectedContent, "kra/internal/biz/article") {
t.Error("注入后应包含导入路径")
}
// 写入并重新解析
if err := os.WriteFile(testFile, []byte(injectedContent), 0644); err != nil {
t.Fatalf("写入注入后文件失败: %v", err)
}
file2, err := injector.Parse(testFile, nil)
if err != nil {
t.Fatalf("解析注入后文件失败: %v", err)
}
// 执行回滚
if err := injector.Rollback(file2); err != nil {
t.Fatalf("回滚失败: %v", err)
}
var rollbackBuf bytes.Buffer
if err := injector.Format(testFile, &rollbackBuf, file2); err != nil {
t.Fatalf("格式化回滚结果失败: %v", err)
}
rollbackContent := rollbackBuf.String()
if strings.Contains(rollbackContent, "article.NewArticleUsecase") {
t.Error("回滚后不应包含 article.NewArticleUsecase")
}
if strings.Contains(rollbackContent, "kra/internal/biz/article") {
t.Error("回滚后不应包含 article 导入路径")
}
if !strings.Contains(rollbackContent, "system.NewUserUsecase") {
t.Error("回滚后应保留 system.NewUserUsecase")
}
}
// TestMultipleInjectionAndRollback 测试多次注入和回滚
func TestMultipleInjectionAndRollback(t *testing.T) {
tempDir := t.TempDir()
testFile := filepath.Join(tempDir, "test_multi.go")
initialContent := `package test
import (
"github.com/google/wire"
)
var ProviderSet = wire.NewSet(
NewBaseProvider,
)
`
if err := os.WriteFile(testFile, []byte(initialContent), 0644); err != nil {
t.Fatalf("创建测试文件失败: %v", err)
}
providers := []string{"NewProvider1", "NewProvider2", "NewProvider3"}
currentContent := initialContent
// 依次注入多个 Provider
for _, provider := range providers {
if err := os.WriteFile(testFile, []byte(currentContent), 0644); err != nil {
t.Fatalf("写入文件失败: %v", err)
}
injector := autocode.NewWireProviderAst("biz", "test", provider, testFile)
file, err := injector.Parse(testFile, nil)
if err != nil {
t.Fatalf("解析文件失败: %v", err)
}
if err := injector.Injection(file); err != nil {
t.Fatalf("注入 %s 失败: %v", provider, err)
}
var buf bytes.Buffer
if err := injector.Format(testFile, &buf, file); err != nil {
t.Fatalf("格式化失败: %v", err)
}
currentContent = buf.String()
}
// 验证所有 Provider 都已注入
for _, provider := range providers {
if !strings.Contains(currentContent, provider) {
t.Errorf("应包含 %s", provider)
}
}
// 逆序回滚所有 Provider
for i := len(providers) - 1; i >= 0; i-- {
provider := providers[i]
if err := os.WriteFile(testFile, []byte(currentContent), 0644); err != nil {
t.Fatalf("写入文件失败: %v", err)
}
injector := autocode.NewWireProviderAst("biz", "test", provider, testFile)
file, err := injector.Parse(testFile, nil)
if err != nil {
t.Fatalf("解析文件失败: %v", err)
}
if err := injector.Rollback(file); err != nil {
t.Fatalf("回滚 %s 失败: %v", provider, err)
}
var buf bytes.Buffer
if err := injector.Format(testFile, &buf, file); err != nil {
t.Fatalf("格式化失败: %v", err)
}
currentContent = buf.String()
if strings.Contains(currentContent, provider) {
t.Errorf("回滚后不应包含 %s", provider)
}
}
if !strings.Contains(currentContent, "NewBaseProvider") {
t.Error("回滚后应保留 NewBaseProvider")
}
}
// TestRollbackIdempotency 测试回滚幂等性
func TestRollbackIdempotency(t *testing.T) {
tempDir := t.TempDir()
testFile := filepath.Join(tempDir, "test_idempotent.go")
initialContent := `package test
import (
"github.com/google/wire"
)
var ProviderSet = wire.NewSet(
NewExistingProvider,
NewTargetProvider,
)
`
if err := os.WriteFile(testFile, []byte(initialContent), 0644); err != nil {
t.Fatalf("创建测试文件失败: %v", err)
}
injector := autocode.NewWireProviderAst("biz", "test", "NewTargetProvider", testFile)
// 第一次回滚
file1, _ := injector.Parse(testFile, nil)
_ = injector.Rollback(file1)
var buf1 bytes.Buffer
_ = injector.Format(testFile, &buf1, file1)
firstRollback := buf1.String()
if err := os.WriteFile(testFile, []byte(firstRollback), 0644); err != nil {
t.Fatalf("写入文件失败: %v", err)
}
// 第二次回滚
file2, _ := injector.Parse(testFile, nil)
_ = injector.Rollback(file2)
var buf2 bytes.Buffer
_ = injector.Format(testFile, &buf2, file2)
secondRollback := buf2.String()
// 验证幂等性
if firstRollback != secondRollback {
t.Error("回滚应该是幂等的")
}
if strings.Contains(secondRollback, "NewTargetProvider") {
t.Error("回滚后不应包含 NewTargetProvider")
}
if !strings.Contains(secondRollback, "NewExistingProvider") {
t.Error("回滚后应保留 NewExistingProvider")
}
}
// countProviderSetArgs 计算 ProviderSet 中的参数数量
func countProviderSetArgs(file *ast.File) int {
var count int
ast.Inspect(file, func(node ast.Node) bool {
genDecl, ok := node.(*ast.GenDecl)
if !ok || genDecl.Tok != token.VAR {
return true
}
for _, spec := range genDecl.Specs {
valueSpec, ok := spec.(*ast.ValueSpec)
if !ok {
continue
}
for i, name := range valueSpec.Names {
if name.Name != "ProviderSet" {
continue
}
if i >= len(valueSpec.Values) {
continue
}
callExpr, ok := valueSpec.Values[i].(*ast.CallExpr)
if !ok {
continue
}
count = len(callExpr.Args)
return false
}
}
return true
})
return count
}
// TestRollbackPreservesFormatting 测试回滚保持代码格式
func TestRollbackPreservesFormatting(t *testing.T) {
tempDir := t.TempDir()
testFile := filepath.Join(tempDir, "format_test.go")
initialContent := `package test
import (
"github.com/google/wire"
)
var ProviderSet = wire.NewSet(
NewProvider1,
NewProvider2,
NewTargetProvider,
)
`
if err := os.WriteFile(testFile, []byte(initialContent), 0644); err != nil {
t.Fatalf("创建测试文件失败: %v", err)
}
injector := autocode.NewWireProviderAst("biz", "test", "NewTargetProvider", testFile)
file, err := injector.Parse(testFile, nil)
if err != nil {
t.Fatalf("解析文件失败: %v", err)
}
if err := injector.Rollback(file); err != nil {
t.Fatalf("回滚失败: %v", err)
}
var buf bytes.Buffer
if err := injector.Format(testFile, &buf, file); err != nil {
t.Fatalf("格式化失败: %v", err)
}
result := buf.String()
// 验证代码仍然是有效的 Go 代码
fset := token.NewFileSet()
_, err = parser.ParseFile(fset, "", result, parser.AllErrors)
if err != nil {
t.Errorf("回滚后代码语法错误: %v", err)
}
// 验证 gofmt 幂等性
formatted1, err := format.Source([]byte(result))
if err != nil {
t.Errorf("第一次格式化失败: %v", err)
return
}
formatted2, err := format.Source(formatted1)
if err != nil {
t.Errorf("第二次格式化失败: %v", err)
return
}
if string(formatted1) != string(formatted2) {
t.Error("回滚后代码 gofmt 不是幂等的")
}
}
// TestRollbackNonExistentProvider 测试回滚不存在的 Provider
func TestRollbackNonExistentProvider(t *testing.T) {
tempDir := t.TempDir()
testFile := filepath.Join(tempDir, "nonexistent.go")
initialContent := `package test
import (
"github.com/google/wire"
)
var ProviderSet = wire.NewSet(
NewExistingProvider,
)
`
if err := os.WriteFile(testFile, []byte(initialContent), 0644); err != nil {
t.Fatalf("创建测试文件失败: %v", err)
}
injector := autocode.NewWireProviderAst("biz", "test", "NewNonExistentProvider", testFile)
file, err := injector.Parse(testFile, nil)
if err != nil {
t.Fatalf("解析文件失败: %v", err)
}
// 回滚不存在的 Provider 不应该报错
if err := injector.Rollback(file); err != nil {
t.Errorf("回滚不存在的 Provider 不应该报错: %v", err)
}
var buf bytes.Buffer
if err := injector.Format(testFile, &buf, file); err != nil {
t.Fatalf("格式化失败: %v", err)
}
result := buf.String()
// 验证原有 Provider 仍然存在
if !strings.Contains(result, "NewExistingProvider") {
t.Error("回滚后应保留 NewExistingProvider")
}
}