refactor: isolate db handler
This commit is contained in:
parent
09fd9f1ba5
commit
e68ded1b3c
5 changed files with 55 additions and 34 deletions
11
cobble.go
11
cobble.go
|
@ -18,13 +18,8 @@ var up = &cobra.Command{
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
dbPort, _ := strconv.Atoi(os.Getenv("DB_PORT"))
|
dbPort, _ := strconv.Atoi(os.Getenv("DB_PORT"))
|
||||||
|
|
||||||
dbConnection, err := ConnectToDatabase(DatabaseConfiguration{
|
db := NewDatabase(os.Getenv("DB_HOST"), dbPort)
|
||||||
User: os.Getenv("DB_APP_USER"),
|
err := db.Connect(os.Getenv("DB_APP_USER"), os.Getenv("DB_APP_PASSWORD"), os.Getenv("DB_NAME"))
|
||||||
Password: os.Getenv("DB_APP_PASSWORD"),
|
|
||||||
DatabaseName: os.Getenv("DB_NAME"),
|
|
||||||
Host: os.Getenv("DB_HOST"),
|
|
||||||
Port: dbPort,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to connect to the database: %s", err)
|
log.Printf("Failed to connect to the database: %s", err)
|
||||||
|
@ -36,7 +31,7 @@ var up = &cobra.Command{
|
||||||
|
|
||||||
for _, migration := range migrationHistory {
|
for _, migration := range migrationHistory {
|
||||||
log.Printf("%s", migration.Name)
|
log.Printf("%s", migration.Name)
|
||||||
if err = migrate(dbConnection, migration.Path); err != nil {
|
if err = migrate(db, migration.Path); err != nil {
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
45
database.go
Normal file
45
database.go
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
_ "github.com/lib/pq"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Database struct {
|
||||||
|
Host string
|
||||||
|
Port int
|
||||||
|
connection *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDatabase(host string, port int) Database {
|
||||||
|
return Database{
|
||||||
|
Host: host,
|
||||||
|
Port: port,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) Connect(user string, password string, dbname string) error {
|
||||||
|
connectionString := fmt.Sprintf("postgresql://%s:%s@%s:%d?sslmode=disable", user, password, d.Host, d.Port)
|
||||||
|
|
||||||
|
conn, err := sql.Open(dbname, connectionString)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
d.connection = conn
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d Database) Execute(sql string) error {
|
||||||
|
if d.connection == nil {
|
||||||
|
return errors.New("Cannot execute SQL without being connected to a database.")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := d.connection.Exec(sql)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
11
migrate.go
11
migrate.go
|
@ -8,9 +8,7 @@ package main
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
_ "github.com/lib/pq"
|
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
@ -41,7 +39,7 @@ func migrate(db DB, migrationPath string) error {
|
||||||
|
|
||||||
log.Printf("SQL: %s", migrationSql)
|
log.Printf("SQL: %s", migrationSql)
|
||||||
|
|
||||||
_, err = db.Exec(migrationSql)
|
err = db.Execute(migrationSql)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to run %s: %s", migrationPath, err)
|
log.Printf("Failed to run %s: %s", migrationPath, err)
|
||||||
|
@ -126,8 +124,3 @@ func getMigrations(migrationsRoot string) ([]Migration, error) {
|
||||||
|
|
||||||
return migrationGraph.GetLinearHistory()
|
return migrationGraph.GetLinearHistory()
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConnectToDatabase(config DatabaseConfiguration) (*sql.DB, error) {
|
|
||||||
connectionString := fmt.Sprintf("postgresql://%s:%s@%s:%d?sslmode=disable", config.User, config.Password, config.Host, config.Port)
|
|
||||||
return sql.Open(config.DatabaseName, connectionString)
|
|
||||||
}
|
|
||||||
|
|
|
@ -19,18 +19,18 @@ type MockResult struct {
|
||||||
sql.Result
|
sql.Result
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockSqlDB) Exec(sql string, args ...any) (sql.Result, error) {
|
func (m *MockSqlDB) Execute(sql string) error {
|
||||||
m.queries = append(m.queries, sql)
|
m.queries = append(m.queries, sql)
|
||||||
|
|
||||||
if m.mockedError != nil {
|
if m.mockedError != nil {
|
||||||
return MockResult{}, m.mockedError
|
return m.mockedError
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.mockedResult != nil {
|
if m.mockedResult != nil {
|
||||||
return m.mockedResult, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return MockResult{}, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetMigrationsGathersSqlFiles(t *testing.T) {
|
func TestGetMigrationsGathersSqlFiles(t *testing.T) {
|
||||||
|
|
14
models.go
14
models.go
|
@ -1,9 +1,5 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MigrationHeaders struct {
|
type MigrationHeaders struct {
|
||||||
Requirements []string
|
Requirements []string
|
||||||
}
|
}
|
||||||
|
@ -15,14 +11,6 @@ type Migration struct {
|
||||||
Run bool
|
Run bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type DatabaseConfiguration struct {
|
|
||||||
Host string
|
|
||||||
User string
|
|
||||||
Password string
|
|
||||||
DatabaseName string
|
|
||||||
Port int
|
|
||||||
}
|
|
||||||
|
|
||||||
type MigrationGraph struct {
|
type MigrationGraph struct {
|
||||||
// Reference to the root of the graph.
|
// Reference to the root of the graph.
|
||||||
Root *Migration
|
Root *Migration
|
||||||
|
@ -33,5 +21,5 @@ type MigrationGraph struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type DB interface {
|
type DB interface {
|
||||||
Exec(sql string, args ...any) (sql.Result, error)
|
Execute(sql string) error
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue