1375 lines
35 KiB
Go
1375 lines
35 KiB
Go
package system
|
|
|
|
import (
|
|
"bytes"
|
|
"go/ast"
|
|
"go/format"
|
|
"go/parser"
|
|
"go/token"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
"text/template"
|
|
|
|
"kra/pkg/utils/autocode"
|
|
)
|
|
|
|
// ==================== 端到端代码生成测试 ====================
|
|
|
|
// TestEndToEndCodeGeneration 端到端代码生成测试
|
|
// 验证所有层的代码生成是否正确
|
|
func TestEndToEndCodeGeneration(t *testing.T) {
|
|
info := createTestAutoCodeInfo()
|
|
|
|
t.Run("BizLayer", func(t *testing.T) {
|
|
testBizLayerGeneration(t, info)
|
|
})
|
|
|
|
t.Run("DataLayer", func(t *testing.T) {
|
|
testDataLayerGeneration(t, info)
|
|
})
|
|
|
|
t.Run("HandlerLayer", func(t *testing.T) {
|
|
testHandlerLayerGeneration(t, info)
|
|
})
|
|
|
|
t.Run("ServiceLayer", func(t *testing.T) {
|
|
testServiceLayerGeneration(t, info)
|
|
})
|
|
|
|
t.Run("RouterLayer", func(t *testing.T) {
|
|
testRouterLayerGeneration(t, info)
|
|
})
|
|
|
|
t.Run("TypesLayer", func(t *testing.T) {
|
|
testTypesLayerGeneration(t, info)
|
|
})
|
|
|
|
t.Run("ModelLayer", func(t *testing.T) {
|
|
testModelLayerGeneration(t, info)
|
|
})
|
|
|
|
t.Run("FrontendPages", func(t *testing.T) {
|
|
testFrontendPagesGeneration(t, info)
|
|
})
|
|
|
|
t.Run("FrontendServices", func(t *testing.T) {
|
|
testFrontendServicesGeneration(t, info)
|
|
})
|
|
|
|
t.Run("FrontendTypes", func(t *testing.T) {
|
|
testFrontendTypesGeneration(t, info)
|
|
})
|
|
}
|
|
|
|
// createTestAutoCodeInfo 创建测试用的 AutoCodeInfo
|
|
func createTestAutoCodeInfo() *AutoCodeInfo {
|
|
return &AutoCodeInfo{
|
|
StructName: "Article",
|
|
TableName: "articles",
|
|
Package: "article",
|
|
PackageName: "article",
|
|
Abbreviation: "article",
|
|
Description: "文章",
|
|
Module: "kra",
|
|
GvaModel: true,
|
|
PrimaryField: &AutoCodeField{
|
|
FieldName: "ID",
|
|
FieldType: "uint",
|
|
FieldJson: "id",
|
|
FieldDesc: "主键ID",
|
|
},
|
|
Fields: []*AutoCodeField{
|
|
{
|
|
FieldName: "Title",
|
|
FieldType: "string",
|
|
FieldJson: "title",
|
|
FieldDesc: "标题",
|
|
ColumnName: "title",
|
|
FieldSearchType: "LIKE",
|
|
Form: true,
|
|
Table: true,
|
|
Desc: true,
|
|
Require: true,
|
|
ErrorText: "请输入标题",
|
|
},
|
|
{
|
|
FieldName: "Content",
|
|
FieldType: "string",
|
|
FieldJson: "content",
|
|
FieldDesc: "内容",
|
|
ColumnName: "content",
|
|
Form: true,
|
|
Table: false,
|
|
Desc: true,
|
|
},
|
|
{
|
|
FieldName: "Status",
|
|
FieldType: "int",
|
|
FieldJson: "status",
|
|
FieldDesc: "状态",
|
|
ColumnName: "status",
|
|
FieldSearchType: "=",
|
|
DictType: "articleStatus",
|
|
Form: true,
|
|
Table: true,
|
|
Desc: true,
|
|
},
|
|
{
|
|
FieldName: "PublishTime",
|
|
FieldType: "time.Time",
|
|
FieldJson: "publishTime",
|
|
FieldDesc: "发布时间",
|
|
ColumnName: "publish_time",
|
|
FieldSearchType: "BETWEEN",
|
|
Form: true,
|
|
Table: true,
|
|
Desc: true,
|
|
},
|
|
{
|
|
FieldName: "ViewCount",
|
|
FieldType: "int",
|
|
FieldJson: "viewCount",
|
|
FieldDesc: "浏览量",
|
|
ColumnName: "view_count",
|
|
Form: false,
|
|
Table: true,
|
|
Desc: true,
|
|
Sort: true,
|
|
},
|
|
{
|
|
FieldName: "IsTop",
|
|
FieldType: "bool",
|
|
FieldJson: "isTop",
|
|
FieldDesc: "是否置顶",
|
|
ColumnName: "is_top",
|
|
Form: true,
|
|
Table: true,
|
|
Desc: true,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
// renderTemplate 渲染模板
|
|
func renderTemplate(t *testing.T, name, tplContent string, data interface{}) string {
|
|
tmpl, err := template.New(name).Funcs(autocode.GetTemplateFuncMap()).Parse(tplContent)
|
|
if err != nil {
|
|
t.Fatalf("解析模板失败: %v", err)
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
if err := tmpl.Execute(&buf, data); err != nil {
|
|
t.Fatalf("执行模板失败: %v", err)
|
|
}
|
|
|
|
return buf.String()
|
|
}
|
|
|
|
// validateGoSyntax 验证 Go 代码语法
|
|
func validateGoSyntax(t *testing.T, code, layer string) {
|
|
fset := token.NewFileSet()
|
|
_, err := parser.ParseFile(fset, "", code, parser.AllErrors)
|
|
if err != nil {
|
|
t.Errorf("%s 代码语法错误: %v\n代码:\n%s", layer, err, code)
|
|
}
|
|
}
|
|
|
|
// validateFormatIdempotence 验证 gofmt 幂等性 (Property 4)
|
|
func validateFormatIdempotence(t *testing.T, code, layer string) {
|
|
formatted1, err := format.Source([]byte(code))
|
|
if err != nil {
|
|
t.Errorf("%s 第一次格式化失败: %v", layer, err)
|
|
return
|
|
}
|
|
formatted2, err := format.Source(formatted1)
|
|
if err != nil {
|
|
t.Errorf("%s 第二次格式化失败: %v", layer, err)
|
|
return
|
|
}
|
|
if string(formatted1) != string(formatted2) {
|
|
t.Errorf("%s gofmt 不是幂等的", layer)
|
|
}
|
|
}
|
|
|
|
// TestFileRollback 测试文件删除回滚
|
|
func TestFileRollback(t *testing.T) {
|
|
tempDir := t.TempDir()
|
|
|
|
generatedFiles := []string{
|
|
filepath.Join(tempDir, "biz", "article.go"),
|
|
filepath.Join(tempDir, "data", "article.go"),
|
|
filepath.Join(tempDir, "service", "article.go"),
|
|
}
|
|
|
|
// 创建目录和文件
|
|
for _, file := range generatedFiles {
|
|
dir := filepath.Dir(file)
|
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
|
t.Fatalf("创建目录失败: %v", err)
|
|
}
|
|
if err := os.WriteFile(file, []byte("// generated file"), 0644); err != nil {
|
|
t.Fatalf("创建文件失败: %v", err)
|
|
}
|
|
}
|
|
|
|
// 验证文件存在
|
|
for _, file := range generatedFiles {
|
|
if _, err := os.Stat(file); os.IsNotExist(err) {
|
|
t.Errorf("文件应该存在: %s", file)
|
|
}
|
|
}
|
|
|
|
// 模拟回滚:删除生成的文件
|
|
for _, file := range generatedFiles {
|
|
if err := os.Remove(file); err != nil {
|
|
t.Errorf("删除文件失败: %v", err)
|
|
}
|
|
}
|
|
|
|
// 验证文件已删除
|
|
for _, file := range generatedFiles {
|
|
if _, err := os.Stat(file); !os.IsNotExist(err) {
|
|
t.Errorf("文件应该已被删除: %s", file)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestRollbackAllWireProviders 测试批量回滚所有层的 Wire Provider
|
|
func TestRollbackAllWireProviders(t *testing.T) {
|
|
tempDir := t.TempDir()
|
|
|
|
layers := map[string]string{
|
|
"biz": `package biz
|
|
|
|
import (
|
|
"github.com/google/wire"
|
|
"kra/internal/biz/article"
|
|
)
|
|
|
|
var ProviderSet = wire.NewSet(
|
|
article.NewArticleUsecase,
|
|
)
|
|
`,
|
|
"data": `package data
|
|
|
|
import (
|
|
"github.com/google/wire"
|
|
"kra/internal/data/article"
|
|
)
|
|
|
|
var ProviderSet = wire.NewSet(
|
|
article.NewArticleRepo,
|
|
)
|
|
`,
|
|
}
|
|
|
|
files := make(map[string]string)
|
|
for layer, content := range layers {
|
|
file := filepath.Join(tempDir, layer+".go")
|
|
files[layer] = file
|
|
if err := os.WriteFile(file, []byte(content), 0644); err != nil {
|
|
t.Fatalf("创建 %s 文件失败: %v", layer, err)
|
|
}
|
|
}
|
|
|
|
for layer, file := range files {
|
|
var providerName, packageName string
|
|
switch layer {
|
|
case "biz":
|
|
providerName = "NewArticleUsecase"
|
|
packageName = "article"
|
|
case "data":
|
|
providerName = "NewArticleRepo"
|
|
packageName = "article"
|
|
}
|
|
|
|
injector := autocode.NewMainProviderSetAst(
|
|
layer, packageName, providerName,
|
|
"kra/internal/"+layer+"/article", file,
|
|
)
|
|
|
|
astFile, err := injector.Parse(file, nil)
|
|
if err != nil {
|
|
t.Fatalf("解析 %s 文件失败: %v", layer, err)
|
|
}
|
|
|
|
if err := injector.Rollback(astFile); err != nil {
|
|
t.Fatalf("回滚 %s 层失败: %v", layer, err)
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
if err := injector.Format(file, &buf, astFile); err != nil {
|
|
t.Fatalf("格式化 %s 失败: %v", layer, err)
|
|
}
|
|
|
|
result := buf.String()
|
|
if strings.Contains(result, providerName) {
|
|
t.Errorf("%s 层回滚后不应包含 %s", layer, providerName)
|
|
}
|
|
if strings.Contains(result, "kra/internal/"+layer+"/article") {
|
|
t.Errorf("%s 层回滚后不应包含 article 导入", layer)
|
|
}
|
|
}
|
|
}
|
|
|
|
// ==================== 后端层测试 ====================
|
|
|
|
// testBizLayerGeneration 测试 Biz 层代码生成
|
|
func testBizLayerGeneration(t *testing.T, info *AutoCodeInfo) {
|
|
tplContent := `package {{.Package}}
|
|
|
|
import (
|
|
"context"
|
|
"github.com/go-kratos/kratos/v2/log"
|
|
)
|
|
|
|
// {{.StructName}} {{.Description}}实体
|
|
type {{.StructName}} struct {
|
|
{{- range .Fields}}
|
|
{{.FieldName}} {{.FieldType}} ` + "`" + `json:"{{.FieldJson}}"` + "`" + `
|
|
{{- end}}
|
|
}
|
|
|
|
// {{.StructName}}Repo {{.Description}}仓储接口
|
|
type {{.StructName}}Repo interface {
|
|
Create(ctx context.Context, {{.Abbreviation}} *{{.StructName}}) error
|
|
Delete(ctx context.Context, {{.PrimaryField.FieldJson}} {{.PrimaryField.FieldType}}) error
|
|
Update(ctx context.Context, {{.Abbreviation}} *{{.StructName}}) error
|
|
FindByID(ctx context.Context, {{.PrimaryField.FieldJson}} {{.PrimaryField.FieldType}}) (*{{.StructName}}, error)
|
|
List(ctx context.Context, page, pageSize int) ([]*{{.StructName}}, int64, error)
|
|
}
|
|
|
|
// {{.StructName}}Usecase {{.Description}}用例
|
|
type {{.StructName}}Usecase struct {
|
|
repo {{.StructName}}Repo
|
|
log *log.Helper
|
|
}
|
|
|
|
// New{{.StructName}}Usecase 创建{{.Description}}用例
|
|
func New{{.StructName}}Usecase(repo {{.StructName}}Repo, logger log.Logger) *{{.StructName}}Usecase {
|
|
return &{{.StructName}}Usecase{
|
|
repo: repo,
|
|
log: log.NewHelper(logger),
|
|
}
|
|
}
|
|
`
|
|
result := renderTemplate(t, "biz", tplContent, info)
|
|
|
|
requiredComponents := []string{
|
|
"package article",
|
|
"Article struct",
|
|
"ArticleRepo interface",
|
|
"ArticleUsecase struct",
|
|
"NewArticleUsecase",
|
|
"Create(ctx context.Context",
|
|
"Delete(ctx context.Context",
|
|
"Update(ctx context.Context",
|
|
"FindByID(ctx context.Context",
|
|
"List(ctx context.Context",
|
|
"log.Helper",
|
|
}
|
|
|
|
for _, component := range requiredComponents {
|
|
if !strings.Contains(result, component) {
|
|
t.Errorf("Biz 层代码应包含 '%s'", component)
|
|
}
|
|
}
|
|
|
|
validateGoSyntax(t, result, "biz layer")
|
|
validateFormatIdempotence(t, result, "biz layer")
|
|
}
|
|
|
|
// testDataLayerGeneration 测试 Data 层代码生成
|
|
func testDataLayerGeneration(t *testing.T, info *AutoCodeInfo) {
|
|
tplContent := `package {{.Package}}
|
|
|
|
import (
|
|
"context"
|
|
"{{.Module}}/internal/biz/{{.Package}}"
|
|
"{{.Module}}/internal/data/query"
|
|
)
|
|
|
|
type {{.Abbreviation}}Repo struct {
|
|
data *Data
|
|
}
|
|
|
|
func New{{.StructName}}Repo(data *Data) {{.Package}}.{{.StructName}}Repo {
|
|
return &{{.Abbreviation}}Repo{data: data}
|
|
}
|
|
|
|
func (r *{{.Abbreviation}}Repo) Create(ctx context.Context, {{.Abbreviation}} *{{.Package}}.{{.StructName}}) error {
|
|
t := query.{{.StructName}}
|
|
return t.WithContext(ctx).Create({{.Abbreviation}})
|
|
}
|
|
|
|
func (r *{{.Abbreviation}}Repo) List(ctx context.Context, page, pageSize int) ([]*{{.Package}}.{{.StructName}}, int64, error) {
|
|
t := query.{{.StructName}}
|
|
q := t.WithContext(ctx)
|
|
total, _ := q.Count()
|
|
list, err := q.Offset((page - 1) * pageSize).Limit(pageSize).Find()
|
|
return list, total, err
|
|
}
|
|
`
|
|
result := renderTemplate(t, "data", tplContent, info)
|
|
|
|
requiredComponents := []string{
|
|
"package article",
|
|
"articleRepo struct",
|
|
"NewArticleRepo",
|
|
"query.Article",
|
|
"WithContext(ctx)",
|
|
}
|
|
|
|
for _, component := range requiredComponents {
|
|
if !strings.Contains(result, component) {
|
|
t.Errorf("Data 层代码应包含 '%s'", component)
|
|
}
|
|
}
|
|
|
|
validateGoSyntax(t, result, "data layer")
|
|
}
|
|
|
|
// testHandlerLayerGeneration 测试 Handler 层代码生成
|
|
func testHandlerLayerGeneration(t *testing.T, info *AutoCodeInfo) {
|
|
tplContent := `package {{.Package}}
|
|
|
|
import (
|
|
"{{.Module}}/internal/biz/{{.Package}}"
|
|
"{{.Module}}/pkg/response"
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
type {{.StructName}}Handler struct {
|
|
uc *{{.Package}}.{{.StructName}}Usecase
|
|
}
|
|
|
|
func New{{.StructName}}Handler(uc *{{.Package}}.{{.StructName}}Usecase) *{{.StructName}}Handler {
|
|
return &{{.StructName}}Handler{uc: uc}
|
|
}
|
|
|
|
// Create{{.StructName}} 创建{{.Description}}
|
|
// @Tags {{.Description}}
|
|
// @Summary 创建{{.Description}}
|
|
// @Router /{{.Abbreviation}}/create{{.StructName}} [post]
|
|
func (h *{{.StructName}}Handler) Create{{.StructName}}(c *gin.Context) {
|
|
response.OkWithMessage("创建成功", c)
|
|
}
|
|
|
|
// Delete{{.StructName}} 删除{{.Description}}
|
|
// @Tags {{.Description}}
|
|
// @Router /{{.Abbreviation}}/delete{{.StructName}} [delete]
|
|
func (h *{{.StructName}}Handler) Delete{{.StructName}}(c *gin.Context) {
|
|
response.OkWithMessage("删除成功", c)
|
|
}
|
|
|
|
// Update{{.StructName}} 更新{{.Description}}
|
|
// @Tags {{.Description}}
|
|
// @Router /{{.Abbreviation}}/update{{.StructName}} [put]
|
|
func (h *{{.StructName}}Handler) Update{{.StructName}}(c *gin.Context) {
|
|
response.OkWithMessage("更新成功", c)
|
|
}
|
|
|
|
// Find{{.StructName}} 根据ID获取{{.Description}}
|
|
// @Tags {{.Description}}
|
|
// @Router /{{.Abbreviation}}/find{{.StructName}} [get]
|
|
func (h *{{.StructName}}Handler) Find{{.StructName}}(c *gin.Context) {
|
|
response.OkWithMessage("获取成功", c)
|
|
}
|
|
|
|
// Get{{.StructName}}List 获取{{.Description}}列表
|
|
// @Tags {{.Description}}
|
|
// @Router /{{.Abbreviation}}/get{{.StructName}}List [get]
|
|
func (h *{{.StructName}}Handler) Get{{.StructName}}List(c *gin.Context) {
|
|
response.OkWithMessage("获取成功", c)
|
|
}
|
|
`
|
|
result := renderTemplate(t, "handler", tplContent, info)
|
|
|
|
requiredComponents := []string{
|
|
"package article",
|
|
"ArticleHandler struct",
|
|
"NewArticleHandler",
|
|
"CreateArticle",
|
|
"DeleteArticle",
|
|
"UpdateArticle",
|
|
"FindArticle",
|
|
"GetArticleList",
|
|
"@Tags 文章",
|
|
"@Router /article/",
|
|
"gin.Context",
|
|
}
|
|
|
|
for _, component := range requiredComponents {
|
|
if !strings.Contains(result, component) {
|
|
t.Errorf("Handler 层代码应包含 '%s'", component)
|
|
}
|
|
}
|
|
|
|
validateGoSyntax(t, result, "handler layer")
|
|
}
|
|
|
|
// testServiceLayerGeneration 测试 Service 层代码生成
|
|
func testServiceLayerGeneration(t *testing.T, info *AutoCodeInfo) {
|
|
tplContent := `package {{.Package}}
|
|
|
|
import (
|
|
"context"
|
|
"{{.Module}}/internal/biz/{{.Package}}"
|
|
"github.com/go-kratos/kratos/v2/log"
|
|
)
|
|
|
|
type {{.StructName}}Service struct {
|
|
uc *{{.Package}}.{{.StructName}}Usecase
|
|
log *log.Helper
|
|
}
|
|
|
|
func New{{.StructName}}Service(uc *{{.Package}}.{{.StructName}}Usecase, logger log.Logger) *{{.StructName}}Service {
|
|
return &{{.StructName}}Service{
|
|
uc: uc,
|
|
log: log.NewHelper(logger),
|
|
}
|
|
}
|
|
|
|
func (s *{{.StructName}}Service) Create{{.StructName}}(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
func (s *{{.StructName}}Service) Delete{{.StructName}}(ctx context.Context, {{.PrimaryField.FieldJson}} {{.PrimaryField.FieldType}}) error {
|
|
return nil
|
|
}
|
|
`
|
|
result := renderTemplate(t, "service", tplContent, info)
|
|
|
|
requiredComponents := []string{
|
|
"package article",
|
|
"ArticleService struct",
|
|
"NewArticleService",
|
|
"CreateArticle",
|
|
"DeleteArticle",
|
|
}
|
|
|
|
for _, component := range requiredComponents {
|
|
if !strings.Contains(result, component) {
|
|
t.Errorf("Service 层代码应包含 '%s'", component)
|
|
}
|
|
}
|
|
|
|
validateGoSyntax(t, result, "service layer")
|
|
}
|
|
|
|
// testRouterLayerGeneration 测试 Router 层代码生成
|
|
func testRouterLayerGeneration(t *testing.T, info *AutoCodeInfo) {
|
|
tplContent := `package {{.Package}}
|
|
|
|
import (
|
|
"{{.Module}}/internal/server/handler/{{.Package}}"
|
|
"{{.Module}}/internal/middleware"
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
type {{.StructName}}Router struct{}
|
|
|
|
func (r *{{.StructName}}Router) Init{{.StructName}}Router(Router *gin.RouterGroup, PublicRouter *gin.RouterGroup, handler *{{.Package}}.{{.StructName}}Handler) {
|
|
{{.Abbreviation}}Router := Router.Group("{{.Abbreviation}}").Use(middleware.OperationRecord())
|
|
{{.Abbreviation}}RouterWithoutRecord := Router.Group("{{.Abbreviation}}")
|
|
{
|
|
{{.Abbreviation}}Router.POST("create{{.StructName}}", handler.Create{{.StructName}})
|
|
{{.Abbreviation}}Router.DELETE("delete{{.StructName}}", handler.Delete{{.StructName}})
|
|
{{.Abbreviation}}Router.PUT("update{{.StructName}}", handler.Update{{.StructName}})
|
|
}
|
|
{
|
|
{{.Abbreviation}}RouterWithoutRecord.GET("find{{.StructName}}", handler.Find{{.StructName}})
|
|
{{.Abbreviation}}RouterWithoutRecord.GET("get{{.StructName}}List", handler.Get{{.StructName}}List)
|
|
}
|
|
}
|
|
`
|
|
result := renderTemplate(t, "router", tplContent, info)
|
|
|
|
requiredComponents := []string{
|
|
"package article",
|
|
"ArticleRouter struct",
|
|
"InitArticleRouter",
|
|
"gin.RouterGroup",
|
|
"articleRouter",
|
|
"middleware.OperationRecord",
|
|
}
|
|
|
|
for _, component := range requiredComponents {
|
|
if !strings.Contains(result, component) {
|
|
t.Errorf("Router 层代码应包含 '%s'", component)
|
|
}
|
|
}
|
|
|
|
validateGoSyntax(t, result, "router layer")
|
|
}
|
|
|
|
// testTypesLayerGeneration 测试 Types 层代码生成
|
|
func testTypesLayerGeneration(t *testing.T, info *AutoCodeInfo) {
|
|
tplContent := `package request
|
|
|
|
import (
|
|
"{{.Module}}/pkg/request"
|
|
)
|
|
|
|
type {{.StructName}}Request struct {
|
|
{{- range .Fields}}
|
|
{{.FieldName}} {{.FieldType}} ` + "`" + `json:"{{.FieldJson}}"` + "`" + `
|
|
{{- end}}
|
|
}
|
|
|
|
type {{.StructName}}IDRequest struct {
|
|
ID {{.PrimaryField.FieldType}} ` + "`" + `json:"{{.PrimaryField.FieldJson}}"` + "`" + `
|
|
}
|
|
|
|
type {{.StructName}}Search struct {
|
|
request.PageInfo
|
|
}
|
|
`
|
|
result := renderTemplate(t, "types", tplContent, info)
|
|
|
|
requiredComponents := []string{
|
|
"package request",
|
|
"ArticleRequest struct",
|
|
"ArticleIDRequest struct",
|
|
"ArticleSearch struct",
|
|
"PageInfo",
|
|
}
|
|
|
|
for _, component := range requiredComponents {
|
|
if !strings.Contains(result, component) {
|
|
t.Errorf("Types 层代码应包含 '%s'", component)
|
|
}
|
|
}
|
|
|
|
validateGoSyntax(t, result, "types layer")
|
|
}
|
|
|
|
// testModelLayerGeneration 测试 Model 层代码生成
|
|
func testModelLayerGeneration(t *testing.T, info *AutoCodeInfo) {
|
|
tplContent := `package model
|
|
|
|
type {{.StructName}} struct {
|
|
{{- if .GvaModel }}
|
|
GVA_MODEL
|
|
{{- end }}
|
|
{{- range .Fields}}
|
|
{{.FieldName}} {{.FieldType}} ` + "`" + `gorm:"column:{{.ColumnName}}" json:"{{.FieldJson}}"` + "`" + `
|
|
{{- end}}
|
|
}
|
|
|
|
func ({{.StructName}}) TableName() string {
|
|
return "{{.TableName}}"
|
|
}
|
|
`
|
|
result := renderTemplate(t, "model", tplContent, info)
|
|
|
|
requiredComponents := []string{
|
|
"package model",
|
|
"Article struct",
|
|
"GVA_MODEL",
|
|
"TableName()",
|
|
"articles",
|
|
"gorm:",
|
|
}
|
|
|
|
for _, component := range requiredComponents {
|
|
if !strings.Contains(result, component) {
|
|
t.Errorf("Model 层代码应包含 '%s'", component)
|
|
}
|
|
}
|
|
}
|
|
|
|
// ==================== 前端层测试 ====================
|
|
|
|
// testFrontendPagesGeneration 测试前端页面代码生成
|
|
func testFrontendPagesGeneration(t *testing.T, info *AutoCodeInfo) {
|
|
// 简化测试:直接验证模板函数可以生成正确的列配置
|
|
// 使用 autocode 包的类型来测试
|
|
field := autocode.AutoCodeField{
|
|
FieldName: "Title",
|
|
FieldType: "string",
|
|
FieldJson: "title",
|
|
FieldDesc: "标题",
|
|
FieldSearchType: "LIKE",
|
|
Table: true,
|
|
Sort: true,
|
|
DictType: "",
|
|
}
|
|
|
|
// 测试 GenerateReactProTableColumn 函数
|
|
result := autocode.GenerateReactProTableColumn(field)
|
|
|
|
requiredComponents := []string{
|
|
"title:",
|
|
"'标题'",
|
|
"dataIndex:",
|
|
"'title'",
|
|
"sorter: true",
|
|
}
|
|
|
|
for _, component := range requiredComponents {
|
|
if !strings.Contains(result, component) {
|
|
t.Errorf("ProTable 列配置应包含 '%s', 实际结果: %s", component, result)
|
|
}
|
|
}
|
|
|
|
// 测试字典类型字段
|
|
dictField := autocode.AutoCodeField{
|
|
FieldName: "Status",
|
|
FieldType: "int",
|
|
FieldJson: "status",
|
|
FieldDesc: "状态",
|
|
FieldSearchType: "=",
|
|
DictType: "articleStatus",
|
|
Table: true,
|
|
}
|
|
|
|
dictResult := autocode.GenerateReactProTableColumn(dictField)
|
|
if !strings.Contains(dictResult, "articleStatus") {
|
|
t.Errorf("字典类型列应包含 articleStatus, 实际结果: %s", dictResult)
|
|
}
|
|
if !strings.Contains(dictResult, "select") {
|
|
t.Errorf("字典类型列应包含 select valueType, 实际结果: %s", dictResult)
|
|
}
|
|
|
|
// 测试时间类型字段
|
|
timeField := autocode.AutoCodeField{
|
|
FieldName: "PublishTime",
|
|
FieldType: "time.Time",
|
|
FieldJson: "publishTime",
|
|
FieldDesc: "发布时间",
|
|
FieldSearchType: "BETWEEN",
|
|
Table: true,
|
|
}
|
|
|
|
timeResult := autocode.GenerateReactProTableColumn(timeField)
|
|
if !strings.Contains(timeResult, "publishTime") {
|
|
t.Errorf("时间类型列应包含 publishTime, 实际结果: %s", timeResult)
|
|
}
|
|
if !strings.Contains(timeResult, "dateTime") {
|
|
t.Errorf("时间类型列应包含 dateTime valueType, 实际结果: %s", timeResult)
|
|
}
|
|
|
|
// 测试布尔类型字段
|
|
boolField := autocode.AutoCodeField{
|
|
FieldName: "IsTop",
|
|
FieldType: "bool",
|
|
FieldJson: "isTop",
|
|
FieldDesc: "是否置顶",
|
|
Table: true,
|
|
}
|
|
|
|
boolResult := autocode.GenerateReactProTableColumn(boolField)
|
|
if !strings.Contains(boolResult, "isTop") {
|
|
t.Errorf("布尔类型列应包含 isTop, 实际结果: %s", boolResult)
|
|
}
|
|
}
|
|
|
|
// testFrontendServicesGeneration 测试前端 API 服务代码生成
|
|
func testFrontendServicesGeneration(t *testing.T, info *AutoCodeInfo) {
|
|
tplContent := `import { request } from '@umijs/max';
|
|
|
|
/** 获取{{.Description}}列表 */
|
|
export async function get{{.StructName}}List(params: API.{{.StructName}}ListParams) {
|
|
return request<API.{{.StructName}}ListResult>('/api/{{.Abbreviation}}/get{{.StructName}}List', {
|
|
method: 'GET',
|
|
params,
|
|
});
|
|
}
|
|
|
|
/** 获取{{.Description}}详情 */
|
|
export async function get{{.StructName}}ById({{.PrimaryField.FieldJson}}: {{ GenerateTSType .PrimaryField.FieldType }}) {
|
|
return request<API.{{.StructName}}>('/api/{{.Abbreviation}}/find{{.StructName}}', {
|
|
method: 'GET',
|
|
params: { {{.PrimaryField.FieldJson}} },
|
|
});
|
|
}
|
|
|
|
/** 创建{{.Description}} */
|
|
export async function create{{.StructName}}(data: API.{{.StructName}}Request) {
|
|
return request<API.Response>('/api/{{.Abbreviation}}/create{{.StructName}}', {
|
|
method: 'POST',
|
|
data,
|
|
});
|
|
}
|
|
|
|
/** 更新{{.Description}} */
|
|
export async function update{{.StructName}}(data: API.{{.StructName}}Request) {
|
|
return request<API.Response>('/api/{{.Abbreviation}}/update{{.StructName}}', {
|
|
method: 'PUT',
|
|
data,
|
|
});
|
|
}
|
|
|
|
/** 删除{{.Description}} */
|
|
export async function delete{{.StructName}}({{.PrimaryField.FieldJson}}: {{ GenerateTSType .PrimaryField.FieldType }}) {
|
|
return request<API.Response>('/api/{{.Abbreviation}}/delete{{.StructName}}', {
|
|
method: 'DELETE',
|
|
params: { {{.PrimaryField.FieldJson}} },
|
|
});
|
|
}
|
|
|
|
/** 批量删除{{.Description}} */
|
|
export async function delete{{.StructName}}ByIds(ids: {{ GenerateTSType .PrimaryField.FieldType }}[]) {
|
|
return request<API.Response>('/api/{{.Abbreviation}}/delete{{.StructName}}ByIds', {
|
|
method: 'DELETE',
|
|
data: { ids },
|
|
});
|
|
}
|
|
`
|
|
result := renderTemplate(t, "services", tplContent, info)
|
|
|
|
// 验证 API 服务必要组件
|
|
requiredComponents := []string{
|
|
"import { request } from '@umijs/max'",
|
|
"getArticleList",
|
|
"getArticleById",
|
|
"createArticle",
|
|
"updateArticle",
|
|
"deleteArticle",
|
|
"deleteArticleByIds",
|
|
"/api/article/",
|
|
"method: 'GET'",
|
|
"method: 'POST'",
|
|
"method: 'PUT'",
|
|
"method: 'DELETE'",
|
|
"API.ArticleListParams",
|
|
"API.ArticleRequest",
|
|
"API.Response",
|
|
}
|
|
|
|
for _, component := range requiredComponents {
|
|
if !strings.Contains(result, component) {
|
|
t.Errorf("前端 API 服务代码应包含 '%s'", component)
|
|
}
|
|
}
|
|
|
|
// 验证类型转换
|
|
if !strings.Contains(result, "number") {
|
|
t.Error("前端 API 服务应包含 number 类型(从 uint 转换)")
|
|
}
|
|
}
|
|
|
|
// testFrontendTypesGeneration 测试前端 TypeScript 类型定义生成
|
|
func testFrontendTypesGeneration(t *testing.T, info *AutoCodeInfo) {
|
|
tplContent := `declare namespace API {
|
|
/** {{.Description}}实体 */
|
|
interface {{.StructName}} {
|
|
{{.PrimaryField.FieldJson}}?: {{ GenerateTSType .PrimaryField.FieldType }};
|
|
{{- range .Fields}}
|
|
{{.FieldJson}}?: {{ GenerateTSType .FieldType }};
|
|
{{- end}}
|
|
createdAt?: string;
|
|
updatedAt?: string;
|
|
}
|
|
|
|
/** {{.Description}}请求参数 */
|
|
interface {{.StructName}}Request {
|
|
{{- range .Fields}}
|
|
{{- if .Form}}
|
|
{{.FieldJson}}{{if not .Require}}?{{end}}: {{ GenerateTSType .FieldType }};
|
|
{{- end}}
|
|
{{- end}}
|
|
}
|
|
|
|
/** {{.Description}}列表请求参数 */
|
|
interface {{.StructName}}ListParams {
|
|
page?: number;
|
|
pageSize?: number;
|
|
{{- range .Fields}}
|
|
{{- if .FieldSearchType}}
|
|
{{.FieldJson}}?: {{ GenerateTSType .FieldType }};
|
|
{{- end}}
|
|
{{- end}}
|
|
}
|
|
|
|
/** {{.Description}}列表响应 */
|
|
interface {{.StructName}}ListResult {
|
|
list: {{.StructName}}[];
|
|
total: number;
|
|
page: number;
|
|
pageSize: number;
|
|
}
|
|
}
|
|
`
|
|
result := renderTemplate(t, "types", tplContent, info)
|
|
|
|
// 验证 TypeScript 类型定义必要组件
|
|
requiredComponents := []string{
|
|
"declare namespace API",
|
|
"interface Article",
|
|
"interface ArticleRequest",
|
|
"interface ArticleListParams",
|
|
"interface ArticleListResult",
|
|
"id?:",
|
|
"title",
|
|
"content",
|
|
"status",
|
|
"publishTime",
|
|
"viewCount",
|
|
"createdAt?: string",
|
|
"updatedAt?: string",
|
|
"page?: number",
|
|
"pageSize?: number",
|
|
"list: Article[]",
|
|
"total: number",
|
|
}
|
|
|
|
for _, component := range requiredComponents {
|
|
if !strings.Contains(result, component) {
|
|
t.Errorf("前端 TypeScript 类型定义应包含 '%s'", component)
|
|
}
|
|
}
|
|
|
|
// 验证类型映射正确性
|
|
typeChecks := map[string]string{
|
|
"string": "string",
|
|
"number": "number", // int/uint -> number
|
|
"boolean": "boolean",
|
|
}
|
|
|
|
for _, expectedType := range typeChecks {
|
|
if !strings.Contains(result, expectedType) {
|
|
t.Errorf("前端类型定义应包含 TypeScript 类型 '%s'", expectedType)
|
|
}
|
|
}
|
|
}
|
|
|
|
// ==================== Wire 注入测试 ====================
|
|
|
|
// TestWireProviderInjectionAndRollback 测试 Wire Provider 注入和回滚
|
|
func TestWireProviderInjectionAndRollback(t *testing.T) {
|
|
tempDir := t.TempDir()
|
|
testFile := filepath.Join(tempDir, "test_wire.go")
|
|
|
|
initialContent := `package test
|
|
|
|
import (
|
|
"github.com/google/wire"
|
|
)
|
|
|
|
var ProviderSet = wire.NewSet(
|
|
NewExistingProvider,
|
|
)
|
|
`
|
|
if err := os.WriteFile(testFile, []byte(initialContent), 0644); err != nil {
|
|
t.Fatalf("创建测试文件失败: %v", err)
|
|
}
|
|
|
|
injector := autocode.NewWireProviderAst("biz", "article", "NewArticleUsecase", testFile)
|
|
|
|
file, err := injector.Parse(testFile, nil)
|
|
if err != nil {
|
|
t.Fatalf("解析文件失败: %v", err)
|
|
}
|
|
|
|
// 执行注入
|
|
if err := injector.Injection(file); err != nil {
|
|
t.Fatalf("注入失败: %v", err)
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
if err := injector.Format(testFile, &buf, file); err != nil {
|
|
t.Fatalf("格式化失败: %v", err)
|
|
}
|
|
|
|
injectedContent := buf.String()
|
|
if !strings.Contains(injectedContent, "NewArticleUsecase") {
|
|
t.Error("注入后应包含 NewArticleUsecase")
|
|
}
|
|
if !strings.Contains(injectedContent, "NewExistingProvider") {
|
|
t.Error("注入后应保留 NewExistingProvider")
|
|
}
|
|
|
|
// 写入注入后的内容并重新解析
|
|
if err := os.WriteFile(testFile, []byte(injectedContent), 0644); err != nil {
|
|
t.Fatalf("写入注入后文件失败: %v", err)
|
|
}
|
|
|
|
file2, err := injector.Parse(testFile, nil)
|
|
if err != nil {
|
|
t.Fatalf("解析注入后文件失败: %v", err)
|
|
}
|
|
|
|
// 执行回滚
|
|
if err := injector.Rollback(file2); err != nil {
|
|
t.Fatalf("回滚失败: %v", err)
|
|
}
|
|
|
|
var rollbackBuf bytes.Buffer
|
|
if err := injector.Format(testFile, &rollbackBuf, file2); err != nil {
|
|
t.Fatalf("格式化回滚结果失败: %v", err)
|
|
}
|
|
|
|
rollbackContent := rollbackBuf.String()
|
|
if strings.Contains(rollbackContent, "NewArticleUsecase") {
|
|
t.Error("回滚后不应包含 NewArticleUsecase")
|
|
}
|
|
if !strings.Contains(rollbackContent, "NewExistingProvider") {
|
|
t.Error("回滚后应保留 NewExistingProvider")
|
|
}
|
|
}
|
|
|
|
// TestMainProviderSetInjectionAndRollback 测试主 ProviderSet 注入和回滚
|
|
func TestMainProviderSetInjectionAndRollback(t *testing.T) {
|
|
tempDir := t.TempDir()
|
|
testFile := filepath.Join(tempDir, "biz.go")
|
|
|
|
initialContent := `package biz
|
|
|
|
import (
|
|
"github.com/google/wire"
|
|
"kra/internal/biz/system"
|
|
)
|
|
|
|
var ProviderSet = wire.NewSet(
|
|
system.NewUserUsecase,
|
|
)
|
|
`
|
|
if err := os.WriteFile(testFile, []byte(initialContent), 0644); err != nil {
|
|
t.Fatalf("创建测试文件失败: %v", err)
|
|
}
|
|
|
|
injector := autocode.NewMainProviderSetAst(
|
|
"biz", "article", "NewArticleUsecase",
|
|
"kra/internal/biz/article", testFile,
|
|
)
|
|
|
|
file, err := injector.Parse(testFile, nil)
|
|
if err != nil {
|
|
t.Fatalf("解析文件失败: %v", err)
|
|
}
|
|
|
|
if err := injector.Injection(file); err != nil {
|
|
t.Fatalf("注入失败: %v", err)
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
if err := injector.Format(testFile, &buf, file); err != nil {
|
|
t.Fatalf("格式化失败: %v", err)
|
|
}
|
|
|
|
injectedContent := buf.String()
|
|
if !strings.Contains(injectedContent, "article.NewArticleUsecase") {
|
|
t.Error("注入后应包含 article.NewArticleUsecase")
|
|
}
|
|
if !strings.Contains(injectedContent, "kra/internal/biz/article") {
|
|
t.Error("注入后应包含导入路径")
|
|
}
|
|
|
|
// 写入并重新解析
|
|
if err := os.WriteFile(testFile, []byte(injectedContent), 0644); err != nil {
|
|
t.Fatalf("写入注入后文件失败: %v", err)
|
|
}
|
|
|
|
file2, err := injector.Parse(testFile, nil)
|
|
if err != nil {
|
|
t.Fatalf("解析注入后文件失败: %v", err)
|
|
}
|
|
|
|
// 执行回滚
|
|
if err := injector.Rollback(file2); err != nil {
|
|
t.Fatalf("回滚失败: %v", err)
|
|
}
|
|
|
|
var rollbackBuf bytes.Buffer
|
|
if err := injector.Format(testFile, &rollbackBuf, file2); err != nil {
|
|
t.Fatalf("格式化回滚结果失败: %v", err)
|
|
}
|
|
|
|
rollbackContent := rollbackBuf.String()
|
|
if strings.Contains(rollbackContent, "article.NewArticleUsecase") {
|
|
t.Error("回滚后不应包含 article.NewArticleUsecase")
|
|
}
|
|
if strings.Contains(rollbackContent, "kra/internal/biz/article") {
|
|
t.Error("回滚后不应包含 article 导入路径")
|
|
}
|
|
if !strings.Contains(rollbackContent, "system.NewUserUsecase") {
|
|
t.Error("回滚后应保留 system.NewUserUsecase")
|
|
}
|
|
}
|
|
|
|
// TestMultipleInjectionAndRollback 测试多次注入和回滚
|
|
func TestMultipleInjectionAndRollback(t *testing.T) {
|
|
tempDir := t.TempDir()
|
|
testFile := filepath.Join(tempDir, "test_multi.go")
|
|
|
|
initialContent := `package test
|
|
|
|
import (
|
|
"github.com/google/wire"
|
|
)
|
|
|
|
var ProviderSet = wire.NewSet(
|
|
NewBaseProvider,
|
|
)
|
|
`
|
|
if err := os.WriteFile(testFile, []byte(initialContent), 0644); err != nil {
|
|
t.Fatalf("创建测试文件失败: %v", err)
|
|
}
|
|
|
|
providers := []string{"NewProvider1", "NewProvider2", "NewProvider3"}
|
|
currentContent := initialContent
|
|
|
|
// 依次注入多个 Provider
|
|
for _, provider := range providers {
|
|
if err := os.WriteFile(testFile, []byte(currentContent), 0644); err != nil {
|
|
t.Fatalf("写入文件失败: %v", err)
|
|
}
|
|
|
|
injector := autocode.NewWireProviderAst("biz", "test", provider, testFile)
|
|
file, err := injector.Parse(testFile, nil)
|
|
if err != nil {
|
|
t.Fatalf("解析文件失败: %v", err)
|
|
}
|
|
|
|
if err := injector.Injection(file); err != nil {
|
|
t.Fatalf("注入 %s 失败: %v", provider, err)
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
if err := injector.Format(testFile, &buf, file); err != nil {
|
|
t.Fatalf("格式化失败: %v", err)
|
|
}
|
|
currentContent = buf.String()
|
|
}
|
|
|
|
// 验证所有 Provider 都已注入
|
|
for _, provider := range providers {
|
|
if !strings.Contains(currentContent, provider) {
|
|
t.Errorf("应包含 %s", provider)
|
|
}
|
|
}
|
|
|
|
// 逆序回滚所有 Provider
|
|
for i := len(providers) - 1; i >= 0; i-- {
|
|
provider := providers[i]
|
|
if err := os.WriteFile(testFile, []byte(currentContent), 0644); err != nil {
|
|
t.Fatalf("写入文件失败: %v", err)
|
|
}
|
|
|
|
injector := autocode.NewWireProviderAst("biz", "test", provider, testFile)
|
|
file, err := injector.Parse(testFile, nil)
|
|
if err != nil {
|
|
t.Fatalf("解析文件失败: %v", err)
|
|
}
|
|
|
|
if err := injector.Rollback(file); err != nil {
|
|
t.Fatalf("回滚 %s 失败: %v", provider, err)
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
if err := injector.Format(testFile, &buf, file); err != nil {
|
|
t.Fatalf("格式化失败: %v", err)
|
|
}
|
|
currentContent = buf.String()
|
|
|
|
if strings.Contains(currentContent, provider) {
|
|
t.Errorf("回滚后不应包含 %s", provider)
|
|
}
|
|
}
|
|
|
|
if !strings.Contains(currentContent, "NewBaseProvider") {
|
|
t.Error("回滚后应保留 NewBaseProvider")
|
|
}
|
|
}
|
|
|
|
// TestRollbackIdempotency 测试回滚幂等性
|
|
func TestRollbackIdempotency(t *testing.T) {
|
|
tempDir := t.TempDir()
|
|
testFile := filepath.Join(tempDir, "test_idempotent.go")
|
|
|
|
initialContent := `package test
|
|
|
|
import (
|
|
"github.com/google/wire"
|
|
)
|
|
|
|
var ProviderSet = wire.NewSet(
|
|
NewExistingProvider,
|
|
NewTargetProvider,
|
|
)
|
|
`
|
|
if err := os.WriteFile(testFile, []byte(initialContent), 0644); err != nil {
|
|
t.Fatalf("创建测试文件失败: %v", err)
|
|
}
|
|
|
|
injector := autocode.NewWireProviderAst("biz", "test", "NewTargetProvider", testFile)
|
|
|
|
// 第一次回滚
|
|
file1, _ := injector.Parse(testFile, nil)
|
|
_ = injector.Rollback(file1)
|
|
var buf1 bytes.Buffer
|
|
_ = injector.Format(testFile, &buf1, file1)
|
|
firstRollback := buf1.String()
|
|
|
|
if err := os.WriteFile(testFile, []byte(firstRollback), 0644); err != nil {
|
|
t.Fatalf("写入文件失败: %v", err)
|
|
}
|
|
|
|
// 第二次回滚
|
|
file2, _ := injector.Parse(testFile, nil)
|
|
_ = injector.Rollback(file2)
|
|
var buf2 bytes.Buffer
|
|
_ = injector.Format(testFile, &buf2, file2)
|
|
secondRollback := buf2.String()
|
|
|
|
// 验证幂等性
|
|
if firstRollback != secondRollback {
|
|
t.Error("回滚应该是幂等的")
|
|
}
|
|
|
|
if strings.Contains(secondRollback, "NewTargetProvider") {
|
|
t.Error("回滚后不应包含 NewTargetProvider")
|
|
}
|
|
if !strings.Contains(secondRollback, "NewExistingProvider") {
|
|
t.Error("回滚后应保留 NewExistingProvider")
|
|
}
|
|
}
|
|
|
|
// countProviderSetArgs 计算 ProviderSet 中的参数数量
|
|
func countProviderSetArgs(file *ast.File) int {
|
|
var count int
|
|
ast.Inspect(file, func(node ast.Node) bool {
|
|
genDecl, ok := node.(*ast.GenDecl)
|
|
if !ok || genDecl.Tok != token.VAR {
|
|
return true
|
|
}
|
|
|
|
for _, spec := range genDecl.Specs {
|
|
valueSpec, ok := spec.(*ast.ValueSpec)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
for i, name := range valueSpec.Names {
|
|
if name.Name != "ProviderSet" {
|
|
continue
|
|
}
|
|
|
|
if i >= len(valueSpec.Values) {
|
|
continue
|
|
}
|
|
|
|
callExpr, ok := valueSpec.Values[i].(*ast.CallExpr)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
count = len(callExpr.Args)
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
return count
|
|
}
|
|
|
|
// TestRollbackPreservesFormatting 测试回滚保持代码格式
|
|
func TestRollbackPreservesFormatting(t *testing.T) {
|
|
tempDir := t.TempDir()
|
|
testFile := filepath.Join(tempDir, "format_test.go")
|
|
|
|
initialContent := `package test
|
|
|
|
import (
|
|
"github.com/google/wire"
|
|
)
|
|
|
|
var ProviderSet = wire.NewSet(
|
|
NewProvider1,
|
|
NewProvider2,
|
|
NewTargetProvider,
|
|
)
|
|
`
|
|
if err := os.WriteFile(testFile, []byte(initialContent), 0644); err != nil {
|
|
t.Fatalf("创建测试文件失败: %v", err)
|
|
}
|
|
|
|
injector := autocode.NewWireProviderAst("biz", "test", "NewTargetProvider", testFile)
|
|
|
|
file, err := injector.Parse(testFile, nil)
|
|
if err != nil {
|
|
t.Fatalf("解析文件失败: %v", err)
|
|
}
|
|
|
|
if err := injector.Rollback(file); err != nil {
|
|
t.Fatalf("回滚失败: %v", err)
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
if err := injector.Format(testFile, &buf, file); err != nil {
|
|
t.Fatalf("格式化失败: %v", err)
|
|
}
|
|
|
|
result := buf.String()
|
|
|
|
// 验证代码仍然是有效的 Go 代码
|
|
fset := token.NewFileSet()
|
|
_, err = parser.ParseFile(fset, "", result, parser.AllErrors)
|
|
if err != nil {
|
|
t.Errorf("回滚后代码语法错误: %v", err)
|
|
}
|
|
|
|
// 验证 gofmt 幂等性
|
|
formatted1, err := format.Source([]byte(result))
|
|
if err != nil {
|
|
t.Errorf("第一次格式化失败: %v", err)
|
|
return
|
|
}
|
|
formatted2, err := format.Source(formatted1)
|
|
if err != nil {
|
|
t.Errorf("第二次格式化失败: %v", err)
|
|
return
|
|
}
|
|
if string(formatted1) != string(formatted2) {
|
|
t.Error("回滚后代码 gofmt 不是幂等的")
|
|
}
|
|
}
|
|
|
|
// TestRollbackNonExistentProvider 测试回滚不存在的 Provider
|
|
func TestRollbackNonExistentProvider(t *testing.T) {
|
|
tempDir := t.TempDir()
|
|
testFile := filepath.Join(tempDir, "nonexistent.go")
|
|
|
|
initialContent := `package test
|
|
|
|
import (
|
|
"github.com/google/wire"
|
|
)
|
|
|
|
var ProviderSet = wire.NewSet(
|
|
NewExistingProvider,
|
|
)
|
|
`
|
|
if err := os.WriteFile(testFile, []byte(initialContent), 0644); err != nil {
|
|
t.Fatalf("创建测试文件失败: %v", err)
|
|
}
|
|
|
|
injector := autocode.NewWireProviderAst("biz", "test", "NewNonExistentProvider", testFile)
|
|
|
|
file, err := injector.Parse(testFile, nil)
|
|
if err != nil {
|
|
t.Fatalf("解析文件失败: %v", err)
|
|
}
|
|
|
|
// 回滚不存在的 Provider 不应该报错
|
|
if err := injector.Rollback(file); err != nil {
|
|
t.Errorf("回滚不存在的 Provider 不应该报错: %v", err)
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
if err := injector.Format(testFile, &buf, file); err != nil {
|
|
t.Fatalf("格式化失败: %v", err)
|
|
}
|
|
|
|
result := buf.String()
|
|
|
|
// 验证原有 Provider 仍然存在
|
|
if !strings.Contains(result, "NewExistingProvider") {
|
|
t.Error("回滚后应保留 NewExistingProvider")
|
|
}
|
|
}
|