diff --git a/cobble.go b/cobble.go index 5a2d5f7..7ce477b 100644 --- a/cobble.go +++ b/cobble.go @@ -31,7 +31,7 @@ var up = &cobra.Command{ migrationHistory, _ := migrationGraph.GetLinearHistory() for _, migration := range migrationHistory { log.Printf("%s", migration.Name) - if err = ApplyMigration(db, migration); err != nil { + if err = migration.ApplyMigration(db); err != nil { os.Exit(1) } } diff --git a/migrate.go b/migrate.go index 063f64f..69a5b9d 100644 --- a/migrate.go +++ b/migrate.go @@ -8,40 +8,12 @@ package main import ( "bufio" "bytes" - - "log" "os" ) var commentPrefix = []byte("--") var requirementPrefix = []byte("requires:") -// Applies a migration to the given database connection. -// -// If an error is returned while trying to read the migration file -// or execute the SQL it contains, the error is returned. -func ApplyMigration(db DB, migration Migration) error { - migrationPath := migration.Path - migrationBytes, err := os.ReadFile(migrationPath) - - if err != nil { - return err - } - - migrationSql := string(migrationBytes) - - log.Printf("SQL: %s", migrationSql) - - err = db.Execute(migrationSql) - - if err != nil { - return err - } - - return nil - -} - // Removes the provided prefix from the byte slice and removes // any spaces leading or trailing the resulting slice. // diff --git a/migrate_test.go b/migrate_test.go index cbf03b0..78c740b 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -168,7 +168,7 @@ func TestMigrateRunsSqlOnDBConnection(t *testing.T) { migrationGraph, _ := NewMigrationGraphFromDirectory(root) migrations, err := migrationGraph.GetLinearHistory() - err = ApplyMigration(mockDb, migrations[0]) + err = migrations[0].ApplyMigration(mockDb) if err != nil { t.Errorf("Expected no error returned, got %#v", err) @@ -186,7 +186,7 @@ func TestMigrateReturnsErrorOnFailToReadMigrationFile(t *testing.T) { migrationPath := filepath.Join(root, "0001.sql") migration := NewMigration(migrationPath, "0001.sql", "") - err := ApplyMigration(mockDb, migration) + err := migration.ApplyMigration(mockDb) if err == nil { t.Errorf("Expected error returned, got nil") @@ -203,7 +203,7 @@ func TestMigrateReturnsErrorOnFailToRunSQL(t *testing.T) { migrationGraph, _ := NewMigrationGraphFromDirectory(root) migrations, err := migrationGraph.GetLinearHistory() - err = ApplyMigration(mockDb, migrations[0]) + err = migrations[0].ApplyMigration(mockDb) if err == nil { t.Errorf("Expected error returned, got nil") diff --git a/migration.go b/migration.go index 7c1ade2..952ea84 100644 --- a/migration.go +++ b/migration.go @@ -1,5 +1,10 @@ package main +import ( + "log" + "os" +) + type Migration struct { Path string Name string @@ -15,3 +20,29 @@ func NewMigration(path string, name string, requires string) Migration { Run: false, } } + +// Applies a migration to the given database connection. +// +// If an error is returned while trying to read the migration file +// or execute the SQL it contains, the error is returned. +func (m *Migration) ApplyMigration(db DB) error { + migrationPath := m.Path + migrationBytes, err := os.ReadFile(migrationPath) + + if err != nil { + return err + } + + migrationSql := string(migrationBytes) + + log.Printf("SQL: %s", migrationSql) + + err = db.Execute(migrationSql) + + if err != nil { + return err + } + + return nil + +}