From e5d848349199a63deb8cb0873b80eb2d000947cf Mon Sep 17 00:00:00 2001 From: Marc Cataford Date: Thu, 11 Jul 2024 20:24:57 -0400 Subject: [PATCH] refactor: graph from directory --- cobble.go | 4 ++-- migrate.go | 40 ---------------------------------------- migrate_test.go | 28 ++++++++++++++++------------ migration_graph.go | 38 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 56 insertions(+), 54 deletions(-) diff --git a/cobble.go b/cobble.go index 77f1a2f..5a2d5f7 100644 --- a/cobble.go +++ b/cobble.go @@ -27,8 +27,8 @@ var up = &cobra.Command{ } migrationRoot := args[0] - migrationHistory, err := getMigrations(migrationRoot) - + migrationGraph, _ := NewMigrationGraphFromDirectory(migrationRoot) + migrationHistory, _ := migrationGraph.GetLinearHistory() for _, migration := range migrationHistory { log.Printf("%s", migration.Name) if err = ApplyMigration(db, migration); err != nil { diff --git a/migrate.go b/migrate.go index 950f34e..063f64f 100644 --- a/migrate.go +++ b/migrate.go @@ -11,7 +11,6 @@ import ( "log" "os" - "path/filepath" ) var commentPrefix = []byte("--") @@ -79,42 +78,3 @@ func gatherMigrationHeaders(migrationFile *os.File) MigrationHeaders { return MigrationHeaders{Requirements: requirements} } - -// Gathers available migrations from the migration root directory. -// -// Migrations are expected to be .sql files. -func getMigrations(migrationsRoot string) ([]Migration, error) { - files, err := os.ReadDir(migrationsRoot) - - if err != nil { - log.Printf("Failed to read migration directory: %s", migrationsRoot) - return []Migration{}, err - } - - migrationGraph := NewMigrationGraph() - - for _, file := range files { - filename := file.Name() - - if filepath.Ext(filename) != ".sql" { - continue - } - - migrationPath := filepath.Join(migrationsRoot, file.Name()) - - file, _ := os.Open(migrationPath) - - headers := gatherMigrationHeaders(file) - - requirements := "" - - if len(headers.Requirements) > 0 { - requirements = headers.Requirements[0] - } - - migration := NewMigration(migrationPath, filepath.Base(file.Name()), requirements) - migrationGraph.AddMigration(migration) - } - - return migrationGraph.GetLinearHistory() -} diff --git a/migrate_test.go b/migrate_test.go index 6560d65..cbf03b0 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -39,7 +39,8 @@ func TestGetMigrationsGathersSqlFiles(t *testing.T) { os.WriteFile(filepath.Join(root, "0002.sql"), []byte("--requires:0001.sql\nSELECT 1 from table;"), 0750) os.WriteFile(filepath.Join(root, "0003.sql"), []byte("--requires:0002.sql\nSELECT 1 from table;"), 0750) - migrations, err := getMigrations(root) + migrationGraph, err := NewMigrationGraphFromDirectory(root) + migrations, _ := migrationGraph.GetLinearHistory() if len(migrations) != 3 { t.Errorf("Expected three migrations collected, got %d instead.", len(migrations)) @@ -71,7 +72,8 @@ func TestGetMigrationsIgnoresNonMigrationFiles(t *testing.T) { os.WriteFile(filepath.Join(root, "0001.sql"), []byte("SELECT 1 from table;"), 0750) os.WriteFile(filepath.Join(root, "0002.txt"), []byte("SELECT 1 from table;"), 0750) - migrations, err := getMigrations(root) + migrationGraph, err := NewMigrationGraphFromDirectory(root) + migrations, _ := migrationGraph.GetLinearHistory() if len(migrations) != 1 { t.Errorf("Expected one migration collected, got %d instead.", len(migrations)) @@ -95,7 +97,8 @@ func TestGetMigrationsIgnoresNonMigrationFiles(t *testing.T) { } func TestGetMigrationsReturnsErrorIfInvalidRootDirectory(t *testing.T) { - migrations, err := getMigrations("not-root") + migrationGraph, _ := NewMigrationGraphFromDirectory("not-root") + migrations, err := migrationGraph.GetLinearHistory() if len(migrations) != 0 { t.Errorf("Did not expect any migrations returned, got %#v", migrations) @@ -111,8 +114,8 @@ func TestGetMigrationsParsesRequirementsInCommentHeaders(t *testing.T) { os.WriteFile(filepath.Join(root, "0000.sql"), []byte("SELECT 1 from table;"), 0750) os.WriteFile(filepath.Join(root, "0001.sql"), []byte("-- requires: 0000.sql\nSELECT 1 from table;"), 0750) - migrations, err := getMigrations(root) - + migrationGraph, _ := NewMigrationGraphFromDirectory(root) + migrations, err := migrationGraph.GetLinearHistory() if err != nil { t.Errorf("Expected no error returned, got %#v", err) } @@ -136,8 +139,8 @@ func TestGetMigrationsIgnoresUnrecognizedCommentHeaders(t *testing.T) { root := t.TempDir() os.WriteFile(filepath.Join(root, "0001.sql"), []byte("-- randomheader\nSELECT 1 from table;"), 0750) - migrations, err := getMigrations(root) - + migrationGraph, _ := NewMigrationGraphFromDirectory(root) + migrations, err := migrationGraph.GetLinearHistory() if err != nil { t.Errorf("Expected no error returned, got %#v", err) } @@ -163,9 +166,9 @@ func TestMigrateRunsSqlOnDBConnection(t *testing.T) { migrationSql := "SELECT 1 FROM table;" os.WriteFile(migrationPath, []byte(migrationSql), 0750) - migrations, _ := getMigrations(root) - - err := ApplyMigration(mockDb, migrations[0]) + migrationGraph, _ := NewMigrationGraphFromDirectory(root) + migrations, err := migrationGraph.GetLinearHistory() + err = ApplyMigration(mockDb, migrations[0]) if err != nil { t.Errorf("Expected no error returned, got %#v", err) @@ -197,9 +200,10 @@ func TestMigrateReturnsErrorOnFailToRunSQL(t *testing.T) { migrationPath := filepath.Join(root, "0001.sql") os.WriteFile(migrationPath, []byte(migrationSql), 0750) + migrationGraph, _ := NewMigrationGraphFromDirectory(root) + migrations, err := migrationGraph.GetLinearHistory() - migrations, _ := getMigrations(root) - err := ApplyMigration(mockDb, migrations[0]) + err = ApplyMigration(mockDb, migrations[0]) if err == nil { t.Errorf("Expected error returned, got nil") diff --git a/migration_graph.go b/migration_graph.go index 688459e..28ac90a 100644 --- a/migration_graph.go +++ b/migration_graph.go @@ -3,6 +3,9 @@ package main import ( "errors" "fmt" + "log" + "os" + "path/filepath" ) type MigrationGraph struct { @@ -18,6 +21,41 @@ func NewMigrationGraph() MigrationGraph { return MigrationGraph{Migrations: map[string]Migration{}, parentage: map[string]*Migration{}} } +func NewMigrationGraphFromDirectory(pathRoot string) (MigrationGraph, error) { + migrationGraph := NewMigrationGraph() + + files, err := os.ReadDir(pathRoot) + + if err != nil { + log.Printf("Failed to read migration directory: %s", pathRoot) + return migrationGraph, err + } + for _, file := range files { + filename := file.Name() + + if filepath.Ext(filename) != ".sql" { + continue + } + + migrationPath := filepath.Join(pathRoot, file.Name()) + + file, _ := os.Open(migrationPath) + + headers := gatherMigrationHeaders(file) + + requirements := "" + + if len(headers.Requirements) > 0 { + requirements = headers.Requirements[0] + } + + migration := NewMigration(migrationPath, filepath.Base(file.Name()), requirements) + migrationGraph.AddMigration(migration) + } + + return migrationGraph, nil +} + // Adds a migration to the graph. // // This also adds the migration to the parentage mappings to link it