refactor: graph from directory
This commit is contained in:
parent
90a512bd91
commit
e5d8483491
4 changed files with 56 additions and 54 deletions
|
@ -27,8 +27,8 @@ var up = &cobra.Command{
|
|||
}
|
||||
|
||||
migrationRoot := args[0]
|
||||
migrationHistory, err := getMigrations(migrationRoot)
|
||||
|
||||
migrationGraph, _ := NewMigrationGraphFromDirectory(migrationRoot)
|
||||
migrationHistory, _ := migrationGraph.GetLinearHistory()
|
||||
for _, migration := range migrationHistory {
|
||||
log.Printf("%s", migration.Name)
|
||||
if err = ApplyMigration(db, migration); err != nil {
|
||||
|
|
40
migrate.go
40
migrate.go
|
@ -11,7 +11,6 @@ import (
|
|||
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
var commentPrefix = []byte("--")
|
||||
|
@ -79,42 +78,3 @@ func gatherMigrationHeaders(migrationFile *os.File) MigrationHeaders {
|
|||
|
||||
return MigrationHeaders{Requirements: requirements}
|
||||
}
|
||||
|
||||
// Gathers available migrations from the migration root directory.
|
||||
//
|
||||
// Migrations are expected to be .sql files.
|
||||
func getMigrations(migrationsRoot string) ([]Migration, error) {
|
||||
files, err := os.ReadDir(migrationsRoot)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Failed to read migration directory: %s", migrationsRoot)
|
||||
return []Migration{}, err
|
||||
}
|
||||
|
||||
migrationGraph := NewMigrationGraph()
|
||||
|
||||
for _, file := range files {
|
||||
filename := file.Name()
|
||||
|
||||
if filepath.Ext(filename) != ".sql" {
|
||||
continue
|
||||
}
|
||||
|
||||
migrationPath := filepath.Join(migrationsRoot, file.Name())
|
||||
|
||||
file, _ := os.Open(migrationPath)
|
||||
|
||||
headers := gatherMigrationHeaders(file)
|
||||
|
||||
requirements := ""
|
||||
|
||||
if len(headers.Requirements) > 0 {
|
||||
requirements = headers.Requirements[0]
|
||||
}
|
||||
|
||||
migration := NewMigration(migrationPath, filepath.Base(file.Name()), requirements)
|
||||
migrationGraph.AddMigration(migration)
|
||||
}
|
||||
|
||||
return migrationGraph.GetLinearHistory()
|
||||
}
|
||||
|
|
|
@ -39,7 +39,8 @@ func TestGetMigrationsGathersSqlFiles(t *testing.T) {
|
|||
os.WriteFile(filepath.Join(root, "0002.sql"), []byte("--requires:0001.sql\nSELECT 1 from table;"), 0750)
|
||||
os.WriteFile(filepath.Join(root, "0003.sql"), []byte("--requires:0002.sql\nSELECT 1 from table;"), 0750)
|
||||
|
||||
migrations, err := getMigrations(root)
|
||||
migrationGraph, err := NewMigrationGraphFromDirectory(root)
|
||||
migrations, _ := migrationGraph.GetLinearHistory()
|
||||
|
||||
if len(migrations) != 3 {
|
||||
t.Errorf("Expected three migrations collected, got %d instead.", len(migrations))
|
||||
|
@ -71,7 +72,8 @@ func TestGetMigrationsIgnoresNonMigrationFiles(t *testing.T) {
|
|||
os.WriteFile(filepath.Join(root, "0001.sql"), []byte("SELECT 1 from table;"), 0750)
|
||||
os.WriteFile(filepath.Join(root, "0002.txt"), []byte("SELECT 1 from table;"), 0750)
|
||||
|
||||
migrations, err := getMigrations(root)
|
||||
migrationGraph, err := NewMigrationGraphFromDirectory(root)
|
||||
migrations, _ := migrationGraph.GetLinearHistory()
|
||||
|
||||
if len(migrations) != 1 {
|
||||
t.Errorf("Expected one migration collected, got %d instead.", len(migrations))
|
||||
|
@ -95,7 +97,8 @@ func TestGetMigrationsIgnoresNonMigrationFiles(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestGetMigrationsReturnsErrorIfInvalidRootDirectory(t *testing.T) {
|
||||
migrations, err := getMigrations("not-root")
|
||||
migrationGraph, _ := NewMigrationGraphFromDirectory("not-root")
|
||||
migrations, err := migrationGraph.GetLinearHistory()
|
||||
|
||||
if len(migrations) != 0 {
|
||||
t.Errorf("Did not expect any migrations returned, got %#v", migrations)
|
||||
|
@ -111,8 +114,8 @@ func TestGetMigrationsParsesRequirementsInCommentHeaders(t *testing.T) {
|
|||
os.WriteFile(filepath.Join(root, "0000.sql"), []byte("SELECT 1 from table;"), 0750)
|
||||
os.WriteFile(filepath.Join(root, "0001.sql"), []byte("-- requires: 0000.sql\nSELECT 1 from table;"), 0750)
|
||||
|
||||
migrations, err := getMigrations(root)
|
||||
|
||||
migrationGraph, _ := NewMigrationGraphFromDirectory(root)
|
||||
migrations, err := migrationGraph.GetLinearHistory()
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error returned, got %#v", err)
|
||||
}
|
||||
|
@ -136,8 +139,8 @@ func TestGetMigrationsIgnoresUnrecognizedCommentHeaders(t *testing.T) {
|
|||
root := t.TempDir()
|
||||
os.WriteFile(filepath.Join(root, "0001.sql"), []byte("-- randomheader\nSELECT 1 from table;"), 0750)
|
||||
|
||||
migrations, err := getMigrations(root)
|
||||
|
||||
migrationGraph, _ := NewMigrationGraphFromDirectory(root)
|
||||
migrations, err := migrationGraph.GetLinearHistory()
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error returned, got %#v", err)
|
||||
}
|
||||
|
@ -163,9 +166,9 @@ func TestMigrateRunsSqlOnDBConnection(t *testing.T) {
|
|||
migrationSql := "SELECT 1 FROM table;"
|
||||
os.WriteFile(migrationPath, []byte(migrationSql), 0750)
|
||||
|
||||
migrations, _ := getMigrations(root)
|
||||
|
||||
err := ApplyMigration(mockDb, migrations[0])
|
||||
migrationGraph, _ := NewMigrationGraphFromDirectory(root)
|
||||
migrations, err := migrationGraph.GetLinearHistory()
|
||||
err = ApplyMigration(mockDb, migrations[0])
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error returned, got %#v", err)
|
||||
|
@ -197,9 +200,10 @@ func TestMigrateReturnsErrorOnFailToRunSQL(t *testing.T) {
|
|||
|
||||
migrationPath := filepath.Join(root, "0001.sql")
|
||||
os.WriteFile(migrationPath, []byte(migrationSql), 0750)
|
||||
migrationGraph, _ := NewMigrationGraphFromDirectory(root)
|
||||
migrations, err := migrationGraph.GetLinearHistory()
|
||||
|
||||
migrations, _ := getMigrations(root)
|
||||
err := ApplyMigration(mockDb, migrations[0])
|
||||
err = ApplyMigration(mockDb, migrations[0])
|
||||
|
||||
if err == nil {
|
||||
t.Errorf("Expected error returned, got nil")
|
||||
|
|
|
@ -3,6 +3,9 @@ package main
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
type MigrationGraph struct {
|
||||
|
@ -18,6 +21,41 @@ func NewMigrationGraph() MigrationGraph {
|
|||
return MigrationGraph{Migrations: map[string]Migration{}, parentage: map[string]*Migration{}}
|
||||
}
|
||||
|
||||
func NewMigrationGraphFromDirectory(pathRoot string) (MigrationGraph, error) {
|
||||
migrationGraph := NewMigrationGraph()
|
||||
|
||||
files, err := os.ReadDir(pathRoot)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Failed to read migration directory: %s", pathRoot)
|
||||
return migrationGraph, err
|
||||
}
|
||||
for _, file := range files {
|
||||
filename := file.Name()
|
||||
|
||||
if filepath.Ext(filename) != ".sql" {
|
||||
continue
|
||||
}
|
||||
|
||||
migrationPath := filepath.Join(pathRoot, file.Name())
|
||||
|
||||
file, _ := os.Open(migrationPath)
|
||||
|
||||
headers := gatherMigrationHeaders(file)
|
||||
|
||||
requirements := ""
|
||||
|
||||
if len(headers.Requirements) > 0 {
|
||||
requirements = headers.Requirements[0]
|
||||
}
|
||||
|
||||
migration := NewMigration(migrationPath, filepath.Base(file.Name()), requirements)
|
||||
migrationGraph.AddMigration(migration)
|
||||
}
|
||||
|
||||
return migrationGraph, nil
|
||||
}
|
||||
|
||||
// Adds a migration to the graph.
|
||||
//
|
||||
// This also adds the migration to the parentage mappings to link it
|
||||
|
|
Loading…
Reference in a new issue