kra/internal/biz/system/initdb.go

225 lines
6.6 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package system
import (
"context"
"database/sql"
"errors"
"fmt"
"sort"
"github.com/go-kratos/kratos/v2/log"
"gorm.io/gorm"
)
const (
Mysql = "mysql"
Pgsql = "pgsql"
Sqlite = "sqlite"
Mssql = "mssql"
InitSuccess = "\n[%v] --> 初始数据成功!\n"
InitDataExist = "\n[%v] --> %v 的初始数据已存在!\n"
InitDataFailed = "\n[%v] --> %v 初始数据失败! \nerr: %+v\n"
InitDataSuccess = "\n[%v] --> %v 初始数据成功!\n"
)
const (
InitOrderSystem = 10
InitOrderInternal = 1000
InitOrderExternal = 100000
)
var (
ErrMissingDBContext = errors.New("missing db in context")
ErrMissingDependentContext = errors.New("missing dependent value in context")
ErrDBTypeMismatch = errors.New("db type mismatch")
ErrNoInitializers = errors.New("无可用初始化过程,请检查初始化是否已执行完成")
)
// InitDBRequest 初始化数据库请求
type InitDBRequest struct {
AdminPassword string `json:"adminPassword" binding:"required"`
DBType string `json:"dbType"` // 数据库类型
Host string `json:"host"` // 服务器地址
Port string `json:"port"` // 数据库连接端口
UserName string `json:"userName"` // 数据库用户名
Password string `json:"password"` // 数据库密码
DBName string `json:"dbName" binding:"required"` // 数据库名
DBPath string `json:"dbPath"` // sqlite数据库文件路径
Template string `json:"template"` // postgresql指定template
}
// SubInitializer 提供 source/*/init() 使用的接口,每个 initializer 完成一个初始化过程
type SubInitializer interface {
InitializerName() string // 不一定代表单独一个表,所以改成了更宽泛的语义
MigrateTable(ctx context.Context) (next context.Context, err error)
InitializeData(ctx context.Context) (next context.Context, err error)
TableCreated(ctx context.Context) bool
DataInserted(ctx context.Context) bool
}
// TypedDBInitHandler 执行传入的 initializer
type TypedDBInitHandler interface {
EnsureDB(ctx context.Context, conf *InitDBRequest) (context.Context, error) // 建库,失败属于 fatal error
WriteConfig(ctx context.Context) error // 回写配置
InitTables(ctx context.Context, inits InitSlice) error // 建表 handler
InitData(ctx context.Context, inits InitSlice) error // 建数据 handler
}
// orderedInitializer 组合一个顺序字段,以供排序
type orderedInitializer struct {
order int
SubInitializer
}
// InitSlice 供 initializer 排序依赖时使用
type InitSlice []*orderedInitializer
// InitDBUsecase 数据库初始化用例
type InitDBUsecase struct {
log *log.Helper
initializers InitSlice
cache map[string]*orderedInitializer
dbSetter func(*gorm.DB) // 用于设置全局DB的回调
configWriter func(ctx context.Context, dbType string, config interface{}) error
}
// NewInitDBUsecase 创建数据库初始化用例
func NewInitDBUsecase(logger log.Logger) *InitDBUsecase {
return &InitDBUsecase{
log: log.NewHelper(logger),
initializers: InitSlice{},
cache: map[string]*orderedInitializer{},
}
}
// SetDBSetter 设置数据库设置回调
func (uc *InitDBUsecase) SetDBSetter(setter func(*gorm.DB)) {
uc.dbSetter = setter
}
// SetConfigWriter 设置配置写入回调
func (uc *InitDBUsecase) SetConfigWriter(writer func(ctx context.Context, dbType string, config interface{}) error) {
uc.configWriter = writer
}
// RegisterInit 注册要执行的初始化过程,会在 InitDB() 时调用
func (uc *InitDBUsecase) RegisterInit(order int, i SubInitializer) {
name := i.InitializerName()
if _, existed := uc.cache[name]; existed {
panic(fmt.Sprintf("Name conflict on %s", name))
}
ni := orderedInitializer{order, i}
uc.initializers = append(uc.initializers, &ni)
uc.cache[name] = &ni
}
// InitDB 创建数据库并初始化 总入口
func (uc *InitDBUsecase) InitDB(ctx context.Context, conf *InitDBRequest) error {
ctx = context.WithValue(ctx, "adminPassword", conf.AdminPassword)
if len(uc.initializers) == 0 {
return ErrNoInitializers
}
sort.Sort(&uc.initializers) // 保证有依赖的 initializer 排在后面执行
var initHandler TypedDBInitHandler
switch conf.DBType {
case "mysql":
initHandler = NewMysqlInitHandler(uc.log, uc.configWriter)
ctx = context.WithValue(ctx, "dbtype", "mysql")
case "pgsql":
initHandler = NewPgsqlInitHandler(uc.log, uc.configWriter)
ctx = context.WithValue(ctx, "dbtype", "pgsql")
case "sqlite":
initHandler = NewSqliteInitHandler(uc.log, uc.configWriter)
ctx = context.WithValue(ctx, "dbtype", "sqlite")
case "mssql":
initHandler = NewMssqlInitHandler(uc.log, uc.configWriter)
ctx = context.WithValue(ctx, "dbtype", "mssql")
default:
initHandler = NewMysqlInitHandler(uc.log, uc.configWriter)
ctx = context.WithValue(ctx, "dbtype", "mysql")
}
var err error
ctx, err = initHandler.EnsureDB(ctx, conf)
if err != nil {
return err
}
db := ctx.Value("db").(*gorm.DB)
if uc.dbSetter != nil {
uc.dbSetter(db)
}
if err = initHandler.InitTables(ctx, uc.initializers); err != nil {
return err
}
if err = initHandler.InitData(ctx, uc.initializers); err != nil {
return err
}
if err = initHandler.WriteConfig(ctx); err != nil {
return err
}
// 清空初始化器
uc.initializers = InitSlice{}
uc.cache = map[string]*orderedInitializer{}
return nil
}
// CheckDB 检查数据库是否需要初始化
func (uc *InitDBUsecase) CheckDB(ctx context.Context) bool {
return len(uc.initializers) > 0
}
// createDatabase 创建数据库( EnsureDB() 中调用
func createDatabase(dsn string, driver string, createSql string) error {
db, err := sql.Open(driver, dsn)
if err != nil {
return err
}
defer func(db *sql.DB) {
err = db.Close()
if err != nil {
fmt.Println(err)
}
}(db)
if err = db.Ping(); err != nil {
return err
}
_, err = db.Exec(createSql)
return err
}
// createTables 创建表(默认 dbInitHandler.initTables 行为)
func createTables(ctx context.Context, inits InitSlice) error {
next, cancel := context.WithCancel(ctx)
defer cancel()
for _, init := range inits {
if init.TableCreated(next) {
continue
}
if n, err := init.MigrateTable(next); err != nil {
return err
} else {
next = n
}
}
return nil
}
/* -- sortable interface -- */
func (a InitSlice) Len() int {
return len(a)
}
func (a InitSlice) Less(i, j int) bool {
return a[i].order < a[j].order
}
func (a InitSlice) Swap(i, j int) {
a[i], a[j] = a[j], a[i]
}