refactor: graph from directory

This commit is contained in:
Marc 2024-07-11 20:24:57 -04:00
parent 90a512bd91
commit e5d8483491
Signed by: marc
GPG key ID: 048E042F22B5DC79
4 changed files with 56 additions and 54 deletions

View file

@ -27,8 +27,8 @@ var up = &cobra.Command{
} }
migrationRoot := args[0] migrationRoot := args[0]
migrationHistory, err := getMigrations(migrationRoot) migrationGraph, _ := NewMigrationGraphFromDirectory(migrationRoot)
migrationHistory, _ := migrationGraph.GetLinearHistory()
for _, migration := range migrationHistory { for _, migration := range migrationHistory {
log.Printf("%s", migration.Name) log.Printf("%s", migration.Name)
if err = ApplyMigration(db, migration); err != nil { if err = ApplyMigration(db, migration); err != nil {

View file

@ -11,7 +11,6 @@ import (
"log" "log"
"os" "os"
"path/filepath"
) )
var commentPrefix = []byte("--") var commentPrefix = []byte("--")
@ -79,42 +78,3 @@ func gatherMigrationHeaders(migrationFile *os.File) MigrationHeaders {
return MigrationHeaders{Requirements: requirements} return MigrationHeaders{Requirements: requirements}
} }
// Gathers available migrations from the migration root directory.
//
// Migrations are expected to be .sql files.
func getMigrations(migrationsRoot string) ([]Migration, error) {
files, err := os.ReadDir(migrationsRoot)
if err != nil {
log.Printf("Failed to read migration directory: %s", migrationsRoot)
return []Migration{}, err
}
migrationGraph := NewMigrationGraph()
for _, file := range files {
filename := file.Name()
if filepath.Ext(filename) != ".sql" {
continue
}
migrationPath := filepath.Join(migrationsRoot, file.Name())
file, _ := os.Open(migrationPath)
headers := gatherMigrationHeaders(file)
requirements := ""
if len(headers.Requirements) > 0 {
requirements = headers.Requirements[0]
}
migration := NewMigration(migrationPath, filepath.Base(file.Name()), requirements)
migrationGraph.AddMigration(migration)
}
return migrationGraph.GetLinearHistory()
}

View file

@ -39,7 +39,8 @@ func TestGetMigrationsGathersSqlFiles(t *testing.T) {
os.WriteFile(filepath.Join(root, "0002.sql"), []byte("--requires:0001.sql\nSELECT 1 from table;"), 0750) os.WriteFile(filepath.Join(root, "0002.sql"), []byte("--requires:0001.sql\nSELECT 1 from table;"), 0750)
os.WriteFile(filepath.Join(root, "0003.sql"), []byte("--requires:0002.sql\nSELECT 1 from table;"), 0750) os.WriteFile(filepath.Join(root, "0003.sql"), []byte("--requires:0002.sql\nSELECT 1 from table;"), 0750)
migrations, err := getMigrations(root) migrationGraph, err := NewMigrationGraphFromDirectory(root)
migrations, _ := migrationGraph.GetLinearHistory()
if len(migrations) != 3 { if len(migrations) != 3 {
t.Errorf("Expected three migrations collected, got %d instead.", len(migrations)) t.Errorf("Expected three migrations collected, got %d instead.", len(migrations))
@ -71,7 +72,8 @@ func TestGetMigrationsIgnoresNonMigrationFiles(t *testing.T) {
os.WriteFile(filepath.Join(root, "0001.sql"), []byte("SELECT 1 from table;"), 0750) os.WriteFile(filepath.Join(root, "0001.sql"), []byte("SELECT 1 from table;"), 0750)
os.WriteFile(filepath.Join(root, "0002.txt"), []byte("SELECT 1 from table;"), 0750) os.WriteFile(filepath.Join(root, "0002.txt"), []byte("SELECT 1 from table;"), 0750)
migrations, err := getMigrations(root) migrationGraph, err := NewMigrationGraphFromDirectory(root)
migrations, _ := migrationGraph.GetLinearHistory()
if len(migrations) != 1 { if len(migrations) != 1 {
t.Errorf("Expected one migration collected, got %d instead.", len(migrations)) t.Errorf("Expected one migration collected, got %d instead.", len(migrations))
@ -95,7 +97,8 @@ func TestGetMigrationsIgnoresNonMigrationFiles(t *testing.T) {
} }
func TestGetMigrationsReturnsErrorIfInvalidRootDirectory(t *testing.T) { func TestGetMigrationsReturnsErrorIfInvalidRootDirectory(t *testing.T) {
migrations, err := getMigrations("not-root") migrationGraph, _ := NewMigrationGraphFromDirectory("not-root")
migrations, err := migrationGraph.GetLinearHistory()
if len(migrations) != 0 { if len(migrations) != 0 {
t.Errorf("Did not expect any migrations returned, got %#v", migrations) t.Errorf("Did not expect any migrations returned, got %#v", migrations)
@ -111,8 +114,8 @@ func TestGetMigrationsParsesRequirementsInCommentHeaders(t *testing.T) {
os.WriteFile(filepath.Join(root, "0000.sql"), []byte("SELECT 1 from table;"), 0750) os.WriteFile(filepath.Join(root, "0000.sql"), []byte("SELECT 1 from table;"), 0750)
os.WriteFile(filepath.Join(root, "0001.sql"), []byte("-- requires: 0000.sql\nSELECT 1 from table;"), 0750) os.WriteFile(filepath.Join(root, "0001.sql"), []byte("-- requires: 0000.sql\nSELECT 1 from table;"), 0750)
migrations, err := getMigrations(root) migrationGraph, _ := NewMigrationGraphFromDirectory(root)
migrations, err := migrationGraph.GetLinearHistory()
if err != nil { if err != nil {
t.Errorf("Expected no error returned, got %#v", err) t.Errorf("Expected no error returned, got %#v", err)
} }
@ -136,8 +139,8 @@ func TestGetMigrationsIgnoresUnrecognizedCommentHeaders(t *testing.T) {
root := t.TempDir() root := t.TempDir()
os.WriteFile(filepath.Join(root, "0001.sql"), []byte("-- randomheader\nSELECT 1 from table;"), 0750) os.WriteFile(filepath.Join(root, "0001.sql"), []byte("-- randomheader\nSELECT 1 from table;"), 0750)
migrations, err := getMigrations(root) migrationGraph, _ := NewMigrationGraphFromDirectory(root)
migrations, err := migrationGraph.GetLinearHistory()
if err != nil { if err != nil {
t.Errorf("Expected no error returned, got %#v", err) t.Errorf("Expected no error returned, got %#v", err)
} }
@ -163,9 +166,9 @@ func TestMigrateRunsSqlOnDBConnection(t *testing.T) {
migrationSql := "SELECT 1 FROM table;" migrationSql := "SELECT 1 FROM table;"
os.WriteFile(migrationPath, []byte(migrationSql), 0750) os.WriteFile(migrationPath, []byte(migrationSql), 0750)
migrations, _ := getMigrations(root) migrationGraph, _ := NewMigrationGraphFromDirectory(root)
migrations, err := migrationGraph.GetLinearHistory()
err := ApplyMigration(mockDb, migrations[0]) err = ApplyMigration(mockDb, migrations[0])
if err != nil { if err != nil {
t.Errorf("Expected no error returned, got %#v", err) t.Errorf("Expected no error returned, got %#v", err)
@ -197,9 +200,10 @@ func TestMigrateReturnsErrorOnFailToRunSQL(t *testing.T) {
migrationPath := filepath.Join(root, "0001.sql") migrationPath := filepath.Join(root, "0001.sql")
os.WriteFile(migrationPath, []byte(migrationSql), 0750) os.WriteFile(migrationPath, []byte(migrationSql), 0750)
migrationGraph, _ := NewMigrationGraphFromDirectory(root)
migrations, err := migrationGraph.GetLinearHistory()
migrations, _ := getMigrations(root) err = ApplyMigration(mockDb, migrations[0])
err := ApplyMigration(mockDb, migrations[0])
if err == nil { if err == nil {
t.Errorf("Expected error returned, got nil") t.Errorf("Expected error returned, got nil")

View file

@ -3,6 +3,9 @@ package main
import ( import (
"errors" "errors"
"fmt" "fmt"
"log"
"os"
"path/filepath"
) )
type MigrationGraph struct { type MigrationGraph struct {
@ -18,6 +21,41 @@ func NewMigrationGraph() MigrationGraph {
return MigrationGraph{Migrations: map[string]Migration{}, parentage: map[string]*Migration{}} return MigrationGraph{Migrations: map[string]Migration{}, parentage: map[string]*Migration{}}
} }
func NewMigrationGraphFromDirectory(pathRoot string) (MigrationGraph, error) {
migrationGraph := NewMigrationGraph()
files, err := os.ReadDir(pathRoot)
if err != nil {
log.Printf("Failed to read migration directory: %s", pathRoot)
return migrationGraph, err
}
for _, file := range files {
filename := file.Name()
if filepath.Ext(filename) != ".sql" {
continue
}
migrationPath := filepath.Join(pathRoot, file.Name())
file, _ := os.Open(migrationPath)
headers := gatherMigrationHeaders(file)
requirements := ""
if len(headers.Requirements) > 0 {
requirements = headers.Requirements[0]
}
migration := NewMigration(migrationPath, filepath.Base(file.Name()), requirements)
migrationGraph.AddMigration(migration)
}
return migrationGraph, nil
}
// Adds a migration to the graph. // Adds a migration to the graph.
// //
// This also adds the migration to the parentage mappings to link it // This also adds the migration to the parentage mappings to link it