cobble/migrate.go

120 lines
2.8 KiB
Go

// Migration tooling
//
// Runs migrations in-order based on files found in the migration
// root directory provided as argument.
package main
import (
"bufio"
"bytes"
"log"
"os"
"path/filepath"
)
var commentPrefix = []byte("--")
var requirementPrefix = []byte("requires:")
// Applies a migration to the given database connection.
//
// If an error is returned while trying to read the migration file
// or execute the SQL it contains, the error is returned.
func ApplyMigration(db DB, migration Migration) error {
migrationPath := migration.Path
migrationBytes, err := os.ReadFile(migrationPath)
if err != nil {
return err
}
migrationSql := string(migrationBytes)
log.Printf("SQL: %s", migrationSql)
err = db.Execute(migrationSql)
if err != nil {
return err
}
return nil
}
// Removes the provided prefix from the byte slice and removes
// any spaces leading or trailing the resulting slice.
//
// If the prefix is not present, nothing is removed except
// the leading and trailing spaces.
func cutPrefixAndTrim(b []byte, prefix []byte) []byte {
withoutPrefix, _ := bytes.CutPrefix(b, prefix)
return bytes.TrimSpace(withoutPrefix)
}
// Extracts any migration headers present in the migration source file.
// Headers are left in comments of the form:
// -- <header-name>: ...
// Anything else than the known headers is ignored.
func gatherMigrationHeaders(migrationFile *os.File) MigrationHeaders {
scanner := bufio.NewScanner(migrationFile)
requirements := []string{}
for scanner.Scan() {
currentBytes := scanner.Bytes()
if !bytes.HasPrefix(currentBytes, commentPrefix) {
continue
}
commentBytes := cutPrefixAndTrim(currentBytes, commentPrefix)
if bytes.HasPrefix(commentBytes, requirementPrefix) {
reqName := cutPrefixAndTrim(commentBytes, requirementPrefix)
requirements = append(requirements, string(reqName))
}
}
return MigrationHeaders{Requirements: requirements}
}
// Gathers available migrations from the migration root directory.
//
// Migrations are expected to be .sql files.
func getMigrations(migrationsRoot string) ([]Migration, error) {
files, err := os.ReadDir(migrationsRoot)
if err != nil {
log.Printf("Failed to read migration directory: %s", migrationsRoot)
return []Migration{}, err
}
migrationGraph := NewMigrationGraph()
for _, file := range files {
filename := file.Name()
if filepath.Ext(filename) != ".sql" {
continue
}
migrationPath := filepath.Join(migrationsRoot, file.Name())
file, _ := os.Open(migrationPath)
headers := gatherMigrationHeaders(file)
requirements := ""
if len(headers.Requirements) > 0 {
requirements = headers.Requirements[0]
}
migration := NewMigration(migrationPath, filepath.Base(file.Name()), requirements)
migrationGraph.AddMigration(migration)
}
return migrationGraph.GetLinearHistory()
}