diff --git a/.gitignore b/.gitignore index adf8f72..84fff5b 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,5 @@ # Go workspace file go.work +*.env +migrations diff --git a/MigrationMap.go b/MigrationMap.go new file mode 100644 index 0000000..688459e --- /dev/null +++ b/MigrationMap.go @@ -0,0 +1,79 @@ +package main + +import ( + "errors" + "fmt" +) + +type MigrationGraph struct { + // Reference to the root of the graph. + Root *Migration + // Name to struct mapping of all migrations part of the graph. + Migrations map[string]Migration + // Mapping of all migrations to their parent, if any. + parentage map[string]*Migration +} + +func NewMigrationGraph() MigrationGraph { + return MigrationGraph{Migrations: map[string]Migration{}, parentage: map[string]*Migration{}} +} + +// Adds a migration to the graph. +// +// This also adds the migration to the parentage mappings to link it +// to its parent. If the migration added has no parent, then it's also +// set to be the root of the graph. +func (g *MigrationGraph) AddMigration(m Migration) { + if m.Requires == "" { + g.addRoot(m) + } + + g.Migrations[m.Name] = m + g.parentage[m.Requires] = &m +} + +// Builds the linear history of migrations. +func (g *MigrationGraph) GetLinearHistory() ([]Migration, error) { + if g.Root == nil { + return []Migration{}, errors.New("Cannot get linear history, the graph has no root.") + } + + ordered := []Migration{} + visited := map[string]bool{} + unordered := []Migration{*g.Root} + + for len(unordered) > 0 { + migration := unordered[0] + unordered = unordered[1:] + + if _, hasVisited := visited[migration.Name]; hasVisited { + return []Migration{}, errors.New("Cycle detected, cannot generate linear history.") + } + + child, hasChildren := g.parentage[migration.Name] + + ordered = append(ordered, migration) + + if hasChildren { + unordered = append(unordered, *child) + } + + visited[migration.Name] = true + } + + if len(ordered) != len(g.Migrations) { + return ordered, errors.New("Not all migrations in the graph are part of the history.") + } + + return ordered, nil +} + +func (g *MigrationGraph) addRoot(m Migration) error { + if g.Root != nil { + return errors.New(fmt.Sprintf("Cannot have more than one root, tried to add %#v but already knew about %#v.", m, g.Root)) + } + + g.Root = &m + + return nil +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..6715750 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module cobble + +go 1.22.2 + +require github.com/lib/pq v1.10.9 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..aeddeae --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= diff --git a/migrate.go b/migrate.go new file mode 100644 index 0000000..bee1ac6 --- /dev/null +++ b/migrate.go @@ -0,0 +1,184 @@ +// Migration tooling +// +// Runs migrations in-order based on files found in the migration +// root directory provided as argument. + +package main + +import ( + "bufio" + "bytes" + "database/sql" + "fmt" + _ "github.com/lib/pq" + "log" + "os" + "path/filepath" + "strconv" +) + +type MigrationHeaders struct { + Requirements []string +} + +type Migration struct { + Path string + Name string + Requires string + Run bool +} + +type DB interface { + Exec(sql string, args ...any) (sql.Result, error) +} + +var commentPrefix = []byte("--") +var requirementPrefix = []byte("requires:") + +func NewMigration(path string, name string, requires string) Migration { + return Migration{ + Path: path, + Name: name, + Requires: requires, + Run: false, + } +} + +// Applies a migration to the given database. +func migrate(db DB, migrationPath string) error { + migrationBytes, err := os.ReadFile(migrationPath) + + if err != nil { + log.Printf("Failed to read migration file: %s", migrationPath) + return err + } + + migrationSql := string(migrationBytes) + + log.Printf("SQL: %s", migrationSql) + + _, err = db.Exec(migrationSql) + + if err != nil { + log.Printf("Failed to run %s: %s", migrationPath, err) + return err + } + + return nil +} + +// Removes the provided prefix from the byte slice and removes +// any spaces leading or trailing the resulting slice. +// +// If the prefix is not present, nothing is removed except +// the leading and trailing spaces. +func cutPrefixAndTrim(b []byte, prefix []byte) []byte { + withoutPrefix, _ := bytes.CutPrefix(b, prefix) + return bytes.TrimSpace(withoutPrefix) +} + +// Extracts any migration headers present in the migration source file. +// Headers are left in comments of the form: +// -- : ... +// Anything else than the known headers is ignored. +func gatherMigrationHeaders(migrationFile *os.File) MigrationHeaders { + scanner := bufio.NewScanner(migrationFile) + + requirements := []string{} + + for scanner.Scan() { + currentBytes := scanner.Bytes() + + if !bytes.HasPrefix(currentBytes, commentPrefix) { + continue + } + + commentBytes := cutPrefixAndTrim(currentBytes, commentPrefix) + + if bytes.HasPrefix(commentBytes, requirementPrefix) { + reqName := cutPrefixAndTrim(commentBytes, requirementPrefix) + requirements = append(requirements, string(reqName)) + } + } + + return MigrationHeaders{Requirements: requirements} +} + +// Gathers available migrations from the migration root directory. +// +// Migrations are expected to be .sql files. +func getMigrations(migrationsRoot string) ([]Migration, error) { + files, err := os.ReadDir(migrationsRoot) + + if err != nil { + log.Printf("Failed to read migration directory: %s", migrationsRoot) + return []Migration{}, err + } + + migrationGraph := NewMigrationGraph() + + for _, file := range files { + filename := file.Name() + + if filepath.Ext(filename) != ".sql" { + continue + } + + migrationPath := filepath.Join(migrationsRoot, file.Name()) + + file, _ := os.Open(migrationPath) + + headers := gatherMigrationHeaders(file) + + requirements := "" + + if len(headers.Requirements) > 0 { + requirements = headers.Requirements[0] + } + + migration := NewMigration(migrationPath, filepath.Base(file.Name()), requirements) + migrationGraph.AddMigration(migration) + } + + return migrationGraph.GetLinearHistory() +} + +type DatabaseConfiguration struct { + Host string + User string + Password string + DatabaseName string + Port int +} + +func ConnectToDatabase(config DatabaseConfiguration) (*sql.DB, error) { + connectionString := fmt.Sprintf("postgresql://%s:%s@%s:%d?sslmode=disable", config.User, config.Password, config.Host, config.Port) + return sql.Open(config.DatabaseName, connectionString) +} + +func main() { + dbPort, _ := strconv.Atoi(os.Getenv("DB_PORT")) + + dbConnection, err := ConnectToDatabase(DatabaseConfiguration{ + User: os.Getenv("DB_APP_USER"), + Password: os.Getenv("DB_APP_PASSWORD"), + DatabaseName: os.Getenv("DB_NAME"), + Host: os.Getenv("DB_HOST"), + Port: dbPort, + }) + + if err != nil { + log.Printf("Failed to connect to the database: %s", err) + os.Exit(1) + } + + migrationRoot := os.Args[1] + migrationHistory, err := getMigrations(migrationRoot) + + for _, migration := range migrationHistory { + log.Printf("%s", migration.Name) + if err = migrate(dbConnection, migration.Path); err != nil { + os.Exit(1) + } + } +} diff --git a/migrate_test.go b/migrate_test.go new file mode 100644 index 0000000..da15370 --- /dev/null +++ b/migrate_test.go @@ -0,0 +1,202 @@ +package main + +import ( + "database/sql" + "errors" + "os" + "path/filepath" + "reflect" + "testing" +) + +type MockSqlDB struct { + queries []string + mockedError error + mockedResult sql.Result +} + +type MockResult struct { + sql.Result +} + +func (m *MockSqlDB) Exec(sql string, args ...any) (sql.Result, error) { + m.queries = append(m.queries, sql) + + if m.mockedError != nil { + return MockResult{}, m.mockedError + } + + if m.mockedResult != nil { + return m.mockedResult, nil + } + + return MockResult{}, nil +} + +func TestGetMigrationsGathersSqlFiles(t *testing.T) { + root := t.TempDir() + os.WriteFile(filepath.Join(root, "0001.sql"), []byte("SELECT 1 from table;"), 0750) + os.WriteFile(filepath.Join(root, "0002.sql"), []byte("--requires:0001.sql\nSELECT 1 from table;"), 0750) + os.WriteFile(filepath.Join(root, "0003.sql"), []byte("--requires:0002.sql\nSELECT 1 from table;"), 0750) + + migrations, err := getMigrations(root) + + if len(migrations) != 3 { + t.Errorf("Expected three migrations collected, got %d instead.", len(migrations)) + } + + if err != nil { + t.Errorf("Expected no error returned, got %#v", err) + } + + expected := []Migration{ + NewMigration(filepath.Join(root, "0001.sql"), "0001.sql", ""), + NewMigration(filepath.Join(root, "0002.sql"), "0002.sql", "0001.sql"), + NewMigration(filepath.Join(root, "0003.sql"), "0003.sql", "0002.sql"), + } + + collected := []Migration{} + + for _, migration := range migrations { + collected = append(collected, migration) + } + + if !reflect.DeepEqual(collected, expected) { + t.Errorf("Expected collected migrations to equal %#v, got %3v instead.", expected, collected) + } +} + +func TestGetMigrationsIgnoresNonMigrationFiles(t *testing.T) { + root := t.TempDir() + os.WriteFile(filepath.Join(root, "0001.sql"), []byte("SELECT 1 from table;"), 0750) + os.WriteFile(filepath.Join(root, "0002.txt"), []byte("SELECT 1 from table;"), 0750) + + migrations, err := getMigrations(root) + + if len(migrations) != 1 { + t.Errorf("Expected one migration collected, got %d instead.", len(migrations)) + } + + if err != nil { + t.Errorf("Expected no error returned, got %#v", err) + } + expected := []Migration{ + NewMigration(filepath.Join(root, "0001.sql"), "0001.sql", ""), + } + collected := []Migration{} + + for _, migration := range migrations { + collected = append(collected, migration) + } + + if !reflect.DeepEqual(collected, expected) { + t.Errorf("Expected collected migrations to equal %#v, got %#v instead.", expected, collected) + } +} + +func TestGetMigrationsReturnsErrorIfInvalidRootDirectory(t *testing.T) { + migrations, err := getMigrations("not-root") + + if len(migrations) != 0 { + t.Errorf("Did not expect any migrations returned, got %#v", migrations) + } + + if err == nil { + t.Errorf("Expected error, got nil.") + } +} + +func TestGetMigrationsParsesRequirementsInCommentHeaders(t *testing.T) { + root := t.TempDir() + os.WriteFile(filepath.Join(root, "0000.sql"), []byte("SELECT 1 from table;"), 0750) + os.WriteFile(filepath.Join(root, "0001.sql"), []byte("-- requires: 0000.sql\nSELECT 1 from table;"), 0750) + + migrations, err := getMigrations(root) + + if err != nil { + t.Errorf("Expected no error returned, got %#v", err) + } + + collected := []Migration{} + + for _, migration := range migrations { + collected = append(collected, migration) + } + expected := []Migration{ + NewMigration(filepath.Join(root, "0000.sql"), "0000.sql", ""), + NewMigration(filepath.Join(root, "0001.sql"), "0001.sql", "0000.sql"), + } + + if !reflect.DeepEqual(collected, expected) { + t.Errorf("Expected collected migrations to equal %#v, got %#v instead.", expected, collected) + } +} + +func TestGetMigrationsIgnoresUnrecognizedCommentHeaders(t *testing.T) { + root := t.TempDir() + os.WriteFile(filepath.Join(root, "0001.sql"), []byte("-- randomheader\nSELECT 1 from table;"), 0750) + + migrations, err := getMigrations(root) + + if err != nil { + t.Errorf("Expected no error returned, got %#v", err) + } + + expected := []Migration{ + NewMigration(filepath.Join(root, "0001.sql"), "0001.sql", ""), + } + + collected := []Migration{} + + for _, migration := range migrations { + collected = append(collected, migration) + } + if !reflect.DeepEqual(collected, expected) { + t.Errorf("Expected collected migrations to equal %#v, got %#v instead.", expected, collected) + } +} + +func TestMigrateRunsSqlOnDBConnection(t *testing.T) { + mockDb := &MockSqlDB{} + root := t.TempDir() + migrationPath := filepath.Join(root, "0001.sql") + migrationSql := "SELECT 1 FROM table;" + os.WriteFile(migrationPath, []byte(migrationSql), 0750) + + err := migrate(mockDb, migrationPath) + + if err != nil { + t.Errorf("Expected no error returned, got %#v", err) + } + + if !reflect.DeepEqual(mockDb.queries, []string{migrationSql}) { + t.Errorf("Expected queries %#v to be run, got %#v instead.", []string{migrationSql}, mockDb.queries) + } +} + +func TestMigrateReturnsErrorOnFailToReadMigrationFile(t *testing.T) { + mockDb := &MockSqlDB{} + root := t.TempDir() + // Does not exist. + migrationPath := filepath.Join(root, "0001.sql") + err := migrate(mockDb, migrationPath) + + if err == nil { + t.Errorf("Expected error returned, got nil") + } +} + +func TestMigrateReturnsErrorOnFailToRunSQL(t *testing.T) { + mockDb := &MockSqlDB{mockedError: errors.New("Test error!")} + root := t.TempDir() + migrationSql := "SELECT 1 FROM table;" + + migrationPath := filepath.Join(root, "0001.sql") + os.WriteFile(migrationPath, []byte(migrationSql), 0750) + + err := migrate(mockDb, migrationPath) + + if err == nil { + t.Errorf("Expected error returned, got nil") + } +}