kra/cmd/kra-gen/interactive.go

1049 lines
31 KiB
Go

package main
import (
"bufio"
"context"
"fmt"
"os"
"strconv"
"strings"
"gopkg.in/yaml.v3"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// DBConfig 数据库配置
type DBConfig struct {
Type string `yaml:"type"` // mysql, pgsql, sqlite
Host string `yaml:"host"` // 数据库主机
Port string `yaml:"port"` // 数据库端口
DBName string `yaml:"dbName"` // 数据库名
Username string `yaml:"username"` // 用户名
Password string `yaml:"password"` // 密码
Config string `yaml:"config"` // 额外配置
}
// InteractiveConfig 交互模式配置
type InteractiveConfig struct {
Database DBConfig `yaml:"database"`
}
// DBMetadata 数据库元数据读取器
type DBMetadata struct {
db *gorm.DB
dbType string
}
// DatabaseInfo 数据库信息
type DatabaseInfo struct {
Database string `gorm:"column:database"`
}
// TableInfo 表信息
type TableInfo struct {
TableName string `gorm:"column:table_name"`
}
// ColumnInfo 列信息
type ColumnInfo struct {
ColumnName string `gorm:"column:column_name"`
DataType string `gorm:"column:data_type"`
DataTypeLong string `gorm:"column:data_type_long"`
ColumnComment string `gorm:"column:column_comment"`
PrimaryKey int `gorm:"column:primary_key"`
}
// NewDBMetadata 创建数据库元数据读取器
func NewDBMetadata(config *DBConfig) (*DBMetadata, error) {
var db *gorm.DB
var err error
gormConfig := &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
}
switch config.Type {
case "mysql":
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?%s",
config.Username, config.Password, config.Host, config.Port, config.DBName, config.Config)
if config.Config == "" {
dsn = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
config.Username, config.Password, config.Host, config.Port, config.DBName)
}
db, err = gorm.Open(mysql.Open(dsn), gormConfig)
case "pgsql", "postgres":
dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable",
config.Host, config.Port, config.Username, config.Password, config.DBName)
db, err = gorm.Open(postgres.Open(dsn), gormConfig)
case "sqlite":
db, err = gorm.Open(sqlite.Open(config.DBName), gormConfig)
default:
return nil, fmt.Errorf("不支持的数据库类型: %s", config.Type)
}
if err != nil {
return nil, fmt.Errorf("连接数据库失败: %w", err)
}
return &DBMetadata{db: db, dbType: config.Type}, nil
}
// GetDatabases 获取所有数据库
func (m *DBMetadata) GetDatabases(ctx context.Context) ([]string, error) {
var databases []DatabaseInfo
var sql string
switch m.dbType {
case "mysql":
sql = "SELECT SCHEMA_NAME AS `database` FROM INFORMATION_SCHEMA.SCHEMATA"
case "pgsql", "postgres":
sql = "SELECT datname as database FROM pg_database WHERE datistemplate = false"
case "sqlite":
// SQLite 只有一个数据库
return []string{"main"}, nil
default:
return nil, fmt.Errorf("不支持的数据库类型: %s", m.dbType)
}
if err := m.db.WithContext(ctx).Raw(sql).Scan(&databases).Error; err != nil {
return nil, err
}
result := make([]string, len(databases))
for i, db := range databases {
result[i] = db.Database
}
return result, nil
}
// GetTables 获取指定数据库的所有表
func (m *DBMetadata) GetTables(ctx context.Context, dbName string) ([]string, error) {
var tables []TableInfo
var sql string
switch m.dbType {
case "mysql":
sql = "SELECT table_name AS table_name FROM information_schema.tables WHERE table_schema = ?"
case "pgsql", "postgres":
sql = "SELECT tablename AS table_name FROM pg_tables WHERE schemaname = 'public'"
case "sqlite":
sql = "SELECT name AS table_name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
default:
return nil, fmt.Errorf("不支持的数据库类型: %s", m.dbType)
}
var err error
if m.dbType == "mysql" {
err = m.db.WithContext(ctx).Raw(sql, dbName).Scan(&tables).Error
} else {
err = m.db.WithContext(ctx).Raw(sql).Scan(&tables).Error
}
if err != nil {
return nil, err
}
result := make([]string, len(tables))
for i, t := range tables {
result[i] = t.TableName
}
return result, nil
}
// GetColumns 获取指定表的所有列
func (m *DBMetadata) GetColumns(ctx context.Context, dbName, tableName string) ([]ColumnInfo, error) {
var columns []ColumnInfo
var sql string
switch m.dbType {
case "mysql":
sql = `
SELECT
c.COLUMN_NAME column_name,
c.DATA_TYPE data_type,
CASE c.DATA_TYPE
WHEN 'longtext' THEN c.CHARACTER_MAXIMUM_LENGTH
WHEN 'varchar' THEN c.CHARACTER_MAXIMUM_LENGTH
WHEN 'double' THEN CONCAT_WS(',', c.NUMERIC_PRECISION, c.NUMERIC_SCALE)
WHEN 'decimal' THEN CONCAT_WS(',', c.NUMERIC_PRECISION, c.NUMERIC_SCALE)
WHEN 'int' THEN c.NUMERIC_PRECISION
WHEN 'bigint' THEN c.NUMERIC_PRECISION
ELSE ''
END AS data_type_long,
c.COLUMN_COMMENT column_comment,
CASE WHEN kcu.COLUMN_NAME IS NOT NULL THEN 1 ELSE 0 END AS primary_key
FROM
INFORMATION_SCHEMA.COLUMNS c
LEFT JOIN
INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu
ON
c.TABLE_SCHEMA = kcu.TABLE_SCHEMA
AND c.TABLE_NAME = kcu.TABLE_NAME
AND c.COLUMN_NAME = kcu.COLUMN_NAME
AND kcu.CONSTRAINT_NAME = 'PRIMARY'
WHERE
c.TABLE_NAME = ?
AND c.TABLE_SCHEMA = ?
ORDER BY
c.ORDINAL_POSITION`
if err := m.db.WithContext(ctx).Raw(sql, tableName, dbName).Scan(&columns).Error; err != nil {
return nil, err
}
case "pgsql", "postgres":
sql = `
SELECT
c.column_name,
c.data_type,
COALESCE(c.character_maximum_length::text, c.numeric_precision::text, '') as data_type_long,
COALESCE(pgd.description, '') as column_comment,
CASE WHEN pk.column_name IS NOT NULL THEN 1 ELSE 0 END as primary_key
FROM
information_schema.columns c
LEFT JOIN
pg_catalog.pg_statio_all_tables st ON c.table_schema = st.schemaname AND c.table_name = st.relname
LEFT JOIN
pg_catalog.pg_description pgd ON pgd.objoid = st.relid AND pgd.objsubid = c.ordinal_position
LEFT JOIN (
SELECT ku.column_name
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage ku ON tc.constraint_name = ku.constraint_name
WHERE tc.constraint_type = 'PRIMARY KEY' AND tc.table_name = $1
) pk ON c.column_name = pk.column_name
WHERE
c.table_name = $1 AND c.table_schema = 'public'
ORDER BY
c.ordinal_position`
if err := m.db.WithContext(ctx).Raw(sql, tableName).Scan(&columns).Error; err != nil {
return nil, err
}
case "sqlite":
sql = fmt.Sprintf("PRAGMA table_info(%s)", tableName)
var pragmaInfo []struct {
CID int `gorm:"column:cid"`
Name string `gorm:"column:name"`
Type string `gorm:"column:type"`
NotNull int `gorm:"column:notnull"`
DfltValue string `gorm:"column:dflt_value"`
PK int `gorm:"column:pk"`
}
if err := m.db.WithContext(ctx).Raw(sql).Scan(&pragmaInfo).Error; err != nil {
return nil, err
}
for _, info := range pragmaInfo {
columns = append(columns, ColumnInfo{
ColumnName: info.Name,
DataType: info.Type,
DataTypeLong: info.Type,
ColumnComment: "",
PrimaryKey: info.PK,
})
}
default:
return nil, fmt.Errorf("不支持的数据库类型: %s", m.dbType)
}
return columns, nil
}
// Close 关闭数据库连接
func (m *DBMetadata) Close() error {
sqlDB, err := m.db.DB()
if err != nil {
return err
}
return sqlDB.Close()
}
// InteractiveSession 交互式会话
type InteractiveSession struct {
reader *bufio.Reader
metadata *DBMetadata
config *AutoCodeConfig
dbName string
}
// NewInteractiveSession 创建交互式会话
func NewInteractiveSession() *InteractiveSession {
return &InteractiveSession{
reader: bufio.NewReader(os.Stdin),
config: &AutoCodeConfig{
Options: AutoCodeOptions{
GvaModel: true,
AutoMigrate: true,
AutoCreateResource: false,
AutoCreateApiToSql: true,
AutoCreateMenuToSql: true,
AutoCreateBtnAuth: true,
OnlyTemplate: false,
IsTree: false,
TreeJson: "",
IsAdd: false,
GenerateWeb: true,
GenerateServer: true,
},
},
}
}
// Run 运行交互式会话
func (s *InteractiveSession) Run() (*AutoCodeConfig, error) {
fmt.Println("========================================")
fmt.Println("KRA 代码生成 - 交互模式")
fmt.Println("========================================")
fmt.Println()
// 步骤1: 配置数据库连接
if err := s.configureDatabase(); err != nil {
return nil, err
}
// 步骤2: 选择数据库
if err := s.selectDatabase(); err != nil {
return nil, err
}
// 步骤3: 选择表
if err := s.selectTable(); err != nil {
return nil, err
}
// 步骤4: 配置基本信息
if err := s.configureBasicInfo(); err != nil {
return nil, err
}
// 步骤5: 读取并配置字段
if err := s.configureFields(); err != nil {
return nil, err
}
// 步骤6: 配置生成选项
if err := s.configureOptions(); err != nil {
return nil, err
}
// 关闭数据库连接
if s.metadata != nil {
s.metadata.Close()
}
return s.config, nil
}
// configureDatabase 配置数据库连接
func (s *InteractiveSession) configureDatabase() error {
fmt.Println("步骤 1/6: 配置数据库连接")
fmt.Println("----------------------------------------")
// 检查是否有配置文件
configPath := "configs/config.yaml"
if _, err := os.Stat(configPath); err == nil {
fmt.Printf("检测到配置文件: %s\n", configPath)
useConfig := s.promptConfirm("是否使用配置文件中的数据库连接?")
if useConfig {
return s.loadDatabaseFromConfig(configPath)
}
}
// 手动配置数据库
return s.manualDatabaseConfig()
}
// loadDatabaseFromConfig 从配置文件加载数据库配置
func (s *InteractiveSession) loadDatabaseFromConfig(configPath string) error {
data, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("读取配置文件失败: %w", err)
}
var config struct {
Mysql struct {
Path string `yaml:"path"`
Port string `yaml:"port"`
DBName string `yaml:"db_name"`
Username string `yaml:"username"`
Password string `yaml:"password"`
Config string `yaml:"config"`
} `yaml:"mysql"`
Pgsql struct {
Path string `yaml:"path"`
Port string `yaml:"port"`
DBName string `yaml:"db_name"`
Username string `yaml:"username"`
Password string `yaml:"password"`
} `yaml:"pgsql"`
Sqlite struct {
DBName string `yaml:"db_name"`
} `yaml:"sqlite"`
}
if err := yaml.Unmarshal(data, &config); err != nil {
return fmt.Errorf("解析配置文件失败: %w", err)
}
// 优先使用 MySQL
if config.Mysql.Path != "" && config.Mysql.DBName != "" {
dbConfig := &DBConfig{
Type: "mysql",
Host: config.Mysql.Path,
Port: config.Mysql.Port,
DBName: config.Mysql.DBName,
Username: config.Mysql.Username,
Password: config.Mysql.Password,
Config: config.Mysql.Config,
}
fmt.Printf("使用 MySQL 数据库: %s@%s:%s/%s\n", dbConfig.Username, dbConfig.Host, dbConfig.Port, dbConfig.DBName)
return s.connectDatabase(dbConfig)
}
// 其次使用 PostgreSQL
if config.Pgsql.Path != "" && config.Pgsql.DBName != "" {
dbConfig := &DBConfig{
Type: "pgsql",
Host: config.Pgsql.Path,
Port: config.Pgsql.Port,
DBName: config.Pgsql.DBName,
Username: config.Pgsql.Username,
Password: config.Pgsql.Password,
}
fmt.Printf("使用 PostgreSQL 数据库: %s@%s:%s/%s\n", dbConfig.Username, dbConfig.Host, dbConfig.Port, dbConfig.DBName)
return s.connectDatabase(dbConfig)
}
// 最后使用 SQLite
if config.Sqlite.DBName != "" {
dbConfig := &DBConfig{
Type: "sqlite",
DBName: config.Sqlite.DBName,
}
fmt.Printf("使用 SQLite 数据库: %s\n", dbConfig.DBName)
return s.connectDatabase(dbConfig)
}
fmt.Println("配置文件中未找到有效的数据库配置,请手动配置")
return s.manualDatabaseConfig()
}
// manualDatabaseConfig 手动配置数据库
func (s *InteractiveSession) manualDatabaseConfig() error {
fmt.Println("\n请选择数据库类型:")
fmt.Println(" 1. MySQL")
fmt.Println(" 2. PostgreSQL")
fmt.Println(" 3. SQLite")
choice := s.promptInput("请输入选项 (1-3): ")
dbConfig := &DBConfig{}
switch choice {
case "1":
dbConfig.Type = "mysql"
dbConfig.Host = s.promptInputWithDefault("数据库主机", "127.0.0.1")
dbConfig.Port = s.promptInputWithDefault("数据库端口", "3306")
dbConfig.Username = s.promptInputWithDefault("用户名", "root")
dbConfig.Password = s.promptInput("密码: ")
dbConfig.DBName = s.promptInput("数据库名: ")
case "2":
dbConfig.Type = "pgsql"
dbConfig.Host = s.promptInputWithDefault("数据库主机", "127.0.0.1")
dbConfig.Port = s.promptInputWithDefault("数据库端口", "5432")
dbConfig.Username = s.promptInputWithDefault("用户名", "postgres")
dbConfig.Password = s.promptInput("密码: ")
dbConfig.DBName = s.promptInput("数据库名: ")
case "3":
dbConfig.Type = "sqlite"
dbConfig.DBName = s.promptInput("数据库文件路径: ")
default:
return fmt.Errorf("无效的选项: %s", choice)
}
return s.connectDatabase(dbConfig)
}
// connectDatabase 连接数据库
func (s *InteractiveSession) connectDatabase(config *DBConfig) error {
fmt.Print("正在连接数据库...")
metadata, err := NewDBMetadata(config)
if err != nil {
fmt.Println(" 失败!")
return err
}
fmt.Println(" 成功!")
s.metadata = metadata
s.dbName = config.DBName
return nil
}
// selectDatabase 选择数据库
func (s *InteractiveSession) selectDatabase() error {
fmt.Println("\n步骤 2/6: 选择数据库")
fmt.Println("----------------------------------------")
ctx := context.Background()
databases, err := s.metadata.GetDatabases(ctx)
if err != nil {
return fmt.Errorf("获取数据库列表失败: %w", err)
}
if len(databases) == 0 {
return fmt.Errorf("未找到任何数据库")
}
fmt.Println("可用数据库:")
for i, db := range databases {
marker := ""
if db == s.dbName {
marker = " (当前)"
}
fmt.Printf(" %d. %s%s\n", i+1, db, marker)
}
// 默认使用当前数据库
defaultIdx := 0
for i, db := range databases {
if db == s.dbName {
defaultIdx = i + 1
break
}
}
choice := s.promptInputWithDefault("请选择数据库", fmt.Sprintf("%d", defaultIdx))
idx, err := strconv.Atoi(choice)
if err != nil || idx < 1 || idx > len(databases) {
return fmt.Errorf("无效的选项: %s", choice)
}
s.dbName = databases[idx-1]
s.config.BusinessDB = s.dbName
fmt.Printf("已选择数据库: %s\n", s.dbName)
return nil
}
// selectTable 选择表
func (s *InteractiveSession) selectTable() error {
fmt.Println("\n步骤 3/6: 选择数据表")
fmt.Println("----------------------------------------")
ctx := context.Background()
tables, err := s.metadata.GetTables(ctx, s.dbName)
if err != nil {
return fmt.Errorf("获取表列表失败: %w", err)
}
if len(tables) == 0 {
return fmt.Errorf("数据库 %s 中未找到任何表", s.dbName)
}
fmt.Printf("数据库 %s 中的表:\n", s.dbName)
for i, table := range tables {
fmt.Printf(" %d. %s\n", i+1, table)
}
choice := s.promptInput("请选择表 (输入序号): ")
idx, err := strconv.Atoi(choice)
if err != nil || idx < 1 || idx > len(tables) {
return fmt.Errorf("无效的选项: %s", choice)
}
s.config.TableName = tables[idx-1]
fmt.Printf("已选择表: %s\n", s.config.TableName)
return nil
}
// configureBasicInfo 配置基本信息
func (s *InteractiveSession) configureBasicInfo() error {
fmt.Println("\n步骤 4/6: 配置基本信息")
fmt.Println("----------------------------------------")
// 根据表名生成默认值
tableName := s.config.TableName
defaultStructName := toCamelCase(tableName)
defaultPackage := strings.ToLower(tableName)
defaultAbbr := strings.ToLower(string(defaultStructName[0])) + defaultStructName[1:]
s.config.StructName = s.promptInputWithDefault("结构体名称", defaultStructName)
s.config.Package = s.promptInputWithDefault("包名", defaultPackage)
s.config.Abbreviation = s.promptInputWithDefault("缩写", defaultAbbr)
s.config.Description = s.promptInputWithDefault("功能描述", s.config.StructName+"管理")
return nil
}
// configureFields 配置字段
func (s *InteractiveSession) configureFields() error {
fmt.Println("\n步骤 5/6: 配置字段")
fmt.Println("----------------------------------------")
ctx := context.Background()
columns, err := s.metadata.GetColumns(ctx, s.dbName, s.config.TableName)
if err != nil {
return fmt.Errorf("获取列信息失败: %w", err)
}
if len(columns) == 0 {
return fmt.Errorf("表 %s 中未找到任何列", s.config.TableName)
}
fmt.Printf("表 %s 的字段:\n", s.config.TableName)
fmt.Println("----------------------------------------")
// 转换列信息为字段配置
s.config.Fields = make([]AutoCodeField, 0, len(columns))
for _, col := range columns {
// 跳过常见的系统字段
if isSystemField(col.ColumnName) {
fmt.Printf(" [跳过] %s (系统字段)\n", col.ColumnName)
continue
}
isPrimaryKey := col.PrimaryKey == 1
field := AutoCodeField{
FieldName: toCamelCase(col.ColumnName),
FieldDesc: col.ColumnComment,
FieldType: mapDBTypeToGoType(col.DataType),
FieldJson: toLowerCamelCase(col.ColumnName),
ColumnName: col.ColumnName,
DataTypeLong: formatDataTypeLong(col.DataType, col.DataTypeLong),
Form: true,
Table: true,
Desc: true, // 默认在详情页显示
Excel: false, // 默认不导出
Require: isPrimaryKey,
Clearable: !isPrimaryKey, // 非主键字段默认可清空
PrimaryKey: isPrimaryKey,
}
if field.FieldDesc == "" {
field.FieldDesc = field.FieldName
}
fmt.Printf(" %s (%s) - %s\n", field.FieldName, field.FieldType, field.FieldDesc)
s.config.Fields = append(s.config.Fields, field)
}
fmt.Println("----------------------------------------")
fmt.Printf("共 %d 个字段\n", len(s.config.Fields))
// 询问是否需要详细配置每个字段
if s.promptConfirm("是否需要详细配置每个字段?") {
if err := s.configureFieldsDetail(); err != nil {
return err
}
}
return nil
}
// configureFieldsDetail 详细配置字段
func (s *InteractiveSession) configureFieldsDetail() error {
for i := range s.config.Fields {
field := &s.config.Fields[i]
fmt.Printf("\n========================================\n")
fmt.Printf("配置字段: %s (%s)\n", field.FieldName, field.FieldDesc)
fmt.Printf("类型: %s, 列名: %s\n", field.FieldType, field.ColumnName)
fmt.Printf("========================================\n")
// 基础显示配置
fmt.Println("\n【显示配置】")
field.Form = s.promptConfirmWithDefault(" 显示在表单中?", field.Form)
field.Table = s.promptConfirmWithDefault(" 显示在表格中?", field.Table)
field.Desc = s.promptConfirmWithDefault(" 显示在详情页?", field.Desc)
field.Excel = s.promptConfirmWithDefault(" 支持导入/导出?", field.Excel)
// 验证配置
fmt.Println("\n【验证配置】")
field.Require = s.promptConfirmWithDefault(" 是否必填?", field.Require)
if field.Require {
errorText := s.promptInput(" 校验失败提示文字 (留空使用默认): ")
if errorText != "" {
field.ErrorText = errorText
}
}
field.Clearable = s.promptConfirmWithDefault(" 是否可清空?", field.Clearable)
// 默认值配置
defaultValue := s.promptInput(" 默认值 (留空跳过): ")
if defaultValue != "" {
field.DefaultValue = defaultValue
}
// 搜索配置
fmt.Println("\n【搜索配置】")
if s.promptConfirm(" 启用搜索?") {
fmt.Println(" 搜索类型:")
fmt.Println(" 1. LIKE (模糊匹配)")
fmt.Println(" 2. EQ (精确匹配)")
fmt.Println(" 3. BETWEEN (范围)")
fmt.Println(" 4. GT (大于)")
fmt.Println(" 5. GTE (大于等于)")
fmt.Println(" 6. LT (小于)")
fmt.Println(" 7. LTE (小于等于)")
fmt.Println(" 8. NEQ (不等于)")
searchChoice := s.promptInputWithDefault(" 请选择搜索类型", "1")
searchTypes := []string{"LIKE", "EQ", "BETWEEN", "GT", "GTE", "LT", "LTE", "NEQ"}
idx, _ := strconv.Atoi(searchChoice)
if idx >= 1 && idx <= len(searchTypes) {
field.FieldSearchType = searchTypes[idx-1]
}
field.FieldSearchHide = s.promptConfirm(" 隐藏搜索条件 (高级搜索)?")
}
// 排序配置
field.Sort = s.promptConfirm(" 启用排序?")
// 字典类型配置
fmt.Println("\n【关联配置】")
dictType := s.promptInput(" 字典类型 (留空跳过): ")
if dictType != "" {
field.DictType = dictType
}
// 数据源配置
if s.promptConfirm(" 配置数据源 (关联其他表)?") {
if err := s.configureDataSource(field); err != nil {
fmt.Printf(" 数据源配置失败: %v\n", err)
}
}
// 索引配置
fmt.Println("\n【索引配置】")
if s.promptConfirm(" 配置索引?") {
fmt.Println(" 索引类型:")
fmt.Println(" 1. index (普通索引)")
fmt.Println(" 2. unique (唯一索引)")
indexChoice := s.promptInputWithDefault(" 请选择索引类型", "1")
indexTypes := []string{"index", "unique"}
idx, _ := strconv.Atoi(indexChoice)
if idx >= 1 && idx <= len(indexTypes) {
field.FieldIndexType = indexTypes[idx-1]
}
}
}
return nil
}
// configureDataSource 配置数据源
func (s *InteractiveSession) configureDataSource(field *AutoCodeField) error {
fmt.Println("\n 【数据源配置】")
// 获取可用的表列表
ctx := context.Background()
tables, err := s.metadata.GetTables(ctx, s.dbName)
if err != nil {
return fmt.Errorf("获取表列表失败: %w", err)
}
fmt.Println(" 可用的表:")
for i, table := range tables {
fmt.Printf(" %d. %s\n", i+1, table)
}
tableChoice := s.promptInput(" 请选择关联表 (输入序号): ")
idx, err := strconv.Atoi(tableChoice)
if err != nil || idx < 1 || idx > len(tables) {
return fmt.Errorf("无效的选项: %s", tableChoice)
}
selectedTable := tables[idx-1]
// 获取表的列信息
columns, err := s.metadata.GetColumns(ctx, s.dbName, selectedTable)
if err != nil {
return fmt.Errorf("获取列信息失败: %w", err)
}
fmt.Printf(" 表 %s 的字段:\n", selectedTable)
for i, col := range columns {
fmt.Printf(" %d. %s (%s)\n", i+1, col.ColumnName, col.DataType)
}
// 选择 Value 字段 (通常是 ID)
valueChoice := s.promptInputWithDefault(" 选择 Value 字段 (通常是 ID)", "1")
valueIdx, _ := strconv.Atoi(valueChoice)
if valueIdx < 1 || valueIdx > len(columns) {
valueIdx = 1
}
valueField := columns[valueIdx-1].ColumnName
// 选择 Label 字段 (显示名称)
labelChoice := s.promptInputWithDefault(" 选择 Label 字段 (显示名称)", "2")
labelIdx, _ := strconv.Atoi(labelChoice)
if labelIdx < 1 || labelIdx > len(columns) {
labelIdx = 2
}
labelField := columns[labelIdx-1].ColumnName
// 关联关系
fmt.Println(" 关联关系:")
fmt.Println(" 1. 一对一")
fmt.Println(" 2. 一对多")
assocChoice := s.promptInputWithDefault(" 请选择关联关系", "1")
association := 1
if assocChoice == "2" {
association = 2
}
// 是否有软删除
hasDeletedAt := s.promptConfirm(" 关联表是否有软删除 (deleted_at)?")
field.DataSource = &DataSource{
DBName: s.dbName,
Table: selectedTable,
Value: valueField,
Label: labelField,
Association: association,
HasDeletedAt: hasDeletedAt,
}
fmt.Printf(" 数据源配置完成: %s.%s (Label: %s, Value: %s)\n",
s.dbName, selectedTable, labelField, valueField)
return nil
}
// configureOptions 配置生成选项
func (s *InteractiveSession) configureOptions() error {
fmt.Println("\n步骤 6/6: 配置生成选项")
fmt.Println("----------------------------------------")
// 基础模型配置
fmt.Println("\n【基础配置】")
s.config.Options.GvaModel = s.promptConfirmWithDefault("使用 KRA_MODEL (包含 ID, CreatedAt, UpdatedAt, DeletedAt)?", true)
s.config.Options.AutoMigrate = s.promptConfirmWithDefault("自动迁移数据库表结构?", true)
// 代码生成范围
fmt.Println("\n【生成范围】")
s.config.Options.GenerateServer = s.promptConfirmWithDefault("生成后端代码?", true)
s.config.Options.GenerateWeb = s.promptConfirmWithDefault("生成前端代码?", true)
// 高级选项
fmt.Println("\n【高级选项】")
s.config.Options.OnlyTemplate = s.promptConfirmWithDefault("仅生成模板 (不注入代码到现有文件)?", false)
s.config.Options.IsAdd = s.promptConfirmWithDefault("追加模式 (不覆盖已有文件)?", false)
// 自动注册配置
if !s.config.Options.OnlyTemplate {
fmt.Println("\n【自动注册配置】")
s.config.Options.AutoCreateApiToSql = s.promptConfirmWithDefault("自动创建 API 记录?", true)
s.config.Options.AutoCreateMenuToSql = s.promptConfirmWithDefault("自动创建菜单记录?", true)
if s.config.Options.AutoCreateMenuToSql {
s.config.Options.AutoCreateBtnAuth = s.promptConfirmWithDefault("自动创建按钮权限?", true)
}
s.config.Options.AutoCreateResource = s.promptConfirmWithDefault("自动创建资源标识?", false)
}
// 树形结构配置
fmt.Println("\n【树形结构配置】")
s.config.Options.IsTree = s.promptConfirm("是否为树形结构数据?")
if s.config.Options.IsTree {
// 显示可用字段供选择
fmt.Println(" 可用字段:")
for i, field := range s.config.Fields {
fmt.Printf(" %d. %s (%s)\n", i+1, field.FieldJson, field.FieldDesc)
}
treeJsonChoice := s.promptInput(" 选择树形展示字段 (输入序号): ")
idx, err := strconv.Atoi(treeJsonChoice)
if err == nil && idx >= 1 && idx <= len(s.config.Fields) {
s.config.Options.TreeJson = s.config.Fields[idx-1].FieldJson
} else {
// 默认使用 name 或第一个字符串字段
s.config.Options.TreeJson = s.findDefaultTreeField()
}
fmt.Printf(" 树形展示字段: %s\n", s.config.Options.TreeJson)
}
// 显示配置摘要
s.printOptionsSummary()
fmt.Println("\n========================================")
fmt.Println("配置完成!")
fmt.Println("========================================")
return nil
}
// findDefaultTreeField 查找默认的树形展示字段
func (s *InteractiveSession) findDefaultTreeField() string {
// 优先查找 name 字段
for _, field := range s.config.Fields {
if strings.ToLower(field.FieldJson) == "name" || strings.ToLower(field.FieldJson) == "title" {
return field.FieldJson
}
}
// 其次查找第一个字符串类型字段
for _, field := range s.config.Fields {
if field.FieldType == "string" {
return field.FieldJson
}
}
// 最后返回第一个字段
if len(s.config.Fields) > 0 {
return s.config.Fields[0].FieldJson
}
return "name"
}
// printOptionsSummary 打印配置摘要
func (s *InteractiveSession) printOptionsSummary() {
fmt.Println("\n========================================")
fmt.Println("配置摘要")
fmt.Println("========================================")
fmt.Println("\n【基本信息】")
fmt.Printf(" 结构体名称: %s\n", s.config.StructName)
fmt.Printf(" 包名: %s\n", s.config.Package)
fmt.Printf(" 表名: %s\n", s.config.TableName)
fmt.Printf(" 描述: %s\n", s.config.Description)
fmt.Printf(" 字段数量: %d\n", len(s.config.Fields))
fmt.Println("\n【生成选项】")
fmt.Printf(" 使用 KRA_MODEL: %s\n", boolToYesNo(s.config.Options.GvaModel))
fmt.Printf(" 自动迁移: %s\n", boolToYesNo(s.config.Options.AutoMigrate))
fmt.Printf(" 生成后端: %s\n", boolToYesNo(s.config.Options.GenerateServer))
fmt.Printf(" 生成前端: %s\n", boolToYesNo(s.config.Options.GenerateWeb))
fmt.Printf(" 仅生成模板: %s\n", boolToYesNo(s.config.Options.OnlyTemplate))
fmt.Printf(" 追加模式: %s\n", boolToYesNo(s.config.Options.IsAdd))
if !s.config.Options.OnlyTemplate {
fmt.Println("\n【自动注册】")
fmt.Printf(" 创建 API 记录: %s\n", boolToYesNo(s.config.Options.AutoCreateApiToSql))
fmt.Printf(" 创建菜单记录: %s\n", boolToYesNo(s.config.Options.AutoCreateMenuToSql))
fmt.Printf(" 创建按钮权限: %s\n", boolToYesNo(s.config.Options.AutoCreateBtnAuth))
fmt.Printf(" 创建资源标识: %s\n", boolToYesNo(s.config.Options.AutoCreateResource))
}
if s.config.Options.IsTree {
fmt.Println("\n【树形结构】")
fmt.Printf(" 树形展示字段: %s\n", s.config.Options.TreeJson)
}
}
// boolToYesNo 将布尔值转换为是/否
func boolToYesNo(b bool) string {
if b {
return "是"
}
return "否"
}
// 辅助函数
// promptInput 提示用户输入
func (s *InteractiveSession) promptInput(prompt string) string {
fmt.Print(prompt)
input, _ := s.reader.ReadString('\n')
return strings.TrimSpace(input)
}
// promptInputWithDefault 提示用户输入,带默认值
func (s *InteractiveSession) promptInputWithDefault(prompt, defaultValue string) string {
fmt.Printf("%s [%s]: ", prompt, defaultValue)
input, _ := s.reader.ReadString('\n')
input = strings.TrimSpace(input)
if input == "" {
return defaultValue
}
return input
}
// promptConfirm 提示用户确认
func (s *InteractiveSession) promptConfirm(prompt string) bool {
fmt.Printf("%s (y/n): ", prompt)
input, _ := s.reader.ReadString('\n')
input = strings.ToLower(strings.TrimSpace(input))
return input == "y" || input == "yes"
}
// promptConfirmWithDefault 提示用户确认,带默认值
func (s *InteractiveSession) promptConfirmWithDefault(prompt string, defaultValue bool) bool {
defaultStr := "n"
if defaultValue {
defaultStr = "y"
}
fmt.Printf("%s (y/n) [%s]: ", prompt, defaultStr)
input, _ := s.reader.ReadString('\n')
input = strings.ToLower(strings.TrimSpace(input))
if input == "" {
return defaultValue
}
return input == "y" || input == "yes"
}
// toCamelCase 转换为大驼峰命名
func toCamelCase(s string) string {
parts := strings.Split(s, "_")
for i := range parts {
if len(parts[i]) > 0 {
parts[i] = strings.ToUpper(string(parts[i][0])) + parts[i][1:]
}
}
return strings.Join(parts, "")
}
// toLowerCamelCase 转换为小驼峰命名
func toLowerCamelCase(s string) string {
camel := toCamelCase(s)
if len(camel) > 0 {
return strings.ToLower(string(camel[0])) + camel[1:]
}
return camel
}
// isSystemField 判断是否为系统字段
func isSystemField(name string) bool {
systemFields := []string{
"id", "created_at", "updated_at", "deleted_at",
"create_time", "update_time", "delete_time",
"created_by", "updated_by", "deleted_by",
}
lowerName := strings.ToLower(name)
for _, sf := range systemFields {
if lowerName == sf {
return true
}
}
return false
}
// mapDBTypeToGoType 映射数据库类型到 Go 类型
func mapDBTypeToGoType(dbType string) string {
dbType = strings.ToLower(dbType)
switch dbType {
case "int", "integer", "smallint", "mediumint", "tinyint":
return "int"
case "bigint":
return "int64"
case "float", "double", "decimal", "numeric", "real":
return "float64"
case "bool", "boolean":
return "bool"
case "date", "datetime", "timestamp", "time":
return "time.Time"
case "json", "jsonb":
return "datatypes.JSON"
case "text", "longtext", "mediumtext", "tinytext", "varchar", "char", "character varying":
return "string"
case "blob", "longblob", "mediumblob", "tinyblob", "bytea":
return "[]byte"
default:
return "string"
}
}
// formatDataTypeLong 格式化数据类型长度
func formatDataTypeLong(dataType, dataTypeLong string) string {
if dataTypeLong != "" && dataTypeLong != "0" && dataTypeLong != "<nil>" {
return fmt.Sprintf("%s(%s)", dataType, dataTypeLong)
}
return dataType
}
// runInteractiveModeWithDB 运行带数据库连接的交互模式
func runInteractiveModeWithDB() (*AutoCodeConfig, error) {
session := NewInteractiveSession()
return session.Run()
}