162 lines
3.9 KiB
Go
162 lines
3.9 KiB
Go
|
// Migration tooling
|
||
|
//
|
||
|
// Runs migrations in-order based on files found in the migration
|
||
|
// root directory provided as argument.
|
||
|
|
||
|
package main
|
||
|
|
||
|
import (
|
||
|
"bufio"
|
||
|
"bytes"
|
||
|
"database/sql"
|
||
|
"fmt"
|
||
|
_ "github.com/lib/pq"
|
||
|
"log"
|
||
|
"os"
|
||
|
"path/filepath"
|
||
|
"strconv"
|
||
|
)
|
||
|
|
||
|
var commentPrefix = []byte("--")
|
||
|
var requirementPrefix = []byte("requires:")
|
||
|
|
||
|
func NewMigration(path string, name string, requires string) Migration {
|
||
|
return Migration{
|
||
|
Path: path,
|
||
|
Name: name,
|
||
|
Requires: requires,
|
||
|
Run: false,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Applies a migration to the given database.
|
||
|
func migrate(db DB, migrationPath string) error {
|
||
|
migrationBytes, err := os.ReadFile(migrationPath)
|
||
|
|
||
|
if err != nil {
|
||
|
log.Printf("Failed to read migration file: %s", migrationPath)
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
migrationSql := string(migrationBytes)
|
||
|
|
||
|
log.Printf("SQL: %s", migrationSql)
|
||
|
|
||
|
_, err = db.Exec(migrationSql)
|
||
|
|
||
|
if err != nil {
|
||
|
log.Printf("Failed to run %s: %s", migrationPath, err)
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Removes the provided prefix from the byte slice and removes
|
||
|
// any spaces leading or trailing the resulting slice.
|
||
|
//
|
||
|
// If the prefix is not present, nothing is removed except
|
||
|
// the leading and trailing spaces.
|
||
|
func cutPrefixAndTrim(b []byte, prefix []byte) []byte {
|
||
|
withoutPrefix, _ := bytes.CutPrefix(b, prefix)
|
||
|
return bytes.TrimSpace(withoutPrefix)
|
||
|
}
|
||
|
|
||
|
// Extracts any migration headers present in the migration source file.
|
||
|
// Headers are left in comments of the form:
|
||
|
// -- <header-name>: ...
|
||
|
// Anything else than the known headers is ignored.
|
||
|
func gatherMigrationHeaders(migrationFile *os.File) MigrationHeaders {
|
||
|
scanner := bufio.NewScanner(migrationFile)
|
||
|
|
||
|
requirements := []string{}
|
||
|
|
||
|
for scanner.Scan() {
|
||
|
currentBytes := scanner.Bytes()
|
||
|
|
||
|
if !bytes.HasPrefix(currentBytes, commentPrefix) {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
commentBytes := cutPrefixAndTrim(currentBytes, commentPrefix)
|
||
|
|
||
|
if bytes.HasPrefix(commentBytes, requirementPrefix) {
|
||
|
reqName := cutPrefixAndTrim(commentBytes, requirementPrefix)
|
||
|
requirements = append(requirements, string(reqName))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return MigrationHeaders{Requirements: requirements}
|
||
|
}
|
||
|
|
||
|
// Gathers available migrations from the migration root directory.
|
||
|
//
|
||
|
// Migrations are expected to be .sql files.
|
||
|
func getMigrations(migrationsRoot string) ([]Migration, error) {
|
||
|
files, err := os.ReadDir(migrationsRoot)
|
||
|
|
||
|
if err != nil {
|
||
|
log.Printf("Failed to read migration directory: %s", migrationsRoot)
|
||
|
return []Migration{}, err
|
||
|
}
|
||
|
|
||
|
migrationGraph := NewMigrationGraph()
|
||
|
|
||
|
for _, file := range files {
|
||
|
filename := file.Name()
|
||
|
|
||
|
if filepath.Ext(filename) != ".sql" {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
migrationPath := filepath.Join(migrationsRoot, file.Name())
|
||
|
|
||
|
file, _ := os.Open(migrationPath)
|
||
|
|
||
|
headers := gatherMigrationHeaders(file)
|
||
|
|
||
|
requirements := ""
|
||
|
|
||
|
if len(headers.Requirements) > 0 {
|
||
|
requirements = headers.Requirements[0]
|
||
|
}
|
||
|
|
||
|
migration := NewMigration(migrationPath, filepath.Base(file.Name()), requirements)
|
||
|
migrationGraph.AddMigration(migration)
|
||
|
}
|
||
|
|
||
|
return migrationGraph.GetLinearHistory()
|
||
|
}
|
||
|
|
||
|
func ConnectToDatabase(config DatabaseConfiguration) (*sql.DB, error) {
|
||
|
connectionString := fmt.Sprintf("postgresql://%s:%s@%s:%d?sslmode=disable", config.User, config.Password, config.Host, config.Port)
|
||
|
return sql.Open(config.DatabaseName, connectionString)
|
||
|
}
|
||
|
|
||
|
func main() {
|
||
|
dbPort, _ := strconv.Atoi(os.Getenv("DB_PORT"))
|
||
|
|
||
|
dbConnection, err := ConnectToDatabase(DatabaseConfiguration{
|
||
|
User: os.Getenv("DB_APP_USER"),
|
||
|
Password: os.Getenv("DB_APP_PASSWORD"),
|
||
|
DatabaseName: os.Getenv("DB_NAME"),
|
||
|
Host: os.Getenv("DB_HOST"),
|
||
|
Port: dbPort,
|
||
|
})
|
||
|
|
||
|
if err != nil {
|
||
|
log.Printf("Failed to connect to the database: %s", err)
|
||
|
os.Exit(1)
|
||
|
}
|
||
|
|
||
|
migrationRoot := os.Args[1]
|
||
|
migrationHistory, err := getMigrations(migrationRoot)
|
||
|
|
||
|
for _, migration := range migrationHistory {
|
||
|
log.Printf("%s", migration.Name)
|
||
|
if err = migrate(dbConnection, migration.Path); err != nil {
|
||
|
os.Exit(1)
|
||
|
}
|
||
|
}
|
||
|
}
|