package system import ( "context" "database/sql" "errors" "fmt" "sort" "github.com/go-kratos/kratos/v2/log" "gorm.io/gorm" ) const ( Mysql = "mysql" Pgsql = "pgsql" Sqlite = "sqlite" Mssql = "mssql" InitSuccess = "\n[%v] --> 初始数据成功!\n" InitDataExist = "\n[%v] --> %v 的初始数据已存在!\n" InitDataFailed = "\n[%v] --> %v 初始数据失败! \nerr: %+v\n" InitDataSuccess = "\n[%v] --> %v 初始数据成功!\n" ) const ( InitOrderSystem = 10 InitOrderInternal = 1000 InitOrderExternal = 100000 ) var ( ErrMissingDBContext = errors.New("missing db in context") ErrMissingDependentContext = errors.New("missing dependent value in context") ErrDBTypeMismatch = errors.New("db type mismatch") ErrNoInitializers = errors.New("无可用初始化过程,请检查初始化是否已执行完成") ) // InitDBRequest 初始化数据库请求 type InitDBRequest struct { AdminPassword string `json:"adminPassword" binding:"required"` DBType string `json:"dbType"` // 数据库类型 Host string `json:"host"` // 服务器地址 Port string `json:"port"` // 数据库连接端口 UserName string `json:"userName"` // 数据库用户名 Password string `json:"password"` // 数据库密码 DBName string `json:"dbName" binding:"required"` // 数据库名 DBPath string `json:"dbPath"` // sqlite数据库文件路径 Template string `json:"template"` // postgresql指定template } // SubInitializer 提供 source/*/init() 使用的接口,每个 initializer 完成一个初始化过程 type SubInitializer interface { InitializerName() string // 不一定代表单独一个表,所以改成了更宽泛的语义 MigrateTable(ctx context.Context) (next context.Context, err error) InitializeData(ctx context.Context) (next context.Context, err error) TableCreated(ctx context.Context) bool DataInserted(ctx context.Context) bool } // TypedDBInitHandler 执行传入的 initializer type TypedDBInitHandler interface { EnsureDB(ctx context.Context, conf *InitDBRequest) (context.Context, error) // 建库,失败属于 fatal error WriteConfig(ctx context.Context) error // 回写配置 InitTables(ctx context.Context, inits InitSlice) error // 建表 handler InitData(ctx context.Context, inits InitSlice) error // 建数据 handler } // orderedInitializer 组合一个顺序字段,以供排序 type orderedInitializer struct { order int SubInitializer } // InitSlice 供 initializer 排序依赖时使用 type InitSlice []*orderedInitializer // InitDBUsecase 数据库初始化用例 type InitDBUsecase struct { log *log.Helper initializers InitSlice cache map[string]*orderedInitializer dbSetter func(*gorm.DB) // 用于设置全局DB的回调 configWriter func(ctx context.Context, dbType string, config interface{}) error } // NewInitDBUsecase 创建数据库初始化用例 func NewInitDBUsecase(logger log.Logger) *InitDBUsecase { return &InitDBUsecase{ log: log.NewHelper(logger), initializers: InitSlice{}, cache: map[string]*orderedInitializer{}, } } // SetDBSetter 设置数据库设置回调 func (uc *InitDBUsecase) SetDBSetter(setter func(*gorm.DB)) { uc.dbSetter = setter } // SetConfigWriter 设置配置写入回调 func (uc *InitDBUsecase) SetConfigWriter(writer func(ctx context.Context, dbType string, config interface{}) error) { uc.configWriter = writer } // RegisterInit 注册要执行的初始化过程,会在 InitDB() 时调用 func (uc *InitDBUsecase) RegisterInit(order int, i SubInitializer) { name := i.InitializerName() if _, existed := uc.cache[name]; existed { panic(fmt.Sprintf("Name conflict on %s", name)) } ni := orderedInitializer{order, i} uc.initializers = append(uc.initializers, &ni) uc.cache[name] = &ni } // InitDB 创建数据库并初始化 总入口 func (uc *InitDBUsecase) InitDB(ctx context.Context, conf *InitDBRequest) error { ctx = context.WithValue(ctx, "adminPassword", conf.AdminPassword) if len(uc.initializers) == 0 { return ErrNoInitializers } sort.Sort(&uc.initializers) // 保证有依赖的 initializer 排在后面执行 var initHandler TypedDBInitHandler switch conf.DBType { case "mysql": initHandler = NewMysqlInitHandler(uc.log, uc.configWriter) ctx = context.WithValue(ctx, "dbtype", "mysql") case "pgsql": initHandler = NewPgsqlInitHandler(uc.log, uc.configWriter) ctx = context.WithValue(ctx, "dbtype", "pgsql") case "sqlite": initHandler = NewSqliteInitHandler(uc.log, uc.configWriter) ctx = context.WithValue(ctx, "dbtype", "sqlite") case "mssql": initHandler = NewMssqlInitHandler(uc.log, uc.configWriter) ctx = context.WithValue(ctx, "dbtype", "mssql") default: initHandler = NewMysqlInitHandler(uc.log, uc.configWriter) ctx = context.WithValue(ctx, "dbtype", "mysql") } var err error ctx, err = initHandler.EnsureDB(ctx, conf) if err != nil { return err } db := ctx.Value("db").(*gorm.DB) if uc.dbSetter != nil { uc.dbSetter(db) } if err = initHandler.InitTables(ctx, uc.initializers); err != nil { return err } if err = initHandler.InitData(ctx, uc.initializers); err != nil { return err } if err = initHandler.WriteConfig(ctx); err != nil { return err } // 清空初始化器 uc.initializers = InitSlice{} uc.cache = map[string]*orderedInitializer{} return nil } // CheckDB 检查数据库是否需要初始化 func (uc *InitDBUsecase) CheckDB(ctx context.Context) bool { return len(uc.initializers) > 0 } // createDatabase 创建数据库( EnsureDB() 中调用 ) func createDatabase(dsn string, driver string, createSql string) error { db, err := sql.Open(driver, dsn) if err != nil { return err } defer func(db *sql.DB) { err = db.Close() if err != nil { fmt.Println(err) } }(db) if err = db.Ping(); err != nil { return err } _, err = db.Exec(createSql) return err } // createTables 创建表(默认 dbInitHandler.initTables 行为) func createTables(ctx context.Context, inits InitSlice) error { next, cancel := context.WithCancel(ctx) defer cancel() for _, init := range inits { if init.TableCreated(next) { continue } if n, err := init.MigrateTable(next); err != nil { return err } else { next = n } } return nil } /* -- sortable interface -- */ func (a InitSlice) Len() int { return len(a) } func (a InitSlice) Less(i, j int) bool { return a[i].order < a[j].order } func (a InitSlice) Swap(i, j int) { a[i], a[j] = a[j], a[i] }