225 lines
6.6 KiB
Go
225 lines
6.6 KiB
Go
package system
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"errors"
|
||
"fmt"
|
||
"sort"
|
||
|
||
"github.com/go-kratos/kratos/v2/log"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
const (
|
||
Mysql = "mysql"
|
||
Pgsql = "pgsql"
|
||
Sqlite = "sqlite"
|
||
Mssql = "mssql"
|
||
InitSuccess = "\n[%v] --> 初始数据成功!\n"
|
||
InitDataExist = "\n[%v] --> %v 的初始数据已存在!\n"
|
||
InitDataFailed = "\n[%v] --> %v 初始数据失败! \nerr: %+v\n"
|
||
InitDataSuccess = "\n[%v] --> %v 初始数据成功!\n"
|
||
)
|
||
|
||
const (
|
||
InitOrderSystem = 10
|
||
InitOrderInternal = 1000
|
||
InitOrderExternal = 100000
|
||
)
|
||
|
||
var (
|
||
ErrMissingDBContext = errors.New("missing db in context")
|
||
ErrMissingDependentContext = errors.New("missing dependent value in context")
|
||
ErrDBTypeMismatch = errors.New("db type mismatch")
|
||
ErrNoInitializers = errors.New("无可用初始化过程,请检查初始化是否已执行完成")
|
||
)
|
||
|
||
// InitDBRequest 初始化数据库请求
|
||
type InitDBRequest struct {
|
||
AdminPassword string `json:"adminPassword" binding:"required"`
|
||
DBType string `json:"dbType"` // 数据库类型
|
||
Host string `json:"host"` // 服务器地址
|
||
Port string `json:"port"` // 数据库连接端口
|
||
UserName string `json:"userName"` // 数据库用户名
|
||
Password string `json:"password"` // 数据库密码
|
||
DBName string `json:"dbName" binding:"required"` // 数据库名
|
||
DBPath string `json:"dbPath"` // sqlite数据库文件路径
|
||
Template string `json:"template"` // postgresql指定template
|
||
}
|
||
|
||
// SubInitializer 提供 source/*/init() 使用的接口,每个 initializer 完成一个初始化过程
|
||
type SubInitializer interface {
|
||
InitializerName() string // 不一定代表单独一个表,所以改成了更宽泛的语义
|
||
MigrateTable(ctx context.Context) (next context.Context, err error)
|
||
InitializeData(ctx context.Context) (next context.Context, err error)
|
||
TableCreated(ctx context.Context) bool
|
||
DataInserted(ctx context.Context) bool
|
||
}
|
||
|
||
// TypedDBInitHandler 执行传入的 initializer
|
||
type TypedDBInitHandler interface {
|
||
EnsureDB(ctx context.Context, conf *InitDBRequest) (context.Context, error) // 建库,失败属于 fatal error
|
||
WriteConfig(ctx context.Context) error // 回写配置
|
||
InitTables(ctx context.Context, inits InitSlice) error // 建表 handler
|
||
InitData(ctx context.Context, inits InitSlice) error // 建数据 handler
|
||
}
|
||
|
||
// orderedInitializer 组合一个顺序字段,以供排序
|
||
type orderedInitializer struct {
|
||
order int
|
||
SubInitializer
|
||
}
|
||
|
||
// InitSlice 供 initializer 排序依赖时使用
|
||
type InitSlice []*orderedInitializer
|
||
|
||
// InitDBUsecase 数据库初始化用例
|
||
type InitDBUsecase struct {
|
||
log *log.Helper
|
||
initializers InitSlice
|
||
cache map[string]*orderedInitializer
|
||
dbSetter func(*gorm.DB) // 用于设置全局DB的回调
|
||
configWriter func(ctx context.Context, dbType string, config interface{}) error
|
||
}
|
||
|
||
// NewInitDBUsecase 创建数据库初始化用例
|
||
func NewInitDBUsecase(logger log.Logger) *InitDBUsecase {
|
||
return &InitDBUsecase{
|
||
log: log.NewHelper(logger),
|
||
initializers: InitSlice{},
|
||
cache: map[string]*orderedInitializer{},
|
||
}
|
||
}
|
||
|
||
// SetDBSetter 设置数据库设置回调
|
||
func (uc *InitDBUsecase) SetDBSetter(setter func(*gorm.DB)) {
|
||
uc.dbSetter = setter
|
||
}
|
||
|
||
// SetConfigWriter 设置配置写入回调
|
||
func (uc *InitDBUsecase) SetConfigWriter(writer func(ctx context.Context, dbType string, config interface{}) error) {
|
||
uc.configWriter = writer
|
||
}
|
||
|
||
// RegisterInit 注册要执行的初始化过程,会在 InitDB() 时调用
|
||
func (uc *InitDBUsecase) RegisterInit(order int, i SubInitializer) {
|
||
name := i.InitializerName()
|
||
if _, existed := uc.cache[name]; existed {
|
||
panic(fmt.Sprintf("Name conflict on %s", name))
|
||
}
|
||
ni := orderedInitializer{order, i}
|
||
uc.initializers = append(uc.initializers, &ni)
|
||
uc.cache[name] = &ni
|
||
}
|
||
|
||
// InitDB 创建数据库并初始化 总入口
|
||
func (uc *InitDBUsecase) InitDB(ctx context.Context, conf *InitDBRequest) error {
|
||
ctx = context.WithValue(ctx, "adminPassword", conf.AdminPassword)
|
||
if len(uc.initializers) == 0 {
|
||
return ErrNoInitializers
|
||
}
|
||
sort.Sort(&uc.initializers) // 保证有依赖的 initializer 排在后面执行
|
||
|
||
var initHandler TypedDBInitHandler
|
||
switch conf.DBType {
|
||
case "mysql":
|
||
initHandler = NewMysqlInitHandler(uc.log, uc.configWriter)
|
||
ctx = context.WithValue(ctx, "dbtype", "mysql")
|
||
case "pgsql":
|
||
initHandler = NewPgsqlInitHandler(uc.log, uc.configWriter)
|
||
ctx = context.WithValue(ctx, "dbtype", "pgsql")
|
||
case "sqlite":
|
||
initHandler = NewSqliteInitHandler(uc.log, uc.configWriter)
|
||
ctx = context.WithValue(ctx, "dbtype", "sqlite")
|
||
case "mssql":
|
||
initHandler = NewMssqlInitHandler(uc.log, uc.configWriter)
|
||
ctx = context.WithValue(ctx, "dbtype", "mssql")
|
||
default:
|
||
initHandler = NewMysqlInitHandler(uc.log, uc.configWriter)
|
||
ctx = context.WithValue(ctx, "dbtype", "mysql")
|
||
}
|
||
|
||
var err error
|
||
ctx, err = initHandler.EnsureDB(ctx, conf)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
db := ctx.Value("db").(*gorm.DB)
|
||
if uc.dbSetter != nil {
|
||
uc.dbSetter(db)
|
||
}
|
||
|
||
if err = initHandler.InitTables(ctx, uc.initializers); err != nil {
|
||
return err
|
||
}
|
||
if err = initHandler.InitData(ctx, uc.initializers); err != nil {
|
||
return err
|
||
}
|
||
|
||
if err = initHandler.WriteConfig(ctx); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 清空初始化器
|
||
uc.initializers = InitSlice{}
|
||
uc.cache = map[string]*orderedInitializer{}
|
||
return nil
|
||
}
|
||
|
||
// CheckDB 检查数据库是否需要初始化
|
||
func (uc *InitDBUsecase) CheckDB(ctx context.Context) bool {
|
||
return len(uc.initializers) > 0
|
||
}
|
||
|
||
// createDatabase 创建数据库( EnsureDB() 中调用 )
|
||
func createDatabase(dsn string, driver string, createSql string) error {
|
||
db, err := sql.Open(driver, dsn)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer func(db *sql.DB) {
|
||
err = db.Close()
|
||
if err != nil {
|
||
fmt.Println(err)
|
||
}
|
||
}(db)
|
||
if err = db.Ping(); err != nil {
|
||
return err
|
||
}
|
||
_, err = db.Exec(createSql)
|
||
return err
|
||
}
|
||
|
||
// createTables 创建表(默认 dbInitHandler.initTables 行为)
|
||
func createTables(ctx context.Context, inits InitSlice) error {
|
||
next, cancel := context.WithCancel(ctx)
|
||
defer cancel()
|
||
for _, init := range inits {
|
||
if init.TableCreated(next) {
|
||
continue
|
||
}
|
||
if n, err := init.MigrateTable(next); err != nil {
|
||
return err
|
||
} else {
|
||
next = n
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
/* -- sortable interface -- */
|
||
|
||
func (a InitSlice) Len() int {
|
||
return len(a)
|
||
}
|
||
|
||
func (a InitSlice) Less(i, j int) bool {
|
||
return a[i].order < a[j].order
|
||
}
|
||
|
||
func (a InitSlice) Swap(i, j int) {
|
||
a[i], a[j] = a[j], a[i]
|
||
}
|