package system import ( "context" "errors" "path/filepath" "github.com/go-kratos/kratos/v2/log" "gorm.io/driver/sqlserver" "gorm.io/gorm" ) // MssqlConfig MSSQL配置 type MssqlConfig struct { Path string Port string Config string Dbname string Username string Password string MaxIdleConns int MaxOpenConns int LogMode string } // Dsn 生成MSSQL DSN func (c *MssqlConfig) Dsn() string { return "sqlserver://" + c.Username + ":" + c.Password + "@" + c.Path + ":" + c.Port + "?database=" + c.Dbname + "&encrypt=disable" } // MssqlInitHandler MSSQL初始化处理器 type MssqlInitHandler struct { log *log.Helper configWriter func(ctx context.Context, dbType string, config interface{}) error } // NewMssqlInitHandler 创建MSSQL初始化处理器 func NewMssqlInitHandler(logger *log.Helper, configWriter func(ctx context.Context, dbType string, config interface{}) error) *MssqlInitHandler { return &MssqlInitHandler{ log: logger, configWriter: configWriter, } } // WriteConfig mssql回写配置 func (h *MssqlInitHandler) WriteConfig(ctx context.Context) error { c, ok := ctx.Value("config").(MssqlConfig) if !ok { return errors.New("mssql config invalid") } if h.configWriter != nil { return h.configWriter(ctx, "mssql", c) } return nil } // EnsureDB 创建数据库并初始化 mssql func (h *MssqlInitHandler) EnsureDB(ctx context.Context, conf *InitDBRequest) (next context.Context, err error) { if s, ok := ctx.Value("dbtype").(string); !ok || s != "mssql" { return ctx, ErrDBTypeMismatch } c := h.toMssqlConfig(conf) next = context.WithValue(ctx, "config", c) if c.Dbname == "" { return ctx, nil } // 如果没有数据库名, 则跳出初始化数据 dsn := h.mssqlEmptyDsn(conf) mssqlConfig := sqlserver.Config{ DSN: dsn, // DSN data source name DefaultStringSize: 191, // string 类型字段的默认长度 } var db *gorm.DB if db, err = gorm.Open(sqlserver.New(mssqlConfig), &gorm.Config{DisableForeignKeyConstraintWhenMigrating: true}); err != nil { return nil, err } // 设置AutoCode根目录 autoCodeRoot, _ := filepath.Abs("..") next = context.WithValue(next, "autoCodeRoot", autoCodeRoot) next = context.WithValue(next, "db", db) return next, err } // InitTables 初始化表 func (h *MssqlInitHandler) InitTables(ctx context.Context, inits InitSlice) error { return createTables(ctx, inits) } // InitData 初始化数据 func (h *MssqlInitHandler) InitData(ctx context.Context, inits InitSlice) error { next, cancel := context.WithCancel(ctx) defer cancel() for _, init := range inits { if init.DataInserted(next) { h.log.Infof(InitDataExist, Mssql, init.InitializerName()) continue } if n, err := init.InitializeData(next); err != nil { h.log.Errorf(InitDataFailed, Mssql, init.InitializerName(), err) return err } else { next = n h.log.Infof(InitDataSuccess, Mssql, init.InitializerName()) } } h.log.Infof(InitSuccess, Mssql) return nil } // mssqlEmptyDsn mssql 空数据库 建库链接 func (h *MssqlInitHandler) mssqlEmptyDsn(i *InitDBRequest) string { return "sqlserver://" + i.UserName + ":" + i.Password + "@" + i.Host + ":" + i.Port + "?database=" + i.DBName + "&encrypt=disable" } // toMssqlConfig 转换为MSSQL配置 func (h *MssqlInitHandler) toMssqlConfig(i *InitDBRequest) MssqlConfig { return MssqlConfig{ Path: i.Host, Port: i.Port, Dbname: i.DBName, Username: i.UserName, Password: i.Password, MaxIdleConns: 10, MaxOpenConns: 100, LogMode: "error", } }