diff --git a/models/models.go b/models/models.go index 4e2e08cf839e..5558ec006239 100644 --- a/models/models.go +++ b/models/models.go @@ -55,11 +55,12 @@ func LoadModelsConfig() { DbCfg.Path = setting.Cfg.MustValue("database", "PATH", "data/gogs.db") } -func NewTestEngine(x *xorm.Engine) (err error) { +func getEngine() (*xorm.Engine, error) { + cnnstr := "" switch DbCfg.Type { case "mysql": - x, err = xorm.NewEngine("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8", - DbCfg.User, DbCfg.Pwd, DbCfg.Host, DbCfg.Name)) + cnnstr = fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8", + DbCfg.User, DbCfg.Pwd, DbCfg.Host, DbCfg.Name) case "postgres": var host, port = "127.0.0.1", "5432" fields := strings.Split(DbCfg.Host, ":") @@ -69,46 +70,31 @@ func NewTestEngine(x *xorm.Engine) (err error) { if len(fields) > 1 && len(strings.TrimSpace(fields[1])) > 0 { port = fields[1] } - cnnstr := fmt.Sprintf("user=%s password=%s host=%s port=%s dbname=%s sslmode=%s", + cnnstr = fmt.Sprintf("user=%s password=%s host=%s port=%s dbname=%s sslmode=%s", DbCfg.User, DbCfg.Pwd, host, port, DbCfg.Name, DbCfg.SslMode) - x, err = xorm.NewEngine("postgres", cnnstr) case "sqlite3": if !EnableSQLite3 { - return fmt.Errorf("Unknown database type: %s", DbCfg.Type) + return nil, fmt.Errorf("Unknown database type: %s", DbCfg.Type) } os.MkdirAll(path.Dir(DbCfg.Path), os.ModePerm) - x, err = xorm.NewEngine("sqlite3", DbCfg.Path) + cnnstr = DbCfg.Path default: - return fmt.Errorf("Unknown database type: %s", DbCfg.Type) + return nil, fmt.Errorf("Unknown database type: %s", DbCfg.Type) } + return xorm.NewEngine(DbCfg.Type, cnnstr) +} + +func NewTestEngine(x *xorm.Engine) (err error) { + x, err = getEngine() if err != nil { return fmt.Errorf("models.init(fail to conntect database): %v", err) } + return x.Sync(tables...) } func SetEngine() (err error) { - switch DbCfg.Type { - case "mysql": - x, err = xorm.NewEngine("mysql", fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8", - DbCfg.User, DbCfg.Pwd, DbCfg.Host, DbCfg.Name)) - case "postgres": - var host, port = "127.0.0.1", "5432" - fields := strings.Split(DbCfg.Host, ":") - if len(fields) > 0 && len(strings.TrimSpace(fields[0])) > 0 { - host = fields[0] - } - if len(fields) > 1 && len(strings.TrimSpace(fields[1])) > 0 { - port = fields[1] - } - x, err = xorm.NewEngine("postgres", fmt.Sprintf("user=%s password=%s host=%s port=%s dbname=%s sslmode=%s", - DbCfg.User, DbCfg.Pwd, host, port, DbCfg.Name, DbCfg.SslMode)) - case "sqlite3": - os.MkdirAll(path.Dir(DbCfg.Path), os.ModePerm) - x, err = xorm.NewEngine("sqlite3", DbCfg.Path) - default: - return fmt.Errorf("Unknown database type: %s", DbCfg.Type) - } + x, err = getEngine() if err != nil { return fmt.Errorf("models.init(fail to conntect database): %v", err) }