refactor: graph from directory
This commit is contained in:
parent
90a512bd91
commit
e5d8483491
4 changed files with 56 additions and 54 deletions
|
@ -27,8 +27,8 @@ var up = &cobra.Command{
|
||||||
}
|
}
|
||||||
|
|
||||||
migrationRoot := args[0]
|
migrationRoot := args[0]
|
||||||
migrationHistory, err := getMigrations(migrationRoot)
|
migrationGraph, _ := NewMigrationGraphFromDirectory(migrationRoot)
|
||||||
|
migrationHistory, _ := migrationGraph.GetLinearHistory()
|
||||||
for _, migration := range migrationHistory {
|
for _, migration := range migrationHistory {
|
||||||
log.Printf("%s", migration.Name)
|
log.Printf("%s", migration.Name)
|
||||||
if err = ApplyMigration(db, migration); err != nil {
|
if err = ApplyMigration(db, migration); err != nil {
|
||||||
|
|
40
migrate.go
40
migrate.go
|
@ -11,7 +11,6 @@ import (
|
||||||
|
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var commentPrefix = []byte("--")
|
var commentPrefix = []byte("--")
|
||||||
|
@ -79,42 +78,3 @@ func gatherMigrationHeaders(migrationFile *os.File) MigrationHeaders {
|
||||||
|
|
||||||
return MigrationHeaders{Requirements: requirements}
|
return MigrationHeaders{Requirements: requirements}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Gathers available migrations from the migration root directory.
|
|
||||||
//
|
|
||||||
// Migrations are expected to be .sql files.
|
|
||||||
func getMigrations(migrationsRoot string) ([]Migration, error) {
|
|
||||||
files, err := os.ReadDir(migrationsRoot)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to read migration directory: %s", migrationsRoot)
|
|
||||||
return []Migration{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
migrationGraph := NewMigrationGraph()
|
|
||||||
|
|
||||||
for _, file := range files {
|
|
||||||
filename := file.Name()
|
|
||||||
|
|
||||||
if filepath.Ext(filename) != ".sql" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
migrationPath := filepath.Join(migrationsRoot, file.Name())
|
|
||||||
|
|
||||||
file, _ := os.Open(migrationPath)
|
|
||||||
|
|
||||||
headers := gatherMigrationHeaders(file)
|
|
||||||
|
|
||||||
requirements := ""
|
|
||||||
|
|
||||||
if len(headers.Requirements) > 0 {
|
|
||||||
requirements = headers.Requirements[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
migration := NewMigration(migrationPath, filepath.Base(file.Name()), requirements)
|
|
||||||
migrationGraph.AddMigration(migration)
|
|
||||||
}
|
|
||||||
|
|
||||||
return migrationGraph.GetLinearHistory()
|
|
||||||
}
|
|
||||||
|
|
|
@ -39,7 +39,8 @@ func TestGetMigrationsGathersSqlFiles(t *testing.T) {
|
||||||
os.WriteFile(filepath.Join(root, "0002.sql"), []byte("--requires:0001.sql\nSELECT 1 from table;"), 0750)
|
os.WriteFile(filepath.Join(root, "0002.sql"), []byte("--requires:0001.sql\nSELECT 1 from table;"), 0750)
|
||||||
os.WriteFile(filepath.Join(root, "0003.sql"), []byte("--requires:0002.sql\nSELECT 1 from table;"), 0750)
|
os.WriteFile(filepath.Join(root, "0003.sql"), []byte("--requires:0002.sql\nSELECT 1 from table;"), 0750)
|
||||||
|
|
||||||
migrations, err := getMigrations(root)
|
migrationGraph, err := NewMigrationGraphFromDirectory(root)
|
||||||
|
migrations, _ := migrationGraph.GetLinearHistory()
|
||||||
|
|
||||||
if len(migrations) != 3 {
|
if len(migrations) != 3 {
|
||||||
t.Errorf("Expected three migrations collected, got %d instead.", len(migrations))
|
t.Errorf("Expected three migrations collected, got %d instead.", len(migrations))
|
||||||
|
@ -71,7 +72,8 @@ func TestGetMigrationsIgnoresNonMigrationFiles(t *testing.T) {
|
||||||
os.WriteFile(filepath.Join(root, "0001.sql"), []byte("SELECT 1 from table;"), 0750)
|
os.WriteFile(filepath.Join(root, "0001.sql"), []byte("SELECT 1 from table;"), 0750)
|
||||||
os.WriteFile(filepath.Join(root, "0002.txt"), []byte("SELECT 1 from table;"), 0750)
|
os.WriteFile(filepath.Join(root, "0002.txt"), []byte("SELECT 1 from table;"), 0750)
|
||||||
|
|
||||||
migrations, err := getMigrations(root)
|
migrationGraph, err := NewMigrationGraphFromDirectory(root)
|
||||||
|
migrations, _ := migrationGraph.GetLinearHistory()
|
||||||
|
|
||||||
if len(migrations) != 1 {
|
if len(migrations) != 1 {
|
||||||
t.Errorf("Expected one migration collected, got %d instead.", len(migrations))
|
t.Errorf("Expected one migration collected, got %d instead.", len(migrations))
|
||||||
|
@ -95,7 +97,8 @@ func TestGetMigrationsIgnoresNonMigrationFiles(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetMigrationsReturnsErrorIfInvalidRootDirectory(t *testing.T) {
|
func TestGetMigrationsReturnsErrorIfInvalidRootDirectory(t *testing.T) {
|
||||||
migrations, err := getMigrations("not-root")
|
migrationGraph, _ := NewMigrationGraphFromDirectory("not-root")
|
||||||
|
migrations, err := migrationGraph.GetLinearHistory()
|
||||||
|
|
||||||
if len(migrations) != 0 {
|
if len(migrations) != 0 {
|
||||||
t.Errorf("Did not expect any migrations returned, got %#v", migrations)
|
t.Errorf("Did not expect any migrations returned, got %#v", migrations)
|
||||||
|
@ -111,8 +114,8 @@ func TestGetMigrationsParsesRequirementsInCommentHeaders(t *testing.T) {
|
||||||
os.WriteFile(filepath.Join(root, "0000.sql"), []byte("SELECT 1 from table;"), 0750)
|
os.WriteFile(filepath.Join(root, "0000.sql"), []byte("SELECT 1 from table;"), 0750)
|
||||||
os.WriteFile(filepath.Join(root, "0001.sql"), []byte("-- requires: 0000.sql\nSELECT 1 from table;"), 0750)
|
os.WriteFile(filepath.Join(root, "0001.sql"), []byte("-- requires: 0000.sql\nSELECT 1 from table;"), 0750)
|
||||||
|
|
||||||
migrations, err := getMigrations(root)
|
migrationGraph, _ := NewMigrationGraphFromDirectory(root)
|
||||||
|
migrations, err := migrationGraph.GetLinearHistory()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Expected no error returned, got %#v", err)
|
t.Errorf("Expected no error returned, got %#v", err)
|
||||||
}
|
}
|
||||||
|
@ -136,8 +139,8 @@ func TestGetMigrationsIgnoresUnrecognizedCommentHeaders(t *testing.T) {
|
||||||
root := t.TempDir()
|
root := t.TempDir()
|
||||||
os.WriteFile(filepath.Join(root, "0001.sql"), []byte("-- randomheader\nSELECT 1 from table;"), 0750)
|
os.WriteFile(filepath.Join(root, "0001.sql"), []byte("-- randomheader\nSELECT 1 from table;"), 0750)
|
||||||
|
|
||||||
migrations, err := getMigrations(root)
|
migrationGraph, _ := NewMigrationGraphFromDirectory(root)
|
||||||
|
migrations, err := migrationGraph.GetLinearHistory()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Expected no error returned, got %#v", err)
|
t.Errorf("Expected no error returned, got %#v", err)
|
||||||
}
|
}
|
||||||
|
@ -163,9 +166,9 @@ func TestMigrateRunsSqlOnDBConnection(t *testing.T) {
|
||||||
migrationSql := "SELECT 1 FROM table;"
|
migrationSql := "SELECT 1 FROM table;"
|
||||||
os.WriteFile(migrationPath, []byte(migrationSql), 0750)
|
os.WriteFile(migrationPath, []byte(migrationSql), 0750)
|
||||||
|
|
||||||
migrations, _ := getMigrations(root)
|
migrationGraph, _ := NewMigrationGraphFromDirectory(root)
|
||||||
|
migrations, err := migrationGraph.GetLinearHistory()
|
||||||
err := ApplyMigration(mockDb, migrations[0])
|
err = ApplyMigration(mockDb, migrations[0])
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Expected no error returned, got %#v", err)
|
t.Errorf("Expected no error returned, got %#v", err)
|
||||||
|
@ -197,9 +200,10 @@ func TestMigrateReturnsErrorOnFailToRunSQL(t *testing.T) {
|
||||||
|
|
||||||
migrationPath := filepath.Join(root, "0001.sql")
|
migrationPath := filepath.Join(root, "0001.sql")
|
||||||
os.WriteFile(migrationPath, []byte(migrationSql), 0750)
|
os.WriteFile(migrationPath, []byte(migrationSql), 0750)
|
||||||
|
migrationGraph, _ := NewMigrationGraphFromDirectory(root)
|
||||||
|
migrations, err := migrationGraph.GetLinearHistory()
|
||||||
|
|
||||||
migrations, _ := getMigrations(root)
|
err = ApplyMigration(mockDb, migrations[0])
|
||||||
err := ApplyMigration(mockDb, migrations[0])
|
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("Expected error returned, got nil")
|
t.Errorf("Expected error returned, got nil")
|
||||||
|
|
|
@ -3,6 +3,9 @@ package main
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
)
|
)
|
||||||
|
|
||||||
type MigrationGraph struct {
|
type MigrationGraph struct {
|
||||||
|
@ -18,6 +21,41 @@ func NewMigrationGraph() MigrationGraph {
|
||||||
return MigrationGraph{Migrations: map[string]Migration{}, parentage: map[string]*Migration{}}
|
return MigrationGraph{Migrations: map[string]Migration{}, parentage: map[string]*Migration{}}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewMigrationGraphFromDirectory(pathRoot string) (MigrationGraph, error) {
|
||||||
|
migrationGraph := NewMigrationGraph()
|
||||||
|
|
||||||
|
files, err := os.ReadDir(pathRoot)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to read migration directory: %s", pathRoot)
|
||||||
|
return migrationGraph, err
|
||||||
|
}
|
||||||
|
for _, file := range files {
|
||||||
|
filename := file.Name()
|
||||||
|
|
||||||
|
if filepath.Ext(filename) != ".sql" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
migrationPath := filepath.Join(pathRoot, file.Name())
|
||||||
|
|
||||||
|
file, _ := os.Open(migrationPath)
|
||||||
|
|
||||||
|
headers := gatherMigrationHeaders(file)
|
||||||
|
|
||||||
|
requirements := ""
|
||||||
|
|
||||||
|
if len(headers.Requirements) > 0 {
|
||||||
|
requirements = headers.Requirements[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
migration := NewMigration(migrationPath, filepath.Base(file.Name()), requirements)
|
||||||
|
migrationGraph.AddMigration(migration)
|
||||||
|
}
|
||||||
|
|
||||||
|
return migrationGraph, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Adds a migration to the graph.
|
// Adds a migration to the graph.
|
||||||
//
|
//
|
||||||
// This also adds the migration to the parentage mappings to link it
|
// This also adds the migration to the parentage mappings to link it
|
||||||
|
|
Loading…
Reference in a new issue