refactor: isolate db handler

This commit is contained in:
Marc 2024-07-11 00:09:24 -04:00
parent 09fd9f1ba5
commit e68ded1b3c
Signed by: marc
GPG key ID: 048E042F22B5DC79
5 changed files with 55 additions and 34 deletions

View file

@ -18,13 +18,8 @@ var up = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
dbPort, _ := strconv.Atoi(os.Getenv("DB_PORT")) dbPort, _ := strconv.Atoi(os.Getenv("DB_PORT"))
dbConnection, err := ConnectToDatabase(DatabaseConfiguration{ db := NewDatabase(os.Getenv("DB_HOST"), dbPort)
User: os.Getenv("DB_APP_USER"), err := db.Connect(os.Getenv("DB_APP_USER"), os.Getenv("DB_APP_PASSWORD"), os.Getenv("DB_NAME"))
Password: os.Getenv("DB_APP_PASSWORD"),
DatabaseName: os.Getenv("DB_NAME"),
Host: os.Getenv("DB_HOST"),
Port: dbPort,
})
if err != nil { if err != nil {
log.Printf("Failed to connect to the database: %s", err) log.Printf("Failed to connect to the database: %s", err)
@ -36,7 +31,7 @@ var up = &cobra.Command{
for _, migration := range migrationHistory { for _, migration := range migrationHistory {
log.Printf("%s", migration.Name) log.Printf("%s", migration.Name)
if err = migrate(dbConnection, migration.Path); err != nil { if err = migrate(db, migration.Path); err != nil {
os.Exit(1) os.Exit(1)
} }
} }

45
database.go Normal file
View file

@ -0,0 +1,45 @@
package main
import (
"database/sql"
"errors"
"fmt"
_ "github.com/lib/pq"
)
type Database struct {
Host string
Port int
connection *sql.DB
}
func NewDatabase(host string, port int) Database {
return Database{
Host: host,
Port: port,
}
}
func (d *Database) Connect(user string, password string, dbname string) error {
connectionString := fmt.Sprintf("postgresql://%s:%s@%s:%d?sslmode=disable", user, password, d.Host, d.Port)
conn, err := sql.Open(dbname, connectionString)
if err != nil {
return err
}
d.connection = conn
return nil
}
func (d Database) Execute(sql string) error {
if d.connection == nil {
return errors.New("Cannot execute SQL without being connected to a database.")
}
_, err := d.connection.Exec(sql)
return err
}

View file

@ -8,9 +8,7 @@ package main
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"database/sql"
"fmt"
_ "github.com/lib/pq"
"log" "log"
"os" "os"
"path/filepath" "path/filepath"
@ -41,7 +39,7 @@ func migrate(db DB, migrationPath string) error {
log.Printf("SQL: %s", migrationSql) log.Printf("SQL: %s", migrationSql)
_, err = db.Exec(migrationSql) err = db.Execute(migrationSql)
if err != nil { if err != nil {
log.Printf("Failed to run %s: %s", migrationPath, err) log.Printf("Failed to run %s: %s", migrationPath, err)
@ -126,8 +124,3 @@ func getMigrations(migrationsRoot string) ([]Migration, error) {
return migrationGraph.GetLinearHistory() return migrationGraph.GetLinearHistory()
} }
func ConnectToDatabase(config DatabaseConfiguration) (*sql.DB, error) {
connectionString := fmt.Sprintf("postgresql://%s:%s@%s:%d?sslmode=disable", config.User, config.Password, config.Host, config.Port)
return sql.Open(config.DatabaseName, connectionString)
}

View file

@ -19,18 +19,18 @@ type MockResult struct {
sql.Result sql.Result
} }
func (m *MockSqlDB) Exec(sql string, args ...any) (sql.Result, error) { func (m *MockSqlDB) Execute(sql string) error {
m.queries = append(m.queries, sql) m.queries = append(m.queries, sql)
if m.mockedError != nil { if m.mockedError != nil {
return MockResult{}, m.mockedError return m.mockedError
} }
if m.mockedResult != nil { if m.mockedResult != nil {
return m.mockedResult, nil return nil
} }
return MockResult{}, nil return nil
} }
func TestGetMigrationsGathersSqlFiles(t *testing.T) { func TestGetMigrationsGathersSqlFiles(t *testing.T) {

View file

@ -1,9 +1,5 @@
package main package main
import (
"database/sql"
)
type MigrationHeaders struct { type MigrationHeaders struct {
Requirements []string Requirements []string
} }
@ -15,14 +11,6 @@ type Migration struct {
Run bool Run bool
} }
type DatabaseConfiguration struct {
Host string
User string
Password string
DatabaseName string
Port int
}
type MigrationGraph struct { type MigrationGraph struct {
// Reference to the root of the graph. // Reference to the root of the graph.
Root *Migration Root *Migration
@ -33,5 +21,5 @@ type MigrationGraph struct {
} }
type DB interface { type DB interface {
Exec(sql string, args ...any) (sql.Result, error) Execute(sql string) error
} }