package system import ( "context" "errors" "fmt" "path/filepath" "github.com/go-kratos/kratos/v2/log" "gorm.io/driver/postgres" "gorm.io/gorm" ) // PgsqlConfig PostgreSQL配置 type PgsqlConfig struct { Path string Port string Config string Dbname string Username string Password string MaxIdleConns int MaxOpenConns int LogMode string } // Dsn 生成PostgreSQL DSN func (c *PgsqlConfig) Dsn() string { return "host=" + c.Path + " user=" + c.Username + " password=" + c.Password + " dbname=" + c.Dbname + " port=" + c.Port + " " + c.Config } // PgsqlInitHandler PostgreSQL初始化处理器 type PgsqlInitHandler struct { log *log.Helper configWriter func(ctx context.Context, dbType string, config interface{}) error } // NewPgsqlInitHandler 创建PostgreSQL初始化处理器 func NewPgsqlInitHandler(logger *log.Helper, configWriter func(ctx context.Context, dbType string, config interface{}) error) *PgsqlInitHandler { return &PgsqlInitHandler{ log: logger, configWriter: configWriter, } } // WriteConfig pgsql 回写配置 func (h *PgsqlInitHandler) WriteConfig(ctx context.Context) error { c, ok := ctx.Value("config").(PgsqlConfig) if !ok { return errors.New("postgresql config invalid") } if h.configWriter != nil { return h.configWriter(ctx, "pgsql", c) } return nil } // EnsureDB 创建数据库并初始化 pg func (h *PgsqlInitHandler) EnsureDB(ctx context.Context, conf *InitDBRequest) (next context.Context, err error) { if s, ok := ctx.Value("dbtype").(string); !ok || s != "pgsql" { return ctx, ErrDBTypeMismatch } c := h.toPgsqlConfig(conf) next = context.WithValue(ctx, "config", c) if c.Dbname == "" { return ctx, nil } // 如果没有数据库名, 则跳出初始化数据 dsn := h.pgsqlEmptyDsn(conf) var createSql string if conf.Template != "" { createSql = fmt.Sprintf("CREATE DATABASE %s WITH TEMPLATE %s;", c.Dbname, conf.Template) } else { createSql = fmt.Sprintf("CREATE DATABASE %s;", c.Dbname) } if err = createDatabase(dsn, "pgx", createSql); err != nil { return nil, err } // 创建数据库 var db *gorm.DB if db, err = gorm.Open(postgres.New(postgres.Config{ DSN: c.Dsn(), // DSN data source name PreferSimpleProtocol: false, }), &gorm.Config{DisableForeignKeyConstraintWhenMigrating: true}); err != nil { return ctx, err } // 设置AutoCode根目录 autoCodeRoot, _ := filepath.Abs("..") next = context.WithValue(next, "autoCodeRoot", autoCodeRoot) next = context.WithValue(next, "db", db) return next, err } // InitTables 初始化表 func (h *PgsqlInitHandler) InitTables(ctx context.Context, inits InitSlice) error { return createTables(ctx, inits) } // InitData 初始化数据 func (h *PgsqlInitHandler) InitData(ctx context.Context, inits InitSlice) error { next, cancel := context.WithCancel(ctx) defer cancel() for i := 0; i < len(inits); i++ { if inits[i].DataInserted(next) { h.log.Infof(InitDataExist, Pgsql, inits[i].InitializerName()) continue } if n, err := inits[i].InitializeData(next); err != nil { h.log.Errorf(InitDataFailed, Pgsql, inits[i].InitializerName(), err) return err } else { next = n h.log.Infof(InitDataSuccess, Pgsql, inits[i].InitializerName()) } } h.log.Infof(InitSuccess, Pgsql) return nil } // pgsqlEmptyDsn pgsql 空数据库 建库链接 func (h *PgsqlInitHandler) pgsqlEmptyDsn(i *InitDBRequest) string { host := i.Host port := i.Port if host == "" { host = "127.0.0.1" } if port == "" { port = "5432" } return "host=" + host + " user=" + i.UserName + " password=" + i.Password + " port=" + port + " dbname=" + "postgres" + " " + "sslmode=disable TimeZone=Asia/Shanghai" } // toPgsqlConfig 转换为PostgreSQL配置 func (h *PgsqlInitHandler) toPgsqlConfig(i *InitDBRequest) PgsqlConfig { return PgsqlConfig{ Path: i.Host, Port: i.Port, Dbname: i.DBName, Username: i.UserName, Password: i.Password, MaxIdleConns: 10, MaxOpenConns: 100, LogMode: "error", Config: "sslmode=disable TimeZone=Asia/Shanghai", } }