cobble/migrate.go

133 lines
3.2 KiB
Go

// Migration tooling
//
// Runs migrations in-order based on files found in the migration
// root directory provided as argument.
package main
import (
"bufio"
"bytes"
"database/sql"
"fmt"
_ "github.com/lib/pq"
"log"
"os"
"path/filepath"
)
var commentPrefix = []byte("--")
var requirementPrefix = []byte("requires:")
func NewMigration(path string, name string, requires string) Migration {
return Migration{
Path: path,
Name: name,
Requires: requires,
Run: false,
}
}
// Applies a migration to the given database.
func migrate(db DB, migrationPath string) error {
migrationBytes, err := os.ReadFile(migrationPath)
if err != nil {
log.Printf("Failed to read migration file: %s", migrationPath)
return err
}
migrationSql := string(migrationBytes)
log.Printf("SQL: %s", migrationSql)
_, err = db.Exec(migrationSql)
if err != nil {
log.Printf("Failed to run %s: %s", migrationPath, err)
return err
}
return nil
}
// Removes the provided prefix from the byte slice and removes
// any spaces leading or trailing the resulting slice.
//
// If the prefix is not present, nothing is removed except
// the leading and trailing spaces.
func cutPrefixAndTrim(b []byte, prefix []byte) []byte {
withoutPrefix, _ := bytes.CutPrefix(b, prefix)
return bytes.TrimSpace(withoutPrefix)
}
// Extracts any migration headers present in the migration source file.
// Headers are left in comments of the form:
// -- <header-name>: ...
// Anything else than the known headers is ignored.
func gatherMigrationHeaders(migrationFile *os.File) MigrationHeaders {
scanner := bufio.NewScanner(migrationFile)
requirements := []string{}
for scanner.Scan() {
currentBytes := scanner.Bytes()
if !bytes.HasPrefix(currentBytes, commentPrefix) {
continue
}
commentBytes := cutPrefixAndTrim(currentBytes, commentPrefix)
if bytes.HasPrefix(commentBytes, requirementPrefix) {
reqName := cutPrefixAndTrim(commentBytes, requirementPrefix)
requirements = append(requirements, string(reqName))
}
}
return MigrationHeaders{Requirements: requirements}
}
// Gathers available migrations from the migration root directory.
//
// Migrations are expected to be .sql files.
func getMigrations(migrationsRoot string) ([]Migration, error) {
files, err := os.ReadDir(migrationsRoot)
if err != nil {
log.Printf("Failed to read migration directory: %s", migrationsRoot)
return []Migration{}, err
}
migrationGraph := NewMigrationGraph()
for _, file := range files {
filename := file.Name()
if filepath.Ext(filename) != ".sql" {
continue
}
migrationPath := filepath.Join(migrationsRoot, file.Name())
file, _ := os.Open(migrationPath)
headers := gatherMigrationHeaders(file)
requirements := ""
if len(headers.Requirements) > 0 {
requirements = headers.Requirements[0]
}
migration := NewMigration(migrationPath, filepath.Base(file.Name()), requirements)
migrationGraph.AddMigration(migration)
}
return migrationGraph.GetLinearHistory()
}
func ConnectToDatabase(config DatabaseConfiguration) (*sql.DB, error) {
connectionString := fmt.Sprintf("postgresql://%s:%s@%s:%d?sslmode=disable", config.User, config.Password, config.Host, config.Port)
return sql.Open(config.DatabaseName, connectionString)
}