refactor: pull in ApplyMigration into migration
This commit is contained in:
parent
e5d8483491
commit
fba9720db0
4 changed files with 35 additions and 32 deletions
|
@ -31,7 +31,7 @@ var up = &cobra.Command{
|
|||
migrationHistory, _ := migrationGraph.GetLinearHistory()
|
||||
for _, migration := range migrationHistory {
|
||||
log.Printf("%s", migration.Name)
|
||||
if err = ApplyMigration(db, migration); err != nil {
|
||||
if err = migration.ApplyMigration(db); err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
|
28
migrate.go
28
migrate.go
|
@ -8,40 +8,12 @@ package main
|
|||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
var commentPrefix = []byte("--")
|
||||
var requirementPrefix = []byte("requires:")
|
||||
|
||||
// Applies a migration to the given database connection.
|
||||
//
|
||||
// If an error is returned while trying to read the migration file
|
||||
// or execute the SQL it contains, the error is returned.
|
||||
func ApplyMigration(db DB, migration Migration) error {
|
||||
migrationPath := migration.Path
|
||||
migrationBytes, err := os.ReadFile(migrationPath)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
migrationSql := string(migrationBytes)
|
||||
|
||||
log.Printf("SQL: %s", migrationSql)
|
||||
|
||||
err = db.Execute(migrationSql)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// Removes the provided prefix from the byte slice and removes
|
||||
// any spaces leading or trailing the resulting slice.
|
||||
//
|
||||
|
|
|
@ -168,7 +168,7 @@ func TestMigrateRunsSqlOnDBConnection(t *testing.T) {
|
|||
|
||||
migrationGraph, _ := NewMigrationGraphFromDirectory(root)
|
||||
migrations, err := migrationGraph.GetLinearHistory()
|
||||
err = ApplyMigration(mockDb, migrations[0])
|
||||
err = migrations[0].ApplyMigration(mockDb)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error returned, got %#v", err)
|
||||
|
@ -186,7 +186,7 @@ func TestMigrateReturnsErrorOnFailToReadMigrationFile(t *testing.T) {
|
|||
migrationPath := filepath.Join(root, "0001.sql")
|
||||
|
||||
migration := NewMigration(migrationPath, "0001.sql", "")
|
||||
err := ApplyMigration(mockDb, migration)
|
||||
err := migration.ApplyMigration(mockDb)
|
||||
|
||||
if err == nil {
|
||||
t.Errorf("Expected error returned, got nil")
|
||||
|
@ -203,7 +203,7 @@ func TestMigrateReturnsErrorOnFailToRunSQL(t *testing.T) {
|
|||
migrationGraph, _ := NewMigrationGraphFromDirectory(root)
|
||||
migrations, err := migrationGraph.GetLinearHistory()
|
||||
|
||||
err = ApplyMigration(mockDb, migrations[0])
|
||||
err = migrations[0].ApplyMigration(mockDb)
|
||||
|
||||
if err == nil {
|
||||
t.Errorf("Expected error returned, got nil")
|
||||
|
|
31
migration.go
31
migration.go
|
@ -1,5 +1,10 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
type Migration struct {
|
||||
Path string
|
||||
Name string
|
||||
|
@ -15,3 +20,29 @@ func NewMigration(path string, name string, requires string) Migration {
|
|||
Run: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Applies a migration to the given database connection.
|
||||
//
|
||||
// If an error is returned while trying to read the migration file
|
||||
// or execute the SQL it contains, the error is returned.
|
||||
func (m *Migration) ApplyMigration(db DB) error {
|
||||
migrationPath := m.Path
|
||||
migrationBytes, err := os.ReadFile(migrationPath)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
migrationSql := string(migrationBytes)
|
||||
|
||||
log.Printf("SQL: %s", migrationSql)
|
||||
|
||||
err = db.Execute(migrationSql)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue