feat: initial commit, extract from other repository
This commit is contained in:
parent
e5a22eddfc
commit
9c7c12b76d
6 changed files with 474 additions and 0 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -21,3 +21,5 @@
|
||||||
# Go workspace file
|
# Go workspace file
|
||||||
go.work
|
go.work
|
||||||
|
|
||||||
|
*.env
|
||||||
|
migrations
|
||||||
|
|
79
MigrationMap.go
Normal file
79
MigrationMap.go
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MigrationGraph struct {
|
||||||
|
// Reference to the root of the graph.
|
||||||
|
Root *Migration
|
||||||
|
// Name to struct mapping of all migrations part of the graph.
|
||||||
|
Migrations map[string]Migration
|
||||||
|
// Mapping of all migrations to their parent, if any.
|
||||||
|
parentage map[string]*Migration
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMigrationGraph() MigrationGraph {
|
||||||
|
return MigrationGraph{Migrations: map[string]Migration{}, parentage: map[string]*Migration{}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adds a migration to the graph.
|
||||||
|
//
|
||||||
|
// This also adds the migration to the parentage mappings to link it
|
||||||
|
// to its parent. If the migration added has no parent, then it's also
|
||||||
|
// set to be the root of the graph.
|
||||||
|
func (g *MigrationGraph) AddMigration(m Migration) {
|
||||||
|
if m.Requires == "" {
|
||||||
|
g.addRoot(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
g.Migrations[m.Name] = m
|
||||||
|
g.parentage[m.Requires] = &m
|
||||||
|
}
|
||||||
|
|
||||||
|
// Builds the linear history of migrations.
|
||||||
|
func (g *MigrationGraph) GetLinearHistory() ([]Migration, error) {
|
||||||
|
if g.Root == nil {
|
||||||
|
return []Migration{}, errors.New("Cannot get linear history, the graph has no root.")
|
||||||
|
}
|
||||||
|
|
||||||
|
ordered := []Migration{}
|
||||||
|
visited := map[string]bool{}
|
||||||
|
unordered := []Migration{*g.Root}
|
||||||
|
|
||||||
|
for len(unordered) > 0 {
|
||||||
|
migration := unordered[0]
|
||||||
|
unordered = unordered[1:]
|
||||||
|
|
||||||
|
if _, hasVisited := visited[migration.Name]; hasVisited {
|
||||||
|
return []Migration{}, errors.New("Cycle detected, cannot generate linear history.")
|
||||||
|
}
|
||||||
|
|
||||||
|
child, hasChildren := g.parentage[migration.Name]
|
||||||
|
|
||||||
|
ordered = append(ordered, migration)
|
||||||
|
|
||||||
|
if hasChildren {
|
||||||
|
unordered = append(unordered, *child)
|
||||||
|
}
|
||||||
|
|
||||||
|
visited[migration.Name] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ordered) != len(g.Migrations) {
|
||||||
|
return ordered, errors.New("Not all migrations in the graph are part of the history.")
|
||||||
|
}
|
||||||
|
|
||||||
|
return ordered, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *MigrationGraph) addRoot(m Migration) error {
|
||||||
|
if g.Root != nil {
|
||||||
|
return errors.New(fmt.Sprintf("Cannot have more than one root, tried to add %#v but already knew about %#v.", m, g.Root))
|
||||||
|
}
|
||||||
|
|
||||||
|
g.Root = &m
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
5
go.mod
Normal file
5
go.mod
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
module cobble
|
||||||
|
|
||||||
|
go 1.22.2
|
||||||
|
|
||||||
|
require github.com/lib/pq v1.10.9
|
2
go.sum
Normal file
2
go.sum
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||||
|
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
184
migrate.go
Normal file
184
migrate.go
Normal file
|
@ -0,0 +1,184 @@
|
||||||
|
// Migration tooling
|
||||||
|
//
|
||||||
|
// Runs migrations in-order based on files found in the migration
|
||||||
|
// root directory provided as argument.
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
_ "github.com/lib/pq"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MigrationHeaders struct {
|
||||||
|
Requirements []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Migration struct {
|
||||||
|
Path string
|
||||||
|
Name string
|
||||||
|
Requires string
|
||||||
|
Run bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type DB interface {
|
||||||
|
Exec(sql string, args ...any) (sql.Result, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var commentPrefix = []byte("--")
|
||||||
|
var requirementPrefix = []byte("requires:")
|
||||||
|
|
||||||
|
func NewMigration(path string, name string, requires string) Migration {
|
||||||
|
return Migration{
|
||||||
|
Path: path,
|
||||||
|
Name: name,
|
||||||
|
Requires: requires,
|
||||||
|
Run: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Applies a migration to the given database.
|
||||||
|
func migrate(db DB, migrationPath string) error {
|
||||||
|
migrationBytes, err := os.ReadFile(migrationPath)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to read migration file: %s", migrationPath)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
migrationSql := string(migrationBytes)
|
||||||
|
|
||||||
|
log.Printf("SQL: %s", migrationSql)
|
||||||
|
|
||||||
|
_, err = db.Exec(migrationSql)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to run %s: %s", migrationPath, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Removes the provided prefix from the byte slice and removes
|
||||||
|
// any spaces leading or trailing the resulting slice.
|
||||||
|
//
|
||||||
|
// If the prefix is not present, nothing is removed except
|
||||||
|
// the leading and trailing spaces.
|
||||||
|
func cutPrefixAndTrim(b []byte, prefix []byte) []byte {
|
||||||
|
withoutPrefix, _ := bytes.CutPrefix(b, prefix)
|
||||||
|
return bytes.TrimSpace(withoutPrefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extracts any migration headers present in the migration source file.
|
||||||
|
// Headers are left in comments of the form:
|
||||||
|
// -- <header-name>: ...
|
||||||
|
// Anything else than the known headers is ignored.
|
||||||
|
func gatherMigrationHeaders(migrationFile *os.File) MigrationHeaders {
|
||||||
|
scanner := bufio.NewScanner(migrationFile)
|
||||||
|
|
||||||
|
requirements := []string{}
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
currentBytes := scanner.Bytes()
|
||||||
|
|
||||||
|
if !bytes.HasPrefix(currentBytes, commentPrefix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
commentBytes := cutPrefixAndTrim(currentBytes, commentPrefix)
|
||||||
|
|
||||||
|
if bytes.HasPrefix(commentBytes, requirementPrefix) {
|
||||||
|
reqName := cutPrefixAndTrim(commentBytes, requirementPrefix)
|
||||||
|
requirements = append(requirements, string(reqName))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return MigrationHeaders{Requirements: requirements}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gathers available migrations from the migration root directory.
|
||||||
|
//
|
||||||
|
// Migrations are expected to be .sql files.
|
||||||
|
func getMigrations(migrationsRoot string) ([]Migration, error) {
|
||||||
|
files, err := os.ReadDir(migrationsRoot)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to read migration directory: %s", migrationsRoot)
|
||||||
|
return []Migration{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
migrationGraph := NewMigrationGraph()
|
||||||
|
|
||||||
|
for _, file := range files {
|
||||||
|
filename := file.Name()
|
||||||
|
|
||||||
|
if filepath.Ext(filename) != ".sql" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
migrationPath := filepath.Join(migrationsRoot, file.Name())
|
||||||
|
|
||||||
|
file, _ := os.Open(migrationPath)
|
||||||
|
|
||||||
|
headers := gatherMigrationHeaders(file)
|
||||||
|
|
||||||
|
requirements := ""
|
||||||
|
|
||||||
|
if len(headers.Requirements) > 0 {
|
||||||
|
requirements = headers.Requirements[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
migration := NewMigration(migrationPath, filepath.Base(file.Name()), requirements)
|
||||||
|
migrationGraph.AddMigration(migration)
|
||||||
|
}
|
||||||
|
|
||||||
|
return migrationGraph.GetLinearHistory()
|
||||||
|
}
|
||||||
|
|
||||||
|
type DatabaseConfiguration struct {
|
||||||
|
Host string
|
||||||
|
User string
|
||||||
|
Password string
|
||||||
|
DatabaseName string
|
||||||
|
Port int
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConnectToDatabase(config DatabaseConfiguration) (*sql.DB, error) {
|
||||||
|
connectionString := fmt.Sprintf("postgresql://%s:%s@%s:%d?sslmode=disable", config.User, config.Password, config.Host, config.Port)
|
||||||
|
return sql.Open(config.DatabaseName, connectionString)
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
dbPort, _ := strconv.Atoi(os.Getenv("DB_PORT"))
|
||||||
|
|
||||||
|
dbConnection, err := ConnectToDatabase(DatabaseConfiguration{
|
||||||
|
User: os.Getenv("DB_APP_USER"),
|
||||||
|
Password: os.Getenv("DB_APP_PASSWORD"),
|
||||||
|
DatabaseName: os.Getenv("DB_NAME"),
|
||||||
|
Host: os.Getenv("DB_HOST"),
|
||||||
|
Port: dbPort,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to connect to the database: %s", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
migrationRoot := os.Args[1]
|
||||||
|
migrationHistory, err := getMigrations(migrationRoot)
|
||||||
|
|
||||||
|
for _, migration := range migrationHistory {
|
||||||
|
log.Printf("%s", migration.Name)
|
||||||
|
if err = migrate(dbConnection, migration.Path); err != nil {
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
202
migrate_test.go
Normal file
202
migrate_test.go
Normal file
|
@ -0,0 +1,202 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MockSqlDB struct {
|
||||||
|
queries []string
|
||||||
|
mockedError error
|
||||||
|
mockedResult sql.Result
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockResult struct {
|
||||||
|
sql.Result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSqlDB) Exec(sql string, args ...any) (sql.Result, error) {
|
||||||
|
m.queries = append(m.queries, sql)
|
||||||
|
|
||||||
|
if m.mockedError != nil {
|
||||||
|
return MockResult{}, m.mockedError
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.mockedResult != nil {
|
||||||
|
return m.mockedResult, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return MockResult{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetMigrationsGathersSqlFiles(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
os.WriteFile(filepath.Join(root, "0001.sql"), []byte("SELECT 1 from table;"), 0750)
|
||||||
|
os.WriteFile(filepath.Join(root, "0002.sql"), []byte("--requires:0001.sql\nSELECT 1 from table;"), 0750)
|
||||||
|
os.WriteFile(filepath.Join(root, "0003.sql"), []byte("--requires:0002.sql\nSELECT 1 from table;"), 0750)
|
||||||
|
|
||||||
|
migrations, err := getMigrations(root)
|
||||||
|
|
||||||
|
if len(migrations) != 3 {
|
||||||
|
t.Errorf("Expected three migrations collected, got %d instead.", len(migrations))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error returned, got %#v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := []Migration{
|
||||||
|
NewMigration(filepath.Join(root, "0001.sql"), "0001.sql", ""),
|
||||||
|
NewMigration(filepath.Join(root, "0002.sql"), "0002.sql", "0001.sql"),
|
||||||
|
NewMigration(filepath.Join(root, "0003.sql"), "0003.sql", "0002.sql"),
|
||||||
|
}
|
||||||
|
|
||||||
|
collected := []Migration{}
|
||||||
|
|
||||||
|
for _, migration := range migrations {
|
||||||
|
collected = append(collected, migration)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(collected, expected) {
|
||||||
|
t.Errorf("Expected collected migrations to equal %#v, got %3v instead.", expected, collected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetMigrationsIgnoresNonMigrationFiles(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
os.WriteFile(filepath.Join(root, "0001.sql"), []byte("SELECT 1 from table;"), 0750)
|
||||||
|
os.WriteFile(filepath.Join(root, "0002.txt"), []byte("SELECT 1 from table;"), 0750)
|
||||||
|
|
||||||
|
migrations, err := getMigrations(root)
|
||||||
|
|
||||||
|
if len(migrations) != 1 {
|
||||||
|
t.Errorf("Expected one migration collected, got %d instead.", len(migrations))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error returned, got %#v", err)
|
||||||
|
}
|
||||||
|
expected := []Migration{
|
||||||
|
NewMigration(filepath.Join(root, "0001.sql"), "0001.sql", ""),
|
||||||
|
}
|
||||||
|
collected := []Migration{}
|
||||||
|
|
||||||
|
for _, migration := range migrations {
|
||||||
|
collected = append(collected, migration)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(collected, expected) {
|
||||||
|
t.Errorf("Expected collected migrations to equal %#v, got %#v instead.", expected, collected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetMigrationsReturnsErrorIfInvalidRootDirectory(t *testing.T) {
|
||||||
|
migrations, err := getMigrations("not-root")
|
||||||
|
|
||||||
|
if len(migrations) != 0 {
|
||||||
|
t.Errorf("Did not expect any migrations returned, got %#v", migrations)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error, got nil.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetMigrationsParsesRequirementsInCommentHeaders(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
os.WriteFile(filepath.Join(root, "0000.sql"), []byte("SELECT 1 from table;"), 0750)
|
||||||
|
os.WriteFile(filepath.Join(root, "0001.sql"), []byte("-- requires: 0000.sql\nSELECT 1 from table;"), 0750)
|
||||||
|
|
||||||
|
migrations, err := getMigrations(root)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error returned, got %#v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
collected := []Migration{}
|
||||||
|
|
||||||
|
for _, migration := range migrations {
|
||||||
|
collected = append(collected, migration)
|
||||||
|
}
|
||||||
|
expected := []Migration{
|
||||||
|
NewMigration(filepath.Join(root, "0000.sql"), "0000.sql", ""),
|
||||||
|
NewMigration(filepath.Join(root, "0001.sql"), "0001.sql", "0000.sql"),
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(collected, expected) {
|
||||||
|
t.Errorf("Expected collected migrations to equal %#v, got %#v instead.", expected, collected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetMigrationsIgnoresUnrecognizedCommentHeaders(t *testing.T) {
|
||||||
|
root := t.TempDir()
|
||||||
|
os.WriteFile(filepath.Join(root, "0001.sql"), []byte("-- randomheader\nSELECT 1 from table;"), 0750)
|
||||||
|
|
||||||
|
migrations, err := getMigrations(root)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error returned, got %#v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := []Migration{
|
||||||
|
NewMigration(filepath.Join(root, "0001.sql"), "0001.sql", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
collected := []Migration{}
|
||||||
|
|
||||||
|
for _, migration := range migrations {
|
||||||
|
collected = append(collected, migration)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(collected, expected) {
|
||||||
|
t.Errorf("Expected collected migrations to equal %#v, got %#v instead.", expected, collected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigrateRunsSqlOnDBConnection(t *testing.T) {
|
||||||
|
mockDb := &MockSqlDB{}
|
||||||
|
root := t.TempDir()
|
||||||
|
migrationPath := filepath.Join(root, "0001.sql")
|
||||||
|
migrationSql := "SELECT 1 FROM table;"
|
||||||
|
os.WriteFile(migrationPath, []byte(migrationSql), 0750)
|
||||||
|
|
||||||
|
err := migrate(mockDb, migrationPath)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error returned, got %#v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(mockDb.queries, []string{migrationSql}) {
|
||||||
|
t.Errorf("Expected queries %#v to be run, got %#v instead.", []string{migrationSql}, mockDb.queries)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigrateReturnsErrorOnFailToReadMigrationFile(t *testing.T) {
|
||||||
|
mockDb := &MockSqlDB{}
|
||||||
|
root := t.TempDir()
|
||||||
|
// Does not exist.
|
||||||
|
migrationPath := filepath.Join(root, "0001.sql")
|
||||||
|
err := migrate(mockDb, migrationPath)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error returned, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigrateReturnsErrorOnFailToRunSQL(t *testing.T) {
|
||||||
|
mockDb := &MockSqlDB{mockedError: errors.New("Test error!")}
|
||||||
|
root := t.TempDir()
|
||||||
|
migrationSql := "SELECT 1 FROM table;"
|
||||||
|
|
||||||
|
migrationPath := filepath.Join(root, "0001.sql")
|
||||||
|
os.WriteFile(migrationPath, []byte(migrationSql), 0750)
|
||||||
|
|
||||||
|
err := migrate(mockDb, migrationPath)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error returned, got nil")
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue